diff --git a/.github/workflows/build_rocm.yml b/.github/workflows/build_rocm.yml new file mode 100644 index 0000000000..7faf187bca --- /dev/null +++ b/.github/workflows/build_rocm.yml @@ -0,0 +1,97 @@ +name: Build ROCm and Test + +on: + push: + branches: [ rocm-support ] + workflow_dispatch: + +jobs: + build-and-test: + runs-on: strix-halo + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + run: | + uv venv venv + source venv/bin/activate + uv pip install --upgrade mlx-lm + + - name: Build and install MLX ROCm wheel + run: | + source venv/bin/activate + export CMAKE_ARGS="-DMLX_BUILD_ROCM=ON -DMLX_ROCM_ARCHITECTURES=gfx1151 -DBLA_VENDOR=OpenBLAS -DCMAKE_BUILD_TYPE=RelWithDebInfo" + rm -rf wheelhouse + mkdir -p wheelhouse + uv build --wheel --out-dir wheelhouse . + uv pip install --force-reinstall wheelhouse/mlx-*.whl + + - name: Basic MLX GPU test + run: | + source venv/bin/activate + python3 -c " + import mlx.core as mx + print('MLX version:', mx.__version__) + print('Default device:', mx.default_device()) + mx.set_default_device(mx.gpu) + print('GPU device set') + + # Test basic operations + a = mx.ones((10, 10)) + mx.eval(a) + print('Basic array creation: OK') + + # Test matmul + b = mx.random.normal((256, 256)) + c = mx.matmul(b, b) + mx.eval(c) + print('Matmul test: OK') + + # Test softmax + d = mx.softmax(b, axis=-1) + mx.eval(d) + print('Softmax test: OK') + + print('All basic tests passed!') + " + + - name: Run inference tests + run: | + source venv/bin/activate + export HIP_LAUNCH_BLOCKING=1 + export PYTHONFAULTHANDLER=1 + mkdir -p "${GITHUB_WORKSPACE}/rocm-stacktraces" + + run_and_trace() { + local name="$1" + shift + lldb -Q -b \ + -o "run" \ + -k "bt" \ + -k "quit 1" \ + -- python3 "$(which mlx_lm.generate)" "$@" \ + > >(tee "${GITHUB_WORKSPACE}/rocm-stacktraces/${name}.log") 2>&1 + } + + run_and_trace qwen3_bf16 --model mlx-community/Qwen3-0.6B-bf16 --prompt "Hi" --max-tokens 5 + run_and_trace qwen3_8bit --model mlx-community/Qwen3-0.6B-8bit --prompt "How tall is Mt Everest?" --max-tokens 128 + + - name: Upload ROCm wheel artifacts + if: ${{ always() }} + uses: actions/upload-artifact@v6 + with: + name: rocm-wheel-${{ github.run_attempt }} + path: wheelhouse/mlx-*.whl + if-no-files-found: warn + retention-days: 14 + + - name: Upload ROCm stacktrace artifacts + if: ${{ always() }} + uses: actions/upload-artifact@v6 + with: + name: rocm-stacktraces-${{ github.run_attempt }} + path: ${{ github.workspace }}/rocm-stacktraces/* + if-no-files-found: warn + retention-days: 14 diff --git a/.gitignore b/.gitignore index 1daaa46d12..4da73eccf5 100644 --- a/.gitignore +++ b/.gitignore @@ -79,3 +79,10 @@ uv.lock .cache/ # vim *.swp + +# keys +*.pem + +build.sh +github-runner/ +sync_fork.sh diff --git a/CMakeLists.txt b/CMakeLists.txt index a14ea9ffc5..56ec705d2c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -36,6 +36,7 @@ option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF) option(MLX_BUILD_METAL "Build metal backend" ON) option(MLX_BUILD_CPU "Build cpu backend" ON) option(MLX_BUILD_CUDA "Build cuda backend" OFF) +option(MLX_BUILD_ROCM "Build rocm backend" OFF) option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF) option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF) option(MLX_BUILD_GGUF "Include support for GGUF format" ON) @@ -164,6 +165,43 @@ if(MLX_BUILD_CUDA) endif() endif() +if(MLX_BUILD_ROCM) + # Set HIP architectures - these will be used by the ROCm backend + # CMakeLists.txt + # + # Supported architectures from ROCm 6.4.0 - 7.2.0 compatibility matrix: CDNA: + # gfx908 (MI100), gfx90a (MI200), gfx942 (MI300) CDNA4: gfx950 (MI400 series) + # RDNA2: gfx1030 (RX 6000 series) RDNA3: gfx1100 (RX 7900), gfx1101 (RX 7600) + # RDNA4: gfx1200, gfx1201 (RX 8000 series) + if(NOT DEFINED CMAKE_HIP_ARCHITECTURES) + if(DEFINED MLX_ROCM_ARCHITECTURES) + set(CMAKE_HIP_ARCHITECTURES + ${MLX_ROCM_ARCHITECTURES} + CACHE STRING "HIP architectures") + else() + set(CMAKE_HIP_ARCHITECTURES + "gfx908;gfx90a;gfx942;gfx1010;gfx1011;gfx1012;gfx1030;gfx1031;gfx1032;gfx1100;gfx1101;gfx1102" + CACHE STRING "HIP architectures") + endif() + endif() + message( + STATUS "Setting CMAKE_HIP_ARCHITECTURES to: ${CMAKE_HIP_ARCHITECTURES}") + # Note: We don't enable_language(HIP) here because it causes CMake to add -x + # hip to all CXX files in targets that link to HIP libraries. Instead, we + # compile HIP files using custom commands in the ROCm backend CMakeLists.txt. + # Find the HIP compiler + find_program( + CMAKE_HIP_COMPILER + NAMES hipcc clang++ + PATHS /opt/rocm/bin /opt/rocm-6.0.0/bin /opt/rocm/llvm/bin + PATH_SUFFIXES bin + DOC "HIP compiler") + if(NOT CMAKE_HIP_COMPILER) + message(FATAL_ERROR "Could not find HIP compiler (hipcc or clang++)") + endif() + message(STATUS "Found HIP compiler: ${CMAKE_HIP_COMPILER}") +endif() + if(MLX_BUILD_METAL) find_library(METAL_LIB Metal) find_library(FOUNDATION_LIB Foundation) @@ -310,10 +348,12 @@ if(MLX_BUILD_CPU) message(FATAL_ERROR "Must have LAPACK installed") endif() find_path(LAPACK_INCLUDE_DIRS lapacke.h /usr/include /usr/local/include - /usr/local/opt/openblas/include) + /usr/local/opt/openblas/include /usr/include/openblas) message(STATUS "Lapack lib " ${LAPACK_LIBRARIES}) message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS}) - target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS}) + if(LAPACK_INCLUDE_DIRS) + target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS}) + endif() target_link_libraries(mlx PRIVATE ${LAPACK_LIBRARIES}) # List blas after lapack otherwise we may accidentally incldue an old # version of lapack.h from the include dirs of blas. diff --git a/benchmark_llm_rocm.py b/benchmark_llm_rocm.py new file mode 100644 index 0000000000..bd739bfa08 --- /dev/null +++ b/benchmark_llm_rocm.py @@ -0,0 +1,685 @@ +#!/usr/bin/env python3 + +import argparse +import re +import shlex +import subprocess +import sys +from dataclasses import dataclass + +MODEL_VARIANTS: dict[str, dict[str, str]] = { + "glm_4_7_flash_bf16": { + "mlx_repo": "mlx-community/GLM-4.7-Flash-bf16", + "llama_hf": "unsloth/GLM-4.7-Flash-GGUF:BF16", + }, + "glm_4_7_flash_8bit": { + "mlx_repo": "mlx-community/GLM-4.7-Flash-8bit", + "llama_hf": "unsloth/GLM-4.7-Flash-GGUF:Q8_0", + }, + "qwen3_0_6b_bf16": { + "mlx_repo": "mlx-community/Qwen3-0.6B-bf16", + "llama_hf": "unsloth/Qwen3-0.6B-GGUF:BF16", + }, + "qwen3_0_6b_8bit": { + "mlx_repo": "mlx-community/Qwen3-0.6B-8bit", + "llama_hf": "unsloth/Qwen3-0.6B-GGUF:Q8_0", + }, + "qwen3_coder_next_4bit": { + "mlx_repo": "mlx-community/Qwen3-Coder-Next-4bit", + "llama_hf": "unsloth/Qwen3-Coder-Next-GGUF:Q4_K_M", + }, +} + +DEFAULT_PROMPT = """ +You are a coding assistant with deep expertise in GPU programming, machine learning systems, and performance optimization. + +Explain, in plain English, how a GPU inference benchmark should be designed to fairly compare two runtimes (such as MLX vs llama.cpp). Provide a comprehensive analysis covering the following aspects: + +1. Prompt Length Considerations: + - Why varying prompt lengths (short, medium, long) reveal different performance characteristics + - How prompt length affects memory bandwidth utilization vs compute utilization + - The relationship between prompt length and KV cache behavior + - Recommended prompt lengths for realistic benchmarks (128, 512, 1024, 2048 tokens) + +2. Decode Length Impact: + - How generation length affects time-to-first-token vs sustained throughput + - Why short decodes may not represent real-world usage + - The effect of decode length on memory allocation patterns + - Recommendations for decode lengths to test (64, 128, 256, 512 tokens) + +3. Sampling Settings: + - Why temperature, top-k, top-p, and min-p settings affect benchmark consistency + - The trade-off between deterministic (greedy) and stochastic sampling + - How to choose sampling settings for fair comparisons + - The impact of different sampling strategies on kernel utilization + +4. Warmup Considerations: + - Why warmup runs are essential for accurate GPU benchmarks + - How CUDA/ROCm kernel compilation affects first-run latency + - Memory allocation warmup vs kernel warmup + - Recommended warmup strategies (number of runs, timing) + +5. Memory Pressure Testing: + - How to test under realistic memory constraints + - The effect of batch size on memory utilization + - KV cache memory scaling with sequence length + - Out-of-memory behavior and graceful degradation + +6. Deterministic Seeds: + - Why deterministic seeds are critical for reproducibility + - How random seed affects sampling and therefore timing + - Recommendations for seed management in benchmarks + +7. Additional Considerations: + - GPU temperature throttling and thermal equilibrium + - Power management and clock frequency stability + - Multi-GPU scaling considerations + - Quantization format comparisons (BF16, FP16, INT8, INT4) + +Keep the answer structured with clear sections and bullet points. Provide specific numerical recommendations where applicable. +""" + + +@dataclass +class RunStats: + variant: str + backend: str + model: str + prompt_tokens: int | None = None + prompt_tps: float | None = None + gen_tokens: int | None = None + gen_tps: float | None = None + peak_mem_gb: float | None = None + error: str | None = None + + +def run_command(cmd: list[str]) -> str: + # Redact prompt from printed command to reduce clutter + printed_cmd = [] + skip_next = False + for arg in cmd: + if skip_next: + printed_cmd.append("") + skip_next = False + else: + printed_cmd.append(arg) + if arg == "--prompt": + skip_next = True + print(f"\n$ {shlex.join(printed_cmd)}") + proc = subprocess.run(cmd, capture_output=True, text=True) + output = (proc.stdout or "") + (proc.stderr or "") + if proc.returncode != 0: + raise RuntimeError(f"Command failed with exit code {proc.returncode}\n{output}") + return output + + +def parse_mlx_stats(output: str, variant: str, model: str) -> RunStats: + stats = RunStats(variant=variant, backend="mlx", model=model) + + m = re.search(r"Prompt:\s*(\d+)\s*tokens,\s*([0-9.]+)\s*tokens-per-sec", output) + if m: + stats.prompt_tokens = int(m.group(1)) + stats.prompt_tps = float(m.group(2)) + + m = re.search(r"Generation:\s*(\d+)\s*tokens,\s*([0-9.]+)\s*tokens-per-sec", output) + if m: + stats.gen_tokens = int(m.group(1)) + stats.gen_tps = float(m.group(2)) + + m = re.search(r"Peak memory:\s*([0-9.]+)\s*GB", output) + if m: + stats.peak_mem_gb = float(m.group(1)) + + return stats + + +def maybe_fmt_float(v: float | None, digits: int = 3) -> str: + if v is None: + return "n/a" + return f"{v:.{digits}f}" + + +def maybe_fmt_int(v: int | None) -> str: + if v is None: + return "n/a" + return str(v) + + +def parse_int_token_count(s: str) -> int: + return int(s.replace(",", "")) + + +def parse_tps_value(s: str) -> float | None: + if s.lower() == "inf": + return None + return float(s) + + +def parse_llama_cli_stats(output: str, variant: str, model: str) -> RunStats: + stats = RunStats(variant=variant, backend="llama", model=model) + + # Typical llama.cpp timing format examples: + # common_perf_print: prompt eval time = ... / 60 tokens (..., 332.12 tokens per second) + # common_perf_print: eval time = ... / 7 runs (..., 46.40 tokens per second) + prompt_re = re.compile( + r"/\s*([0-9,]+)\s*tokens?\s*\(\s*[0-9.]+\s*ms per token,\s*([0-9.]+|inf)\s*(?:tok/s|tokens per second)", + flags=re.IGNORECASE, + ) + eval_re = re.compile( + r"/\s*([0-9,]+)\s*(?:runs|tokens?)\s*\(\s*[0-9.]+\s*ms per token,\s*([0-9.]+|inf)\s*(?:tok/s|tokens per second)", + flags=re.IGNORECASE, + ) + + for line in output.splitlines(): + low = line.lower() + if "prompt eval time" in low: + m = prompt_re.search(line) + if m: + stats.prompt_tokens = parse_int_token_count(m.group(1)) + stats.prompt_tps = parse_tps_value(m.group(2)) + elif "eval time" in low: + m = eval_re.search(line) + if m: + stats.gen_tokens = parse_int_token_count(m.group(1)) + stats.gen_tps = parse_tps_value(m.group(2)) + + # Fallback for interactive llama-cli output format: + # [ Prompt: 84.9 t/s | Generation: 50.3 t/s ] + if stats.prompt_tps is None or stats.gen_tps is None: + m = re.search( + r"Prompt:\s*([0-9.]+)\s*t/s\s*\|\s*Generation:\s*([0-9.]+)\s*t/s", + output, + flags=re.IGNORECASE, + ) + if m: + stats.prompt_tps = parse_tps_value(m.group(1)) + stats.gen_tps = parse_tps_value(m.group(2)) + + return stats + + +def run_mlx(cfg: dict[str, str], variant: str, args: argparse.Namespace) -> RunStats: + mlx_model = cfg["mlx_repo"] + + try: + import time + + import mlx.core as mx + + try: + import mlx_lm + from mlx_lm.generate import stream_generate as lm_stream_generate + except Exception: + mlx_lm = None + lm_stream_generate = None + + try: + from mlx_vlm import load as vlm_load + from mlx_vlm import stream_generate as vlm_stream_generate + except Exception: + vlm_load = None + vlm_stream_generate = None + + if mlx_lm is None and vlm_load is None: + raise RuntimeError( + "No MLX generation backend available. Install mlx-lm and/or mlx-vlm." + ) + + def likely_vision_model(model_id: str) -> bool: + model_id = model_id.lower() + return any( + token in model_id + for token in ( + "qwen3.5", + "vision", + "multimodal", + "llava", + "internvl", + "gemma3", + ) + ) + + def looks_like_vision_weight_mismatch(exc: Exception) -> bool: + message = str(exc).lower() + return "vision_tower" in message or ( + "parameters not in model" in message and "vision" in message + ) + + backend = "mlx_lm" + stream_generate_fn = lm_stream_generate + + if likely_vision_model(mlx_model) and vlm_load is not None: + backend = "mlx_vlm" + stream_generate_fn = vlm_stream_generate + print(f" Loading MLX model ({backend}): {mlx_model}") + model, processor = vlm_load(mlx_model) + elif mlx_lm is not None: + try: + print(f" Loading MLX model ({backend}): {mlx_model}") + model, processor = mlx_lm.load(mlx_model) + except Exception as exc: + if vlm_load is None or not looks_like_vision_weight_mismatch(exc): + raise + backend = "mlx_vlm" + stream_generate_fn = vlm_stream_generate + print(f" Falling back to {backend} for: {mlx_model}") + model, processor = vlm_load(mlx_model) + else: + backend = "mlx_vlm" + stream_generate_fn = vlm_stream_generate + print(f" Loading MLX model ({backend}): {mlx_model}") + model, processor = vlm_load(mlx_model) + + # Load model once + # Warmup runs (model stays loaded, JIT compiles kernels) + if args.warmup_runs > 0: + print(f" Warming up MLX ({args.warmup_runs} runs)...") + for i in range(args.warmup_runs): + _ = next( + stream_generate_fn( + model, + processor, + prompt=args.prompt, + max_tokens=1, + sampler=lambda x: mx.argmax(x, axis=-1), + ) + ) + mx.synchronize() + + # Timed run + print(f" Running timed generation...") + + # Use stream_generate to get accurate per-token timings in a single pass + # This avoids running the prompt twice and eliminates tokenization overhead from the timing + start_time = time.perf_counter() + final_stats = None + output_text = "" + stream_kwargs = { + "prompt": args.prompt, + "max_tokens": args.max_tokens, + "sampler": lambda x: mx.argmax(x, axis=-1) if args.temp == 0 else None, + } + if backend == "mlx_vlm": + stream_kwargs.update({"temp": args.temp, "top_p": args.top_p}) + + for response in stream_generate_fn(model, processor, **stream_kwargs): + output_text += response.text + final_stats = response + + mx.synchronize() + total_time = time.perf_counter() - start_time + + if final_stats is None: + raise RuntimeError("Generation produced no output.") + + num_prompt_tokens = final_stats.prompt_tokens + gen_tokens = final_stats.generation_tokens + prompt_tps = final_stats.prompt_tps + gen_tps = final_stats.generation_tps + + # Get peak memory + peak_mem_gb = None + try: + peak_mem_gb = mx.metal.get_peak_memory() / (1024**3) + except: + try: + peak_mem_gb = mx.gpu.get_peak_memory() / (1024**3) + except: + try: + peak_mem_gb = mx.get_peak_memory() / (1024**3) + except: + pass + + if args.show_raw_output: + print(f" Output: {output_text[:200]}...") + print(f" Prompt: {num_prompt_tokens} tokens, {prompt_tps:.2f} tok/s") + print(f" Generation: {gen_tokens} tokens, {gen_tps:.2f} tok/s") + + return RunStats( + variant=variant, + backend="mlx", + model=mlx_model, + prompt_tokens=num_prompt_tokens, + prompt_tps=prompt_tps, + gen_tokens=gen_tokens, + gen_tps=gen_tps, + peak_mem_gb=peak_mem_gb, + ) + except Exception as e: + import traceback + + traceback.print_exc() + return RunStats( + variant=variant, + backend="mlx", + model=mlx_model, + error=str(e), + ) + + +def run_llama_cli( + cfg: dict[str, str], variant: str, args: argparse.Namespace +) -> RunStats: + model_name = ( + cfg.get("gguf_path") + or cfg.get("llama_hf") + or (f"{cfg.get('gguf_repo', 'n/a')}:{cfg.get('gguf_filename', 'n/a')}") + ) + + cmd = [ + args.llama_cli_path, + "--prompt", + args.prompt, + "--n-predict", + str(args.max_tokens), + "--temp", + str(args.temp), + "--top-k", + str(args.top_k), + "--top-p", + str(args.top_p), + "--min-p", + str(args.min_p), + "--seed", + str(args.seed), + "--ctx-size", + str(args.llama_n_ctx), + "--batch-size", + str(args.llama_n_batch), + "--gpu-layers", + str(args.llama_n_gpu_layers), + "--simple-io", + "--no-mmap", + "--no-display-prompt", + "--no-conversation", + "--perf", + "-fa", + "1", + ] + + if args.llama_n_threads is not None: + cmd.extend(["--threads", str(args.llama_n_threads)]) + + gguf_path = cfg.get("gguf_path") + if gguf_path: + cmd.extend(["--model", gguf_path]) + elif cfg.get("llama_hf"): + cmd.extend(["-hf", cfg["llama_hf"]]) + else: + gguf_repo = cfg.get("gguf_repo") + gguf_filename = cfg.get("gguf_filename") + if not gguf_repo or not gguf_filename: + return RunStats( + variant=variant, + backend="llama", + model=model_name, + error=( + "Variant must provide one of: gguf_path, llama_hf, or " + "(gguf_repo + gguf_filename) for llama-completion" + ), + ) + cmd.extend(["--hf-repo", gguf_repo, "--hf-file", gguf_filename]) + + try: + output = run_command(cmd) + if args.show_raw_output: + print(output) + return parse_llama_cli_stats(output, variant=variant, model=model_name) + except Exception as e: + return RunStats( + variant=variant, + backend="llama", + model=model_name, + error=str(e), + ) + + +def format_row(cols: list[str], widths: list[int]) -> str: + return " | ".join(col.ljust(width) for col, width in zip(cols, widths)) + + +def print_results_table(results: list[RunStats]) -> None: + headers = [ + "variant", + "backend", + "prompt_tok/s", + "decode_tok/s", + "prompt_tok", + "gen_tok", + "peak_gb", + "status", + ] + + rows: list[list[str]] = [] + for r in results: + rows.append( + [ + r.variant, + r.backend, + maybe_fmt_float(r.prompt_tps, 3), + maybe_fmt_float(r.gen_tps, 3), + maybe_fmt_int(r.prompt_tokens), + maybe_fmt_int(r.gen_tokens), + maybe_fmt_float(r.peak_mem_gb, 3), + "ok" if r.error is None else "error", + ] + ) + + widths = [len(h) for h in headers] + for row in rows: + for i, col in enumerate(row): + widths[i] = max(widths[i], len(col)) + + print("\n=== Benchmark results ===") + print(format_row(headers, widths)) + print("-+-".join("-" * w for w in widths)) + for row in rows: + print(format_row(row, widths)) + + +def print_results_table_compact(results: list[RunStats], variants: list[str]) -> None: + backend_names = {"llama": "llama", "mlx": "mlx"} + + headers = [ + "variant", + "backend", + "prompt_tps", + "decode_tps", + "p_tok", + "g_tok", + "mem_gb", + "status", + ] + rows: list[list[str]] = [] + + for r in results: + rows.append( + [ + r.variant, + backend_names.get(r.backend, r.backend), + maybe_fmt_float(r.prompt_tps, 2), + maybe_fmt_float(r.gen_tps, 2), + maybe_fmt_int(r.prompt_tokens), + maybe_fmt_int(r.gen_tokens), + maybe_fmt_float(r.peak_mem_gb, 1), + "ok" if r.error is None else "er", + ] + ) + + widths = [len(h) for h in headers] + for row in rows: + for i, col in enumerate(row): + widths[i] = max(widths[i], len(col)) + + print("\n=== Results (compact) ===") + print(format_row(headers, widths)) + print("-+-".join("-" * w for w in widths)) + for row in rows: + print(format_row(row, widths)) + + +def print_comparison( + results: list[RunStats], variants: list[str], compact: bool = False +) -> None: + by_variant: dict[str, dict[str, RunStats]] = {} + for r in results: + by_variant.setdefault(r.variant, {})[r.backend] = r + + print("\n=== Decode ratio (MLX / llama-completion) ===") + for variant in variants: + mlx = by_variant.get(variant, {}).get("mlx") + llama = by_variant.get(variant, {}).get("llama") + label = variant + if not mlx or not llama: + print(f"- {label}: n/a") + continue + if mlx.error or llama.error: + print(f"- {label}: n/a (one or both runs failed)") + continue + if not mlx.gen_tps or not llama.gen_tps: + print(f"- {label}: n/a (missing decode stats)") + continue + ratio = mlx.gen_tps / llama.gen_tps + if compact: + print( + f"- {label}: {ratio:.3f}x ({mlx.gen_tps:.2f}/{llama.gen_tps:.2f} tok/s)" + ) + else: + print( + f"- {label}: {ratio:.3f}x " + f"(mlx {mlx.gen_tps:.3f} tok/s vs llama {llama.gen_tps:.3f} tok/s)" + ) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=( + "Benchmark MLX generate CLI vs llama-completion across model variants." + ) + ) + parser.add_argument("--prompt", default=DEFAULT_PROMPT) + parser.add_argument("--max-tokens", type=int, default=1000) + + parser.add_argument("--temp", type=float, default=0.0) + parser.add_argument("--top-k", type=int, default=1) + parser.add_argument("--top-p", type=float, default=1.0) + parser.add_argument("--min-p", type=float, default=0.0) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument( + "--warmup-runs", + type=int, + default=2, + help="Number of warmup runs for MLX (default: 2). Use 0 to disable.", + ) + + parser.add_argument( + "--variants", + nargs="*", + default=["all"], + help="Variant keys from MODEL_VARIANTS. Use 'all' for every variant.", + ) + parser.add_argument( + "--list-variants", + action="store_true", + help="List variants and exit.", + ) + + parser.add_argument("--llama-n-ctx", type=int, default=8192) + parser.add_argument("--llama-n-batch", type=int, default=2048) + parser.add_argument("--llama-n-gpu-layers", type=int, default=-1) + parser.add_argument("--llama-n-threads", type=int, default=None) + parser.add_argument( + "--llama-cli-path", + default="llama-completion", + help="Path to the llama-completion executable.", + ) + + parser.add_argument( + "--show-raw-output", + action="store_true", + help="Print raw MLX CLI output for each run.", + ) + parser.add_argument( + "--table-mode", + choices=["compact", "full"], + default="full", + help="Table format: full (default) or compact.", + ) + return parser.parse_args() + + +def resolve_variants(arg_variants: list[str]) -> list[str]: + if len(arg_variants) == 1 and arg_variants[0] == "all": + return list(MODEL_VARIANTS.keys()) + + unknown = [v for v in arg_variants if v not in MODEL_VARIANTS] + if unknown: + raise ValueError( + f"Unknown variant(s): {', '.join(unknown)}. " + f"Known: {', '.join(MODEL_VARIANTS.keys())}" + ) + return arg_variants + + +def list_variants() -> None: + print("Available variants:") + for key, cfg in MODEL_VARIANTS.items(): + mlx_repo = cfg.get("mlx_repo", "n/a") + gguf = ( + cfg.get("gguf_path") + or cfg.get("llama_hf") + or (f"{cfg.get('gguf_repo', 'n/a')}:{cfg.get('gguf_filename', 'n/a')}") + ) + print(f"- {key}") + print(f" mlx: {mlx_repo}") + print(f" llama: {gguf}") + + +def main() -> int: + args = parse_args() + + if args.list_variants: + list_variants() + return 0 + + try: + variants = resolve_variants(args.variants) + except ValueError as e: + print(f"ERROR: {e}", file=sys.stderr) + return 2 + + print("Running benchmark with shared decode settings:") + prompt_summary = args.prompt[:50] + "..." if len(args.prompt) > 50 else args.prompt + print(f"- prompt: {prompt_summary!r} (total {len(args.prompt)} chars)") + print(f"- max_tokens: {args.max_tokens}") + print( + f"- sampling: temp={args.temp}, top_k={args.top_k}, " + f"top_p={args.top_p}, min_p={args.min_p}, seed={args.seed}" + ) + print("- execution: strictly serial (no concurrent model loads)") + print(f"- variants: {', '.join(variants)}") + + results: list[RunStats] = [] + for variant in variants: + cfg = MODEL_VARIANTS[variant] + print(f"\n--- Variant: {variant} ---") + results.append(run_llama_cli(cfg, variant, args)) + results.append(run_mlx(cfg, variant, args)) + + if args.table_mode == "compact": + print_results_table_compact(results, variants) + else: + print_results_table(results) + print_comparison(results, variants, compact=(args.table_mode == "compact")) + + errors = [r for r in results if r.error] + if errors: + print("\n=== Errors ===") + for r in errors: + print(f"- {r.variant} [{r.backend}]: {r.error}") + return 1 + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/benchmarks/python/qwen3_quantized_generate_bench.py b/benchmarks/python/qwen3_quantized_generate_bench.py new file mode 100644 index 0000000000..1588623da6 --- /dev/null +++ b/benchmarks/python/qwen3_quantized_generate_bench.py @@ -0,0 +1,259 @@ +# Copyright © 2026 Apple Inc. + +"""Benchmark Qwen3-0.6B bf16 and quantized generation throughput. + +Example: + python benchmarks/python/qwen3_quantized_generate_bench.py +""" + +from __future__ import annotations + +import argparse +import statistics +import time +from dataclasses import dataclass +from typing import Callable + +import mlx.core as mx + +try: + from mlx_lm import load as lm_load + from mlx_lm.generate import stream_generate as lm_stream_generate +except Exception: # pragma: no cover + lm_load = None + lm_stream_generate = None + +try: + from mlx_vlm import load as vlm_load + from mlx_vlm import stream_generate as vlm_stream_generate +except Exception: # pragma: no cover + vlm_load = None + vlm_stream_generate = None + +if lm_load is None and vlm_load is None: # pragma: no cover + raise RuntimeError( + "No generation backend available. Install mlx-lm and/or mlx-vlm." + ) + + +DEFAULT_MODELS = ( + "mlx-community/Qwen3-0.6B-bf16", + "mlx-community/Qwen3-0.6B-4bit", + "mlx-community/Qwen3-0.6B-8bit", +) + +DEFAULT_PROMPT = "Explain matrix multiplication in one short paragraph." + + +@dataclass +class RunStats: + wall_s: float + prompt_tokens: int + prompt_tps: float + generation_tokens: int + generation_tps: float + + +def greedy_sampler(logprobs: mx.array) -> mx.array: + return mx.argmax(logprobs, axis=-1) + + +def _is_likely_vision_model(model_id: str) -> bool: + model_id = model_id.lower() + return any( + token in model_id + for token in ( + "qwen3.5", + "vision", + "multimodal", + "llava", + "internvl", + "gemma3", + ) + ) + + +def _looks_like_vision_weight_mismatch(exc: Exception) -> bool: + message = str(exc).lower() + return "vision_tower" in message or ( + "parameters not in model" in message and "vision" in message + ) + + +def load_with_backend( + model_id: str, +) -> tuple[object, object, Callable[..., object], str]: + if _is_likely_vision_model(model_id) and vlm_load is not None: + model, processor = vlm_load(model_id) + return model, processor, vlm_stream_generate, "mlx_vlm" + + if lm_load is not None: + try: + model, tokenizer = lm_load(model_id) + return model, tokenizer, lm_stream_generate, "mlx_lm" + except Exception as exc: + if vlm_load is not None and _looks_like_vision_weight_mismatch(exc): + model, processor = vlm_load(model_id) + return model, processor, vlm_stream_generate, "mlx_vlm" + raise + + if vlm_load is not None: + model, processor = vlm_load(model_id) + return model, processor, vlm_stream_generate, "mlx_vlm" + + raise RuntimeError("Unable to load model with mlx-lm or mlx-vlm.") + + +def run_once( + model, + processor, + stream_fn: Callable[..., object], + prompt: str, + max_tokens: int, +) -> RunStats: + start = time.perf_counter() + final = None + for response in stream_fn( + model, + processor, + prompt=prompt, + max_tokens=max_tokens, + sampler=greedy_sampler, + ): + final = response + wall_s = time.perf_counter() - start + + if final is None: + raise RuntimeError("Generation produced no output.") + + return RunStats( + wall_s=wall_s, + prompt_tokens=final.prompt_tokens, + prompt_tps=final.prompt_tps, + generation_tokens=final.generation_tokens, + generation_tps=final.generation_tps, + ) + + +def summarize(values: list[float]) -> tuple[float, float]: + mean = statistics.fmean(values) + stdev = statistics.stdev(values) if len(values) > 1 else 0.0 + return mean, stdev + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument( + "--models", + nargs="+", + default=list(DEFAULT_MODELS), + help="Model ids to benchmark.", + ) + parser.add_argument( + "--prompt", + default=DEFAULT_PROMPT, + help="Prompt text for generation.", + ) + parser.add_argument( + "--max-tokens", + type=int, + default=64, + help="Maximum generated tokens.", + ) + parser.add_argument( + "--warmup-runs", + type=int, + default=1, + help="Warmup runs before timed runs.", + ) + parser.add_argument( + "--runs", + type=int, + default=3, + help="Timed runs per model.", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="Random seed used before each run.", + ) + parser.add_argument( + "--device", + choices=("gpu", "cpu"), + default="gpu", + help="MLX device to run on.", + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + + device = mx.gpu if args.device == "gpu" else mx.cpu + mx.set_default_device(device) + + print(f"device={args.device} max_tokens={args.max_tokens} runs={args.runs}") + print(f"prompt={args.prompt!r}") + print() + + for model_id in args.models: + print(f"=== {model_id} ===") + + load_start = time.perf_counter() + model, processor, stream_fn, backend = load_with_backend(model_id) + load_s = time.perf_counter() - load_start + print(f"load_s={load_s:.3f} backend={backend}") + + for _ in range(args.warmup_runs): + mx.random.seed(args.seed) + _ = run_once(model, processor, stream_fn, args.prompt, args.max_tokens) + + runs: list[RunStats] = [] + for run_idx in range(args.runs): + mx.random.seed(args.seed + run_idx) + runs.append( + run_once(model, processor, stream_fn, args.prompt, args.max_tokens) + ) + + wall_mean, wall_std = summarize([r.wall_s for r in runs]) + gen_tps_mean, gen_tps_std = summarize([r.generation_tps for r in runs]) + prompt_tps_mean, prompt_tps_std = summarize([r.prompt_tps for r in runs]) + eff_gen_tps_mean, eff_gen_tps_std = summarize( + [r.generation_tokens / r.wall_s for r in runs] + ) + + print( + "prompt_tokens={} generation_tokens={}".format( + runs[-1].prompt_tokens, + runs[-1].generation_tokens, + ) + ) + print( + "prompt_tps_mean={:.2f} prompt_tps_std={:.2f}".format( + prompt_tps_mean, + prompt_tps_std, + ) + ) + print( + "generation_tps_mean={:.2f} generation_tps_std={:.2f}".format( + gen_tps_mean, + gen_tps_std, + ) + ) + print( + "effective_gen_tps_mean={:.2f} effective_gen_tps_std={:.2f}".format( + eff_gen_tps_mean, + eff_gen_tps_std, + ) + ) + print("wall_s_mean={:.3f} wall_s_std={:.3f}".format(wall_mean, wall_std)) + print() + + del model + del processor + mx.clear_cache() + + +if __name__ == "__main__": + main() diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index 0df4f42349..567c2e7210 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -103,7 +103,16 @@ else() PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda/no_cuda.cpp) endif() -if(MLX_BUILD_METAL OR MLX_BUILD_CUDA) +if(MLX_BUILD_ROCM) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/rocm) +else() + target_sources(mlx + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/rocm/no_rocm.cpp) +endif() + +if(MLX_BUILD_METAL + OR MLX_BUILD_CUDA + OR MLX_BUILD_ROCM) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu) else() add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu) diff --git a/mlx/backend/common/compiled.cpp b/mlx/backend/common/compiled.cpp index aceeb1f7fd..1a960f7519 100644 --- a/mlx/backend/common/compiled.cpp +++ b/mlx/backend/common/compiled.cpp @@ -84,13 +84,19 @@ std::string get_type_string(Dtype d) { bool compiled_check_contiguity( const std::vector& inputs, - const Shape& shape) { + const Shape& shape, + const std::function& is_constant) { bool contiguous = true; bool all_contig = true; bool all_row_contig = true; bool all_col_contig = true; int non_scalar_inputs = 0; - for (const auto& x : inputs) { + for (size_t i = 0; i < inputs.size(); ++i) { + // Skip constants. + if (is_constant(i)) { + continue; + } + const auto& x = inputs[i]; if (is_scalar(x)) { continue; } @@ -175,7 +181,7 @@ std::tuple> compiled_collapse_contiguous_dims( const array& out, const std::function& is_constant) { const Shape& shape = out.shape(); - bool contiguous = compiled_check_contiguity(inputs, shape); + bool contiguous = compiled_check_contiguity(inputs, shape, is_constant); if (contiguous) { return {true, shape, {}}; } diff --git a/mlx/backend/common/compiled.h b/mlx/backend/common/compiled.h index 84a3460459..7ac3bc8a38 100644 --- a/mlx/backend/common/compiled.h +++ b/mlx/backend/common/compiled.h @@ -69,7 +69,10 @@ inline bool is_scalar(const array& x) { // Check if we can use a contiguous operation given inputs and the output shape bool compiled_check_contiguity( const std::vector& inputs, - const Shape& shape); + const Shape& shape, + const std::function& is_constant = [](size_t) { + return false; + }); // Allocate space for the outputs possibly with input donation void compiled_allocate_outputs( diff --git a/mlx/backend/cuda/custom_kernel.cpp b/mlx/backend/cuda/custom_kernel.cpp index 3918d0fb45..8304120985 100644 --- a/mlx/backend/cuda/custom_kernel.cpp +++ b/mlx/backend/cuda/custom_kernel.cpp @@ -281,9 +281,19 @@ void CustomKernel::eval_gpu( std::vector copies; + // Output index -> aliased input index (output reuses the input's buffer). + std::vector alias_of(outputs.size(), -1); + for (auto& [oi, ii] : output_input_aliases_) { + if (oi >= 0 && oi < (int)outputs.size() && ii >= 0 && ii < (int)inputs.size()) + alias_of[oi] = ii; + } + // Allocate and initialize the output arrays - for (auto& out : outputs) { - if (init_value_) { + for (size_t i = 0; i < outputs.size(); ++i) { + auto& out = outputs[i]; + if (alias_of[i] >= 0) { + out.copy_shared_buffer(inputs[alias_of[i]]); + } else if (init_value_) { copies.emplace_back(init_value_.value(), out.dtype()); fill_gpu(copies.back(), out, s); } else { diff --git a/mlx/backend/gpu/primitives.cpp b/mlx/backend/gpu/primitives.cpp index 268d6290bf..6f3d3f4923 100644 --- a/mlx/backend/gpu/primitives.cpp +++ b/mlx/backend/gpu/primitives.cpp @@ -18,6 +18,13 @@ #define MLX_PROFILER_RANGE(message) #endif +#if defined(MLX_USE_ROCM) +namespace mlx::core::rocm { +// True from HIP-graph capture start until the captured graph is destroyed. +bool graph_active(); +} +#endif + namespace mlx::core { void AsStrided::eval_gpu(const std::vector& inputs, array& out) { @@ -125,12 +132,29 @@ void DynamicSliceUpdate::eval_gpu( return; } - // Copy or donate input to output + // Donate the input buffer when uniquely owned, else copy. During HIP-graph + // capture the async pipeline inflates the buffer's use_count, forcing a full + // copy into a FRESH buffer — the captured graph then reconstructs + // (frozen capture input + current row) every replay and loses accumulation + // (e.g. a growing KV cache → frozen/repeated tokens). For a contiguous, + // fully-materialized buffer the in-place donation is the intended semantics, + // so force it while a graph is being captured. auto s = stream(); - auto ctype = in.flags().contiguous && in.size() == in.data_size() - ? CopyType::Vector - : CopyType::General; - copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, s); + bool can_donate = in.data_shared_ptr() != nullptr && in.flags().contiguous && + in.data_size() == in.size() && + (in.data_shared_ptr().use_count() == 1 +#if defined(MLX_USE_ROCM) + || mlx::core::rocm::graph_active() +#endif + ); + if (can_donate) { + out.copy_shared_buffer(in); + } else { + auto ctype = in.flags().contiguous && in.size() == in.data_size() + ? CopyType::Vector + : CopyType::General; + copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, s); + } auto out_offset = compute_dynamic_offset(start_indices, out.strides(), axes_, s); diff --git a/mlx/backend/metal/custom_kernel.cpp b/mlx/backend/metal/custom_kernel.cpp index 6d33ff5007..a3f21fc793 100644 --- a/mlx/backend/metal/custom_kernel.cpp +++ b/mlx/backend/metal/custom_kernel.cpp @@ -335,8 +335,18 @@ void CustomKernel::eval_gpu( std::vector copies; - for (auto& out : outputs) { - if (init_value_) { + // Output index -> aliased input index (output reuses the input's buffer). + std::vector alias_of(outputs.size(), -1); + for (auto& [oi, ii] : output_input_aliases_) { + if (oi >= 0 && oi < (int)outputs.size() && ii >= 0 && ii < (int)inputs.size()) + alias_of[oi] = ii; + } + + for (size_t i = 0; i < outputs.size(); ++i) { + auto& out = outputs[i]; + if (alias_of[i] >= 0) { + out.copy_shared_buffer(inputs[alias_of[i]]); + } else if (init_value_) { copies.emplace_back(init_value_.value(), out.dtype()); fill_gpu(copies.back(), out, s); } else { diff --git a/mlx/backend/rocm/CMakeLists.txt b/mlx/backend/rocm/CMakeLists.txt new file mode 100644 index 0000000000..3fce8d6450 --- /dev/null +++ b/mlx/backend/rocm/CMakeLists.txt @@ -0,0 +1,325 @@ +# Filename rules in ROCm backend: +# +# * Use .hip/.hpp if code contains device code, and .cpp/.h if not. +# * Device-only code should be put in device/ subdir. +# * Files in device/ subdir should not include files outside. + +# Find ROCm packages +find_package(hip REQUIRED CONFIG) +find_package(rocblas REQUIRED CONFIG) +find_package(rocthrust REQUIRED CONFIG) +find_package(rocprim REQUIRED CONFIG) +find_package(hiprand REQUIRED CONFIG) +# Ensure HIP architectures are set - respect user-provided value from command +# line The user can set this via -DCMAKE_HIP_ARCHITECTURES=gfx1011 +# +# Supported architectures from ROCm 6.4.0 - 7.2.0 compatibility matrix: CDNA: +# gfx908 (MI100), gfx90a (MI200), gfx942 (MI300) RDNA2: gfx1030 (RX 6000 +# series) RDNA3: gfx1100 (RX 7900), gfx1101 (RX 7600) RDNA3.5: gfx1150, +# gfx1151, gfx1152 (Ryzen AI / Radeon 8060S) RDNA4: gfx1200, gfx1201 (RX 9000 +# series) +if(NOT CMAKE_HIP_ARCHITECTURES) + set(CMAKE_HIP_ARCHITECTURES + "gfx908;gfx90a;gfx942;gfx1010;gfx1011;gfx1012;gfx1030;gfx1031;gfx1032;gfx1100;gfx1101;gfx1102;gfx1150;gfx1151;gfx1152;gfx1200;gfx1201" + CACHE STRING "HIP architectures" FORCE) +endif() +message( + STATUS "ROCm backend using HIP architectures: ${CMAKE_HIP_ARCHITECTURES}") + +# Check if any target architecture supports WMMA (RDNA 3 / gfx11xx and RDNA 4 / +# gfx12xx) +set(MLX_HAS_ROCM_WMMA OFF) +foreach(arch ${CMAKE_HIP_ARCHITECTURES}) + if(arch MATCHES "^gfx1[12]") + set(MLX_HAS_ROCM_WMMA ON) + break() + endif() +endforeach() +message(STATUS "ROCm WMMA support: ${MLX_HAS_ROCM_WMMA}") + +if(MLX_HAS_ROCM_WMMA) + find_package(rocwmma REQUIRED CONFIG) +endif() + +# Build architecture flags +set(HIP_ARCH_FLAGS "") +foreach(arch ${CMAKE_HIP_ARCHITECTURES}) + list(APPEND HIP_ARCH_FLAGS "--offload-arch=${arch}") +endforeach() + +# Get HIP include directories +get_target_property(HIP_DEVICE_INCLUDES hip::device + INTERFACE_INCLUDE_DIRECTORIES) +get_target_property(ROCTHRUST_INCLUDES roc::rocthrust + INTERFACE_INCLUDE_DIRECTORIES) +get_target_property(ROCPRIM_INCLUDES roc::rocprim INTERFACE_INCLUDE_DIRECTORIES) +get_target_property(HIPRAND_INCLUDES hip::hiprand INTERFACE_INCLUDE_DIRECTORIES) +if(MLX_HAS_ROCM_WMMA) + get_target_property(ROCWMMA_INCLUDES roc::rocwmma + INTERFACE_INCLUDE_DIRECTORIES) +endif() + +# Find GCC installation for C++ standard library headers ROCm's clang needs to +# know where to find libstdc++ headers +execute_process( + COMMAND ${CMAKE_CXX_COMPILER} -print-file-name=include/c++ + OUTPUT_VARIABLE GCC_CXX_INCLUDE_BASE + OUTPUT_STRIP_TRAILING_WHITESPACE) +get_filename_component(GCC_CXX_INCLUDE_BASE "${GCC_CXX_INCLUDE_BASE}" DIRECTORY) + +# Get GCC version for the target-specific include directory +execute_process( + COMMAND ${CMAKE_CXX_COMPILER} -dumpversion + OUTPUT_VARIABLE GCC_VERSION + OUTPUT_STRIP_TRAILING_WHITESPACE) +string(REGEX MATCH "^[0-9]+" GCC_MAJOR_VERSION "${GCC_VERSION}") + +# Build include flags - use PROJECT_SOURCE_DIR for correct path +set(HIP_INCLUDE_FLAGS "-I${PROJECT_SOURCE_DIR}" "-I${HIP_INCLUDE_DIRS}") + +# Add C++ standard library include paths for HIP compiler +if(EXISTS "${GCC_CXX_INCLUDE_BASE}/c++/${GCC_MAJOR_VERSION}") + list(APPEND HIP_INCLUDE_FLAGS + "-I${GCC_CXX_INCLUDE_BASE}/c++/${GCC_MAJOR_VERSION}") + list(APPEND HIP_INCLUDE_FLAGS + "-I${GCC_CXX_INCLUDE_BASE}/c++/${GCC_MAJOR_VERSION}/x86_64-linux-gnu") + list(APPEND HIP_INCLUDE_FLAGS + "-I${GCC_CXX_INCLUDE_BASE}/c++/${GCC_MAJOR_VERSION}/backward") +endif() + +# Also try to find system include directories +if(EXISTS "/usr/include/c++/${GCC_MAJOR_VERSION}") + list(APPEND HIP_INCLUDE_FLAGS "-I/usr/include/c++/${GCC_MAJOR_VERSION}") + list(APPEND HIP_INCLUDE_FLAGS + "-I/usr/include/x86_64-linux-gnu/c++/${GCC_MAJOR_VERSION}") + list(APPEND HIP_INCLUDE_FLAGS + "-I/usr/include/c++/${GCC_MAJOR_VERSION}/backward") +endif() + +# Add standard system include paths +list(APPEND HIP_INCLUDE_FLAGS "-I/usr/include/x86_64-linux-gnu") +list(APPEND HIP_INCLUDE_FLAGS "-I/usr/include") + +foreach(inc ${HIP_DEVICE_INCLUDES}) + if(inc) + list(APPEND HIP_INCLUDE_FLAGS "-I${inc}") + endif() +endforeach() +foreach(inc ${ROCTHRUST_INCLUDES}) + if(inc) + list(APPEND HIP_INCLUDE_FLAGS "-I${inc}") + endif() +endforeach() +foreach(inc ${ROCPRIM_INCLUDES}) + if(inc) + list(APPEND HIP_INCLUDE_FLAGS "-I${inc}") + endif() +endforeach() +foreach(inc ${HIPRAND_INCLUDES}) + if(inc) + list(APPEND HIP_INCLUDE_FLAGS "-I${inc}") + endif() +endforeach() +if(MLX_HAS_ROCM_WMMA) + foreach(inc ${ROCWMMA_INCLUDES}) + if(inc) + list(APPEND HIP_INCLUDE_FLAGS "-I${inc}") + endif() + endforeach() +endif() + +message(STATUS "HIP include flags: ${HIP_INCLUDE_FLAGS}") + +# HIP source files +set(HIP_SOURCES + ${CMAKE_CURRENT_SOURCE_DIR}/event.hip + ${CMAKE_CURRENT_SOURCE_DIR}/arange.hip + ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.hip + ${CMAKE_CURRENT_SOURCE_DIR}/binary.hip + ${CMAKE_CURRENT_SOURCE_DIR}/binary_two.hip + ${CMAKE_CURRENT_SOURCE_DIR}/copy.hip + ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.hip + ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.hip + ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.hip + ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.hip + ${CMAKE_CURRENT_SOURCE_DIR}/distributed.hip + ${CMAKE_CURRENT_SOURCE_DIR}/indexing.hip + ${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.hip + ${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.hip + ${CMAKE_CURRENT_SOURCE_DIR}/random.hip + ${CMAKE_CURRENT_SOURCE_DIR}/reduce.hip + ${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.hip + ${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.hip + ${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.hip + ${CMAKE_CURRENT_SOURCE_DIR}/reduce/init_reduce.hip + ${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.hip + ${CMAKE_CURRENT_SOURCE_DIR}/rope.hip + ${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.hip + ${CMAKE_CURRENT_SOURCE_DIR}/flash_attention.hip + ${CMAKE_CURRENT_SOURCE_DIR}/scan.hip + ${CMAKE_CURRENT_SOURCE_DIR}/softmax.hip + ${CMAKE_CURRENT_SOURCE_DIR}/sort.hip + ${CMAKE_CURRENT_SOURCE_DIR}/ternary.hip + ${CMAKE_CURRENT_SOURCE_DIR}/unary.hip + ${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.hip + ${CMAKE_CURRENT_SOURCE_DIR}/gemms/naive_gemm.hip + ${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.hip + ${CMAKE_CURRENT_SOURCE_DIR}/quantized/fp_quantize.hip + ${CMAKE_CURRENT_SOURCE_DIR}/quantized/convert_fp8.hip + ${CMAKE_CURRENT_SOURCE_DIR}/quantized/qmm.hip + ${CMAKE_CURRENT_SOURCE_DIR}/quantized/qmv_tiled_kernel.hip + ${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_conv.hip) + +if(MLX_HAS_ROCM_WMMA) + list(APPEND HIP_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/flash_attention_wmma.hip) +endif() + +# Create output directory for compiled objects +set(HIP_OBJ_DIR "${CMAKE_CURRENT_BINARY_DIR}/hip_objs") +file(MAKE_DIRECTORY ${HIP_OBJ_DIR}) + +# Detect CPU count for parallel HIP offload compilation Use half of available +# CPUs for parallel HIP offload compilation per file (Ninja already parallelizes +# across files, so this avoids oversubscription) +include(ProcessorCount) +ProcessorCount(NPROC) +if(NPROC EQUAL 0) + set(NPROC 4) +else() + math(EXPR NPROC "${NPROC} / 2") + if(NPROC LESS 2) + set(NPROC 2) + endif() +endif() + +# Compile each HIP file to object file using custom commands Use -fno-gpu-rdc to +# avoid needing device link step +set(HIP_OBJECTS "") +foreach(hip_src ${HIP_SOURCES}) + get_filename_component(hip_name ${hip_src} NAME_WE) + get_filename_component(hip_dir ${hip_src} DIRECTORY) + file(RELATIVE_PATH rel_dir ${CMAKE_CURRENT_SOURCE_DIR} ${hip_dir}) + + # Create subdirectory for object if needed + if(rel_dir) + set(obj_subdir "${HIP_OBJ_DIR}/${rel_dir}") + file(MAKE_DIRECTORY ${obj_subdir}) + set(hip_obj "${obj_subdir}/${hip_name}.o") + else() + set(hip_obj "${HIP_OBJ_DIR}/${hip_name}.o") + endif() + + add_custom_command( + OUTPUT ${hip_obj} + COMMAND + ${CMAKE_HIP_COMPILER} -c ${hip_src} -o ${hip_obj} -fPIC -DMLX_USE_ROCM + ${HIP_ARCH_FLAGS} ${HIP_INCLUDE_FLAGS} -std=c++17 -parallel-jobs=${NPROC} + DEPENDS ${hip_src} + COMMENT "Compiling HIP source ${hip_src}" + VERBATIM) + + list(APPEND HIP_OBJECTS ${hip_obj}) +endforeach() + +# Create a custom target for all HIP objects +add_custom_target(mlx_hip_objects DEPENDS ${HIP_OBJECTS}) + +# Create static library from all objects (no device link needed without +# -fgpu-rdc) +set(HIP_STATIC_LIB "${CMAKE_CURRENT_BINARY_DIR}/libmlx_rocm_kernels.a") +add_custom_command( + OUTPUT ${HIP_STATIC_LIB} + COMMAND ${CMAKE_AR} rcs ${HIP_STATIC_LIB} ${HIP_OBJECTS} + DEPENDS ${HIP_OBJECTS} + COMMENT "Creating static library from HIP objects" + VERBATIM) + +add_custom_target(mlx_rocm_kernels_lib DEPENDS ${HIP_STATIC_LIB}) + +# Add C++ sources directly to mlx target +target_sources( + mlx + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/device_info.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/rocm.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/conv/conv.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/gemms/rocblas_gemm.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/gemms/hipblaslt_gemm.cpp) + +target_compile_definitions(mlx PRIVATE MLX_USE_ROCM) +if(MLX_HAS_ROCM_WMMA) + target_compile_definitions(mlx PRIVATE MLX_HAS_ROCM_WMMA) +endif() + +# Make mlx depend on the HIP kernels library +add_dependencies(mlx mlx_rocm_kernels_lib) + +# Get the library paths from the imported targets (without propagating compile +# options) +get_target_property(ROCBLAS_LIB roc::rocblas IMPORTED_LOCATION) +if(NOT ROCBLAS_LIB) + get_target_property(ROCBLAS_LIB roc::rocblas IMPORTED_LOCATION_RELEASE) +endif() +if(NOT ROCBLAS_LIB) + # Fallback to finding the library directly + find_library(ROCBLAS_LIB rocblas PATHS ${ROCM_PATH}/lib /opt/rocm/lib + /opt/rocm-6.0.0/lib) +endif() + +get_target_property(HIPRAND_LIB hip::hiprand IMPORTED_LOCATION) +if(NOT HIPRAND_LIB) + get_target_property(HIPRAND_LIB hip::hiprand IMPORTED_LOCATION_RELEASE) +endif() +if(NOT HIPRAND_LIB) + find_library(HIPRAND_LIB hiprand PATHS ${ROCM_PATH}/lib /opt/rocm/lib + /opt/rocm-6.0.0/lib) +endif() + +# Find amdhip64 library +find_library(AMDHIP64_LIB amdhip64 PATHS ${ROCM_PATH}/lib /opt/rocm/lib + /opt/rocm-6.0.0/lib) + +# Find hiprtc library (needed for JIT compilation) +find_library(HIPRTC_LIB hiprtc PATHS ${ROCM_PATH}/lib /opt/rocm/lib + /opt/rocm-6.0.0/lib) + +# Find hipBLASLt library (optimized GEMM for half-precision) +find_library(HIPBLASLT_LIB hipblaslt PATHS ${ROCM_PATH}/lib /opt/rocm/lib + /opt/rocm-6.0.0/lib) + +message( + STATUS + "ROCm libraries: rocblas=${ROCBLAS_LIB}, hiprand=${HIPRAND_LIB}, amdhip64=${AMDHIP64_LIB}, hiprtc=${HIPRTC_LIB}, hipblaslt=${HIPBLASLT_LIB}" +) + +# Link the static library and ROCm libraries to mlx We link directly to the .so +# files instead of using CMake targets to avoid propagating compile options like +# -x hip +target_link_libraries( + mlx PRIVATE ${HIP_STATIC_LIB} ${AMDHIP64_LIB} ${ROCBLAS_LIB} ${HIPRAND_LIB} + ${HIPRTC_LIB} ${HIPBLASLT_LIB}) + +# Include ROCm headers for mlx C++ files Get the HIP include directory from the +# hip package +get_target_property(HIP_HOST_INCLUDES hip::host INTERFACE_INCLUDE_DIRECTORIES) +if(HIP_HOST_INCLUDES) + target_include_directories(mlx PRIVATE ${HIP_HOST_INCLUDES}) +endif() +target_include_directories(mlx PRIVATE ${HIP_INCLUDE_DIRS}) + +# Add HIP platform define for C++ files +target_compile_definitions(mlx PRIVATE __HIP_PLATFORM_AMD__=1) diff --git a/mlx/backend/rocm/all_reduce.hip b/mlx/backend/rocm/all_reduce.hip new file mode 100644 index 0000000000..4cdba9eacd --- /dev/null +++ b/mlx/backend/rocm/all_reduce.hip @@ -0,0 +1,323 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/allocator.h" +#include "mlx/backend/rocm/device/config.h" +#include "mlx/backend/rocm/reduce/reduce.hpp" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/backend/rocm/device/fp16_math.hpp" + +#include +#include + +namespace mlx::core { + +namespace rocm { + +// Helper to handle warp shuffle for different types +template +__device__ T warp_shfl_down_all(T val, int offset) { + return __shfl_down(val, offset); +} + +// Specialization for hip_bfloat16 - convert to float for shuffle +template <> +__device__ hip_bfloat16 warp_shfl_down_all(hip_bfloat16 val, int offset) { + float f = bf16_to_float(val); + f = __shfl_down(f, offset); + return float_to_bf16(f); +} + +// Specialization for __half - convert to float for shuffle +template <> +__device__ __half warp_shfl_down_all(__half val, int offset) { + float f = __half2float(val); + f = __shfl_down(f, offset); + return __float2half(f); +} + +template +__device__ U warp_reduce(U val, Op op) { + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + val = op(val, warp_shfl_down_all(val, offset)); + } + return val; +} + +template +__global__ void all_reduce_kernel( + const T* __restrict__ in, + U* __restrict__ out, + size_t block_step, + size_t size) { + __shared__ U shared_data[32]; + + const U init = ReduceInit::value(); + Op op; + + U acc = init; + + size_t start = blockIdx.x * block_step; + size_t end = min(start + block_step, size); + + // Each thread processes multiple elements + for (size_t i = start + threadIdx.x * N; i < end; i += blockDim.x * N) { + #pragma unroll + for (int j = 0; j < N && (i + j) < end; ++j) { + acc = op(acc, static_cast(in[i + j])); + } + } + + // Warp-level reduction + int lane = threadIdx.x % WARP_SIZE; + int warp_id = threadIdx.x / WARP_SIZE; + + acc = warp_reduce(acc, op); + + if (lane == 0) { + shared_data[warp_id] = acc; + } + __syncthreads(); + + // Final reduction by first warp + int num_warps = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE; + if (warp_id == 0) { + acc = (lane < num_warps) ? shared_data[lane] : init; + acc = warp_reduce(acc, op); + + if (lane == 0) { + out[blockIdx.x] = acc; + } + } +} + +} // namespace rocm + +void all_reduce( + rocm::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type) { + constexpr int N_READS = 4; + + out.set_data(mlx::core::rocm::malloc_async(out.nbytes(), encoder)); + + auto get_args = [](size_t size, int N) { + int threads = std::min(512, static_cast((size + N - 1) / N)); + threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + int reductions_per_step = threads * N; + size_t steps_needed = (size + reductions_per_step - 1) / reductions_per_step; + + int blocks; + if (steps_needed < 32) { + blocks = 1; + } else if (steps_needed < 128) { + blocks = 32; + } else if (steps_needed < 512) { + blocks = 128; + } else if (steps_needed < 1024) { + blocks = 512; + } else { + blocks = 1024; + } + + size_t steps_per_block = (steps_needed + blocks - 1) / blocks; + size_t block_step = steps_per_block * reductions_per_step; + + return std::make_tuple(blocks, threads, block_step); + }; + + int blocks, threads; + size_t block_step; + size_t insize = in.size(); + + std::tie(blocks, threads, block_step) = get_args(insize, N_READS); + + encoder.set_input_array(in); + encoder.set_output_array(out); + + // For multi-block reduction, we need an intermediate buffer + if (blocks > 1) { + array intermediate({blocks}, out.dtype(), nullptr, {}); + intermediate.set_data(mlx::core::rocm::malloc_async(intermediate.nbytes(), encoder)); + encoder.add_temporary(intermediate); + + // First pass: reduce to intermediate + encoder.launch_kernel([&](hipStream_t stream) { + #define LAUNCH_ALL_REDUCE(T, U, OP) \ + hipLaunchKernelGGL( \ + (rocm::all_reduce_kernel), \ + dim3(blocks), dim3(threads), 0, stream, \ + gpu_ptr(in), gpu_ptr(intermediate), block_step, insize) + + switch (in.dtype()) { + case float32: + switch (reduce_type) { + case Reduce::Sum: LAUNCH_ALL_REDUCE(float, float, Sum); break; + case Reduce::Prod: LAUNCH_ALL_REDUCE(float, float, Prod); break; + case Reduce::Max: LAUNCH_ALL_REDUCE(float, float, Max); break; + case Reduce::Min: LAUNCH_ALL_REDUCE(float, float, Min); break; + default: break; + } + break; + case float16: + switch (reduce_type) { + case Reduce::Sum: LAUNCH_ALL_REDUCE(__half, float, Sum); break; + case Reduce::Prod: LAUNCH_ALL_REDUCE(__half, float, Prod); break; + case Reduce::Max: LAUNCH_ALL_REDUCE(__half, __half, Max); break; + case Reduce::Min: LAUNCH_ALL_REDUCE(__half, __half, Min); break; + default: break; + } + break; + case int32: + switch (reduce_type) { + case Reduce::Sum: LAUNCH_ALL_REDUCE(int32_t, int32_t, Sum); break; + case Reduce::Prod: LAUNCH_ALL_REDUCE(int32_t, int32_t, Prod); break; + case Reduce::Max: LAUNCH_ALL_REDUCE(int32_t, int32_t, Max); break; + case Reduce::Min: LAUNCH_ALL_REDUCE(int32_t, int32_t, Min); break; + default: break; + } + break; + case int64: + switch (reduce_type) { + case Reduce::Sum: LAUNCH_ALL_REDUCE(int64_t, int64_t, Sum); break; + case Reduce::Prod: LAUNCH_ALL_REDUCE(int64_t, int64_t, Prod); break; + case Reduce::Max: LAUNCH_ALL_REDUCE(int64_t, int64_t, Max); break; + case Reduce::Min: LAUNCH_ALL_REDUCE(int64_t, int64_t, Min); break; + default: break; + } + break; + case bool_: + switch (reduce_type) { + case Reduce::And: LAUNCH_ALL_REDUCE(bool, bool, And); break; + case Reduce::Or: LAUNCH_ALL_REDUCE(bool, bool, Or); break; + default: break; + } + break; + default: + throw std::runtime_error("Unsupported type for all_reduce"); + } + #undef LAUNCH_ALL_REDUCE + }); + + // Second pass: reduce intermediate to output + std::tie(blocks, threads, block_step) = get_args(intermediate.size(), N_READS); + + encoder.launch_kernel([&](hipStream_t stream) { + #define LAUNCH_ALL_REDUCE_FINAL(T, U, OP) \ + hipLaunchKernelGGL( \ + (rocm::all_reduce_kernel), \ + dim3(1), dim3(threads), 0, stream, \ + gpu_ptr(intermediate), gpu_ptr(out), block_step, intermediate.size()) + + switch (out.dtype()) { + case float32: + switch (reduce_type) { + case Reduce::Sum: LAUNCH_ALL_REDUCE_FINAL(float, float, Sum); break; + case Reduce::Prod: LAUNCH_ALL_REDUCE_FINAL(float, float, Prod); break; + case Reduce::Max: LAUNCH_ALL_REDUCE_FINAL(float, float, Max); break; + case Reduce::Min: LAUNCH_ALL_REDUCE_FINAL(float, float, Min); break; + default: break; + } + break; + case float16: + switch (reduce_type) { + case Reduce::Sum: LAUNCH_ALL_REDUCE_FINAL(float, __half, Sum); break; + case Reduce::Prod: LAUNCH_ALL_REDUCE_FINAL(float, __half, Prod); break; + case Reduce::Max: LAUNCH_ALL_REDUCE_FINAL(__half, __half, Max); break; + case Reduce::Min: LAUNCH_ALL_REDUCE_FINAL(__half, __half, Min); break; + default: break; + } + break; + case int32: + switch (reduce_type) { + case Reduce::Sum: LAUNCH_ALL_REDUCE_FINAL(int32_t, int32_t, Sum); break; + case Reduce::Prod: LAUNCH_ALL_REDUCE_FINAL(int32_t, int32_t, Prod); break; + case Reduce::Max: LAUNCH_ALL_REDUCE_FINAL(int32_t, int32_t, Max); break; + case Reduce::Min: LAUNCH_ALL_REDUCE_FINAL(int32_t, int32_t, Min); break; + default: break; + } + break; + case int64: + switch (reduce_type) { + case Reduce::Sum: LAUNCH_ALL_REDUCE_FINAL(int64_t, int64_t, Sum); break; + case Reduce::Prod: LAUNCH_ALL_REDUCE_FINAL(int64_t, int64_t, Prod); break; + case Reduce::Max: LAUNCH_ALL_REDUCE_FINAL(int64_t, int64_t, Max); break; + case Reduce::Min: LAUNCH_ALL_REDUCE_FINAL(int64_t, int64_t, Min); break; + default: break; + } + break; + case bool_: + switch (reduce_type) { + case Reduce::And: LAUNCH_ALL_REDUCE_FINAL(bool, bool, And); break; + case Reduce::Or: LAUNCH_ALL_REDUCE_FINAL(bool, bool, Or); break; + default: break; + } + break; + default: + throw std::runtime_error("Unsupported type for all_reduce"); + } + #undef LAUNCH_ALL_REDUCE_FINAL + }); + } else { + // Single block reduction + encoder.launch_kernel([&](hipStream_t stream) { + #define LAUNCH_ALL_REDUCE_SINGLE(T, U, OP) \ + hipLaunchKernelGGL( \ + (rocm::all_reduce_kernel), \ + dim3(1), dim3(threads), 0, stream, \ + gpu_ptr(in), gpu_ptr(out), block_step, insize) + + switch (in.dtype()) { + case float32: + switch (reduce_type) { + case Reduce::Sum: LAUNCH_ALL_REDUCE_SINGLE(float, float, Sum); break; + case Reduce::Prod: LAUNCH_ALL_REDUCE_SINGLE(float, float, Prod); break; + case Reduce::Max: LAUNCH_ALL_REDUCE_SINGLE(float, float, Max); break; + case Reduce::Min: LAUNCH_ALL_REDUCE_SINGLE(float, float, Min); break; + default: break; + } + break; + case float16: + switch (reduce_type) { + case Reduce::Sum: LAUNCH_ALL_REDUCE_SINGLE(__half, __half, Sum); break; + case Reduce::Prod: LAUNCH_ALL_REDUCE_SINGLE(__half, __half, Prod); break; + case Reduce::Max: LAUNCH_ALL_REDUCE_SINGLE(__half, __half, Max); break; + case Reduce::Min: LAUNCH_ALL_REDUCE_SINGLE(__half, __half, Min); break; + default: break; + } + break; + case int32: + switch (reduce_type) { + case Reduce::Sum: LAUNCH_ALL_REDUCE_SINGLE(int32_t, int32_t, Sum); break; + case Reduce::Prod: LAUNCH_ALL_REDUCE_SINGLE(int32_t, int32_t, Prod); break; + case Reduce::Max: LAUNCH_ALL_REDUCE_SINGLE(int32_t, int32_t, Max); break; + case Reduce::Min: LAUNCH_ALL_REDUCE_SINGLE(int32_t, int32_t, Min); break; + default: break; + } + break; + case int64: + switch (reduce_type) { + case Reduce::Sum: LAUNCH_ALL_REDUCE_SINGLE(int64_t, int64_t, Sum); break; + case Reduce::Prod: LAUNCH_ALL_REDUCE_SINGLE(int64_t, int64_t, Prod); break; + case Reduce::Max: LAUNCH_ALL_REDUCE_SINGLE(int64_t, int64_t, Max); break; + case Reduce::Min: LAUNCH_ALL_REDUCE_SINGLE(int64_t, int64_t, Min); break; + default: break; + } + break; + case bool_: + switch (reduce_type) { + case Reduce::And: LAUNCH_ALL_REDUCE_SINGLE(bool, bool, And); break; + case Reduce::Or: LAUNCH_ALL_REDUCE_SINGLE(bool, bool, Or); break; + default: break; + } + break; + default: + throw std::runtime_error("Unsupported type for all_reduce"); + } + #undef LAUNCH_ALL_REDUCE_SINGLE + }); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/rocm/allocator.cpp b/mlx/backend/rocm/allocator.cpp new file mode 100644 index 0000000000..62652d6fea --- /dev/null +++ b/mlx/backend/rocm/allocator.cpp @@ -0,0 +1,1017 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/allocator.h" +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/utils.h" +#include "mlx/device.h" +#include "mlx/memory.h" +#include "mlx/utils.h" + +#include +#include + +#include +#include +#include +#include +#include + +namespace mlx::core { + +namespace rocm { + +constexpr int page_size = 16384; + +// Check if ROCm device is available +static bool rocm_available() { + static int available = -1; + if (available < 0) { + int device_count = 0; + hipError_t err = hipGetDeviceCount(&device_count); + available = (err == hipSuccess && device_count > 0) ? 1 : 0; + } + return available == 1; +} + +// Check if managed memory (HMM) is supported on this device. +static bool managed_memory_supported() { + static int supported = -1; + if (supported < 0) { + if (!rocm_available()) { + supported = 0; + } else { + void* test_ptr = nullptr; + hipError_t err = hipMallocManaged(&test_ptr, 64); + if (err == hipSuccess) { + (void)hipFree(test_ptr); + supported = 1; + } else { + supported = 0; + } + } + } + return supported == 1; +} + +static bool is_integrated() { + static int integrated = -1; + if (integrated < 0) { + if (!rocm_available()) { + integrated = 0; + } else { + int device = 0; + (void)hipGetDevice(&device); + hipDeviceProp_t props; + hipError_t err = hipGetDeviceProperties(&props, device); + integrated = (err == hipSuccess && props.integrated == 1) ? 1 : 0; + } + } + return integrated == 1; +} + +static bool device_is_integrated(int dev) { + static int cache[16] = {-1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1}; + if (dev < 0 || dev >= 16) + return false; + if (cache[dev] < 0) { + hipDeviceProp_t p; + cache[dev] = + (hipGetDeviceProperties(&p, dev) == hipSuccess && p.integrated == 1) ? 1 + : 0; + } + return cache[dev] == 1; +} + +static bool use_finegrained() { + if (const char* e = std::getenv("MLX_ROCM_FINEGRAINED")) + return std::atoi(e) != 0; + return true; +} + +// CUDA-style stream-ordered device pool (hipMallocAsync/hipFreeAsync). Always +// on where the device supports memory pools; allocations fall back to the +// unified path only for pool-less devices or stream-less requests. +static bool use_async_pool() { + return true; +} + +static int alloc_device_tag() { + return use_finegrained() ? -1 : 0; +} + +inline void* rocm_unified_malloc(size_t size, bool& is_managed) { + void* data = nullptr; + hipError_t err; + // Bind the alloc to the MLX-selected GPU. set_default_device(gpu,N) only sets + // MLX bookkeeping; it never calls hipSetDevice. Without this, allocations made + // OUTSIDE the eval path — notably the slab warmup at allocator construction — + // land on whatever device is current (device 0 at startup), so the model's + // small/intermediate tensors live on the APU while weights live on the dGPU. + // A dGPU kernel then reads APU memory across the (TB5) link and hangs. Use raw + // hipSetDevice (NOT device().make_current(), whose Device construction + device- + // flags loop faults against device-0's already-created context). + { + mlx::core::Device dd = mlx::core::default_device(); + if (dd.type == mlx::core::Device::gpu) { + int cur = -1; + if (hipGetDevice(&cur) == hipSuccess && cur != dd.index) + (void)hipSetDevice(dd.index); + } + } + if (size > (16ull << 20) && std::getenv("MLX_ALLOC_DEBUG")) { + int d = -1; + (void)hipGetDevice(&d); + fprintf(stderr, "[alloc] %zu MB curdev=%d defdev=%d finegrained=%d\n", + size >> 20, d, mlx::core::default_device().index, + (int)use_finegrained()); + } + if (use_finegrained()) { + // Integrated APU: unified LPDDR5, host-coherent. One pointer feeds kernels + // (gpu_ptr) and the CPU (raw_ptr) — no host shadow, coherent at sync points. + err = hipExtMallocWithFlags(&data, size, hipDeviceMallocFinegrained); + } else { + // Discrete GPU: coarse-grained VRAM (no coherency requirement). CPU access + // goes through the pinned host shadow (ensure_host_shadow/flush_host_shadow). + err = hipMalloc(&data, size); + } + if (err == hipSuccess) { + is_managed = true; + return data; + } + // Fallbacks for platforms without fine-grained device memory. + if (managed_memory_supported()) { + err = hipMallocManaged(&data, size); + is_managed = true; + } else { + err = hipHostMalloc(&data, size, hipHostMallocDefault); + is_managed = false; + } + if (err != hipSuccess) { + std::ostringstream oss; + oss << "hipMalloc (unified) failed: " << hipGetErrorString(err) << "."; + throw std::runtime_error(oss.str()); + } + return data; +} + +inline void rocm_unified_free(void* data, bool is_managed) { + if (is_managed) { + (void)hipFree(data); + } else { + (void)hipHostFree(data); + } +} + +// Apply memory hints for the managed-memory fallback path. Fine-grained device +// memory (the primary path) is already VRAM-resident, so these are no-ops there +// (errors swallowed); they only matter if rocm_unified_malloc fell back to HMM. +static void apply_slab_hints(void* data, size_t size) { + if (!rocm_available()) + return; + int device = 0; + (void)hipGetDevice(&device); + // Managed/SVM hints apply only to integrated (APU) memory. On discrete GPUs + // they fail (hsa_amd_svm_attributes_set) and corrupt the HIP runtime. + if (!device_is_integrated(device)) + return; + (void)hipMemAdvise(data, size, hipMemAdviseSetAccessedBy, device); + (void)hipMemPrefetchAsync(data, size, device, nullptr); +} + +// --------------------------------------------------------------------------- +// SizeClassPool +// --------------------------------------------------------------------------- + +void SizeClassPool::init(size_t block_size, size_t slab_page_size) { + block_size_ = block_size; + slab_page_size_ = slab_page_size; +} + +SizeClassPool::~SizeClassPool() { + for (size_t i = 0; i < backing_pages_.size(); i++) { + rocm_unified_free(backing_pages_[i], is_managed_); + delete[] block_arrays_[i]; + } +} + +bool SizeClassPool::grow() { + if (!rocm_available() || block_size_ == 0) + return false; + + void* data = nullptr; + try { + data = rocm_unified_malloc(slab_page_size_, is_managed_); + } catch (...) { + return false; + } + + // Apply memory hints for GPU access + apply_slab_hints(data, slab_page_size_); + + size_t num_blocks = slab_page_size_ / block_size_; + auto* blocks = new Block[num_blocks]; + + // Chain blocks into the free list + for (size_t i = 0; i < num_blocks; i++) { + blocks[i].next = (i + 1 < num_blocks) ? &blocks[i + 1] : next_free_; + } + next_free_ = &blocks[0]; + + backing_pages_.push_back(data); + block_arrays_.push_back(blocks); + blocks_per_page_.push_back(num_blocks); + free_count_ += num_blocks; + total_blocks_ += num_blocks; + + return true; +} + +RocmBuffer* SizeClassPool::malloc() { + if (next_free_ == nullptr) + return nullptr; + + Block* b = next_free_; + next_free_ = next_free_->next; + free_count_--; + + // Fast path: single page (common case after warmup) + if (block_arrays_.size() == 1) { + size_t idx = static_cast(b - block_arrays_[0]); + b->buf.data = static_cast(backing_pages_[0]) + idx * block_size_; + b->buf.size = block_size_; + b->buf.is_managed = is_managed_; + b->buf.device = alloc_device_tag(); + b->buf.host_shadow = nullptr; + b->buf.host_dirty = false; + return &b->buf; + } + + // Multi-page: find which backing page this block belongs to + for (size_t page = 0; page < block_arrays_.size(); page++) { + Block* base = block_arrays_[page]; + size_t count = blocks_per_page_[page]; + if (b >= base && b < base + count) { + size_t idx = static_cast(b - base); + b->buf.data = + static_cast(backing_pages_[page]) + idx * block_size_; + b->buf.size = block_size_; + b->buf.is_managed = is_managed_; + b->buf.device = alloc_device_tag(); + b->buf.host_shadow = nullptr; + b->buf.host_dirty = false; + return &b->buf; + } + } + + return nullptr; +} + +void SizeClassPool::free(RocmBuffer* buf) { + auto* b = reinterpret_cast(buf); + b->next = next_free_; + next_free_ = b; + free_count_++; +} + +bool SizeClassPool::in_pool(RocmBuffer* buf) const { + if (block_arrays_.empty()) + return false; + auto* b = reinterpret_cast(buf); + + // Fast path: single page + if (block_arrays_.size() == 1) { + return b >= block_arrays_[0] && b < block_arrays_[0] + blocks_per_page_[0]; + } + + for (size_t page = 0; page < block_arrays_.size(); page++) { + if (b >= block_arrays_[page] && + b < block_arrays_[page] + blocks_per_page_[page]) { + return true; + } + } + return false; +} + +// --------------------------------------------------------------------------- +// SlabAllocator +// --------------------------------------------------------------------------- + +// Slab page sizes per tier (indexed by size class) +static constexpr size_t kSlabPageSizes[SlabAllocator::kNumSizeClasses] = { + 64 * 1024, // 8B blocks + 64 * 1024, // 16B + 64 * 1024, // 32B + 64 * 1024, // 64B + 64 * 1024, // 128B + 256 * 1024, // 256B + 256 * 1024, // 512B + 1024 * 1024, // 1KB + 1024 * 1024, // 2KB + 1024 * 1024, // 4KB + 1024 * 1024, // 8KB + 1024 * 1024, // 16KB + 2 * 1024 * 1024, // 32KB + 4 * 1024 * 1024, // 64KB + 8 * 1024 * 1024, // 128KB + 16 * 1024 * 1024, // 256KB + 32 * 1024 * 1024, // 512KB + 64 * 1024 * 1024, // 1MB +}; + +// Whether to pre-allocate each tier at startup +static constexpr bool kPreallocate[SlabAllocator::kNumSizeClasses] = { + true, + true, + true, + true, + true, // 8B-128B + true, + true, // 256B-512B + true, + true, + true, + true, + true, // 1KB-16KB + false, + false, + false, + false, + false, + false, // 32KB-1MB: on demand +}; + +SlabAllocator::SlabAllocator() { + for (int i = 0; i < kNumSizeClasses; i++) { + size_t block_size = static_cast(1) + << (i + 3); // 2^3=8 through 2^20=1MB + pools_[i].init(block_size, kSlabPageSizes[i]); + } +} + +int SlabAllocator::size_class_index(size_t size) { + if (size == 0 || size > kMaxSlabSize) + return -1; + if (size <= 8) + return 0; + // ceil(log2(size)) - 3, computed via bit manipulation + int bits = 64 - __builtin_clzll(size - 1); // ceil(log2(size)) + return bits - 3; +} + +size_t SlabAllocator::round_to_size_class(size_t size) { + if (size <= 8) + return 8; + if (size > kMaxSlabSize) + return size; + // Round up to next power of 2 + return static_cast(1) << (64 - __builtin_clzll(size - 1)); +} + +void SlabAllocator::warmup() { + if (!rocm_available()) + return; + for (int i = 0; i < kNumSizeClasses; i++) { + if (kPreallocate[i]) { + pools_[i].grow(); + } + } +} + +RocmBuffer* SlabAllocator::malloc(size_t size) { + int idx = size_class_index(size); + if (idx < 0) + return nullptr; + return pools_[idx].malloc(); +} + +void SlabAllocator::free(RocmBuffer* buf) { + // O(1) dispatch: use buf->size to find the correct pool + int idx = size_class_index(buf->size); + if (idx >= 0 && pools_[idx].initialized()) { + pools_[idx].free(buf); + } +} + +bool SlabAllocator::in_pool(RocmBuffer* buf) const { + // O(1) dispatch: size determines the pool, then verify membership + int idx = size_class_index(buf->size); + if (idx >= 0 && pools_[idx].initialized()) { + return pools_[idx].in_pool(buf); + } + return false; +} + +bool SlabAllocator::grow(size_t size) { + int idx = size_class_index(size); + if (idx < 0) + return false; + return pools_[idx].grow(); +} + +size_t SlabAllocator::total_allocated() const { + size_t total = 0; + for (int i = 0; i < kNumSizeClasses; i++) { + total += pools_[i].total_allocated(); + } + return total; +} + +size_t SlabAllocator::free_memory() const { + size_t total = 0; + for (int i = 0; i < kNumSizeClasses; i++) { + total += pools_[i].free_memory(); + } + return total; +} + +// --------------------------------------------------------------------------- +// RocmAllocator +// --------------------------------------------------------------------------- + +RocmAllocator::RocmAllocator() + : buffer_cache_( + page_size, + [](RocmBuffer* buf) { return buf->size; }, + [this](RocmBuffer* buf) { rocm_free(buf); }), + memory_limit_(0), + max_pool_size_(0), + active_memory_(0), + peak_memory_(0) { + if (!rocm_available()) { + return; + } + + size_t free, total; + hipError_t err = hipMemGetInfo(&free, &total); + if (err == hipSuccess) { + int dev = 0; + (void)hipGetDevice(&dev); + // Integrated APU: unified memory is shared with the CPU/system, so keep a + // conservative cap. Discrete GPU: it is dedicated VRAM — use almost all of + // it. The old 0.8 cap stranded ~6GB on a 32GB card, so once the working set + // crossed 0.8*total every allocation evicted the buffer cache, and on a + // discrete GPU each eviction is a blocking hipFree (waits on GPU drain) — + // which stalls decode. Leave only a small reserve for driver/fragmentation. + if (device_is_integrated(dev)) { + // The APU's managed/fine-grained allocations live in the large unified + // pool (system RAM / GTT), but hipMemGetInfo reports only the tiny + // device-visible VRAM carveout. Sizing the cache to that carveout makes + // the allocator evict on nearly every allocation, and each eviction is a + // blocking hipFree that deadlocks under heavy async load (MTP). Size the + // limit to system RAM, which is what the unified pool actually draws from. + size_t sys_ram = static_cast(sysconf(_SC_PHYS_PAGES)) * + static_cast(sysconf(_SC_PAGE_SIZE)); + memory_limit_ = std::max( + static_cast(total * 0.8), static_cast(sys_ram * 0.8)); + } else { + size_t reserve = 512ull << 20; // 512 MB driver/TTM headroom + memory_limit_ = (total > reserve) ? (total - reserve) : total; + } + max_pool_size_ = memory_limit_; + total_memory_ = total; + free_limit_ = (total > memory_limit_) ? (total - memory_limit_) : 0; + } + + // Per-device hipMemPool + dedicated free stream for the async pool path. + if (use_async_pool()) { + int n = 0; + (void)hipGetDeviceCount(&n); + mem_pools_.resize(n, nullptr); + free_streams_.resize(n, nullptr); + int saved = 0; + (void)hipGetDevice(&saved); + for (int i = 0; i < n; ++i) { + int supported = 0; + (void)hipDeviceGetAttribute( + &supported, hipDeviceAttributeMemoryPoolsSupported, i); + if (!supported) + continue; + (void)hipSetDevice(i); + hipMemPool_t pool = nullptr; + if (hipDeviceGetDefaultMemPool(&pool, i) == hipSuccess) { + mem_pools_[i] = pool; + hipStream_t s = nullptr; + if (hipStreamCreateWithFlags(&s, hipStreamNonBlocking) == hipSuccess) + free_streams_[i] = s; + } + } + (void)hipSetDevice(saved); + } + + // Pre-allocate slab pages for common allocation sizes + slab_allocator_.warmup(); +} + +// Unified-path frees deferred out of free() so they never run a blocking +// hipFree on the completion-worker thread (which self-deadlocks). Drained by +// malloc on the eval thread, where a blocking hipFree is safe. Pool buffers +// don't use this — they free non-blocking via hipFreeAsync. +static std::mutex g_pending_free_mutex; +static std::vector g_pending_frees; + +Buffer RocmAllocator::malloc(size_t size) { + if (!rocm_available()) { + throw std::runtime_error( + "Cannot allocate ROCm memory: no ROCm-capable device detected. " + "Please use CPU backend instead."); + } + + // Drain deferred unified frees on this (eval) thread, outside any lock. + { + std::vector to_free; + { + std::lock_guard lk(g_pending_free_mutex); + to_free.swap(g_pending_frees); + } + for (auto* b : to_free) { + rocm_free(b); + } + } + + // Arena fast path: deterministic bump allocation for HIP Graph capture + if (arena_.active()) { + RocmBuffer* buf = arena_.malloc(size); + if (buf) + return Buffer{buf}; + // Arena exhausted — fall through to normal path + } + + auto orig_size = size; + std::unique_lock lock(mutex_); + + // Round size to appropriate boundary + if (size <= SlabAllocator::kMaxSlabSize) { + size = SlabAllocator::round_to_size_class(size); + + // Try slab allocator (O(1) free-list pop) + RocmBuffer* buf = slab_allocator_.malloc(size); + if (buf) { + active_memory_ += size; + peak_memory_ = std::max(active_memory_, peak_memory_); + return Buffer{buf}; + } + + // Pool exhausted — grow (holds lock during HIP alloc, acceptable for rare + // path) + if (slab_allocator_.grow(size)) { + buf = slab_allocator_.malloc(size); + if (buf) { + active_memory_ += size; + peak_memory_ = std::max(active_memory_, peak_memory_); + return Buffer{buf}; + } + } + + // Slab growth failed — fall through to BufferCache + // Slab growth failed — fall through to BufferCache + } else { + // Large allocation: page-align + size = page_size * ((size + page_size - 1) / page_size); + } + + // Stream-less allocations (model load, KV, non-wired primitives) use unified + // memory + the BufferCache. The wired primitives route their outputs through + // malloc_async (the pool) instead; this path is the safe fallback. + RocmBuffer* buf = buffer_cache_.reuse_from_cache(size); + if (!buf) { + int64_t mem_to_free = + get_active_memory() + get_cache_memory() + size - memory_limit_; + if (mem_to_free > 0) { + buffer_cache_.release_cached_buffers(mem_to_free); + } + lock.unlock(); + bool is_managed = false; + void* data = rocm_unified_malloc(size, is_managed); + buf = new RocmBuffer{data, size, is_managed, alloc_device_tag(), nullptr, false, nullptr}; + lock.lock(); + } + active_memory_ += size; + peak_memory_ = std::max(active_memory_, peak_memory_); + if (get_cache_memory() > max_pool_size_) { + buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_); + } + return Buffer{buf}; +} + +Buffer RocmAllocator::malloc_async(size_t size, int device, void* stream_v) { + // During HIP-graph capture, route through the DecodeArena like malloc() does. + // Otherwise hipMallocAsync below records a MemAlloc node into the captured + // graph; such graphs allocate on the first replay but FAIL on the second + // ("invalid argument") because the memory node can't re-allocate — freezing + // decode after one token. The arena hands back pre-allocated deterministic + // addresses, so no MemAlloc node is recorded. + if (arena_.active()) { + RocmBuffer* buf = arena_.malloc(size); + if (buf) + return Buffer{buf}; + // arena exhausted — fall through to the pool path + } + hipStream_t stream = static_cast(stream_v); + // Fall back to the unified path unless the pool is usable for this request. + if (!use_async_pool() || stream == nullptr || device < 0 || + device >= static_cast(mem_pools_.size()) || + mem_pools_[device] == nullptr || size == 0 || + size <= SlabAllocator::kMaxSlabSize) { + return malloc(size); + } + + size = page_size * ((size + page_size - 1) / page_size); + + // Bypass our BufferCache entirely: the hipMemPool already manages reuse and + // retention (ReleaseThreshold=MAX). Layering our own eviction on top causes + // hipFreeAsync storms that starve the HSA handler pool and wedge. Let the GPU + // manage its own memory — alloc straight from the pool. + void* data = nullptr; + hipError_t err = hipMallocAsync(&data, size, stream); + if (err != hipSuccess || !data) { + (void)hipGetLastError(); + return malloc(size); // pool exhausted: fall back to unified + } + // is_managed=false marks this as a stream-ordered pool buffer (freed via + // hipFreeAsync); device>=0 routes CPU access through the host shadow. + RocmBuffer* buf = new RocmBuffer{data, size, false, device, nullptr, false, stream}; + std::lock_guard lock(mutex_); + active_memory_ += buf->size; + peak_memory_ = std::max(active_memory_, peak_memory_); + return Buffer{buf}; +} + +void RocmAllocator::free_async(RocmBuffer* buf, void* stream_v) { + hipStream_t stream = static_cast(stream_v); + // Free on the buffer's own alloc/eval stream so the free retires in order + // behind its last use and the pool reclaims it (a separate idle free-stream + // never executes during a forward, so the pool can't reuse and VRAM grows). + if (!stream) + stream = static_cast(buf->alloc_stream); + if (!stream && buf->device >= 0 && + buf->device < static_cast(free_streams_.size())) { + stream = static_cast(free_streams_[buf->device]); + } + if (buf->host_shadow) { + (void)hipHostFree(buf->host_shadow); + buf->host_shadow = nullptr; + } + if (stream) { + (void)hipFreeAsync(buf->data, stream); + } else { + (void)hipFree(buf->data); + } + delete buf; +} + +static std::mutex g_deferred_mutex; +static std::vector g_deferred_frees; + +void flush_graph_deferred_frees() { + std::vector to_free; + { + std::lock_guard lk(g_deferred_mutex); + to_free.swap(g_deferred_frees); + } + for (auto b : to_free) { + allocator().free(b); + } +} + +void RocmAllocator::free(Buffer buffer) { + auto* buf = static_cast(buffer.ptr()); + if (!buf) { + return; + } + + // Defer all frees while a captured graph is alive so its baked buffer + // addresses stay valid through replay. + if (graph_active()) { + std::lock_guard lk(g_deferred_mutex); + g_deferred_frees.push_back(buffer); + return; + } + + // Arena fast path: no-op (memory freed in bulk on arena.end()) + if (arena_.active()) { + arena_.free(buf); + return; + } + + std::unique_lock lock(mutex_); + active_memory_ -= buf->size; + + // Slab-allocated buffers go back to the slab free list + if (slab_allocator_.in_pool(buf)) { + slab_allocator_.free(buf); + return; + } + + // Stream-ordered pool buffer (the common case): return it straight to the + // hipMemPool via hipFreeAsync on its own stream. The pool owns reuse/retention. + if (buf->device >= 0 && !buf->is_managed) { + free_async(buf, nullptr); + return; + } + + // Unified buffer (model load / KV / non-wired primitives). Recycle to the + // BufferCache, or defer the blocking hipFree off the worker thread. + if (get_cache_memory() < max_pool_size_) { + buffer_cache_.recycle_to_cache(buf); + } else { + std::lock_guard lk(g_pending_free_mutex); + g_pending_frees.push_back(buf); + } +} + +size_t RocmAllocator::size(Buffer buffer) const { + auto* buf = static_cast(buffer.ptr()); + if (!buf) { + return 0; + } + return buf->size; +} + +void RocmAllocator::rocm_free(RocmBuffer* buf) { + // Stream-ordered pool buffer: free non-blocking via hipFreeAsync. + if (buf->device >= 0 && !buf->is_managed) { + free_async(buf, nullptr); + return; + } + if (buf->host_shadow) { + (void)hipHostFree(buf->host_shadow); + buf->host_shadow = nullptr; + } + if (buf->device == -1) { + rocm_unified_free(buf->data, buf->is_managed); + } else { + (void)hipFree(buf->data); + } + delete buf; +} + +void RocmAllocator::ensure_host_shadow(RocmBuffer& buf) { + // Integrated APU buffers are already host-coherent — never reached. + if (buf.device == -1) { + return; + } + // Allocate the pinned host mirror once, then refresh it from VRAM. The VRAM + // copy in buf.data is KEPT (no hipFree, device stays != -1) so gpu_ptr() + // keeps feeding kernels the resident device pointer; only CPU reads see the + // host mirror. No per-weight VRAM doubling / migration. + if (buf.host_shadow == nullptr) { + hipError_t err = + hipHostMalloc(&buf.host_shadow, buf.size, hipHostMallocDefault); + if (err != hipSuccess) { + buf.host_shadow = nullptr; + std::ostringstream oss; + oss << "hipHostMalloc (host shadow) failed: " << hipGetErrorString(err) + << "."; + throw std::runtime_error(oss.str()); + } + } + // Refresh from VRAM only when the shadow is NOT already the authoritative copy + // (i.e. no un-flushed CPU writes pending) — otherwise we'd clobber them. + if (!buf.host_dirty) { + hipError_t err = + hipMemcpy(buf.host_shadow, buf.data, buf.size, hipMemcpyDeviceToHost); + if (err != hipSuccess) { + std::ostringstream oss; + oss << "hipMemcpy (host shadow) failed: " << hipGetErrorString(err) << "."; + throw std::runtime_error(oss.str()); + } + } +} + +void RocmAllocator::flush_host_shadow(RocmBuffer& buf) { + if (buf.host_shadow == nullptr || !buf.host_dirty) { + return; + } + (void)hipMemcpy(buf.data, buf.host_shadow, buf.size, hipMemcpyHostToDevice); + buf.host_dirty = false; +} + +size_t RocmAllocator::get_active_memory() const { + return active_memory_; +} + +size_t RocmAllocator::get_peak_memory() const { + return peak_memory_; +} + +void RocmAllocator::reset_peak_memory() { + std::lock_guard lock(mutex_); + peak_memory_ = 0; +} + +size_t RocmAllocator::get_memory_limit() { + return memory_limit_; +} + +size_t RocmAllocator::set_memory_limit(size_t limit) { + std::lock_guard lock(mutex_); + std::swap(limit, memory_limit_); + return limit; +} + +size_t RocmAllocator::get_cache_memory() const { + // Only report BufferCache size. Slab free memory is infrastructure, + // not cache — including it inflates the count and causes premature + // eviction of large buffers from the BufferCache. + return buffer_cache_.cache_size(); +} + +size_t RocmAllocator::set_cache_limit(size_t limit) { + std::lock_guard lk(mutex_); + std::swap(limit, max_pool_size_); + // Trim the reuse pool down to the new cap NOW, while the caller is at an idle + // point (e.g. just after warmup). Otherwise the trim happens lazily on the + // next malloc — i.e. during the first forward — and its blocking hipFree + // (which on a discrete GPU implicitly synchronizes the device and can force a + // TTM eviction) wedges the command queue mid-pass. + if (get_cache_memory() > max_pool_size_) { + buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_); + } + return limit; +} + +void RocmAllocator::clear_cache() { + // The hipMemPool owns reuse/retention for pool buffers; releasing memory means + // trimming it. Drain the device first so trimmed blocks have no outstanding + // work, then drain deferred unified frees on this (safe) thread. Do NOT + // blocking-clear the unified BufferCache under pool handler pressure — those + // buffers are bounded by max_pool_size_ and reused. + (void)hipDeviceSynchronize(); + for (void* p : mem_pools_) { + if (p) + (void)hipMemPoolTrimTo(static_cast(p), 0); + } + std::vector to_free; + { + std::lock_guard lk(g_pending_free_mutex); + to_free.swap(g_pending_frees); + } + for (auto* b : to_free) { + rocm_free(b); + } +} + +// --------------------------------------------------------------------------- +// DecodeArena implementation +// --------------------------------------------------------------------------- + +DecodeArena::~DecodeArena() { + end(); +} + +bool DecodeArena::begin(size_t capacity_bytes) { + if (base_) + end(); + + // Align capacity to page boundary + capacity_bytes = (capacity_bytes + 4095) & ~size_t(4095); + + bool managed = false; + void* data = nullptr; + try { + data = rocm_unified_malloc(capacity_bytes, managed); + } catch (...) { + return false; + } + + base_ = data; + capacity_ = capacity_bytes; + offset_ = 0; + is_managed_ = managed; + desc_index_ = 0; + paused_ = false; + descriptors_.clear(); + // Reserve a hard upper bound so the vector NEVER reallocates: malloc() returns + // RocmBuffer* pointers INTO this vector, and the captured graph + live arrays + // hold those pointers for the whole decode. A realloc would dangle all of them + // (heap corruption). A decode step + per-token sampling stays well under this. + descriptors_.reserve(16384); + return true; +} + +void DecodeArena::reset() { + offset_ = 0; + desc_index_ = 0; +} + +void DecodeArena::end() { + if (!base_) + return; + rocm_unified_free(base_, is_managed_); + base_ = nullptr; + capacity_ = 0; + offset_ = 0; + descriptors_.clear(); + desc_index_ = 0; +} + +RocmBuffer* DecodeArena::malloc(size_t size) { + if (!base_) + return nullptr; + + // Align to 256 bytes for GPU access patterns + size_t aligned = (size + 255) & ~size_t(255); + if (offset_ + aligned > capacity_) + return nullptr; + + void* ptr = static_cast(base_) + offset_; + offset_ += aligned; + + // Reuse or create a RocmBuffer descriptor + if (desc_index_ < descriptors_.size()) { + auto& d = descriptors_[desc_index_]; + d.data = ptr; + d.size = size; + d.host_shadow = nullptr; + d.host_dirty = false; + desc_index_++; + return &d; + } + + // Fully initialize host_shadow/host_dirty: gpu_ptr() reads host_dirty, so an + // uninitialized value could spuriously trigger a flush of a garbage pointer. + descriptors_.push_back(RocmBuffer{ptr, size, is_managed_, -1, nullptr, false, nullptr}); + desc_index_++; + return &descriptors_.back(); +} + +RocmAllocator& allocator() { + static RocmAllocator* allocator_ = new RocmAllocator; + return *allocator_; +} + +Buffer malloc_async(size_t size, CommandEncoder& encoder) { + return allocator().malloc_async( + size, + encoder.device().hip_device(), + static_cast(encoder.stream())); +} + +} // namespace rocm + +namespace allocator { + +Allocator& allocator() { + return rocm::allocator(); +} + +void* Buffer::raw_ptr() { + if (!ptr_) { + return nullptr; + } + auto& cbuf = *static_cast(ptr_); + + if (cbuf.device == -1) { + // Unified memory on iGPU: fine-grained coherent memory means CPU sees + // GPU writes without explicit sync. Only sync if the stream has pending + // work (hipStreamQuery returns hipErrorNotReady when busy). + if (hipStreamQuery(nullptr) != hipSuccess) { + (void)hipStreamSynchronize(nullptr); + } + } else { + // Discrete GPU: serve CPU access from the pinned host mirror (fresh D2H), + // keeping the VRAM copy authoritative. Synchronize the device first so the + // producing kernel has finished before the D2H read — a lighter null-stream + // query is NOT sufficient (the value may be produced on a non-default stream) + // and reading early returns stale zeros (crashes / garbage). + (void)hipDeviceSynchronize(); + rocm::allocator().ensure_host_shadow(cbuf); + return cbuf.host_shadow; + } + return cbuf.data; +} + +} // namespace allocator + +size_t get_active_memory() { + return rocm::allocator().get_active_memory(); +} +size_t get_peak_memory() { + return rocm::allocator().get_peak_memory(); +} +void reset_peak_memory() { + return rocm::allocator().reset_peak_memory(); +} +size_t set_memory_limit(size_t limit) { + return rocm::allocator().set_memory_limit(limit); +} +size_t get_memory_limit() { + return rocm::allocator().get_memory_limit(); +} +size_t get_cache_memory() { + return rocm::allocator().get_cache_memory(); +} +size_t set_cache_limit(size_t limit) { + return rocm::allocator().set_cache_limit(limit); +} +void clear_cache() { + rocm::allocator().clear_cache(); +} + +// Not supported in ROCm. +size_t set_wired_limit(size_t) { + return 0; +} + +} // namespace mlx::core diff --git a/mlx/backend/rocm/allocator.h b/mlx/backend/rocm/allocator.h new file mode 100644 index 0000000000..6220165459 --- /dev/null +++ b/mlx/backend/rocm/allocator.h @@ -0,0 +1,276 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/allocator.h" +#include "mlx/backend/common/buffer_cache.h" + +#include +#include +#include +#include + +namespace mlx::core::rocm { + +using allocator::Buffer; + +struct RocmBuffer { + void* data; + size_t size; + bool is_managed; + int device; + // Discrete-GPU only: pinned host mirror that serves CPU reads (raw_ptr) + // WITHOUT migrating/freeing the resident VRAM copy in `data`. No default + // initializer (keeps RocmBuffer trivial for the SizeClassPool union); set + // explicitly in the slab path, aggregate-init'd to null in the large-alloc + // paths. Always null on the integrated APU (device == -1). + void* host_shadow; + // True while host_shadow is the authoritative copy (CPU may have written + // through raw_ptr). gpu_ptr() flushes host_shadow -> VRAM and clears it so + // kernels see CPU writes; raw_ptr() won't re-pull from VRAM while dirty. + bool host_dirty; + // For stream-ordered pool buffers: the stream the buffer was allocated/used + // on. hipFreeAsync must run on this same (actively-executing) stream so the + // free retires in order behind the buffer's last use and the pool reclaims it. + void* alloc_stream; +}; + +// --------------------------------------------------------------------------- +// SizeClassPool — fixed-size block pool with free list +// --------------------------------------------------------------------------- + +class SizeClassPool { + public: + SizeClassPool() = default; + ~SizeClassPool(); + + SizeClassPool(const SizeClassPool&) = delete; + SizeClassPool& operator=(const SizeClassPool&) = delete; + + void init(size_t block_size, size_t slab_page_size); + RocmBuffer* malloc(); + void free(RocmBuffer* buf); + bool in_pool(RocmBuffer* buf) const; + bool grow(); + + size_t block_size() const { + return block_size_; + } + size_t free_count() const { + return free_count_; + } + size_t total_allocated() const { + return backing_pages_.size() * slab_page_size_; + } + size_t free_memory() const { + return free_count_ * block_size_; + } + bool initialized() const { + return block_size_ > 0; + } + + private: + union Block { + Block* next; + RocmBuffer buf; + }; + + size_t block_size_{0}; + size_t slab_page_size_{0}; + bool is_managed_{false}; + + std::vector backing_pages_; + std::vector block_arrays_; + std::vector blocks_per_page_; + + Block* next_free_{nullptr}; + size_t free_count_{0}; + size_t total_blocks_{0}; +}; + +// --------------------------------------------------------------------------- +// SlabAllocator — multi-tier slab allocator for sizes <= 1MB +// --------------------------------------------------------------------------- + +class SlabAllocator { + public: + static constexpr int kNumSizeClasses = 18; + static constexpr size_t kMaxSlabSize = 1 << 20; + + SlabAllocator(); + ~SlabAllocator() = default; + + RocmBuffer* malloc(size_t size); + void free(RocmBuffer* buf); + bool in_pool(RocmBuffer* buf) const; + bool grow(size_t size); + void warmup(); + + size_t total_allocated() const; + size_t free_memory() const; + + static int size_class_index(size_t size); + static size_t round_to_size_class(size_t size); + + private: + SizeClassPool pools_[kNumSizeClasses]; +}; + +// --------------------------------------------------------------------------- +// DecodeArena — deterministic bump allocator for HIP Graph capture +// --------------------------------------------------------------------------- +// During decode, the allocation pattern is fixed: same sizes in the same +// order every step. The arena allocates from a pre-sized contiguous buffer, +// guaranteeing identical pointers on each reset+replay cycle. +// +// Usage: +// arena.begin(estimated_bytes); // allocate backing buffer +// // ... run decode step (allocations go through arena) ... +// arena.reset(); // rewind bump pointer for next step +// // ... replay same step (same pointers) ... +// arena.end(); // release backing buffer + +class DecodeArena { + public: + DecodeArena() = default; + ~DecodeArena(); + + // Allocate the backing buffer and enter arena mode. + bool begin(size_t capacity_bytes); + + // Rewind the bump pointer. Next cycle returns same addresses. + void reset(); + + // Number of descriptors handed out so far (descriptor mark companion to used()). + size_t desc_used() const { + return desc_index_; + } + + // Rewind BOTH the byte bump pointer and the descriptor index to a recorded + // mark (the state right after a captured graph's buffers). The graph region + // [0, byte_mark) / descriptors [0, desc_mark) stays reserved and untouched; + // per-token sampling reuses the region after the mark each cycle. Rewinding + // only bytes (not desc_index_) would grow the descriptor vector unboundedly + // (realloc → dangling pointers); rewinding desc_index_ to 0 would reuse and + // mutate the graph's descriptor objects (corrupting live arrays). + void reset_to(size_t byte_mark, size_t desc_mark) { + offset_ = byte_mark; + desc_index_ = desc_mark; + } + + // Leave arena mode and free the backing buffer. + void end(); + + // Bump-allocate from the arena. Returns nullptr if inactive or exhausted. + RocmBuffer* malloc(size_t size); + + // No-op free (bulk-freed on end()). + void free(RocmBuffer* /*buf*/) {} + + // active() drives the allocator's routing to the arena. When paused, the + // backing stays allocated (so captured-graph buffers remain valid at their + // baked addresses) but NEW allocations fall through to the pool. Used after a + // capture-once graph is built: the graph keeps its arena buffers, while + // per-token sampling allocates from the pool and can't clobber graph buffers. + bool active() const { + return base_ != nullptr && !paused_; + } + void set_paused(bool p) { + paused_ = p; + } + size_t used() const { + return offset_; + } + size_t capacity() const { + return capacity_; + } + + private: + void* base_{nullptr}; + size_t capacity_{0}; + size_t offset_{0}; + bool is_managed_{false}; + bool paused_{false}; + + // Pre-allocated RocmBuffer descriptors (recycled on reset) + std::vector descriptors_; + size_t desc_index_{0}; +}; + +// --------------------------------------------------------------------------- +// RocmAllocator +// --------------------------------------------------------------------------- + +class RocmAllocator : public allocator::Allocator { + public: + Buffer malloc(size_t size) override; + void free(Buffer buffer) override; + size_t size(Buffer buffer) const override; + + // CUDA-style stream-ordered allocation. When the async pool is enabled and a + // real stream is given for a discrete device, allocates GPU-only pool memory + // (hipMallocAsync) freed non-blocking (hipFreeAsync). Otherwise falls back to + // the unified path (== malloc). CPU access to pool buffers is served by the + // existing host-shadow path (device != -1) in Buffer::raw_ptr(). + Buffer malloc_async(size_t size, int device, void* stream); + void free_async(RocmBuffer* buf, void* stream); + + // Discrete GPU: ensure buf has an up-to-date pinned host mirror for CPU reads. + // Keeps the VRAM copy resident (does not free it or flip device to -1). + void ensure_host_shadow(RocmBuffer& buf); + + // Discrete GPU: if buf's host shadow was written by the CPU, copy it back to + // VRAM so kernels (gpu_ptr) see the update. No-op otherwise. + void flush_host_shadow(RocmBuffer& buf); + + size_t get_active_memory() const; + size_t get_peak_memory() const; + void reset_peak_memory(); + size_t get_memory_limit(); + size_t set_memory_limit(size_t limit); + size_t get_cache_memory() const; + size_t set_cache_limit(size_t limit); + void clear_cache(); + + private: + void rocm_free(RocmBuffer* buf); + + RocmAllocator(); + friend RocmAllocator& allocator(); + + std::mutex mutex_; + size_t memory_limit_; + size_t max_pool_size_; + size_t total_memory_{0}; + size_t free_limit_{0}; + BufferCache buffer_cache_; + size_t active_memory_{0}; + size_t peak_memory_{0}; + SlabAllocator slab_allocator_; + + // Per-device hipMemPool + a dedicated free stream for stream-less frees + // (mirrors the CUDA backend). Empty entry => device has no pool support and + // uses the blocking path. + std::vector mem_pools_; + std::vector free_streams_; + + public: + // Arena mode for HIP Graph capture. + // When active, malloc() returns deterministic addresses from the arena. + DecodeArena& arena() { + return arena_; + } + + private: + DecodeArena arena_; +}; + +RocmAllocator& allocator(); + +class CommandEncoder; +// Stream-ordered allocation bound to an encoder's device/stream. Primitives +// call this for their output buffers so transient activations come from the +// device pool (fast, non-blocking free, in-eval reuse) instead of unified mem. +Buffer malloc_async(size_t size, CommandEncoder& encoder); + +} // namespace mlx::core::rocm diff --git a/mlx/backend/rocm/arange.hip b/mlx/backend/rocm/arange.hip new file mode 100644 index 0000000000..944b226090 --- /dev/null +++ b/mlx/backend/rocm/arange.hip @@ -0,0 +1,139 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/allocator.h" +#include "mlx/backend/rocm/device/arange.hpp" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/primitives.h" + +#include +#include +#include + +namespace mlx::core { + +void Arange::eval_gpu(const std::vector& inputs, array& out) { + auto& s = stream(); + auto& encoder = rocm::get_command_encoder(s); + + out.set_data(mlx::core::rocm::malloc_async(out.nbytes(), encoder)); + + size_t size = out.size(); + int block_size = 256; + int num_blocks = (size + block_size - 1) / block_size; + + switch (out.dtype()) { + case float32: { + float start = static_cast(start_); + float step = static_cast(step_); + encoder.add_kernel_node( + &rocm::arange_kernel, + dim3(num_blocks), dim3(block_size), 0, + gpu_ptr(out), start, step, size); + break; + } + case float64: { + double start = start_; + double step = step_; + encoder.add_kernel_node( + &rocm::arange_kernel, + dim3(num_blocks), dim3(block_size), 0, + gpu_ptr(out), start, step, size); + break; + } + case float16: { + __half start = __float2half(static_cast(start_)); + __half step = __float2half(static_cast(step_)); + encoder.add_kernel_node( + &rocm::arange_kernel<__half>, + dim3(num_blocks), dim3(block_size), 0, + gpu_ptr<__half>(out), start, step, size); + break; + } + case bfloat16: { + hip_bfloat16 start = hip_bfloat16(static_cast(start_)); + hip_bfloat16 step = hip_bfloat16(static_cast(step_)); + encoder.add_kernel_node( + &rocm::arange_kernel, + dim3(num_blocks), dim3(block_size), 0, + gpu_ptr(out), start, step, size); + break; + } + case int32: { + int32_t start = static_cast(start_); + int32_t step = static_cast(step_); + encoder.add_kernel_node( + &rocm::arange_kernel, + dim3(num_blocks), dim3(block_size), 0, + gpu_ptr(out), start, step, size); + break; + } + case int64: { + int64_t start = static_cast(start_); + int64_t step = static_cast(step_); + encoder.add_kernel_node( + &rocm::arange_kernel, + dim3(num_blocks), dim3(block_size), 0, + gpu_ptr(out), start, step, size); + break; + } + case uint32: { + uint32_t start = static_cast(start_); + uint32_t step = static_cast(step_); + encoder.add_kernel_node( + &rocm::arange_kernel, + dim3(num_blocks), dim3(block_size), 0, + gpu_ptr(out), start, step, size); + break; + } + case uint64: { + uint64_t start = static_cast(start_); + uint64_t step = static_cast(step_); + encoder.add_kernel_node( + &rocm::arange_kernel, + dim3(num_blocks), dim3(block_size), 0, + gpu_ptr(out), start, step, size); + break; + } + case int8: { + int8_t start = static_cast(start_); + int8_t step = static_cast(step_); + encoder.add_kernel_node( + &rocm::arange_kernel, + dim3(num_blocks), dim3(block_size), 0, + gpu_ptr(out), start, step, size); + break; + } + case int16: { + int16_t start = static_cast(start_); + int16_t step = static_cast(step_); + encoder.add_kernel_node( + &rocm::arange_kernel, + dim3(num_blocks), dim3(block_size), 0, + gpu_ptr(out), start, step, size); + break; + } + case uint8: { + uint8_t start = static_cast(start_); + uint8_t step = static_cast(step_); + encoder.add_kernel_node( + &rocm::arange_kernel, + dim3(num_blocks), dim3(block_size), 0, + gpu_ptr(out), start, step, size); + break; + } + case uint16: { + uint16_t start = static_cast(start_); + uint16_t step = static_cast(step_); + encoder.add_kernel_node( + &rocm::arange_kernel, + dim3(num_blocks), dim3(block_size), 0, + gpu_ptr(out), start, step, size); + break; + } + default: + throw std::runtime_error("Unsupported type for arange"); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/rocm/arg_reduce.hip b/mlx/backend/rocm/arg_reduce.hip new file mode 100644 index 0000000000..1f08385ad4 --- /dev/null +++ b/mlx/backend/rocm/arg_reduce.hip @@ -0,0 +1,276 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/utils.h" +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/device/config.h" +#include "mlx/backend/rocm/device/fp16_math.hpp" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include + +#include + +namespace mlx::core { + +namespace rocm { + +template +struct IndexValPair { + uint32_t index; + T val; +}; + +// Type-safe shuffle wrappers for __shfl_xor +template +__device__ __forceinline__ T shfl_xor_arg(T val, int lane_mask) { + return __shfl_xor(val, lane_mask); +} + +// Specialization for __half - __shfl_xor returns float +template <> +__device__ __forceinline__ __half shfl_xor_arg(__half val, int lane_mask) { + return __half(__shfl_xor(__half2float(val), lane_mask)); +} + +// Specialization for hip_bfloat16 +template <> +__device__ __forceinline__ hip_bfloat16 shfl_xor_arg(hip_bfloat16 val, int lane_mask) { + return hip_bfloat16(__shfl_xor(static_cast(val), lane_mask)); +} + +template +struct ArgMin { + __device__ T init() const { + return numeric_limits::max(); + } + + __device__ IndexValPair operator()( + const IndexValPair& best, + const IndexValPair& current) const { + if (best.val > current.val || + (best.val == current.val && best.index > current.index)) { + return current; + } else { + return best; + } + } +}; + +template +struct ArgMax { + __device__ T init() const { + return numeric_limits::lowest(); + } + + __device__ IndexValPair operator()( + const IndexValPair& best, + const IndexValPair& current) const { + if (best.val < current.val || + (best.val == current.val && best.index > current.index)) { + return current; + } else { + return best; + } + } +}; + +// Warp reduce for IndexValPair - uses runtime warp size +template +__device__ IndexValPair warp_reduce_arg(IndexValPair val, Op op) { + // Use warpSize which is a built-in variable in HIP + for (int offset = warpSize / 2; offset > 0; offset /= 2) { + IndexValPair other; + other.index = __shfl_xor(val.index, offset); + other.val = shfl_xor_arg(val.val, offset); + val = op(val, other); + } + return val; +} + +// Block reduce for IndexValPair +template +__device__ IndexValPair block_reduce_arg(IndexValPair val, Op op) { + // Use warpSize built-in for correct behavior on both RDNA (32) and CDNA (64) + constexpr int MAX_WARPS = BLOCK_DIM / 32 + 1; // Conservative estimate + __shared__ IndexValPair shared[MAX_WARPS]; + + int lane = threadIdx.x % warpSize; + int warp_id = threadIdx.x / warpSize; + int num_warps = (BLOCK_DIM + warpSize - 1) / warpSize; + + // Warp-level reduction + val = warp_reduce_arg(val, op); + + // Write reduced value to shared memory + if (lane == 0) { + shared[warp_id] = val; + } + __syncthreads(); + + // Final reduction in first warp + if (warp_id == 0) { + val = (lane < num_warps) ? shared[lane] : IndexValPair{0, op.init()}; + val = warp_reduce_arg(val, op); + } + + return val; +} + +template +__global__ void arg_reduce_general( + const T* in, + uint32_t* out, + size_t size, + const Shape shape, + const Strides in_strides, + const Strides out_strides, + int32_t ndim, + int64_t axis_stride, + int32_t axis_size) { + int64_t index = blockIdx.x + blockIdx.y * gridDim.x; + if (index >= size) { + return; + } + + // Compute input and output indices using elem_to_loc + int64_t in_idx = elem_to_loc(index, shape.data_, in_strides.data_, ndim); + int64_t out_idx = elem_to_loc(index, shape.data_, out_strides.data_, ndim); + in += in_idx; + + Op op; + T init_val = op.init(); + IndexValPair best{0, init_val}; + + // Each thread processes multiple elements + for (int i = threadIdx.x; i < axis_size; i += BLOCK_DIM) { + T val = in[i * axis_stride]; + IndexValPair current{static_cast(i), val}; + best = op(best, current); + } + + // Block reduction + best = block_reduce_arg(best, op); + + if (threadIdx.x == 0) { + out[out_idx] = best.index; + } +} + +} // namespace rocm + +void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + auto& in = inputs[0]; + out.set_data(allocator::malloc(out.nbytes())); + auto& s = stream(); + + // Handle scalar case - just output 0 + if (in.ndim() == 0 || in.size() == 1) { + auto& encoder = rocm::get_command_encoder(s); + encoder.set_output_array(out); + encoder.launch_kernel([&](hipStream_t stream) { + (void)hipMemsetAsync(gpu_ptr(out), 0, sizeof(uint32_t), stream); + }); + return; + } + + // Prepare the shapes, strides and axis arguments. + Shape shape = remove_index(in.shape(), axis_); + Strides in_strides = remove_index(in.strides(), axis_); + Strides out_strides = out.ndim() == in.ndim() + ? remove_index(out.strides(), axis_) + : out.strides(); + int64_t axis_stride = in.strides()[axis_]; + int32_t axis_size = in.shape()[axis_]; + int32_t ndim = shape.size(); + + auto& encoder = rocm::get_command_encoder(s); + encoder.set_input_array(in); + encoder.set_output_array(out); + + constexpr int BLOCK_DIM = 256; + dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides()); + + // Use const_param to pass shape and strides by value (like CUDA) + auto shape_param = const_param(shape); + auto in_strides_param = const_param(in_strides); + auto out_strides_param = const_param(out_strides); + + size_t out_size = out.size(); + switch (in.dtype()) { + case float32: + if (reduce_type_ == ArgReduce::ArgMax) { + encoder.add_kernel_node( + &rocm::arg_reduce_general, BLOCK_DIM, 4>, + num_blocks, dim3(BLOCK_DIM), 0, + gpu_ptr(in), gpu_ptr(out), out_size, + shape_param, in_strides_param, out_strides_param, + ndim, axis_stride, axis_size); + } else { + encoder.add_kernel_node( + &rocm::arg_reduce_general, BLOCK_DIM, 4>, + num_blocks, dim3(BLOCK_DIM), 0, + gpu_ptr(in), gpu_ptr(out), out_size, + shape_param, in_strides_param, out_strides_param, + ndim, axis_stride, axis_size); + } + break; + case int32: + if (reduce_type_ == ArgReduce::ArgMax) { + encoder.add_kernel_node( + &rocm::arg_reduce_general, BLOCK_DIM, 4>, + num_blocks, dim3(BLOCK_DIM), 0, + gpu_ptr(in), gpu_ptr(out), out_size, + shape_param, in_strides_param, out_strides_param, + ndim, axis_stride, axis_size); + } else { + encoder.add_kernel_node( + &rocm::arg_reduce_general, BLOCK_DIM, 4>, + num_blocks, dim3(BLOCK_DIM), 0, + gpu_ptr(in), gpu_ptr(out), out_size, + shape_param, in_strides_param, out_strides_param, + ndim, axis_stride, axis_size); + } + break; + case float16: + if (reduce_type_ == ArgReduce::ArgMax) { + encoder.add_kernel_node( + &rocm::arg_reduce_general<__half, rocm::ArgMax<__half>, BLOCK_DIM, 4>, + num_blocks, dim3(BLOCK_DIM), 0, + gpu_ptr<__half>(in), gpu_ptr(out), out_size, + shape_param, in_strides_param, out_strides_param, + ndim, axis_stride, axis_size); + } else { + encoder.add_kernel_node( + &rocm::arg_reduce_general<__half, rocm::ArgMin<__half>, BLOCK_DIM, 4>, + num_blocks, dim3(BLOCK_DIM), 0, + gpu_ptr<__half>(in), gpu_ptr(out), out_size, + shape_param, in_strides_param, out_strides_param, + ndim, axis_stride, axis_size); + } + break; + case bfloat16: + if (reduce_type_ == ArgReduce::ArgMax) { + encoder.add_kernel_node( + &rocm::arg_reduce_general, BLOCK_DIM, 4>, + num_blocks, dim3(BLOCK_DIM), 0, + gpu_ptr(in), gpu_ptr(out), out_size, + shape_param, in_strides_param, out_strides_param, + ndim, axis_stride, axis_size); + } else { + encoder.add_kernel_node( + &rocm::arg_reduce_general, BLOCK_DIM, 4>, + num_blocks, dim3(BLOCK_DIM), 0, + gpu_ptr(in), gpu_ptr(out), out_size, + shape_param, in_strides_param, out_strides_param, + ndim, axis_stride, axis_size); + } + break; + default: + throw std::runtime_error("Unsupported type for ArgReduce"); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/rocm/bin2h.cmake b/mlx/backend/rocm/bin2h.cmake new file mode 100644 index 0000000000..1766b27c92 --- /dev/null +++ b/mlx/backend/rocm/bin2h.cmake @@ -0,0 +1,47 @@ +# Copyright © 2025 Apple Inc. + +# Script to embed kernel source files as header for JIT compilation + +set(MLX_OUTPUT_FILE "${CMAKE_CURRENT_BINARY_DIR}/gen/rocm_jit_sources.h") +set(MLX_KERNEL_HEADER + "#pragma once\n\n#include \n#include \n\nnamespace mlx::core::rocm {\n\n" +) +set(MLX_KERNEL_FOOTER "\n} // namespace mlx::core::rocm\n") + +# Create output directory +get_filename_component(MLX_OUTPUT_DIR ${MLX_OUTPUT_FILE} DIRECTORY) +file(MAKE_DIRECTORY ${MLX_OUTPUT_DIR}) + +# Write header +file(WRITE ${MLX_OUTPUT_FILE} ${MLX_KERNEL_HEADER}) + +# Process JIT sources +string(REPLACE ":" ";" MLX_JIT_SOURCES_LIST ${MLX_JIT_SOURCES}) + +set(MLX_SOURCE_MAP + "const std::unordered_map kernel_sources = {\n") + +foreach(source IN LISTS MLX_JIT_SOURCES_LIST) + set(source_file "${MLX_SOURCE_ROOT}/${source}") + if(EXISTS ${source_file}) + # Read source file + file(READ ${source_file} source_content) + + # Escape content for C++ string literal + string(REPLACE "\\" "\\\\" source_content "${source_content}") + string(REPLACE "\"" "\\\"" source_content "${source_content}") + string(REPLACE "\n" "\\n\"\n\"" source_content "${source_content}") + + # Add to map + set(MLX_SOURCE_MAP + "${MLX_SOURCE_MAP} {\"${source}\", \"${source_content}\"},\n") + endif() +endforeach() + +set(MLX_SOURCE_MAP "${MLX_SOURCE_MAP}};\n") + +# Write source map +file(APPEND ${MLX_OUTPUT_FILE} ${MLX_SOURCE_MAP}) + +# Write footer +file(APPEND ${MLX_OUTPUT_FILE} ${MLX_KERNEL_FOOTER}) diff --git a/mlx/backend/rocm/binary.hip b/mlx/backend/rocm/binary.hip new file mode 100644 index 0000000000..fd4de2fe26 --- /dev/null +++ b/mlx/backend/rocm/binary.hip @@ -0,0 +1,397 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/binary.h" +#include "mlx/backend/common/utils.h" +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/device/binary_ops.hpp" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include +#include + +namespace mlx::core { + +namespace rocm { + +template +__global__ void binary_ss(const In* a, const In* b, Out* out, IdxT size) { + IdxT index = blockIdx.x * blockDim.x + threadIdx.x; + IdxT stride = blockDim.x * gridDim.x; + + for (IdxT i = index * N_READS; i < size; i += stride * N_READS) { + if (i + N_READS <= size) { + #pragma unroll + for (int j = 0; j < N_READS; ++j) { + out[i + j] = Op{}(a[0], b[0]); + } + } else { + for (IdxT j = i; j < size; ++j) { + out[j] = Op{}(a[0], b[0]); + } + } + } +} + +template +__global__ void binary_sv(const In* a, const In* b, Out* out, IdxT size) { + IdxT index = blockIdx.x * blockDim.x + threadIdx.x; + IdxT stride = blockDim.x * gridDim.x; + + for (IdxT i = index * N_READS; i < size; i += stride * N_READS) { + if (i + N_READS <= size) { + #pragma unroll + for (int j = 0; j < N_READS; ++j) { + out[i + j] = Op{}(a[0], b[i + j]); + } + } else { + for (IdxT j = i; j < size; ++j) { + out[j] = Op{}(a[0], b[j]); + } + } + } +} + +template +__global__ void binary_vs(const In* a, const In* b, Out* out, IdxT size) { + IdxT index = blockIdx.x * blockDim.x + threadIdx.x; + IdxT stride = blockDim.x * gridDim.x; + + for (IdxT i = index * N_READS; i < size; i += stride * N_READS) { + if (i + N_READS <= size) { + #pragma unroll + for (int j = 0; j < N_READS; ++j) { + out[i + j] = Op{}(a[i + j], b[0]); + } + } else { + for (IdxT j = i; j < size; ++j) { + out[j] = Op{}(a[j], b[0]); + } + } + } +} + +template +__global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) { + IdxT index = blockIdx.x * blockDim.x + threadIdx.x; + IdxT stride = blockDim.x * gridDim.x; + + for (IdxT i = index * N_READS; i < size; i += stride * N_READS) { + if (i + N_READS <= size) { + #pragma unroll + for (int j = 0; j < N_READS; ++j) { + out[i + j] = Op{}(a[i + j], b[i + j]); + } + } else { + for (IdxT j = i; j < size; ++j) { + out[j] = Op{}(a[j], b[j]); + } + } + } +} + +template +__global__ void binary_g( + const In* a, + const In* b, + Out* out, + IdxT size, + hip_array shape, + hip_array a_strides, + hip_array b_strides, + int ndim) { + IdxT index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= size) { + return; + } + + // Compute offsets using elem_to_loc style + IdxT a_idx = 0, b_idx = 0; + IdxT tmp = index; + for (int i = ndim - 1; i >= 0 && tmp > 0; --i) { + IdxT coord = tmp % shape[i]; + a_idx += coord * a_strides[i]; + b_idx += coord * b_strides[i]; + tmp /= shape[i]; + } + + out[index] = Op{}(a[a_idx], b[b_idx]); +} + +template +constexpr bool supports_binary_op() { + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v) { + return std::is_same_v; + } else if constexpr (std::is_same_v || std::is_same_v) { + return std::is_same_v && !is_complex_v; + } else if constexpr (std::is_same_v || std::is_same_v) { + return std::is_same_v; + } else if constexpr (std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v) { + return std::is_same_v && !is_complex_v; + } else if constexpr (std::is_same_v || std::is_same_v) { + return std::is_same_v && std::is_same_v; + } else if constexpr (std::is_same_v) { + return std::is_same_v; + } else if constexpr (std::is_same_v) { + return std::is_same_v && is_inexact_v; + } else if constexpr (std::is_same_v) { + return std::is_same_v && !is_complex_v && + (std::is_floating_point_v || std::is_same_v || std::is_same_v); + } else if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + return std::is_same_v && std::is_integral_v; + } else if constexpr (std::is_same_v || std::is_same_v) { + return std::is_same_v && std::is_integral_v && + !std::is_same_v; + } else { + return false; + } +} + +} // namespace rocm + +namespace rocm { + +// Helper to launch general binary kernel +template +void launch_binary_general( + CommandEncoder& encoder, + const array& a, + const array& b, + array& out, + const ShapeType& shape, + const StridesVecType& strides_vec) { + auto& strides_a = strides_vec[0]; + auto& strides_b = strides_vec[1]; + int ndim = shape.size(); + size_t data_size = out.size(); + + hip_array shape_arg = {}; + hip_array strides_a_arg = {}; + hip_array strides_b_arg = {}; + for (int i = 0; i < ndim; i++) { + shape_arg.data_[i] = static_cast(shape[i]); + strides_a_arg.data_[i] = strides_a[i]; + strides_b_arg.data_[i] = strides_b[i]; + } + + int block_size = 256; + int num_blocks = (data_size + block_size - 1) / block_size; + + int64_t size_arg = static_cast(data_size); + encoder.add_kernel_node( + &binary_g, + dim3(num_blocks), dim3(block_size), 0, + gpu_ptr(a), gpu_ptr(b), gpu_ptr(out), + size_arg, + shape_arg, + strides_a_arg, + strides_b_arg, + ndim); +} + +} // namespace rocm + +template +void binary_op_gpu_inplace( + const std::vector& inputs, + array& out, + const char* op, + const Stream& s) { + assert(inputs.size() > 1); + const auto& a = inputs[0]; + const auto& b = inputs[1]; + if (out.size() == 0) { + return; + } + + auto& encoder = rocm::get_command_encoder(s); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); + + auto bopt = get_binary_op_type(a, b); + bool large = out.data_size() > UINT32_MAX; + + dispatch_all_types(a.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + using CTYPE_IN = MLX_GET_TYPE(in_type_tag); + using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); + using InType = hip_type_t; + using OutType = hip_type_t; + + if constexpr (rocm::supports_binary_op()) { + if (bopt == BinaryOpType::General) { + auto [shape, strides_vec] = collapse_contiguous_dims(a, b, out); + rocm::launch_binary_general(encoder, a, b, out, shape, strides_vec); + } else { + constexpr int N_READS = 4; + int block_size = 256; + auto size = out.data_size(); + int num_blocks = (size + block_size * N_READS - 1) / (block_size * N_READS); + num_blocks = std::max(1, std::min(num_blocks, 65535)); + + int64_t size_large = static_cast(size); + uint32_t size_small = static_cast(size); + if (bopt == BinaryOpType::ScalarScalar) { + if (large) { + encoder.add_kernel_node( + &rocm::binary_ss, + dim3(num_blocks), dim3(block_size), 0, + gpu_ptr(a), gpu_ptr(b), gpu_ptr(out), + size_large); + } else { + encoder.add_kernel_node( + &rocm::binary_ss, + dim3(num_blocks), dim3(block_size), 0, + gpu_ptr(a), gpu_ptr(b), gpu_ptr(out), + size_small); + } + } else if (bopt == BinaryOpType::ScalarVector) { + if (large) { + encoder.add_kernel_node( + &rocm::binary_sv, + dim3(num_blocks), dim3(block_size), 0, + gpu_ptr(a), gpu_ptr(b), gpu_ptr(out), + size_large); + } else { + encoder.add_kernel_node( + &rocm::binary_sv, + dim3(num_blocks), dim3(block_size), 0, + gpu_ptr(a), gpu_ptr(b), gpu_ptr(out), + size_small); + } + } else if (bopt == BinaryOpType::VectorScalar) { + if (large) { + encoder.add_kernel_node( + &rocm::binary_vs, + dim3(num_blocks), dim3(block_size), 0, + gpu_ptr(a), gpu_ptr(b), gpu_ptr(out), + size_large); + } else { + encoder.add_kernel_node( + &rocm::binary_vs, + dim3(num_blocks), dim3(block_size), 0, + gpu_ptr(a), gpu_ptr(b), gpu_ptr(out), + size_small); + } + } else { + if (large) { + encoder.add_kernel_node( + &rocm::binary_vv, + dim3(num_blocks), dim3(block_size), 0, + gpu_ptr(a), gpu_ptr(b), gpu_ptr(out), + size_large); + } else { + encoder.add_kernel_node( + &rocm::binary_vv, + dim3(num_blocks), dim3(block_size), 0, + gpu_ptr(a), gpu_ptr(b), gpu_ptr(out), + size_small); + } + } + } + } else { + throw std::runtime_error( + std::string("Unsupported type for binary op ") + op); + } + }); + }); +} + +template +void binary_op_gpu( + const std::vector& inputs, + array& out, + const char* op, + const Stream& s) { + auto& a = inputs[0]; + auto& b = inputs[1]; + auto bopt = get_binary_op_type(a, b); + set_binary_op_output_data(a, b, out, bopt); + binary_op_gpu_inplace(inputs, out, op, s); +} + +#define BINARY_GPU(prim) \ + void prim::eval_gpu(const std::vector& inputs, array& out) { \ + auto& s = out.primitive().stream(); \ + binary_op_gpu(inputs, out, name(), s); \ + } + +BINARY_GPU(Add) +BINARY_GPU(ArcTan2) +BINARY_GPU(Divide) +BINARY_GPU(Greater) +BINARY_GPU(GreaterEqual) +BINARY_GPU(Less) +BINARY_GPU(LessEqual) +BINARY_GPU(LogAddExp) +BINARY_GPU(LogicalAnd) +BINARY_GPU(LogicalOr) +BINARY_GPU(Maximum) +BINARY_GPU(Minimum) +BINARY_GPU(Multiply) +BINARY_GPU(NotEqual) +BINARY_GPU(Power) +BINARY_GPU(Remainder) +BINARY_GPU(Subtract) + +#undef BINARY_GPU + +void Equal::eval_gpu(const std::vector& inputs, array& out) { + auto& s = out.primitive().stream(); + if (equal_nan_) { + binary_op_gpu(inputs, out, name(), s); + } else { + binary_op_gpu(inputs, out, name(), s); + } +} + +void BitwiseBinary::eval_gpu(const std::vector& inputs, array& out) { + auto& s = out.primitive().stream(); + switch (op_) { + case BitwiseBinary::And: + binary_op_gpu(inputs, out, name(), s); + break; + case BitwiseBinary::Or: + binary_op_gpu(inputs, out, name(), s); + break; + case BitwiseBinary::Xor: + binary_op_gpu(inputs, out, name(), s); + break; + case BitwiseBinary::LeftShift: + binary_op_gpu(inputs, out, name(), s); + break; + case BitwiseBinary::RightShift: + binary_op_gpu(inputs, out, name(), s); + break; + } +} + +void DivMod::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + // DivMod outputs two arrays: quotient and remainder + auto& s = outputs[0].primitive().stream(); + auto& a = inputs[0]; + auto& b = inputs[1]; + + // Set output data + auto bopt = get_binary_op_type(a, b); + set_binary_op_output_data(a, b, outputs[0], bopt); + set_binary_op_output_data(a, b, outputs[1], bopt); + + // Compute floor divide for first output + binary_op_gpu_inplace(inputs, outputs[0], "FloorDivide", s); + + // Compute remainder for second output + binary_op_gpu_inplace(inputs, outputs[1], "Remainder", s); +} + +} // namespace mlx::core diff --git a/mlx/backend/rocm/binary_two.hip b/mlx/backend/rocm/binary_two.hip new file mode 100644 index 0000000000..9a908b541d --- /dev/null +++ b/mlx/backend/rocm/binary_two.hip @@ -0,0 +1,245 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/binary.h" +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/allocator.h" +#include "mlx/backend/rocm/device/binary_ops.hpp" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include + +namespace mlx::core { + +namespace rocm { + +// Use DivMod from binary_ops.hpp + +template +__global__ void binary_two_ss( + const In* a, + const In* b, + Out* out_a, + Out* out_b, + IdxT size) { + IdxT index = blockIdx.x * blockDim.x + threadIdx.x; + IdxT stride = blockDim.x * gridDim.x; + + Op op; + for (IdxT i = index * N_READS; i < size; i += stride * N_READS) { + #pragma unroll + for (int j = 0; j < N_READS && (i + j) < size; ++j) { + auto result = op(a[0], b[0]); + out_a[i + j] = result[0]; + out_b[i + j] = result[1]; + } + } +} + +template +__global__ void binary_two_sv( + const In* a, + const In* b, + Out* out_a, + Out* out_b, + IdxT size) { + IdxT index = blockIdx.x * blockDim.x + threadIdx.x; + IdxT stride = blockDim.x * gridDim.x; + + Op op; + for (IdxT i = index * N_READS; i < size; i += stride * N_READS) { + #pragma unroll + for (int j = 0; j < N_READS && (i + j) < size; ++j) { + auto result = op(a[0], b[i + j]); + out_a[i + j] = result[0]; + out_b[i + j] = result[1]; + } + } +} + +template +__global__ void binary_two_vs( + const In* a, + const In* b, + Out* out_a, + Out* out_b, + IdxT size) { + IdxT index = blockIdx.x * blockDim.x + threadIdx.x; + IdxT stride = blockDim.x * gridDim.x; + + Op op; + for (IdxT i = index * N_READS; i < size; i += stride * N_READS) { + #pragma unroll + for (int j = 0; j < N_READS && (i + j) < size; ++j) { + auto result = op(a[i + j], b[0]); + out_a[i + j] = result[0]; + out_b[i + j] = result[1]; + } + } +} + +template +__global__ void binary_two_vv( + const In* a, + const In* b, + Out* out_a, + Out* out_b, + IdxT size) { + IdxT index = blockIdx.x * blockDim.x + threadIdx.x; + IdxT stride = blockDim.x * gridDim.x; + + Op op; + for (IdxT i = index * N_READS; i < size; i += stride * N_READS) { + #pragma unroll + for (int j = 0; j < N_READS && (i + j) < size; ++j) { + auto result = op(a[i + j], b[i + j]); + out_a[i + j] = result[0]; + out_b[i + j] = result[1]; + } + } +} + +template +__global__ void binary_two_g( + const In* a, + const In* b, + Out* out_a, + Out* out_b, + IdxT size, + const int* shape, + const int64_t* a_strides, + const int64_t* b_strides, + int ndim) { + IdxT index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= size) return; + + // Compute input indices + int64_t a_idx = 0; + int64_t b_idx = 0; + IdxT tmp = index; + for (int i = ndim - 1; i >= 0; --i) { + int coord = tmp % shape[i]; + a_idx += coord * a_strides[i]; + b_idx += coord * b_strides[i]; + tmp /= shape[i]; + } + + Op op; + auto result = op(a[a_idx], b[b_idx]); + out_a[index] = result[0]; + out_b[index] = result[1]; +} + +template +constexpr bool supports_binary_two_op() { + if constexpr (std::is_same_v) { + return std::is_same_v && (std::is_integral_v || std::is_floating_point_v); + } + return false; +} + +} // namespace rocm + +template +void binary_two_op_gpu_inplace( + const std::vector& inputs, + std::vector& outputs, + const char* op_name, + const Stream& s) { + assert(inputs.size() > 1); + const auto& a = inputs[0]; + const auto& b = inputs[1]; + auto& out_a = outputs[0]; + auto& out_b = outputs[1]; + auto bopt = get_binary_op_type(a, b); + auto& encoder = rocm::get_command_encoder(s); + + set_binary_op_output_data( + a, b, out_a, bopt, [&](auto n) { return mlx::core::rocm::malloc_async(n, encoder); }); + set_binary_op_output_data( + a, b, out_b, bopt, [&](auto n) { return mlx::core::rocm::malloc_async(n, encoder); }); + + if (out_a.size() == 0) { + return; + } + + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out_a); + encoder.set_output_array(out_b); + + constexpr int N_READS = 4; + int block_size = 256; + size_t size = out_a.data_size(); + int num_blocks = std::min((size + block_size * N_READS - 1) / (block_size * N_READS), (size_t)65535); + + int64_t size_arg = static_cast(size); + #define LAUNCH_BINARY_TWO(T, OP_TYPE) \ + switch (bopt) { \ + case BinaryOpType::ScalarScalar: \ + encoder.add_kernel_node( \ + &rocm::binary_two_ss, \ + dim3(num_blocks), dim3(block_size), 0, \ + gpu_ptr(a), gpu_ptr(b), gpu_ptr(out_a), gpu_ptr(out_b), \ + size_arg); \ + break; \ + case BinaryOpType::ScalarVector: \ + encoder.add_kernel_node( \ + &rocm::binary_two_sv, \ + dim3(num_blocks), dim3(block_size), 0, \ + gpu_ptr(a), gpu_ptr(b), gpu_ptr(out_a), gpu_ptr(out_b), \ + size_arg); \ + break; \ + case BinaryOpType::VectorScalar: \ + encoder.add_kernel_node( \ + &rocm::binary_two_vs, \ + dim3(num_blocks), dim3(block_size), 0, \ + gpu_ptr(a), gpu_ptr(b), gpu_ptr(out_a), gpu_ptr(out_b), \ + size_arg); \ + break; \ + case BinaryOpType::VectorVector: \ + encoder.add_kernel_node( \ + &rocm::binary_two_vv, \ + dim3(num_blocks), dim3(block_size), 0, \ + gpu_ptr(a), gpu_ptr(b), gpu_ptr(out_a), gpu_ptr(out_b), \ + size_arg); \ + break; \ + default: \ + throw std::runtime_error("Unsupported binary op type for binary_two"); \ + } + + if constexpr (std::is_same_v) { + switch (a.dtype()) { + case float32: LAUNCH_BINARY_TWO(float, DivMod); break; + case int32: LAUNCH_BINARY_TWO(int32_t, DivMod); break; + case int64: LAUNCH_BINARY_TWO(int64_t, DivMod); break; + default: + throw std::runtime_error("Unsupported type for DivMod"); + } + } + #undef LAUNCH_BINARY_TWO +} + +template +void binary_two_op_gpu( + const std::vector& inputs, + std::vector& outputs, + const char* op_name, + const Stream& s) { + auto& a = inputs[0]; + auto& b = inputs[1]; + auto bopt = get_binary_op_type(a, b); + set_binary_op_output_data(a, b, outputs[0], bopt); + set_binary_op_output_data(a, b, outputs[1], bopt); + binary_two_op_gpu_inplace(inputs, outputs, op_name, s); +} + +void DivMod::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + auto& s = outputs[0].primitive().stream(); + binary_two_op_gpu(inputs, outputs, name(), s); +} + +} // namespace mlx::core diff --git a/mlx/backend/rocm/compiled.cpp b/mlx/backend/rocm/compiled.cpp new file mode 100644 index 0000000000..db0b67560e --- /dev/null +++ b/mlx/backend/rocm/compiled.cpp @@ -0,0 +1,851 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/compiled.h" +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/jit_module.h" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/graph_utils.h" +#include "mlx/primitives.h" + +#include + +namespace mlx::core { + +namespace rocm { + +struct FusedKernelBuilder { + std::string os; + const std::string& kernel_name; + const std::vector& inputs; + const std::vector& outputs; + const std::vector& tape; + const std::function& is_constant; + + void build(const char* name, bool contiguous) { + NodeNamer namer; + + // Function parameters. + std::vector params; + for (size_t i = 0; i < inputs.size(); ++i) { + if (is_constant(i)) { + continue; + } + const auto& x = inputs[i]; + const std::string& xname = namer.get_name(x); + params.push_back( + std::string("const ") + dtype_to_hip_type(x.dtype()) + "* " + xname); + if (!is_scalar(x) && !contiguous) { + params.push_back( + std::string("const hip::std::array ") + xname + + "_strides"); + } + } + for (const auto& x : outputs) { + params.push_back( + std::string(dtype_to_hip_type(x.dtype())) + "* " + namer.get_name(x)); + } + if (!contiguous) { + params.push_back("const hip::std::array shape"); + } + params.push_back("IdxT size"); + + // Build function signature. + if (contiguous) { + os += "template \n"; + } else { + os += + "template \n"; + } + os += "__global__ void " + kernel_name + name + "(\n"; + for (size_t i = 0; i < params.size(); ++i) { + os += " "; + os += params[i]; + if (i != params.size() - 1) { + os += ",\n"; + } + } + os += ") {\n"; + + // Index. For non contiguous kernels we create a separate index + // variable per variable otherwise everyone uses `index`. + os += + " IdxT index = (blockIdx.x * blockDim.x + threadIdx.x) * work_per_thread;\n" + " if (index >= size) {\n" + " return;\n" + " }\n"; + if (!contiguous) { + for (size_t i = 0; i < inputs.size(); ++i) { + const auto& x = inputs[i]; + const std::string& xname = namer.get_name(x); + if (is_scalar(x) || is_constant(i)) { + continue; + } + os += " IdxT " + xname + "_idx = 0;\n"; + } + os += " {\n"; + os += " IdxT loc = index;\n"; + os += + " #pragma unroll\n" + " for (int i = NDIM - 1; i >= 0; i--) {\n"; + for (size_t i = 0; i < inputs.size(); ++i) { + const auto& x = inputs[i]; + const std::string& xname = namer.get_name(x); + if (is_scalar(x) || is_constant(i)) { + continue; + } + os += " " + xname + "_idx += (loc \% shape[i]) * IdxT(" + xname + + "_strides[i]);\n"; + } + os += + " loc /= shape[i];\n" + " }\n" + " }\n"; + } + + // Work loop + if (!contiguous) { + os += + "\n" + " for (int i = 0; i < work_per_thread && index + i < size; i++) {\n"; + } else { + os += + "\n" + " #pragma unroll\n" + " for (int i = 0; i < work_per_thread; i++) {\n" + " if (index + i >= size) break;\n"; + } + + // Read inputs. + for (size_t i = 0; i < inputs.size(); ++i) { + const auto& x = inputs[i]; + const std::string& xname = namer.get_name(x); + std::string type = dtype_to_hip_type(x.dtype()); + std::string value; + if (is_constant(i)) { + std::ostringstream ss; + print_constant(ss, x); + value = std::string("static_cast<") + type + ">(" + ss.str() + ")"; + } else if (is_scalar(x)) { + value = xname + "[0]"; + } else if (contiguous) { + value = xname + "[index + i]"; + } else { + value = xname + "[" + xname + "_idx]"; + } + os += + std::string(" ") + type + " tmp_" + xname + " = " + value + ";\n"; + } + + // Write tape. + for (const auto& x : tape) { + const std::string& xname = namer.get_name(x); + std::string type = dtype_to_hip_type(x.dtype()); + std::string value; + if (is_static_cast(x.primitive())) { + value = std::string("static_cast<") + type + ">(tmp_" + + namer.get_name(x.inputs()[0]) + ")"; + } else { + value = x.primitive().name(); + value += "{}("; + for (size_t i = 0; i < x.inputs().size() - 1; ++i) { + value += "tmp_" + namer.get_name(x.inputs()[i]) + ", "; + } + value += "tmp_" + namer.get_name(x.inputs().back()) + ")"; + } + os += + std::string(" ") + type + " tmp_" + xname + " = " + value + ";\n"; + } + + // Write output. + for (const auto& x : outputs) { + std::string xname = namer.get_name(x); + if (contiguous) { + os += + std::string(" ") + xname + "[index + i] = tmp_" + xname + ";\n"; + } else { + os += + std::string(" ") + xname + "[index + i] = tmp_" + xname + ";\n"; + } + } + + // End of work loop + if (!contiguous) { + os += "\n"; + for (size_t i = 0; i < inputs.size(); ++i) { + const auto& x = inputs[i]; + const std::string& xname = namer.get_name(x); + if (is_scalar(x) || is_constant(i)) { + continue; + } + os += std::string(" ") + xname + "_idx += " + xname + + "_strides[NDIM - 1];\n"; + } + } + os += " }\n"; + + os += "}\n"; + } +}; + +} // namespace rocm + +constexpr const char* g_jit_includes = R"( +#include +#include +#include + +// Standard type definitions for JIT compilation +using uint32_t = unsigned int; +using int32_t = signed int; +using uint64_t = unsigned long long; +using int64_t = signed long long; +using uint16_t = unsigned short; +using int16_t = signed short; +using uint8_t = unsigned char; +using int8_t = signed char; +using size_t = unsigned long; + +// Simple array type for JIT compilation (hip/std/array not available in hiprtc) +namespace hip { +namespace std { +template +struct array { + T data_[N]; + __device__ T& operator[](int i) { return data_[i]; } + __device__ const T& operator[](int i) const { return data_[i]; } +}; + +template +struct numeric_limits; + +template <> +struct numeric_limits { + __device__ static float infinity() { return __int_as_float(0x7f800000); } +}; +} // namespace std +} // namespace hip + +// Math function overloads for bfloat16 and half types +// HIP doesn't provide native math functions for these types, +// so we convert to float, compute, and convert back. + +__device__ inline hip_bfloat16 abs(hip_bfloat16 x) { + return hip_bfloat16(fabsf(static_cast(x))); +} +__device__ inline __half abs(__half x) { + return __float2half(fabsf(__half2float(x))); +} + +__device__ inline hip_bfloat16 exp(hip_bfloat16 x) { + return hip_bfloat16(expf(static_cast(x))); +} +__device__ inline __half exp(__half x) { + return __float2half(expf(__half2float(x))); +} + +__device__ inline hip_bfloat16 log(hip_bfloat16 x) { + return hip_bfloat16(logf(static_cast(x))); +} +__device__ inline __half log(__half x) { + return __float2half(logf(__half2float(x))); +} + +__device__ inline hip_bfloat16 sqrt(hip_bfloat16 x) { + return hip_bfloat16(sqrtf(static_cast(x))); +} +__device__ inline __half sqrt(__half x) { + return __float2half(sqrtf(__half2float(x))); +} + +__device__ inline hip_bfloat16 rsqrt(hip_bfloat16 x) { + return hip_bfloat16(rsqrtf(static_cast(x))); +} +__device__ inline __half rsqrt(__half x) { + return __float2half(rsqrtf(__half2float(x))); +} + +__device__ inline hip_bfloat16 sin(hip_bfloat16 x) { + return hip_bfloat16(sinf(static_cast(x))); +} +__device__ inline __half sin(__half x) { + return __float2half(sinf(__half2float(x))); +} + +__device__ inline hip_bfloat16 cos(hip_bfloat16 x) { + return hip_bfloat16(cosf(static_cast(x))); +} +__device__ inline __half cos(__half x) { + return __float2half(cosf(__half2float(x))); +} + +__device__ inline hip_bfloat16 tan(hip_bfloat16 x) { + return hip_bfloat16(tanf(static_cast(x))); +} +__device__ inline __half tan(__half x) { + return __float2half(tanf(__half2float(x))); +} + +__device__ inline hip_bfloat16 sinh(hip_bfloat16 x) { + return hip_bfloat16(sinhf(static_cast(x))); +} +__device__ inline __half sinh(__half x) { + return __float2half(sinhf(__half2float(x))); +} + +__device__ inline hip_bfloat16 cosh(hip_bfloat16 x) { + return hip_bfloat16(coshf(static_cast(x))); +} +__device__ inline __half cosh(__half x) { + return __float2half(coshf(__half2float(x))); +} + +__device__ inline hip_bfloat16 tanh(hip_bfloat16 x) { + return hip_bfloat16(tanhf(static_cast(x))); +} +__device__ inline __half tanh(__half x) { + return __float2half(tanhf(__half2float(x))); +} + +__device__ inline hip_bfloat16 asin(hip_bfloat16 x) { + return hip_bfloat16(asinf(static_cast(x))); +} +__device__ inline __half asin(__half x) { + return __float2half(asinf(__half2float(x))); +} + +__device__ inline hip_bfloat16 acos(hip_bfloat16 x) { + return hip_bfloat16(acosf(static_cast(x))); +} +__device__ inline __half acos(__half x) { + return __float2half(acosf(__half2float(x))); +} + +__device__ inline hip_bfloat16 atan(hip_bfloat16 x) { + return hip_bfloat16(atanf(static_cast(x))); +} +__device__ inline __half atan(__half x) { + return __float2half(atanf(__half2float(x))); +} + +__device__ inline hip_bfloat16 asinh(hip_bfloat16 x) { + return hip_bfloat16(asinhf(static_cast(x))); +} +__device__ inline __half asinh(__half x) { + return __float2half(asinhf(__half2float(x))); +} + +__device__ inline hip_bfloat16 acosh(hip_bfloat16 x) { + return hip_bfloat16(acoshf(static_cast(x))); +} +__device__ inline __half acosh(__half x) { + return __float2half(acoshf(__half2float(x))); +} + +__device__ inline hip_bfloat16 atanh(hip_bfloat16 x) { + return hip_bfloat16(atanhf(static_cast(x))); +} +__device__ inline __half atanh(__half x) { + return __float2half(atanhf(__half2float(x))); +} + +__device__ inline hip_bfloat16 ceil(hip_bfloat16 x) { + return hip_bfloat16(ceilf(static_cast(x))); +} +__device__ inline __half ceil(__half x) { + return __float2half(ceilf(__half2float(x))); +} + +__device__ inline hip_bfloat16 floor(hip_bfloat16 x) { + return hip_bfloat16(floorf(static_cast(x))); +} +__device__ inline __half floor(__half x) { + return __float2half(floorf(__half2float(x))); +} + +__device__ inline hip_bfloat16 rint(hip_bfloat16 x) { + return hip_bfloat16(rintf(static_cast(x))); +} +__device__ inline __half rint(__half x) { + return __float2half(rintf(__half2float(x))); +} + +__device__ inline hip_bfloat16 log2(hip_bfloat16 x) { + return hip_bfloat16(log2f(static_cast(x))); +} +__device__ inline __half log2(__half x) { + return __float2half(log2f(__half2float(x))); +} + +__device__ inline hip_bfloat16 log10(hip_bfloat16 x) { + return hip_bfloat16(log10f(static_cast(x))); +} +__device__ inline __half log10(__half x) { + return __float2half(log10f(__half2float(x))); +} + +__device__ inline hip_bfloat16 log1pf(hip_bfloat16 x) { + return hip_bfloat16(::log1pf(static_cast(x))); +} +__device__ inline __half log1pf(__half x) { + return __float2half(::log1pf(__half2float(x))); +} + +__device__ inline hip_bfloat16 expm1f(hip_bfloat16 x) { + return hip_bfloat16(::expm1f(static_cast(x))); +} +__device__ inline __half expm1f(__half x) { + return __float2half(::expm1f(__half2float(x))); +} + +__device__ inline hip_bfloat16 erff(hip_bfloat16 x) { + return hip_bfloat16(::erff(static_cast(x))); +} +__device__ inline __half erff(__half x) { + return __float2half(::erff(__half2float(x))); +} + +__device__ inline hip_bfloat16 erfinvf(hip_bfloat16 x) { + return hip_bfloat16(::erfinvf(static_cast(x))); +} +__device__ inline __half erfinvf(__half x) { + return __float2half(::erfinvf(__half2float(x))); +} + +__device__ inline hip_bfloat16 powf(hip_bfloat16 base, hip_bfloat16 exp) { + return hip_bfloat16(::powf(static_cast(base), static_cast(exp))); +} +__device__ inline __half powf(__half base, __half exp) { + return __float2half(::powf(__half2float(base), __half2float(exp))); +} + +__device__ inline hip_bfloat16 fmodf(hip_bfloat16 x, hip_bfloat16 y) { + return hip_bfloat16(::fmodf(static_cast(x), static_cast(y))); +} +__device__ inline __half fmodf(__half x, __half y) { + return __float2half(::fmodf(__half2float(x), __half2float(y))); +} + +__device__ inline hip_bfloat16 truncf(hip_bfloat16 x) { + return hip_bfloat16(::truncf(static_cast(x))); +} +__device__ inline __half truncf(__half x) { + return __float2half(::truncf(__half2float(x))); +} + +__device__ inline hip_bfloat16 atan2f(hip_bfloat16 y, hip_bfloat16 x) { + return hip_bfloat16(::atan2f(static_cast(y), static_cast(x))); +} +__device__ inline __half atan2f(__half y, __half x) { + return __float2half(::atan2f(__half2float(y), __half2float(x))); +} + +// Include device operations +namespace mlx::core::rocm { + +// Binary ops — promote half/bfloat16 through float to avoid precision loss +// that compounds across 28-36 transformer layers in LLM inference. +struct Add { + template + __device__ T operator()(T x, T y) { + return T(static_cast(x) + static_cast(y)); + } +}; + +struct Subtract { + template + __device__ T operator()(T x, T y) { + return T(static_cast(x) - static_cast(y)); + } +}; + +struct Multiply { + template + __device__ T operator()(T x, T y) { + return T(static_cast(x) * static_cast(y)); + } +}; + +struct Divide { + template + __device__ T operator()(T x, T y) { + return T(static_cast(x) / static_cast(y)); + } +}; + +struct Maximum { + template + __device__ T operator()(T x, T y) { return x > y ? x : y; } +}; + +struct Minimum { + template + __device__ T operator()(T x, T y) { return x < y ? x : y; } +}; + +struct Power { + template + __device__ T operator()(T base, T exp) { + return T(powf(static_cast(base), static_cast(exp))); + } +}; + +struct Equal { + template + __device__ bool operator()(T x, T y) { return x == y; } +}; + +struct NotEqual { + template + __device__ bool operator()(T x, T y) { return x != y; } +}; + +struct Greater { + template + __device__ bool operator()(T x, T y) { return x > y; } +}; + +struct GreaterEqual { + template + __device__ bool operator()(T x, T y) { return x >= y; } +}; + +struct Less { + template + __device__ bool operator()(T x, T y) { return x < y; } +}; + +struct LessEqual { + template + __device__ bool operator()(T x, T y) { return x <= y; } +}; + +struct LogicalAnd { + template + __device__ bool operator()(T x, T y) { return x && y; } +}; + +struct LogicalOr { + template + __device__ bool operator()(T x, T y) { return x || y; } +}; + +struct ArcTan2 { + template + __device__ T operator()(T y, T x) { + return T(atan2f(static_cast(y), static_cast(x))); + } +}; + +struct Remainder { + template + __device__ T operator()(T x, T y) { + return T(fmodf(static_cast(x), static_cast(y))); + } +}; + +struct FloorDivide { + template + __device__ T operator()(T x, T y) { + return T(truncf(static_cast(x) / static_cast(y))); + } +}; + +struct LogAddExp { + __device__ hip_bfloat16 operator()(hip_bfloat16 x, hip_bfloat16 y) { + float fx = static_cast(x); + float fy = static_cast(y); + float maxval = fx > fy ? fx : fy; + float minval = fx > fy ? fy : fx; + return hip_bfloat16(maxval + log1pf(expf(minval - maxval))); + } + + __device__ __half operator()(__half x, __half y) { + float fx = __half2float(x); + float fy = __half2float(y); + float maxval = fx > fy ? fx : fy; + float minval = fx > fy ? fy : fx; + return __float2half(maxval + log1pf(expf(minval - maxval))); + } + + template + __device__ T operator()(T x, T y) { + float fx = static_cast(x); + float fy = static_cast(y); + float maxval = fx > fy ? fx : fy; + float minval = fx > fy ? fy : fx; + return T(maxval + log1pf(expf(minval - maxval))); + } +}; + +struct BitwiseAnd { + template + __device__ T operator()(T x, T y) { return x & y; } +}; + +struct BitwiseOr { + template + __device__ T operator()(T x, T y) { return x | y; } +}; + +struct BitwiseXor { + template + __device__ T operator()(T x, T y) { return x ^ y; } +}; + +struct LeftShift { + template + __device__ T operator()(T x, T y) { return x << y; } +}; + +struct RightShift { + template + __device__ T operator()(T x, T y) { return x >> y; } +}; + +// All unary math ops promote through float to support half/bfloat16. +// For float inputs the static_cast is a no-op. +#define UNARY_FLOAT_OP(name, op) \ +struct name { \ + template \ + __device__ T operator()(T x) { \ + return T(op(static_cast(x))); \ + } \ +}; + +// Unary ops +UNARY_FLOAT_OP(Abs, fabsf) +UNARY_FLOAT_OP(Exp, expf) +UNARY_FLOAT_OP(Log, logf) +UNARY_FLOAT_OP(Sqrt, sqrtf) + +struct Negative { + template + __device__ T operator()(T x) { return -x; } +}; + +struct Square { + template + __device__ T operator()(T x) { + float fx = static_cast(x); + return T(fx * fx); + } +}; + +struct Sigmoid { + __device__ hip_bfloat16 operator()(hip_bfloat16 x) { + float fx = static_cast(x); + float y = 1.0f / (1.0f + expf(-fabsf(fx))); + return hip_bfloat16((fx < 0.0f) ? 1.0f - y : y); + } + + __device__ __half operator()(__half x) { + float fx = __half2float(x); + float y = 1.0f / (1.0f + expf(-fabsf(fx))); + return __float2half((fx < 0.0f) ? 1.0f - y : y); + } + + template + __device__ T operator()(T x) { + float fx = static_cast(x); + float y = 1.0f / (1.0f + expf(-fabsf(fx))); + return T((fx < 0.0f) ? 1.0f - y : y); + } +}; + +UNARY_FLOAT_OP(Tanh, tanhf) +UNARY_FLOAT_OP(Sin, sinf) +UNARY_FLOAT_OP(Cos, cosf) +UNARY_FLOAT_OP(Tan, tanf) +UNARY_FLOAT_OP(Sinh, sinhf) +UNARY_FLOAT_OP(Cosh, coshf) +UNARY_FLOAT_OP(Erf, erff) +UNARY_FLOAT_OP(ErfInv, erfinvf) +UNARY_FLOAT_OP(Expm1, expm1f) +UNARY_FLOAT_OP(Log1p, log1pf) +UNARY_FLOAT_OP(Log2, log2f) +UNARY_FLOAT_OP(Log10, log10f) +UNARY_FLOAT_OP(Ceil, ceilf) +UNARY_FLOAT_OP(Floor, floorf) +UNARY_FLOAT_OP(Round, rintf) +UNARY_FLOAT_OP(Rsqrt, rsqrtf) + +struct Sign { + template + __device__ T operator()(T x) { + float fx = static_cast(x); + return T((fx > 0.0f) - (fx < 0.0f)); + } +}; + +UNARY_FLOAT_OP(Asin, asinf) +UNARY_FLOAT_OP(Acos, acosf) +UNARY_FLOAT_OP(Atan, atanf) +UNARY_FLOAT_OP(Asinh, asinhf) +UNARY_FLOAT_OP(Acosh, acoshf) +UNARY_FLOAT_OP(Atanh, atanhf) + +struct LogicalNot { + template + __device__ bool operator()(T x) { return !x; } +}; + +struct BitwiseNot { + template + __device__ T operator()(T x) { return ~x; } +}; + +#undef UNARY_FLOAT_OP + +struct Reciprocal { + template + __device__ T operator()(T x) { return T(1.0f / static_cast(x)); } +}; + +// Ternary ops +struct Select { + template + __device__ T operator()(bool c, T x, T y) { return c ? x : y; } +}; + +// Broadcast is a no-op in fused kernels (handled by indexing) +struct Broadcast { + template + __device__ T operator()(T x) { return x; } +}; + +} // namespace mlx::core::rocm + +#define inf hip::std::numeric_limits::infinity() +)"; + +void Compiled::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + auto& s = stream(); + + // Determine the work per thread for the vectorized reads/writes. + int max_size = 1; + for (const auto& x : outputs) { + max_size = (max_size > x.itemsize()) ? max_size : x.itemsize(); + } + int work_per_thread = 16 / max_size; + + rocm::JitModule& mod = rocm::get_jit_module(s.device, lib_name(), [&]() { + // Build source code. + rocm::FusedKernelBuilder builder{ + g_jit_includes, lib_name(), inputs_, outputs_, tape_, is_constant_}; + builder.os += "namespace mlx::core::rocm {\n\n"; + builder.build("_contiguous", true); + builder.os += "\n"; + builder.build("_strided", false); + builder.os += "\n} // namespace mlx::core::rocm\n"; + + // Build kernel names. + std::vector kernel_names; + kernel_names.push_back( + std::string("mlx::core::rocm::") + lib_name() + + "_contiguous"); + kernel_names.push_back( + std::string("mlx::core::rocm::") + lib_name() + + "_contiguous"); + for (auto wpt : std::array{1, work_per_thread}) { + for (int i = 1; i <= MAX_NDIM; ++i) { + kernel_names.push_back( + std::string("mlx::core::rocm::") + lib_name() + "_strided<" + + std::to_string(i) + ", uint32_t, " + std::to_string(wpt) + ">"); + kernel_names.push_back( + std::string("mlx::core::rocm::") + lib_name() + "_strided<" + + std::to_string(i) + ", int64_t, " + std::to_string(wpt) + ">"); + } + } + + return std::make_tuple( + false, std::move(builder.os), std::move(kernel_names)); + }); + + // Collapse contiguous dims to route to a faster kernel if possible. + auto [contiguous, shape, strides_vec] = + compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_); + + // Whether to use large index. + bool large = compiled_use_large_index(inputs, outputs, contiguous); + + rocm::KernelArgs args; + // Put inputs. + int strides_index = 1; + for (size_t i = 0; i < inputs.size(); ++i) { + if (is_constant_(i)) { + continue; + } + const auto& x = inputs[i]; + args.append(x); + if (!contiguous && !is_scalar(x)) { + args.append_ptr(strides_vec[strides_index++].data()); + } + } + + // Put outputs. + compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous); + for (auto& x : outputs) { + args.append(x); + } + + // Put shape and size. + if (!contiguous) { + args.append_ptr(shape.data()); + } + if (large) { + args.append(outputs[0].data_size()); + } else { + args.append(outputs[0].data_size()); + } + + // Choose work per thread + if (!contiguous && shape.back() % work_per_thread != 0) { + work_per_thread = 1; + } + + // Launch kernel. + const char* index_type = large ? "int64_t" : "uint32_t"; + std::string kernel_name = std::string("mlx::core::rocm::") + lib_name(); + if (contiguous) { + kernel_name += std::string("_contiguous<") + index_type + ", " + + std::to_string(work_per_thread) + ">"; + } else { + kernel_name += std::string("_strided<") + std::to_string(shape.size()) + + ", " + index_type + ", " + std::to_string(work_per_thread) + ">"; + } + + auto& encoder = rocm::get_command_encoder(s); + for (const auto& in : inputs) { + encoder.set_input_array(in); + } + for (const auto& out : outputs) { + encoder.set_output_array(out); + } + + auto kernel = mod.get_kernel(kernel_name); + + // Calculate launch configuration + int block_size = 256; + int64_t total_work = + (outputs[0].data_size() + work_per_thread - 1) / work_per_thread; + int num_blocks = (total_work + block_size - 1) / block_size; + + encoder.launch_kernel([&](hipStream_t stream) { + (void)hipModuleLaunchKernel( + kernel, + num_blocks, + 1, + 1, + block_size, + 1, + 1, + 0, + stream, + args.args(), + nullptr); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/rocm/conv/conv.cpp b/mlx/backend/rocm/conv/conv.cpp new file mode 100644 index 0000000000..0780719d4d --- /dev/null +++ b/mlx/backend/rocm/conv/conv.cpp @@ -0,0 +1,93 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/conv/conv.h" +#include "mlx/backend/rocm/allocator.h" +#include "mlx/backend/gpu/copy.h" +#include "mlx/backend/rocm/device.h" +#include "mlx/primitives.h" + +#include + +namespace mlx::core { + +// Forward declaration of gemm_conv functions +void gemm_conv( + rocm::CommandEncoder& encoder, + const array& in, + const array& wt, + array& out, + const std::vector& strides, + const std::vector& padding, + const std::vector& kernel_dilation, + const std::vector& input_dilation, + bool flip, + Stream s); + +void gemm_grouped_conv( + rocm::CommandEncoder& encoder, + const array& in, + const array& wt, + array& out, + const std::vector& strides, + const std::vector& padding, + const std::vector& kernel_dilation, + const std::vector& input_dilation, + int groups, + bool flip, + Stream s); + +void Convolution::eval_gpu(const std::vector& inputs, array& out) { + if (out.size() == 0) { + return; + } + + auto& s = stream(); + auto& d = rocm::device(s.device); + auto& encoder = d.get_command_encoder(s); + + array in = inputs[0]; + array wt = inputs[1]; + + // Allocate output + out.set_data(mlx::core::rocm::malloc_async(out.nbytes(), encoder)); + + // Ensure inputs are contiguous + if (!in.flags().row_contiguous) { + in = contiguous_copy_gpu(in, s); + encoder.add_temporary(in); + } + if (!wt.flags().row_contiguous) { + wt = contiguous_copy_gpu(wt, s); + encoder.add_temporary(wt); + } + + // Use GEMM-based convolution + if (groups_ == 1) { + gemm_conv( + encoder, + in, + wt, + out, + kernel_strides_, + padding_lo_, + kernel_dilation_, + input_dilation_, + flip_, + s); + } else { + gemm_grouped_conv( + encoder, + in, + wt, + out, + kernel_strides_, + padding_lo_, + kernel_dilation_, + input_dilation_, + groups_, + flip_, + s); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/rocm/conv/conv.h b/mlx/backend/rocm/conv/conv.h new file mode 100644 index 0000000000..3a7e30c6e3 --- /dev/null +++ b/mlx/backend/rocm/conv/conv.h @@ -0,0 +1,126 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/gpu/copy.h" +#include "mlx/backend/rocm/device.h" + +namespace mlx::core { + +template +struct ConvParams { + int N; // Batch size + int C; // In channels + int O; // Out channels + int strides[NDIM]; + int padding[NDIM]; + int kernel_dilation[NDIM]; + int input_dilation[NDIM]; + int groups; + bool flip; + int in_spatial_dims[NDIM]; + int wt_spatial_dims[NDIM]; + int out_spatial_dims[NDIM]; + int64_t in_strides[NDIM + 2]; + + ConvParams( + const array& in, + const array& wt, + const array& out, + const std::vector& strides, + const std::vector& padding, + const std::vector& kernel_dilation, + const std::vector& input_dilation, + int groups, + bool flip) + : N(in.shape(0)), + C(in.shape(-1)), + O(wt.shape(0)), + groups(groups), + flip(flip) { + std::copy_n(strides.begin(), NDIM, this->strides); + std::copy_n(padding.begin(), NDIM, this->padding); + std::copy_n(kernel_dilation.begin(), NDIM, this->kernel_dilation); + std::copy_n(input_dilation.begin(), NDIM, this->input_dilation); + std::copy_n(in.shape().begin() + 1, NDIM, this->in_spatial_dims); + std::copy_n(wt.shape().begin() + 1, NDIM, this->wt_spatial_dims); + std::copy_n(out.shape().begin() + 1, NDIM, this->out_spatial_dims); + std::copy_n(in.strides().begin(), NDIM + 2, this->in_strides); + } +}; + +void gemm_grouped_conv( + rocm::CommandEncoder& encoder, + const array& in, + const array& wt, + array& out, + const std::vector& strides, + const std::vector& padding, + const std::vector& kernel_dilation, + const std::vector& input_dilation, + int groups, + bool flip, + Stream s); + +void gemm_conv( + rocm::CommandEncoder& encoder, + const array& in, + const array& wt, + array& out, + const std::vector& strides, + const std::vector& padding, + const std::vector& kernel_dilation, + const std::vector& input_dilation, + bool flip, + Stream s); + +inline void gemm_conv( + rocm::CommandEncoder& encoder, + array in, + array wt, + array& out, + const std::vector& strides, + const std::vector& padding, + const std::vector& kernel_dilation, + const std::vector& input_dilation, + int groups, + bool flip, + Stream s) { + if (!in.flags().row_contiguous) { + in = contiguous_copy_gpu(in, s); + encoder.add_temporary(in); + } + if (!wt.flags().row_contiguous) { + wt = contiguous_copy_gpu(wt, s); + encoder.add_temporary(wt); + } + + if (groups == 1) { + gemm_conv( + encoder, + in, + wt, + out, + strides, + padding, + kernel_dilation, + input_dilation, + flip, + s); + } else { + gemm_grouped_conv( + encoder, + in, + wt, + out, + strides, + padding, + kernel_dilation, + input_dilation, + groups, + flip, + s); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/rocm/conv/gemm_conv.hip b/mlx/backend/rocm/conv/gemm_conv.hip new file mode 100644 index 0000000000..cabf351960 --- /dev/null +++ b/mlx/backend/rocm/conv/gemm_conv.hip @@ -0,0 +1,585 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/gpu/copy.h" +#include "mlx/backend/rocm/conv/conv.h" +#include "mlx/backend/rocm/allocator.h" +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/gemms/naive_gemm.h" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/dtype_utils.h" + +#include + +namespace mlx::core { + +namespace { + +template +__global__ void depthwise_conv1d_kernel( + const T* __restrict__ in, + const T* __restrict__ wt, + T* __restrict__ out, + ConvParams<1> params) { + int out_channel = blockIdx.x * blockDim.x + threadIdx.x; + int out_pos = blockIdx.y; + int batch = blockIdx.z; + + if (out_channel >= params.O || out_pos >= params.out_spatial_dims[0] || + batch >= params.N) { + return; + } + + float acc = 0.0f; + int kernel_size = params.wt_spatial_dims[0]; + int index_max = + 1 + params.input_dilation[0] * (params.in_spatial_dims[0] - 1); + + for (int k = 0; k < kernel_size; ++k) { + int k_input = params.flip ? (kernel_size - 1 - k) : k; + int in_index = out_pos * params.strides[0] - params.padding[0] + + k_input * params.kernel_dilation[0]; + if (in_index >= 0 && in_index < index_max && + (in_index % params.input_dilation[0] == 0)) { + int in_pos = in_index / params.input_dilation[0]; + int64_t in_offset = static_cast(batch) * params.in_strides[0] + + static_cast(in_pos) * params.in_strides[1] + + static_cast(out_channel) * params.in_strides[2]; + int64_t wt_offset = static_cast(out_channel) * kernel_size + k; + acc += + static_cast(in[in_offset]) * static_cast(wt[wt_offset]); + } + } + + int64_t out_offset = + (static_cast(batch) * params.out_spatial_dims[0] + out_pos) * + params.O + + out_channel; + out[out_offset] = static_cast(acc); +} + +void depthwise_conv1d( + rocm::CommandEncoder& encoder, + const array& in, + const array& wt, + array& out, + const std::vector& strides, + const std::vector& padding, + const std::vector& kernel_dilation, + const std::vector& input_dilation, + int groups, + bool flip, + Stream s) { + (void)s; + ConvParams<1> params( + in, + wt, + out, + strides, + padding, + kernel_dilation, + input_dilation, + groups, + flip); + + int block_size = 256; + dim3 block_dims(block_size); + dim3 num_blocks( + (params.O + block_size - 1) / block_size, + params.out_spatial_dims[0], + params.N); + + encoder.set_input_array(in); + encoder.set_input_array(wt); + encoder.set_output_array(out); + + encoder.launch_kernel([&](hipStream_t stream) { + switch (in.dtype()) { + case float32: + depthwise_conv1d_kernel<<>>( + gpu_ptr(in), gpu_ptr(wt), gpu_ptr(out), params); + break; + case float16: + depthwise_conv1d_kernel<__half><<>>( + gpu_ptr<__half>(in), gpu_ptr<__half>(wt), gpu_ptr<__half>(out), params); + break; + case bfloat16: + depthwise_conv1d_kernel + <<>>( + gpu_ptr(in), + gpu_ptr(wt), + gpu_ptr(out), + params); + break; + default: + throw std::runtime_error("Unsupported dtype for depthwise conv1d"); + } + }); +} + +// N-dimensional grouped unfold kernel +template +__global__ void naive_grouped_unfold_transpose_nd( + const T* __restrict__ in, + T* __restrict__ out, + int filter_size, + int out_pixels, + ConvParams params) { + int index_batch = blockIdx.z / out_pixels; + int index_out_spatial = blockIdx.z % out_pixels; + int index_wt_spatial = blockIdx.x * blockDim.x + threadIdx.x; + + if (index_wt_spatial >= filter_size / params.C) { + return; + } + + in += blockIdx.y; // Channel offset + out += blockIdx.z * filter_size + blockIdx.y * (filter_size / params.C); + + bool valid = index_batch < params.N; + + // Get coordinates in input + int index_in[NDIM] = {}; + int wt_stride = 1; + int tmp_out_spatial = index_out_spatial; + int tmp_wt_spatial = index_wt_spatial; + + for (int i = NDIM - 1; i >= 0; --i) { + int index_out = tmp_out_spatial % params.out_spatial_dims[i]; + int index_wt = tmp_wt_spatial % params.wt_spatial_dims[i]; + out += index_wt * wt_stride; + + if (params.flip) { + index_wt = params.wt_spatial_dims[i] - index_wt - 1; + } + + int index = index_out * params.strides[i] - params.padding[i] + + index_wt * params.kernel_dilation[i]; + int index_max = + 1 + params.input_dilation[i] * (params.in_spatial_dims[i] - 1); + + valid &= (index >= 0) && (index < index_max) && + (index % params.input_dilation[i] == 0); + + index_in[i] = index / params.input_dilation[i]; + + tmp_out_spatial /= params.out_spatial_dims[i]; + tmp_wt_spatial /= params.wt_spatial_dims[i]; + wt_stride *= params.wt_spatial_dims[i]; + } + + if (valid) { + int64_t in_offset = index_batch * params.in_strides[0]; + for (int i = 0; i < NDIM; ++i) { + in_offset += index_in[i] * params.in_strides[i + 1]; + } + *out = in[in_offset]; + } else { + *out = T{0}; + } +} + +// Helper to launch unfold kernel for specific NDIM +template +void launch_unfold_kernel( + hipStream_t stream, + const array& in, + array& unfolded, + dim3 num_blocks, + dim3 block_dims, + int filter_size, + int out_pixels, + const ConvParams& params) { + switch (in.dtype()) { + case float32: + naive_grouped_unfold_transpose_nd + <<>>( + gpu_ptr(in), + gpu_ptr(unfolded), + filter_size, + out_pixels, + params); + break; + case float16: + naive_grouped_unfold_transpose_nd<__half, NDIM> + <<>>( + gpu_ptr<__half>(in), + gpu_ptr<__half>(unfolded), + filter_size, + out_pixels, + params); + break; + case bfloat16: + naive_grouped_unfold_transpose_nd + <<>>( + gpu_ptr(in), + gpu_ptr(unfolded), + filter_size, + out_pixels, + params); + break; + default: + throw std::runtime_error("Unsupported dtype for conv unfold"); + } +} + +// Implementation for specific NDIM +template +void gemm_conv_nd( + rocm::CommandEncoder& encoder, + const array& in, + const array& wt, + array& out, + const std::vector& strides, + const std::vector& padding, + const std::vector& kernel_dilation, + const std::vector& input_dilation, + bool flip, + Stream s) { + ConvParams params( + in, wt, out, strides, padding, kernel_dilation, input_dilation, 1, flip); + + int mat_M = out.size() / params.O; + int mat_K = wt.size() / params.O; + int mat_N = params.O; + + bool is_pointwise = !flip; + for (int i = 0; i < NDIM; ++i) { + is_pointwise = is_pointwise && params.wt_spatial_dims[i] == 1 && + params.strides[i] == 1 && params.padding[i] == 0 && + params.kernel_dilation[i] == 1 && params.input_dilation[i] == 1; + } + + if (is_pointwise) { + array wt_2d({params.O, params.C}, wt.dtype(), nullptr, {}); + wt_2d.copy_shared_buffer( + wt, {wt.strides(0), wt.strides(-1)}, wt.flags(), wt.size()); + array wt_contig = contiguous_copy_gpu(wt_2d, s); + encoder.add_temporary(wt_contig); + + rocm::naive_gemm( + encoder, + in, + wt_contig, + out, + mat_M, + mat_N, + mat_K, + false, + mat_K, + true, + mat_K, + 1.0f, + 0.0f); + return; + } + + int filter_size = params.C; + for (int i = 0; i < NDIM; ++i) { + filter_size *= params.wt_spatial_dims[i]; + } + + int out_pixels = 1; + for (int i = 0; i < NDIM; ++i) { + out_pixels *= params.out_spatial_dims[i]; + } + + array unfolded({mat_M, mat_K}, in.dtype(), nullptr, {}); + unfolded.set_data(mlx::core::rocm::malloc_async(unfolded.nbytes(), encoder)); + encoder.add_temporary(unfolded); + + int wt_spatial_size = mat_K / params.C; + dim3 block_dims(std::min(std::max(wt_spatial_size, 32), 1024)); + dim3 num_blocks( + (wt_spatial_size + block_dims.x - 1) / block_dims.x, params.C, mat_M); + + encoder.set_input_array(in); + encoder.set_output_array(unfolded); + + encoder.launch_kernel([&](hipStream_t stream) { + launch_unfold_kernel( + stream, + in, + unfolded, + num_blocks, + block_dims, + filter_size, + out_pixels, + params); + }); + + int wt_spatial_total = 1; + for (int i = 0; i < NDIM; ++i) { + wt_spatial_total *= params.wt_spatial_dims[i]; + } + + array wt_view( + {params.O, params.C, wt_spatial_total}, wt.dtype(), nullptr, {}); + wt_view.copy_shared_buffer( + wt, {wt.strides(0), 1, params.C}, wt.flags(), wt.size()); + array wt_reshaped = contiguous_copy_gpu(wt_view, s); + encoder.add_temporary(wt_reshaped); + + rocm::naive_gemm( + encoder, + unfolded, + wt_reshaped, + out, + mat_M, + mat_N, + mat_K, + false, + mat_K, + true, + mat_K, + 1.0f, + 0.0f); +} + +template +void gemm_grouped_conv_nd( + rocm::CommandEncoder& encoder, + const array& in, + const array& wt, + array& out, + const std::vector& strides, + const std::vector& padding, + const std::vector& kernel_dilation, + const std::vector& input_dilation, + int groups, + bool flip, + Stream s) { + ConvParams params( + in, + wt, + out, + strides, + padding, + kernel_dilation, + input_dilation, + groups, + flip); + + int C_per_group = params.C / params.groups; + int O_per_group = params.O / params.groups; + int mat_M = out.size() / params.O; + int mat_K = wt.size() / params.O; + int mat_N = O_per_group; + + int filter_size = params.C; + for (int i = 0; i < NDIM; ++i) { + filter_size *= params.wt_spatial_dims[i]; + } + + int out_pixels = 1; + for (int i = 0; i < NDIM; ++i) { + out_pixels *= params.out_spatial_dims[i]; + } + + array unfolded({mat_M, mat_K * params.groups}, in.dtype(), nullptr, {}); + unfolded.set_data(mlx::core::rocm::malloc_async(unfolded.nbytes(), encoder)); + encoder.add_temporary(unfolded); + + int wt_spatial_size = (mat_K * params.groups) / params.C; + dim3 block_dims(std::min(std::max(wt_spatial_size, 32), 1024)); + dim3 num_blocks( + (wt_spatial_size + block_dims.x - 1) / block_dims.x, params.C, mat_M); + + encoder.set_input_array(in); + encoder.set_output_array(unfolded); + + encoder.launch_kernel([&](hipStream_t stream) { + launch_unfold_kernel( + stream, + in, + unfolded, + num_blocks, + block_dims, + filter_size, + out_pixels, + params); + }); + + int wt_spatial_total = 1; + for (int i = 0; i < NDIM; ++i) { + wt_spatial_total *= params.wt_spatial_dims[i]; + } + + array wt_view( + {params.O, C_per_group, wt_spatial_total}, wt.dtype(), nullptr, {}); + wt_view.copy_shared_buffer( + wt, {wt.strides(0), 1, C_per_group}, wt.flags(), wt.size()); + array wt_reshaped = contiguous_copy_gpu(wt_view, s); + encoder.add_temporary(wt_reshaped); + + for (int g = 0; g < params.groups; ++g) { + int64_t a_offset = g * mat_K; + int64_t b_offset = g * O_per_group * mat_K; + int64_t c_offset = g * O_per_group; + + rocm::naive_gemm_with_offset_ldc( + encoder, + unfolded, + wt_reshaped, + out, + mat_M, + mat_N, + mat_K, + false, + mat_K * params.groups, + a_offset, + true, + mat_K, + b_offset, + mat_N * params.groups, + c_offset, // ldc = full output row width + 1.0f, + 0.0f); + } +} + +} // namespace + +void gemm_conv( + rocm::CommandEncoder& encoder, + const array& in, + const array& wt, + array& out, + const std::vector& strides, + const std::vector& padding, + const std::vector& kernel_dilation, + const std::vector& input_dilation, + bool flip, + Stream s) { + int conv_ndim = in.ndim() - 2; + + switch (conv_ndim) { + case 1: + gemm_conv_nd<1>( + encoder, + in, + wt, + out, + strides, + padding, + kernel_dilation, + input_dilation, + flip, + s); + break; + case 2: + gemm_conv_nd<2>( + encoder, + in, + wt, + out, + strides, + padding, + kernel_dilation, + input_dilation, + flip, + s); + break; + case 3: + gemm_conv_nd<3>( + encoder, + in, + wt, + out, + strides, + padding, + kernel_dilation, + input_dilation, + flip, + s); + break; + default: + throw std::runtime_error( + "[conv] ROCm GEMM-based convolution only supports 1D, 2D, 3D."); + } +} + +void gemm_grouped_conv( + rocm::CommandEncoder& encoder, + const array& in, + const array& wt, + array& out, + const std::vector& strides, + const std::vector& padding, + const std::vector& kernel_dilation, + const std::vector& input_dilation, + int groups, + bool flip, + Stream s) { + int conv_ndim = in.ndim() - 2; + + // Depthwise 1D convolution with channel multiplier 1 (C == O == groups) + // is a common decode-time pattern (e.g. Qwen3-Next linear attention). + // Running it through unfold + per-group GEMMs is very launch-heavy. + // Use a direct kernel in this configuration. + if (conv_ndim == 1 && in.shape(-1) == groups && wt.shape(0) == groups && + out.shape(-1) == groups && wt.shape(-1) == 1) { + depthwise_conv1d( + encoder, + in, + wt, + out, + strides, + padding, + kernel_dilation, + input_dilation, + groups, + flip, + s); + return; + } + + switch (conv_ndim) { + case 1: + gemm_grouped_conv_nd<1>( + encoder, + in, + wt, + out, + strides, + padding, + kernel_dilation, + input_dilation, + groups, + flip, + s); + break; + case 2: + gemm_grouped_conv_nd<2>( + encoder, + in, + wt, + out, + strides, + padding, + kernel_dilation, + input_dilation, + groups, + flip, + s); + break; + case 3: + gemm_grouped_conv_nd<3>( + encoder, + in, + wt, + out, + strides, + padding, + kernel_dilation, + input_dilation, + groups, + flip, + s); + break; + default: + throw std::runtime_error( + "[conv] ROCm grouped convolution only supports 1D, 2D, 3D."); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/rocm/copy.hip b/mlx/backend/rocm/copy.hip new file mode 100644 index 0000000000..d4a3950074 --- /dev/null +++ b/mlx/backend/rocm/copy.hip @@ -0,0 +1,155 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/utils.h" +#include "mlx/backend/rocm/copy/copy.hpp" +#include "mlx/backend/rocm/allocator.h" +#include "mlx/backend/rocm/kernel_utils.hpp" + +namespace mlx::core { + +void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) { + auto& encoder = rocm::get_command_encoder(s); + bool donated = set_copy_output_data( + in, out, ctype, [&](auto n) { return mlx::core::rocm::malloc_async(n, encoder); }); + if (donated && in.dtype() == out.dtype()) { + // If the output has the same type as the input then there is nothing to + // copy, just use the buffer. + return; + } + if (ctype == CopyType::GeneralGeneral) { + ctype = CopyType::General; + } + copy_gpu_inplace(in, out, ctype, s); +} + +void copy_gpu_inplace( + const array& in, + array& out, + const Shape& shape, + const Strides& strides_in, + const Strides& strides_out, + int64_t offset_in, + int64_t offset_out, + CopyType ctype, + const Stream& s, + std::optional dynamic_offset_in, + std::optional dynamic_offset_out) { + if (out.size() == 0) { + return; + } + + auto& encoder = rocm::get_command_encoder(s); + encoder.set_input_array(in); + encoder.set_output_array(out); + + // Handle dynamic offsets + if (dynamic_offset_in.has_value() || dynamic_offset_out.has_value()) { + auto [shape_collapsed, strides_vec] = collapse_contiguous_dims( + shape, std::vector{strides_in, strides_out}, INT32_MAX); + + // Create zero offset arrays for missing dynamic offsets + // We need to allocate and initialize on GPU to avoid hipDeviceSynchronize + if (!dynamic_offset_in) { + dynamic_offset_in = array({1}, int64, nullptr, {}); + dynamic_offset_in->set_data(mlx::core::rocm::malloc_async(sizeof(int64_t), encoder)); + encoder.add_temporary(*dynamic_offset_in); + // Initialize to zero on GPU using hipMemset + int64_t* ptr = gpu_ptr(*dynamic_offset_in); + encoder.launch_kernel([ptr](hipStream_t stream) { + (void)hipMemsetAsync(ptr, 0, sizeof(int64_t), stream); + }); + } + if (!dynamic_offset_out) { + dynamic_offset_out = array({1}, int64, nullptr, {}); + dynamic_offset_out->set_data(mlx::core::rocm::malloc_async(sizeof(int64_t), encoder)); + encoder.add_temporary(*dynamic_offset_out); + // Initialize to zero on GPU using hipMemset + int64_t* ptr = gpu_ptr(*dynamic_offset_out); + encoder.launch_kernel([ptr](hipStream_t stream) { + (void)hipMemsetAsync(ptr, 0, sizeof(int64_t), stream); + }); + } + encoder.set_input_array(*dynamic_offset_in); + encoder.set_input_array(*dynamic_offset_out); + + copy_general_dynamic( + encoder, + ctype, + in, + out, + offset_in, + offset_out, + shape_collapsed, + strides_vec[0], + strides_vec[1], + *dynamic_offset_in, + *dynamic_offset_out); + return; + } + + if (ctype == CopyType::Scalar || ctype == CopyType::Vector) { + copy_contiguous(encoder, ctype, in, out, offset_in, offset_out); + return; + } + + if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) { + auto [shape_collapsed, strides_vec] = collapse_contiguous_dims( + shape, std::vector{strides_in, strides_out}, INT32_MAX); + if (ctype == CopyType::General) { + copy_general_input( + encoder, + ctype, + in, + out, + offset_in, + offset_out, + shape_collapsed, + strides_vec[0]); + } else { + copy_general( + encoder, + ctype, + in, + out, + offset_in, + offset_out, + shape_collapsed, + strides_vec[0], + strides_vec[1]); + } + return; + } +} + +void fill_gpu(const array& in, array& out, const Stream& s) { + if (out.size() == 0) { + return; + } + auto& encoder = rocm::get_command_encoder(s); + out.set_data(mlx::core::rocm::malloc_async(out.nbytes(), encoder)); + encoder.set_input_array(in); + encoder.set_output_array(out); + copy_contiguous(encoder, CopyType::Scalar, in, out, 0, 0); +} + +void reshape_gpu(const array& in, array& out, Stream s) { + auto [copy_necessary, out_strides] = prepare_reshape(in, out); + if (copy_necessary) { + auto& encoder = rocm::get_command_encoder(s); + out.set_data(mlx::core::rocm::malloc_async(out.nbytes(), encoder)); + copy_gpu_inplace( + in, + out, + in.shape(), + in.strides(), + make_contiguous_strides(in.shape()), + 0, + 0, + CopyType::General, + s); + } else { + shared_buffer_reshape(in, out_strides, out); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/rocm/copy/copy.hpp b/mlx/backend/rocm/copy/copy.hpp new file mode 100644 index 0000000000..b7363db263 --- /dev/null +++ b/mlx/backend/rocm/copy/copy.hpp @@ -0,0 +1,238 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/gpu/copy.h" +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/device/utils.hpp" +#include "mlx/backend/rocm/kernel_utils.hpp" + +#include +#include + +namespace mlx::core { + +namespace rocm { + +// Cast operation for copy - general case +template +struct CastOp { + static constexpr bool is_castable = std::is_convertible_v; + + __device__ DstT operator()(SrcT x) { + return static_cast(x); + } +}; + +// Castings between complex and boolean +template <> +struct CastOp { + static constexpr bool is_castable = true; + + __device__ bool operator()(hipFloatComplex x) { + return x.x != 0 && x.y != 0; + } +}; + +template <> +struct CastOp { + static constexpr bool is_castable = true; + + __device__ hipFloatComplex operator()(bool x) { + return x ? make_hipFloatComplex(1.0f, 1.0f) + : make_hipFloatComplex(0.0f, 0.0f); + } +}; + +// Converting a complex number to real number discards the imaginary part +template +struct CastOp< + hipFloatComplex, + DstT, + std::enable_if_t && !std::is_same_v>> { + static constexpr bool is_castable = true; + + __device__ DstT operator()(hipFloatComplex x) { + return static_cast(x.x); // x.x is the real part + } +}; + +// Allow converting a real number to complex number +template +struct CastOp< + SrcT, + hipFloatComplex, + std::enable_if_t && !std::is_same_v>> { + static constexpr bool is_castable = true; + + __device__ hipFloatComplex operator()(SrcT x) { + return make_hipFloatComplex(static_cast(x), 0.0f); + } +}; + +// Do nothing when no casting is needed +template +struct CastOp { + static constexpr bool is_castable = true; + + __device__ T operator()(T x) { + return x; + } +}; + +// Specializations for half types +template <> +struct CastOp<__half, float> { + static constexpr bool is_castable = true; + __device__ float operator()(__half x) { + return __half2float(x); + } +}; + +template <> +struct CastOp { + static constexpr bool is_castable = true; + __device__ __half operator()(float x) { + return __float2half(x); + } +}; + +template <> +struct CastOp { + static constexpr bool is_castable = true; + __device__ float operator()(hip_bfloat16 x) { + return static_cast(x); + } +}; + +template <> +struct CastOp { + static constexpr bool is_castable = true; + __device__ hip_bfloat16 operator()(float x) { + return hip_bfloat16(x); + } +}; + +// Conversions through float for half types +template +struct CastOp< + __half, + DstT, + std::enable_if_t< + !std::is_same_v && !std::is_same_v && + !is_complex_v>> { + static constexpr bool is_castable = true; + __device__ DstT operator()(__half x) { + return static_cast(__half2float(x)); + } +}; + +template +struct CastOp< + SrcT, + __half, + std::enable_if_t< + !std::is_same_v && !std::is_same_v && + !is_complex_v>> { + static constexpr bool is_castable = true; + __device__ __half operator()(SrcT x) { + return __float2half(static_cast(x)); + } +}; + +template +struct CastOp< + hip_bfloat16, + DstT, + std::enable_if_t< + !std::is_same_v && !std::is_same_v && + !is_complex_v>> { + static constexpr bool is_castable = true; + __device__ DstT operator()(hip_bfloat16 x) { + return static_cast(static_cast(x)); + } +}; + +template +struct CastOp< + SrcT, + hip_bfloat16, + std::enable_if_t< + !std::is_same_v && !std::is_same_v && + !is_complex_v>> { + static constexpr bool is_castable = true; + __device__ hip_bfloat16 operator()(SrcT x) { + return hip_bfloat16(static_cast(x)); + } +}; + +// Conversion between __half and hip_bfloat16 +template <> +struct CastOp<__half, hip_bfloat16> { + static constexpr bool is_castable = true; + __device__ hip_bfloat16 operator()(__half x) { + return hip_bfloat16(__half2float(x)); + } +}; + +template <> +struct CastOp { + static constexpr bool is_castable = true; + __device__ __half operator()(hip_bfloat16 x) { + return __float2half(static_cast(x)); + } +}; + +// Helper to deduce the SrcT +template +inline __device__ auto cast_to(SrcT x) { + return CastOp{}(x); +} + +} // namespace rocm + +// Forward declarations +void copy_contiguous( + rocm::CommandEncoder& encoder, + CopyType ctype, + const array& in, + array& out, + int64_t in_offset, + int64_t out_offset); + +void copy_general_input( + rocm::CommandEncoder& encoder, + CopyType ctype, + const array& in, + array& out, + int64_t in_offset, + int64_t out_offset, + const Shape& shape, + const Strides& strides_in); + +void copy_general( + rocm::CommandEncoder& encoder, + CopyType ctype, + const array& in, + array& out, + int64_t in_offset, + int64_t out_offset, + const Shape& shape, + const Strides& strides_in, + const Strides& strides_out); + +void copy_general_dynamic( + rocm::CommandEncoder& encoder, + CopyType ctype, + const array& in, + array& out, + int64_t offset_in, + int64_t offset_out, + const Shape& shape, + const Strides& strides_in, + const Strides& strides_out, + const array& dynamic_offset_in, + const array& dynamic_offset_out); + +} // namespace mlx::core diff --git a/mlx/backend/rocm/copy/copy_contiguous.hip b/mlx/backend/rocm/copy/copy_contiguous.hip new file mode 100644 index 0000000000..9713aec4ae --- /dev/null +++ b/mlx/backend/rocm/copy/copy_contiguous.hip @@ -0,0 +1,96 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/copy/copy.hpp" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/dtype_utils.h" + +#include + +namespace mlx::core { + +namespace rocm { + +template +__global__ void copy_s(const In* in, Out* out, IdxT size) { + IdxT index = blockIdx.x * blockDim.x + threadIdx.x; + IdxT stride = blockDim.x * gridDim.x; + + for (IdxT i = index * N_READS; i < size; i += stride * N_READS) { + if (i + N_READS <= size) { + #pragma unroll + for (int j = 0; j < N_READS; ++j) { + out[i + j] = cast_to(in[0]); + } + } else { + for (IdxT j = i; j < size; ++j) { + out[j] = cast_to(in[0]); + } + } + } +} + +template +__global__ void copy_v(const In* in, Out* out, IdxT size) { + IdxT index = blockIdx.x * blockDim.x + threadIdx.x; + IdxT stride = blockDim.x * gridDim.x; + + for (IdxT i = index * N_READS; i < size; i += stride * N_READS) { + if (i + N_READS <= size) { + #pragma unroll + for (int j = 0; j < N_READS; ++j) { + out[i + j] = cast_to(in[i + j]); + } + } else { + for (IdxT j = i; j < size; ++j) { + out[j] = cast_to(in[j]); + } + } + } +} + +} // namespace rocm + +void copy_contiguous( + rocm::CommandEncoder& encoder, + CopyType ctype, + const array& in, + array& out, + int64_t in_offset, + int64_t out_offset) { + + // Handle empty arrays + size_t size = out.data_size(); + if (size == 0) { + return; + } + + dispatch_all_types(in.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + dispatch_bool(out.data_size() > UINT32_MAX, [&](auto large) { + using InType = hip_type_t; + using OutType = hip_type_t; + using IdxT = std::conditional_t; + constexpr int N_READS = 4; + + int block_size = 256; + int num_blocks = (size + block_size * N_READS - 1) / (block_size * N_READS); + num_blocks = std::max(1, std::min(num_blocks, 65535)); + + const InType* in_ptr = gpu_ptr(in) + in_offset; + OutType* out_ptr = gpu_ptr(out) + out_offset; + IdxT size_arg = static_cast(size); + + auto kernel = &rocm::copy_s; + if (ctype != CopyType::Scalar) { + kernel = &rocm::copy_v; + } + encoder.add_kernel_node( + kernel, + dim3(num_blocks), dim3(block_size), 0, + in_ptr, out_ptr, size_arg); + }); + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/rocm/copy/copy_general.hip b/mlx/backend/rocm/copy/copy_general.hip new file mode 100644 index 0000000000..d7ad2207e1 --- /dev/null +++ b/mlx/backend/rocm/copy/copy_general.hip @@ -0,0 +1,102 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/copy/copy.hpp" +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/backend/rocm/device/config.h" +#include "mlx/backend/rocm/device/utils.hpp" +#include "mlx/dtype_utils.h" + +#include + +namespace mlx::core { + +namespace rocm { + +// General copy kernel with by-value shape/strides (no hipMemcpyAsync needed) +template +__global__ void copy_gg_byval( + const In* in, + Out* out, + IdxT size, + hip_array shape, + hip_array strides_in, + hip_array strides_out, + int ndim) { + IdxT index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= size) return; + + IdxT loc_in = 0, loc_out = 0; + IdxT elem = index; + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { + IdxT dim_idx = elem % shape[i]; + loc_in += dim_idx * IdxT(strides_in[i]); + loc_out += dim_idx * IdxT(strides_out[i]); + elem /= shape[i]; + } + out[loc_out] = cast_to(in[loc_in]); +} + +} // namespace rocm + +void copy_general( + rocm::CommandEncoder& encoder, + CopyType ctype, + const array& in, + array& out, + int64_t offset_in, + int64_t offset_out, + const Shape& shape, + const Strides& strides_in, + const Strides& strides_out) { + + int ndim = shape.size(); + size_t data_size = 1; + for (auto& s : shape) { + data_size *= s; + } + + if (data_size == 0) { + return; + } + + // Pack shape/strides into by-value structs (no device allocation needed) + rocm::hip_array shape_arg = {}; + rocm::hip_array strides_in_arg = {}; + rocm::hip_array strides_out_arg = {}; + for (int i = 0; i < ndim; i++) { + shape_arg.data_[i] = static_cast(shape[i]); + strides_in_arg.data_[i] = strides_in[i]; + strides_out_arg.data_[i] = strides_out[i]; + } + + const void* in_ptr = gpu_ptr(in); + void* out_ptr = gpu_ptr(out); + + int block_size = 256; + int num_blocks = (data_size + block_size - 1) / block_size; + int64_t size_arg = static_cast(data_size); + + dispatch_all_types(in.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + using InType = hip_type_t; + using OutType = hip_type_t; + + const InType* in_typed = static_cast(in_ptr) + offset_in; + OutType* out_typed = static_cast(out_ptr) + offset_out; + + encoder.add_kernel_node( + &rocm::copy_gg_byval, + dim3(num_blocks), dim3(block_size), 0, + in_typed, + out_typed, + size_arg, + shape_arg, + strides_in_arg, + strides_out_arg, + ndim); + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/rocm/copy/copy_general_dynamic.hip b/mlx/backend/rocm/copy/copy_general_dynamic.hip new file mode 100644 index 0000000000..865c08ddb3 --- /dev/null +++ b/mlx/backend/rocm/copy/copy_general_dynamic.hip @@ -0,0 +1,261 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/copy/copy.hpp" +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/backend/rocm/device/config.h" +#include "mlx/backend/rocm/device/utils.hpp" + +#include +#include +#include +#include + +namespace mlx::core { + +namespace rocm { + +// Kernel with fixed-size arrays passed by value (no device memory needed) +template +__global__ void copy_gg_dynamic_nd( + const In* in, + Out* out, + IdxT size, + const int32_t shape0, const int32_t shape1, const int32_t shape2, + const int64_t strides_in0, const int64_t strides_in1, const int64_t strides_in2, + const int64_t strides_out0, const int64_t strides_out1, const int64_t strides_out2, + const int64_t* offset_in, + const int64_t* offset_out) { + IdxT index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= size) return; + + // Compute input and output locations + IdxT idx_in = 0; + IdxT idx_out = 0; + IdxT elem = index; + + // Unroll based on NDIM + if constexpr (NDIM >= 3) { + IdxT dim_idx = elem % shape2; + elem /= shape2; + idx_in += dim_idx * strides_in2; + idx_out += dim_idx * strides_out2; + } + if constexpr (NDIM >= 2) { + IdxT dim_idx = elem % shape1; + elem /= shape1; + idx_in += dim_idx * strides_in1; + idx_out += dim_idx * strides_out1; + } + if constexpr (NDIM >= 1) { + IdxT dim_idx = elem % shape0; + idx_in += dim_idx * strides_in0; + idx_out += dim_idx * strides_out0; + } + + out[idx_out + *offset_out] = static_cast(in[idx_in + *offset_in]); +} + +// General kernel for ndim > 3 (shape/strides passed by value) +template +__global__ void copy_gg_dynamic( + const In* in, + Out* out, + IdxT size, + hip_array shape, + hip_array strides_in, + hip_array strides_out, + int ndim, + const int64_t* offset_in, + const int64_t* offset_out) { + IdxT index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= size) return; + + // Compute input and output locations + IdxT idx_in = 0; + IdxT idx_out = 0; + IdxT elem = index; + + for (int i = ndim - 1; i >= 0; --i) { + IdxT dim_idx = elem % shape[i]; + elem /= shape[i]; + idx_in += dim_idx * strides_in[i]; + idx_out += dim_idx * strides_out[i]; + } + + out[idx_out + *offset_out] = static_cast(in[idx_in + *offset_in]); +} + +} // namespace rocm + +void copy_general_dynamic( + rocm::CommandEncoder& encoder, + CopyType ctype, + const array& in, + array& out, + int64_t offset_in, + int64_t offset_out, + const Shape& shape, + const Strides& strides_in, + const Strides& strides_out, + const array& dynamic_offset_in, + const array& dynamic_offset_out) { + + encoder.set_input_array(in); + encoder.set_input_array(dynamic_offset_in); + encoder.set_input_array(dynamic_offset_out); + encoder.set_output_array(out); + + int ndim = shape.size(); + size_t size = out.size(); + + int block_size = 256; + int num_blocks = (size + block_size - 1) / block_size; + + bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX; + + // Get GPU pointers before lambda to avoid synchronization issues + const void* in_ptr_base = gpu_ptr(in); + void* out_ptr_base = gpu_ptr(out); + const int64_t* dyn_offset_in_ptr = gpu_ptr(dynamic_offset_in); + const int64_t* dyn_offset_out_ptr = gpu_ptr(dynamic_offset_out); + + // For ndim <= 3, pass shape and strides as kernel arguments (no device memory needed) + if (ndim <= 3) { + // Pad arrays to size 3 + int32_t s0 = ndim > 0 ? static_cast(shape[0]) : 1; + int32_t s1 = ndim > 1 ? static_cast(shape[1]) : 1; + int32_t s2 = ndim > 2 ? static_cast(shape[2]) : 1; + int64_t si0 = ndim > 0 ? strides_in[0] : 0; + int64_t si1 = ndim > 1 ? strides_in[1] : 0; + int64_t si2 = ndim > 2 ? strides_in[2] : 0; + int64_t so0 = ndim > 0 ? strides_out[0] : 0; + int64_t so1 = ndim > 1 ? strides_out[1] : 0; + int64_t so2 = ndim > 2 ? strides_out[2] : 0; + + #define LAUNCH_COPY_DYNAMIC_ND(InT, OutT, IdxT, NDIM) \ + do { \ + const InT* in_typed = \ + static_cast(in_ptr_base) + offset_in; \ + OutT* out_typed = static_cast(out_ptr_base) + offset_out; \ + IdxT size_arg = static_cast(size); \ + encoder.add_kernel_node( \ + &rocm::copy_gg_dynamic_nd, \ + dim3(num_blocks), dim3(block_size), 0, \ + in_typed, \ + out_typed, \ + size_arg, \ + s0, s1, s2, si0, si1, si2, so0, so1, so2, \ + dyn_offset_in_ptr, dyn_offset_out_ptr); \ + } while (0) + + #define DISPATCH_NDIM_ND(InT, OutT, IdxT) \ + switch (ndim) { \ + case 1: LAUNCH_COPY_DYNAMIC_ND(InT, OutT, IdxT, 1); break; \ + case 2: LAUNCH_COPY_DYNAMIC_ND(InT, OutT, IdxT, 2); break; \ + case 3: LAUNCH_COPY_DYNAMIC_ND(InT, OutT, IdxT, 3); break; \ + default: break; \ + } + + #define DISPATCH_OUT_TYPE_ND(InT, IdxT) \ + switch (out.dtype()) { \ + case float32: DISPATCH_NDIM_ND(InT, float, IdxT); break; \ + case float16: DISPATCH_NDIM_ND(InT, __half, IdxT); break; \ + case bfloat16: DISPATCH_NDIM_ND(InT, hip_bfloat16, IdxT); break; \ + case int32: DISPATCH_NDIM_ND(InT, int32_t, IdxT); break; \ + case int64: DISPATCH_NDIM_ND(InT, int64_t, IdxT); break; \ + case uint32: DISPATCH_NDIM_ND(InT, uint32_t, IdxT); break; \ + case uint8: DISPATCH_NDIM_ND(InT, uint8_t, IdxT); break; \ + case bool_: DISPATCH_NDIM_ND(InT, bool, IdxT); break; \ + default: break; \ + } + + #define DISPATCH_IN_TYPE_ND(IdxT) \ + switch (in.dtype()) { \ + case float32: DISPATCH_OUT_TYPE_ND(float, IdxT); break; \ + case float16: DISPATCH_OUT_TYPE_ND(__half, IdxT); break; \ + case bfloat16: DISPATCH_OUT_TYPE_ND(hip_bfloat16, IdxT); break; \ + case int32: DISPATCH_OUT_TYPE_ND(int32_t, IdxT); break; \ + case int64: DISPATCH_OUT_TYPE_ND(int64_t, IdxT); break; \ + case uint32: DISPATCH_OUT_TYPE_ND(uint32_t, IdxT); break; \ + case uint8: DISPATCH_OUT_TYPE_ND(uint8_t, IdxT); break; \ + case bool_: DISPATCH_OUT_TYPE_ND(bool, IdxT); break; \ + default: break; \ + } + + if (large) { + DISPATCH_IN_TYPE_ND(int64_t); + } else { + DISPATCH_IN_TYPE_ND(int32_t); + } + + #undef DISPATCH_IN_TYPE_ND + #undef DISPATCH_OUT_TYPE_ND + #undef DISPATCH_NDIM_ND + #undef LAUNCH_COPY_DYNAMIC_ND + return; + } + + // For ndim > 3, pack shape/strides into by-value structs + rocm::hip_array shape_arg = {}; + rocm::hip_array strides_in_arg = {}; + rocm::hip_array strides_out_arg = {}; + for (int i = 0; i < ndim; i++) { + shape_arg.data_[i] = static_cast(shape[i]); + strides_in_arg.data_[i] = strides_in[i]; + strides_out_arg.data_[i] = strides_out[i]; + } + + #define LAUNCH_COPY_DYNAMIC_GENERAL(InT, OutT, IdxT) \ + do { \ + const InT* in_typed = static_cast(in_ptr_base) + offset_in; \ + OutT* out_typed = static_cast(out_ptr_base) + offset_out; \ + IdxT size_arg = static_cast(size); \ + encoder.add_kernel_node( \ + &rocm::copy_gg_dynamic, \ + dim3(num_blocks), dim3(block_size), 0, \ + in_typed, \ + out_typed, \ + size_arg, shape_arg, \ + strides_in_arg, strides_out_arg, \ + ndim, dyn_offset_in_ptr, dyn_offset_out_ptr); \ + } while (0) + + #define DISPATCH_OUT_TYPE_GEN(InT, IdxT) \ + switch (out.dtype()) { \ + case float32: LAUNCH_COPY_DYNAMIC_GENERAL(InT, float, IdxT); break; \ + case float16: LAUNCH_COPY_DYNAMIC_GENERAL(InT, __half, IdxT); break; \ + case bfloat16: LAUNCH_COPY_DYNAMIC_GENERAL(InT, hip_bfloat16, IdxT); break; \ + case int32: LAUNCH_COPY_DYNAMIC_GENERAL(InT, int32_t, IdxT); break; \ + case int64: LAUNCH_COPY_DYNAMIC_GENERAL(InT, int64_t, IdxT); break; \ + case uint32: LAUNCH_COPY_DYNAMIC_GENERAL(InT, uint32_t, IdxT); break; \ + case uint8: LAUNCH_COPY_DYNAMIC_GENERAL(InT, uint8_t, IdxT); break; \ + case bool_: LAUNCH_COPY_DYNAMIC_GENERAL(InT, bool, IdxT); break; \ + default: break; \ + } + + #define DISPATCH_IN_TYPE_GEN(IdxT) \ + switch (in.dtype()) { \ + case float32: DISPATCH_OUT_TYPE_GEN(float, IdxT); break; \ + case float16: DISPATCH_OUT_TYPE_GEN(__half, IdxT); break; \ + case bfloat16: DISPATCH_OUT_TYPE_GEN(hip_bfloat16, IdxT); break; \ + case int32: DISPATCH_OUT_TYPE_GEN(int32_t, IdxT); break; \ + case int64: DISPATCH_OUT_TYPE_GEN(int64_t, IdxT); break; \ + case uint32: DISPATCH_OUT_TYPE_GEN(uint32_t, IdxT); break; \ + case uint8: DISPATCH_OUT_TYPE_GEN(uint8_t, IdxT); break; \ + case bool_: DISPATCH_OUT_TYPE_GEN(bool, IdxT); break; \ + default: break; \ + } + + if (large) { + DISPATCH_IN_TYPE_GEN(int64_t); + } else { + DISPATCH_IN_TYPE_GEN(int32_t); + } + + #undef DISPATCH_IN_TYPE_GEN + #undef DISPATCH_OUT_TYPE_GEN + #undef LAUNCH_COPY_DYNAMIC_GENERAL +} + +} // namespace mlx::core diff --git a/mlx/backend/rocm/copy/copy_general_input.hip b/mlx/backend/rocm/copy/copy_general_input.hip new file mode 100644 index 0000000000..60c8a62780 --- /dev/null +++ b/mlx/backend/rocm/copy/copy_general_input.hip @@ -0,0 +1,147 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/copy/copy.hpp" +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/backend/rocm/device/config.h" +#include "mlx/backend/rocm/device/utils.hpp" +#include "mlx/dtype_utils.h" + +#include + +namespace mlx::core { + +static constexpr int TILE_SIZE = 16; + +namespace rocm { + +// General copy kernel - strided input to contiguous output (by-value args) +template +__global__ void copy_g_byval( + const In* in, + Out* out, + IdxT size, + hip_array shape, + hip_array strides, + int ndim) { + IdxT index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= size) return; + + IdxT loc = 0; + IdxT elem = index; + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { + loc += (elem % shape[i]) * IdxT(strides[i]); + elem /= shape[i]; + } + out[index] = cast_to(in[loc]); +} + +// Column to row transpose kernel +template +__global__ void copy_col_row( + const T* in, + T* out, + int64_t rows, + int64_t cols) { + __shared__ T tile[TILE_SIZE][TILE_SIZE + 1]; + + int tile_row = blockIdx.x * TILE_SIZE; + int tile_col = blockIdx.y * TILE_SIZE; + + int tidx = threadIdx.x; + int tidy = threadIdx.y; + + int in_row = tile_row + tidx; + int in_col = tile_col + tidy; + if (in_row < rows && in_col < cols) { + tile[tidx][tidy] = in[in_col * rows + in_row]; + } + + __syncthreads(); + + int out_row = tile_row + tidy; + int out_col = tile_col + tidx; + if (out_row < rows && out_col < cols) { + out[out_row * cols + out_col] = tile[tidy][tidx]; + } +} + +} // namespace rocm + +void copy_general_input( + rocm::CommandEncoder& encoder, + CopyType ctype, + const array& in, + array& out, + int64_t offset_in, + int64_t offset_out, + const Shape& shape, + const Strides& strides_in) { + + int ndim = shape.size(); + size_t data_size = out.size(); + + if (data_size == 0) { + return; + } + + // Column contiguous to row contiguous specialization (same type only) + if (ndim == 2 && strides_in[0] == 1 && strides_in[1] == shape[0] && in.dtype() == out.dtype()) { + dispatch_all_types(in.dtype(), [&](auto type_tag) { + using T = hip_type_t; + dim3 block(TILE_SIZE, TILE_SIZE); + dim3 grid((shape[0] + TILE_SIZE - 1) / TILE_SIZE, + (shape[1] + TILE_SIZE - 1) / TILE_SIZE); + const T* in_typed = + reinterpret_cast(gpu_ptr(in)) + offset_in; + T* out_typed = reinterpret_cast(gpu_ptr(out)) + offset_out; + int64_t rows_arg = static_cast(shape[0]); + int64_t cols_arg = static_cast(shape[1]); + encoder.add_kernel_node( + &rocm::copy_col_row, + grid, block, 0, + in_typed, + out_typed, + rows_arg, + cols_arg); + }); + return; + } + + // Pack shape/strides into by-value structs (no device allocation or hipMemcpyAsync) + rocm::hip_array shape_arg = {}; + rocm::hip_array strides_arg = {}; + for (int i = 0; i < ndim; i++) { + shape_arg.data_[i] = static_cast(shape[i]); + strides_arg.data_[i] = strides_in[i]; + } + + const void* in_ptr = gpu_ptr(in); + void* out_ptr = gpu_ptr(out); + + int block_size = 256; + int num_blocks = (data_size + block_size - 1) / block_size; + int64_t size_arg = static_cast(data_size); + + dispatch_all_types(in.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + using InType = hip_type_t; + using OutType = hip_type_t; + + const InType* in_typed = static_cast(in_ptr) + offset_in; + OutType* out_typed = static_cast(out_ptr) + offset_out; + + encoder.add_kernel_node( + &rocm::copy_g_byval, + dim3(num_blocks), dim3(block_size), 0, + in_typed, + out_typed, + size_arg, + shape_arg, + strides_arg, + ndim); + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/rocm/custom_kernel.cpp b/mlx/backend/rocm/custom_kernel.cpp new file mode 100644 index 0000000000..5a81186652 --- /dev/null +++ b/mlx/backend/rocm/custom_kernel.cpp @@ -0,0 +1,376 @@ +// Copyright © 2025 Apple Inc. + +#include +#include + +#include "mlx/backend/common/compiled.h" +#include "mlx/backend/gpu/copy.h" +#include "mlx/backend/rocm/jit_module.h" +#include "mlx/backend/rocm/allocator.h" +#include "mlx/backend/rocm/utils.h" +#include "mlx/fast.h" +#include "mlx/fast_primitives.h" + +#include + +namespace mlx::core::fast { + +namespace { + +// Inline the essential definitions for custom kernels +// This avoids the need for include paths in JIT compilation +constexpr const char* default_header = R"( +#include +#include +#include +#include + +#define inf (1.0f / 0.0f) + +namespace mlx::core::rocm { + +// Type aliases for convenience +using float16_t = __half; +using bfloat16_t = hip_bfloat16; + +// Ceil division +template +__host__ __device__ T ceildiv(T a, T b) { + return (a + b - 1) / b; +} + +// Thread/block index helpers +__device__ inline int thread_index() { + return threadIdx.x + threadIdx.y * blockDim.x + + threadIdx.z * blockDim.x * blockDim.y; +} + +__device__ inline int block_index() { + return blockIdx.x + blockIdx.y * gridDim.x + + blockIdx.z * gridDim.x * gridDim.y; +} + +__device__ inline int global_thread_index() { + return thread_index() + + block_index() * (blockDim.x * blockDim.y * blockDim.z); +} + +// Indexing helper +template +__device__ IdxT +elem_to_loc(IdxT elem, const int* shape, const int64_t* strides, int ndim) { + IdxT loc = 0; + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { + loc += (elem % shape[i]) * IdxT(strides[i]); + elem /= shape[i]; + } + return loc; +} + +} // namespace mlx::core::rocm + +)"; + +std::string template_arguments_hash( + const std::vector>& template_args) { + if (template_args.empty()) { + return ""; + } + + std::ostringstream hash; + + for (const auto& [name, arg] : template_args) { + if (std::holds_alternative(arg)) { + hash << "_" << std::get(arg); + } else if (std::holds_alternative(arg)) { + hash << (std::get(arg) ? "_t" : "_f"); + } else if (std::holds_alternative(arg)) { + hash << "_" << get_type_string(std::get(arg)); + } + } + + return hash.str(); +} + +std::string build_kernel( + const std::string& func_name, + const std::string& header, + const std::string& source, + const std::vector& input_names, + const std::vector& inputs, + const std::vector& output_names, + const std::vector& output_dtypes, + const std::vector>& template_args, + const std::vector>& shape_infos) { + std::ostringstream kernel_source; + kernel_source << default_header; + kernel_source << header; + kernel_source << "namespace mlx::core::rocm {\n\n"; + + kernel_source << "__global__ void " << func_name << "(\n"; + + // Add inputs + for (size_t i = 0; i < inputs.size(); ++i) { + const auto& name = input_names[i]; + const auto& arr = inputs[i]; + kernel_source << " const " << dtype_to_hip_type(arr.dtype()) << "* " + << name << ",\n"; + // Add input shape, strides and ndim if present in the source + if (arr.ndim() > 0) { + if (std::get<0>(shape_infos[i])) { + kernel_source << " const int32_t* " << name << "_shape,\n"; + } + if (std::get<1>(shape_infos[i])) { + kernel_source << " const int64_t* " << name << "_strides,\n"; + } + if (std::get<2>(shape_infos[i])) { + kernel_source << " const int " << name << "_ndim,\n"; + } + } + } + + // Add outputs + for (size_t i = 0; i < output_names.size(); ++i) { + const auto& name = output_names[i]; + const auto& dtype = output_dtypes[i]; + kernel_source << " " << dtype_to_hip_type(dtype) << "* " << name; + if (i < output_names.size() - 1) { + kernel_source << ",\n"; + } else { + kernel_source << ") {\n"; + } + } + + // Set compile time constants + if (!template_args.empty()) { + for (const auto& [name, arg] : template_args) { + if (std::holds_alternative(arg)) { + kernel_source << " constexpr int " << name << " = " + << std::get(arg) << ";\n"; + } else if (std::holds_alternative(arg)) { + kernel_source << " constexpr bool " << name << " = " + << (std::get(arg) ? "true" : "false") << ";\n"; + } else { + kernel_source << " using " << name << " = " + << dtype_to_hip_type(std::get(arg)) << ";\n"; + } + } + kernel_source << "\n"; + } + + kernel_source << source; + kernel_source << "\n}\n\n} // namespace mlx::core::rocm\n"; + + return kernel_source.str(); +} + +} // namespace + +CustomKernelFunction hip_kernel( + const std::string& name, + const std::vector& input_names, + const std::vector& output_names, + const std::string& source, + const std::string& header, + bool ensure_row_contiguous, + int shared_memory, + std::vector> output_input_aliases) { + if (output_names.empty()) { + throw std::invalid_argument( + "[custom_kernel] Must specify at least one output."); + } + + std::vector> shape_infos; + for (auto& n : input_names) { + std::tuple shape_info; + std::get<0>(shape_info) = source.find(n + "_shape") != std::string::npos; + std::get<1>(shape_info) = source.find(n + "_strides") != std::string::npos; + std::get<2>(shape_info) = source.find(n + "_ndim") != std::string::npos; + shape_infos.push_back(shape_info); + } + + return [=, shape_infos = std::move(shape_infos)]( + const std::vector& inputs, + const std::vector& output_shapes, + const std::vector& output_dtypes, + std::tuple grid, + std::tuple threadgroup, + const std::vector>& + template_args = {}, + std::optional init_value = std::nullopt, + bool /*ensure_row_contiguous_unused*/ = false, + StreamOrDevice s_ = {}) { + if (inputs.size() != input_names.size()) { + std::ostringstream msg; + msg << "[custom_kernel] Expected `inputs` to have size " + << input_names.size() << " but got size " << inputs.size() << "." + << std::endl; + throw std::invalid_argument(msg.str()); + } + if (output_shapes.size() != output_names.size()) { + std::ostringstream msg; + msg << "[custom_kernel] Expected `output_shapes` to have size " + << output_names.size() << " but got size " << output_shapes.size() + << "." << std::endl; + throw std::invalid_argument(msg.str()); + } + if (output_dtypes.size() != output_names.size()) { + std::ostringstream msg; + msg << "[custom_kernel] Expected `output_dtypes` to have size " + << output_names.size() << " but got size " << output_dtypes.size() + << "." << std::endl; + throw std::invalid_argument(msg.str()); + } + + auto s = to_stream(s_); + if (s.device != Device::gpu) { + throw std::invalid_argument("[custom_kernel] Only supports the GPU."); + } + + std::string kernel_name = + "custom_kernel_" + name + template_arguments_hash(template_args); + std::string kernel_source = build_kernel( + kernel_name, + header, + source, + input_names, + inputs, + output_names, + output_dtypes, + template_args, + shape_infos); + + return array::make_arrays( + std::move(output_shapes), + std::move(output_dtypes), + std::make_shared( + s, + std::move(kernel_name), + std::move(kernel_source), + grid, + threadgroup, + shape_infos, + ensure_row_contiguous, + init_value, + std::vector{}, + false, + shared_memory, + output_input_aliases), + std::move(inputs)); + }; +} + +void CustomKernel::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + auto& s = stream(); + auto& encoder = rocm::get_command_encoder(s); + + std::vector copies; + + // Output index -> input index it aliases (reuses the buffer in place). + std::vector alias_of(outputs.size(), -1); + for (auto& [oi, ii] : output_input_aliases_) { + if (oi >= 0 && oi < (int)outputs.size() && ii >= 0 && ii < (int)inputs.size()) + alias_of[oi] = ii; + } + + // Allocate and initialize the output arrays + for (size_t i = 0; i < outputs.size(); ++i) { + auto& out = outputs[i]; + if (alias_of[i] >= 0) { + // In-place: output shares the aliased input's device buffer. + out.copy_shared_buffer(inputs[alias_of[i]]); + } else if (init_value_) { + copies.emplace_back(init_value_.value(), out.dtype()); + fill_gpu(copies.back(), out, s); + } else { + out.set_data(mlx::core::rocm::malloc_async(out.nbytes(), encoder)); + } + } + + // Create the input arrays and copy if needed + auto check_input = [&copies, &s, this](const array& x) -> const array { + bool no_copy = x.flags().row_contiguous; + if (!ensure_row_contiguous_ || no_copy) { + return x; + } else { + copies.push_back(array(x.shape(), x.dtype(), nullptr, {})); + copy_gpu(x, copies.back(), CopyType::General, s); + return copies.back(); + } + }; + std::vector checked_inputs; + for (const array& in : inputs) { + checked_inputs.push_back(check_input(in)); + } + + // Compile the custom kernel + std::string kernel_name = + (is_precompiled_) ? name_ : "mlx::core::rocm::" + name_; + rocm::JitModule& mod = rocm::get_jit_module( + s.device, + name_, + [&]() { + return std::make_tuple( + is_precompiled_, source_, std::vector{kernel_name}); + }, + false); + + // Build argument list using KernelArgs helper + rocm::KernelArgs args; + for (int i = 0; i < checked_inputs.size(); i++) { + const array& in = checked_inputs[i]; + auto& shape_info = shape_infos_[i]; + args.append(in); + if (std::get<0>(shape_info)) { + args.append_ndim(in.shape()); + } + if (std::get<1>(shape_info)) { + args.append_ndim(in.strides()); + } + if (std::get<2>(shape_info)) { + args.append(in.ndim()); + } + } + for (auto& out : outputs) { + args.append(out); + } + + // Make the grid + const auto [tx, ty, tz] = threadgroup_; + const auto [gx, gy, gz] = grid_; + dim3 block(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz)); + dim3 grid((gx + tx - 1) / tx, (gy + ty - 1) / ty, (gz + tz - 1) / tz); + + // Set up arrays for kernel + for (const auto& in : checked_inputs) { + encoder.set_input_array(in); + } + for (const auto& out : outputs) { + encoder.set_output_array(out); + } + for (const auto& t : copies) { + encoder.add_temporary(t); + } + + // Launch kernel + encoder.launch_kernel([&](hipStream_t stream) { + auto kernel = mod.get_kernel(kernel_name); + + (void)hipModuleLaunchKernel( + kernel, + grid.x, + grid.y, + grid.z, + block.x, + block.y, + block.z, + shared_memory_, + stream, + args.args(), + nullptr); + }); +} + +} // namespace mlx::core::fast diff --git a/mlx/backend/rocm/device.cpp b/mlx/backend/rocm/device.cpp new file mode 100644 index 0000000000..1d08a15b04 --- /dev/null +++ b/mlx/backend/rocm/device.cpp @@ -0,0 +1,855 @@ +// Copyright © 2025 Apple Inc. + +#include +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/utils.h" +#include "mlx/backend/rocm/worker.h" +#include "mlx/utils.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace mlx::core::rocm { + +namespace { + +// Can be tuned with MLX_MAX_OPS_PER_BUFFER +constexpr int default_max_ops_per_buffer = 2000; + +inline bool is_empty_dim(dim3 dim) { + return (dim.x == 0 && dim.y == 0 && dim.z == 0) || + (dim.x == 1 && dim.y == 1 && dim.z == 1); +} + +} // namespace + +bool use_hip_graphs() { + static bool use_graphs = std::getenv("MLX_USE_HIP_GRAPHS") != nullptr; + return use_graphs; +} + +// Per-arch op/MB caps for the build graph. Tunable via env. +// The earlier "corrupts at >3 nodes" was actually one bad op (Concatenate, whose +// multi-copy kernels corrupt when co-grouped); it is now graph-split in +// gpu::eval (is_graph_split_op), so large graphs are correct again. +static std::pair get_graph_limits() { + int ops = env::max_ops_per_buffer(50); + int mb = env::max_mb_per_buffer(200); + return {ops, mb}; +} + +Device::Device(int device) : device_(device) { + make_current(); + { + hipDeviceProp_t p; + if (hipGetDeviceProperties(&p, device_) == hipSuccess) { + fprintf(stderr, "[mlx-rocm] bound HIP device %d: %s (%s)\n", + device_, p.gcnArchName, p.name); + if (p.sharedMemPerBlock > 0) { + max_shared_memory_per_block_ = static_cast(p.sharedMemPerBlock); + } + } + } + // rocBLAS initialization is now lazy - done in get_rocblas_handle() +} + +Device::~Device() { + if (rocblas_) { + rocblas_destroy_handle(rocblas_); + } +} + +rocblas_handle Device::get_rocblas_handle() { + if (!rocblas_initialized_) { + rocblas_initialized_ = true; + make_current(); + + // Check if the GPU architecture is supported by rocBLAS + hipDeviceProp_t props; + hipGetDeviceProperties(&props, device_); + std::string arch_name = props.gcnArchName; + + // List of architectures supported by rocBLAS (based on TensileLibrary + // files). These are the architectures that have TensileLibrary_lazy_*.dat. + static const std::vector supported_archs = { + "gfx908", + "gfx90a", + "gfx942", + "gfx950", + "gfx1030", + "gfx1100", + "gfx1101", + "gfx1102", + "gfx1150", + "gfx1151", + "gfx1152", + "gfx1200", + "gfx1201"}; + + // Extract base architecture name (remove any suffix like :sramecc+:xnack-) + std::string base_arch = arch_name; + size_t colon_pos = base_arch.find(':'); + if (colon_pos != std::string::npos) { + base_arch = base_arch.substr(0, colon_pos); + } + + bool arch_supported = false; + for (const auto& supported : supported_archs) { + if (base_arch == supported) { + arch_supported = true; + break; + } + } + + if (!arch_supported) { + rocblas_available_ = false; + rocblas_ = nullptr; + std::cerr << "Warning: rocBLAS does not support GPU architecture '" + << arch_name << "'. " + << "Matrix multiplication operations will not be available. " + << "Supported architectures: gfx908, gfx90a, gfx942, gfx950, " + << "gfx1030, gfx1100, gfx1101, gfx1102, gfx1150, gfx1151, " + << "gfx1200, gfx1201." << std::endl; + } else { + rocblas_status status = rocblas_create_handle(&rocblas_); + if (status != rocblas_status_success) { + rocblas_available_ = false; + rocblas_ = nullptr; + std::cerr + << "Warning: rocBLAS initialization failed (status " + << static_cast(status) + << "). Matrix multiplication operations will not be available." + << std::endl; + } + } + } + if (!rocblas_available_) { + throw std::runtime_error( + "rocBLAS is not available on this GPU architecture. " + "Matrix multiplication operations are not supported."); + } + return rocblas_; +} + +bool Device::is_rocblas_available() { + if (!rocblas_initialized_) { + try { + get_rocblas_handle(); + } catch (...) { + } + } + return rocblas_available_; +} + +bool Device::is_rocblas_bf16_available() { + if (!rocblas_bf16_probed_) { + rocblas_bf16_probed_ = true; + rocblas_bf16_available_ = false; + + if (!is_rocblas_available()) { + return false; + } + + // Probe: run a tiny bf16 GEMM and check if the GPU survives. + // rocBLAS may claim support but crash if the Tensile .co files + // are corrupt or missing specific kernel variants. + make_current(); + void* a_ptr = nullptr; + void* b_ptr = nullptr; + void* c_ptr = nullptr; + hipError_t err; + + err = hipMalloc(&a_ptr, 4 * 4 * 2); // 4x4 bf16 + if (err != hipSuccess) + return false; + err = hipMalloc(&b_ptr, 4 * 4 * 2); + if (err != hipSuccess) { + hipFree(a_ptr); + return false; + } + err = hipMalloc(&c_ptr, 4 * 4 * 2); + if (err != hipSuccess) { + hipFree(a_ptr); + hipFree(b_ptr); + return false; + } + + (void)hipMemset(a_ptr, 0, 4 * 4 * 2); + (void)hipMemset(b_ptr, 0, 4 * 4 * 2); + (void)hipMemset(c_ptr, 0, 4 * 4 * 2); + + float alpha = 1.0f, beta = 0.0f; + rocblas_status status = rocblas_gemm_ex( + rocblas_, + rocblas_operation_none, + rocblas_operation_none, + 4, + 4, + 4, + &alpha, + a_ptr, + rocblas_datatype_bf16_r, + 4, + b_ptr, + rocblas_datatype_bf16_r, + 4, + &beta, + c_ptr, + rocblas_datatype_bf16_r, + 4, + c_ptr, + rocblas_datatype_bf16_r, + 4, + rocblas_datatype_f32_r, + rocblas_gemm_algo_standard, + 0, + 0); + + // Sync and check if the GPU is still alive + hipError_t sync_err = hipDeviceSynchronize(); + // Clear any lingering error + (void)hipGetLastError(); + + hipFree(a_ptr); + hipFree(b_ptr); + hipFree(c_ptr); + + if (status == rocblas_status_success && sync_err == hipSuccess) { + rocblas_bf16_available_ = true; + } else { + // GPU may be in a bad state — need to reset + (void)hipDeviceReset(); + // Re-initialize device + make_current(); + // Re-create rocBLAS handle + if (rocblas_) { + rocblas_destroy_handle(rocblas_); + rocblas_ = nullptr; + } + rocblas_status rs = rocblas_create_handle(&rocblas_); + if (rs != rocblas_status_success) { + rocblas_available_ = false; + } + std::cerr << "Warning: rocBLAS bfloat16 GEMM probe failed on this GPU. " + << "Using fallback kernels for bf16 matmul." << std::endl; + } + } + return rocblas_bf16_available_; +} + +bool Device::has_native_wmma() { + if (!wmma_probed_) { + wmma_probed_ = true; + + hipDeviceProp_t props; + if (hipGetDeviceProperties(&props, device_) != hipSuccess) { + has_native_wmma_ = false; + return has_native_wmma_; + } + + // Strip any ":sramecc+:xnack-" style suffix from gcnArchName. + std::string base_arch = props.gcnArchName; + size_t colon_pos = base_arch.find(':'); + if (colon_pos != std::string::npos) { + base_arch = base_arch.substr(0, colon_pos); + } + + // rocWMMA arch allowlist (AMD's official support matrix). Keep in sync + // with detect_rocm_hw_info() in mlx/backend/rocm/quantized/qmm.hip. + static const std::vector rocwmma_archs = { + "gfx908", + "gfx90a", + "gfx942", + "gfx1100", + "gfx1101", + "gfx1102", + "gfx1151", + "gfx1200", + "gfx1201", + }; + for (const auto& a : rocwmma_archs) { + if (base_arch == a) { + has_native_wmma_ = true; + break; + } + } + } + return has_native_wmma_; +} + +void Device::make_current() { + // HIP's current device is per-thread, so the cache must be too — a process + // global lets one thread's binding suppress another's, stranding allocations + // on the wrong device in a multi-GPU / multi-stream-thread run. + thread_local int current = -1; + if (current != device_) { + CHECK_HIP_ERROR(hipSetDevice(device_)); + current = device_; + } +} + +void Device::set_rocblas_stream(hipStream_t stream) { + if (rocblas_stream_ != stream) { + rocblas_set_stream(get_rocblas_handle(), stream); + rocblas_stream_ = stream; + } +} + +CommandEncoder& Device::get_command_encoder(Stream s) { + // Bind this device current before constructing/returning the encoder. Callers + // reach this member directly (e.g. QuantizedMatmul::eval_gpu), and the + // encoder's stream + the kernels launched on it must land on this device, not + // whatever was current on the calling thread. + make_current(); + auto it = encoders_.find(s.index); + if (it == encoders_.end()) { + auto [inserted_it, success] = + encoders_.emplace(s.index, std::make_unique(*this)); + it = inserted_it; + } + return *it->second; +} + +void Device::clear_encoders() { + encoders_.clear(); +} + +CommandEncoder::CommandEncoder(Device& d) + : device_(d), + stream_(d), + worker_(std::make_unique(d.hip_device())) { + std::tie(max_ops_per_graph_, max_mb_per_graph_) = get_graph_limits(); + if (use_hip_graphs()) { + device_.make_current(); + CHECK_HIP_ERROR(hipGraphCreate(&build_graph_, 0)); + set_graph_active(true); + } +} + +CommandEncoder::~CommandEncoder() { + if (build_graph_) { + hipGraphDestroy(build_graph_); + build_graph_ = nullptr; + } +} + +void CommandEncoder::add_temporary(const array& arr) { + auto data = arr.data_shared_ptr(); + const array::Data* ptr = data.get(); + if (temporary_ptrs_.insert(ptr).second) { + temporaries_.push_back(std::move(data)); + } +} + +void CommandEncoder::add_completed_handler(std::function task) { + worker_->add_task(std::move(task)); +} + +void CommandEncoder::set_input_array(const array& arr) { + if (!use_hip_graphs()) { + return; + } + bytes_in_graph_ += arr.data_size(); + auto id = reinterpret_cast(arr.buffer().ptr()); + active_deps_.push_back(id); +} + +void CommandEncoder::set_output_array(const array& arr) { + if (!use_hip_graphs()) { + return; + } + auto id = reinterpret_cast(arr.buffer().ptr()); + active_deps_.push_back(id); + active_outputs_.push_back(id); +} + +void CommandEncoder::insert_graph_dependencies(GraphNode node) { + node.id = std::to_string(node_count_++); + std::vector nodes; + nodes.push_back(std::move(node)); + insert_graph_dependencies(std::move(nodes)); +} + +void CommandEncoder::insert_graph_dependencies(std::vector nodes) { + // Serialize the graph into a linear chain in submission order. This matches + // eager single-stream execution order exactly (correct), while still + // collapsing all kernels into one hipGraphLaunch (the batching win). The + // dep-based edges from set_input/output_array were unreliable because not + // every migrated kernel registers all of its inputs/outputs, leaving missing + // edges and races; a linear chain is robust and costs nothing over eager + // (which is already serial on the stream). + active_deps_.clear(); + active_outputs_.clear(); + for (auto& node : nodes) { + graph_nodes_key_ += node.node_type; + graph_nodes_key_ += "-"; + if (last_node_ != nullptr) { + from_nodes_.push_back(last_node_); + to_nodes_.push_back(node.node); + } + last_node_ = node.node; + } +} + +void CommandEncoder::add_kernel_node_raw( + void* func, + dim3 grid_dim, + dim3 block_dim, + uint32_t smem_bytes, + void** params) { + if (!use_hip_graphs()) { + device_.make_current(); + CHECK_HIP_ERROR(hipLaunchKernel( + func, grid_dim, block_dim, params, smem_bytes, stream_)); + node_count_++; + return; + } + + hipKernelNodeParams kernel_params = {}; + kernel_params.func = func; + kernel_params.gridDim = grid_dim; + kernel_params.blockDim = block_dim; + kernel_params.kernelParams = params; + kernel_params.sharedMemBytes = smem_bytes; + hipGraphNode_t node; + CHECK_HIP_ERROR( + hipGraphAddKernelNode(&node, build_graph_, nullptr, 0, &kernel_params)); + // Key the node by its kernel FUNCTION (+ launch dims), not just "K": the exec + // cache is reused via hipGraphExecUpdate only on a matching key, and update + // can only re-point params of an IDENTICAL kernel sequence. A type-only key + // collides distinct kernels and reuses the wrong exec -> garbage output. + std::string key = "K"; + key += std::to_string(reinterpret_cast(func)); + key += "_"; + key += std::to_string(grid_dim.x * grid_dim.y * grid_dim.z); + key += "x"; + key += std::to_string(block_dim.x * block_dim.y * block_dim.z); + insert_graph_dependencies(GraphNode{node, key}); +} + +void CommandEncoder::add_child_graph_node( + hipGraph_t child, + const std::string& key) { + hipGraphNode_t node; + CHECK_HIP_ERROR( + hipGraphAddChildGraphNode(&node, build_graph_, nullptr, 0, child)); + insert_graph_dependencies(GraphNode{node, key}); +} + +void CommandEncoder::maybe_commit() { + if (use_hip_graphs()) { + if (needs_commit()) { + commit(); + } + return; + } + if (node_count_ >= env::max_ops_per_buffer(default_max_ops_per_buffer)) { + commit(); + } +} + +bool CommandEncoder::needs_commit() { + if (!use_hip_graphs()) { + return node_count_ >= env::max_ops_per_buffer(default_max_ops_per_buffer); + } + return (node_count_ > max_ops_per_graph_) || + ((bytes_in_graph_ >> 20) > static_cast(max_mb_per_graph_)); +} + +void CommandEncoder::commit() { + // During graph capture, record ONLY the compute kernels into the graph. The + // host-function completion callbacks (which release temporaries) are not + // executed under stream capture and would otherwise be baked into the graph + // as host nodes that fire on every replay. Temporaries are arena-backed while + // capturing (the arena is freed in bulk on end, never per-buffer), so we can + // simply drop our references without scheduling a cleanup task. NOTE: this + // relies on the DecodeArena being active during capture — otherwise dropping + // these refs would return live buffers to the pool while later recorded + // kernels still reference them. + if (capturing_) { + // Keep capture-time buffers alive (unique, stable addresses) until the + // graph is destroyed — do NOT free them (which would alias graph nodes) and + // do NOT schedule a host-function completion (it can't fire under capture). + for (auto& d : temporaries_) + capture_held_.push_back(std::move(d)); + temporaries_.clear(); + temporary_ptrs_.clear(); + node_count_ = 0; + return; + } + + if (!temporaries_.empty()) { + add_completed_handler([temporaries = std::move(temporaries_)]() {}); + } + temporary_ptrs_.clear(); + + if (use_hip_graphs() && node_count_ > 0) { + if (!from_nodes_.empty()) { + CHECK_HIP_ERROR(hipGraphAddDependencies( + build_graph_, + from_nodes_.data(), + to_nodes_.data(), + from_nodes_.size())); + } + + device_.make_current(); + + static const bool reuse = std::getenv("MLX_GRAPH_REUSE") != nullptr; + hipGraphExec_t graph_exec = nullptr; + bool cached_exec = false; + if (reuse) { + // Reuse the exec for an identical kernel sequence via hipGraphExecUpdate + // (the key includes each kernel's func ptr + dims, so no cross-sequence + // mis-reuse). Avoids a fresh hipGraphInstantiate every commit. + size_t key = std::hash{}(graph_nodes_key_); + auto c = graph_cache_.get(key); + if (c) { + graph_exec = *c; + hipGraphExecUpdateResult ur; + hipGraphNode_t en; + if (hipGraphExecUpdate(graph_exec, build_graph_, &en, &ur) == + hipSuccess && + ur == hipGraphExecUpdateSuccess) { + cached_exec = true; + } else { + (void)hipGetLastError(); + graph_exec = nullptr; + } + } + if (graph_exec == nullptr) { + CHECK_HIP_ERROR( + hipGraphInstantiate(&graph_exec, build_graph_, nullptr, nullptr, 0)); + graph_cache_.put(key, graph_exec); + cached_exec = true; + } + } else { + CHECK_HIP_ERROR( + hipGraphInstantiate(&graph_exec, build_graph_, nullptr, nullptr, 0)); + } + + static const bool dump = std::getenv("MLX_HIP_GRAPH_DUMP") != nullptr; + if (dump) { + size_t n = 0; + hipGraphGetNodes(build_graph_, nullptr, &n); + std::vector nodes(n); + hipGraphGetNodes(build_graph_, nodes.data(), &n); + size_t nedges = 0; + hipGraphGetEdges(build_graph_, nullptr, nullptr, &nedges); + int k = 0, mcpy = 0, mset = 0, host = 0, child = 0, empty = 0, malloc_n = 0, + free_n = 0, other = 0; + for (auto nd : nodes) { + hipGraphNodeType t; + if (hipGraphNodeGetType(nd, &t) != hipSuccess) { other++; continue; } + switch (t) { + case hipGraphNodeTypeKernel: k++; break; + case hipGraphNodeTypeMemcpy: mcpy++; break; + case hipGraphNodeTypeMemset: mset++; break; + case hipGraphNodeTypeHost: host++; break; + case hipGraphNodeTypeGraph: child++; break; + case hipGraphNodeTypeEmpty: empty++; break; + case hipGraphNodeTypeMemAlloc: malloc_n++; break; + case hipGraphNodeTypeMemFree: free_n++; break; + default: other++; break; + } + } + fprintf(stderr, + "[graph] nodes=%zu edges=%zu kernel=%d memcpy=%d memset=%d " + "host=%d child=%d empty=%d memAlloc=%d memFree=%d other=%d\n", + n, nedges, k, mcpy, mset, host, child, empty, malloc_n, free_n, + other); + static int dn = 0; + char path[64]; + snprintf(path, sizeof(path), "/tmp/hipgraph_%d.dot", dn++); + hipGraphDebugDotPrint(build_graph_, path, 0); + fprintf(stderr, "[graph] dot -> %s\n", path); + } + + CHECK_HIP_ERROR(hipGraphLaunch(graph_exec, stream_)); + // Destroy the exec once its (async) launch completes (completion handler + // fires after the stream passes this commit) — unless it's cached for reuse. + if (!cached_exec) + add_completed_handler([graph_exec]() { hipGraphExecDestroy(graph_exec); }); + + // Reset build state for the next chunk. + from_nodes_.clear(); + to_nodes_.clear(); + graph_nodes_key_.clear(); + graph_deps_key_.clear(); + node_map_.clear(); + active_deps_.clear(); + active_outputs_.clear(); + bytes_in_graph_ = 0; + last_node_ = nullptr; + hipGraphDestroy(build_graph_); + CHECK_HIP_ERROR(hipGraphCreate(&build_graph_, 0)); + // NOTE: do NOT free graph_node_args_ here. hipGraphLaunch is async and the + // exec references the kernelParams until the stream drains. They are freed + // in synchronize() once the stream is idle. + } + + node_count_ = 0; + + // Put completion handlers in a batch. + worker_->commit(stream_); +} + +void CommandEncoder::synchronize() { + // A capturing stream cannot be synchronized, and there is nothing to wait for + // — recorded kernels do not execute until the captured graph is replayed. + if (capturing_) { + return; + } + (void)hipStreamSynchronize(stream_); + auto p = std::make_shared>(); + std::future f = p->get_future(); + add_completed_handler([p = std::move(p)]() { p->set_value(); }); + commit(); + f.wait(); + (void)hipStreamSynchronize(stream_); + // Stream is fully drained; graph execs are done and no longer reference the + // kernelParams. Destroy the retained execs and release the arg packs. + graph_node_args_.clear(); + graph_node_args_prev_.clear(); + if (use_hip_graphs()) flush_graph_deferred_frees(); +} + +// Global flag: true while any stream on this process is recording a HIP graph. +// Lazy library inits (e.g. hipblasLtCreate) abort the process if first called +// during capture, so they consult this to defer to a non-capturing path. +std::atomic g_stream_capturing{false}; +bool stream_capturing() { + return g_stream_capturing.load(std::memory_order_relaxed); +} +void set_stream_capturing(bool v) { + g_stream_capturing.store(v, std::memory_order_relaxed); +} + +std::atomic g_graph_active{false}; +bool graph_active() { + return g_graph_active.load(std::memory_order_relaxed); +} +void set_graph_active(bool v) { + g_graph_active.store(v, std::memory_order_relaxed); +} + +void CommandEncoder::begin_capture() { + if (capturing_) + return; + g_stream_capturing.store(true, std::memory_order_relaxed); + g_graph_active.store(true, std::memory_order_relaxed); + device_.make_current(); + // hipStreamBeginCapture records all subsequent operations on this stream + // into a graph instead of executing them. Use ThreadLocal (not Global) mode + // so only THIS thread's stream activity is captured — the Worker thread may + // still be running completion/free callbacks from prior eager steps, and + // capturing those cross-thread ops bakes spurious nodes into the graph that + // hang on replay. + hipError_t err = + hipStreamBeginCapture(stream_, hipStreamCaptureModeThreadLocal); + if (err == hipSuccess) { + capturing_ = true; + } +} + +bool CommandEncoder::end_capture() { + if (!capturing_) + return false; + capturing_ = false; + g_stream_capturing.store(false, std::memory_order_relaxed); + + hipGraph_t new_graph = nullptr; + hipError_t err = hipStreamEndCapture(stream_, &new_graph); + if (err != hipSuccess || new_graph == nullptr) { + return false; + } + + // Destroy previous graph if any + reset_graph(); + + graph_ = new_graph; + + // Patch host->device constant-upload memcpy nodes. Stream capture records + // these with the HOST source pointer, but those host buffers are freed before + // replay, so on replay the H2D copy reads stale host memory and stalls the + // GPU queue. While the host data is still valid (right after capture), copy + // each into a persistent device staging buffer and rewrite the node as + // device->device so replay reads valid device memory. The staging buffers are + // intentionally leaked for the lifetime of the graph. + { + size_t n = 0; + hipGraphGetNodes(graph_, nullptr, &n); + std::vector nodes(n); + hipGraphGetNodes(graph_, nodes.data(), &n); + for (size_t i = 0; i < n; i++) { + hipGraphNodeType t; + if (hipGraphNodeGetType(nodes[i], &t) != hipSuccess || + t != hipGraphNodeTypeMemcpy) + continue; + hipMemcpy3DParms p{}; + if (hipGraphMemcpyNodeGetParams(nodes[i], &p) != hipSuccess) + continue; + if (p.kind != hipMemcpyHostToDevice) + continue; + size_t bytes = p.extent.width * std::max(p.extent.height, 1) * + std::max(p.extent.depth, 1); + if (bytes == 0 || p.srcPtr.ptr == nullptr) + continue; + void* stage = nullptr; + if (hipMalloc(&stage, bytes) != hipSuccess) + continue; + // Copy the host constant into the staging buffer now (host source is still + // valid right after capture) and rewrite the node as device->device. + if (hipMemcpy(stage, p.srcPtr.ptr, bytes, hipMemcpyHostToDevice) != + hipSuccess) { + hipFree(stage); + continue; + } + p.srcPtr = make_hipPitchedPtr(stage, p.srcPtr.pitch ? p.srcPtr.pitch : bytes, + p.extent.width, std::max(p.extent.height, 1)); + p.kind = hipMemcpyDeviceToDevice; + (void)hipGraphMemcpyNodeSetParams(nodes[i], &p); + } + } + + static const bool dbg = std::getenv("MLX_GRAPH_DEBUG") != nullptr; + if (dbg) { + size_t n = 0; + hipGraphGetNodes(graph_, nullptr, &n); + std::vector nodes(n); + hipGraphGetNodes(graph_, nodes.data(), &n); + int kKernel = 0, kMemcpy = 0, kMemset = 0, kHost = 0, kEmpty = 0, + kWaitEvent = 0, kEventRecord = 0, kMemAlloc = 0, kMemFree = 0, kOther = 0; + for (size_t i = 0; i < n; i++) { + hipGraphNodeType t; + if (hipGraphNodeGetType(nodes[i], &t) != hipSuccess) { kOther++; continue; } + switch (t) { + case hipGraphNodeTypeKernel: kKernel++; break; + case hipGraphNodeTypeMemcpy: kMemcpy++; break; + case hipGraphNodeTypeMemset: kMemset++; break; + case hipGraphNodeTypeHost: kHost++; break; + case hipGraphNodeTypeEmpty: kEmpty++; break; + case hipGraphNodeTypeWaitEvent: kWaitEvent++; break; + case hipGraphNodeTypeEventRecord: kEventRecord++; break; + case hipGraphNodeTypeMemAlloc: kMemAlloc++; break; + case hipGraphNodeTypeMemFree: kMemFree++; break; + default: kOther++; break; + } + } + fprintf(stderr, + "[capture] nodes=%zu kernel=%d memcpy=%d memset=%d host=%d empty=%d " + "waitEvent=%d eventRecord=%d memAlloc=%d memFree=%d other=%d\n", + n, kKernel, kMemcpy, kMemset, kHost, kEmpty, kWaitEvent, + kEventRecord, kMemAlloc, kMemFree, kOther); + // Inspect memcpy nodes — host->device copies with a stale host source would + // fault/stall on replay. + for (size_t i = 0; i < n; i++) { + hipGraphNodeType t; + if (hipGraphNodeGetType(nodes[i], &t) != hipSuccess || + t != hipGraphNodeTypeMemcpy) + continue; + hipMemcpy3DParms p{}; + if (hipGraphMemcpyNodeGetParams(nodes[i], &p) == hipSuccess) { + fprintf(stderr, "[capture] memcpy kind=%d bytes=%zu\n", (int)p.kind, + p.extent.width * p.extent.height * p.extent.depth); + } + } + } + + err = hipGraphInstantiate(&graph_exec_, graph_, nullptr, nullptr, 0); + if (err != hipSuccess) { + hipGraphDestroy(graph_); + graph_ = nullptr; + graph_exec_ = nullptr; + return false; + } + return true; +} + +bool CommandEncoder::replay(bool sync) { + if (!graph_exec_) + return false; + device_.make_current(); + static const bool dbg = std::getenv("MLX_GRAPH_DEBUG") != nullptr; + if (dbg) fprintf(stderr, "[replay] launching graph (sync=%d)...\n", (int)sync); + hipError_t err = hipGraphLaunch(graph_exec_, stream_); + if (dbg) fprintf(stderr, "[replay] launch returned %d (%s)\n", + (int)err, hipGetErrorString(err)); + if (err != hipSuccess) + return false; + // The captured kernels run asynchronously on stream_. The completion Events + // that eval() would normally wait on were skipped during capture. When sync + // is requested, wait here for the replayed work to finish before the caller + // reads outputs. When async, the caller orders its output reads after this + // launch on the SAME stream (subsequent MLX eval on the generation stream), + // so no drain is needed and per-token work can pipeline. + if (!sync) + return true; + err = hipStreamSynchronize(stream_); + if (dbg) fprintf(stderr, "[replay] sync returned %d (%s)\n", + (int)err, hipGetErrorString(err)); + return err == hipSuccess; +} + +void CommandEncoder::reset_graph() { + if (graph_exec_) { + hipGraphExecDestroy(graph_exec_); + graph_exec_ = nullptr; + } + if (graph_) { + hipGraphDestroy(graph_); + graph_ = nullptr; + } + // The captured graph is gone — release the buffers it referenced. + capture_held_.clear(); + g_graph_active.store(false, std::memory_order_relaxed); + flush_graph_deferred_frees(); +} + +std::unordered_map& get_devices() { + static std::unordered_map devices; + return devices; +} + +Device& device(mlx::core::Device device) { + auto& devices = get_devices(); + auto it = devices.find(device.index); + if (it == devices.end()) { + // Set blocking sync flags on THIS device (per index, not a single global + // bool: if device 0 were touched first the global gate would leave device 1 + // unflagged). Must happen while this device is current and before its + // context is created — i.e. before the Device is constructed. Iterating every + // device would create a context/queue on the other GPU too; on a multi-GPU + // host that cross-device coexistence is what wedges the discrete GPU's queue + // over a TB5 link, so touch only this device. + hipSetDevice(device.index); + hipSetDeviceFlags(hipDeviceScheduleBlockingSync); + it = devices.try_emplace(device.index, device.index).first; + } + return it->second; +} + +CommandEncoder& get_command_encoder(Stream s) { + // Bind the HIP current device to this stream's device. HIP's current device is + // per-thread; everything that touches a stream goes through here (eval, kernel + // launches, event record/wait, commit, completion callbacks). Without binding, + // operations for a non-default GPU (--device 1) execute against device 0 — the + // stream/event/kernel land on the wrong device and the queue hangs. With + // HIP_VISIBLE_DEVICES the only device IS index 0 so the bug is hidden. + auto& d = device(s.device); + d.make_current(); + return d.get_command_encoder(s); +} + +void clear_all_encoders() { + auto& devices = get_devices(); + for (auto& [idx, dev] : devices) { + dev.clear_encoders(); + } +} + +} // namespace mlx::core::rocm diff --git a/mlx/backend/rocm/device.h b/mlx/backend/rocm/device.h new file mode 100644 index 0000000000..5b77b9a325 --- /dev/null +++ b/mlx/backend/rocm/device.h @@ -0,0 +1,326 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/rocm/lru_cache.h" +#include "mlx/backend/rocm/utils.h" +#include "mlx/stream.h" + +#include +#include + +// Only include thrust headers when compiling with HIP compiler +// (thrust headers have dependencies on CUDA/HIP-specific headers) +#ifdef __HIPCC__ +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace mlx::core::rocm { + +// Forward declaration +class Device; +class Worker; + +// Gate for the automatic HIP-graph batching path. Default OFF so the legacy +// immediate-launch path is unaffected unless MLX_USE_HIP_GRAPHS is set. +bool use_hip_graphs(); + +class CommandEncoder { + public: + explicit CommandEncoder(Device& d); + ~CommandEncoder(); + + CommandEncoder(const CommandEncoder&) = delete; + CommandEncoder& operator=(const CommandEncoder&) = delete; + + void set_input_array(const array& arr); + void set_output_array(const array& arr); + + template + void launch_kernel(F&& func); + + template + void add_kernel_node( + Func* func, + dim3 grid_dim, + dim3 block_dim, + uint32_t smem_bytes, + Params&&... params) { + add_kernel_node_ex(func, grid_dim, block_dim, smem_bytes, params...); + } + + template + void add_kernel_node_ex( + Func* func, + dim3 grid_dim, + dim3 block_dim, + uint32_t smem_bytes, + Params&&... params) { + constexpr size_t num = sizeof...(Params); + if (!use_hip_graphs()) { + // Immediate launch: kernelParams are consumed synchronously, so + // addresses of the caller's locals are fine. + void* ptrs[num > 0 ? num : 1]; + size_t i = 0; + ([&](auto&& p) { + ptrs[i++] = + const_cast(static_cast(std::addressof(p))); + }(std::forward(params)), + ...); + add_kernel_node_raw( + reinterpret_cast(func), grid_dim, block_dim, smem_bytes, ptrs); + return; + } + // Graph build: a HIP graph kernel node references its kernelParams until the + // node is instantiated/updated into the exec graph, which happens later in + // commit(). The caller's argument locals are gone by then, so copy the + // argument VALUES (and the pointer array) into a heap pack kept alive until + // commit() finishes (cleared there). + struct Pack { + std::tuple...> vals; + std::array 0 ? num : 1)> ptrs; + }; + auto pack = std::make_shared(); + pack->vals = std::tuple...>( + std::forward(params)...); + fill_param_ptrs(pack->vals, pack->ptrs, std::index_sequence_for{}); + graph_node_args_.push_back(pack); + add_kernel_node_raw( + reinterpret_cast(func), + grid_dim, + block_dim, + smem_bytes, + pack->ptrs.data()); + } + + template + static void + fill_param_ptrs(Tuple& vals, Arr& ptrs, std::index_sequence) { + ((ptrs[I] = const_cast( + static_cast(std::addressof(std::get(vals))))), + ...); + } + + void add_kernel_node_raw( + void* func, + dim3 grid_dim, + dim3 block_dim, + uint32_t smem_bytes, + void** params); + + void add_temporary(const array& arr); + + void add_completed_handler(std::function task); + void maybe_commit(); + bool needs_commit(); + void commit(); + + Device& device() { + return device_; + } + + HipStream& stream() { + return stream_; + } + + // Wait until kernels and completion handlers are finished + void synchronize(); + + // --- Graph capture API --- + // Begin recording all kernel launches into a HIP graph. + // While capturing, launch_kernel dispatches are recorded (not executed). + void begin_capture(); + + // End recording and instantiate the captured graph. + // Returns true if capture succeeded (graph is ready to replay). + bool end_capture(); + + // Replay the previously captured graph. All recorded kernels execute + // in a single GPU dispatch. Returns false if no graph is available. + // If sync is true (default) the call blocks until the replayed work + // finishes. If false it only launches the graph onto the stream and + // returns immediately — the caller must order any reads of the graph's + // outputs after it on the SAME stream (subsequent MLX eval on the + // generation stream does exactly this), which lets per-token sampling + // pipeline instead of draining the GPU every token. + bool replay(bool sync = true); + + // Returns true if a captured graph is ready to replay. + bool has_graph() const { + return graph_exec_ != nullptr; + } + + // True while this encoder's stream is recording into a HIP graph. Used by the + // Event layer to avoid recording completion events onto the captured stream + // (they would be baked into the graph and never fire, deadlocking eval). + bool capturing() const { + return capturing_; + } + + // Discard the captured graph. + void reset_graph(); + + private: + struct GraphNode { + hipGraphNode_t node; + // K = kernel, E = empty, () = subgraph + std::string node_type; + std::string id; + }; + + void insert_graph_dependencies(GraphNode node); + void insert_graph_dependencies(std::vector nodes); + void add_child_graph_node(hipGraph_t child, const std::string& key); + + Device& device_; + HipStream stream_; + std::unique_ptr worker_; + int node_count_{0}; + std::vector> temporaries_; + std::unordered_set temporary_ptrs_; + bool capturing_{false}; + + // --- Automatic graph-batching state (mirrors CUDA CommandEncoder) --- + hipGraph_t build_graph_{nullptr}; + std::vector from_nodes_; + std::vector to_nodes_; + hipGraphNode_t last_node_{nullptr}; + std::string graph_nodes_key_; + std::string graph_deps_key_; + std::vector active_deps_; + std::vector active_outputs_; + std::unordered_map node_map_; + size_t bytes_in_graph_{0}; + int max_ops_per_graph_{50}; + int max_mb_per_graph_{200}; + LRUCache graph_cache_{400}; + // Per-build kernel-arg packs: keep the kernelParams values alive while the + // (async) exec may reference them. Held one extra commit via _prev_. + std::vector> graph_node_args_; + std::vector> graph_node_args_prev_; + // Instantiated execs retained until the stream drains (destroyed in + // synchronize()), since hipGraphLaunch is async. + std::vector graph_execs_; + // Buffers allocated during capture are held alive here (not freed) so their + // addresses stay valid and unique for the lifetime of the captured graph — + // freeing them mid-capture would let later allocations reuse the same + // address, aliasing distinct graph nodes. Released in reset_graph(). + std::vector> capture_held_; + hipGraph_t graph_{nullptr}; + hipGraphExec_t graph_exec_{nullptr}; +}; + +class Device { + public: + explicit Device(int device); + ~Device(); + + Device(const Device&) = delete; + Device& operator=(const Device&) = delete; + + // Make this device the current HIP device, required by some HIP calls. + void make_current(); + + CommandEncoder& get_command_encoder(Stream s); + void clear_encoders(); + + int hip_device() const { + return device_; + } + + rocblas_handle get_rocblas_handle(); + void set_rocblas_stream(hipStream_t stream); + + // Check if rocBLAS is available for the current GPU architecture + bool is_rocblas_available(); + + // Check if rocBLAS bf16 GEMM works on this device (probed at init) + bool is_rocblas_bf16_available(); + + // True iff this device's gcnArchName is on the rocWMMA arch allowlist + // (CDNA1/2/3 + RDNA3 dGPU + gfx1151 + RDNA4). Lazy-cached on first call. + bool has_native_wmma(); + + // Max shared memory (LDS) a single block may use on this device, in bytes, + // queried from hipDeviceProp at construction. RDNA3/3.5 report 64 KB; RDNA4 + // and CDNA may report more. Kernels that size LDS tiles must read this from + // the device actually running the op rather than assume a fixed budget. + int max_shared_memory_per_block() const { + return max_shared_memory_per_block_; + } + + private: + int device_; + rocblas_handle rocblas_{nullptr}; + hipStream_t rocblas_stream_{nullptr}; + bool rocblas_initialized_{false}; + bool rocblas_available_{true}; + bool rocblas_bf16_probed_{false}; + bool rocblas_bf16_available_{false}; + bool wmma_probed_{false}; + bool has_native_wmma_{false}; + int max_shared_memory_per_block_{65536}; + std::unordered_map> encoders_; +}; + +Device& device(mlx::core::Device device); +CommandEncoder& get_command_encoder(Stream s); +void clear_all_encoders(); + +// True while a HIP graph capture is in progress on any stream. Lazy library +// inits that abort under capture (e.g. hipblasLtCreate) check this. +bool stream_capturing(); +void set_stream_capturing(bool v); +void set_graph_active(bool v); + +// True from capture start until the captured graph is destroyed. The allocator +// defers all frees while set so graph-referenced buffers stay valid through replay. +bool graph_active(); +void flush_graph_deferred_frees(); + +// Return an execution policy that does not sync for result. +// Only available when compiling with HIP compiler +#ifdef __HIPCC__ +inline auto thrust_policy(hipStream_t stream) { + return thrust::hip::par.on(stream); +} +#endif + +// Template implementation (must be after Device is defined) +template +void CommandEncoder::launch_kernel(F&& func) { + device_.make_current(); + // Under the automatic graph-batching path, capture this lambda's launches + // into a child graph node so the build graph stays complete while individual + // kernels are migrated to add_kernel_node. The legacy whole-stream capture + // path (capturing_) and the immediate path are left untouched. + // Residual ops not migrated to add_kernel_node (library GEMM, JIT module + // kernels, memsets) can't be HIP graph kernel nodes (no module-func field) + // and child-graph capture wedges the GPU on this ROCm. Instead graph-split: + // flush+launch the accumulated graph, then run this op immediately on the + // same stream (ordered after the graph), and the next op starts a fresh + // graph. Library GEMM thus runs OUTSIDE capture, so hipBLASLt won't abort. + if (use_hip_graphs() && !capturing_) { + commit(); + func(static_cast(stream_)); + return; + } + // When the legacy path is capturing, kernel launches are recorded into the + // HIP graph automatically. Otherwise hipLaunchKernel executes immediately. + func(static_cast(stream_)); + node_count_++; +} + +} // namespace mlx::core::rocm diff --git a/mlx/backend/rocm/device/arange.hpp b/mlx/backend/rocm/device/arange.hpp new file mode 100644 index 0000000000..e33a65a790 --- /dev/null +++ b/mlx/backend/rocm/device/arange.hpp @@ -0,0 +1,17 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include + +namespace mlx::core::rocm { + +template +__global__ void arange_kernel(T* out, T start, T step, size_t size) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < size) { + out[idx] = start + static_cast(idx) * step; + } +} + +} // namespace mlx::core::rocm diff --git a/mlx/backend/rocm/device/atomic_ops.hpp b/mlx/backend/rocm/device/atomic_ops.hpp new file mode 100644 index 0000000000..970a515dec --- /dev/null +++ b/mlx/backend/rocm/device/atomic_ops.hpp @@ -0,0 +1,302 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include +#include +#include + +namespace mlx::core::rocm { + +// Generic atomic reduce using CAS loop +template +__device__ void atomic_reduce(T* addr, T val) { + Op op; + T old = *addr; + T assumed; + do { + assumed = old; + T new_val = op(assumed, val); + old = atomicCAS(addr, assumed, new_val); + } while (old != assumed); +} + +// Atomic add for various types +template +__device__ void atomic_add(T* addr, T val) { + atomicAdd(addr, val); +} + +// Specialization for float +template <> +__device__ inline void atomic_add(float* addr, float val) { + atomicAdd(addr, val); +} + +// Specialization for double +template <> +__device__ inline void atomic_add(double* addr, double val) { + atomicAdd(addr, val); +} + +// Specialization for int +template <> +__device__ inline void atomic_add(int* addr, int val) { + atomicAdd(addr, val); +} + +// Specialization for unsigned int +template <> +__device__ inline void atomic_add( + unsigned int* addr, + unsigned int val) { + atomicAdd(addr, val); +} + +// Specialization for unsigned long long +template <> +__device__ inline void atomic_add( + unsigned long long* addr, + unsigned long long val) { + atomicAdd(addr, val); +} + +// Specialization for int64_t (maps to long long on most platforms) +template <> +__device__ inline void atomic_add(long long* addr, long long val) { + atomicAdd( + reinterpret_cast(addr), + static_cast(val)); +} + +// CAS-based atomic add for unsupported types +template +__device__ void atomic_add_general(T* addr, T val) { + // Use CAS loop for types without native atomic support + T old = *addr; + T assumed; + do { + assumed = old; + T new_val = assumed + val; + // Reinterpret as unsigned int for CAS + unsigned int* addr_as_uint = reinterpret_cast(addr); + unsigned int old_as_uint = + __float_as_uint(*reinterpret_cast(&assumed)); + unsigned int new_as_uint = + __float_as_uint(*reinterpret_cast(&new_val)); + unsigned int result = atomicCAS(addr_as_uint, old_as_uint, new_as_uint); + old = *reinterpret_cast(&result); + } while (old != assumed); +} + +// Specialization for __half using CAS +template <> +__device__ inline void atomic_add<__half>(__half* addr, __half val) { + // Use 32-bit CAS for half precision + unsigned int* addr_as_uint = reinterpret_cast( + reinterpret_cast(addr) & ~size_t(0x3)); + unsigned int shift = (reinterpret_cast(addr) & 0x2) ? 16 : 0; + + unsigned int old = *addr_as_uint; + unsigned int assumed; + do { + assumed = old; + __half old_half = __ushort_as_half((assumed >> shift) & 0xFFFF); + __half new_half = __hadd(old_half, val); + unsigned int new_val = + (assumed & ~(0xFFFF << shift)) | (__half_as_ushort(new_half) << shift); + old = atomicCAS(addr_as_uint, assumed, new_val); + } while (old != assumed); +} + +// Specialization for hip_bfloat16 using CAS +template <> +__device__ inline void atomic_add( + hip_bfloat16* addr, + hip_bfloat16 val) { + // Use 32-bit CAS for bfloat16 + unsigned int* addr_as_uint = reinterpret_cast( + reinterpret_cast(addr) & ~size_t(0x3)); + unsigned int shift = (reinterpret_cast(addr) & 0x2) ? 16 : 0; + + unsigned int old = *addr_as_uint; + unsigned int assumed; + do { + assumed = old; + hip_bfloat16 old_bf16; + old_bf16.data = (assumed >> shift) & 0xFFFF; + hip_bfloat16 new_bf16 = + hip_bfloat16(static_cast(old_bf16) + static_cast(val)); + unsigned int new_val = + (assumed & ~(0xFFFF << shift)) | (new_bf16.data << shift); + old = atomicCAS(addr_as_uint, assumed, new_val); + } while (old != assumed); +} + +// Specialization for hipFloatComplex using CAS +template <> +__device__ inline void atomic_add( + hipFloatComplex* addr, + hipFloatComplex val) { + // Atomic add for real and imaginary parts separately + atomic_add(&(addr->x), val.x); + atomic_add(&(addr->y), val.y); +} + +// Atomic product using CAS loop +template +__device__ void atomic_prod(T* addr, T val) { + T old = *addr; + T assumed; + do { + assumed = old; + T new_val = assumed * val; + old = atomicCAS(addr, assumed, new_val); + } while (old != assumed); +} + +// Specialization for float +template <> +__device__ inline void atomic_prod(float* addr, float val) { + unsigned int* addr_as_uint = reinterpret_cast(addr); + unsigned int old = *addr_as_uint; + unsigned int assumed; + do { + assumed = old; + float old_float = __uint_as_float(assumed); + float new_float = old_float * val; + old = atomicCAS(addr_as_uint, assumed, __float_as_uint(new_float)); + } while (old != assumed); +} + +// Specialization for double +template <> +__device__ inline void atomic_prod(double* addr, double val) { + unsigned long long* addr_as_ull = reinterpret_cast(addr); + unsigned long long old = *addr_as_ull; + unsigned long long assumed; + do { + assumed = old; + double old_double = __longlong_as_double(assumed); + double new_double = old_double * val; + old = atomicCAS(addr_as_ull, assumed, __double_as_longlong(new_double)); + } while (old != assumed); +} + +// Atomic max for various types +template +__device__ void atomic_max(T* addr, T val) { + atomicMax(addr, val); +} + +// Specialization for float using CAS +template <> +__device__ inline void atomic_max(float* addr, float val) { + if (val < 0.0f) { + // For negative values, use integer atomicMin on the bit representation + int* addr_as_int = reinterpret_cast(addr); + atomicMin(addr_as_int, __float_as_int(val)); + } else { + // For non-negative values, use integer atomicMax + unsigned int* addr_as_uint = reinterpret_cast(addr); + atomicMax(addr_as_uint, __float_as_uint(val)); + } +} + +// Specialization for double using CAS +template <> +__device__ inline void atomic_max(double* addr, double val) { + unsigned long long* addr_as_ull = reinterpret_cast(addr); + unsigned long long old = *addr_as_ull; + unsigned long long assumed; + do { + assumed = old; + double old_double = __longlong_as_double(assumed); + double new_double = (old_double > val) ? old_double : val; + old = atomicCAS(addr_as_ull, assumed, __double_as_longlong(new_double)); + } while (old != assumed && __longlong_as_double(old) < val); +} + +// Atomic min for various types +template +__device__ void atomic_min(T* addr, T val) { + atomicMin(addr, val); +} + +// Specialization for float using CAS +template <> +__device__ inline void atomic_min(float* addr, float val) { + if (val < 0.0f) { + // For negative values, use integer atomicMax on the bit representation + int* addr_as_int = reinterpret_cast(addr); + atomicMax(addr_as_int, __float_as_int(val)); + } else { + // For non-negative values, use integer atomicMin + unsigned int* addr_as_uint = reinterpret_cast(addr); + atomicMin(addr_as_uint, __float_as_uint(val)); + } +} + +// Specialization for double using CAS +template <> +__device__ inline void atomic_min(double* addr, double val) { + unsigned long long* addr_as_ull = reinterpret_cast(addr); + unsigned long long old = *addr_as_ull; + unsigned long long assumed; + do { + assumed = old; + double old_double = __longlong_as_double(assumed); + double new_double = (old_double < val) ? old_double : val; + old = atomicCAS(addr_as_ull, assumed, __double_as_longlong(new_double)); + } while (old != assumed && __longlong_as_double(old) > val); +} + +// Atomic CAS (Compare-And-Swap) +template +__device__ T atomic_cas(T* addr, T compare, T val) { + return atomicCAS(addr, compare, val); +} + +// Atomic exchange +template +__device__ T atomic_exchange(T* addr, T val) { + return atomicExch(addr, val); +} + +// Atomic and +template +__device__ void atomic_and(T* addr, T val) { + atomicAnd(addr, val); +} + +// Atomic or +template +__device__ void atomic_or(T* addr, T val) { + atomicOr(addr, val); +} + +// Specialization for bool +template <> +__device__ inline void atomic_and(bool* addr, bool val) { + if (!val) { + // If val is false, set to false + unsigned int* addr_as_uint = reinterpret_cast( + reinterpret_cast(addr) & ~size_t(0x3)); + unsigned int shift = (reinterpret_cast(addr) & 0x3) * 8; + atomicAnd(addr_as_uint, ~(0xFF << shift)); + } +} + +template <> +__device__ inline void atomic_or(bool* addr, bool val) { + if (val) { + // If val is true, set to true + unsigned int* addr_as_uint = reinterpret_cast( + reinterpret_cast(addr) & ~size_t(0x3)); + unsigned int shift = (reinterpret_cast(addr) & 0x3) * 8; + atomicOr(addr_as_uint, 0x01 << shift); + } +} + +} // namespace mlx::core::rocm diff --git a/mlx/backend/rocm/device/binary_ops.hpp b/mlx/backend/rocm/device/binary_ops.hpp new file mode 100644 index 0000000000..59dd1c8e69 --- /dev/null +++ b/mlx/backend/rocm/device/binary_ops.hpp @@ -0,0 +1,486 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/rocm/device/unary_ops.hpp" + +#include + +namespace mlx::core::rocm { + +struct Add { + template + __device__ T operator()(T x, T y) { + if constexpr (is_complex_v) { + return hipCaddf(x, y); + } else if constexpr (std::is_same_v) { + return hip_bfloat16(static_cast(x) + static_cast(y)); + } else if constexpr (std::is_same_v) { + return __float2half(__half2float(x) + __half2float(y)); + } else { + return x + y; + } + } +}; + +struct FloorDivide { + template + __device__ T operator()(T x, T y) { + if constexpr (std::is_integral_v) { + return x / y; + } else if constexpr (std::is_same_v) { + return hip_bfloat16( + truncf(static_cast(x) / static_cast(y))); + } else if constexpr (std::is_same_v) { + return __float2half(truncf(__half2float(x) / __half2float(y))); + } else { + return truncf(x / y); + } + } +}; + +struct Divide { + template + __device__ T operator()(T x, T y) { + if constexpr (is_complex_v) { + return hipCdivf(x, y); + } else if constexpr (std::is_same_v) { + return hip_bfloat16(static_cast(x) / static_cast(y)); + } else if constexpr (std::is_same_v) { + return __float2half(__half2float(x) / __half2float(y)); + } else { + return x / y; + } + } +}; + +struct Remainder { + template + __device__ T operator()(T x, T y) { + if constexpr (std::is_integral_v) { + if constexpr (std::is_signed_v) { + auto r = x % y; + if (r != 0 && (r < 0 != y < 0)) { + r += y; + } + return r; + } else { + return x % y; + } + } else if constexpr (is_complex_v) { + // Complex modulo not typically defined, return x + return x; + } else if constexpr (std::is_same_v) { + float fx = static_cast(x); + float fy = static_cast(y); + float r = fmodf(fx, fy); + if (r != 0 && (r < 0 != fy < 0)) { + r = r + fy; + } + return hip_bfloat16(r); + } else if constexpr (std::is_same_v) { + float fx = __half2float(x); + float fy = __half2float(y); + float r = fmodf(fx, fy); + if (r != 0 && (r < 0 != fy < 0)) { + r = r + fy; + } + return __float2half(r); + } else { + T r = fmodf(x, y); + if (r != 0 && (r < 0 != y < 0)) { + r = r + y; + } + return r; + } + } +}; + +struct Equal { + template + __device__ bool operator()(T x, T y) { + return x == y; + } +}; + +struct NaNEqual { + template + __device__ bool operator()(T x, T y) { + if constexpr (is_complex_v) { + return (x.x == y.x && x.y == y.y) || + (__isnanf(x.x) && __isnanf(y.x) && __isnanf(x.y) && __isnanf(y.y)) || + (x.x == y.x && __isnanf(x.y) && __isnanf(y.y)) || + (__isnanf(x.x) && __isnanf(y.x) && x.y == y.y); + } else if constexpr (std::is_same_v) { + float fx = static_cast(x); + float fy = static_cast(y); + return fx == fy || (__isnanf(fx) && __isnanf(fy)); + } else if constexpr (std::is_same_v) { + float fx = __half2float(x); + float fy = __half2float(y); + return fx == fy || (__isnanf(fx) && __isnanf(fy)); + } else { + return x == y || (__isnanf(x) && __isnanf(y)); + } + } +}; + +struct Greater { + template + __device__ bool operator()(T x, T y) { + return x > y; + } +}; + +struct GreaterEqual { + template + __device__ bool operator()(T x, T y) { + return x >= y; + } +}; + +struct Less { + template + __device__ bool operator()(T x, T y) { + return x < y; + } +}; + +struct LessEqual { + template + __device__ bool operator()(T x, T y) { + return x <= y; + } +}; + +struct LogAddExp { + template + __device__ T operator()(T x, T y) { + if constexpr (std::is_integral_v) { + // LogAddExp doesn't make sense for integers, but handle it gracefully + return x > y ? x : y; + } else if constexpr (is_complex_v) { + if (isnan(x.x) || isnan(x.y) || isnan(y.x) || isnan(y.y)) { + return { + numeric_limits::quiet_NaN(), + numeric_limits::quiet_NaN()}; + } + auto maxv = x.x > y.x ? x : y; + auto minv = x.x < y.x ? x : y; + auto min_real = minv.x; + auto max_real = maxv.x; + if (!isfinite(min_real) && (min_real == max_real)) { + if (min_real < 0) { + return minv; + } else { + return Log{}(hipCaddf(Exp{}(minv), Exp{}(maxv))); + } + } else { + return hipCaddf(Log1p{}(Exp{}(hipCsubf(minv, maxv))), maxv); + } + } else if constexpr (std::is_same_v) { + float fx = static_cast(x); + float fy = static_cast(y); + if (isnan(fx) || isnan(fy)) { + return hip_bfloat16(numeric_limits::quiet_NaN()); + } + float maxval = fmaxf(fx, fy); + float minval = fminf(fx, fy); + float result = (minval == -numeric_limits::infinity() || + maxval == numeric_limits::infinity()) + ? maxval + : maxval + log1pf(expf(minval - maxval)); + return hip_bfloat16(result); + } else if constexpr (std::is_same_v) { + float fx = __half2float(x); + float fy = __half2float(y); + if (isnan(fx) || isnan(fy)) { + return __float2half(numeric_limits::quiet_NaN()); + } + float maxval = fmaxf(fx, fy); + float minval = fminf(fx, fy); + float result = (minval == -numeric_limits::infinity() || + maxval == numeric_limits::infinity()) + ? maxval + : maxval + log1pf(expf(minval - maxval)); + return __float2half(result); + } else { + if (isnan(x) || isnan(y)) { + return numeric_limits::quiet_NaN(); + } + T maxval = fmaxf(x, y); + T minval = fminf(x, y); + return (minval == -numeric_limits::infinity() || + maxval == numeric_limits::infinity()) + ? maxval + : T(float(maxval) + log1pf(expf(minval - maxval))); + } + }; +}; + +struct Maximum { + template + __device__ T operator()(T x, T y) { + if constexpr (std::is_integral_v) { + return max(x, y); + } else if constexpr (is_complex_v) { + if (__isnanf(x.x) || __isnanf(x.y)) { + return x; + } + // Compare by real part first, then imaginary + if (x.x > y.x || (x.x == y.x && x.y > y.y)) { + return x; + } + return y; + } else if constexpr (std::is_same_v) { + float fx = static_cast(x); + float fy = static_cast(y); + if (__isnanf(fx)) { + return x; + } + return fx > fy ? x : y; + } else if constexpr (std::is_same_v) { + float fx = __half2float(x); + float fy = __half2float(y); + if (__isnanf(fx)) { + return x; + } + return fx > fy ? x : y; + } else { + if (__isnanf(x)) { + return x; + } + return x > y ? x : y; + } + } +}; + +struct Minimum { + template + __device__ T operator()(T x, T y) { + if constexpr (std::is_integral_v) { + return min(x, y); + } else if constexpr (is_complex_v) { + if (__isnanf(x.x) || __isnanf(x.y)) { + return x; + } + // Compare by real part first, then imaginary + if (x.x < y.x || (x.x == y.x && x.y < y.y)) { + return x; + } + return y; + } else if constexpr (std::is_same_v) { + float fx = static_cast(x); + float fy = static_cast(y); + if (__isnanf(fx)) { + return x; + } + return fx < fy ? x : y; + } else if constexpr (std::is_same_v) { + float fx = __half2float(x); + float fy = __half2float(y); + if (__isnanf(fx)) { + return x; + } + return fx < fy ? x : y; + } else { + if (__isnanf(x)) { + return x; + } + return x < y ? x : y; + } + } +}; + +struct Multiply { + template + __device__ T operator()(T x, T y) { + if constexpr (is_complex_v) { + return hipCmulf(x, y); + } else if constexpr (std::is_same_v) { + return hip_bfloat16(static_cast(x) * static_cast(y)); + } else if constexpr (std::is_same_v) { + return __float2half(__half2float(x) * __half2float(y)); + } else { + return x * y; + } + } +}; + +struct NotEqual { + template + __device__ bool operator()(T x, T y) { + if constexpr (is_complex_v) { + return x.x != y.x || x.y != y.y; + } else { + return x != y; + } + } +}; + +struct Power { + template + __device__ T operator()(T base, T exp) { + if constexpr (std::is_integral_v) { + T res = 1; + // Raising an integer to a negative power is undefined + if constexpr (std::is_signed_v) { + if (exp < 0) { + return 0; + } + } + while (exp) { + if (exp & 1) { + res *= base; + } + exp >>= 1; + base *= base; + } + return res; + } else if constexpr (is_complex_v) { + // Complex power: base^exp = exp(exp * log(base)) + float r = hypotf(base.x, base.y); + float theta = atan2f(base.y, base.x); + float log_r = logf(r); + float new_r = expf(exp.x * log_r - exp.y * theta); + float new_theta = exp.x * theta + exp.y * log_r; + return make_hipFloatComplex( + new_r * cosf(new_theta), new_r * sinf(new_theta)); + } else if constexpr (std::is_same_v) { + return hip_bfloat16( + powf(static_cast(base), static_cast(exp))); + } else if constexpr (std::is_same_v) { + return __float2half(powf(__half2float(base), __half2float(exp))); + } else { + return powf(base, exp); + } + } +}; + +struct Subtract { + template + __device__ T operator()(T x, T y) { + if constexpr (is_complex_v) { + return hipCsubf(x, y); + } else if constexpr (std::is_same_v) { + return hip_bfloat16(static_cast(x) - static_cast(y)); + } else if constexpr (std::is_same_v) { + return __float2half(__half2float(x) - __half2float(y)); + } else { + return x - y; + } + } +}; + +struct LogicalAnd { + template + __device__ bool operator()(T x, T y) { + if constexpr (std::is_same_v) { + return (static_cast(x) != 0.0f) && (static_cast(y) != 0.0f); + } else if constexpr (std::is_same_v) { + return (__half2float(x) != 0.0f) && (__half2float(y) != 0.0f); + } else if constexpr (std::is_floating_point_v) { + return (x != T(0)) && (y != T(0)); + } else { + return x && y; + } + }; +}; + +struct LogicalOr { + template + __device__ bool operator()(T x, T y) { + if constexpr (std::is_same_v) { + return (static_cast(x) != 0.0f) || (static_cast(y) != 0.0f); + } else if constexpr (std::is_same_v) { + return (__half2float(x) != 0.0f) || (__half2float(y) != 0.0f); + } else if constexpr (std::is_floating_point_v) { + return (x != T(0)) || (y != T(0)); + } else { + return x || y; + } + }; +}; + +struct BitwiseAnd { + template + __device__ T operator()(T x, T y) { + if constexpr (std::is_integral_v) { + return x & y; + } else { + // This branch should never be taken due to supports_binary_op filtering + return T{}; + } + }; +}; + +struct BitwiseOr { + template + __device__ T operator()(T x, T y) { + if constexpr (std::is_integral_v) { + return x | y; + } else { + return T{}; + } + }; +}; + +struct BitwiseXor { + template + __device__ T operator()(T x, T y) { + if constexpr (std::is_integral_v) { + return x ^ y; + } else { + return T{}; + } + }; +}; + +struct LeftShift { + template + __device__ T operator()(T x, T y) { + if constexpr (std::is_integral_v) { + return x << y; + } else { + return T{}; + } + }; +}; + +struct RightShift { + template + __device__ T operator()(T x, T y) { + if constexpr (std::is_integral_v) { + return x >> y; + } else { + return T{}; + } + }; +}; + +struct ArcTan2 { + template + __device__ T operator()(T y, T x) { + if constexpr (std::is_same_v || std::is_integral_v) { + return static_cast( + atan2f(static_cast(y), static_cast(x))); + } else if constexpr (std::is_same_v) { + return hip_bfloat16(atan2f(static_cast(y), static_cast(x))); + } else if constexpr (std::is_same_v) { + return __float2half(atan2f(__half2float(y), __half2float(x))); + } else if constexpr (std::is_same_v) { + return atan2(y, x); + } else { + return atan2f(y, x); + } + } +}; + +struct DivMod { + template + __device__ hip_array operator()(T x, T y) { + return {FloorDivide{}(x, y), Remainder{}(x, y)}; + }; +}; + +} // namespace mlx::core::rocm diff --git a/mlx/backend/rocm/device/cast_op.hpp b/mlx/backend/rocm/device/cast_op.hpp new file mode 100644 index 0000000000..859eb7d8cb --- /dev/null +++ b/mlx/backend/rocm/device/cast_op.hpp @@ -0,0 +1,294 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include +#include +#include + +#include + +namespace mlx::core::rocm { + +// Type trait to check if a type is castable +template +struct is_castable : std::true_type {}; + +// Cast operation for type conversion +template +struct Cast { + __device__ To operator()(From x) { + return static_cast(x); + } +}; + +// Same type - no-op +template +struct Cast { + __device__ T operator()(T x) { + return x; + } +}; + +// Specializations for half types +template +struct Cast<__half, To> { + __device__ To operator()(__half x) { + return static_cast(__half2float(x)); + } +}; + +template +struct Cast { + __device__ __half operator()(From x) { + return __float2half(static_cast(x)); + } +}; + +template <> +struct Cast<__half, __half> { + __device__ __half operator()(__half x) { + return x; + } +}; + +// Specializations for bfloat16 types +template +struct Cast { + __device__ To operator()(hip_bfloat16 x) { + return static_cast(static_cast(x)); + } +}; + +template +struct Cast { + __device__ hip_bfloat16 operator()(From x) { + return hip_bfloat16(static_cast(x)); + } +}; + +template <> +struct Cast { + __device__ hip_bfloat16 operator()(hip_bfloat16 x) { + return x; + } +}; + +// Conversion between half and bfloat16 +template <> +struct Cast<__half, hip_bfloat16> { + __device__ hip_bfloat16 operator()(__half x) { + return hip_bfloat16(__half2float(x)); + } +}; + +template <> +struct Cast { + __device__ __half operator()(hip_bfloat16 x) { + return __float2half(static_cast(x)); + } +}; + +// Complex type conversions +// Complex to bool +template <> +struct Cast { + __device__ bool operator()(hipFloatComplex x) { + return x.x != 0.0f || x.y != 0.0f; + } +}; + +// Bool to complex +template <> +struct Cast { + __device__ hipFloatComplex operator()(bool x) { + return make_hipFloatComplex(x ? 1.0f : 0.0f, 0.0f); + } +}; + +// Complex to real types (discards imaginary part) +template <> +struct Cast { + __device__ float operator()(hipFloatComplex x) { + return x.x; + } +}; + +template <> +struct Cast { + __device__ double operator()(hipFloatComplex x) { + return static_cast(x.x); + } +}; + +template <> +struct Cast { + __device__ int operator()(hipFloatComplex x) { + return static_cast(x.x); + } +}; + +template <> +struct Cast { + __device__ int64_t operator()(hipFloatComplex x) { + return static_cast(x.x); + } +}; + +template <> +struct Cast { + __device__ uint32_t operator()(hipFloatComplex x) { + return static_cast(x.x); + } +}; + +template <> +struct Cast { + __device__ uint64_t operator()(hipFloatComplex x) { + return static_cast(x.x); + } +}; + +template <> +struct Cast { + __device__ int8_t operator()(hipFloatComplex x) { + return static_cast(x.x); + } +}; + +template <> +struct Cast { + __device__ uint8_t operator()(hipFloatComplex x) { + return static_cast(x.x); + } +}; + +template <> +struct Cast { + __device__ int16_t operator()(hipFloatComplex x) { + return static_cast(x.x); + } +}; + +template <> +struct Cast { + __device__ uint16_t operator()(hipFloatComplex x) { + return static_cast(x.x); + } +}; + +template <> +struct Cast { + __device__ __half operator()(hipFloatComplex x) { + return __float2half(x.x); + } +}; + +template <> +struct Cast { + __device__ hip_bfloat16 operator()(hipFloatComplex x) { + return hip_bfloat16(x.x); + } +}; + +// Real types to complex (sets imaginary to 0) +template <> +struct Cast { + __device__ hipFloatComplex operator()(float x) { + return make_hipFloatComplex(x, 0.0f); + } +}; + +template <> +struct Cast { + __device__ hipFloatComplex operator()(double x) { + return make_hipFloatComplex(static_cast(x), 0.0f); + } +}; + +template <> +struct Cast { + __device__ hipFloatComplex operator()(int x) { + return make_hipFloatComplex(static_cast(x), 0.0f); + } +}; + +template <> +struct Cast { + __device__ hipFloatComplex operator()(int64_t x) { + return make_hipFloatComplex(static_cast(x), 0.0f); + } +}; + +template <> +struct Cast { + __device__ hipFloatComplex operator()(uint32_t x) { + return make_hipFloatComplex(static_cast(x), 0.0f); + } +}; + +template <> +struct Cast { + __device__ hipFloatComplex operator()(uint64_t x) { + return make_hipFloatComplex(static_cast(x), 0.0f); + } +}; + +template <> +struct Cast { + __device__ hipFloatComplex operator()(int8_t x) { + return make_hipFloatComplex(static_cast(x), 0.0f); + } +}; + +template <> +struct Cast { + __device__ hipFloatComplex operator()(uint8_t x) { + return make_hipFloatComplex(static_cast(x), 0.0f); + } +}; + +template <> +struct Cast { + __device__ hipFloatComplex operator()(int16_t x) { + return make_hipFloatComplex(static_cast(x), 0.0f); + } +}; + +template <> +struct Cast { + __device__ hipFloatComplex operator()(uint16_t x) { + return make_hipFloatComplex(static_cast(x), 0.0f); + } +}; + +template <> +struct Cast<__half, hipFloatComplex> { + __device__ hipFloatComplex operator()(__half x) { + return make_hipFloatComplex(__half2float(x), 0.0f); + } +}; + +template <> +struct Cast { + __device__ hipFloatComplex operator()(hip_bfloat16 x) { + return make_hipFloatComplex(static_cast(x), 0.0f); + } +}; + +// Complex to complex (identity) +template <> +struct Cast { + __device__ hipFloatComplex operator()(hipFloatComplex x) { + return x; + } +}; + +// Helper function for casting (similar to CUDA's cast_to) +template +__device__ DstT cast_to(SrcT x) { + return Cast{}(x); +} + +} // namespace mlx::core::rocm diff --git a/mlx/backend/rocm/device/config.h b/mlx/backend/rocm/device/config.h new file mode 100644 index 0000000000..d23de7f747 --- /dev/null +++ b/mlx/backend/rocm/device/config.h @@ -0,0 +1,166 @@ +// Copyright © 2025 Apple Inc. + +// This file is used by both HIP kernel code and host-only C++ code. + +#pragma once + +// The maximum dimensions of shape/strides passed as kernel parameters. +#define MAX_NDIM 10 + +// AMD GPU warp (wavefront) size varies by architecture: +// - CDNA/GCN (gfx9xx and earlier): 64 +// - RDNA (gfx10xx, gfx11xx, gfx12xx): 32 +// +// The __AMDGCN_WAVEFRONT_SIZE__ macro is defined by the HIP compiler +// based on the target architecture. We use it when available for device code. +// +// IMPORTANT: For host code, we need a consistent value that matches the +// compiled device code. Since we compile for specific architectures via +// CMAKE_HIP_ARCHITECTURES, we need to ensure host and device agree. +// +// For now, we default to 32 (RDNA) since that's the most common consumer GPU. +// If targeting CDNA/GCN architectures, change this to 64. +#if defined(__AMDGCN_WAVEFRONT_SIZE__) +// Device code: use the compiler-provided value +#define WARP_SIZE __AMDGCN_WAVEFRONT_SIZE__ +#elif defined(__HIP_DEVICE_COMPILE__) +// Device code without wavefront size macro - check architecture macros +#if defined(__gfx1010__) || defined(__gfx1011__) || defined(__gfx1012__) || \ + defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \ + defined(__gfx1033__) || defined(__gfx1034__) || defined(__gfx1035__) || \ + defined(__gfx1036__) || defined(__gfx1100__) || defined(__gfx1101__) || \ + defined(__gfx1102__) || defined(__gfx1103__) || defined(__gfx1150__) || \ + defined(__gfx1151__) || defined(__gfx1200__) || defined(__gfx1201__) +#define WARP_SIZE 32 +#else +#define WARP_SIZE 64 +#endif +#else +// Host code: use a fixed value that matches the target architecture. +// This MUST match the CMAKE_HIP_ARCHITECTURES setting. +// For RDNA (gfx10xx, gfx11xx, gfx12xx): 32 +// For CDNA/GCN (gfx9xx): 64 +#define WARP_SIZE 32 +#endif + +namespace mlx::core::rocm { + +// Configuration constants for ROCm kernels + +// Default thread block size +constexpr int kDefaultBlockSize = 256; + +// Maximum threads per block (typical for AMD GPUs) +constexpr int kMaxThreadsPerBlock = 1024; + +// Warp size (wavefront size) - use the macro for compile-time value +constexpr int kWarpSize = WARP_SIZE; + +// Maximum shared memory per block (in bytes) +constexpr int kMaxSharedMemoryPerBlock = 65536; + +// Maximum number of dimensions supported +constexpr int kMaxNdim = 8; + +// Reduce constants +constexpr int kReduceBlockSize = 256; +constexpr int kReduceMaxBlocks = 1024; + +// Copy constants +constexpr int kCopyBlockSize = 256; + +// Softmax constants +constexpr int kSoftmaxBlockSize = 256; + +// Layer norm constants +constexpr int kLayerNormBlockSize = 256; + +// RMS norm constants +constexpr int kRMSNormBlockSize = 256; + +// Attention constants +constexpr int kAttentionBlockSize = 256; + +// ---- Architecture tier detection and per-arch kernel tuning ---- +// +// RocmArchTier provides fine-grained GPU generation identification. +// ArchTuning holds per-arch parameters for kernel dispatch decisions. +// Both are usable from host code and kernel dispatch logic. + +enum class RocmArchTier { + Rdna2, // gfx10xx: RDNA 2, Wave32, no WMMA + Rdna3, // gfx1100-gfx1103: RDNA 3, Wave32, WMMA, 96KB LDS + Rdna35, // gfx1150-gfx1152: RDNA 3.5, Wave32, WMMA, 64KB LDS, 32MB IC + Rdna4, // gfx1200-gfx1201: RDNA 4, Wave32, enhanced WMMA + Cdna, // gfx9xx: MI-series, Wave64 +}; + +// Hardware capabilities detected at runtime from hipDeviceProp_t. +struct HWInfo { + RocmArchTier tier; + int num_cus; // Compute units (multiProcessorCount) + int simds_per_cu; // SIMDs per CU (2 for RDNA, 4 for CDNA) + int max_threads_per_cu; // Max resident threads per CU + int shared_mem_per_cu; // Shared/LDS memory per CU in bytes + int l2_cache_bytes; // L2/Infinity Cache size + bool has_native_wmma; // True if arch is on rocWMMA allowlist + // (CDNA1/2/3 + RDNA3 dGPU + gfx1151 + RDNA4) +}; + +// Per-architecture tuning parameters for quantized matvec and attention +// kernels. +struct ArchTuning { + // QMV tiled kernel + int qmv_tile_n; // Output columns per block (L2 reuse) + // QMV↔GEMM crossover M thresholds + int qmv_crossover_small; // For K<=2048, N<=2048 + int qmv_crossover_medium; // For K<=4096, N<=4096 + int qmv_crossover_large; // For larger shapes + // Flash attention + int fa_block_m; // Queries per flash attention block + int fa_block_n; // Keys per iteration +}; + +// Auto-tune based on detected hardware. Adjusts tile sizes based on actual +// CU count to balance occupancy vs L2 reuse. +inline ArchTuning get_arch_tuning(RocmArchTier tier) { + // Defaults per tier — used when HWInfo isn't available + switch (tier) { + case RocmArchTier::Rdna2: + return ArchTuning{8, 28, 20, 14, 128, 64}; + case RocmArchTier::Rdna3: + return ArchTuning{16, 36, 24, 16, 64, 64}; + case RocmArchTier::Rdna35: + // 40 CUs: TILE_N=16 gives best occupancy/reuse balance + return ArchTuning{16, 36, 24, 16, 64, 64}; + case RocmArchTier::Rdna4: + return ArchTuning{32, 40, 28, 18, 64, 64}; + case RocmArchTier::Cdna: + default: + return ArchTuning{16, 20, 14, 10, 128, 64}; + } +} + +// Auto-tune using full hardware info. Adjusts TILE_N based on CU count: +// fewer CUs → larger tiles for more L2 reuse per block. +inline ArchTuning get_arch_tuning(const HWInfo& hw) { + auto t = get_arch_tuning(hw.tier); + + // TILE_N is bounded by how many column streams L2 holds without evicting the + // reused X/scales. RDNA 3/3.5 (2 MB L2): 16. RDNA 4 (8 MB L2): 24. + if (hw.tier == RocmArchTier::Rdna3 || hw.tier == RocmArchTier::Rdna35) { + t.qmv_tile_n = (hw.num_cus <= 16) ? 8 : 16; + } else if (hw.tier == RocmArchTier::Rdna4) { + if (hw.num_cus <= 16) { + t.qmv_tile_n = 8; + } else if (hw.l2_cache_bytes >= (6 << 20)) { + t.qmv_tile_n = 24; // >=6 MB L2 (Navi 48 = 8 MB): wider tile, less waste + } else { + t.qmv_tile_n = 16; + } + } + + return t; +} + +} // namespace mlx::core::rocm diff --git a/mlx/backend/rocm/device/fp16_math.hpp b/mlx/backend/rocm/device/fp16_math.hpp new file mode 100644 index 0000000000..52770d683f --- /dev/null +++ b/mlx/backend/rocm/device/fp16_math.hpp @@ -0,0 +1,436 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include +#include +#include + +namespace mlx::core::rocm { + +// Half-precision math functions for HIP +// Note: bfloat16 operations are computed in float since HIP doesn't have native +// bfloat16 math + +// Helper to convert bfloat16 to float and back +__device__ inline float bf16_to_float(hip_bfloat16 x) { + return static_cast(x); +} + +__device__ inline hip_bfloat16 float_to_bf16(float x) { + return hip_bfloat16(x); +} + +// Abs for half types +__device__ inline __half abs(__half x) { + return __habs(x); +} + +__device__ inline hip_bfloat16 abs(hip_bfloat16 x) { + return float_to_bf16(fabsf(bf16_to_float(x))); +} + +// Sqrt for half types +__device__ inline __half sqrt(__half x) { + return hsqrt(x); +} + +__device__ inline hip_bfloat16 sqrt(hip_bfloat16 x) { + return float_to_bf16(sqrtf(bf16_to_float(x))); +} + +// Rsqrt for half types +__device__ inline __half rsqrt(__half x) { + return hrsqrt(x); +} + +__device__ inline hip_bfloat16 rsqrt(hip_bfloat16 x) { + return float_to_bf16(rsqrtf(bf16_to_float(x))); +} + +// Exp for half types +__device__ inline __half exp(__half x) { + return hexp(x); +} + +__device__ inline hip_bfloat16 exp(hip_bfloat16 x) { + return float_to_bf16(expf(bf16_to_float(x))); +} + +// Log for half types +__device__ inline __half log(__half x) { + return hlog(x); +} + +__device__ inline hip_bfloat16 log(hip_bfloat16 x) { + return float_to_bf16(logf(bf16_to_float(x))); +} + +// Log2 for half types +__device__ inline __half log2(__half x) { + return hlog2(x); +} + +__device__ inline hip_bfloat16 log2(hip_bfloat16 x) { + return float_to_bf16(log2f(bf16_to_float(x))); +} + +// Log10 for half types +__device__ inline __half log10(__half x) { + return hlog10(x); +} + +__device__ inline hip_bfloat16 log10(hip_bfloat16 x) { + return float_to_bf16(log10f(bf16_to_float(x))); +} + +// Sin for half types +__device__ inline __half sin(__half x) { + return hsin(x); +} + +__device__ inline hip_bfloat16 sin(hip_bfloat16 x) { + return float_to_bf16(sinf(bf16_to_float(x))); +} + +// Cos for half types +__device__ inline __half cos(__half x) { + return hcos(x); +} + +__device__ inline hip_bfloat16 cos(hip_bfloat16 x) { + return float_to_bf16(cosf(bf16_to_float(x))); +} + +// Ceil for half types +__device__ inline __half ceil(__half x) { + return hceil(x); +} + +__device__ inline hip_bfloat16 ceil(hip_bfloat16 x) { + return float_to_bf16(ceilf(bf16_to_float(x))); +} + +// Floor for half types +__device__ inline __half floor(__half x) { + return hfloor(x); +} + +__device__ inline hip_bfloat16 floor(hip_bfloat16 x) { + return float_to_bf16(floorf(bf16_to_float(x))); +} + +// Rint (round to nearest integer) for half types +__device__ inline __half rint(__half x) { + return hrint(x); +} + +__device__ inline hip_bfloat16 rint(hip_bfloat16 x) { + return float_to_bf16(rintf(bf16_to_float(x))); +} + +// Trunc for half types +__device__ inline __half trunc(__half x) { + return htrunc(x); +} + +__device__ inline hip_bfloat16 trunc(hip_bfloat16 x) { + return float_to_bf16(truncf(bf16_to_float(x))); +} + +// Conversion helpers +__device__ inline float half2float(__half x) { + return __half2float(x); +} + +__device__ inline __half float2half(float x) { + return __float2half(x); +} + +__device__ inline float bfloat162float(hip_bfloat16 x) { + return bf16_to_float(x); +} + +__device__ inline hip_bfloat16 float2bfloat16(float x) { + return float_to_bf16(x); +} + +// Erf for half types (compute in float) +__device__ inline __half erf(__half x) { + return __float2half(erff(__half2float(x))); +} + +__device__ inline hip_bfloat16 erf(hip_bfloat16 x) { + return float_to_bf16(erff(bf16_to_float(x))); +} + +// Erfinv for half types (compute in float) +__device__ inline __half erfinv(__half x) { + return __float2half(erfinvf(__half2float(x))); +} + +__device__ inline hip_bfloat16 erfinv(hip_bfloat16 x) { + return float_to_bf16(erfinvf(bf16_to_float(x))); +} + +// Expm1 for half types (compute in float) +__device__ inline __half expm1(__half x) { + return __float2half(expm1f(__half2float(x))); +} + +__device__ inline hip_bfloat16 expm1(hip_bfloat16 x) { + return float_to_bf16(expm1f(bf16_to_float(x))); +} + +// Log1p for half types (compute in float) +__device__ inline __half log1p(__half x) { + return __float2half(log1pf(__half2float(x))); +} + +__device__ inline hip_bfloat16 log1p(hip_bfloat16 x) { + return float_to_bf16(log1pf(bf16_to_float(x))); +} + +// Tanh for half types +__device__ inline __half tanh(__half x) { + // HIP may not have htanh, compute in float + return __float2half(tanhf(__half2float(x))); +} + +__device__ inline hip_bfloat16 tanh(hip_bfloat16 x) { + return float_to_bf16(tanhf(bf16_to_float(x))); +} + +// Sinh for half types +__device__ inline __half sinh(__half x) { + return __float2half(sinhf(__half2float(x))); +} + +__device__ inline hip_bfloat16 sinh(hip_bfloat16 x) { + return float_to_bf16(sinhf(bf16_to_float(x))); +} + +// Cosh for half types +__device__ inline __half cosh(__half x) { + return __float2half(coshf(__half2float(x))); +} + +__device__ inline hip_bfloat16 cosh(hip_bfloat16 x) { + return float_to_bf16(coshf(bf16_to_float(x))); +} + +// Asin for half types +__device__ inline __half asin(__half x) { + return __float2half(asinf(__half2float(x))); +} + +__device__ inline hip_bfloat16 asin(hip_bfloat16 x) { + return float_to_bf16(asinf(bf16_to_float(x))); +} + +// Acos for half types +__device__ inline __half acos(__half x) { + return __float2half(acosf(__half2float(x))); +} + +__device__ inline hip_bfloat16 acos(hip_bfloat16 x) { + return float_to_bf16(acosf(bf16_to_float(x))); +} + +// Atan for half types +__device__ inline __half atan(__half x) { + return __float2half(atanf(__half2float(x))); +} + +__device__ inline hip_bfloat16 atan(hip_bfloat16 x) { + return float_to_bf16(atanf(bf16_to_float(x))); +} + +// Asinh for half types +__device__ inline __half asinh(__half x) { + return __float2half(asinhf(__half2float(x))); +} + +__device__ inline hip_bfloat16 asinh(hip_bfloat16 x) { + return float_to_bf16(asinhf(bf16_to_float(x))); +} + +// Acosh for half types +__device__ inline __half acosh(__half x) { + return __float2half(acoshf(__half2float(x))); +} + +__device__ inline hip_bfloat16 acosh(hip_bfloat16 x) { + return float_to_bf16(acoshf(bf16_to_float(x))); +} + +// Atanh for half types +__device__ inline __half atanh(__half x) { + return __float2half(atanhf(__half2float(x))); +} + +__device__ inline hip_bfloat16 atanh(hip_bfloat16 x) { + return float_to_bf16(atanhf(bf16_to_float(x))); +} + +// Tan for half types +__device__ inline __half tan(__half x) { + return __float2half(tanf(__half2float(x))); +} + +__device__ inline hip_bfloat16 tan(hip_bfloat16 x) { + return float_to_bf16(tanf(bf16_to_float(x))); +} + +// Complex math functions +// exp(z) = exp(x) * (cos(y) + i*sin(y)) +__device__ inline hipFloatComplex exp(hipFloatComplex z) { + float ex = expf(z.x); + // Handle special case: if real part is -inf, result is 0 + if (isinf(z.x) && z.x < 0) { + return make_hipFloatComplex(0.0f, 0.0f); + } + float s, c; + sincosf(z.y, &s, &c); + return make_hipFloatComplex(ex * c, ex * s); +} + +// log(z) = log(|z|) + i*arg(z) +__device__ inline hipFloatComplex log(hipFloatComplex z) { + float r = hypotf(z.x, z.y); + float theta = atan2f(z.y, z.x); + return make_hipFloatComplex(logf(r), theta); +} + +// log10(z) = log(z) / log(10) +__device__ inline hipFloatComplex log10(hipFloatComplex z) { + hipFloatComplex lz = log(z); + constexpr float ln10 = 2.302585092994045684017991454684364208f; + return make_hipFloatComplex(lz.x / ln10, lz.y / ln10); +} + +// sin(z) = sin(x)*cosh(y) + i*cos(x)*sinh(y) +__device__ inline hipFloatComplex sin(hipFloatComplex z) { + float sx, cx; + sincosf(z.x, &sx, &cx); + return make_hipFloatComplex(sx * coshf(z.y), cx * sinhf(z.y)); +} + +// cos(z) = cos(x)*cosh(y) - i*sin(x)*sinh(y) +__device__ inline hipFloatComplex cos(hipFloatComplex z) { + float sx, cx; + sincosf(z.x, &sx, &cx); + return make_hipFloatComplex(cx * coshf(z.y), -sx * sinhf(z.y)); +} + +// tan(z) = sin(z) / cos(z) +__device__ inline hipFloatComplex tan(hipFloatComplex z) { + return hipCdivf(sin(z), cos(z)); +} + +// sinh(z) = sinh(x)*cos(y) + i*cosh(x)*sin(y) +__device__ inline hipFloatComplex sinh(hipFloatComplex z) { + float sy, cy; + sincosf(z.y, &sy, &cy); + return make_hipFloatComplex(sinhf(z.x) * cy, coshf(z.x) * sy); +} + +// cosh(z) = cosh(x)*cos(y) + i*sinh(x)*sin(y) +__device__ inline hipFloatComplex cosh(hipFloatComplex z) { + float sy, cy; + sincosf(z.y, &sy, &cy); + return make_hipFloatComplex(coshf(z.x) * cy, sinhf(z.x) * sy); +} + +// tanh(z) = sinh(z) / cosh(z) +__device__ inline hipFloatComplex tanh(hipFloatComplex z) { + return hipCdivf(sinh(z), cosh(z)); +} + +// sqrt(z) = sqrt(|z|) * (cos(arg(z)/2) + i*sin(arg(z)/2)) +__device__ inline hipFloatComplex sqrt(hipFloatComplex z) { + float r = hypotf(z.x, z.y); + float theta = atan2f(z.y, z.x); + float sr = sqrtf(r); + float half_theta = theta * 0.5f; + float s, c; + sincosf(half_theta, &s, &c); + return make_hipFloatComplex(sr * c, sr * s); +} + +// abs(z) = |z| (returns complex with real part = magnitude, imag = 0) +__device__ inline hipFloatComplex abs(hipFloatComplex z) { + return make_hipFloatComplex(hypotf(z.x, z.y), 0.0f); +} + +// asin(z) = -i * log(i*z + sqrt(1 - z^2)) +__device__ inline hipFloatComplex asin(hipFloatComplex z) { + // i*z + hipFloatComplex iz = make_hipFloatComplex(-z.y, z.x); + // z^2 + hipFloatComplex z2 = hipCmulf(z, z); + // 1 - z^2 + hipFloatComplex one_minus_z2 = make_hipFloatComplex(1.0f - z2.x, -z2.y); + // sqrt(1 - z^2) + hipFloatComplex sqrt_term = sqrt(one_minus_z2); + // i*z + sqrt(1 - z^2) + hipFloatComplex sum = + make_hipFloatComplex(iz.x + sqrt_term.x, iz.y + sqrt_term.y); + // log(...) + hipFloatComplex log_term = log(sum); + // -i * log(...) = (log.y, -log.x) + return make_hipFloatComplex(log_term.y, -log_term.x); +} + +// acos(z) = pi/2 - asin(z) +__device__ inline hipFloatComplex acos(hipFloatComplex z) { + hipFloatComplex asin_z = asin(z); + constexpr float pi_2 = 1.5707963267948966192313216916397514f; + return make_hipFloatComplex(pi_2 - asin_z.x, -asin_z.y); +} + +// atan(z) = (i/2) * log((i+z)/(i-z)) +__device__ inline hipFloatComplex atan(hipFloatComplex z) { + // i + z + hipFloatComplex i_plus_z = make_hipFloatComplex(z.x, 1.0f + z.y); + // i - z + hipFloatComplex i_minus_z = make_hipFloatComplex(-z.x, 1.0f - z.y); + // (i+z)/(i-z) + hipFloatComplex ratio = hipCdivf(i_plus_z, i_minus_z); + // log(...) + hipFloatComplex log_term = log(ratio); + // (i/2) * log(...) = (-log.y/2, log.x/2) + return make_hipFloatComplex(-log_term.y * 0.5f, log_term.x * 0.5f); +} + +// asinh(z) = log(z + sqrt(z^2 + 1)) +__device__ inline hipFloatComplex asinh(hipFloatComplex z) { + hipFloatComplex z2 = hipCmulf(z, z); + hipFloatComplex z2_plus_1 = make_hipFloatComplex(z2.x + 1.0f, z2.y); + hipFloatComplex sqrt_term = sqrt(z2_plus_1); + hipFloatComplex sum = + make_hipFloatComplex(z.x + sqrt_term.x, z.y + sqrt_term.y); + return log(sum); +} + +// acosh(z) = log(z + sqrt(z^2 - 1)) +__device__ inline hipFloatComplex acosh(hipFloatComplex z) { + hipFloatComplex z2 = hipCmulf(z, z); + hipFloatComplex z2_minus_1 = make_hipFloatComplex(z2.x - 1.0f, z2.y); + hipFloatComplex sqrt_term = sqrt(z2_minus_1); + hipFloatComplex sum = + make_hipFloatComplex(z.x + sqrt_term.x, z.y + sqrt_term.y); + return log(sum); +} + +// atanh(z) = (1/2) * log((1+z)/(1-z)) +__device__ inline hipFloatComplex atanh(hipFloatComplex z) { + hipFloatComplex one_plus_z = make_hipFloatComplex(1.0f + z.x, z.y); + hipFloatComplex one_minus_z = make_hipFloatComplex(1.0f - z.x, -z.y); + hipFloatComplex ratio = hipCdivf(one_plus_z, one_minus_z); + hipFloatComplex log_term = log(ratio); + return make_hipFloatComplex(log_term.x * 0.5f, log_term.y * 0.5f); +} + +} // namespace mlx::core::rocm diff --git a/mlx/backend/rocm/device/gather.hpp b/mlx/backend/rocm/device/gather.hpp new file mode 100644 index 0000000000..947d97fa6e --- /dev/null +++ b/mlx/backend/rocm/device/gather.hpp @@ -0,0 +1,48 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/rocm/device/indexing.hpp" +#include "mlx/backend/rocm/device/utils.hpp" + +#include + +namespace mlx::core::rocm { + +template +__global__ void gather( + const T* src, + T* out, + LocT size, + const int32_t* src_shape, + const int64_t* src_strides, + int32_t src_ndim, + const int32_t* slice_sizes, + uint32_t slice_size, + const int32_t* axes, + const IdxT* const* indices, + const int32_t* indices_shape, + const int64_t* indices_strides) { + LocT out_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (out_idx >= size) { + return; + } + + LocT src_elem = out_idx % slice_size; + LocT idx_elem = out_idx / slice_size; + + LocT src_loc = elem_to_loc(src_elem, slice_sizes, src_strides, src_ndim); + +#pragma unroll + for (int i = 0; i < NIDX; ++i) { + LocT idx_loc = elem_to_loc_nd( + idx_elem, indices_shape + i * IDX_NDIM, indices_strides + i * IDX_NDIM); + int32_t axis = axes[i]; + LocT idx_val = absolute_index(indices[i][idx_loc], src_shape[axis]); + src_loc += idx_val * src_strides[axis]; + } + + out[out_idx] = src[src_loc]; +} + +} // namespace mlx::core::rocm diff --git a/mlx/backend/rocm/device/gather_axis.hpp b/mlx/backend/rocm/device/gather_axis.hpp new file mode 100644 index 0000000000..7138109ade --- /dev/null +++ b/mlx/backend/rocm/device/gather_axis.hpp @@ -0,0 +1,66 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/rocm/device/indexing.hpp" +#include "mlx/backend/rocm/device/utils.hpp" + +#include + +namespace mlx::core::rocm { + +template < + typename T, + typename IdxT, + int NDIM, + bool SrcC, + bool IdxC, + typename LocT = int64_t> +__global__ void gather_axis_kernel( + const T* src, + const IdxT* indices, + T* out, + LocT idx_size_pre, + LocT idx_size_axis, + LocT idx_size_post, + const hip_array shape, + const hip_array src_strides, + const hip_array idx_strides, + int32_t axis, + int32_t axis_size, + int64_t src_stride_axis, + int64_t idx_stride_axis) { + LocT index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= idx_size_pre * idx_size_axis * idx_size_post) { + return; + } + + LocT x, y, z; + index_to_dims(index, idx_size_axis, idx_size_pre, x, y, z); + + LocT elem_idx = z * idx_size_post; + + LocT idx_loc = y * idx_stride_axis; + if constexpr (IdxC) { + idx_loc += elem_idx * idx_size_axis + x; + } else { + idx_loc += + elem_to_loc_nd(elem_idx + x, shape.data_, idx_strides.data_); + } + + auto idx_val = absolute_index(indices[idx_loc], axis_size); + + LocT src_loc = idx_val * src_stride_axis; + if constexpr (SrcC) { + src_loc += elem_idx * axis_size + x; + } else { + src_loc += + elem_to_loc_nd(elem_idx + x, shape.data_, src_strides.data_); + } + + LocT out_idx = y * idx_size_post + elem_idx * idx_size_axis + x; + + out[out_idx] = src[src_loc]; +} + +} // namespace mlx::core::rocm diff --git a/mlx/backend/rocm/device/hip_complex_math.hpp b/mlx/backend/rocm/device/hip_complex_math.hpp new file mode 100644 index 0000000000..22c69853b7 --- /dev/null +++ b/mlx/backend/rocm/device/hip_complex_math.hpp @@ -0,0 +1,172 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include + +namespace mlx::core::rocm { + +// Complex number type alias +using complex64_t = hipFloatComplex; + +// Make complex from real and imaginary parts +__device__ inline hipFloatComplex make_complex(float real, float imag) { + return make_hipFloatComplex(real, imag); +} + +// Get real part +__device__ inline float real(hipFloatComplex z) { + return hipCrealf(z); +} + +// Get imaginary part +__device__ inline float imag(hipFloatComplex z) { + return hipCimagf(z); +} + +// Complex conjugate +__device__ inline hipFloatComplex conj(hipFloatComplex z) { + return hipConjf(z); +} + +// Complex absolute value (magnitude) +__device__ inline float abs(hipFloatComplex z) { + return hipCabsf(z); +} + +// Complex addition +__device__ inline hipFloatComplex operator+( + hipFloatComplex a, + hipFloatComplex b) { + return hipCaddf(a, b); +} + +// Complex subtraction +__device__ inline hipFloatComplex operator-( + hipFloatComplex a, + hipFloatComplex b) { + return hipCsubf(a, b); +} + +// Complex multiplication +__device__ inline hipFloatComplex operator*( + hipFloatComplex a, + hipFloatComplex b) { + return hipCmulf(a, b); +} + +// Complex division +__device__ inline hipFloatComplex operator/( + hipFloatComplex a, + hipFloatComplex b) { + return hipCdivf(a, b); +} + +// Complex negation +__device__ inline hipFloatComplex operator-(hipFloatComplex z) { + return make_hipFloatComplex(-hipCrealf(z), -hipCimagf(z)); +} + +// Complex comparison (by magnitude, for sorting) +__device__ inline bool operator<(hipFloatComplex a, hipFloatComplex b) { + float mag_a = hipCabsf(a); + float mag_b = hipCabsf(b); + return mag_a < mag_b; +} + +__device__ inline bool operator>(hipFloatComplex a, hipFloatComplex b) { + float mag_a = hipCabsf(a); + float mag_b = hipCabsf(b); + return mag_a > mag_b; +} + +__device__ inline bool operator<=(hipFloatComplex a, hipFloatComplex b) { + return !(a > b); +} + +__device__ inline bool operator>=(hipFloatComplex a, hipFloatComplex b) { + return !(a < b); +} + +__device__ inline bool operator==(hipFloatComplex a, hipFloatComplex b) { + return hipCrealf(a) == hipCrealf(b) && hipCimagf(a) == hipCimagf(b); +} + +__device__ inline bool operator!=(hipFloatComplex a, hipFloatComplex b) { + return !(a == b); +} + +// Complex exponential +__device__ inline hipFloatComplex exp(hipFloatComplex z) { + float r = expf(hipCrealf(z)); + float i = hipCimagf(z); + return make_hipFloatComplex(r * cosf(i), r * sinf(i)); +} + +// Complex logarithm +__device__ inline hipFloatComplex log(hipFloatComplex z) { + return make_hipFloatComplex( + logf(hipCabsf(z)), atan2f(hipCimagf(z), hipCrealf(z))); +} + +// Complex square root +__device__ inline hipFloatComplex sqrt(hipFloatComplex z) { + float r = hipCabsf(z); + float x = hipCrealf(z); + float y = hipCimagf(z); + float t = sqrtf((r + fabsf(x)) / 2.0f); + if (x >= 0) { + return make_hipFloatComplex(t, y / (2.0f * t)); + } else { + return make_hipFloatComplex(fabsf(y) / (2.0f * t), copysignf(t, y)); + } +} + +// Complex sine +__device__ inline hipFloatComplex sin(hipFloatComplex z) { + float x = hipCrealf(z); + float y = hipCimagf(z); + return make_hipFloatComplex(sinf(x) * coshf(y), cosf(x) * sinhf(y)); +} + +// Complex cosine +__device__ inline hipFloatComplex cos(hipFloatComplex z) { + float x = hipCrealf(z); + float y = hipCimagf(z); + return make_hipFloatComplex(cosf(x) * coshf(y), -sinf(x) * sinhf(y)); +} + +// Complex tangent +__device__ inline hipFloatComplex tan(hipFloatComplex z) { + return hipCdivf(sin(z), cos(z)); +} + +// Complex hyperbolic sine +__device__ inline hipFloatComplex sinh(hipFloatComplex z) { + float x = hipCrealf(z); + float y = hipCimagf(z); + return make_hipFloatComplex(sinhf(x) * cosf(y), coshf(x) * sinf(y)); +} + +// Complex hyperbolic cosine +__device__ inline hipFloatComplex cosh(hipFloatComplex z) { + float x = hipCrealf(z); + float y = hipCimagf(z); + return make_hipFloatComplex(coshf(x) * cosf(y), sinhf(x) * sinf(y)); +} + +// Complex hyperbolic tangent +__device__ inline hipFloatComplex tanh(hipFloatComplex z) { + return hipCdivf(sinh(z), cosh(z)); +} + +// Complex power +__device__ inline hipFloatComplex pow( + hipFloatComplex base, + hipFloatComplex exp) { + // base^exp = exp(exp * log(base)) + return rocm::exp(hipCmulf(exp, rocm::log(base))); +} + +} // namespace mlx::core::rocm diff --git a/mlx/backend/rocm/device/indexing.hpp b/mlx/backend/rocm/device/indexing.hpp new file mode 100644 index 0000000000..3861316917 --- /dev/null +++ b/mlx/backend/rocm/device/indexing.hpp @@ -0,0 +1,31 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include + +namespace mlx::core::rocm { + +// Convert an absolute index to positions in a 3d grid, assuming the index is +// calculated with: +// index = x * dim1 * dim2 + y * dim2 + z +template +inline __host__ __device__ void +index_to_dims(T index, T dim1, T dim2, T& x, T& y, T& z) { + x = index / (dim1 * dim2); + y = (index % (dim1 * dim2)) / dim2; + z = index % dim2; +} + +// Get absolute index from possible negative index. +template +inline __host__ __device__ auto absolute_index(IdxT idx, int32_t size) { + if constexpr (std::is_unsigned_v) { + return idx; + } else { + return static_cast(idx < 0 ? idx + size : idx); + } +} + +} // namespace mlx::core::rocm diff --git a/mlx/backend/rocm/device/scatter.hpp b/mlx/backend/rocm/device/scatter.hpp new file mode 100644 index 0000000000..5b842ac190 --- /dev/null +++ b/mlx/backend/rocm/device/scatter.hpp @@ -0,0 +1,64 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/rocm/device/indexing.hpp" +#include "mlx/backend/rocm/device/scatter_ops.hpp" +#include "mlx/backend/rocm/device/utils.hpp" + +#include + +namespace mlx::core::rocm { + +template < + typename T, + typename IdxT, + typename Op, + int NIDX, + int IDX_NDIM, + typename LocT> +__global__ void scatter( + const T* upd, + T* out, + LocT size, + const int32_t* upd_shape, + const int64_t* upd_strides, + int32_t upd_ndim, + LocT upd_post_idx_size, + const int32_t* out_shape, + const int64_t* out_strides, + int32_t out_ndim, + const int32_t* axes, + const IdxT* const* indices, + const int32_t* indices_shape, + const int64_t* indices_strides) { + LocT upd_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (upd_idx >= size) { + return; + } + + LocT out_elem = upd_idx % upd_post_idx_size; + LocT idx_elem = upd_idx / upd_post_idx_size; + + LocT out_idx = + elem_to_loc(out_elem, upd_shape + IDX_NDIM, out_strides, out_ndim); + +#pragma unroll + for (int i = 0; i < NIDX; ++i) { + LocT idx_loc = elem_to_loc_nd( + idx_elem, indices_shape + i * IDX_NDIM, indices_strides + i * IDX_NDIM); + int32_t axis = axes[i]; + LocT idx_val = absolute_index(indices[i][idx_loc], out_shape[axis]); + out_idx += idx_val * out_strides[axis]; + } + + LocT upd_loc = elem_to_loc( + out_elem + idx_elem * upd_post_idx_size, + upd_shape, + upd_strides, + upd_ndim); + + Op{}(out + out_idx, upd[upd_loc]); +} + +} // namespace mlx::core::rocm diff --git a/mlx/backend/rocm/device/scatter_axis.hpp b/mlx/backend/rocm/device/scatter_axis.hpp new file mode 100644 index 0000000000..6aee595afb --- /dev/null +++ b/mlx/backend/rocm/device/scatter_axis.hpp @@ -0,0 +1,68 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/rocm/device/indexing.hpp" +#include "mlx/backend/rocm/device/scatter_ops.hpp" +#include "mlx/backend/rocm/device/utils.hpp" + +#include + +namespace mlx::core::rocm { + +template < + typename T, + typename IdxT, + typename Op, + int NDIM, + bool UpdC, + bool IdxC, + typename LocT = int64_t> +__global__ void scatter_axis_kernel( + const T* upd, + const IdxT* indices, + T* out, + LocT idx_size_pre, + LocT idx_size_axis, + LocT idx_size_post, + const hip_array shape, + const hip_array upd_strides, + const hip_array idx_strides, + int32_t axis, + int32_t axis_size, + int64_t upd_stride_axis, + int64_t idx_stride_axis) { + LocT index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= idx_size_pre * idx_size_axis * idx_size_post) { + return; + } + + LocT x, y, z; + index_to_dims(index, idx_size_axis, idx_size_pre, x, y, z); + + LocT elem_idx = z * idx_size_post; + + LocT idx_loc = y * idx_stride_axis; + if constexpr (IdxC) { + idx_loc += elem_idx * idx_size_axis + x; + } else { + idx_loc += + elem_to_loc_nd(elem_idx + x, shape.data_, idx_strides.data_); + } + + auto idx_val = absolute_index(indices[idx_loc], axis_size); + + LocT upd_loc = y * upd_stride_axis; + if constexpr (UpdC) { + upd_loc += elem_idx * idx_size_axis + x; + } else { + upd_loc += + elem_to_loc_nd(elem_idx + x, shape.data_, upd_strides.data_); + } + + LocT out_idx = idx_val * idx_size_post + elem_idx * axis_size + x; + + Op{}(out + out_idx, upd[upd_loc]); +} + +} // namespace mlx::core::rocm diff --git a/mlx/backend/rocm/device/scatter_ops.hpp b/mlx/backend/rocm/device/scatter_ops.hpp new file mode 100644 index 0000000000..c8973d39da --- /dev/null +++ b/mlx/backend/rocm/device/scatter_ops.hpp @@ -0,0 +1,44 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/rocm/device/atomic_ops.hpp" + +namespace mlx::core::rocm { + +struct ScatterAssign { + template + __device__ void operator()(T* out, T val) const { + *out = val; + } +}; + +struct ScatterSum { + template + __device__ void operator()(T* out, T val) const { + atomic_add(out, val); + } +}; + +struct ScatterProd { + template + __device__ void operator()(T* out, T val) const { + atomic_prod(out, val); + } +}; + +struct ScatterMax { + template + __device__ void operator()(T* out, T val) const { + atomic_max(out, val); + } +}; + +struct ScatterMin { + template + __device__ void operator()(T* out, T val) const { + atomic_min(out, val); + } +}; + +} // namespace mlx::core::rocm diff --git a/mlx/backend/rocm/device/ternary_ops.hpp b/mlx/backend/rocm/device/ternary_ops.hpp new file mode 100644 index 0000000000..1a12404851 --- /dev/null +++ b/mlx/backend/rocm/device/ternary_ops.hpp @@ -0,0 +1,33 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include +#include + +namespace mlx::core::rocm { + +struct Select { + template + __device__ T operator()(bool condition, T x, T y) { + if constexpr (std::is_same_v) { + // hip_bfloat16 may not work well with ternary operator + if (condition) { + return x; + } else { + return y; + } + } else if constexpr (std::is_same_v) { + if (condition) { + return x; + } else { + return y; + } + } else { + return condition ? x : y; + } + } +}; + +} // namespace mlx::core::rocm diff --git a/mlx/backend/rocm/device/unary_ops.hpp b/mlx/backend/rocm/device/unary_ops.hpp new file mode 100644 index 0000000000..3b31c75303 --- /dev/null +++ b/mlx/backend/rocm/device/unary_ops.hpp @@ -0,0 +1,556 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/rocm/device/fp16_math.hpp" +#include "mlx/backend/rocm/device/utils.hpp" + +#include + +namespace mlx::core::rocm { + +struct Abs { + template + __device__ T operator()(T x) { + if constexpr (std::is_unsigned_v) { + return x; + } else if constexpr (std::is_same_v) { + return fabsf(x); + } else if constexpr (std::is_same_v) { + return fabs(x); + } else if constexpr (std::is_same_v) { + return __habs(x); + } else if constexpr (std::is_same_v) { + return hip_bfloat16(fabsf(static_cast(x))); + } else if constexpr (is_complex_v) { + return make_hipFloatComplex(hypotf(x.x, x.y), 0.0f); + } else { + // For integral types + return abs(x); + } + } +}; + +struct ArcCos { + template + __device__ T operator()(T x) { + if constexpr (std::is_same_v) { + return ::acosf(x); + } else if constexpr (std::is_same_v) { + return ::acos(x); + } else if constexpr (std::is_same_v) { + return __float2half(acosf(__half2float(x))); + } else { + return acos(x); + } + } +}; + +struct ArcCosh { + template + __device__ T operator()(T x) { + if constexpr (std::is_same_v) { + return ::acoshf(x); + } else if constexpr (std::is_same_v) { + return ::acosh(x); + } else if constexpr (std::is_same_v) { + return __float2half(acoshf(__half2float(x))); + } else { + return acosh(x); + } + } +}; + +struct ArcSin { + template + __device__ T operator()(T x) { + if constexpr (std::is_same_v) { + return ::asinf(x); + } else if constexpr (std::is_same_v) { + return ::asin(x); + } else if constexpr (std::is_same_v) { + return __float2half(asinf(__half2float(x))); + } else { + return asin(x); + } + } +}; + +struct ArcSinh { + template + __device__ T operator()(T x) { + if constexpr (std::is_same_v) { + return ::asinhf(x); + } else if constexpr (std::is_same_v) { + return ::asinh(x); + } else if constexpr (std::is_same_v) { + return __float2half(asinhf(__half2float(x))); + } else { + return asinh(x); + } + } +}; + +struct ArcTan { + template + __device__ T operator()(T x) { + if constexpr (std::is_same_v) { + return ::atanf(x); + } else if constexpr (std::is_same_v) { + return ::atan(x); + } else if constexpr (std::is_same_v) { + return __float2half(atanf(__half2float(x))); + } else { + return atan(x); + } + } +}; + +struct ArcTanh { + template + __device__ T operator()(T x) { + if constexpr (std::is_same_v) { + return ::atanhf(x); + } else if constexpr (std::is_same_v) { + return ::atanh(x); + } else if constexpr (std::is_same_v) { + return __float2half(atanhf(__half2float(x))); + } else { + return atanh(x); + } + } +}; + +struct BitwiseInvert { + template + __device__ T operator()(T x) { + if constexpr (std::is_integral_v) { + return ~x; + } else { + // BitwiseInvert only makes sense for integral types + return T{}; + } + } +}; + +struct Ceil { + template + __device__ T operator()(T x) { + if constexpr (std::is_integral_v) { + return x; + } else if constexpr (is_complex_v) { + return T{::ceilf(x.x), ::ceilf(x.y)}; + } else if constexpr (std::is_same_v) { + return ::ceilf(x); + } else if constexpr (std::is_same_v) { + return ::ceil(x); + } else { + return ceil(x); + } + } +}; + +struct Conjugate { + template + __device__ T operator()(T x) { + if constexpr (is_complex_v) { + return hipConjf(x); + } else { + // For non-complex types, conjugate is identity + return x; + } + } +}; + +struct Cos { + template + __device__ T operator()(T x) { + if constexpr (std::is_same_v) { + return cosf(x); + } else if constexpr (std::is_same_v) { + return ::cos(x); + } else if constexpr (std::is_same_v) { + return __float2half(cosf(__half2float(x))); + } else { + return cos(x); + } + } +}; + +struct Cosh { + template + __device__ T operator()(T x) { + if constexpr (std::is_same_v) { + return ::coshf(x); + } else if constexpr (std::is_same_v) { + return ::cosh(x); + } else if constexpr (std::is_same_v) { + return __float2half(coshf(__half2float(x))); + } else { + return cosh(x); + } + } +}; + +struct Erf { + template + __device__ T operator()(T x) { + if constexpr (std::is_same_v || std::is_integral_v) { + return static_cast(erff(static_cast(x))); + } else if constexpr (std::is_same_v) { + return erf(x); + } else if constexpr (std::is_same_v) { + return erf(x); + } else { + return erff(x); + } + } +}; + +struct ErfInv { + template + __device__ T operator()(T x) { + if constexpr (std::is_same_v || std::is_integral_v) { + return static_cast(erfinvf(static_cast(x))); + } else if constexpr (std::is_same_v) { + return erfinv(x); + } else if constexpr (std::is_same_v) { + return erfinv(x); + } else { + return erfinvf(x); + } + } +}; + +struct Exp { + template + __device__ T operator()(T x) { + if constexpr (std::is_same_v) { + return expf(x); + } else if constexpr (std::is_same_v) { + return ::exp(x); + } else if constexpr (std::is_same_v) { + return __float2half(expf(__half2float(x))); + } else { + return exp(x); + } + } +}; + +struct Expm1 { + template + __device__ T operator()(T x) { + if constexpr (std::is_same_v || std::is_integral_v) { + return static_cast(expm1f(static_cast(x))); + } else if constexpr (std::is_same_v) { + return expm1(x); + } else if constexpr (std::is_same_v) { + return expm1(x); + } else { + return expm1f(x); + } + } +}; + +struct Floor { + template + __device__ T operator()(T x) { + if constexpr (std::is_integral_v) { + return x; + } else if constexpr (is_complex_v) { + return T{::floorf(x.x), ::floorf(x.y)}; + } else if constexpr (std::is_same_v) { + return ::floorf(x); + } else if constexpr (std::is_same_v) { + return ::floor(x); + } else { + return floor(x); + } + } +}; + +struct Imag { + template + __device__ auto operator()(T x) { + if constexpr (is_complex_v) { + return x.y; + } else { + // For non-complex types, imaginary part is 0 + return T(0); + } + } +}; + +struct Log { + template + __device__ T operator()(T x) { + if constexpr (std::is_same_v) { + return logf(x); + } else if constexpr (std::is_same_v) { + return ::log(x); + } else if constexpr (std::is_same_v) { + return __float2half(logf(__half2float(x))); + } else { + return log(x); + } + } +}; + +struct Log2 { + template + __device__ T operator()(T x) { + if constexpr (is_complex_v) { + auto y = Log{}(x); + constexpr float ln2 = 0.693147180559945309417232121458176568f; + return {y.x / ln2, y.y / ln2}; + } else if constexpr (std::is_same_v) { + return ::log2f(x); + } else if constexpr (std::is_same_v) { + return ::log2(x); + } else if constexpr (std::is_same_v) { + return __float2half(log2f(__half2float(x))); + } else { + return log2(x); + } + } +}; + +struct Log10 { + template + __device__ T operator()(T x) { + if constexpr (std::is_same_v) { + return ::log10f(x); + } else if constexpr (std::is_same_v) { + return ::log10(x); + } else if constexpr (std::is_same_v) { + return __float2half(log10f(__half2float(x))); + } else { + return log10(x); + } + } +}; + +struct Log1p { + template + __device__ T operator()(T z) { + if constexpr (is_complex_v) { + float x = z.x; + float y = z.y; + float zabs = Abs{}(z).x; + float theta = atan2f(y, x + 1); + if (zabs < 0.5f) { + float r = x * (2 + x) + y * y; + if (r == 0) { // handle underflow + return {x, theta}; + } + return {0.5f * log1pf(r), theta}; + } else { + float z0 = hypotf(x + 1, y); + return {logf(z0), theta}; + } + } else if constexpr (std::is_same_v) { + return log1pf(z); + } else if constexpr (std::is_same_v) { + return ::log1p(z); + } else { + return log1p(z); + } + } +}; + +struct LogicalNot { + __device__ bool operator()(bool x) { + return !x; + } +}; + +struct Negative { + template + __device__ T operator()(T x) { + if constexpr (is_complex_v) { + return make_hipFloatComplex(-x.x, -x.y); + } else { + return -x; + } + } +}; + +struct Real { + template + __device__ auto operator()(T x) { + if constexpr (is_complex_v) { + return x.x; + } else { + // For non-complex types, real part is the value itself + return x; + } + } +}; + +struct Round { + template + __device__ T operator()(T x) { + if constexpr (is_complex_v) { + return {::rintf(x.x), ::rintf(x.y)}; + } else if constexpr (std::is_same_v) { + return ::rintf(x); + } else if constexpr (std::is_same_v) { + return ::rint(x); + } else { + return rint(x); + } + } +}; + +struct Sigmoid { + template + __device__ T operator()(T x) { + if constexpr (std::is_same_v) { + float fx = static_cast(x); + float y = 1.0f / (1.0f + expf(-fabsf(fx))); + return T((fx < 0.0f) ? 1.0f - y : y); + } else if constexpr (std::is_same_v) { + float fx = __half2float(x); + float y = 1.0f / (1.0f + expf(-fabsf(fx))); + return __float2half((fx < 0.0f) ? 1.0f - y : y); + } else { + float fx = static_cast(x); + float y = 1.0f / (1.0f + expf(-fabsf(fx))); + return T((fx < 0.0f) ? 1.0f - y : y); + } + } +}; + +struct Sign { + template + __device__ T operator()(T x) { + if constexpr (std::is_unsigned_v) { + return x != 0; + } else if constexpr (is_complex_v) { + if (x.x == 0 && x.y == 0) { + return x; + } else { + return hipCdivf(x, Abs()(x)); + } + } else if constexpr (std::is_same_v) { + float fx = static_cast(x); + return T((fx > 0.0f) - (fx < 0.0f)); + } else if constexpr (std::is_same_v) { + float fx = __half2float(x); + return __float2half((fx > 0.0f) - (fx < 0.0f)); + } else { + return (x > T(0)) - (x < T(0)); + } + } +}; + +struct Sin { + template + __device__ T operator()(T x) { + if constexpr (std::is_same_v) { + return sinf(x); + } else if constexpr (std::is_same_v) { + return ::sin(x); + } else if constexpr (std::is_same_v) { + return __float2half(sinf(__half2float(x))); + } else { + return sin(x); + } + } +}; + +struct Sinh { + template + __device__ T operator()(T x) { + if constexpr (std::is_same_v) { + return ::sinhf(x); + } else if constexpr (std::is_same_v) { + return ::sinh(x); + } else if constexpr (std::is_same_v) { + return __float2half(sinhf(__half2float(x))); + } else { + return sinh(x); + } + } +}; + +struct Square { + template + __device__ T operator()(T x) { + if constexpr (is_complex_v) { + return hipCmulf(x, x); + } else if constexpr (std::is_same_v) { + float fx = static_cast(x); + return hip_bfloat16(fx * fx); + } else if constexpr (std::is_same_v) { + float fx = __half2float(x); + return __float2half(fx * fx); + } else { + return x * x; + } + } +}; + +struct Sqrt { + template + __device__ T operator()(T x) { + if constexpr (std::is_same_v) { + return ::sqrtf(x); + } else if constexpr (std::is_same_v) { + return ::sqrt(x); + } else if constexpr (std::is_same_v) { + return __float2half(sqrtf(__half2float(x))); + } else { + return sqrt(x); + } + } +}; + +struct Rsqrt { + template + __device__ T operator()(T x) { + if constexpr (is_complex_v) { + return hipCdivf(make_hipFloatComplex(1.0f, 0.0f), Sqrt{}(x)); + } else if constexpr (std::is_same_v) { + return ::rsqrtf(x); + } else if constexpr (std::is_same_v) { + return ::rsqrt(x); + } else if constexpr (std::is_same_v) { + return __float2half(rsqrtf(__half2float(x))); + } else { + return rsqrt(x); + } + } +}; + +struct Tan { + template + __device__ T operator()(T x) { + if constexpr (std::is_same_v) { + return ::tanf(x); + } else if constexpr (std::is_same_v) { + return ::tan(x); + } else if constexpr (std::is_same_v) { + return __float2half(tanf(__half2float(x))); + } else { + return tan(x); + } + } +}; + +struct Tanh { + template + __device__ T operator()(T x) { + if constexpr (std::is_same_v) { + return ::tanhf(x); + } else if constexpr (std::is_same_v) { + return ::tanh(x); + } else if constexpr (std::is_same_v) { + return __float2half(tanhf(__half2float(x))); + } else { + return tanh(x); + } + } +}; + +} // namespace mlx::core::rocm diff --git a/mlx/backend/rocm/device/utils.hpp b/mlx/backend/rocm/device/utils.hpp new file mode 100644 index 0000000000..d9cc3907cd --- /dev/null +++ b/mlx/backend/rocm/device/utils.hpp @@ -0,0 +1,774 @@ +// Copyright © 2025 Apple Inc. + +// This file must not include any host-only code, utilities that work under both +// host and device can be put here. + +#pragma once + +#include "mlx/backend/rocm/device/config.h" + +#include +#include +#include +#include + +#include +#include +#include + +namespace mlx::core::rocm { + +/////////////////////////////////////////////////////////////////////////////// +// Type traits +/////////////////////////////////////////////////////////////////////////////// + +// Type traits for complex types +template +struct is_complex : std::false_type {}; + +template <> +struct is_complex : std::true_type {}; + +template +inline constexpr bool is_complex_v = is_complex::value; + +// Type traits for floating point types (including half precision) +template +inline constexpr bool is_floating_v = + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v; + +// Type traits for inexact types (floating point or complex) +template +inline constexpr bool is_inexact_v = is_floating_v || is_complex_v; + +// Complex type alias +template +using complex_t = hipFloatComplex; + +/////////////////////////////////////////////////////////////////////////////// +// Shape and Strides types +/////////////////////////////////////////////////////////////////////////////// + +// HIP array type (similar to cuda::std::array) +// This is usable from both host and device code +template +struct hip_array { + T data_[N]; + +#ifdef __HIPCC__ + __host__ __device__ T& operator[](int i) { + return data_[i]; + } + __host__ __device__ const T& operator[](int i) const { + return data_[i]; + } + __host__ __device__ constexpr int size() const { + return N; + } + __host__ __device__ T* data() { + return data_; + } + __host__ __device__ const T* data() const { + return data_; + } +#else + T& operator[](int i) { + return data_[i]; + } + const T& operator[](int i) const { + return data_[i]; + } + constexpr int size() const { + return N; + } + T* data() { + return data_; + } + const T* data() const { + return data_; + } +#endif +}; + +// To pass shape/strides to kernels via constant memory, their size must be +// known at compile time. +using Shape = hip_array; +using Strides = hip_array; + +/////////////////////////////////////////////////////////////////////////////// +// Vectorized load/store +/////////////////////////////////////////////////////////////////////////////// + +template +struct alignas(sizeof(T) * N) AlignedVector { + T val[N]; + +#ifdef __HIPCC__ + __device__ T& operator[](int i) { + return val[i]; + } + + __device__ T operator[](int i) const { + return val[i]; + } +#endif +}; + +template +inline __host__ __device__ bool is_aligned(T* x) { + return (reinterpret_cast(x) % (N * sizeof(T))) == 0; +} + +#ifdef __HIPCC__ + +template +inline __device__ AlignedVector unsafe_load_vector( + const T* ptr, + uint32_t offset) { + auto* from = reinterpret_cast*>(ptr); + return from[offset]; +} + +template +inline __device__ AlignedVector load_vector( + const T* ptr, + uint32_t offset) { + if (is_aligned(ptr)) { + auto* from = reinterpret_cast*>(ptr); + return from[offset]; + } else { + AlignedVector v; +#pragma unroll + for (int i = 0; i < N; ++i) { + v[i] = ptr[offset * N + i]; + } + return v; + } +} + +template +inline __device__ AlignedVector +load_vector(const T* ptr, uint32_t offset, SizeT size, T fallback) { + if (is_aligned(ptr) && (offset + 1) * N <= size) { + auto* from = reinterpret_cast*>(ptr); + return from[offset]; + } else { + AlignedVector v; +#pragma unroll + for (int i = 0; i < N; ++i) { + v[i] = (N * offset + i) < size ? ptr[offset * N + i] : fallback; + } + return v; + } +} + +template +inline __device__ AlignedVector load_vector( + const T* ptr, + uint32_t offset, + SizeT size, + int64_t stride, + T fallback) { + if (is_aligned(ptr) && stride == 1 && (offset + 1) * N <= size) { + auto* from = reinterpret_cast*>(ptr); + return from[offset]; + } else { + AlignedVector v; +#pragma unroll + for (int i = 0; i < N; ++i) { + v[i] = + (N * offset + i) < size ? ptr[stride * (offset * N + i)] : fallback; + } + return v; + } +} + +template +inline __device__ void +unsafe_store_vector(T* ptr, uint32_t offset, const AlignedVector& vec) { + auto* to = reinterpret_cast*>(ptr); + to[offset] = vec; +} + +template +inline __device__ void +store_vector(T* ptr, uint32_t offset, const AlignedVector& vec) { + if (is_aligned(ptr)) { + auto* to = reinterpret_cast*>(ptr); + to[offset] = vec; + } else { +#pragma unroll + for (int i = 0; i < N; ++i) { + ptr[offset * N + i] = vec[i]; + } + } +} + +template +inline __device__ void store_vector( + T* ptr, + uint32_t offset, + const AlignedVector& vec, + SizeT size) { + if (is_aligned(ptr) && (offset + 1) * N <= size) { + auto* to = reinterpret_cast*>(ptr); + to[offset] = vec; + } else { + for (int i = 0; (offset * N + i) < size && i < N; ++i) { + ptr[offset * N + i] = vec[i]; + } + } +} + +template +inline __device__ void store_vector( + T* ptr, + uint32_t offset, + const AlignedVector& vec, + SizeT size, + int64_t stride) { + if (is_aligned(ptr) && (offset + 1) * N <= size && stride == 1) { + auto* to = reinterpret_cast*>(ptr); + to[offset] = vec; + } else { + for (int i = 0; (offset * N + i) < size && i < N; ++i) { + ptr[stride * (offset * N + i)] = vec[i]; + } + } +} + +#endif // __HIPCC__ + +/////////////////////////////////////////////////////////////////////////////// +// Utility functions +/////////////////////////////////////////////////////////////////////////////// + +// Ceil division - available on both host and device +template +#ifdef __HIPCC__ +__host__ __device__ +#endif + T ceildiv(T a, T b) { + return (a + b - 1) / b; +} + +// ============================================================================ +// Device-only code below - only compiled when using HIP compiler +// ============================================================================ +#ifdef __HIPCC__ + +/////////////////////////////////////////////////////////////////////////////// +// Numeric limits for device code +/////////////////////////////////////////////////////////////////////////////// + +template +struct numeric_limits; + +template <> +struct numeric_limits { + __device__ static float infinity() { + unsigned int i = 0x7f800000; + return *reinterpret_cast(&i); + } + __device__ static float quiet_NaN() { + unsigned int i = 0x7fc00000; + return *reinterpret_cast(&i); + } + __device__ static constexpr float lowest() { + return -3.402823466e+38f; + } + __device__ static constexpr float max() { + return 3.402823466e+38f; + } +}; + +template <> +struct numeric_limits { + __device__ static double infinity() { + unsigned long long i = 0x7ff0000000000000ULL; + return *reinterpret_cast(&i); + } + __device__ static double quiet_NaN() { + unsigned long long i = 0x7ff8000000000000ULL; + return *reinterpret_cast(&i); + } + __device__ static constexpr double lowest() { + return -1.7976931348623158e+308; + } + __device__ static constexpr double max() { + return 1.7976931348623158e+308; + } +}; + +template <> +struct numeric_limits<__half> { + __device__ static __half infinity() { + return __ushort_as_half(0x7c00); + } + __device__ static __half quiet_NaN() { + return __ushort_as_half(0x7e00); + } + __device__ static __half lowest() { + return __ushort_as_half(0xfbff); + } + __device__ static __half max() { + return __ushort_as_half(0x7bff); + } +}; + +template <> +struct numeric_limits { + __device__ static hip_bfloat16 infinity() { + hip_bfloat16 val; + val.data = 0x7f80; + return val; + } + __device__ static hip_bfloat16 quiet_NaN() { + hip_bfloat16 val; + val.data = 0x7fc0; + return val; + } + __device__ static hip_bfloat16 lowest() { + hip_bfloat16 val; + val.data = 0xff7f; + return val; + } + __device__ static hip_bfloat16 max() { + hip_bfloat16 val; + val.data = 0x7f7f; + return val; + } +}; + +template <> +struct numeric_limits { + __device__ static constexpr int32_t lowest() { + return INT32_MIN; + } + __device__ static constexpr int32_t max() { + return INT32_MAX; + } +}; + +template <> +struct numeric_limits { + __device__ static constexpr int64_t lowest() { + return INT64_MIN; + } + __device__ static constexpr int64_t max() { + return INT64_MAX; + } +}; + +template <> +struct numeric_limits { + __device__ static constexpr uint32_t lowest() { + return 0; + } + __device__ static constexpr uint32_t max() { + return UINT32_MAX; + } +}; + +template <> +struct numeric_limits { + __device__ static constexpr uint64_t lowest() { + return 0; + } + __device__ static constexpr uint64_t max() { + return UINT64_MAX; + } +}; + +template <> +struct numeric_limits { + __device__ static constexpr int8_t lowest() { + return INT8_MIN; + } + __device__ static constexpr int8_t max() { + return INT8_MAX; + } +}; + +template <> +struct numeric_limits { + __device__ static constexpr uint8_t lowest() { + return 0; + } + __device__ static constexpr uint8_t max() { + return UINT8_MAX; + } +}; + +template <> +struct numeric_limits { + __device__ static constexpr int16_t lowest() { + return INT16_MIN; + } + __device__ static constexpr int16_t max() { + return INT16_MAX; + } +}; + +template <> +struct numeric_limits { + __device__ static constexpr uint16_t lowest() { + return 0; + } + __device__ static constexpr uint16_t max() { + return UINT16_MAX; + } +}; + +template <> +struct numeric_limits { + __device__ static constexpr bool lowest() { + return false; + } + __device__ static constexpr bool max() { + return true; + } +}; + +/////////////////////////////////////////////////////////////////////////////// +// Type limits utils (returns infinity for floats, max for integers) +/////////////////////////////////////////////////////////////////////////////// + +template +struct Limits { + __device__ static T max() { + return numeric_limits::max(); + } + __device__ static T min() { + return numeric_limits::lowest(); + } + __device__ static T finite_max() { + return numeric_limits::max(); + } + __device__ static T finite_min() { + return numeric_limits::lowest(); + } +}; + +template +struct Limits< + T, + std::enable_if_t || std::is_same_v>> { + __device__ static T max() { + return numeric_limits::infinity(); + } + __device__ static T min() { + return -numeric_limits::infinity(); + } + __device__ static T finite_max() { + return numeric_limits::max(); + } + __device__ static T finite_min() { + return numeric_limits::lowest(); + } +}; + +template +struct Limits< + T, + std::enable_if_t< + std::is_same_v || std::is_same_v>> { + __device__ static T max() { + return numeric_limits::infinity(); + } + __device__ static T min() { + // Use float infinity for half types to avoid precision issues + return static_cast(-numeric_limits::infinity()); + } + __device__ static T finite_max() { + return numeric_limits::max(); + } + __device__ static T finite_min() { + return numeric_limits::lowest(); + } +}; + +template <> +struct Limits { + __device__ static bool max() { + return true; + } + __device__ static bool min() { + return false; + } + __device__ static bool finite_max() { + return true; + } + __device__ static bool finite_min() { + return false; + } +}; + +template <> +struct numeric_limits { + __device__ static hipFloatComplex lowest() { + return make_hipFloatComplex( + numeric_limits::lowest(), numeric_limits::lowest()); + } + __device__ static hipFloatComplex max() { + return make_hipFloatComplex( + numeric_limits::max(), numeric_limits::max()); + } +}; + +template <> +struct Limits { + __device__ static hipFloatComplex max() { + return make_hipFloatComplex(Limits::max(), Limits::max()); + } + __device__ static hipFloatComplex min() { + return make_hipFloatComplex(Limits::min(), Limits::min()); + } +}; + +/////////////////////////////////////////////////////////////////////////////// +// Indexing utils +/////////////////////////////////////////////////////////////////////////////// + +template +__device__ IdxT +elem_to_loc(IdxT elem, const int* shape, const int64_t* strides, int ndim) { + IdxT loc = 0; + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { + loc += (elem % shape[i]) * IdxT(strides[i]); + elem /= shape[i]; + } + return loc; +} + +// Optimize when the ndim is known at compile time. +template +__device__ IdxT +elem_to_loc_nd(IdxT elem, const int* shape, const int64_t* strides) { + IdxT loc = 0; +#pragma unroll + for (int i = NDIM - 1; i >= 0; --i) { + loc += (elem % shape[i]) * IdxT(strides[i]); + elem /= shape[i]; + } + return loc; +} + +// Two-array version +template +__device__ void elem_to_loc_nd( + IdxT elem, + const int* shape, + const int64_t* a_strides, + const int64_t* b_strides, + IdxT& a_loc, + IdxT& b_loc) { + a_loc = 0; + b_loc = 0; +#pragma unroll + for (int i = NDIM - 1; i >= 0; --i) { + int dim_idx = elem % shape[i]; + a_loc += dim_idx * IdxT(a_strides[i]); + b_loc += dim_idx * IdxT(b_strides[i]); + elem /= shape[i]; + } +} + +// Three-array version +template +__device__ void elem_to_loc_nd( + IdxT elem, + const int* shape, + const int64_t* a_strides, + const int64_t* b_strides, + const int64_t* c_strides, + IdxT& a_loc, + IdxT& b_loc, + IdxT& c_loc) { + a_loc = 0; + b_loc = 0; + c_loc = 0; +#pragma unroll + for (int i = NDIM - 1; i >= 0; --i) { + int dim_idx = elem % shape[i]; + a_loc += dim_idx * IdxT(a_strides[i]); + b_loc += dim_idx * IdxT(b_strides[i]); + c_loc += dim_idx * IdxT(c_strides[i]); + elem /= shape[i]; + } +} + +// Dynamic ndim two-array version +template +__device__ void elem_to_loc( + IdxT elem, + const int* shape, + const int64_t* a_strides, + const int64_t* b_strides, + int ndim, + IdxT& a_loc, + IdxT& b_loc) { + a_loc = 0; + b_loc = 0; + for (int i = ndim - 1; i >= 0; --i) { + int dim_idx = elem % shape[i]; + a_loc += dim_idx * IdxT(a_strides[i]); + b_loc += dim_idx * IdxT(b_strides[i]); + elem /= shape[i]; + } +} + +// Dynamic ndim three-array version +template +__device__ void elem_to_loc( + IdxT elem, + const int* shape, + const int64_t* a_strides, + const int64_t* b_strides, + const int64_t* c_strides, + int ndim, + IdxT& a_loc, + IdxT& b_loc, + IdxT& c_loc) { + a_loc = 0; + b_loc = 0; + c_loc = 0; + for (int i = ndim - 1; i >= 0; --i) { + int dim_idx = elem % shape[i]; + a_loc += dim_idx * IdxT(a_strides[i]); + b_loc += dim_idx * IdxT(b_strides[i]); + c_loc += dim_idx * IdxT(c_strides[i]); + elem /= shape[i]; + } +} + +/////////////////////////////////////////////////////////////////////////////// +// Elem to loc in a loop utils +/////////////////////////////////////////////////////////////////////////////// + +template +struct LoopedElemToLoc { + int dim; + LoopedElemToLoc inner_looper; + OffsetT offset{0}; + int index{0}; + + __device__ LoopedElemToLoc(int dim) : dim(dim), inner_looper(dim - 1) {} + + __device__ void next(const int* shape, const int64_t* strides) { + if (dim == 0) { + return; + } + index++; + offset += OffsetT(strides[dim - 1]); + if (index >= shape[dim - 1]) { + index = 0; + inner_looper.next(shape, strides); + offset = inner_looper.offset; + } + } + + __device__ void next(int n, const int* shape, const int64_t* strides) { + if (dim == 0) { + return; + } + index += n; + offset += n * OffsetT(strides[dim - 1]); + + if (index >= shape[dim - 1]) { + int extra = index - shape[dim - 1]; + if (extra >= shape[dim - 1]) { + inner_looper.next(1 + extra / shape[dim - 1], shape, strides); + extra = extra % shape[dim - 1]; + } else { + inner_looper.next(shape, strides); + } + index = 0; + offset = inner_looper.offset; + if (extra > 0) { + next(extra, shape, strides); + } + } + } + + __device__ OffsetT location() { + return offset; + } +}; + +template +struct LoopedElemToLoc<1, true, OffsetT> { + int dim; + OffsetT offset{0}; + int index{0}; + + __device__ LoopedElemToLoc(int dim) : dim(dim) {} + + __device__ void next(const int* shape, const int64_t* strides) { + index++; + if (dim > 1) { + offset = elem_to_loc(index, shape, strides, dim); + } else { + offset += OffsetT(strides[0]); + } + } + + __device__ void next(int n, const int* shape, const int64_t* strides) { + index += n; + if (dim > 1) { + offset = elem_to_loc(index, shape, strides, dim); + } else { + offset = index * OffsetT(strides[0]); + } + } + + __device__ OffsetT location() { + return offset; + } +}; + +template +struct LoopedElemToLoc<1, false, OffsetT> { + OffsetT offset{0}; + + __device__ LoopedElemToLoc(int) {} + + __device__ void next(const int*, const int64_t* strides) { + offset += OffsetT(strides[0]); + } + + __device__ void next(int n, const int*, const int64_t* strides) { + offset += n * OffsetT(strides[0]); + } + + __device__ OffsetT location() { + return offset; + } +}; + +/////////////////////////////////////////////////////////////////////////////// +// Thread/block index helpers +/////////////////////////////////////////////////////////////////////////////// + +// Get the thread index in the block +__device__ inline int thread_index() { + return threadIdx.x + threadIdx.y * blockDim.x + + threadIdx.z * blockDim.x * blockDim.y; +} + +// Get the block index in the grid +__device__ inline int block_index() { + return blockIdx.x + blockIdx.y * gridDim.x + + blockIdx.z * gridDim.x * gridDim.y; +} + +// Get the global thread index +__device__ inline int global_thread_index() { + return thread_index() + + block_index() * (blockDim.x * blockDim.y * blockDim.z); +} + +#endif // __HIPCC__ + +} // namespace mlx::core::rocm diff --git a/mlx/backend/rocm/device_info.cpp b/mlx/backend/rocm/device_info.cpp new file mode 100644 index 0000000000..a3d780e90c --- /dev/null +++ b/mlx/backend/rocm/device_info.cpp @@ -0,0 +1,140 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/gpu/device_info.h" +#include "mlx/backend/rocm/utils.h" + +#include + +#include +#include +#include +#include + +namespace mlx::core { + +namespace { + +std::string format_uuid(const hipUUID& uuid) { + char buf[64]; + snprintf( + buf, + sizeof(buf), + "GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x", + (unsigned char)uuid.bytes[0], + (unsigned char)uuid.bytes[1], + (unsigned char)uuid.bytes[2], + (unsigned char)uuid.bytes[3], + (unsigned char)uuid.bytes[4], + (unsigned char)uuid.bytes[5], + (unsigned char)uuid.bytes[6], + (unsigned char)uuid.bytes[7], + (unsigned char)uuid.bytes[8], + (unsigned char)uuid.bytes[9], + (unsigned char)uuid.bytes[10], + (unsigned char)uuid.bytes[11], + (unsigned char)uuid.bytes[12], + (unsigned char)uuid.bytes[13], + (unsigned char)uuid.bytes[14], + (unsigned char)uuid.bytes[15]); + return buf; +} + +const std::unordered_map>& +device_info_impl(int device_index) { + // Static cache of device properties + static auto all_devices = []() { + // Get device count + int count = 0; + (void)hipGetDeviceCount(&count); + + // Collect info for all devices + struct DeviceInfo { + std::unordered_map> info; + }; + + std::vector devices; + + for (int i = 0; i < count; ++i) { + hipDeviceProp_t prop; + (void)hipGetDeviceProperties(&prop, i); + + DeviceInfo dev; + dev.info["device_name"] = std::string(prop.name); + + // Format UUID + dev.info["uuid"] = format_uuid(prop.uuid); + + // Architecture string (e.g., "gfx1011") + dev.info["architecture"] = std::string(prop.gcnArchName); + + // PCI bus ID (domain:bus:device.function) + char pci_id[32]; + snprintf( + pci_id, + sizeof(pci_id), + "%04x:%02x:%02x.0", + prop.pciDomainID, + prop.pciBusID, + prop.pciDeviceID); + dev.info["pci_bus_id"] = std::string(pci_id); + + // Compute capability equivalent for AMD (GCN version) + dev.info["compute_capability_major"] = static_cast(prop.major); + dev.info["compute_capability_minor"] = static_cast(prop.minor); + + devices.push_back(std::move(dev)); + } + return devices; + }(); + + if (device_index < 0 || + device_index >= static_cast(all_devices.size())) { + static auto empty = + std::unordered_map>(); + return empty; + } + + // Return a copy with fresh memory info + // Using thread_local to avoid locks while keeping free_memory fresh + thread_local auto device_info_copy = + std::unordered_map>(); + + device_info_copy = all_devices[device_index].info; + + // Get fresh memory info using hipMemGetInfo + size_t free_mem, total_mem; + + int prev_device; + (void)hipGetDevice(&prev_device); + (void)hipSetDevice(device_index); + (void)hipMemGetInfo(&free_mem, &total_mem); + (void)hipSetDevice(prev_device); + + device_info_copy["free_memory"] = free_mem; + device_info_copy["total_memory"] = total_mem; + + return device_info_copy; +} + +} // anonymous namespace + +namespace gpu { + +bool is_available() { + return true; +} + +int device_count() { + int count = 0; + (void)hipGetDeviceCount(&count); + return count; +} + +const std::unordered_map>& +device_info(int device_index) { + return device_info_impl(device_index); +} + +} // namespace gpu + +} // namespace mlx::core diff --git a/mlx/backend/rocm/distributed.hip b/mlx/backend/rocm/distributed.hip new file mode 100644 index 0000000000..f548177370 --- /dev/null +++ b/mlx/backend/rocm/distributed.hip @@ -0,0 +1,132 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/allocator.h" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/backend/gpu/copy.h" +#include "mlx/distributed/primitives.h" +#include "mlx/primitives.h" + +#include +#include + +namespace mlx::core::distributed { + +void AllReduce::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + assert(inputs.size() == 1); + assert(outputs.size() == 1); + + auto& s = stream(); + auto& encoder = rocm::get_command_encoder(s); + auto set_input_output = [&](const array& in, + array& out) -> std::pair { + if (!in.flags().row_contiguous) { + copy_gpu(in, out, CopyType::General, s); + return {out, out}; + } else if (in.is_donatable()) { + out.copy_shared_buffer(in); + return {in, out}; + } else { + out.set_data(mlx::core::rocm::malloc_async(out.nbytes(), encoder)); + return {in, out}; + } + }; + + auto [input, output] = set_input_output(inputs[0], outputs[0]); + + encoder.set_input_array(input); + encoder.set_output_array(output); + + switch (reduce_type_) { + case Sum: + distributed::detail::all_sum(group(), input, output, s); + break; + case Max: + distributed::detail::all_max(group(), input, output, s); + break; + case Min: + distributed::detail::all_min(group(), input, output, s); + break; + default: + throw std::runtime_error( + "Only all reduce sum, max, and min are supported."); + } +} + +void AllGather::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + assert(inputs.size() == 1); + assert(outputs.size() == 1); + + auto& s = stream(); + auto& encoder = rocm::get_command_encoder(s); + + auto ensure_contiguous = [&s, &encoder](const array& x) { + if (x.flags().row_contiguous) { + return x; + } else { + array x_copy = contiguous_copy_gpu(x, s); + encoder.add_temporary(x_copy); + return x_copy; + } + }; + + auto input = ensure_contiguous(inputs[0]); + outputs[0].set_data(mlx::core::rocm::malloc_async(outputs[0].nbytes(), encoder)); + + encoder.set_input_array(input); + encoder.set_output_array(outputs[0]); + + distributed::detail::all_gather(group(), input, outputs[0], s); +} + +void ReduceScatter::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + assert(inputs.size() == 1); + assert(outputs.size() == 1); + + auto& s = stream(); + auto& encoder = rocm::get_command_encoder(s); + + auto ensure_contiguous = [&s, &encoder](const array& x) { + if (x.flags().row_contiguous) { + return x; + } else { + array x_copy = contiguous_copy_gpu(x, s); + encoder.add_temporary(x_copy); + return x_copy; + } + }; + + auto input = ensure_contiguous(inputs[0]); + outputs[0].set_data(mlx::core::rocm::malloc_async(outputs[0].nbytes(), encoder)); + + encoder.set_input_array(input); + encoder.set_output_array(outputs[0]); + + switch (reduce_type_) { + case Sum: + distributed::detail::sum_scatter(group(), input, outputs[0], s); + break; + default: + throw std::runtime_error("Only sum scatter is supported. "); + } +} + +void Send::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + throw std::runtime_error("Send::eval_gpu not yet implemented for ROCm"); +} + +void Recv::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + throw std::runtime_error("Recv::eval_gpu not yet implemented for ROCm"); +} + +} // namespace mlx::core::distributed diff --git a/mlx/backend/rocm/eval.cpp b/mlx/backend/rocm/eval.cpp new file mode 100644 index 0000000000..310572bbfd --- /dev/null +++ b/mlx/backend/rocm/eval.cpp @@ -0,0 +1,188 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/gpu/eval.h" +#include "mlx/backend/rocm/allocator.h" +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/event.h" +#include "mlx/primitives.h" +#include "mlx/scheduler.h" + +#include +#include +#include +#include + +namespace mlx::core::gpu { + +void init() { + // Initialize the SELECTED default GPU's primary context — not device 0. On a + // multi-GPU host, creating a context/queue on the other GPU (the APU) too is + // what differs from HIP_VISIBLE_DEVICES, and that cross-device queue coexistence + // is what wedges the discrete GPU's command queue over a TB5 link. Touch only + // the chosen device so the runtime behaves as if it were the only one. + auto d = mlx::core::default_device(); + if (d.type == mlx::core::Device::gpu) { + (void)hipSetDevice(d.index); + } + hipFree(nullptr); +} + +void new_stream(Stream s) { + // Bind the stream's device FIRST (creates/selects its Device + context), then + // warm the event pool on that device — creating the HipEvent before binding + // would put it (and its queue interaction) on whatever device is current. + rocm::get_command_encoder(s); + rocm::HipEvent(hipEventDefault); +} + +// Ops whose kernels corrupt when batched into a multi-node HIP graph with +// neighbors (a ROCm CLR kernarg-pool interaction; found by per-op force-execute +// bisection). Isolate them: flush the graph before AND after so they run alone. +static bool is_graph_split_op(const char* name) { + static const bool no_split = std::getenv("MLX_NO_CONCAT_SPLIT") != nullptr; + if (no_split) return false; + return std::strcmp(name, "Concatenate") == 0; +} + +void eval(array& arr) { + auto outputs = arr.outputs(); + auto& encoder = rocm::get_command_encoder(arr.primitive().stream()); + const bool split = + rocm::use_hip_graphs() && is_graph_split_op(arr.primitive().name()); + if (split) { + encoder.commit(); // flush ops accumulated before this one + } + // Bind the stream's device before eval_gpu so output buffers allocate on the + // same device the kernels run on. Otherwise (multi-GPU) outputs land on + // whatever device is current (often device 0) while kernels run on the + // stream's device, stranding the model on the wrong GPU. + encoder.device().make_current(); + { + std::vector inputs; + if (arr.is_tracer()) { + inputs = arr.inputs(); + } + arr.primitive().eval_gpu(arr.inputs(), outputs); + } + + for (auto& in : arr.inputs()) { + if (in.data_shared_ptr() != arr.data_shared_ptr()) { + encoder.add_temporary(in); + } + } + for (auto& s : arr.siblings()) { + encoder.add_temporary(s); + } + + if (rocm::use_hip_graphs()) { + auto& stream = arr.primitive().stream(); + if (split || encoder.needs_commit()) { + scheduler::notify_new_task(stream); + encoder.add_completed_handler( + [stream]() { scheduler::notify_task_completion(stream); }); + encoder.commit(); + } + } else { + encoder.maybe_commit(); + } + + // Bisection: batch ops [0, FORCE_FROM) into graphs (per cap), force-execute + // (commit+sync = correct) every op >= FORCE_FROM. The smallest FORCE_FROM that + // turns the output to garbage pinpoints the first op whose batching breaks. + if (rocm::use_hip_graphs()) { + static const int force_from = std::getenv("MLX_GRAPH_FORCE_FROM") + ? std::atoi(std::getenv("MLX_GRAPH_FORCE_FROM")) + : -1; + if (force_from >= 0) { + static int gidx = 0; + int my = gidx++; + if (my >= force_from) { + encoder.commit(); + encoder.synchronize(); + } + static const bool ftr = std::getenv("MLX_GRAPH_FORCE_TRACE") != nullptr; + if (ftr && my >= force_from - 6 && my <= force_from + 1) + fprintf(stderr, "[ff] op %d : %s\n", my, arr.primitive().name()); + } + } +} + +void finalize(Stream s) { + rocm::get_command_encoder(s).commit(); +} + +void synchronize(Stream s) { + rocm::get_command_encoder(s).synchronize(); +} + +void clear_streams() { + rocm::clear_all_encoders(); +} + +} // namespace mlx::core::gpu + +// --- GPU memcpy for direct KV cache writes --- +extern "C" void mlx_gpu_memcpy_async(void* dst, const void* src, size_t bytes) { + // Use the SELECTED default device's stream, not Device::gpu (which is the + // device TYPE = gpu index 0). On a multi-GPU box, --device 1 would otherwise + // memcpy KV data on device 0's stream while the data lives on device 1. + auto& enc = mlx::core::rocm::get_command_encoder( + mlx::core::default_stream(mlx::core::default_device())); + enc.launch_kernel([=](hipStream_t stream) { + (void)hipMemcpyAsync(dst, src, bytes, hipMemcpyDeviceToDevice, stream); + }); +} + +// --- Arena + Graph wrappers (called from engine code without HIP headers) --- +namespace mlx::core { + +bool gpu_arena_begin(size_t capacity) { + return rocm::allocator().arena().begin(capacity); +} +void gpu_arena_reset() { + rocm::allocator().arena().reset(); +} +size_t gpu_arena_desc_used() { + return rocm::allocator().arena().desc_used(); +} +void gpu_arena_reset_to(size_t byte_mark, size_t desc_mark) { + rocm::allocator().arena().reset_to(byte_mark, desc_mark); +} +void gpu_arena_set_paused(bool p) { + rocm::allocator().arena().set_paused(p); +} +void gpu_arena_end() { + rocm::allocator().arena().end(); +} +size_t gpu_arena_used() { + return rocm::allocator().arena().used(); +} +bool gpu_arena_active() { + return rocm::allocator().arena().active(); +} + +static rocm::CommandEncoder& graph_encoder() { + return rocm::get_command_encoder(default_stream(default_device())); +} + +bool gpu_graph_begin_capture() { + graph_encoder().begin_capture(); + return true; +} +bool gpu_graph_end_capture() { + return graph_encoder().end_capture(); +} +bool gpu_graph_replay() { + return graph_encoder().replay(/*sync=*/true); +} +bool gpu_graph_replay_async() { + return graph_encoder().replay(/*sync=*/false); +} +void gpu_graph_reset() { + graph_encoder().reset_graph(); +} +bool gpu_graph_available() { + return graph_encoder().has_graph(); +} + +} // namespace mlx::core diff --git a/mlx/backend/rocm/event.h b/mlx/backend/rocm/event.h new file mode 100644 index 0000000000..f0237b7a40 --- /dev/null +++ b/mlx/backend/rocm/event.h @@ -0,0 +1,80 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/rocm/allocator.h" +#include "mlx/backend/rocm/utils.h" +#include "mlx/stream.h" + +#include + +#include + +namespace mlx::core::rocm { + +// RAII-managed move-only wrapper of hipEvent_t. +struct HipEventHandle : public HipHandle { + HipEventHandle(int flags); + int flags; + // The HIP device the event was created on. A hipEvent is bound to its device: + // recording it on a stream of a DIFFERENT device is invalid and on a multi-GPU + // host hangs the queue. The pool must hand back an event from the right device. + int device{0}; +}; + +// Wrapper of native HIP event. It can synchronize between GPU streams, or wait +// on GPU stream in CPU stream, but can not wait on CPU stream. +class HipEvent { + public: + explicit HipEvent(int flags); + ~HipEvent(); + + HipEvent(HipEvent&&) = default; + HipEvent& operator=(HipEvent&&) = default; + + HipEvent(const HipEvent&) = delete; + HipEvent& operator=(const HipEvent&) = delete; + + void wait(); + void wait(hipStream_t stream); + void record(hipStream_t stream); + + // Return whether the recorded kernels have completed. Note that this method + // returns true if record() has not been called. + bool completed() const; + + private: + HipEventHandle event_; +}; + +// Event that can synchronize between CPU and GPU. It is much slower than +// HipEvent so the latter should always be preferred when possible. +class AtomicEvent { + public: + AtomicEvent(); + + void wait(uint64_t value); + void wait(hipStream_t stream, uint64_t value); + void wait(Stream s, uint64_t value); + void signal(uint64_t value); + void signal(hipStream_t stream, uint64_t value); + void signal(Stream s, uint64_t value); + bool is_signaled(uint64_t value) const; + uint64_t value() const; + + private: + std::atomic* atomic() const { + return atomic_; + } + + // The completion counter lives in PINNED HOST memory, not device memory. The + // GPU writes it (hipStreamWriteValue64) and the CPU polls it (wait()). Device + // memory — even fine-grained — is not reliably CPU-coherent on a discrete GPU + // over a non-coherent link (e.g. an R9700 in a TB5 eGPU enclosure), so the + // host poll would spin forever. Pinned host memory is the canonical GPU->host + // signaling path and works on both the integrated APU and a discrete dGPU. + std::shared_ptr mem_; + std::atomic* atomic_{nullptr}; +}; + +} // namespace mlx::core::rocm diff --git a/mlx/backend/rocm/event.hip b/mlx/backend/rocm/event.hip new file mode 100644 index 0000000000..2d0b4e4a95 --- /dev/null +++ b/mlx/backend/rocm/event.hip @@ -0,0 +1,395 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/allocator.h" +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/event.h" +#include "mlx/event.h" +#include "mlx/scheduler.h" + +#include +#include +#include +#include +#include + +#include + +namespace mlx::core { + +namespace rocm { + +/////////////////////////////////////////////////////////////////////////////// +// HipEvent implementations +/////////////////////////////////////////////////////////////////////////////// + +namespace { + +// Manage cached hipEvent_t objects. Keyed by (device, flags): a hipEvent is +// bound to the device current when it was created, and recording it on another +// device's stream is invalid (hangs the queue on a multi-GPU host). So pool and +// hand back events per device, creating new ones on the CURRENT device. +struct HipEventPool { + static HipEventHandle create(int flags) { + int dev = 0; + (void)hipGetDevice(&dev); + auto& cache = cache_for(dev, flags); + if (cache.empty()) { + return HipEventHandle(flags); // created on the current device + } else { + HipEventHandle ret = std::move(cache.back()); + cache.pop_back(); + return ret; + } + } + + static void release(HipEventHandle event) { + cache_for(event.device, event.flags).push_back(std::move(event)); + } + + static std::vector& cache_for(int device, int flags) { + static std::map, std::vector> cache; + return cache[{device, flags}]; + } +}; + +} // namespace + +HipEventHandle::HipEventHandle(int flags) : flags(flags) { + CHECK_HIP_ERROR(hipEventCreateWithFlags(&handle_, flags)); + assert(handle_ != nullptr); + (void)hipGetDevice(&device); // event is bound to the current device +} + +HipEvent::HipEvent(int flags) : event_(HipEventPool::create(flags)) {} + +HipEvent::~HipEvent() { + HipEventPool::release(std::move(event_)); +} + +void HipEvent::wait() { + // Spin-wait with hipEventQuery instead of hipEventSynchronize. + // On iGPU, the blocking wait in hipEventSynchronize causes CPU-GPU + // contention since they share compute resources. Polling is cheaper. + // Use progressive backoff to reduce hipEventQuery call overhead. + for (int spins = 0; hipEventQuery(event_) != hipSuccess; spins++) { + if (spins < 100) { + // Tight spin for fast completions + } else if (spins < 1000) { + _mm_pause(); // x86 pause hint (reduces power, avoids pipeline stall) + } else { + std::this_thread::yield(); + } + } +} + +void HipEvent::wait(hipStream_t stream) { + (void)hipStreamWaitEvent(stream, event_, 0); +} + +void HipEvent::record(hipStream_t stream) { + (void)hipEventRecord(event_, stream); +} + +bool HipEvent::completed() const { + return hipEventQuery(event_) == hipSuccess; +} + +// Wraps HipEvent with a few features: +// 1. The class can be copied. +// 2. Make wait/record work with CPU streams. +// 3. Add checks for waiting on un-recorded event. +class CopyableHipEvent { + public: + CopyableHipEvent() + : event_(std::make_shared( + hipEventDisableTiming)) {} + // Note: hipEventBlockingSync removed — on iGPU the blocking wait + // contends with GPU for CPU resources. Polling is cheaper. + + void wait() { + event_->wait(); + } + + void wait(Stream s) { + if (s.device == mlx::core::Device::cpu) { + scheduler::enqueue(s, [*this]() mutable { + check_recorded(); + event_->wait(); + }); + } else { + check_recorded(); + auto& encoder = rocm::get_command_encoder(s); + encoder.commit(); + event_->wait(encoder.stream()); + } + } + + void record(Stream s) { + if (s.device == mlx::core::Device::cpu) { + throw std::runtime_error("HipEvent can not wait on CPU stream."); + } else { + auto& encoder = rocm::get_command_encoder(s); + encoder.commit(); + event_->record(encoder.stream()); + recorded_ = true; + } + } + + bool is_signaled() const { + return recorded_ && event_->completed(); + } + + private: + void check_recorded() const { + if (!recorded_) { + throw std::runtime_error( + "Should not wait on a HipEvent before recording."); + } + } + + std::shared_ptr event_; + bool recorded_{false}; +}; + +/////////////////////////////////////////////////////////////////////////////// +// AtomicEvent implementations +/////////////////////////////////////////////////////////////////////////////// + +namespace { + +void signal_atomic_callback(void* data) { + auto* pair = static_cast*, uint64_t>*>(data); + pair->first->store(pair->second); + delete pair; +} + +// Pool of 8-byte SIGNAL-MEMORY counters for AtomicEvent. The counter is used two +// ways: the GPU waits on it via hipStreamWaitValue64 (cross-stream Fence) and +// signals it via hipStreamWriteValue64, AND the CPU polls it (host wait()). HIP +// REQUIRES the pointer passed to hipStreamWaitValue64/WriteValue64 to be +// allocated with the hipMallocSignalMemory flag — plain hipHostMalloc memory is +// accepted by the call (returns success) but the GPU-side wait never observes the +// value, so on a discrete GPU the stream spins forever (busy 100%, mem-ctrl 0%). +// Signal memory is host-accessible, so the CPU poll works too. Recycle the blocks +// (hipFree blocks on a discrete GPU); intentionally leaked at process exit. +struct PinnedCounterPool { + std::mutex m; + std::vector free_list; + + void* acquire() { + { + std::lock_guard lk(m); + if (!free_list.empty()) { + void* p = free_list.back(); + free_list.pop_back(); + return p; + } + } + void* p = nullptr; + CHECK_HIP_ERROR(hipHostMalloc( + &p, sizeof(std::atomic), + hipHostMallocMapped | hipHostMallocCoherent)); + return p; + } + + void release(void* p) { + std::lock_guard lk(m); + free_list.push_back(p); + } +}; + +PinnedCounterPool& counter_pool() { + static PinnedCounterPool* pool = new PinnedCounterPool; + return *pool; +} + +} // namespace + +AtomicEvent::AtomicEvent() { + // Completion counter in pinned, device-mapped host memory: the GPU signals it + // (hipStreamWriteValue64) and the CPU polls it (wait()). On a discrete GPU over + // a non-coherent link (TB5 eGPU) device memory is NOT CPU-visible, so the + // counter must live in host memory or the host poll spins forever. Drawn from a + // recycling pool — see PinnedCounterPool (hipHostFree blocks on a discrete GPU). + void* p = counter_pool().acquire(); + atomic_ = static_cast*>(p); + atomic_->store(0, std::memory_order_release); + mem_ = std::shared_ptr(p, [](void* q) { counter_pool().release(q); }); +} + +void AtomicEvent::wait(uint64_t value) { + auto* ac = atomic(); + while (ac->load(std::memory_order_acquire) < value) { + std::this_thread::yield(); + } +} + +void AtomicEvent::wait(hipStream_t stream, uint64_t value) { + // Do NOT use hipStreamWaitValue64 on the host counter: it requires + // hipMallocSignalMemory and silently never observes a plain pinned-host value, + // wedging the queue. The counter is signaled by a host callback when the + // producer stream reaches the signal point, so block the host here until the + // value lands; subsequent work on this stream is correctly ordered after it. + wait(value); +} + +void AtomicEvent::wait(Stream s, uint64_t value) { + if (s.device == mlx::core::Device::cpu) { + scheduler::enqueue(s, [*this, value]() mutable { wait(value); }); + } else { + auto& encoder = get_command_encoder(s); + encoder.commit(); + wait(encoder.stream(), value); + // Keep the buffer alive until the wait is finished + encoder.add_completed_handler([buf = mem_]() {}); + } +} + +void AtomicEvent::signal(uint64_t value) { + atomic()->store(value, std::memory_order_release); +} + +void AtomicEvent::signal(hipStream_t stream, uint64_t value) { + // Signal the host-resident counter from the stream via a host callback that + // fires when the stream reaches this point. We do NOT use hipStreamWriteValue64 + // here: it REQUIRES a hipMallocSignalMemory pointer, but returns success on a + // plain pinned-host counter while never actually landing the write — so the + // host poll (wait()) spins forever and the discrete-GPU queue wedges (busy + // 100%, mem-ctrl 0%). The host callback always delivers the value correctly. + auto* data = new std::pair*, uint64_t>(atomic(), value); + CHECK_HIP_ERROR(hipLaunchHostFunc(stream, signal_atomic_callback, data)); +} + +void AtomicEvent::signal(Stream s, uint64_t value) { + if (s.device == mlx::core::Device::cpu) { + scheduler::enqueue(s, [*this, value]() mutable { signal(value); }); + } else { + auto& encoder = get_command_encoder(s); + encoder.commit(); + signal(encoder.stream(), value); + // Keep the buffer alive until it's signaled + encoder.add_completed_handler([buf = mem_]() {}); + } +} + +bool AtomicEvent::is_signaled(uint64_t value) const { + return atomic()->load() >= value; +} + +uint64_t AtomicEvent::value() const { + return atomic()->load(); +} + +} // namespace rocm + +/////////////////////////////////////////////////////////////////////////////// +// Event implementations +/////////////////////////////////////////////////////////////////////////////// + +namespace { + +struct EventImpl { + std::unique_ptr hip; + std::unique_ptr atomic; + // Set when the event is "signaled" purely on the host because it was created + // during HIP graph capture: recording a real completion event onto the + // captured stream would bake it into the graph (never fires under capture), + // so we skip the device record and treat the event as already satisfied. The + // recorded compute kernels still execute on replay; completion is handled by + // an explicit stream sync after hipGraphLaunch. + bool host_signaled{false}; + + bool is_created() const { + return hip || atomic; + } + + void ensure_created(Stream s, uint64_t signal_value) { + if (is_created()) { + return; + } + if (s.device == mlx::core::Device::cpu || signal_value > 1) { + atomic = std::make_unique(); + } else { + // Bind the stream's device current before creating the hipEvent — it is + // bound to whatever device is current at creation, and recording it on a + // different device's stream hangs the queue on a multi-GPU host. + (void)rocm::get_command_encoder(s); + hip = std::make_unique(); + } + } +}; + +} // namespace + +Event::Event(Stream s) : stream_(s) { + event_ = std::shared_ptr( + new EventImpl(), [](void* ptr) { delete static_cast(ptr); }); +} + +void Event::wait() { + auto* event = static_cast(event_.get()); + // Capture-created event: nothing to wait for on the host (see EventImpl). + if (event->host_signaled) { + return; + } + assert(event->is_created()); + if (event->hip) { + assert(value() == 1); + event->hip->wait(); + } else { + event->atomic->wait(value()); + } +} + +void Event::wait(Stream s) { + auto* event = static_cast(event_.get()); + if (event->host_signaled) { + return; + } + assert(event->is_created()); + if (event->hip) { + assert(value() == 1); + event->hip->wait(s); + } else { + event->atomic->wait(s, value()); + } +} + +void Event::signal(Stream s) { + auto* event = static_cast(event_.get()); + event->ensure_created(s, value()); + // During graph capture, do NOT record a completion event onto the captured + // stream — it would become a graph node that never fires under capture and + // would deadlock eval()'s wait(). Mark it satisfied on the host instead; the + // compute kernels are recorded and run on replay, after which the caller + // performs an explicit stream sync. + if (!(s.device == mlx::core::Device::cpu) && + rocm::get_command_encoder(s).capturing()) { + event->host_signaled = true; + return; + } + if (event->hip) { + assert(value() == 1); + event->hip->record(s); + } else { + event->atomic->signal(s, value()); + } +} + +bool Event::is_signaled() const { + auto* event = static_cast(event_.get()); + if (event->host_signaled) { + return true; + } + if (!event->is_created()) { + return false; + } + if (event->hip) { + assert(value() == 1); + return event->hip->is_signaled(); + } else { + return event->atomic->is_signaled(value()); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/rocm/fence.cpp b/mlx/backend/rocm/fence.cpp new file mode 100644 index 0000000000..00392c4c1f --- /dev/null +++ b/mlx/backend/rocm/fence.cpp @@ -0,0 +1,29 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/fence.h" +#include "mlx/backend/rocm/event.h" + +namespace mlx::core { + +struct FenceImpl { + uint32_t count; + rocm::AtomicEvent event; +}; + +Fence::Fence(Stream s) { + fence_ = std::shared_ptr( + new FenceImpl{0}, [](void* ptr) { delete static_cast(ptr); }); +} + +void Fence::wait(Stream s, const array&) { + auto* fence = static_cast(fence_.get()); + fence->event.wait(fence->count); +} + +void Fence::update(Stream s, const array&, bool cross_device) { + auto* fence = static_cast(fence_.get()); + fence->count++; + fence->event.signal(s, fence->count); +} + +} // namespace mlx::core diff --git a/mlx/backend/rocm/flash_attention.hip b/mlx/backend/rocm/flash_attention.hip new file mode 100644 index 0000000000..9e5d1de4c9 --- /dev/null +++ b/mlx/backend/rocm/flash_attention.hip @@ -0,0 +1,671 @@ +// Copyright © 2025 Apple Inc. + +#define _USE_MATH_DEFINES + +#include "mlx/backend/gpu/copy.h" +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/allocator.h" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/dtype_utils.h" + +#include +#include +#include + +namespace mlx::core { +namespace rocm { + +struct AttnParams { + int B; + int H; + int D_q; // Query/Key head dimension + int D_v; // Value head dimension + int qL; + int kL; + int gqa_factor; + float scale; + int64_t Q_strides[3]; + int64_t K_strides[3]; + int64_t V_strides[3]; + int64_t O_strides[3]; + int64_t M_strides[4]; // Mask strides [B, H, qL, kL] + bool has_mask; +}; + +// Standard flash attention kernel (D_q == D_v, no array mask) +template < + typename T, + bool do_causal, + int D, + int BLOCK_M = 128, + int BLOCK_N = 64> +__global__ __launch_bounds__(BLOCK_M) void kernel_sdpa_flash_opt( + const T* __restrict__ Q, + const T* __restrict__ K, + const T* __restrict__ V, + T* __restrict__ O, + const T* __restrict__ sinks, + const AttnParams params) { + // Grid: (H, ceil(qL / BLOCK_M), B) + // Block: (BLOCK_M, 1, 1) -> 128 threads + + int batch_idx = blockIdx.z; + int head_idx = blockIdx.x; + int kv_head_idx = head_idx / params.gqa_factor; + int q_seq_start = blockIdx.y * BLOCK_M; + int thread_idx = threadIdx.x; // 0 to BLOCK_M - 1 + int q_seq_idx = q_seq_start + thread_idx; + + if (q_seq_start >= params.qL) + return; + + const T* Q_ptr = Q + batch_idx * params.Q_strides[0] + + head_idx * params.Q_strides[1] + q_seq_idx * params.Q_strides[2]; + T* O_ptr = O + batch_idx * params.O_strides[0] + + head_idx * params.O_strides[1] + q_seq_idx * params.O_strides[2]; + + bool valid_q = q_seq_idx < params.qL; + + typedef float U; + + // Registers for Q and O - use max of 256 for MLA value dimension + U q[256]; + U o[256]; + + if (valid_q) { +#pragma unroll + for (int i = 0; i < D; i++) { + q[i] = static_cast(Q_ptr[i]); + o[i] = 0.f; + } + } + + U max_score = -__int_as_float(0x7f7fffff); // -FLT_MAX + U sum_exp_score = 0.f; + + if (sinks) { + max_score = static_cast(sinks[head_idx]); + sum_exp_score = 1.f; + } + + __shared__ T K_sh[BLOCK_N][D]; + __shared__ T V_sh[BLOCK_N][D]; + + const int K_seq_len = params.kL; + + for (int k_seq_start = 0; k_seq_start < K_seq_len; k_seq_start += BLOCK_N) { + if constexpr (do_causal) { + int earliest_valid_key = (K_seq_len - params.qL) + q_seq_start; + int block_end_key = k_seq_start + BLOCK_N - 1; + if (earliest_valid_key < block_end_key) { + int max_q_seq_idx = min(q_seq_start + BLOCK_M - 1, params.qL - 1); + int latest_valid_key = (K_seq_len - params.qL) + max_q_seq_idx; + if (latest_valid_key < k_seq_start) { + continue; // Block is completely causal-masked + } + } + } + + __syncthreads(); + + // Collaborative loading of K_sh and V_sh + // BLOCK_N * D total elements = 64 * 128 = 8192. + // We have BLOCK_M = 128 threads. + // Each thread loads 8192 / 128 = 64 elements. + const int elements_per_thread = (BLOCK_N * D) / BLOCK_M; + +#pragma unroll + for (int i = 0; i < elements_per_thread; i++) { + int load_idx = i * BLOCK_M + thread_idx; + int r = load_idx / D; + int c = load_idx % D; + int k_idx = k_seq_start + r; + if (k_idx < K_seq_len) { + K_sh[r][c] = + K[batch_idx * params.K_strides[0] + + kv_head_idx * params.K_strides[1] + k_idx * params.K_strides[2] + + c]; + V_sh[r][c] = + V[batch_idx * params.V_strides[0] + + kv_head_idx * params.V_strides[1] + k_idx * params.V_strides[2] + + c]; + } else { + K_sh[r][c] = static_cast(0.f); + V_sh[r][c] = static_cast(0.f); + } + } + + __syncthreads(); + + if (valid_q) { + // Loop over keys in the shared memory + for (int i = 0; i < BLOCK_N; i++) { + int k_idx = k_seq_start + i; + if (k_idx >= K_seq_len) + break; + + bool use_key = true; + if constexpr (do_causal) { + use_key = k_idx <= (K_seq_len - params.qL + q_seq_idx); + } + + if (use_key) { + U score = 0.f; + +#pragma unroll 16 + for (int j = 0; j < D; j++) { + score += q[j] * static_cast(K_sh[i][j]); + } + + score *= params.scale; + + U new_max = max(max_score, score); + U factor = expf(max_score - new_max); + U exp_score = expf(score - new_max); + + max_score = new_max; + sum_exp_score = sum_exp_score * factor + exp_score; + +#pragma unroll 16 + for (int j = 0; j < D; j++) { + o[j] = o[j] * factor + exp_score * static_cast(V_sh[i][j]); + } + } + } + } + } + + if (valid_q) { + U inv_sum = sum_exp_score == 0 ? 0.f : 1.0f / sum_exp_score; +#pragma unroll 16 + for (int i = 0; i < D; i++) { + O_ptr[i] = static_cast(o[i] * inv_sum); + } + } +} + +// MLA flash attention kernel with array mask support +// Supports different Q and V dimensions and additive mask (pe_scores) +// Note: BLOCK_N=32 to fit shared memory constraints (K_sh: 24KB + V_sh: 32KB = +// 56KB < 64KB) +template < + typename T, + bool do_causal, + int D_Q, + int D_V, + int BLOCK_M = 64, + int BLOCK_N = 32> +__global__ __launch_bounds__(BLOCK_M) void kernel_sdpa_flash_mla( + const T* __restrict__ Q, + const T* __restrict__ K, + const T* __restrict__ V, + const T* __restrict__ mask, // Additive mask (pe_scores) [B, H, qL, kL] + T* __restrict__ O, + const T* __restrict__ sinks, + const AttnParams params) { + // Grid: (H, ceil(qL / BLOCK_M), B) + // Block: (BLOCK_M, 1, 1) + + int batch_idx = blockIdx.z; + int head_idx = blockIdx.x; + int kv_head_idx = head_idx / params.gqa_factor; + int q_seq_start = blockIdx.y * BLOCK_M; + int thread_idx = threadIdx.x; + int q_seq_idx = q_seq_start + thread_idx; + + if (q_seq_start >= params.qL) + return; + + const T* Q_ptr = Q + batch_idx * params.Q_strides[0] + + head_idx * params.Q_strides[1] + q_seq_idx * params.Q_strides[2]; + T* O_ptr = O + batch_idx * params.O_strides[0] + + head_idx * params.O_strides[1] + q_seq_idx * params.O_strides[2]; + + // Mask pointer for this query position + const T* M_ptr = params.has_mask + ? (mask + batch_idx * params.M_strides[0] + + head_idx * params.M_strides[1] + q_seq_idx * params.M_strides[2]) + : nullptr; + + bool valid_q = q_seq_idx < params.qL; + + typedef float U; + + // Registers for Q and O + U q[D_Q]; + U o[D_V]; + + if (valid_q) { +#pragma unroll + for (int i = 0; i < D_Q; i++) { + q[i] = static_cast(Q_ptr[i]); + } +#pragma unroll + for (int i = 0; i < D_V; i++) { + o[i] = 0.f; + } + } + + U max_score = -__int_as_float(0x7f7fffff); // -FLT_MAX + U sum_exp_score = 0.f; + + if (sinks) { + max_score = static_cast(sinks[head_idx]); + sum_exp_score = 1.f; + } + + __shared__ T K_sh[BLOCK_N][D_Q]; + __shared__ T V_sh[BLOCK_N][D_V]; + + const int K_seq_len = params.kL; + + for (int k_seq_start = 0; k_seq_start < K_seq_len; k_seq_start += BLOCK_N) { + if constexpr (do_causal) { + int earliest_valid_key = (K_seq_len - params.qL) + q_seq_start; + int block_end_key = k_seq_start + BLOCK_N - 1; + if (earliest_valid_key < block_end_key) { + int max_q_seq_idx = min(q_seq_start + BLOCK_M - 1, params.qL - 1); + int latest_valid_key = (K_seq_len - params.qL) + max_q_seq_idx; + if (latest_valid_key < k_seq_start) { + continue; + } + } + } + + __syncthreads(); + + // Collaborative loading of K_sh (D_Q elements per row) + { + const int total_k_elements = BLOCK_N * D_Q; + const int k_per_thread = (total_k_elements + BLOCK_M - 1) / BLOCK_M; +#pragma unroll + for (int i = 0; i < k_per_thread; i++) { + int load_idx = i * BLOCK_M + thread_idx; + if (load_idx < total_k_elements) { + int r = load_idx / D_Q; + int c = load_idx % D_Q; + int k_idx = k_seq_start + r; + if (k_idx < K_seq_len) { + K_sh[r][c] = + K[batch_idx * params.K_strides[0] + + kv_head_idx * params.K_strides[1] + + k_idx * params.K_strides[2] + c]; + } else { + K_sh[r][c] = static_cast(0.f); + } + } + } + } + + // Collaborative loading of V_sh (D_V elements per row) + { + const int total_v_elements = BLOCK_N * D_V; + const int v_per_thread = (total_v_elements + BLOCK_M - 1) / BLOCK_M; +#pragma unroll + for (int i = 0; i < v_per_thread; i++) { + int load_idx = i * BLOCK_M + thread_idx; + if (load_idx < total_v_elements) { + int r = load_idx / D_V; + int c = load_idx % D_V; + int k_idx = k_seq_start + r; + if (k_idx < K_seq_len) { + V_sh[r][c] = + V[batch_idx * params.V_strides[0] + + kv_head_idx * params.V_strides[1] + + k_idx * params.V_strides[2] + c]; + } else { + V_sh[r][c] = static_cast(0.f); + } + } + } + } + + __syncthreads(); + + if (valid_q) { +// Loop over keys in the shared memory +#pragma unroll 4 + for (int i = 0; i < BLOCK_N; i++) { + int k_idx = k_seq_start + i; + if (k_idx >= K_seq_len) + break; + + bool use_key = true; + if constexpr (do_causal) { + use_key = k_idx <= (K_seq_len - params.qL + q_seq_idx); + } + + if (use_key) { + // Compute Q @ K score + U score = 0.f; + +#pragma unroll 16 + for (int j = 0; j < D_Q; j++) { + score += q[j] * static_cast(K_sh[i][j]); + } + + score *= params.scale; + + // Add mask bias (pe_scores) if present + if (M_ptr) { + score += static_cast(M_ptr[k_idx * params.M_strides[3]]); + } + + U new_max = max(max_score, score); + U factor = expf(max_score - new_max); + U exp_score = expf(score - new_max); + + max_score = new_max; + sum_exp_score = sum_exp_score * factor + exp_score; + +#pragma unroll 16 + for (int j = 0; j < D_V; j++) { + o[j] = o[j] * factor + exp_score * static_cast(V_sh[i][j]); + } + } + } + } + } + + if (valid_q) { + U inv_sum = sum_exp_score == 0 ? 0.f : 1.0f / sum_exp_score; +#pragma unroll 16 + for (int i = 0; i < D_V; i++) { + O_ptr[i] = static_cast(o[i] * inv_sum); + } + } +} + +} // namespace rocm + +bool supports_sdpa_flash( + const array& q, + const array& k, + const array& v, + bool has_mask, + bool has_arr_mask, + bool do_causal, + bool output_logsumexp) { + if (output_logsumexp) { + return false; + } + if (q.dtype() != float32 && q.dtype() != float16 && q.dtype() != bfloat16) { + return false; + } + const int D_q = q.shape(-1); + const int D_v = v.shape(-1); + + // Standard attention dimensions (D_q == D_v) + bool standard_dims = (D_q == 64 || D_q == 96 || D_q == 128 || D_q == 256); + + // MLA attention dimensions (D_q=192, D_v=256) + bool mla_dims = (D_q == 192 && D_v == 256); + + if (D_q == D_v && standard_dims) { + if (D_q == 256 && q.dtype() == float32) { + return false; + } + // Standard attention: no array mask needed for flash kernel + return !has_arr_mask; + } else if (mla_dims) { + // MLA attention: supports array mask (additive bias) + return true; + } + return false; +} + +void sdpa_flash( + const array& q, + const array& k, + const array& v, + float scale, + array& o, + bool do_causal, + const std::optional& mask, + const std::optional& sinks, + Stream s) { + auto& d = rocm::device(s.device); + auto& encoder = d.get_command_encoder(s); + + int B = q.shape(0); + int H = q.shape(1); + int qL = q.shape(2); + int kL = k.shape(2); + int D_q = q.shape(3); + int D_v = v.shape(3); + int gqa_factor = q.shape(1) / k.shape(1); + + o.set_data(mlx::core::rocm::malloc_async(o.nbytes(), encoder)); + + rocm::AttnParams params; + params.B = B; + params.H = H; + params.D_q = D_q; + params.D_v = D_v; + params.qL = qL; + params.kL = kL; + params.gqa_factor = gqa_factor; + params.scale = scale; + params.Q_strides[0] = q.strides(0); + params.Q_strides[1] = q.strides(1); + params.Q_strides[2] = q.strides(2); + params.K_strides[0] = k.strides(0); + params.K_strides[1] = k.strides(1); + params.K_strides[2] = k.strides(2); + params.V_strides[0] = v.strides(0); + params.V_strides[1] = v.strides(1); + params.V_strides[2] = v.strides(2); + params.O_strides[0] = o.strides(0); + params.O_strides[1] = o.strides(1); + params.O_strides[2] = o.strides(2); + + params.has_mask = mask.has_value(); + if (mask) { + params.M_strides[0] = mask->strides(0); + params.M_strides[1] = mask->strides(1); + params.M_strides[2] = mask->strides(2); + params.M_strides[3] = mask->strides(3); + } + + bool has_sinks = sinks.has_value(); + bool has_mask_val = mask.has_value(); + bool is_mla = (D_q == 192 && D_v == 256); + + encoder.set_input_array(q); + encoder.set_input_array(k); + encoder.set_input_array(v); + if (mask) { + encoder.set_input_array(*mask); + } + if (sinks) { + encoder.set_input_array(*sinks); + } + encoder.set_output_array(o); + + { + if (is_mla) { + // MLA kernel with D_q=192, D_v=256 + // Use BLOCK_N=32 to fit shared memory (K_sh: 24KB + V_sh: 32KB = 56KB < + // 64KB limit) + constexpr int BLOCK_M = 64; + constexpr int BLOCK_N = 32; + int grid_y = (qL + BLOCK_M - 1) / BLOCK_M; + dim3 grid_dim(H, grid_y, B); + dim3 block_dim(BLOCK_M, 1, 1); + + auto launch_mla_kernel = [&](auto type_tag, auto causal_tag) { + using DataType = decltype(type_tag); + constexpr bool causal = decltype(causal_tag)::value; + + encoder.add_kernel_node( + &rocm::kernel_sdpa_flash_mla< + DataType, + causal, + 192, + 256, + BLOCK_M, + BLOCK_N>, + grid_dim, + block_dim, + 0, + gpu_ptr(q), + gpu_ptr(k), + gpu_ptr(v), + has_mask_val ? gpu_ptr(*mask) : nullptr, + gpu_ptr(o), + has_sinks ? gpu_ptr(*sinks) : nullptr, + params); + }; + + if (o.dtype() == float32) { + if (do_causal) + launch_mla_kernel(float(), std::true_type()); + else + launch_mla_kernel(float(), std::false_type()); + } else if (o.dtype() == float16) { + if (do_causal) + launch_mla_kernel(__half(), std::true_type()); + else + launch_mla_kernel(__half(), std::false_type()); + } else if (o.dtype() == bfloat16) { + if (do_causal) + launch_mla_kernel(hip_bfloat16(), std::true_type()); + else + launch_mla_kernel(hip_bfloat16(), std::false_type()); + } + } else { + // Standard flash attention kernel + constexpr int BLOCK_M = 128; + constexpr int BLOCK_N = 64; + int grid_y = (qL + BLOCK_M - 1) / BLOCK_M; + dim3 grid_dim(H, grid_y, B); + dim3 block_dim(BLOCK_M, 1, 1); + + auto launch_kernel = + [&](auto type_tag, auto causal_tag, auto headdim_tag) { + using DataType = decltype(type_tag); + constexpr bool causal = decltype(causal_tag)::value; + constexpr int headdim = decltype(headdim_tag)::value; + + encoder.add_kernel_node( + &rocm::kernel_sdpa_flash_opt< + DataType, + causal, + headdim, + BLOCK_M, + BLOCK_N>, + grid_dim, + block_dim, + 0, + gpu_ptr(q), + gpu_ptr(k), + gpu_ptr(v), + gpu_ptr(o), + has_sinks ? gpu_ptr(*sinks) : nullptr, + params); + }; + + if (o.dtype() == float32) { + if (do_causal) { + if (D_q == 64) + launch_kernel( + float(), std::true_type(), std::integral_constant()); + else if (D_q == 96) + launch_kernel( + float(), std::true_type(), std::integral_constant()); + else if (D_q == 128) + launch_kernel( + float(), std::true_type(), std::integral_constant()); + } else { + if (D_q == 64) + launch_kernel( + float(), std::false_type(), std::integral_constant()); + else if (D_q == 96) + launch_kernel( + float(), std::false_type(), std::integral_constant()); + else if (D_q == 128) + launch_kernel( + float(), std::false_type(), std::integral_constant()); + } + } else if (o.dtype() == float16) { + if (do_causal) { + if (D_q == 64) + launch_kernel( + __half(), std::true_type(), std::integral_constant()); + else if (D_q == 96) + launch_kernel( + __half(), std::true_type(), std::integral_constant()); + else if (D_q == 128) + launch_kernel( + __half(), std::true_type(), std::integral_constant()); + else if (D_q == 256) + launch_kernel( + __half(), std::true_type(), std::integral_constant()); + } else { + if (D_q == 64) + launch_kernel( + __half(), std::false_type(), std::integral_constant()); + else if (D_q == 96) + launch_kernel( + __half(), std::false_type(), std::integral_constant()); + else if (D_q == 128) + launch_kernel( + __half(), + std::false_type(), + std::integral_constant()); + else if (D_q == 256) + launch_kernel( + __half(), + std::false_type(), + std::integral_constant()); + } + } else if (o.dtype() == bfloat16) { + if (do_causal) { + if (D_q == 64) + launch_kernel( + hip_bfloat16(), + std::true_type(), + std::integral_constant()); + else if (D_q == 96) + launch_kernel( + hip_bfloat16(), + std::true_type(), + std::integral_constant()); + else if (D_q == 128) + launch_kernel( + hip_bfloat16(), + std::true_type(), + std::integral_constant()); + else if (D_q == 256) + launch_kernel( + hip_bfloat16(), + std::true_type(), + std::integral_constant()); + } else { + if (D_q == 64) + launch_kernel( + hip_bfloat16(), + std::false_type(), + std::integral_constant()); + else if (D_q == 96) + launch_kernel( + hip_bfloat16(), + std::false_type(), + std::integral_constant()); + else if (D_q == 128) + launch_kernel( + hip_bfloat16(), + std::false_type(), + std::integral_constant()); + else if (D_q == 256) + launch_kernel( + hip_bfloat16(), + std::false_type(), + std::integral_constant()); + } + } + } + } +} + +} // namespace mlx::core diff --git a/mlx/backend/rocm/flash_attention_wmma.hip b/mlx/backend/rocm/flash_attention_wmma.hip new file mode 100644 index 0000000000..82bbcd8543 --- /dev/null +++ b/mlx/backend/rocm/flash_attention_wmma.hip @@ -0,0 +1,462 @@ +// WMMA-accelerated flash attention for RDNA 3+ (gfx1100+) +// +// Uses rocwmma 16x16x16 bf16→f32 tiles for Q@K^T and P@V matmuls. +// Implements FlashAttention-2 online softmax. +// +// BLOCK_M=64, BLOCK_N=64, 128 threads (4 waves), each wave owns 16 query rows. +// Shared memory ~50 KB for D=128. + +#include "mlx/backend/gpu/copy.h" +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/allocator.h" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/dtype_utils.h" + +#include +#include + +#if defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__ +#define ROCM_FA_WMMA 1 +#include +#else +#define ROCM_FA_WMMA 0 +#endif + +namespace mlx::core { +namespace rocm { + +struct FAWmmaParams { + int B, H, D, qL, kL, gqa_factor; + float scale; + int64_t Q_strides[3], K_strides[3], V_strides[3], O_strides[3]; +}; + +// Helper: collaborative load from global to shared memory +template +__device__ void load_tile( + T* __restrict__ dst, + const T* __restrict__ src_base, + int64_t row_stride, + int rows, + int cols, + int valid_rows, + int tid) { + const int total = rows * cols; + const int per_t = (total + NTHREADS - 1) / NTHREADS; + for (int i = 0; i < per_t; i++) { + int idx = i * NTHREADS + tid; + if (idx < total) { + int r = idx / cols; + int c = idx % cols; + dst[r * STRIDE + c] = (r < valid_rows) ? + src_base[r * row_stride + c] : static_cast(0.0f); + } + } + // Zero padding columns + const int pad = STRIDE - cols; + if (pad > 0) { + const int total_pad = rows * pad; + const int per_p = (total_pad + NTHREADS - 1) / NTHREADS; + for (int i = 0; i < per_p; i++) { + int idx = i * NTHREADS + tid; + if (idx < total_pad) { + int r = idx / pad; + dst[r * STRIDE + cols + (idx % pad)] = static_cast(0.0f); + } + } + } +} + +template < + typename T, + bool do_causal, + int D, + int BLOCK_M = 64, + int BLOCK_N = 64> +__global__ void __launch_bounds__(128) + kernel_sdpa_flash_wmma( + const T* __restrict__ Q, + const T* __restrict__ K, + const T* __restrict__ V, + T* __restrict__ O, + const FAWmmaParams params) { +#if ROCM_FA_WMMA + constexpr int WT = 16; // WMMA tile size + // Stage Q/K/V one DW-wide head-dim slice at a time so the LDS footprint is + // capped at the D=128 layout for any head dim (D=256 needs two slices). This + // keeps the kernel within RDNA's 64 KB LDS budget while still using WMMA. + constexpr int DW = (D < 128) ? D : 128; + constexpr int N_DCHUNKS = (D + DW - 1) / DW; + constexpr int Q_PAD = DW + 4; + constexpr int KV_PAD = DW + 4; + constexpr int S_PAD = BLOCK_N + 4; + constexpr int P_PAD = BLOCK_N + 4; + constexpr int M_TILES = BLOCK_M / WT; + constexpr int N_TILES = BLOCK_N / WT; + constexpr int DW_TILES = DW / WT; + constexpr int D_TILES = (D + WT - 1) / WT; + constexpr int NTHREADS = 128; + + const int bid_b = blockIdx.z; + const int bid_h = blockIdx.x; + const int bid_kv_h = bid_h / params.gqa_factor; + const int q_start = blockIdx.y * BLOCK_M; + const int tid = threadIdx.x; + const int wave = tid / 32; + const int lane = tid % 32; + + if (q_start >= params.qL) return; + + // ---- Shared memory layout ---- + // Persistent: m[BLOCK_M], l[BLOCK_M] + // Tile A: Q_sh [BLOCK_M][Q_PAD] (bf16, loaded once) + // Tile B: KV_sh[BLOCK_N][KV_PAD] (bf16, K then P_bf16) + // Tile C: S_sh [BLOCK_M][S_PAD] (f32, scores then V_bf16) + extern __shared__ char smem[]; + float* m_arr = reinterpret_cast(smem); + float* l_arr = m_arr + BLOCK_M; + T* Q_sh = reinterpret_cast(l_arr + BLOCK_M); + T* KV_sh = Q_sh + BLOCK_M * Q_PAD; + float* S_sh = reinterpret_cast(KV_sh + BLOCK_N * KV_PAD); + + // Fragment types + using frag_a = rocwmma::fragment; + using frag_b_col = rocwmma::fragment; + using frag_b_row = rocwmma::fragment; + using frag_acc = rocwmma::fragment; + + // Output accumulators: each wave owns D_TILES [16x16] f32 tiles + frag_acc o_acc[D_TILES]; + for (int d = 0; d < D_TILES; d++) + rocwmma::fill_fragment(o_acc[d], 0.0f); + + // Init online softmax state + if (tid < BLOCK_M) { + m_arr[tid] = -1e30f; + l_arr[tid] = 0.0f; + } + __syncthreads(); + + const T* Q_base = Q + bid_b * params.Q_strides[0] + + bid_h * params.Q_strides[1] + + q_start * params.Q_strides[2]; + const int q_valid = min(BLOCK_M, params.qL - q_start); + + // ---- K/V block loop ---- + for (int k_start = 0; k_start < params.kL; k_start += BLOCK_N) { + + // Causal skip: if entire K block is after all queries' causal limits + if constexpr (do_causal) { + int last_q = q_start + min(BLOCK_M, params.qL - q_start) - 1; + int max_k_allowed = (params.kL - params.qL) + last_q; + if (k_start > max_k_allowed) break; + } + + int k_valid = min(BLOCK_N, params.kL - k_start); + + const T* K_base = K + bid_b * params.K_strides[0] + + bid_kv_h * params.K_strides[1] + + k_start * params.K_strides[2]; + + // ---- S = Q @ K^T via WMMA, accumulated over the head dim in DW slices ---- + // Each wave computes S[wave*16 : (wave+1)*16, 0:BLOCK_N]. For D>DW the Q/K + // slices are staged and consumed one DW-wide chunk at a time so LDS stays + // bounded; s_acc carries the partial dot products across chunks. + { + frag_acc s_acc[N_TILES]; + for (int n = 0; n < N_TILES; n++) + rocwmma::fill_fragment(s_acc[n], 0.0f); + + for (int dc = 0; dc < N_DCHUNKS; dc++) { + const int d0 = dc * DW; + const int dwid = min(DW, D - d0); + load_tile( + Q_sh, Q_base + d0, params.Q_strides[2], BLOCK_M, dwid, q_valid, tid); + load_tile( + KV_sh, K_base + d0, params.K_strides[2], BLOCK_N, dwid, k_valid, tid); + __syncthreads(); + + for (int d = 0; d < DW_TILES; d++) { + frag_a q_frag; + rocwmma::load_matrix_sync( + q_frag, Q_sh + wave * WT * Q_PAD + d * WT, Q_PAD); + + for (int n = 0; n < N_TILES; n++) { + frag_b_col k_frag; + // col_major load of K[n*16][d*16] gives K^T + rocwmma::load_matrix_sync( + k_frag, KV_sh + n * WT * KV_PAD + d * WT, KV_PAD); + rocwmma::mma_sync(s_acc[n], q_frag, k_frag, s_acc[n]); + } + } + __syncthreads(); + } + + // Scale and store S to shared memory + for (int n = 0; n < N_TILES; n++) { + for (int e = 0; e < s_acc[n].num_elements; e++) + s_acc[n].x[e] *= params.scale; + rocwmma::store_matrix_sync( + S_sh + wave * WT * S_PAD + n * WT, s_acc[n], S_PAD, + rocwmma::mem_row_major); + } + } + __syncthreads(); + + // ---- Online softmax (scalar, 64 threads handle 64 rows) ---- + float my_scale_old = 0.0f; + if (tid < BLOCK_M) { + int q_idx = q_start + tid; + bool valid = q_idx < params.qL; + float old_m = m_arr[tid]; + float old_l = l_arr[tid]; + + float new_m = old_m; + if (valid) { + for (int j = 0; j < BLOCK_N && (k_start + j) < params.kL; j++) { + bool use = true; + if constexpr (do_causal) + use = (k_start + j) <= (params.kL - params.qL + q_idx); + if (use) + new_m = fmaxf(new_m, S_sh[tid * S_PAD + j]); + } + } + + float scale_old = (old_m > -1e29f) ? expf(old_m - new_m) : 0.0f; + float row_sum = 0.0f; + + for (int j = 0; j < BLOCK_N; j++) { + bool use = valid && (k_start + j) < params.kL; + if constexpr (do_causal) + use = use && ((k_start + j) <= (params.kL - params.qL + q_idx)); + + if (use) { + float p = expf(S_sh[tid * S_PAD + j] - new_m); + S_sh[tid * S_PAD + j] = p; + row_sum += p; + } else { + S_sh[tid * S_PAD + j] = 0.0f; + } + } + + m_arr[tid] = new_m; + l_arr[tid] = old_l * scale_old + row_sum; + my_scale_old = scale_old; + } + // Broadcast scale_old to all lanes in each wave via shared memory + __shared__ float wave_scale[BLOCK_M]; + if (tid < BLOCK_M) wave_scale[tid] = my_scale_old; + __syncthreads(); + + // ---- Rescale O accumulators by scale_old ---- + // Each wave rescales its 16-row O tiles. + // Since fragment layout is opaque, store→scale→reload via shared memory. + { + // Use KV_sh area as f32 temp (it held K which we no longer need) + float* o_tmp = reinterpret_cast(KV_sh); + constexpr int OT_PAD = WT + 4; + + for (int d = 0; d < D_TILES; d++) { + rocwmma::store_matrix_sync( + o_tmp + wave * WT * OT_PAD, o_acc[d], OT_PAD, + rocwmma::mem_row_major); + __syncthreads(); + + // Scale rows (only 16 threads needed per wave) + if (lane < WT) { + int row = wave * WT + lane; + float s = wave_scale[row]; + for (int c = 0; c < WT; c++) + o_tmp[wave * WT * OT_PAD + lane * OT_PAD + c] *= s; + } + __syncthreads(); + + rocwmma::load_matrix_sync( + o_acc[d], o_tmp + wave * WT * OT_PAD, OT_PAD, + rocwmma::mem_row_major); + __syncthreads(); + } + } + + // ---- Convert P (f32 in S_sh) to bf16 in KV_sh for WMMA P@V ---- + { + int total = BLOCK_M * BLOCK_N; + int per_t = (total + NTHREADS - 1) / NTHREADS; + for (int i = 0; i < per_t; i++) { + int idx = i * NTHREADS + tid; + if (idx < total) { + int r = idx / BLOCK_N; + int c = idx % BLOCK_N; + KV_sh[r * P_PAD + c] = static_cast(S_sh[r * S_PAD + c]); + } + } + // Zero padding + int pad_total = BLOCK_M * 4; + int pad_per_t = (pad_total + NTHREADS - 1) / NTHREADS; + for (int i = 0; i < pad_per_t; i++) { + int idx = i * NTHREADS + tid; + if (idx < pad_total) + KV_sh[(idx / 4) * P_PAD + BLOCK_N + (idx % 4)] = static_cast(0.0f); + } + } + __syncthreads(); + + // ---- O += P @ V via WMMA, output head dim tiled in DW slices ---- + // P (bf16) stays resident in KV_sh; V is streamed one DW-wide slice into + // S_sh (the f32 scores are no longer needed) and consumed immediately, so + // a full D=256 V tile never has to coexist with P in LDS. + { + T* P_sh = KV_sh; + T* V_sh = reinterpret_cast(S_sh); + const T* V_base = V + bid_b * params.V_strides[0] + + bid_kv_h * params.V_strides[1] + + k_start * params.V_strides[2]; + + for (int dc = 0; dc < N_DCHUNKS; dc++) { + const int d0 = dc * DW; + const int dwid = min(DW, D - d0); + load_tile( + V_sh, V_base + d0, params.V_strides[2], BLOCK_N, dwid, k_valid, tid); + __syncthreads(); + + for (int dl = 0; dl < DW_TILES; dl++) { + const int dglobal = dc * DW_TILES + dl; + if (dglobal >= D_TILES) + break; + for (int n = 0; n < N_TILES; n++) { + frag_a p_frag; + frag_b_row v_frag; + rocwmma::load_matrix_sync( + p_frag, P_sh + wave * WT * P_PAD + n * WT, P_PAD); + rocwmma::load_matrix_sync( + v_frag, V_sh + n * WT * KV_PAD + dl * WT, KV_PAD); + rocwmma::mma_sync(o_acc[dglobal], p_frag, v_frag, o_acc[dglobal]); + } + } + __syncthreads(); + } + } + __syncthreads(); + } // end K/V loop + + // ---- Finalize: normalize O and write to global ---- + { + float* o_tmp = reinterpret_cast(Q_sh); // Reuse Q_sh + constexpr int OT_PAD = WT + 4; + + for (int d = 0; d < D_TILES; d++) { + rocwmma::store_matrix_sync( + o_tmp + wave * WT * OT_PAD, o_acc[d], OT_PAD, + rocwmma::mem_row_major); + __syncthreads(); + + if (lane < WT) { + int row = wave * WT + lane; + int q_idx = q_start + row; + if (q_idx < params.qL) { + float inv_l = (l_arr[row] > 0.0f) ? (1.0f / l_arr[row]) : 0.0f; + T* dst = O + bid_b * params.O_strides[0] + + bid_h * params.O_strides[1] + + q_idx * params.O_strides[2] + + d * WT; + float* src = o_tmp + wave * WT * OT_PAD + lane * OT_PAD; + for (int c = 0; c < WT && (d * WT + c) < D; c++) + dst[c] = static_cast(src[c] * inv_l); + } + } + __syncthreads(); + } + } +#endif // ROCM_FA_WMMA +} + +} // namespace rocm + +// ---- Host interface ---- + +bool supports_sdpa_flash_wmma( + const array& q, const array& k, const array& v, + bool has_arr_mask, bool output_logsumexp) { + // Host-side check: always enabled when compiled for WMMA-capable targets. + // The kernel itself guards with #if ROCM_FA_WMMA for device code. + if (output_logsumexp || has_arr_mask) return false; + if (q.dtype() != bfloat16 && q.dtype() != float16) return false; + int D = q.shape(-1); + if (D != v.shape(-1)) return false; + return (D == 64 || D == 128 || D == 256); +} + +// Shared-memory (LDS) bytes the WMMA flash kernel needs for a given head dim. +// The kernel stages Q/K/V one DW=min(D,128)-wide head-dim slice at a time, so +// the footprint is capped at the D=128 layout regardless of D. Callers compare +// this against the running device's LDS budget before selecting the kernel. +int sdpa_flash_wmma_smem(int D) { + constexpr int BM = 64, BN = 64; + const int DW = (D < 128) ? D : 128; + return 2 * BM * (int)sizeof(float) + + BM * (DW + 4) * (int)sizeof(uint16_t) + + BN * (DW + 4) * (int)sizeof(uint16_t) + + BM * (BN + 4) * (int)sizeof(float); +} + +void sdpa_flash_wmma( + const array& q, const array& k, const array& v, + float scale, array& o, bool do_causal, Stream s) { + auto& d = rocm::device(s.device); + auto& enc = d.get_command_encoder(s); + + int B = q.shape(0), H = q.shape(1), qL = q.shape(2), kL = k.shape(2); + int D = q.shape(3); + + o.set_data(mlx::core::rocm::malloc_async(o.nbytes(), enc)); + + rocm::FAWmmaParams p{}; + p.B = B; p.H = H; p.D = D; p.qL = qL; p.kL = kL; + p.gqa_factor = H / k.shape(1); + p.scale = scale; + for (int i = 0; i < 3; i++) { + p.Q_strides[i] = q.strides(i); + p.K_strides[i] = k.strides(i); + p.V_strides[i] = v.strides(i); + p.O_strides[i] = o.strides(i); + } + + constexpr int BM = 64, BN = 64; + dim3 grid(H, (qL + BM - 1) / BM, B); + dim3 block(128); + + // Shared memory: m/l + Q + KV + S, with Q/KV sized to one DW-wide slice. + int smem = sdpa_flash_wmma_smem(D); + + enc.set_input_array(q); + enc.set_input_array(k); + enc.set_input_array(v); + enc.set_output_array(o); + + auto launch = [&](auto type_tag, auto causal_tag, auto dim_tag) { + using DT = decltype(type_tag); + constexpr bool C = decltype(causal_tag)::value; + constexpr int DD = decltype(dim_tag)::value; + enc.add_kernel_node_ex( + &rocm::kernel_sdpa_flash_wmma, + grid, block, static_cast(smem), + gpu_ptr(q), gpu_ptr(k), + gpu_ptr(v), gpu_ptr
(o), p); + }; + + auto dispatch_dim = [&](auto tt, auto ct) { + if (D == 64) launch(tt, ct, std::integral_constant()); + else if (D == 128) launch(tt, ct, std::integral_constant()); + else if (D == 256) launch(tt, ct, std::integral_constant()); + }; + + if (o.dtype() == bfloat16) { + if (do_causal) dispatch_dim(hip_bfloat16(), std::true_type()); + else dispatch_dim(hip_bfloat16(), std::false_type()); + } else { + if (do_causal) dispatch_dim(__half(), std::true_type()); + else dispatch_dim(__half(), std::false_type()); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/rocm/gemms/gemv.h b/mlx/backend/rocm/gemms/gemv.h new file mode 100644 index 0000000000..bb7f60c9e6 --- /dev/null +++ b/mlx/backend/rocm/gemms/gemv.h @@ -0,0 +1,34 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/rocm/device.h" + +namespace mlx::core::rocm { + +bool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed); + +void gemv( + const array& a, + const array& b, + array& out, + int M, + int N, + int K, + uint32_t batch_count, + const mlx::core::Shape& batch_shape, + const mlx::core::Strides& a_batch_strides, + const mlx::core::Strides& b_batch_strides, + CommandEncoder& encoder); + +void gather_mv( + const array& mat, + const array& vec, + const array& mat_indices, + const array& vec_indices, + array& out, + int N, + int K, + CommandEncoder& encoder); + +} // namespace mlx::core::rocm diff --git a/mlx/backend/rocm/gemms/gemv.hip b/mlx/backend/rocm/gemms/gemv.hip new file mode 100644 index 0000000000..34100ca2f8 --- /dev/null +++ b/mlx/backend/rocm/gemms/gemv.hip @@ -0,0 +1,826 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/device/config.h" +#include "mlx/backend/rocm/gemms/gemv.h" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/dtype_utils.h" + +#include +#include +#include + +namespace mlx::core::rocm { + +static constexpr int rows_per_block = 16; +static constexpr int kMaxInlineBatchDims = 8; + +struct GemvBatchParams { + int batch_ndim; + int32_t batch_shape[kMaxInlineBatchDims]; + int64_t mat_batch_strides[kMaxInlineBatchDims]; + int64_t vec_batch_strides[kMaxInlineBatchDims]; +}; + +struct GemvGatherParams { + int mat_batch_ndim; + int vec_batch_ndim; + int index_batch_ndim; + int32_t mat_batch_shape[kMaxInlineBatchDims]; + int64_t mat_batch_strides[kMaxInlineBatchDims]; + int32_t vec_batch_shape[kMaxInlineBatchDims]; + int64_t vec_batch_strides[kMaxInlineBatchDims]; + int32_t index_shape[kMaxInlineBatchDims]; + int64_t mat_index_strides[kMaxInlineBatchDims]; + int64_t vec_index_strides[kMaxInlineBatchDims]; +}; + +// Accumulator type selection per input element type T. +template +struct GemvAccType { + using type = T; +}; + +template <> +struct GemvAccType<__half> { + using type = float; +}; + +template <> +struct GemvAccType { + using type = float; +}; + +template <> +struct GemvAccType { + using type = float; +}; + +template <> +struct GemvAccType { + using type = double; +}; + +// Warp reduction for sum +template +__device__ __forceinline__ T warp_reduce_sum_gemv(T val) { +#pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + val += __shfl_down(val, offset); + } + return val; +} + +// Specialization for hip_bfloat16 +template <> +__device__ __forceinline__ float warp_reduce_sum_gemv(float val) { +#pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + val += __shfl_down(val, offset); + } + return val; +} + +template +__device__ void +gemv_impl(const T* mat, const T* vec, T* out, int rows, int cols) { + int row = blockIdx.x * rows_per_block + threadIdx.y; + + if (row < rows) { + using Acc = typename GemvAccType::type; + Acc sum = Acc(0); + + // Each thread processes multiple elements + for (int col = n_per_thread * threadIdx.x; col < cols; + col += (WARP_SIZE * n_per_thread)) { + // Load and accumulate using vectorized loads if possible + auto mat_v = load_vector(mat + row * cols, col / n_per_thread, cols, T(0)); + auto vec_v = load_vector(vec, col / n_per_thread, cols, T(0)); + +#pragma unroll + for (int j = 0; j < n_per_thread; ++j) { + sum += static_cast(mat_v[j]) * static_cast(vec_v[j]); + } + } + + // Warp reduction + sum = warp_reduce_sum_gemv(sum); + + if (threadIdx.x == 0) { + out[row] = static_cast(sum); + } + } +} + +template +__global__ void +gemv_single(const T* mat, const T* vec, T* out, int rows, int cols) { + gemv_impl(mat, vec, out, rows, cols); +} + +// Helper to compute batch offset +template +__device__ __forceinline__ int64_t elem_to_loc_1d( + int64_t idx, + const ShapeT* shape, + const int64_t* strides, + int ndim) { + int64_t offset = 0; + for (int i = ndim - 1; i >= 0; --i) { + offset += (idx % shape[i]) * strides[i]; + idx /= shape[i]; + } + return offset; +} + +template +__global__ void gemv_batched( + const T* mat, + const T* vec, + T* out, + int rows, + int cols, + const int32_t* batch_shape, + const int64_t* mat_batch_strides, + const int64_t* vec_batch_strides, + int batch_ndim) { + int batch_idx = blockIdx.y; + + int64_t mat_offset = + elem_to_loc_1d(batch_idx, batch_shape, mat_batch_strides, batch_ndim); + int64_t vec_offset = + elem_to_loc_1d(batch_idx, batch_shape, vec_batch_strides, batch_ndim); + + gemv_impl( + mat + mat_offset, vec + vec_offset, out + batch_idx * rows, rows, cols); +} + +template +__global__ void gemv_batched_inline( + const T* mat, + const T* vec, + T* out, + int rows, + int cols, + GemvBatchParams params) { + int batch_idx = blockIdx.y; + + int64_t mat_offset = elem_to_loc_1d( + batch_idx, + params.batch_shape, + params.mat_batch_strides, + params.batch_ndim); + int64_t vec_offset = elem_to_loc_1d( + batch_idx, + params.batch_shape, + params.vec_batch_strides, + params.batch_ndim); + + gemv_impl( + mat + mat_offset, vec + vec_offset, out + batch_idx * rows, rows, cols); +} + +template +__global__ void gemv_gather( + const T* mat, + const T* vec, + T* out, + const uint32_t* mat_indices, + const uint32_t* vec_indices, + int rows, + int cols, + const int32_t* mat_batch_shape, + const int64_t* mat_batch_strides, + int mat_batch_ndim, + const int32_t* vec_batch_shape, + const int64_t* vec_batch_strides, + int vec_batch_ndim, + const int32_t* index_shape, + const int64_t* mat_index_strides, + const int64_t* vec_index_strides, + int index_batch_ndim); + +__device__ __forceinline__ uint32_t gather_index( + const uint32_t* indices, + int64_t indices_idx, + const int32_t* index_shape, + const int64_t* index_strides, + int index_batch_ndim) { + if (index_batch_ndim > 1) { + auto index_offset = elem_to_loc_1d( + indices_idx, index_shape, index_strides, index_batch_ndim); + return indices[index_offset]; + } + if (index_batch_ndim == 1) { + return indices[indices_idx * index_strides[0]]; + } + return indices[0]; +} + +__device__ __forceinline__ int64_t gather_batch_offset( + uint32_t index, + const int32_t* batch_shape, + const int64_t* batch_strides, + int batch_ndim) { + if (batch_ndim > 1) { + return elem_to_loc_1d(index, batch_shape, batch_strides, batch_ndim); + } + if (batch_ndim == 1) { + return index * batch_strides[0]; + } + return 0; +} + +template +__device__ void gemv_gather_impl( + const T* mat, + const T* vec, + T* out, + const uint32_t* mat_indices, + const uint32_t* vec_indices, + int rows, + int cols, + int indices_idx, + const int32_t* mat_batch_shape, + const int64_t* mat_batch_strides, + int mat_batch_ndim, + const int32_t* vec_batch_shape, + const int64_t* vec_batch_strides, + int vec_batch_ndim, + const int32_t* index_shape, + const int64_t* mat_index_strides, + const int64_t* vec_index_strides, + int index_batch_ndim) { + uint32_t index_mat = gather_index( + mat_indices, + indices_idx, + index_shape, + mat_index_strides, + index_batch_ndim); + uint32_t index_vec = gather_index( + vec_indices, + indices_idx, + index_shape, + vec_index_strides, + index_batch_ndim); + + int64_t mat_offset = gather_batch_offset( + index_mat, mat_batch_shape, mat_batch_strides, mat_batch_ndim); + int64_t vec_offset = gather_batch_offset( + index_vec, vec_batch_shape, vec_batch_strides, vec_batch_ndim); + + gemv_impl( + mat + mat_offset, vec + vec_offset, out + indices_idx * rows, rows, cols); +} + +template +__global__ void gemv_gather( + const T* mat, + const T* vec, + T* out, + const uint32_t* mat_indices, + const uint32_t* vec_indices, + int rows, + int cols, + const int32_t* mat_batch_shape, + const int64_t* mat_batch_strides, + int mat_batch_ndim, + const int32_t* vec_batch_shape, + const int64_t* vec_batch_strides, + int vec_batch_ndim, + const int32_t* index_shape, + const int64_t* mat_index_strides, + const int64_t* vec_index_strides, + int index_batch_ndim) { + int indices_idx = blockIdx.y; + + gemv_gather_impl( + mat, + vec, + out, + mat_indices, + vec_indices, + rows, + cols, + indices_idx, + mat_batch_shape, + mat_batch_strides, + mat_batch_ndim, + vec_batch_shape, + vec_batch_strides, + vec_batch_ndim, + index_shape, + mat_index_strides, + vec_index_strides, + index_batch_ndim); +} + +template +__global__ void gemv_gather_inline( + const T* mat, + const T* vec, + T* out, + const uint32_t* mat_indices, + const uint32_t* vec_indices, + int rows, + int cols, + GemvGatherParams params) { + int indices_idx = blockIdx.y; + + gemv_gather_impl( + mat, + vec, + out, + mat_indices, + vec_indices, + rows, + cols, + indices_idx, + params.mat_batch_shape, + params.mat_batch_strides, + params.mat_batch_ndim, + params.vec_batch_shape, + params.vec_batch_strides, + params.vec_batch_ndim, + params.index_shape, + params.mat_index_strides, + params.vec_index_strides, + params.index_batch_ndim); +} + +bool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed) { + return K % 32 == 0 && ((M == 1 && b_transposed) || (N == 1 && !a_transposed)); +} + +template +void dispatch_n_per_thread(int n_per_thread, F&& f) { + switch (n_per_thread) { + case 1: + f(std::integral_constant{}); + break; + case 2: + f(std::integral_constant{}); + break; + case 4: + f(std::integral_constant{}); + break; + case 8: + f(std::integral_constant{}); + break; + case 16: + f(std::integral_constant{}); + break; + } +} + +void gemv( + const array& a, + const array& b, + array& out, + int M, + int N, + int K, + uint32_t batch_count, + const mlx::core::Shape& batch_shape, + const mlx::core::Strides& a_batch_strides, + const mlx::core::Strides& b_batch_strides, + CommandEncoder& encoder) { + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); + + dim3 block_dims{WARP_SIZE, rows_per_block}; + int rows; + int cols = K; + + // Determine which array is the matrix and which is the vector + const void* mat_ptr; + const void* vec_ptr; + const mlx::core::Strides* mat_strides_ptr; + const mlx::core::Strides* vec_strides_ptr; + + if (M == 1) { + mat_ptr = gpu_ptr(b); + vec_ptr = gpu_ptr(a); + rows = N; + mat_strides_ptr = &b_batch_strides; + vec_strides_ptr = &a_batch_strides; + } else { + mat_ptr = gpu_ptr(a); + vec_ptr = gpu_ptr(b); + rows = M; + mat_strides_ptr = &a_batch_strides; + vec_strides_ptr = &b_batch_strides; + } + void* out_base_ptr = gpu_ptr(out); + + uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block; + + // Determine n_per_thread based on alignment + int n_per_t = 1; + if (K % 512 == 0) { + n_per_t = 16; + } else if (K % 256 == 0) { + n_per_t = 8; + } else if (K % 128 == 0) { + n_per_t = 4; + } else if (K % 64 == 0) { + n_per_t = 2; + } + + // For batched operations, allocate device memory for parameters + int32_t* d_batch_shape = nullptr; + int64_t* d_mat_strides = nullptr; + int64_t* d_vec_strides = nullptr; + GemvBatchParams inline_batch_params{}; + bool use_inline_batch_params = false; + + if (batch_count > 1) { + size_t batch_ndim = batch_shape.size(); + if (batch_ndim <= kMaxInlineBatchDims) { + use_inline_batch_params = true; + inline_batch_params.batch_ndim = static_cast(batch_ndim); + for (size_t i = 0; i < batch_ndim; ++i) { + inline_batch_params.batch_shape[i] = batch_shape[i]; + inline_batch_params.mat_batch_strides[i] = (*mat_strides_ptr)[i]; + inline_batch_params.vec_batch_strides[i] = (*vec_strides_ptr)[i]; + } + } else { + (void)hipMalloc(&d_batch_shape, batch_ndim * sizeof(int32_t)); + (void)hipMalloc(&d_mat_strides, batch_ndim * sizeof(int64_t)); + (void)hipMalloc(&d_vec_strides, batch_ndim * sizeof(int64_t)); + + (void)hipMemcpy( + d_batch_shape, + batch_shape.data(), + batch_ndim * sizeof(int32_t), + hipMemcpyHostToDevice); + (void)hipMemcpy( + d_mat_strides, + mat_strides_ptr->data(), + batch_ndim * sizeof(int64_t), + hipMemcpyHostToDevice); + (void)hipMemcpy( + d_vec_strides, + vec_strides_ptr->data(), + batch_ndim * sizeof(int64_t), + hipMemcpyHostToDevice); + } + } + + if (batch_count > 1 && !use_inline_batch_params) { + // Rare fallback: batch_ndim > kMaxInlineBatchDims. Strides live in + // stream-ordered device memory freed via hipFreeAsync, so keep this on the + // legacy stream-capture path. + encoder.launch_kernel([&, + mat_ptr, + vec_ptr, + out_base_ptr, + d_batch_shape, + d_mat_strides, + d_vec_strides](hipStream_t stream) { + auto launch_kernel = [&](auto type_tag, auto n_per_thread) { + using T = typename decltype(type_tag)::type; + const T* mat = static_cast(mat_ptr); + const T* vec = static_cast(vec_ptr); + T* out_ptr = static_cast(out_base_ptr); + hipLaunchKernelGGL( + (gemv_batched), + dim3(num_blocks_x, batch_count), + block_dims, + 0, + stream, + mat, + vec, + out_ptr, + rows, + cols, + d_batch_shape, + d_mat_strides, + d_vec_strides, + static_cast(batch_shape.size())); + }; + + dispatch_n_per_thread(n_per_t, [&](auto n_per_thread) { + switch (out.dtype()) { + case float32: + launch_kernel(type_identity{}, n_per_thread); + break; + case float16: + launch_kernel(type_identity<__half>{}, n_per_thread); + break; + case bfloat16: + launch_kernel(type_identity{}, n_per_thread); + break; + case float64: + launch_kernel(type_identity{}, n_per_thread); + break; + default: + break; + } + }); + + (void)hipFreeAsync(d_batch_shape, stream); + (void)hipFreeAsync(d_mat_strides, stream); + (void)hipFreeAsync(d_vec_strides, stream); + }); + return; + } + + auto add_node = [&](auto type_tag, auto n_per_thread) { + using T = typename decltype(type_tag)::type; + const T* mat = static_cast(mat_ptr); + const T* vec = static_cast(vec_ptr); + T* out_ptr = static_cast(out_base_ptr); + + if (batch_count == 1) { + encoder.add_kernel_node( + &gemv_single, + dim3(num_blocks_x), + block_dims, + 0u, + mat, + vec, + out_ptr, + rows, + cols); + } else { + encoder.add_kernel_node( + &gemv_batched_inline, + dim3(num_blocks_x, batch_count), + block_dims, + 0u, + mat, + vec, + out_ptr, + rows, + cols, + inline_batch_params); + } + }; + + dispatch_n_per_thread(n_per_t, [&](auto n_per_thread) { + switch (out.dtype()) { + case float32: + add_node(type_identity{}, n_per_thread); + break; + case float16: + add_node(type_identity<__half>{}, n_per_thread); + break; + case bfloat16: + add_node(type_identity{}, n_per_thread); + break; + case float64: + add_node(type_identity{}, n_per_thread); + break; + default: + break; + } + }); +} + +void gather_mv( + const array& mat_, + const array& vec_, + const array& mat_indices, + const array& vec_indices, + array& out, + int N, + int K, + CommandEncoder& encoder) { + encoder.set_input_array(mat_); + encoder.set_input_array(vec_); + encoder.set_input_array(mat_indices); + encoder.set_input_array(vec_indices); + encoder.set_output_array(out); + + dim3 block_dims{WARP_SIZE, rows_per_block}; + int rows = N; + int cols = K; + uint32_t batch_size = static_cast(out.size() / N); + + uint32_t num_blocks_x = (rows + rows_per_block - 1) / rows_per_block; + + int n_per_t = 1; + if (K % 128 == 0) { + n_per_t = 4; + } else if (K % 64 == 0) { + n_per_t = 2; + } + + auto [index_shape, index_strides] = collapse_contiguous_dims( + mat_indices.shape(), {mat_indices.strides(), vec_indices.strides()}); + auto mat_index_strides = index_strides[0]; + auto vec_index_strides = index_strides[1]; + + mlx::core::Shape mat_batch_shape{ + mat_.shape().begin(), mat_.shape().end() - 2}; + mlx::core::Strides mat_batch_strides{ + mat_.strides().begin(), mat_.strides().end() - 2}; + int mat_batch_ndim = mat_batch_shape.size(); + + mlx::core::Shape vec_batch_shape{ + vec_.shape().begin(), vec_.shape().end() - 2}; + mlx::core::Strides vec_batch_strides{ + vec_.strides().begin(), vec_.strides().end() - 2}; + int vec_batch_ndim = vec_batch_shape.size(); + + int index_batch_ndim = index_shape.size(); + + int32_t* d_mat_batch_shape = nullptr; + int64_t* d_mat_batch_strides = nullptr; + int32_t* d_vec_batch_shape = nullptr; + int64_t* d_vec_batch_strides = nullptr; + int32_t* d_index_shape = nullptr; + int64_t* d_mat_index_strides = nullptr; + int64_t* d_vec_index_strides = nullptr; + + GemvGatherParams inline_gather_params{}; + bool use_inline_gather_params = mat_batch_ndim <= kMaxInlineBatchDims && + vec_batch_ndim <= kMaxInlineBatchDims && + index_batch_ndim <= kMaxInlineBatchDims; + + if (use_inline_gather_params) { + inline_gather_params.mat_batch_ndim = mat_batch_ndim; + inline_gather_params.vec_batch_ndim = vec_batch_ndim; + inline_gather_params.index_batch_ndim = index_batch_ndim; + for (int i = 0; i < mat_batch_ndim; ++i) { + inline_gather_params.mat_batch_shape[i] = mat_batch_shape[i]; + inline_gather_params.mat_batch_strides[i] = mat_batch_strides[i]; + } + for (int i = 0; i < vec_batch_ndim; ++i) { + inline_gather_params.vec_batch_shape[i] = vec_batch_shape[i]; + inline_gather_params.vec_batch_strides[i] = vec_batch_strides[i]; + } + for (int i = 0; i < index_batch_ndim; ++i) { + inline_gather_params.index_shape[i] = index_shape[i]; + inline_gather_params.mat_index_strides[i] = mat_index_strides[i]; + inline_gather_params.vec_index_strides[i] = vec_index_strides[i]; + } + } else { + auto copy_shape_to_device = [](const mlx::core::Shape& shape, + int32_t** dst_shape) { + if (shape.empty()) { + return; + } + (void)hipMalloc(dst_shape, shape.size() * sizeof(int32_t)); + (void)hipMemcpy( + *dst_shape, + shape.data(), + shape.size() * sizeof(int32_t), + hipMemcpyHostToDevice); + }; + + auto copy_strides_to_device = [](const mlx::core::Strides& strides, + int64_t** dst_strides) { + if (strides.empty()) { + return; + } + (void)hipMalloc(dst_strides, strides.size() * sizeof(int64_t)); + (void)hipMemcpy( + *dst_strides, + strides.data(), + strides.size() * sizeof(int64_t), + hipMemcpyHostToDevice); + }; + + copy_shape_to_device(mat_batch_shape, &d_mat_batch_shape); + copy_strides_to_device(mat_batch_strides, &d_mat_batch_strides); + copy_shape_to_device(vec_batch_shape, &d_vec_batch_shape); + copy_strides_to_device(vec_batch_strides, &d_vec_batch_strides); + copy_shape_to_device(index_shape, &d_index_shape); + copy_strides_to_device(mat_index_strides, &d_mat_index_strides); + copy_strides_to_device(vec_index_strides, &d_vec_index_strides); + } + + const void* mat_ptr = gpu_ptr(mat_); + const void* vec_ptr = gpu_ptr(vec_); + void* out_ptr = gpu_ptr(out); + const uint32_t* mat_indices_ptr = gpu_ptr(mat_indices); + const uint32_t* vec_indices_ptr = gpu_ptr(vec_indices); + + if (!use_inline_gather_params) { + // Rare fallback: batch dims exceed kMaxInlineBatchDims. Params live in + // stream-ordered device memory freed via hipFreeAsync, so keep this on the + // legacy stream-capture path. + encoder.launch_kernel([&, + mat_ptr, + vec_ptr, + out_ptr, + mat_indices_ptr, + vec_indices_ptr, + d_mat_batch_shape, + d_mat_batch_strides, + d_vec_batch_shape, + d_vec_batch_strides, + d_index_shape, + d_mat_index_strides, + d_vec_index_strides](hipStream_t stream) { + auto launch_kernel = [&](auto type_tag, auto n_per_thread) { + using T = typename decltype(type_tag)::type; + hipLaunchKernelGGL( + (gemv_gather), + dim3(num_blocks_x, batch_size), + block_dims, + 0, + stream, + static_cast(mat_ptr), + static_cast(vec_ptr), + static_cast(out_ptr), + mat_indices_ptr, + vec_indices_ptr, + rows, + cols, + d_mat_batch_shape, + d_mat_batch_strides, + mat_batch_ndim, + d_vec_batch_shape, + d_vec_batch_strides, + vec_batch_ndim, + d_index_shape, + d_mat_index_strides, + d_vec_index_strides, + index_batch_ndim); + }; + + dispatch_n_per_thread(n_per_t, [&](auto n_per_thread) { + switch (out.dtype()) { + case float32: + launch_kernel(type_identity{}, n_per_thread); + break; + case float16: + launch_kernel(type_identity<__half>{}, n_per_thread); + break; + case bfloat16: + launch_kernel(type_identity{}, n_per_thread); + break; + case float64: + launch_kernel(type_identity{}, n_per_thread); + break; + default: + break; + } + }); + + if (d_mat_batch_shape != nullptr) { + (void)hipFreeAsync(d_mat_batch_shape, stream); + } + if (d_mat_batch_strides != nullptr) { + (void)hipFreeAsync(d_mat_batch_strides, stream); + } + if (d_vec_batch_shape != nullptr) { + (void)hipFreeAsync(d_vec_batch_shape, stream); + } + if (d_vec_batch_strides != nullptr) { + (void)hipFreeAsync(d_vec_batch_strides, stream); + } + if (d_index_shape != nullptr) { + (void)hipFreeAsync(d_index_shape, stream); + } + if (d_mat_index_strides != nullptr) { + (void)hipFreeAsync(d_mat_index_strides, stream); + } + if (d_vec_index_strides != nullptr) { + (void)hipFreeAsync(d_vec_index_strides, stream); + } + }); + return; + } + + auto add_node = [&](auto type_tag, auto n_per_thread) { + using T = typename decltype(type_tag)::type; + encoder.add_kernel_node( + &gemv_gather_inline, + dim3(num_blocks_x, batch_size), + block_dims, + 0u, + static_cast(mat_ptr), + static_cast(vec_ptr), + static_cast(out_ptr), + mat_indices_ptr, + vec_indices_ptr, + rows, + cols, + inline_gather_params); + }; + + dispatch_n_per_thread(n_per_t, [&](auto n_per_thread) { + switch (out.dtype()) { + case float32: + add_node(type_identity{}, n_per_thread); + break; + case float16: + add_node(type_identity<__half>{}, n_per_thread); + break; + case bfloat16: + add_node(type_identity{}, n_per_thread); + break; + case float64: + add_node(type_identity{}, n_per_thread); + break; + default: + break; + } + }); +} + +} // namespace mlx::core::rocm diff --git a/mlx/backend/rocm/gemms/hipblaslt_gemm.cpp b/mlx/backend/rocm/gemms/hipblaslt_gemm.cpp new file mode 100644 index 0000000000..40e07689af --- /dev/null +++ b/mlx/backend/rocm/gemms/hipblaslt_gemm.cpp @@ -0,0 +1,1057 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/gemms/hipblaslt_gemm.h" +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/kernel_utils.hpp" + +#include +#include + +#include +#include +#include +#include + +namespace mlx::core::rocm { + +namespace { + +// Maximum workspace size for hipBLASLt algorithms (32 MB). +// hipBLASLt may request scratch memory for certain algorithm choices. +constexpr size_t kMaxWorkspaceBytes = 32u * 1024u * 1024u; + +// Per-device hipBLASLt handle cache. Lazily initialised, thread-safe. +struct HipblasltState { + hipblasLtHandle_t handle{nullptr}; + bool initialized{false}; + bool available{false}; + std::mutex mutex; + + // Persistent workspace allocation (grown as needed, never shrunk). + void* workspace{nullptr}; + size_t workspace_size{0}; +}; + +// One state per device (indexed by HIP device ordinal). +// 16 devices should be more than enough for any system. +static constexpr int kMaxDevices = 16; +static HipblasltState g_state[kMaxDevices]; + +HipblasltState& get_state(int device_id) { + if (device_id < 0 || device_id >= kMaxDevices) { + throw std::runtime_error( + "hipBLASLt: device id out of range: " + std::to_string(device_id)); + } + return g_state[device_id]; +} + +// Initialise the hipBLASLt handle for the given device. +// Must be called with state.mutex held. +void init_handle(HipblasltState& state, int device_id) { + if (state.initialized) { + return; + } + state.initialized = true; + + hipblasStatus_t status = hipblasLtCreate(&state.handle); + if (status != HIPBLAS_STATUS_SUCCESS) { + state.available = false; + state.handle = nullptr; + std::cerr << "Warning: hipBLASLt initialization failed (status " + << static_cast(status) << ")." << std::endl; + return; + } + state.available = true; + + // Pre-allocate the matmul workspace to the maximum size NOW so that + // ensure_workspace() never calls hipMalloc during a HIP-graph capture (a + // device alloc on the capturing stream invalidates the graph). Any algorithm + // the heuristic returns fits within kMaxWorkspaceBytes, so a single up-front + // allocation makes hipblasLtMatmul capture-safe. + int prev_dev = 0; + (void)hipGetDevice(&prev_dev); + (void)hipSetDevice(device_id); + if (hipMalloc(&state.workspace, kMaxWorkspaceBytes) == hipSuccess) { + state.workspace_size = kMaxWorkspaceBytes; + } else { + state.workspace = nullptr; + state.workspace_size = 0; + } + (void)hipSetDevice(prev_dev); +} + +hipblasLtHandle_t get_handle(int device_id) { + auto& state = get_state(device_id); + if (!state.initialized) { + std::lock_guard lock(state.mutex); + init_handle(state, device_id); + } + if (!state.available) { + throw std::runtime_error("hipBLASLt is not available on this device."); + } + return state.handle; +} + +// Ensure the per-device workspace is at least `required` bytes. +// Returns the workspace pointer and the actual allocated size. +// Must be called from within a launch_kernel callback (i.e., on the +// stream-submission thread for this device), so no extra locking is needed +// beyond the device serialisation that CommandEncoder already provides. +std::pair ensure_workspace(int device_id, size_t required) { + auto& state = get_state(device_id); + if (required <= state.workspace_size && state.workspace != nullptr) { + return {state.workspace, state.workspace_size}; + } + // Free old allocation (hipFree is a no-op on nullptr). + if (state.workspace) { + (void)hipFree(state.workspace); + state.workspace = nullptr; + state.workspace_size = 0; + } + if (required == 0) { + return {nullptr, 0}; + } + hipError_t err = hipMalloc(&state.workspace, required); + if (err != hipSuccess) { + state.workspace = nullptr; + state.workspace_size = 0; + return {nullptr, 0}; + } + state.workspace_size = required; + return {state.workspace, state.workspace_size}; +} + +hipDataType to_hipblaslt_dtype(Dtype dtype) { + switch (dtype) { + case float32: + return HIP_R_32F; + case float16: + return HIP_R_16F; + case bfloat16: + return HIP_R_16BF; + default: + throw std::runtime_error("Unsupported dtype for hipBLASLt GEMM"); + } +} + +hipblasOperation_t to_hipblas_op(bool transpose) { + return transpose ? HIPBLAS_OP_T : HIPBLAS_OP_N; +} + +// Per-device GEMM capability table, discovered at load time by asking +// hipBLASLt's heuristic which input types yield kernels on this GPU. This is a +// runtime probe rather than a hardcoded arch list, so it tracks whatever the +// installed Tensile library actually supports. +struct GemmCaps { + bool probed{false}; + bool bf16{false}; + bool fp8_e4m3{false}; + bool fp8_e5m2{false}; + bool int8{false}; +}; +static GemmCaps g_caps[kMaxDevices]; +static std::mutex g_caps_mutex; + +// Does this (input, output, compute) combination have any hipBLASLt algorithm +// on the given handle? AlgoGetHeuristic only inspects descriptors, so no device +// memory is touched. Uses a representative GEMM shape. +bool probe_gemm_combo( + hipblasLtHandle_t handle, + hipDataType in_type, + hipDataType out_type, + hipblasComputeType_t compute_type) { + hipblasLtMatmulDesc_t desc = nullptr; + if (hipblasLtMatmulDescCreate(&desc, compute_type, HIP_R_32F) != + HIPBLAS_STATUS_SUCCESS) { + return false; + } + int32_t op_t = HIPBLAS_OP_T, op_n = HIPBLAS_OP_N; + hipblasLtMatmulDescSetAttribute( + desc, HIPBLASLT_MATMUL_DESC_TRANSA, &op_t, sizeof(op_t)); + hipblasLtMatmulDescSetAttribute( + desc, HIPBLASLT_MATMUL_DESC_TRANSB, &op_n, sizeof(op_n)); + const int M = 2048, N = 512, K = 2048; + hipblasLtMatrixLayout_t la = nullptr, lb = nullptr, lc = nullptr, ld = nullptr; + hipblasLtMatrixLayoutCreate(&la, in_type, K, M, K); + hipblasLtMatrixLayoutCreate(&lb, in_type, K, N, K); + hipblasLtMatrixLayoutCreate(&lc, out_type, M, N, M); + hipblasLtMatrixLayoutCreate(&ld, out_type, M, N, M); + hipblasLtMatmulPreference_t pref = nullptr; + hipblasLtMatmulPreferenceCreate(&pref); + uint64_t ws = kMaxWorkspaceBytes; + hipblasLtMatmulPreferenceSetAttribute( + pref, HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &ws, sizeof(ws)); + hipblasLtMatmulHeuristicResult_t res[4]; + int count = 0; + hipblasStatus_t st = hipblasLtMatmulAlgoGetHeuristic( + handle, desc, la, lb, lc, ld, pref, 4, res, &count); + if (pref) + hipblasLtMatmulPreferenceDestroy(pref); + if (ld) + hipblasLtMatrixLayoutDestroy(ld); + if (lc) + hipblasLtMatrixLayoutDestroy(lc); + if (lb) + hipblasLtMatrixLayoutDestroy(lb); + if (la) + hipblasLtMatrixLayoutDestroy(la); + if (desc) + hipblasLtMatmulDescDestroy(desc); + return st == HIPBLAS_STATUS_SUCCESS && count > 0; +} + +const GemmCaps& gemm_caps(int device_id) { + std::lock_guard lock(g_caps_mutex); + GemmCaps& caps = g_caps[device_id]; + if (caps.probed) { + return caps; + } + caps.probed = true; + hipblasLtHandle_t handle = nullptr; + try { + handle = get_handle(device_id); + } catch (...) { + return caps; + } + caps.bf16 = probe_gemm_combo(handle, HIP_R_16BF, HIP_R_16BF, HIPBLAS_COMPUTE_32F); + caps.fp8_e4m3 = + probe_gemm_combo(handle, HIP_R_8F_E4M3, HIP_R_16BF, HIPBLAS_COMPUTE_32F); + caps.fp8_e5m2 = + probe_gemm_combo(handle, HIP_R_8F_E5M2, HIP_R_16BF, HIPBLAS_COMPUTE_32F); + caps.int8 = probe_gemm_combo(handle, HIP_R_8I, HIP_R_32I, HIPBLAS_COMPUTE_32I); + + hipDeviceProp_t props; + const char* arch = + (hipGetDeviceProperties(&props, device_id) == hipSuccess) + ? props.gcnArchName + : "?"; + fprintf( + stderr, + "[hipBLASLt caps] device %d (%s): bf16=%d fp8_e4m3=%d fp8_e5m2=%d int8=%d\n", + device_id, + arch, + caps.bf16, + caps.fp8_e4m3, + caps.fp8_e5m2, + caps.int8); + return caps; +} + +// Input precision chosen for a GEMM on a given device. The hardware/library +// capability table decides which is reachable; accuracy ranks them e4m3 > bf16 +// for our (already-quantized) weights. +enum class GemmPrecision { Bf16, Fp8E4M3, Fp8E5M2, Int8 }; + +// Highest-throughput input precision this device can run for half-precision +// GEMMs while preserving accuracy: fp8 e4m3 where the library has kernels +// (RDNA4), otherwise bf16 (RDNA3.5 and anything without fp8 Tensile kernels). +GemmPrecision preferred_gemm_precision(int device_id) { + const GemmCaps& caps = gemm_caps(device_id); + if (caps.fp8_e4m3) { + return GemmPrecision::Fp8E4M3; + } + return GemmPrecision::Bf16; +} + +// RAII wrappers for hipBLASLt descriptors to avoid leaks on error paths. +struct MatmulDescGuard { + hipblasLtMatmulDesc_t desc{nullptr}; + ~MatmulDescGuard() { + if (desc) + hipblasLtMatmulDescDestroy(desc); + } +}; +struct MatrixLayoutGuard { + hipblasLtMatrixLayout_t layout{nullptr}; + ~MatrixLayoutGuard() { + if (layout) + hipblasLtMatrixLayoutDestroy(layout); + } +}; +struct PreferenceGuard { + hipblasLtMatmulPreference_t pref{nullptr}; + ~PreferenceGuard() { + if (pref) + hipblasLtMatmulPreferenceDestroy(pref); + } +}; + +// Core implementation: set up descriptors, find the best algorithm, and +// execute the matmul on the given stream. +void hipblaslt_gemm_impl( + hipblasLtHandle_t handle, + int device_id, + hipblasOperation_t op_a, + hipblasOperation_t op_b, + int M, + int N, + int K, + const float* alpha, + const void* a_ptr, + int lda, + int64_t stride_a, + const void* b_ptr, + int ldb, + int64_t stride_b, + const float* beta, + void* c_ptr, + int ldc, + int64_t stride_c, + int batch_count, + hipDataType data_type, + hipStream_t stream) { + hipblasStatus_t status; + + // Discover this device's GEMM capability table on first use (prints once). + GemmPrecision precision = preferred_gemm_precision(device_id); + (void)precision; + + hipDataType scale_type = HIP_R_32F; + int32_t trans_a_val = static_cast(op_a); + int32_t trans_b_val = static_cast(op_b); + + // --- Matrix layouts (column-major, as expected by BLAS) --- + // A is (op_a == N) ? M x K : K x M in column-major + // B is (op_b == N) ? K x N : N x K in column-major + // C is M x N in column-major + uint64_t a_rows = (op_a == HIPBLAS_OP_N) ? M : K; + uint64_t a_cols = (op_a == HIPBLAS_OP_N) ? K : M; + uint64_t b_rows = (op_b == HIPBLAS_OP_N) ? K : N; + uint64_t b_cols = (op_b == HIPBLAS_OP_N) ? N : K; + + MatrixLayoutGuard layout_a, layout_b, layout_c, layout_d; + + status = hipblasLtMatrixLayoutCreate( + &layout_a.layout, data_type, a_rows, a_cols, lda); + if (status != HIPBLAS_STATUS_SUCCESS) { + throw std::runtime_error( + "hipblasLtMatrixLayoutCreate(A) failed: " + + std::to_string(static_cast(status))); + } + + status = hipblasLtMatrixLayoutCreate( + &layout_b.layout, data_type, b_rows, b_cols, ldb); + if (status != HIPBLAS_STATUS_SUCCESS) { + throw std::runtime_error( + "hipblasLtMatrixLayoutCreate(B) failed: " + + std::to_string(static_cast(status))); + } + + status = hipblasLtMatrixLayoutCreate(&layout_c.layout, data_type, M, N, ldc); + if (status != HIPBLAS_STATUS_SUCCESS) { + throw std::runtime_error( + "hipblasLtMatrixLayoutCreate(C) failed: " + + std::to_string(static_cast(status))); + } + + // D has the same layout as C (in-place: D == C). + status = hipblasLtMatrixLayoutCreate(&layout_d.layout, data_type, M, N, ldc); + if (status != HIPBLAS_STATUS_SUCCESS) { + throw std::runtime_error( + "hipblasLtMatrixLayoutCreate(D) failed: " + + std::to_string(static_cast(status))); + } + + // Set batch attributes when doing strided batched GEMM. + if (batch_count > 1) { + int32_t bc = batch_count; + hipblasLtMatrixLayoutSetAttribute( + layout_a.layout, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &bc, sizeof(bc)); + hipblasLtMatrixLayoutSetAttribute( + layout_a.layout, + HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, + &stride_a, + sizeof(stride_a)); + + hipblasLtMatrixLayoutSetAttribute( + layout_b.layout, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &bc, sizeof(bc)); + hipblasLtMatrixLayoutSetAttribute( + layout_b.layout, + HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, + &stride_b, + sizeof(stride_b)); + + hipblasLtMatrixLayoutSetAttribute( + layout_c.layout, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &bc, sizeof(bc)); + hipblasLtMatrixLayoutSetAttribute( + layout_c.layout, + HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, + &stride_c, + sizeof(stride_c)); + + hipblasLtMatrixLayoutSetAttribute( + layout_d.layout, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &bc, sizeof(bc)); + hipblasLtMatrixLayoutSetAttribute( + layout_d.layout, + HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, + &stride_c, + sizeof(stride_c)); + } + + // --- Algorithm selection via heuristic --- + PreferenceGuard pref_guard; + status = hipblasLtMatmulPreferenceCreate(&pref_guard.pref); + if (status != HIPBLAS_STATUS_SUCCESS) { + throw std::runtime_error( + "hipblasLtMatmulPreferenceCreate failed: " + + std::to_string(static_cast(status))); + } + + uint64_t max_ws = kMaxWorkspaceBytes; + hipblasLtMatmulPreferenceSetAttribute( + pref_guard.pref, + HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &max_ws, + sizeof(max_ws)); + + // Request multiple algorithms for better occupancy/performance + static constexpr int kMaxAlgos = 8; + hipblasLtMatmulHeuristicResult_t heuristics[kMaxAlgos]; + int returned_algo_count = 0; + + MatmulDescGuard matmul_guard; + status = hipblasLtMatmulDescCreate( + &matmul_guard.desc, HIPBLAS_COMPUTE_32F, scale_type); + if (status != HIPBLAS_STATUS_SUCCESS) { + throw std::runtime_error( + "hipblasLtMatmulDescCreate failed: " + + std::to_string(static_cast(status))); + } + hipblasLtMatmulDescSetAttribute( + matmul_guard.desc, + HIPBLASLT_MATMUL_DESC_TRANSA, + &trans_a_val, + sizeof(trans_a_val)); + hipblasLtMatmulDescSetAttribute( + matmul_guard.desc, + HIPBLASLT_MATMUL_DESC_TRANSB, + &trans_b_val, + sizeof(trans_b_val)); + + // Per-(shape,dtype,transpose,device) algorithm cache. The chosen heuristic + // result is reusable across calls with identical problem geometry, so warm + // calls skip AlgoGetHeuristic — the dominant per-call cost for the many small + // GEMMs in a forward pass. + struct AlgoKey { + int M, N, K, batch, dt, ta, tb, dev; + bool operator==(const AlgoKey& o) const { + return M == o.M && N == o.N && K == o.K && batch == o.batch && + dt == o.dt && ta == o.ta && tb == o.tb && dev == o.dev; + } + }; + struct AlgoKeyHash { + size_t operator()(const AlgoKey& k) const { + size_t h = 1469598103934665603ULL; + for (int v : {k.M, k.N, k.K, k.batch, k.dt, k.ta, k.tb, k.dev}) { + h = (h ^ static_cast(v)) * 1099511628211ULL; + } + return h; + } + }; + static std::mutex algo_mutex; + static std::unordered_map + algo_cache; + + AlgoKey key{ + M, + N, + K, + batch_count, + static_cast(data_type), + trans_a_val, + trans_b_val, + device_id}; + hipblasLtMatmulHeuristicResult_t heuristic; + bool cache_hit = false; + { + std::lock_guard lock(algo_mutex); + auto cached = algo_cache.find(key); + if (cached != algo_cache.end()) { + heuristic = cached->second; + cache_hit = true; + } + } + + if (!cache_hit) { + status = hipblasLtMatmulAlgoGetHeuristic( + handle, + matmul_guard.desc, + layout_a.layout, + layout_b.layout, + layout_c.layout, + layout_d.layout, + pref_guard.pref, + kMaxAlgos, + heuristics, + &returned_algo_count); + + if (status != HIPBLAS_STATUS_SUCCESS || returned_algo_count == 0) { + throw std::runtime_error( + "hipblasLtMatmulAlgoGetHeuristic failed (status=" + + std::to_string(static_cast(status)) + + ", returned=" + std::to_string(returned_algo_count) + ")"); + } + + int best_algo_idx = 0; + + // Auto-tuning: benchmark all algorithms to find the fastest for each shape. + // Disabled by default — for quantized models the GEMM path is rarely used + // and the tuning overhead causes warm prompt regression. + // Enable with MLX_ROCM_HIPBLASLT_TUNE=1 for non-quantized models. + static bool do_tune = std::getenv("MLX_ROCM_HIPBLASLT_TUNE") != nullptr; + + if (do_tune && returned_algo_count > 1) { + double best_time = 1e30; + for (int algo_idx = 0; algo_idx < returned_algo_count; algo_idx++) { + size_t ws_need = heuristics[algo_idx].workspaceSize; + void* ws_p = nullptr; + size_t ws_s = 0; + if (ws_need > 0) { + auto [p, s] = ensure_workspace(device_id, ws_need); + ws_p = p; + ws_s = s; + if (!ws_p) + continue; + } + + // Warm-up + (void)hipblasLtMatmul( + handle, + matmul_guard.desc, + alpha, + a_ptr, + layout_a.layout, + b_ptr, + layout_b.layout, + beta, + c_ptr, + layout_c.layout, + c_ptr, + layout_d.layout, + &heuristics[algo_idx].algo, + ws_p, + ws_s, + stream); + (void)hipStreamSynchronize(stream); + + // Timed run + hipEvent_t start_ev, stop_ev; + (void)hipEventCreate(&start_ev); + (void)hipEventCreate(&stop_ev); + (void)hipEventRecord(start_ev, stream); + + static constexpr int kBenchIters = 3; + for (int r = 0; r < kBenchIters; r++) { + (void)hipblasLtMatmul( + handle, + matmul_guard.desc, + alpha, + a_ptr, + layout_a.layout, + b_ptr, + layout_b.layout, + beta, + c_ptr, + layout_c.layout, + c_ptr, + layout_d.layout, + &heuristics[algo_idx].algo, + ws_p, + ws_s, + stream); + } + + (void)hipEventRecord(stop_ev, stream); + (void)hipStreamSynchronize(stream); + float ms = 0; + (void)hipEventElapsedTime(&ms, start_ev, stop_ev); + (void)hipEventDestroy(start_ev); + (void)hipEventDestroy(stop_ev); + + double avg = ms / kBenchIters; + if (avg < best_time) { + best_time = avg; + best_algo_idx = algo_idx; + } + } + } + + heuristic = heuristics[best_algo_idx]; + { + std::lock_guard lock(algo_mutex); + algo_cache[key] = heuristic; + } + } + + // --- Workspace allocation --- + size_t ws_needed = heuristic.workspaceSize; + void* ws_ptr = nullptr; + size_t ws_actual = 0; + if (ws_needed > 0) { + auto [p, s] = ensure_workspace(device_id, ws_needed); + ws_ptr = p; + ws_actual = s; + if (ws_ptr == nullptr && ws_needed > 0) { + throw std::runtime_error( + "hipBLASLt: failed to allocate workspace of " + + std::to_string(ws_needed) + " bytes"); + } + } + + // --- Execute the matmul --- + status = hipblasLtMatmul( + handle, + matmul_guard.desc, + alpha, + a_ptr, + layout_a.layout, + b_ptr, + layout_b.layout, + beta, + c_ptr, + layout_c.layout, + c_ptr, // D == C (in-place) + layout_d.layout, + &heuristic.algo, + ws_ptr, + ws_actual, + stream); + + if (status != HIPBLAS_STATUS_SUCCESS) { + throw std::runtime_error( + "hipblasLtMatmul failed: " + std::to_string(static_cast(status))); + } +} + +} // namespace + +bool is_hipblaslt_available() { + // Diagnostic: force the rocBLAS path everywhere to test whether rocBLAS bf16 + // GEMM is numerically correct for this model. + static const bool g_force_rocblas = std::getenv("MLX_NO_HIPBLASLT") != nullptr; + if (g_force_rocblas) + return false; + // When automatic HIP-graph batching is on, the GEMM is graph-split and run + // immediately, but hipBLASLt's lazy hipblasLtCreate / AlgoGetHeuristic / + // workspace hipMalloc are non-capturable and abort the process if the stream + // is mid-graph. rocBLAS is graph-safe here, so force it whenever graphs are + // enabled. (rocBLAS == hipBLASLt speed at decode, so this costs nothing.) + static const bool g_graphs = + std::getenv("MLX_USE_HIP_GRAPHS") != nullptr; + if (g_graphs) + return false; + // hipBLASLt's lazy init is non-capturable; force rocBLAS during any capture. + if (stream_capturing()) + return false; + int device_id = 0; + (void)hipGetDevice(&device_id); + auto& state = get_state(device_id); + if (!state.initialized) { + std::lock_guard lock(state.mutex); + init_handle(state, device_id); + } + return state.available; +} + +void hipblaslt_gemm( + CommandEncoder& encoder, + bool transpose_a, + bool transpose_b, + int M, + int N, + int K, + float alpha, + const array& a, + int lda, + const array& b, + int ldb, + float beta, + array& c, + int ldc, + Dtype dtype) { + int device_id = encoder.device().hip_device(); + hipblasLtHandle_t handle = get_handle(device_id); + hipDataType hip_dtype = to_hipblaslt_dtype(dtype); + + // hipBLASLt uses column-major layout. MLX stores row-major, so we swap A + // and B and compute C^T = B^T * A^T, just like the rocBLAS path. + hipblasOperation_t op_a = to_hipblas_op(transpose_b); + hipblasOperation_t op_b = to_hipblas_op(transpose_a); + + // Per-call GEMM tracing, gated behind an env flag. + static const bool kGemmDebug = std::getenv("MLX_ROCM_GEMM_DEBUG") != nullptr; + if (kGemmDebug) { + fprintf( + stderr, + "[hipBLASLt] M=%d N=%d K=%d ta=%d tb=%d lda=%d ldb=%d ldc=%d\n", + M, N, K, (int)transpose_a, (int)transpose_b, lda, ldb, ldc); + } + + const void* a_ptr = gpu_ptr(a); + const void* b_ptr = gpu_ptr(b); + void* c_ptr = gpu_ptr(c); + + encoder.launch_kernel([=, &encoder](hipStream_t stream) { + hipblaslt_gemm_impl( + handle, + device_id, + op_a, + op_b, + N, // swap M/N for col-major trick + M, + K, + &alpha, + b_ptr, // swap A/B + ldb, + 0, // stride_a (unused for non-batched) + a_ptr, + lda, + 0, // stride_b (unused for non-batched) + &beta, + c_ptr, + ldc, + 0, // stride_c (unused for non-batched) + 1, // batch_count + hip_dtype, + stream); + }); +} + +void hipblaslt_gemm_batched( + CommandEncoder& encoder, + bool transpose_a, + bool transpose_b, + int M, + int N, + int K, + float alpha, + const array& a, + int lda, + int64_t stride_a, + const array& b, + int ldb, + int64_t stride_b, + float beta, + array& c, + int ldc, + int64_t stride_c, + int batch_count, + Dtype dtype) { + int device_id = encoder.device().hip_device(); + hipblasLtHandle_t handle = get_handle(device_id); + hipDataType hip_dtype = to_hipblaslt_dtype(dtype); + + // Same column-major swap as above. + hipblasOperation_t op_a = to_hipblas_op(transpose_b); + hipblasOperation_t op_b = to_hipblas_op(transpose_a); + + const void* a_ptr = gpu_ptr(a); + const void* b_ptr = gpu_ptr(b); + void* c_ptr = gpu_ptr(c); + + encoder.launch_kernel([=, &encoder](hipStream_t stream) { + hipblaslt_gemm_impl( + handle, + device_id, + op_a, + op_b, + N, + M, + K, + &alpha, + b_ptr, + ldb, + stride_b, // swapped: was b, now is "A" in col-major + a_ptr, + lda, + stride_a, // swapped: was a, now is "B" in col-major + &beta, + c_ptr, + ldc, + stride_c, + batch_count, + hip_dtype, + stream); + }); +} + +void hipblaslt_gemm_raw( + hipStream_t stream, + int op_a, + int op_b, + int M, + int N, + int K, + const float* alpha, + const void* a_ptr, + int lda, + const void* b_ptr, + int ldb, + const float* beta, + void* c_ptr, + int ldc, + int data_type_hint, + int /*compute_type_hint*/) { + int device_id = 0; + (void)hipGetDevice(&device_id); + hipblasLtHandle_t handle = get_handle(device_id); + + // Map data_type_hint: 1=fp16, 2=bf16, 3=fp32 + hipDataType hip_dtype; + switch (data_type_hint) { + case 1: + hip_dtype = HIP_R_16F; + break; + case 2: + hip_dtype = HIP_R_16BF; + break; + default: + hip_dtype = HIP_R_32F; + break; + } + + hipblaslt_gemm_impl( + handle, + device_id, + static_cast(op_a), + static_cast(op_b), + M, + N, + K, + alpha, + a_ptr, + lda, + 0, + b_ptr, + ldb, + 0, + beta, + c_ptr, + ldc, + 0, + 1, // batch_count + hip_dtype, + stream); +} + +bool device_has_fp8_gemm(int device_id) { + return gemm_caps(device_id).fp8_e4m3; +} + +void hipblaslt_gemm_fp8_raw( + hipStream_t stream, + int op_a, + int op_b, + int M, + int N, + int K, + const void* a_ptr, + int lda, + const void* b_ptr, + int ldb, + void* c_ptr, + int ldc, + const float* a_scale, + const float* b_scale) { + int device_id = 0; + (void)hipGetDevice(&device_id); + hipblasLtHandle_t handle = get_handle(device_id); + + MatmulDescGuard desc_guard; + if (hipblasLtMatmulDescCreate( + &desc_guard.desc, HIPBLAS_COMPUTE_32F, HIP_R_32F) != + HIPBLAS_STATUS_SUCCESS) { + throw std::runtime_error("fp8 GEMM: descriptor create failed"); + } + int32_t ta = op_a, tb = op_b; + hipblasLtMatmulDescSetAttribute( + desc_guard.desc, HIPBLASLT_MATMUL_DESC_TRANSA, &ta, sizeof(ta)); + hipblasLtMatmulDescSetAttribute( + desc_guard.desc, HIPBLASLT_MATMUL_DESC_TRANSB, &tb, sizeof(tb)); + hipblasLtMatmulDescSetAttribute( + desc_guard.desc, + HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, + &a_scale, + sizeof(a_scale)); + hipblasLtMatmulDescSetAttribute( + desc_guard.desc, + HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, + &b_scale, + sizeof(b_scale)); + + hipblasOperation_t oa = static_cast(op_a); + hipblasOperation_t ob = static_cast(op_b); + uint64_t a_rows = (oa == HIPBLAS_OP_N) ? M : K; + uint64_t a_cols = (oa == HIPBLAS_OP_N) ? K : M; + uint64_t b_rows = (ob == HIPBLAS_OP_N) ? K : N; + uint64_t b_cols = (ob == HIPBLAS_OP_N) ? N : K; + MatrixLayoutGuard la, lb, lc, ld; + hipblasLtMatrixLayoutCreate(&la.layout, HIP_R_8F_E4M3, a_rows, a_cols, lda); + hipblasLtMatrixLayoutCreate(&lb.layout, HIP_R_8F_E4M3, b_rows, b_cols, ldb); + hipblasLtMatrixLayoutCreate(&lc.layout, HIP_R_16BF, M, N, ldc); + hipblasLtMatrixLayoutCreate(&ld.layout, HIP_R_16BF, M, N, ldc); + + PreferenceGuard pref_guard; + hipblasLtMatmulPreferenceCreate(&pref_guard.pref); + uint64_t max_ws = kMaxWorkspaceBytes; + hipblasLtMatmulPreferenceSetAttribute( + pref_guard.pref, + HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &max_ws, + sizeof(max_ws)); + + // Best algorithm per (shape, device), tuned once. hipBLASLt's heuristic + // top-pick is poor for fp8; timing all candidates on the first call and + // caching the winner is worth the one-time cost (shapes repeat every layer). + struct Key { + int M, N, K, dev; + bool operator==(const Key& o) const { + return M == o.M && N == o.N && K == o.K && dev == o.dev; + } + }; + struct KeyHash { + size_t operator()(const Key& k) const { + size_t h = 1469598103934665603ULL; + for (int v : {k.M, k.N, k.K, k.dev}) { + h = (h ^ static_cast(v)) * 1099511628211ULL; + } + return h; + } + }; + static std::mutex mtx; + static std::unordered_map + algo_cache; + + Key key{M, N, K, device_id}; + hipblasLtMatmulHeuristicResult_t chosen; + bool hit = false; + { + std::lock_guard lock(mtx); + auto it = algo_cache.find(key); + if (it != algo_cache.end()) { + chosen = it->second; + hit = true; + } + } + + float alpha = 1.0f, beta = 0.0f; + if (!hit) { + static constexpr int kNA = 16; + hipblasLtMatmulHeuristicResult_t res[kNA]; + int cnt = 0; + if (hipblasLtMatmulAlgoGetHeuristic( + handle, + desc_guard.desc, + la.layout, + lb.layout, + lc.layout, + ld.layout, + pref_guard.pref, + kNA, + res, + &cnt) != HIPBLAS_STATUS_SUCCESS || + cnt == 0) { + throw std::runtime_error("fp8 GEMM: no algorithm for shape"); + } + double best = 1e30; + int best_idx = 0; + for (int a = 0; a < cnt; ++a) { + size_t need = res[a].workspaceSize; + void* wp = nullptr; + size_t ws = 0; + if (need > 0) { + auto [p, s] = ensure_workspace(device_id, need); + wp = p; + ws = s; + if (!wp) + continue; + } + if (hipblasLtMatmul( + handle, + desc_guard.desc, + &alpha, + a_ptr, + la.layout, + b_ptr, + lb.layout, + &beta, + c_ptr, + lc.layout, + c_ptr, + ld.layout, + &res[a].algo, + wp, + ws, + stream) != HIPBLAS_STATUS_SUCCESS) { + continue; + } + (void)hipStreamSynchronize(stream); + hipEvent_t e0, e1; + (void)hipEventCreate(&e0); + (void)hipEventCreate(&e1); + (void)hipEventRecord(e0, stream); + for (int r = 0; r < 3; ++r) { + (void)hipblasLtMatmul( + handle, + desc_guard.desc, + &alpha, + a_ptr, + la.layout, + b_ptr, + lb.layout, + &beta, + c_ptr, + lc.layout, + c_ptr, + ld.layout, + &res[a].algo, + wp, + ws, + stream); + } + (void)hipEventRecord(e1, stream); + (void)hipEventSynchronize(e1); + float ms = 0; + (void)hipEventElapsedTime(&ms, e0, e1); + (void)hipEventDestroy(e0); + (void)hipEventDestroy(e1); + if (ms < best) { + best = ms; + best_idx = a; + } + } + chosen = res[best_idx]; + std::lock_guard lock(mtx); + algo_cache[key] = chosen; + } + + size_t need = chosen.workspaceSize; + void* wp = nullptr; + size_t ws = 0; + if (need > 0) { + auto [p, s] = ensure_workspace(device_id, need); + wp = p; + ws = s; + } + if (hipblasLtMatmul( + handle, + desc_guard.desc, + &alpha, + a_ptr, + la.layout, + b_ptr, + lb.layout, + &beta, + c_ptr, + lc.layout, + c_ptr, + ld.layout, + &chosen.algo, + wp, + ws, + stream) != HIPBLAS_STATUS_SUCCESS) { + throw std::runtime_error("fp8 GEMM: hipblasLtMatmul failed"); + } +} + +} // namespace mlx::core::rocm diff --git a/mlx/backend/rocm/gemms/hipblaslt_gemm.h b/mlx/backend/rocm/gemms/hipblaslt_gemm.h new file mode 100644 index 0000000000..5a8ca8b326 --- /dev/null +++ b/mlx/backend/rocm/gemms/hipblaslt_gemm.h @@ -0,0 +1,99 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/rocm/device.h" + +namespace mlx::core::rocm { + +// hipBLASLt GEMM wrapper functions +// hipBLASLt provides optimized GEMM kernels that can outperform rocBLAS +// for half-precision (fp16/bf16) matrix multiplications by using hardware +// matrix cores more efficiently and selecting algorithms via heuristics. + +// Returns true if hipBLASLt is available and usable on the current device. +bool is_hipblaslt_available(); + +void hipblaslt_gemm( + CommandEncoder& encoder, + bool transpose_a, + bool transpose_b, + int M, + int N, + int K, + float alpha, + const array& a, + int lda, + const array& b, + int ldb, + float beta, + array& c, + int ldc, + Dtype dtype); + +void hipblaslt_gemm_batched( + CommandEncoder& encoder, + bool transpose_a, + bool transpose_b, + int M, + int N, + int K, + float alpha, + const array& a, + int lda, + int64_t stride_a, + const array& b, + int ldb, + int64_t stride_b, + float beta, + array& c, + int ldc, + int64_t stride_c, + int batch_count, + Dtype dtype); + +// Raw hipBLASLt GEMM — parameters already in column-major convention +// (A/B swapped, M/N swapped). Call directly from inside kernel lambdas. +void hipblaslt_gemm_raw( + hipStream_t stream, + int op_a, // rocblas_operation / hipblasOperation_t value + int op_b, + int M, + int N, + int K, + const float* alpha, + const void* a_ptr, + int lda, + const void* b_ptr, + int ldb, + const float* beta, + void* c_ptr, + int ldc, + int data_type, // hipDataType value (HIP_R_16BF, HIP_R_16F, HIP_R_32F) + int compute_type); // hipblasComputeType_t value + +// True iff this device has e4m3 fp8 GEMM kernels (probed once, cached). +bool device_has_fp8_gemm(int device_id); + +// Raw fp8 (e4m3) GEMM: A/B are e4m3 buffers in column-major convention, +// a_scale/b_scale are device float scalars applied as descale factors +// (out = a_scale*b_scale * (A@B)), output written as bf16. Picks the fastest +// available algorithm for the shape (heuristic top-pick is poor for fp8). +void hipblaslt_gemm_fp8_raw( + hipStream_t stream, + int op_a, + int op_b, + int M, + int N, + int K, + const void* a_ptr, + int lda, + const void* b_ptr, + int ldb, + void* c_ptr, + int ldc, + const float* a_scale, + const float* b_scale); + +} // namespace mlx::core::rocm diff --git a/mlx/backend/rocm/gemms/naive_gemm.h b/mlx/backend/rocm/gemms/naive_gemm.h new file mode 100644 index 0000000000..610ea29432 --- /dev/null +++ b/mlx/backend/rocm/gemms/naive_gemm.h @@ -0,0 +1,105 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/rocm/device.h" + +namespace mlx::core::rocm { + +// Naive GEMM implementation for when rocBLAS is not available +// C = alpha * op(A) * op(B) + beta * C +// where op(X) = X if not transposed, X^T if transposed +void naive_gemm( + CommandEncoder& encoder, + const array& a, + const array& b, + array& out, + int M, + int N, + int K, + bool a_transposed, + int64_t lda, + bool b_transposed, + int64_t ldb, + float alpha = 1.0f, + float beta = 0.0f); + +// Batched naive GEMM +void naive_gemm_batched( + CommandEncoder& encoder, + const array& a, + const array& b, + array& out, + int M, + int N, + int K, + bool a_transposed, + int64_t lda, + int64_t stride_a, + bool b_transposed, + int64_t ldb, + int64_t stride_b, + int64_t stride_c, + int batch_count, + float alpha = 1.0f, + float beta = 0.0f); + +// Batched gather GEMM where matrix selection is driven by index arrays. +void naive_gemm_gather( + CommandEncoder& encoder, + const array& a, + const array& b, + const array& lhs_indices, + const array& rhs_indices, + array& out, + int M, + int N, + int K, + bool a_transposed, + int64_t lda, + bool b_transposed, + int64_t ldb, + float alpha = 1.0f, + float beta = 0.0f); + +// Naive GEMM with explicit offsets (for non-uniform batch strides) +void naive_gemm_with_offset( + CommandEncoder& encoder, + const array& a, + const array& b, + array& out, + int M, + int N, + int K, + bool a_transposed, + int64_t lda, + int64_t a_offset, + bool b_transposed, + int64_t ldb, + int64_t b_offset, + int64_t out_offset, + float alpha = 1.0f, + float beta = 0.0f); + +// Naive GEMM with explicit offsets and custom ldc (for grouped conv) +void naive_gemm_with_offset_ldc( + CommandEncoder& encoder, + const array& a, + const array& b, + array& out, + int M, + int N, + int K, + bool a_transposed, + int64_t lda, + int64_t a_offset, + bool b_transposed, + int64_t ldb, + int64_t b_offset, + int64_t ldc, + int64_t out_offset, + float alpha = 1.0f, + float beta = 0.0f); + +} // namespace mlx::core::rocm diff --git a/mlx/backend/rocm/gemms/naive_gemm.hip b/mlx/backend/rocm/gemms/naive_gemm.hip new file mode 100644 index 0000000000..ac9b2e21bd --- /dev/null +++ b/mlx/backend/rocm/gemms/naive_gemm.hip @@ -0,0 +1,1011 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/gemms/naive_gemm.h" +#include "mlx/backend/rocm/device/config.h" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/dtype_utils.h" + +#include +#include +#include + +namespace mlx::core::rocm { + +// Tile sizes for the naive GEMM kernel +static constexpr int TILE_M = 16; +static constexpr int TILE_N = 16; +static constexpr int TILE_K = 16; + +// Accumulator type selection +template +struct GemmAccType { + using type = T; +}; + +template <> +struct GemmAccType<__half> { + using type = float; +}; + +template <> +struct GemmAccType { + using type = float; +}; + +// Naive GEMM kernel: C = alpha * A * B + beta * C +// A is M x K, B is K x N, C is M x N +// All matrices are row-major +template +__global__ void naive_gemm_kernel( + const T* __restrict__ A, + const T* __restrict__ B, + T* __restrict__ C, + int M, + int N, + int K, + int lda, + int ldb, + int ldc, + float alpha, + float beta) { + using Acc = typename GemmAccType::type; + + int row = blockIdx.y * TILE_M + threadIdx.y; + int col = blockIdx.x * TILE_N + threadIdx.x; + + if (row < M && col < N) { + Acc sum = Acc(0); + + for (int k = 0; k < K; ++k) { + Acc a_val, b_val; + + if constexpr (TransA) { + a_val = static_cast(A[k * lda + row]); + } else { + a_val = static_cast(A[row * lda + k]); + } + + if constexpr (TransB) { + b_val = static_cast(B[col * ldb + k]); + } else { + b_val = static_cast(B[k * ldb + col]); + } + + sum += a_val * b_val; + } + + if (beta != 0.0f) { + C[row * ldc + col] = static_cast(alpha * sum + beta * static_cast(C[row * ldc + col])); + } else { + C[row * ldc + col] = static_cast(alpha * sum); + } + } +} + +// Tiled GEMM kernel with shared memory for better performance +template +__global__ void tiled_gemm_kernel( + const T* __restrict__ A, + const T* __restrict__ B, + T* __restrict__ C, + int M, + int N, + int K, + int lda, + int ldb, + int ldc, + float alpha, + float beta) { + using Acc = typename GemmAccType::type; + + __shared__ Acc As[TILE_M][TILE_K]; + __shared__ Acc Bs[TILE_K][TILE_N]; + + int bx = blockIdx.x; + int by = blockIdx.y; + int tx = threadIdx.x; + int ty = threadIdx.y; + + int row = by * TILE_M + ty; + int col = bx * TILE_N + tx; + + Acc sum = Acc(0); + + // Loop over tiles + for (int t = 0; t < (K + TILE_K - 1) / TILE_K; ++t) { + // Load A tile into shared memory + int a_col = t * TILE_K + tx; + if (row < M && a_col < K) { + if constexpr (TransA) { + As[ty][tx] = static_cast(A[a_col * lda + row]); + } else { + As[ty][tx] = static_cast(A[row * lda + a_col]); + } + } else { + As[ty][tx] = Acc(0); + } + + // Load B tile into shared memory + int b_row = t * TILE_K + ty; + if (b_row < K && col < N) { + if constexpr (TransB) { + Bs[ty][tx] = static_cast(B[col * ldb + b_row]); + } else { + Bs[ty][tx] = static_cast(B[b_row * ldb + col]); + } + } else { + Bs[ty][tx] = Acc(0); + } + + __syncthreads(); + + // Compute partial dot product + #pragma unroll + for (int k = 0; k < TILE_K; ++k) { + sum += As[ty][k] * Bs[k][tx]; + } + + __syncthreads(); + } + + // Write result + if (row < M && col < N) { + if (beta != 0.0f) { + C[row * ldc + col] = static_cast(alpha * sum + beta * static_cast(C[row * ldc + col])); + } else { + C[row * ldc + col] = static_cast(alpha * sum); + } + } +} + +// Batched GEMM kernel +template +__global__ void batched_gemm_kernel( + const T* __restrict__ A, + const T* __restrict__ B, + T* __restrict__ C, + int M, + int N, + int K, + int lda, + int ldb, + int ldc, + int64_t stride_a, + int64_t stride_b, + int64_t stride_c, + float alpha, + float beta) { + using Acc = typename GemmAccType::type; + + int batch = blockIdx.z; + int row = blockIdx.y * TILE_M + threadIdx.y; + int col = blockIdx.x * TILE_N + threadIdx.x; + + const T* A_batch = A + batch * stride_a; + const T* B_batch = B + batch * stride_b; + T* C_batch = C + batch * stride_c; + + if (row < M && col < N) { + Acc sum = Acc(0); + + for (int k = 0; k < K; ++k) { + Acc a_val, b_val; + + if constexpr (TransA) { + a_val = static_cast(A_batch[k * lda + row]); + } else { + a_val = static_cast(A_batch[row * lda + k]); + } + + if constexpr (TransB) { + b_val = static_cast(B_batch[col * ldb + k]); + } else { + b_val = static_cast(B_batch[k * ldb + col]); + } + + sum += a_val * b_val; + } + + if (beta != 0.0f) { + C_batch[row * ldc + col] = static_cast(alpha * sum + beta * static_cast(C_batch[row * ldc + col])); + } else { + C_batch[row * ldc + col] = static_cast(alpha * sum); + } + } +} + +// Gathered batched GEMM kernel. Each output matrix chooses its lhs/rhs matrix +// from index arrays on device. +template +__global__ void gather_batched_gemm_kernel( + const T* __restrict__ A, + const T* __restrict__ B, + T* __restrict__ C, + const uint32_t* __restrict__ lhs_indices, + const uint32_t* __restrict__ rhs_indices, + Shape idx_batch_shape, + Strides lhs_idx_strides, + Strides rhs_idx_strides, + int idx_batch_ndim, + Shape a_batch_shape, + Strides a_batch_strides, + int a_batch_ndim, + Shape b_batch_shape, + Strides b_batch_strides, + int b_batch_ndim, + int M, + int N, + int K, + int lda, + int ldb, + int64_t stride_c, + float alpha, + float beta) { + using Acc = typename GemmAccType::type; + + int batch = blockIdx.z; + int row = blockIdx.y * TILE_M + threadIdx.y; + int col = blockIdx.x * TILE_N + threadIdx.x; + + int64_t lhs_idx_loc = 0; + int64_t rhs_idx_loc = 0; + if (idx_batch_ndim == 1) { + lhs_idx_loc = static_cast(batch) * lhs_idx_strides[0]; + rhs_idx_loc = static_cast(batch) * rhs_idx_strides[0]; + } else if (idx_batch_ndim > 1) { + elem_to_loc( + static_cast(batch), + idx_batch_shape.data_, + lhs_idx_strides.data_, + rhs_idx_strides.data_, + idx_batch_ndim, + lhs_idx_loc, + rhs_idx_loc); + } + + uint32_t lhs_idx = lhs_indices[lhs_idx_loc]; + uint32_t rhs_idx = rhs_indices[rhs_idx_loc]; + + int64_t a_offset = 0; + int64_t b_offset = 0; + if (a_batch_ndim == 1) { + a_offset = static_cast(lhs_idx) * a_batch_strides[0]; + } else if (a_batch_ndim > 1) { + a_offset = elem_to_loc( + static_cast(lhs_idx), + a_batch_shape.data_, + a_batch_strides.data_, + a_batch_ndim); + } + + if (b_batch_ndim == 1) { + b_offset = static_cast(rhs_idx) * b_batch_strides[0]; + } else if (b_batch_ndim > 1) { + b_offset = elem_to_loc( + static_cast(rhs_idx), + b_batch_shape.data_, + b_batch_strides.data_, + b_batch_ndim); + } + + const T* A_batch = A + a_offset; + const T* B_batch = B + b_offset; + T* C_batch = C + static_cast(batch) * stride_c; + + if (row < M && col < N) { + Acc sum = Acc(0); + + for (int k = 0; k < K; ++k) { + Acc a_val; + Acc b_val; + + if constexpr (TransA) { + a_val = static_cast(A_batch[k * lda + row]); + } else { + a_val = static_cast(A_batch[row * lda + k]); + } + + if constexpr (TransB) { + b_val = static_cast(B_batch[col * ldb + k]); + } else { + b_val = static_cast(B_batch[k * ldb + col]); + } + + sum += a_val * b_val; + } + + if (beta != 0.0f) { + C_batch[row * N + col] = static_cast( + alpha * sum + beta * static_cast(C_batch[row * N + col])); + } else { + C_batch[row * N + col] = static_cast(alpha * sum); + } + } +} + +template +void launch_naive_gemm( + hipStream_t stream, + const T* A, + const T* B, + T* C, + int M, + int N, + int K, + int lda, + int ldb, + int ldc, + bool trans_a, + bool trans_b, + float alpha, + float beta) { + dim3 block(TILE_N, TILE_M); + dim3 grid((N + TILE_N - 1) / TILE_N, (M + TILE_M - 1) / TILE_M); + + // Use tiled kernel for larger matrices, naive for smaller ones + bool use_tiled = (M >= 32 && N >= 32 && K >= 32); + + if (trans_a && trans_b) { + if (use_tiled) { + hipLaunchKernelGGL((tiled_gemm_kernel), + grid, block, 0, stream, + A, B, C, M, N, K, lda, ldb, ldc, alpha, beta); + } else { + hipLaunchKernelGGL((naive_gemm_kernel), + grid, block, 0, stream, + A, B, C, M, N, K, lda, ldb, ldc, alpha, beta); + } + } else if (trans_a && !trans_b) { + if (use_tiled) { + hipLaunchKernelGGL((tiled_gemm_kernel), + grid, block, 0, stream, + A, B, C, M, N, K, lda, ldb, ldc, alpha, beta); + } else { + hipLaunchKernelGGL((naive_gemm_kernel), + grid, block, 0, stream, + A, B, C, M, N, K, lda, ldb, ldc, alpha, beta); + } + } else if (!trans_a && trans_b) { + if (use_tiled) { + hipLaunchKernelGGL((tiled_gemm_kernel), + grid, block, 0, stream, + A, B, C, M, N, K, lda, ldb, ldc, alpha, beta); + } else { + hipLaunchKernelGGL((naive_gemm_kernel), + grid, block, 0, stream, + A, B, C, M, N, K, lda, ldb, ldc, alpha, beta); + } + } else { + if (use_tiled) { + hipLaunchKernelGGL((tiled_gemm_kernel), + grid, block, 0, stream, + A, B, C, M, N, K, lda, ldb, ldc, alpha, beta); + } else { + hipLaunchKernelGGL((naive_gemm_kernel), + grid, block, 0, stream, + A, B, C, M, N, K, lda, ldb, ldc, alpha, beta); + } + } +} + +template +void launch_batched_gemm( + hipStream_t stream, + const T* A, + const T* B, + T* C, + int M, + int N, + int K, + int lda, + int ldb, + int ldc, + int64_t stride_a, + int64_t stride_b, + int64_t stride_c, + int batch_count, + bool trans_a, + bool trans_b, + float alpha, + float beta) { + dim3 block(TILE_N, TILE_M); + dim3 grid((N + TILE_N - 1) / TILE_N, (M + TILE_M - 1) / TILE_M, batch_count); + + if (trans_a && trans_b) { + hipLaunchKernelGGL((batched_gemm_kernel), + grid, block, 0, stream, + A, B, C, M, N, K, lda, ldb, ldc, stride_a, stride_b, stride_c, alpha, beta); + } else if (trans_a && !trans_b) { + hipLaunchKernelGGL((batched_gemm_kernel), + grid, block, 0, stream, + A, B, C, M, N, K, lda, ldb, ldc, stride_a, stride_b, stride_c, alpha, beta); + } else if (!trans_a && trans_b) { + hipLaunchKernelGGL((batched_gemm_kernel), + grid, block, 0, stream, + A, B, C, M, N, K, lda, ldb, ldc, stride_a, stride_b, stride_c, alpha, beta); + } else { + hipLaunchKernelGGL((batched_gemm_kernel), + grid, block, 0, stream, + A, B, C, M, N, K, lda, ldb, ldc, stride_a, stride_b, stride_c, alpha, beta); + } +} + +template +void launch_gather_batched_gemm( + hipStream_t stream, + const T* A, + const T* B, + T* C, + const uint32_t* lhs_indices, + const uint32_t* rhs_indices, + Shape idx_batch_shape, + Strides lhs_idx_strides, + Strides rhs_idx_strides, + int idx_batch_ndim, + Shape a_batch_shape, + Strides a_batch_strides, + int a_batch_ndim, + Shape b_batch_shape, + Strides b_batch_strides, + int b_batch_ndim, + int M, + int N, + int K, + int lda, + int ldb, + int64_t stride_c, + int batch_count, + bool trans_a, + bool trans_b, + float alpha, + float beta) { + dim3 block(TILE_N, TILE_M); + dim3 grid((N + TILE_N - 1) / TILE_N, (M + TILE_M - 1) / TILE_M, batch_count); + + if (trans_a && trans_b) { + hipLaunchKernelGGL( + (gather_batched_gemm_kernel), + grid, + block, + 0, + stream, + A, + B, + C, + lhs_indices, + rhs_indices, + idx_batch_shape, + lhs_idx_strides, + rhs_idx_strides, + idx_batch_ndim, + a_batch_shape, + a_batch_strides, + a_batch_ndim, + b_batch_shape, + b_batch_strides, + b_batch_ndim, + M, + N, + K, + lda, + ldb, + stride_c, + alpha, + beta); + } else if (trans_a && !trans_b) { + hipLaunchKernelGGL( + (gather_batched_gemm_kernel), + grid, + block, + 0, + stream, + A, + B, + C, + lhs_indices, + rhs_indices, + idx_batch_shape, + lhs_idx_strides, + rhs_idx_strides, + idx_batch_ndim, + a_batch_shape, + a_batch_strides, + a_batch_ndim, + b_batch_shape, + b_batch_strides, + b_batch_ndim, + M, + N, + K, + lda, + ldb, + stride_c, + alpha, + beta); + } else if (!trans_a && trans_b) { + hipLaunchKernelGGL( + (gather_batched_gemm_kernel), + grid, + block, + 0, + stream, + A, + B, + C, + lhs_indices, + rhs_indices, + idx_batch_shape, + lhs_idx_strides, + rhs_idx_strides, + idx_batch_ndim, + a_batch_shape, + a_batch_strides, + a_batch_ndim, + b_batch_shape, + b_batch_strides, + b_batch_ndim, + M, + N, + K, + lda, + ldb, + stride_c, + alpha, + beta); + } else { + hipLaunchKernelGGL( + (gather_batched_gemm_kernel), + grid, + block, + 0, + stream, + A, + B, + C, + lhs_indices, + rhs_indices, + idx_batch_shape, + lhs_idx_strides, + rhs_idx_strides, + idx_batch_ndim, + a_batch_shape, + a_batch_strides, + a_batch_ndim, + b_batch_shape, + b_batch_strides, + b_batch_ndim, + M, + N, + K, + lda, + ldb, + stride_c, + alpha, + beta); + } +} + +void naive_gemm( + CommandEncoder& encoder, + const array& a, + const array& b, + array& out, + int M, + int N, + int K, + bool a_transposed, + int64_t lda, + bool b_transposed, + int64_t ldb, + float alpha, + float beta) { + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); + + int ldc = N; + const void* a_ptr = gpu_ptr(a); + const void* b_ptr = gpu_ptr(b); + void* out_ptr = gpu_ptr(out); + + encoder.launch_kernel([&, a_ptr, b_ptr, out_ptr](hipStream_t stream) { + switch (a.dtype()) { + case float32: + launch_naive_gemm( + stream, + static_cast(a_ptr), + static_cast(b_ptr), + static_cast(out_ptr), + M, N, K, lda, ldb, ldc, + a_transposed, b_transposed, alpha, beta); + break; + case float64: + launch_naive_gemm( + stream, + static_cast(a_ptr), + static_cast(b_ptr), + static_cast(out_ptr), + M, N, K, lda, ldb, ldc, + a_transposed, b_transposed, alpha, beta); + break; + case float16: + launch_naive_gemm<__half>( + stream, + static_cast(a_ptr), + static_cast(b_ptr), + static_cast<__half*>(out_ptr), + M, N, K, lda, ldb, ldc, + a_transposed, b_transposed, alpha, beta); + break; + case bfloat16: + launch_naive_gemm( + stream, + static_cast(a_ptr), + static_cast(b_ptr), + static_cast(out_ptr), + M, N, K, lda, ldb, ldc, + a_transposed, b_transposed, alpha, beta); + break; + default: + throw std::runtime_error("Unsupported dtype for naive GEMM"); + } + }); +} + +void naive_gemm_batched( + CommandEncoder& encoder, + const array& a, + const array& b, + array& out, + int M, + int N, + int K, + bool a_transposed, + int64_t lda, + int64_t stride_a, + bool b_transposed, + int64_t ldb, + int64_t stride_b, + int64_t stride_c, + int batch_count, + float alpha, + float beta) { + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); + + int ldc = N; + const void* a_ptr = gpu_ptr(a); + const void* b_ptr = gpu_ptr(b); + void* out_ptr = gpu_ptr(out); + + encoder.launch_kernel([&, a_ptr, b_ptr, out_ptr](hipStream_t stream) { + switch (a.dtype()) { + case float32: + launch_batched_gemm( + stream, + static_cast(a_ptr), + static_cast(b_ptr), + static_cast(out_ptr), + M, N, K, lda, ldb, ldc, + stride_a, stride_b, stride_c, batch_count, + a_transposed, b_transposed, alpha, beta); + break; + case float64: + launch_batched_gemm( + stream, + static_cast(a_ptr), + static_cast(b_ptr), + static_cast(out_ptr), + M, N, K, lda, ldb, ldc, + stride_a, stride_b, stride_c, batch_count, + a_transposed, b_transposed, alpha, beta); + break; + case float16: + launch_batched_gemm<__half>( + stream, + static_cast(a_ptr), + static_cast(b_ptr), + static_cast<__half*>(out_ptr), + M, N, K, lda, ldb, ldc, + stride_a, stride_b, stride_c, batch_count, + a_transposed, b_transposed, alpha, beta); + break; + case bfloat16: + launch_batched_gemm( + stream, + static_cast(a_ptr), + static_cast(b_ptr), + static_cast(out_ptr), + M, N, K, lda, ldb, ldc, + stride_a, stride_b, stride_c, batch_count, + a_transposed, b_transposed, alpha, beta); + break; + default: + throw std::runtime_error("Unsupported dtype for batched naive GEMM"); + } + }); +} + +void naive_gemm_gather( + CommandEncoder& encoder, + const array& a, + const array& b, + const array& lhs_indices, + const array& rhs_indices, + array& out, + int M, + int N, + int K, + bool a_transposed, + int64_t lda, + bool b_transposed, + int64_t ldb, + float alpha, + float beta) { + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_input_array(lhs_indices); + encoder.set_input_array(rhs_indices); + encoder.set_output_array(out); + + auto [idx_batch_shape, idx_batch_strides] = collapse_contiguous_dims( + lhs_indices.shape(), {lhs_indices.strides(), rhs_indices.strides()}); + auto lhs_idx_strides = idx_batch_strides[0]; + auto rhs_idx_strides = idx_batch_strides[1]; + int idx_batch_ndim = idx_batch_shape.size(); + + mlx::core::Shape a_batch_shape{a.shape().begin(), a.shape().end() - 2}; + mlx::core::Strides a_batch_strides{a.strides().begin(), a.strides().end() - 2}; + int a_batch_ndim = a_batch_shape.size(); + + mlx::core::Shape b_batch_shape{b.shape().begin(), b.shape().end() - 2}; + mlx::core::Strides b_batch_strides{b.strides().begin(), b.strides().end() - 2}; + int b_batch_ndim = b_batch_shape.size(); + + auto idx_batch_shape_param = const_param(idx_batch_shape); + auto lhs_idx_strides_param = const_param(lhs_idx_strides); + auto rhs_idx_strides_param = const_param(rhs_idx_strides); + + auto a_batch_shape_param = const_param(a_batch_shape); + auto a_batch_strides_param = const_param(a_batch_strides); + auto b_batch_shape_param = const_param(b_batch_shape); + auto b_batch_strides_param = const_param(b_batch_strides); + + const int64_t stride_c = static_cast(M) * N; + const int batch_count = out.size() / (M * N); + + const void* a_ptr = gpu_ptr(a); + const void* b_ptr = gpu_ptr(b); + void* out_ptr = gpu_ptr(out); + const uint32_t* lhs_indices_ptr = gpu_ptr(lhs_indices); + const uint32_t* rhs_indices_ptr = gpu_ptr(rhs_indices); + + encoder.launch_kernel([&, + a_ptr, + b_ptr, + out_ptr, + lhs_indices_ptr, + rhs_indices_ptr](hipStream_t stream) { + switch (a.dtype()) { + case float32: + launch_gather_batched_gemm( + stream, + static_cast(a_ptr), + static_cast(b_ptr), + static_cast(out_ptr), + lhs_indices_ptr, + rhs_indices_ptr, + idx_batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + idx_batch_ndim, + a_batch_shape_param, + a_batch_strides_param, + a_batch_ndim, + b_batch_shape_param, + b_batch_strides_param, + b_batch_ndim, + M, + N, + K, + lda, + ldb, + stride_c, + batch_count, + a_transposed, + b_transposed, + alpha, + beta); + break; + case float64: + launch_gather_batched_gemm( + stream, + static_cast(a_ptr), + static_cast(b_ptr), + static_cast(out_ptr), + lhs_indices_ptr, + rhs_indices_ptr, + idx_batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + idx_batch_ndim, + a_batch_shape_param, + a_batch_strides_param, + a_batch_ndim, + b_batch_shape_param, + b_batch_strides_param, + b_batch_ndim, + M, + N, + K, + lda, + ldb, + stride_c, + batch_count, + a_transposed, + b_transposed, + alpha, + beta); + break; + case float16: + launch_gather_batched_gemm<__half>( + stream, + static_cast(a_ptr), + static_cast(b_ptr), + static_cast<__half*>(out_ptr), + lhs_indices_ptr, + rhs_indices_ptr, + idx_batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + idx_batch_ndim, + a_batch_shape_param, + a_batch_strides_param, + a_batch_ndim, + b_batch_shape_param, + b_batch_strides_param, + b_batch_ndim, + M, + N, + K, + lda, + ldb, + stride_c, + batch_count, + a_transposed, + b_transposed, + alpha, + beta); + break; + case bfloat16: + launch_gather_batched_gemm( + stream, + static_cast(a_ptr), + static_cast(b_ptr), + static_cast(out_ptr), + lhs_indices_ptr, + rhs_indices_ptr, + idx_batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + idx_batch_ndim, + a_batch_shape_param, + a_batch_strides_param, + a_batch_ndim, + b_batch_shape_param, + b_batch_strides_param, + b_batch_ndim, + M, + N, + K, + lda, + ldb, + stride_c, + batch_count, + a_transposed, + b_transposed, + alpha, + beta); + break; + default: + throw std::runtime_error("Unsupported dtype for gathered naive GEMM"); + } + }); +} + +void naive_gemm_with_offset( + CommandEncoder& encoder, + const array& a, + const array& b, + array& out, + int M, + int N, + int K, + bool a_transposed, + int64_t lda, + int64_t a_offset, + bool b_transposed, + int64_t ldb, + int64_t b_offset, + int64_t out_offset, + float alpha, + float beta) { + // Default ldc = N (contiguous output) + naive_gemm_with_offset_ldc( + encoder, a, b, out, M, N, K, + a_transposed, lda, a_offset, + b_transposed, ldb, b_offset, + N, out_offset, alpha, beta); +} + +void naive_gemm_with_offset_ldc( + CommandEncoder& encoder, + const array& a, + const array& b, + array& out, + int M, + int N, + int K, + bool a_transposed, + int64_t lda, + int64_t a_offset, + bool b_transposed, + int64_t ldb, + int64_t b_offset, + int64_t ldc, + int64_t out_offset, + float alpha, + float beta) { + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); + const void* a_ptr = gpu_ptr(a); + const void* b_ptr = gpu_ptr(b); + void* out_ptr = gpu_ptr(out); + + encoder.launch_kernel([&, a_ptr, b_ptr, out_ptr](hipStream_t stream) { + switch (a.dtype()) { + case float32: + launch_naive_gemm( + stream, + static_cast(a_ptr) + a_offset, + static_cast(b_ptr) + b_offset, + static_cast(out_ptr) + out_offset, + M, N, K, lda, ldb, ldc, + a_transposed, b_transposed, alpha, beta); + break; + case float64: + launch_naive_gemm( + stream, + static_cast(a_ptr) + a_offset, + static_cast(b_ptr) + b_offset, + static_cast(out_ptr) + out_offset, + M, N, K, lda, ldb, ldc, + a_transposed, b_transposed, alpha, beta); + break; + case float16: + launch_naive_gemm<__half>( + stream, + static_cast(a_ptr) + a_offset, + static_cast(b_ptr) + b_offset, + static_cast<__half*>(out_ptr) + out_offset, + M, N, K, lda, ldb, ldc, + a_transposed, b_transposed, alpha, beta); + break; + case bfloat16: + launch_naive_gemm( + stream, + static_cast(a_ptr) + a_offset, + static_cast(b_ptr) + b_offset, + static_cast(out_ptr) + out_offset, + M, N, K, lda, ldb, ldc, + a_transposed, b_transposed, alpha, beta); + break; + default: + throw std::runtime_error("Unsupported dtype for naive GEMM with offset"); + } + }); +} + +} // namespace mlx::core::rocm diff --git a/mlx/backend/rocm/gemms/rocblas_gemm.cpp b/mlx/backend/rocm/gemms/rocblas_gemm.cpp new file mode 100644 index 0000000000..4c68e70209 --- /dev/null +++ b/mlx/backend/rocm/gemms/rocblas_gemm.cpp @@ -0,0 +1,549 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/gemms/rocblas_gemm.h" +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/gemms/naive_gemm.h" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/types/half_types.h" + +#include +#include +#include + +#include +#include +#include + +namespace mlx::core::rocm { + +namespace { + +rocblas_operation to_rocblas_op(bool transpose) { + return transpose ? rocblas_operation_transpose : rocblas_operation_none; +} + +rocblas_datatype to_rocblas_dtype(Dtype dtype) { + switch (dtype) { + case float32: + return rocblas_datatype_f32_r; + case float16: + return rocblas_datatype_f16_r; + case bfloat16: + return rocblas_datatype_bf16_r; + default: + throw std::runtime_error("Unsupported dtype for rocBLAS GEMM"); + } +} + +int parse_non_negative_int_env(const char* env_name, int default_value) { + const char* raw = std::getenv(env_name); + if (raw == nullptr || *raw == '\0') { + return default_value; + } + + char* end = nullptr; + long value = std::strtol(raw, &end, 10); + if (end == raw || *end != '\0' || value < 0) { + return default_value; + } + return static_cast(value); +} + +int gemm_solution_index_f32(bool batched) { + static int single_index = + parse_non_negative_int_env("MLX_ROCM_GEMM_F32_SOLUTION_INDEX", 0); + static int batched_index = parse_non_negative_int_env( + "MLX_ROCM_GEMM_F32_BATCHED_SOLUTION_INDEX", -1); + if (!batched) { + return single_index; + } + return batched_index >= 0 ? batched_index : single_index; +} + +int gemm_solution_index_bf16(bool batched) { + static int single_index = + parse_non_negative_int_env("MLX_ROCM_GEMM_BF16_SOLUTION_INDEX", 0); + static int batched_index = parse_non_negative_int_env( + "MLX_ROCM_GEMM_BF16_BATCHED_SOLUTION_INDEX", -1); + if (!batched) { + return single_index; + } + return batched_index >= 0 ? batched_index : single_index; +} + +} // namespace + +void rocblas_gemm( + CommandEncoder& encoder, + bool transpose_a, + bool transpose_b, + int M, + int N, + int K, + float alpha, + const array& a, + int lda, + const array& b, + int ldb, + float beta, + array& c, + int ldc, + Dtype dtype) { + // Check if rocBLAS is available + if (!encoder.device().is_rocblas_available()) { + // Use naive GEMM fallback + naive_gemm( + encoder, + a, + b, + c, + M, + N, + K, + transpose_a, + lda, + transpose_b, + ldb, + alpha, + beta); + return; + } + + const void* a_ptr = gpu_ptr(a); + const void* b_ptr = gpu_ptr(b); + void* c_ptr = gpu_ptr(c); + + encoder.launch_kernel([&, a_ptr, b_ptr, c_ptr](hipStream_t stream) { + encoder.device().set_rocblas_stream(stream); + rocblas_handle handle = encoder.device().get_rocblas_handle(); + + rocblas_operation op_a = to_rocblas_op(transpose_a); + rocblas_operation op_b = to_rocblas_op(transpose_b); + + switch (dtype) { + case float32: { + float alpha_f = alpha; + float beta_f = beta; + int solution_index = gemm_solution_index_f32(false); + static std::atomic solution_valid{true}; + + if (solution_index > 0 && + solution_valid.load(std::memory_order_relaxed)) { + rocblas_status status = rocblas_gemm_ex( + handle, + op_b, + op_a, + N, + M, + K, + &alpha_f, + b_ptr, + rocblas_datatype_f32_r, + ldb, + a_ptr, + rocblas_datatype_f32_r, + lda, + &beta_f, + c_ptr, + rocblas_datatype_f32_r, + ldc, + c_ptr, + rocblas_datatype_f32_r, + ldc, + rocblas_datatype_f32_r, + rocblas_gemm_algo_solution_index, + solution_index, + 0); + if (status != rocblas_status_success) { + solution_valid.store(false, std::memory_order_relaxed); + rocblas_sgemm( + handle, + op_b, + op_a, + N, + M, + K, + &alpha_f, + static_cast(b_ptr), + ldb, + static_cast(a_ptr), + lda, + &beta_f, + static_cast(c_ptr), + ldc); + } + } else { + rocblas_sgemm( + handle, + op_b, + op_a, + N, + M, + K, + &alpha_f, + static_cast(b_ptr), + ldb, + static_cast(a_ptr), + lda, + &beta_f, + static_cast(c_ptr), + ldc); + } + break; + } + case float16: { + rocblas_half alpha_h, beta_h; + float16_t alpha_f16 = static_cast(alpha); + float16_t beta_f16 = static_cast(beta); + std::memcpy(&alpha_h, &alpha_f16, sizeof(rocblas_half)); + std::memcpy(&beta_h, &beta_f16, sizeof(rocblas_half)); + rocblas_hgemm( + handle, + op_b, + op_a, + N, + M, + K, + &alpha_h, + reinterpret_cast( + static_cast(b_ptr)), + ldb, + reinterpret_cast( + static_cast(a_ptr)), + lda, + &beta_h, + reinterpret_cast(static_cast(c_ptr)), + ldc); + break; + } + case bfloat16: { + float alpha_f = alpha; + float beta_f = beta; + int solution_index = gemm_solution_index_bf16(false); + static std::atomic solution_valid{true}; + + rocblas_gemm_algo algo = rocblas_gemm_algo_standard; + if (solution_index > 0 && + solution_valid.load(std::memory_order_relaxed)) { + algo = rocblas_gemm_algo_solution_index; + } else { + solution_index = 0; + } + + rocblas_status status = rocblas_gemm_ex( + handle, + op_b, + op_a, + N, + M, + K, + &alpha_f, + b_ptr, + rocblas_datatype_bf16_r, + ldb, + a_ptr, + rocblas_datatype_bf16_r, + lda, + &beta_f, + c_ptr, + rocblas_datatype_bf16_r, + ldc, + c_ptr, + rocblas_datatype_bf16_r, + ldc, + rocblas_datatype_f32_r, + algo, + solution_index, + 0); + if (status != rocblas_status_success && + algo == rocblas_gemm_algo_solution_index) { + solution_valid.store(false, std::memory_order_relaxed); + rocblas_gemm_ex( + handle, + op_b, + op_a, + N, + M, + K, + &alpha_f, + b_ptr, + rocblas_datatype_bf16_r, + ldb, + a_ptr, + rocblas_datatype_bf16_r, + lda, + &beta_f, + c_ptr, + rocblas_datatype_bf16_r, + ldc, + c_ptr, + rocblas_datatype_bf16_r, + ldc, + rocblas_datatype_f32_r, + rocblas_gemm_algo_standard, + 0, + 0); + } + break; + } + default: + throw std::runtime_error("Unsupported dtype for rocBLAS GEMM"); + } + }); +} + +void rocblas_gemm_batched( + CommandEncoder& encoder, + bool transpose_a, + bool transpose_b, + int M, + int N, + int K, + float alpha, + const array& a, + int lda, + int64_t stride_a, + const array& b, + int ldb, + int64_t stride_b, + float beta, + array& c, + int ldc, + int64_t stride_c, + int batch_count, + Dtype dtype) { + // Check if rocBLAS is available + if (!encoder.device().is_rocblas_available()) { + // Use naive batched GEMM fallback + naive_gemm_batched( + encoder, + a, + b, + c, + M, + N, + K, + transpose_a, + lda, + stride_a, + transpose_b, + ldb, + stride_b, + stride_c, + batch_count, + alpha, + beta); + return; + } + + const void* a_ptr = gpu_ptr(a); + const void* b_ptr = gpu_ptr(b); + void* c_ptr = gpu_ptr(c); + + encoder.launch_kernel([&, a_ptr, b_ptr, c_ptr](hipStream_t stream) { + encoder.device().set_rocblas_stream(stream); + rocblas_handle handle = encoder.device().get_rocblas_handle(); + + rocblas_operation op_a = to_rocblas_op(transpose_a); + rocblas_operation op_b = to_rocblas_op(transpose_b); + + switch (dtype) { + case float32: { + float alpha_f = alpha; + float beta_f = beta; + int solution_index = gemm_solution_index_f32(true); + static std::atomic solution_valid{true}; + + if (solution_index > 0 && + solution_valid.load(std::memory_order_relaxed)) { + rocblas_status status = rocblas_gemm_strided_batched_ex( + handle, + op_b, + op_a, + N, + M, + K, + &alpha_f, + b_ptr, + rocblas_datatype_f32_r, + ldb, + stride_b, + a_ptr, + rocblas_datatype_f32_r, + lda, + stride_a, + &beta_f, + c_ptr, + rocblas_datatype_f32_r, + ldc, + stride_c, + c_ptr, + rocblas_datatype_f32_r, + ldc, + stride_c, + batch_count, + rocblas_datatype_f32_r, + rocblas_gemm_algo_solution_index, + solution_index, + 0); + if (status != rocblas_status_success) { + solution_valid.store(false, std::memory_order_relaxed); + rocblas_sgemm_strided_batched( + handle, + op_b, + op_a, + N, + M, + K, + &alpha_f, + static_cast(b_ptr), + ldb, + stride_b, + static_cast(a_ptr), + lda, + stride_a, + &beta_f, + static_cast(c_ptr), + ldc, + stride_c, + batch_count); + } + } else { + rocblas_sgemm_strided_batched( + handle, + op_b, + op_a, + N, + M, + K, + &alpha_f, + static_cast(b_ptr), + ldb, + stride_b, + static_cast(a_ptr), + lda, + stride_a, + &beta_f, + static_cast(c_ptr), + ldc, + stride_c, + batch_count); + } + break; + } + case float16: { + rocblas_half alpha_h, beta_h; + float16_t alpha_f16 = static_cast(alpha); + float16_t beta_f16 = static_cast(beta); + std::memcpy(&alpha_h, &alpha_f16, sizeof(rocblas_half)); + std::memcpy(&beta_h, &beta_f16, sizeof(rocblas_half)); + rocblas_hgemm_strided_batched( + handle, + op_b, + op_a, + N, + M, + K, + &alpha_h, + reinterpret_cast( + static_cast(b_ptr)), + ldb, + stride_b, + reinterpret_cast( + static_cast(a_ptr)), + lda, + stride_a, + &beta_h, + reinterpret_cast(static_cast(c_ptr)), + ldc, + stride_c, + batch_count); + break; + } + case bfloat16: { + float alpha_f = alpha; + float beta_f = beta; + int solution_index = gemm_solution_index_bf16(true); + static std::atomic solution_valid{true}; + + rocblas_gemm_algo algo = rocblas_gemm_algo_standard; + if (solution_index > 0 && + solution_valid.load(std::memory_order_relaxed)) { + algo = rocblas_gemm_algo_solution_index; + } else { + solution_index = 0; + } + + rocblas_status status = rocblas_gemm_strided_batched_ex( + handle, + op_b, + op_a, + N, + M, + K, + &alpha_f, + b_ptr, + rocblas_datatype_bf16_r, + ldb, + stride_b, + a_ptr, + rocblas_datatype_bf16_r, + lda, + stride_a, + &beta_f, + c_ptr, + rocblas_datatype_bf16_r, + ldc, + stride_c, + c_ptr, + rocblas_datatype_bf16_r, + ldc, + stride_c, + batch_count, + rocblas_datatype_f32_r, + algo, + solution_index, + 0); + if (status != rocblas_status_success && + algo == rocblas_gemm_algo_solution_index) { + solution_valid.store(false, std::memory_order_relaxed); + rocblas_gemm_strided_batched_ex( + handle, + op_b, + op_a, + N, + M, + K, + &alpha_f, + b_ptr, + rocblas_datatype_bf16_r, + ldb, + stride_b, + a_ptr, + rocblas_datatype_bf16_r, + lda, + stride_a, + &beta_f, + c_ptr, + rocblas_datatype_bf16_r, + ldc, + stride_c, + c_ptr, + rocblas_datatype_bf16_r, + ldc, + stride_c, + batch_count, + rocblas_datatype_f32_r, + rocblas_gemm_algo_standard, + 0, + 0); + } + break; + } + default: + throw std::runtime_error("Unsupported dtype for rocBLAS batched GEMM"); + } + }); +} + +} // namespace mlx::core::rocm diff --git a/mlx/backend/rocm/gemms/rocblas_gemm.h b/mlx/backend/rocm/gemms/rocblas_gemm.h new file mode 100644 index 0000000000..56ac79c454 --- /dev/null +++ b/mlx/backend/rocm/gemms/rocblas_gemm.h @@ -0,0 +1,52 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/rocm/device.h" + +#include + +namespace mlx::core::rocm { + +// rocBLAS GEMM wrapper functions + +void rocblas_gemm( + CommandEncoder& encoder, + bool transpose_a, + bool transpose_b, + int M, + int N, + int K, + float alpha, + const array& a, + int lda, + const array& b, + int ldb, + float beta, + array& c, + int ldc, + Dtype dtype); + +void rocblas_gemm_batched( + CommandEncoder& encoder, + bool transpose_a, + bool transpose_b, + int M, + int N, + int K, + float alpha, + const array& a, + int lda, + int64_t stride_a, + const array& b, + int ldb, + int64_t stride_b, + float beta, + array& c, + int ldc, + int64_t stride_c, + int batch_count, + Dtype dtype); + +} // namespace mlx::core::rocm diff --git a/mlx/backend/rocm/indexing.hip b/mlx/backend/rocm/indexing.hip new file mode 100644 index 0000000000..69c7463fc7 --- /dev/null +++ b/mlx/backend/rocm/indexing.hip @@ -0,0 +1,1576 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/allocator.h" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/backend/rocm/jit_module.h" +#include "mlx/backend/rocm/device/indexing.hpp" +#include "mlx/backend/rocm/device/binary_ops.hpp" +#include "mlx/backend/rocm/device/utils.hpp" +#include "mlx/backend/gpu/copy.h" +#include "mlx/backend/common/slicing.h" +#include "mlx/backend/common/utils.h" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include + +#include +#include +#include + +namespace mlx::core { + +namespace rocm { + +// General gather kernel - handles arbitrary indexing +template +__global__ void gather_general_kernel( + const T* src, + T* out, + int64_t size, + // Metadata passed BY VALUE (hip_array in kernel args) rather than via device + // pointers. The previous by-pointer form required uploading these to device + // buffers via hipMemcpyAsync; under HIP graph capture those H2D nodes record + // the (transient) host source pointer and read freed memory on replay, + // producing a garbage source offset -> out-of-bounds read -> GPU queue hang + // on RDNA4 (gfx1201). By-value metadata is captured correctly and replays. + hip_array src_shape, + hip_array src_strides, + int32_t src_ndim, + hip_array slice_sizes, + uint32_t slice_size, + hip_array axes, + hip_array indices, + hip_array indices_shape, + hip_array indices_strides, + int32_t idx_ndim) { + int64_t out_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (out_idx >= size) { + return; + } + + int64_t src_elem = out_idx % slice_size; + int64_t idx_elem = out_idx / slice_size; + + // Compute source location from slice element + int64_t src_loc = 0; + int64_t tmp = src_elem; + for (int i = src_ndim - 1; i >= 0; --i) { + src_loc += (tmp % slice_sizes[i]) * src_strides[i]; + tmp /= slice_sizes[i]; + } + + // Add index contributions + for (int i = 0; i < NIDX; ++i) { + // Compute index location + int64_t idx_loc = 0; + int64_t tmp_idx = idx_elem; + for (int j = idx_ndim - 1; j >= 0; --j) { + idx_loc += (tmp_idx % indices_shape[i * idx_ndim + j]) * indices_strides[i * idx_ndim + j]; + tmp_idx /= indices_shape[i * idx_ndim + j]; + } + + int32_t axis = axes[i]; + IdxT idx_val = indices[i][idx_loc]; + + // Handle negative indices + if (idx_val < 0) { + idx_val += src_shape[axis]; + } + + src_loc += idx_val * src_strides[axis]; + } + + out[out_idx] = src[src_loc]; +} + +// Fast contiguous row gather: out[row, :] = src[idx[row], :] where each row is a +// contiguous block of `row_size` elements. One block per output row, coalesced +// copy, with the source-row base computed once (no per-element index math). +// Covers the common axis-0 gather of a row-contiguous source (e.g. the MoE +// token reorder), which the general kernel does per-element with mod/div loops. +template +__global__ void gather_rows_kernel( + const T* src, + const IdxT* idx, + T* out, + int64_t n_rows, + uint32_t row_size, + int32_t src_dim0) { + int64_t row = blockIdx.x; + if (row >= n_rows) { + return; + } + int64_t r = static_cast(idx[row]); + if (r < 0) { + r += src_dim0; + } + const T* srow = src + r * static_cast(row_size); + T* orow = out + row * static_cast(row_size); + for (uint32_t e = threadIdx.x; e < row_size; e += blockDim.x) { + orow[e] = srow[e]; + } +} + +// Simple gather kernel for axis-based gather (for contiguous arrays) +template +__global__ void gather_axis_kernel( + const T* src, + const IdxT* idx, + T* out, + int64_t idx_size_pre, + int64_t idx_size_axis, + int64_t idx_size_post, + const hip_array shape, + const hip_array src_strides, + const hip_array idx_strides, + int32_t axis, + int32_t axis_size, + int64_t src_stride_axis, + int64_t idx_stride_axis) { + int64_t index = blockIdx.x * blockDim.x + threadIdx.x; + int64_t total = idx_size_pre * idx_size_axis * idx_size_post; + if (index >= total) return; + + // Decompose index into x (post), y (axis), z (pre) coordinates + int64_t x, y, z; + index_to_dims(index, idx_size_axis, idx_size_pre, x, y, z); + + int64_t elem_idx = z * idx_size_post; + + // Compute index location + int64_t idx_loc = y * idx_stride_axis; + if constexpr (IdxC) { + idx_loc += elem_idx * idx_size_axis + x; + } else { + idx_loc += elem_to_loc_nd(elem_idx + x, shape.data_, idx_strides.data_); + } + + // Get index value and handle negative indices + IdxT idx_val = idx[idx_loc]; + if (idx_val < 0) { + idx_val += axis_size; + } + + // Compute source location + int64_t src_loc = idx_val * src_stride_axis; + if constexpr (SrcC) { + src_loc += elem_idx * axis_size + x; + } else { + src_loc += elem_to_loc_nd(elem_idx + x, shape.data_, src_strides.data_); + } + + // Output is always contiguous + int64_t out_idx = y * idx_size_post + elem_idx * idx_size_axis + x; + + out[out_idx] = src[src_loc]; +} + +// Simple scatter kernel for axis-based scatter +template +__global__ void scatter_axis_kernel( + const T* upd, + const IdxT* idx, + T* out, + int64_t idx_size_pre, + int64_t idx_size_axis, + int64_t idx_size_post, + const hip_array shape, + const hip_array upd_strides, + const hip_array idx_strides, + const hip_array out_strides, + int32_t axis, + int32_t axis_size, + int64_t upd_stride_axis, + int64_t idx_stride_axis, + int64_t out_stride_axis) { + int64_t index = blockIdx.x * blockDim.x + threadIdx.x; + int64_t total = idx_size_pre * idx_size_axis * idx_size_post; + if (index >= total) return; + + // Decompose index into x (post), y (axis), z (pre) coordinates + int64_t x, y, z; + index_to_dims(index, idx_size_axis, idx_size_pre, x, y, z); + + int64_t elem_idx = z * idx_size_post; + + // Compute index location + int64_t idx_loc = y * idx_stride_axis; + if constexpr (IdxC) { + idx_loc += elem_idx * idx_size_axis + x; + } else { + idx_loc += elem_to_loc_nd(elem_idx + x, shape.data_, idx_strides.data_); + } + + // Get index value and handle negative indices + IdxT idx_val = idx[idx_loc]; + if (idx_val < 0) { + idx_val += axis_size; + } + + // Compute update location + int64_t upd_loc = y * upd_stride_axis; + if constexpr (UpdC) { + upd_loc += elem_idx * idx_size_axis + x; + } else { + upd_loc += elem_to_loc_nd(elem_idx + x, shape.data_, upd_strides.data_); + } + + // Compute output location + int64_t out_loc = idx_val * out_stride_axis; + out_loc += elem_to_loc_nd(elem_idx + x, shape.data_, out_strides.data_); + + if constexpr (IS_SUM) { + atomicAdd(&out[out_loc], upd[upd_loc]); + } else { + out[out_loc] = upd[upd_loc]; + } +} + +// General scatter kernel - handles arbitrary indexing +template +__global__ void scatter_general_kernel( + const T* upd, + T* out, + int64_t upd_size, + // Metadata passed BY VALUE (hip_array in kernel args) rather than via device + // pointers. The previous by-pointer form required uploading these to device + // buffers via hipMemcpyAsync; under HIP graph capture those H2D nodes record + // the (transient) host source pointer and read freed memory on replay, + // producing a garbage source offset -> out-of-bounds read -> GPU queue hang + // on RDNA4 (gfx1201). By-value metadata is captured correctly and replays. + hip_array upd_shape, + hip_array upd_strides, + int32_t upd_ndim, + int64_t upd_post_idx_size, + hip_array out_shape, + hip_array out_strides, + int32_t out_ndim, + hip_array axes, + hip_array indices, + hip_array indices_shape, + hip_array indices_strides, + int32_t idx_ndim) { + int64_t gid = blockIdx.x * blockDim.x + threadIdx.x; + if (gid >= upd_size) { + return; + } + + int64_t out_elem = gid % upd_post_idx_size; + int64_t idx_elem = gid / upd_post_idx_size; + + // Compute output location from out_elem using upd_shape after idx_ndim dimensions + // This matches the CUDA implementation: elem_to_loc(out_elem, upd_shape + IDX_NDIM, out_strides, out_ndim) + int64_t out_loc = 0; + int64_t tmp = out_elem; + for (int i = out_ndim - 1; i >= 0; --i) { + // Use upd_shape[idx_ndim + i] for the shape dimensions after the index dimensions + int32_t dim_size = (idx_ndim + i < upd_ndim) ? upd_shape[idx_ndim + i] : 1; + out_loc += (tmp % dim_size) * out_strides[i]; + tmp /= dim_size; + } + + // Add index contributions + for (int i = 0; i < NIDX; ++i) { + // Compute index location + int64_t idx_loc = 0; + int64_t tmp_idx = idx_elem; + for (int j = idx_ndim - 1; j >= 0; --j) { + idx_loc += (tmp_idx % indices_shape[i * idx_ndim + j]) * indices_strides[i * idx_ndim + j]; + tmp_idx /= indices_shape[i * idx_ndim + j]; + } + + int32_t axis = axes[i]; + IdxT idx_val = indices[i][idx_loc]; + + // Handle negative indices + if (idx_val < 0) { + idx_val += out_shape[axis]; + } + + out_loc += idx_val * out_strides[axis]; + } + + // Compute update location + int64_t upd_loc = 0; + tmp = out_elem + idx_elem * upd_post_idx_size; + for (int i = upd_ndim - 1; i >= 0; --i) { + upd_loc += (tmp % upd_shape[i]) * upd_strides[i]; + tmp /= upd_shape[i]; + } + + T val = upd[upd_loc]; + + // Apply reduce operation + if constexpr (ReduceType == 0) { // Assign + out[out_loc] = val; + } else if constexpr (ReduceType == 1) { // Sum + // Use appropriate atomic based on type + if constexpr (std::is_same_v) { + atomicAdd(&out[out_loc], val); + } else if constexpr (std::is_same_v) { + atomicAdd(&out[out_loc], val); + } else if constexpr (std::is_same_v) { + atomicAdd(&out[out_loc], val); + } else if constexpr (std::is_same_v) { + atomicAdd(reinterpret_cast(&out[out_loc]), + static_cast(val)); + } else if constexpr (std::is_same_v) { + atomicAdd(&out[out_loc], val); + } else { + // Fallback for types without atomic support - use CAS loop + T* addr = &out[out_loc]; + T old_val = *addr; + T new_val; + do { + new_val = old_val + val; + } while (!__hip_atomic_compare_exchange_strong(addr, &old_val, new_val, + __ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT)); + } + } else if constexpr (ReduceType == 2) { // Prod + // Use CAS loop for atomic multiply + if constexpr (std::is_same_v) { + float* addr = &out[out_loc]; + float old_val = *addr; + float new_val; + do { + new_val = old_val * val; + } while (!__hip_atomic_compare_exchange_strong(addr, &old_val, new_val, + __ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT)); + } else if constexpr (std::is_same_v) { + int32_t* addr = &out[out_loc]; + int32_t old_val = *addr; + int32_t new_val; + do { + new_val = old_val * val; + } while (!__hip_atomic_compare_exchange_strong(addr, &old_val, new_val, + __ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT)); + } else { + // Fallback for other types + T* addr = &out[out_loc]; + T old_val = *addr; + T new_val; + do { + new_val = old_val * val; + } while (!__hip_atomic_compare_exchange_strong(addr, &old_val, new_val, + __ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT)); + } + } else if constexpr (ReduceType == 3) { // Max + // Use CAS loop for atomic max + if constexpr (std::is_same_v) { + int32_t* addr = &out[out_loc]; + int32_t old_val = *addr; + while (val > old_val) { + if (__hip_atomic_compare_exchange_strong(addr, &old_val, val, + __ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT)) { + break; + } + } + } else if constexpr (std::is_same_v) { + uint32_t* addr = &out[out_loc]; + uint32_t old_val = *addr; + while (val > old_val) { + if (__hip_atomic_compare_exchange_strong(addr, &old_val, val, + __ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT)) { + break; + } + } + } else if constexpr (std::is_same_v) { + // Use CAS loop for float max + float* addr = &out[out_loc]; + float old_val = *addr; + while (val > old_val) { + if (__hip_atomic_compare_exchange_strong(addr, &old_val, val, + __ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT)) { + break; + } + } + } else { + // Fallback for other types + T* addr = &out[out_loc]; + T old_val = *addr; + while (val > old_val) { + if (__hip_atomic_compare_exchange_strong(addr, &old_val, val, + __ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT)) { + break; + } + } + } + } else if constexpr (ReduceType == 4) { // Min + // Use CAS loop for atomic min + if constexpr (std::is_same_v) { + int32_t* addr = &out[out_loc]; + int32_t old_val = *addr; + while (val < old_val) { + if (__hip_atomic_compare_exchange_strong(addr, &old_val, val, + __ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT)) { + break; + } + } + } else if constexpr (std::is_same_v) { + uint32_t* addr = &out[out_loc]; + uint32_t old_val = *addr; + while (val < old_val) { + if (__hip_atomic_compare_exchange_strong(addr, &old_val, val, + __ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT)) { + break; + } + } + } else if constexpr (std::is_same_v) { + // Use CAS loop for float min + float* addr = &out[out_loc]; + float old_val = *addr; + while (val < old_val) { + if (__hip_atomic_compare_exchange_strong(addr, &old_val, val, + __ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT)) { + break; + } + } + } else { + // Fallback for other types + T* addr = &out[out_loc]; + T old_val = *addr; + while (val < old_val) { + if (__hip_atomic_compare_exchange_strong(addr, &old_val, val, + __ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT)) { + break; + } + } + } + } +} + +// SliceUpdate kernel: applies Op to combine existing output values with +// update values at computed slice positions. +template < + typename T, + typename IdxT, + typename Op, + bool OUT_ROW_CONTIG, + bool UPD_ROW_CONTIG, + bool UPD_SCALAR, + int NWORK> +__global__ void slice_update_op_kernel( + const T* updates, + T* out, + int64_t update_size, + hip_array update_shape, + hip_array update_strides, + int32_t update_ndim, + hip_array output_strides, + int64_t output_offset) { + Op op; + + IdxT idx = (IdxT(blockIdx.x) * IdxT(blockDim.x) + IdxT(threadIdx.x)) * NWORK; + IdxT out_idx; + IdxT update_idx; + + if constexpr (OUT_ROW_CONTIG) { + out_idx = idx; + } else { + out_idx = elem_to_loc( + idx, update_shape.data_, output_strides.data_, update_ndim); + } + + if constexpr (!UPD_SCALAR) { + if constexpr (UPD_ROW_CONTIG) { + update_idx = idx; + } else { + update_idx = elem_to_loc( + idx, update_shape.data_, update_strides.data_, update_ndim); + } + } else { + update_idx = 0; + } + + out += output_offset; + + for (int j = 0; j < NWORK && idx < update_size; j++) { + out[out_idx] = op(out[out_idx], updates[update_idx]); + idx++; + + if constexpr (OUT_ROW_CONTIG) { + out_idx = idx; + } else { + out_idx += output_strides[update_ndim - 1]; + } + + if constexpr (UPD_ROW_CONTIG) { + update_idx = idx; + } else if constexpr (!UPD_SCALAR) { + update_idx += update_strides[update_ndim - 1]; + } + } +} + +template +__global__ void masked_scatter_offsets_kernel( + const bool* mask, + uint32_t* scatter_offsets, + int64_t mask_batch_size) { + const int64_t batch_idx = static_cast(blockIdx.x); + const int tid = threadIdx.x; + const int64_t batch_base = batch_idx * mask_batch_size; + + __shared__ uint32_t scan_vals[BLOCK_SIZE]; + uint32_t batch_prefix = 0; + + for (int64_t i = 0; i < mask_batch_size; i += BLOCK_SIZE) { + const int64_t mask_idx = i + tid; + const bool in_range = mask_idx < mask_batch_size; + const uint32_t mask_value = + (in_range && mask[batch_base + mask_idx]) ? 1u : 0u; + + scan_vals[tid] = mask_value; + __syncthreads(); + + // In-place inclusive scan for a fixed-size block. + for (int offset = 1; offset < BLOCK_SIZE; offset <<= 1) { + uint32_t add = 0; + if (tid >= offset) { + add = scan_vals[tid - offset]; + } + __syncthreads(); + scan_vals[tid] += add; + __syncthreads(); + } + + if (in_range) { + // Convert the in-block inclusive scan to an exclusive offset. + scatter_offsets[batch_base + mask_idx] = + batch_prefix + (scan_vals[tid] - mask_value); + } + + __syncthreads(); + batch_prefix += scan_vals[BLOCK_SIZE - 1]; + __syncthreads(); + } +} + +template +__global__ void masked_scatter_assign_kernel( + const bool* mask, + const uint32_t* scatter_offsets, + const T* src, + T* out, + int64_t total, + const rocm::hip_array src_shape, + const rocm::hip_array src_strides, + int32_t src_ndim, + int64_t src_batch_size, + int64_t mask_batch_size) { + const int64_t idx = static_cast(blockIdx.x) * blockDim.x + + threadIdx.x; + if (idx >= total || !mask[idx]) { + return; + } + + const uint32_t src_index = scatter_offsets[idx]; + if (static_cast(src_index) >= src_batch_size) { + return; + } + + const int64_t batch_idx = idx / mask_batch_size; + const int64_t src_elem = + batch_idx * src_batch_size + static_cast(src_index); + + if constexpr (SrcContiguous) { + out[idx] = src[src_elem]; + } else { + const int64_t src_loc = rocm::elem_to_loc( + src_elem, src_shape.data_, src_strides.data_, src_ndim); + out[idx] = src[src_loc]; + } +} + +} // namespace rocm + +void Gather::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() > 0); + const auto& src = inputs[0]; + + auto& s = stream(); + auto& encoder = rocm::get_command_encoder(s); + + out.set_data(mlx::core::rocm::malloc_async(out.nbytes(), encoder)); + if (out.size() == 0) { + return; + } + + int nidx = inputs.size() - 1; + int32_t idx_ndim = nidx > 0 ? inputs[1].ndim() : 0; + + uint32_t slice_size = std::accumulate( + slice_sizes_.begin(), slice_sizes_.end(), 1, std::multiplies()); + + // Prepare host data for parameters + std::vector h_src_shape(src.shape().begin(), src.shape().end()); + std::vector h_src_strides(src.strides().begin(), src.strides().end()); + std::vector h_slice_sizes(slice_sizes_.begin(), slice_sizes_.end()); + std::vector h_axes(axes_.begin(), axes_.end()); + + // Prepare indices pointers and metadata + std::vector h_indices(std::max(nidx, 1)); + std::vector h_indices_shape(std::max(nidx, 1) * std::max(idx_ndim, 1)); + std::vector h_indices_strides(std::max(nidx, 1) * std::max(idx_ndim, 1)); + + for (int i = 0; i < nidx; ++i) { + h_indices[i] = gpu_ptr(inputs[i + 1]); + for (int j = 0; j < idx_ndim; ++j) { + h_indices_shape[i * idx_ndim + j] = inputs[i + 1].shape(j); + h_indices_strides[i * idx_ndim + j] = inputs[i + 1].strides(j); + } + } + + for (const auto& in : inputs) { + encoder.set_input_array(in); + } + encoder.set_output_array(out); + + int64_t total = out.size(); + int block_size = 256; + int num_blocks = (total + block_size - 1) / block_size; + + // Fast path: axis-0 full-row gather of a row-contiguous source with a single + // contiguous (flattened) index -> coalesced per-row copy (gather_rows_kernel). + bool fast_rows = (nidx == 1) && (axes_.size() == 1) && (axes_[0] == 0) && + src.ndim() >= 2 && src.flags().row_contiguous && + inputs[1].flags().contiguous && + ((int)slice_sizes_.size() == (int)src.ndim()) && (slice_sizes_[0] == 1); + if (fast_rows) { + for (int d = 1; d < (int)src.ndim(); ++d) { + if (slice_sizes_[d] != src.shape(d)) { + fast_rows = false; + break; + } + } + } + // Grid is one block per row; keep within the launch limit. + fast_rows = fast_rows && (total / (int64_t)slice_size) <= 0x7fffffffLL; + + // Pass all metadata BY VALUE (see gather_general_kernel) — no device buffers, + // no H2D uploads, so nothing reads stale host memory on HIP graph replay. + auto p_src_shape = const_param(h_src_shape); + auto p_src_strides = const_param(h_src_strides); + auto p_slice_sizes = const_param(h_slice_sizes); + auto p_axes = const_param<8>(h_axes); + auto p_indices_shape = const_param<8 * MAX_NDIM>(h_indices_shape); + auto p_indices_strides = const_param<8 * MAX_NDIM>(h_indices_strides); + int32_t src_ndim_v = static_cast(src.ndim()); + + { + if (fast_rows) { + int64_t n_rows = total / (int64_t)slice_size; + dim3 grid((unsigned int)n_rows); + dim3 blk(256); + Dtype it = inputs[1].dtype(); + int32_t src_dim0 = (int32_t)src.shape(0); + #define LAUNCH_ROWS(T, IdxT) \ + encoder.add_kernel_node((&rocm::gather_rows_kernel), grid, blk, 0, \ + gpu_ptr(src), \ + reinterpret_cast(h_indices[0]), gpu_ptr(out), \ + n_rows, slice_size, src_dim0) + #define ROWS_BY_T(IdxT) \ + switch (out.dtype()) { \ + case float32: LAUNCH_ROWS(float, IdxT); break; \ + case float16: LAUNCH_ROWS(__half, IdxT); break; \ + case bfloat16: LAUNCH_ROWS(hip_bfloat16, IdxT); break; \ + case int32: LAUNCH_ROWS(int32_t, IdxT); break; \ + case int64: LAUNCH_ROWS(int64_t, IdxT); break; \ + case uint32: LAUNCH_ROWS(uint32_t, IdxT); break; \ + case uint64: LAUNCH_ROWS(uint64_t, IdxT); break; \ + case int8: LAUNCH_ROWS(int8_t, IdxT); break; \ + case int16: LAUNCH_ROWS(int16_t, IdxT); break; \ + case uint8: LAUNCH_ROWS(uint8_t, IdxT); break; \ + case uint16: LAUNCH_ROWS(uint16_t, IdxT); break; \ + case bool_: LAUNCH_ROWS(bool, IdxT); break; \ + default: throw std::runtime_error("Unsupported dtype for Gather"); \ + } + if (it == int32 || it == uint32) { + ROWS_BY_T(int32_t); + } else { + ROWS_BY_T(int64_t); + } + #undef ROWS_BY_T + #undef LAUNCH_ROWS + return; + } + // Dispatch based on dtype and number of indices + #define LAUNCH_GATHER(T, IdxT, NIDX) \ + do { \ + rocm::hip_array idx_ptrs; \ + for (int _i = 0; _i < (NIDX); ++_i) \ + idx_ptrs[_i] = reinterpret_cast(h_indices[_i]); \ + encoder.add_kernel_node( \ + (&rocm::gather_general_kernel), \ + dim3(num_blocks), dim3(block_size), 0, \ + gpu_ptr(src), gpu_ptr(out), total, \ + p_src_shape, p_src_strides, src_ndim_v, \ + p_slice_sizes, slice_size, p_axes, \ + idx_ptrs, p_indices_shape, p_indices_strides, idx_ndim); \ + } while (0) + + #define DISPATCH_NIDX(T, IdxT) \ + switch (nidx) { \ + case 0: LAUNCH_GATHER(T, IdxT, 0); break; \ + case 1: LAUNCH_GATHER(T, IdxT, 1); break; \ + case 2: LAUNCH_GATHER(T, IdxT, 2); break; \ + case 3: LAUNCH_GATHER(T, IdxT, 3); break; \ + case 4: LAUNCH_GATHER(T, IdxT, 4); break; \ + default: LAUNCH_GATHER(T, IdxT, 8); break; \ + } + + Dtype idx_dtype = nidx > 0 ? inputs[1].dtype() : int32; + + if (idx_dtype == int32 || idx_dtype == uint32) { + switch (out.dtype()) { + case float32: DISPATCH_NIDX(float, int32_t); break; + case float16: DISPATCH_NIDX(__half, int32_t); break; + case bfloat16: DISPATCH_NIDX(hip_bfloat16, int32_t); break; + case int32: DISPATCH_NIDX(int32_t, int32_t); break; + case int64: DISPATCH_NIDX(int64_t, int32_t); break; + case uint32: DISPATCH_NIDX(uint32_t, int32_t); break; + case uint64: DISPATCH_NIDX(uint64_t, int32_t); break; + case int8: DISPATCH_NIDX(int8_t, int32_t); break; + case int16: DISPATCH_NIDX(int16_t, int32_t); break; + case uint8: DISPATCH_NIDX(uint8_t, int32_t); break; + case uint16: DISPATCH_NIDX(uint16_t, int32_t); break; + case bool_: DISPATCH_NIDX(bool, int32_t); break; + default: + throw std::runtime_error("Unsupported dtype for Gather"); + } + } else { + switch (out.dtype()) { + case float32: DISPATCH_NIDX(float, int64_t); break; + case float16: DISPATCH_NIDX(__half, int64_t); break; + case bfloat16: DISPATCH_NIDX(hip_bfloat16, int64_t); break; + case int32: DISPATCH_NIDX(int32_t, int64_t); break; + case int64: DISPATCH_NIDX(int64_t, int64_t); break; + case uint32: DISPATCH_NIDX(uint32_t, int64_t); break; + case uint64: DISPATCH_NIDX(uint64_t, int64_t); break; + case int8: DISPATCH_NIDX(int8_t, int64_t); break; + case int16: DISPATCH_NIDX(int16_t, int64_t); break; + case uint8: DISPATCH_NIDX(uint8_t, int64_t); break; + case uint16: DISPATCH_NIDX(uint16_t, int64_t); break; + case bool_: DISPATCH_NIDX(bool, int64_t); break; + default: + throw std::runtime_error("Unsupported dtype for Gather"); + } + } + + #undef DISPATCH_NIDX + #undef LAUNCH_GATHER + } +} + +void Scatter::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() > 1); + auto& upd = inputs.back(); + + // Copy src into out + CopyType copy_type; + if (inputs[0].data_size() == 1) { + copy_type = CopyType::Scalar; + } else if (inputs[0].flags().row_contiguous) { + copy_type = CopyType::Vector; + } else { + copy_type = CopyType::General; + } + copy_gpu(inputs[0], out, copy_type); + + if (upd.size() == 0) { + return; + } + + auto& s = stream(); + auto& encoder = rocm::get_command_encoder(s); + + int nidx = axes_.size(); + int32_t idx_ndim = nidx > 0 ? inputs[1].ndim() : 0; + + int32_t upd_post_idx_size = std::accumulate( + upd.shape().begin() + idx_ndim, + upd.shape().end(), + 1, + std::multiplies()); + + // Prepare host data for parameters + std::vector h_upd_shape(upd.shape().begin(), upd.shape().end()); + std::vector h_upd_strides(upd.strides().begin(), upd.strides().end()); + std::vector h_out_shape(out.shape().begin(), out.shape().end()); + std::vector h_out_strides(out.strides().begin(), out.strides().end()); + std::vector h_axes(axes_.begin(), axes_.end()); + + // Prepare indices pointers and metadata + std::vector h_indices(std::max(nidx, 1)); + std::vector h_indices_shape(std::max(nidx, 1) * std::max(idx_ndim, 1)); + std::vector h_indices_strides(std::max(nidx, 1) * std::max(idx_ndim, 1)); + + for (int i = 0; i < nidx; ++i) { + h_indices[i] = gpu_ptr(inputs[i + 1]); + for (int j = 0; j < idx_ndim; ++j) { + h_indices_shape[i * idx_ndim + j] = inputs[i + 1].shape(j); + h_indices_strides[i * idx_ndim + j] = inputs[i + 1].strides(j); + } + } + + for (const auto& in : inputs) { + encoder.set_input_array(in); + } + encoder.set_output_array(out); + + int64_t total = upd.size(); + int block_size = 256; + int num_blocks = (total + block_size - 1) / block_size; + + // Pass all metadata BY VALUE (see scatter_general_kernel) — no device buffers, + // no H2D uploads, so nothing reads stale host memory on HIP graph replay. + auto p_upd_shape = const_param(h_upd_shape); + auto p_upd_strides = const_param(h_upd_strides); + auto p_out_shape = const_param(h_out_shape); + auto p_out_strides = const_param(h_out_strides); + auto p_axes = const_param<8>(h_axes); + auto p_indices_shape = const_param<8 * MAX_NDIM>(h_indices_shape); + auto p_indices_strides = const_param<8 * MAX_NDIM>(h_indices_strides); + int32_t upd_ndim_v = static_cast(upd.ndim()); + int32_t out_ndim_v = static_cast(out.ndim()); + + int reduce_type = reduce_type_; // Scatter::ReduceType: Max=0, Min=1, Sum=2, Prod=3, None=4 + // Map to kernel ReduceType: Assign=0, Sum=1, Prod=2, Max=3, Min=4 + int kernel_reduce_type; + switch (reduce_type) { + case 0: kernel_reduce_type = 3; break; // Max + case 1: kernel_reduce_type = 4; break; // Min + case 2: kernel_reduce_type = 1; break; // Sum + case 3: kernel_reduce_type = 2; break; // Prod + case 4: kernel_reduce_type = 0; break; // None -> Assign + default: kernel_reduce_type = 0; break; + } + + { + #define LAUNCH_SCATTER(T, IdxT, NIDX, RT) \ + do { \ + rocm::hip_array idx_ptrs; \ + for (int _i = 0; _i < (NIDX); ++_i) \ + idx_ptrs[_i] = reinterpret_cast(h_indices[_i]); \ + encoder.add_kernel_node( \ + (&rocm::scatter_general_kernel), \ + dim3(num_blocks), dim3(block_size), 0, \ + gpu_ptr(upd), gpu_ptr(out), total, \ + p_upd_shape, p_upd_strides, upd_ndim_v, upd_post_idx_size, \ + p_out_shape, p_out_strides, out_ndim_v, \ + p_axes, idx_ptrs, \ + p_indices_shape, p_indices_strides, idx_ndim); \ + } while (0) + + #define DISPATCH_REDUCE(T, IdxT, NIDX) \ + switch (kernel_reduce_type) { \ + case 0: LAUNCH_SCATTER(T, IdxT, NIDX, 0); break; \ + case 1: LAUNCH_SCATTER(T, IdxT, NIDX, 1); break; \ + case 2: LAUNCH_SCATTER(T, IdxT, NIDX, 2); break; \ + case 3: LAUNCH_SCATTER(T, IdxT, NIDX, 3); break; \ + case 4: LAUNCH_SCATTER(T, IdxT, NIDX, 4); break; \ + default: LAUNCH_SCATTER(T, IdxT, NIDX, 0); break; \ + } + + #define DISPATCH_NIDX(T, IdxT) \ + switch (nidx) { \ + case 0: DISPATCH_REDUCE(T, IdxT, 0); break; \ + case 1: DISPATCH_REDUCE(T, IdxT, 1); break; \ + case 2: DISPATCH_REDUCE(T, IdxT, 2); break; \ + case 3: DISPATCH_REDUCE(T, IdxT, 3); break; \ + default: DISPATCH_REDUCE(T, IdxT, 4); break; \ + } + + Dtype idx_dtype = nidx > 0 ? inputs[1].dtype() : int32; + + if (idx_dtype == int32 || idx_dtype == uint32) { + switch (out.dtype()) { + case float32: DISPATCH_NIDX(float, int32_t); break; + case float16: DISPATCH_NIDX(__half, int32_t); break; + case bfloat16: DISPATCH_NIDX(hip_bfloat16, int32_t); break; + case int32: DISPATCH_NIDX(int32_t, int32_t); break; + case int64: DISPATCH_NIDX(int64_t, int32_t); break; + case uint32: DISPATCH_NIDX(uint32_t, int32_t); break; + case uint64: DISPATCH_NIDX(uint64_t, int32_t); break; + case int8: DISPATCH_NIDX(int8_t, int32_t); break; + case int16: DISPATCH_NIDX(int16_t, int32_t); break; + case uint8: DISPATCH_NIDX(uint8_t, int32_t); break; + case uint16: DISPATCH_NIDX(uint16_t, int32_t); break; + case bool_: DISPATCH_NIDX(bool, int32_t); break; + default: + throw std::runtime_error("Unsupported dtype for Scatter"); + } + } else { + switch (out.dtype()) { + case float32: DISPATCH_NIDX(float, int64_t); break; + case float16: DISPATCH_NIDX(__half, int64_t); break; + case bfloat16: DISPATCH_NIDX(hip_bfloat16, int64_t); break; + case int32: DISPATCH_NIDX(int32_t, int64_t); break; + case int64: DISPATCH_NIDX(int64_t, int64_t); break; + case uint32: DISPATCH_NIDX(uint32_t, int64_t); break; + case uint64: DISPATCH_NIDX(uint64_t, int64_t); break; + case int8: DISPATCH_NIDX(int8_t, int64_t); break; + case int16: DISPATCH_NIDX(int16_t, int64_t); break; + case uint8: DISPATCH_NIDX(uint8_t, int64_t); break; + case uint16: DISPATCH_NIDX(uint16_t, int64_t); break; + case bool_: DISPATCH_NIDX(bool, int64_t); break; + default: + throw std::runtime_error("Unsupported dtype for Scatter"); + } + } + + #undef DISPATCH_NIDX + #undef DISPATCH_REDUCE + #undef LAUNCH_SCATTER + } +} + +void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() > 1); + const auto& src = inputs[0]; + const auto& idx = inputs[1]; + + auto& s = stream(); + auto& encoder = rocm::get_command_encoder(s); + + out.set_data(mlx::core::rocm::malloc_async(out.nbytes(), encoder)); + if (out.size() == 0) { + return; + } + + encoder.set_input_array(src); + encoder.set_input_array(idx); + encoder.set_output_array(out); + + size_t idx_size_pre = 1; + size_t idx_size_post = 1; + for (int i = 0; i < axis_; ++i) { + idx_size_pre *= idx.shape(i); + } + for (int i = axis_ + 1; i < idx.ndim(); ++i) { + idx_size_post *= idx.shape(i); + } + size_t idx_size_axis = idx.shape(axis_); + + // Create shape and strides with axis dimension removed + int ndim = src.ndim() - 1; + if (ndim == 0) { + ndim = 1; // Ensure at least 1 dimension for elem_to_loc_nd + } + + std::vector shape_vec(ndim, 1); + std::vector src_strides_vec(ndim, 0); + std::vector idx_strides_vec(ndim, 0); + + for (int i = 0, j = 0; i < src.ndim(); ++i) { + if (i != axis_) { + if (j < ndim) { + shape_vec[j] = idx.shape(i); + src_strides_vec[j] = src.strides(i); + idx_strides_vec[j] = idx.strides(i); + } + ++j; + } + } + + // Use const_param to pass shape and strides by value (like CUDA) + auto shape_param = const_param(shape_vec); + auto src_strides_param = const_param(src_strides_vec); + auto idx_strides_param = const_param(idx_strides_vec); + + int64_t src_stride_axis = src.strides(axis_); + int64_t idx_stride_axis = idx.strides(axis_); + int32_t axis_size = src.shape(axis_); + + bool src_contiguous = src.flags().row_contiguous; + bool idx_contiguous = idx.flags().row_contiguous; + + int64_t total = idx_size_pre * idx_size_axis * idx_size_post; + int block_size = 256; + int num_blocks = (total + block_size - 1) / block_size; + + // Dispatch based on ndim, contiguity, and index type + #define LAUNCH_GATHER_KERNEL(T, IdxT, NDIM, SrcC, IdxC) \ + encoder.add_kernel_node( \ + (&rocm::gather_axis_kernel), \ + dim3(num_blocks), dim3(block_size), 0, \ + gpu_ptr(src), gpu_ptr(idx), gpu_ptr(out), \ + idx_size_pre, idx_size_axis, idx_size_post, \ + shape_param, \ + src_strides_param, \ + idx_strides_param, \ + axis_, axis_size, src_stride_axis, idx_stride_axis) + + #define DISPATCH_CONTIGUOUS(T, IdxT, NDIM) \ + if (src_contiguous && idx_contiguous) { \ + LAUNCH_GATHER_KERNEL(T, IdxT, NDIM, true, true); \ + } else if (src_contiguous) { \ + LAUNCH_GATHER_KERNEL(T, IdxT, NDIM, true, false); \ + } else if (idx_contiguous) { \ + LAUNCH_GATHER_KERNEL(T, IdxT, NDIM, false, true); \ + } else { \ + LAUNCH_GATHER_KERNEL(T, IdxT, NDIM, false, false); \ + } + + #define DISPATCH_NDIM(T, IdxT) \ + switch (ndim) { \ + case 0: DISPATCH_CONTIGUOUS(T, IdxT, 1); break; \ + case 1: DISPATCH_CONTIGUOUS(T, IdxT, 1); break; \ + case 2: DISPATCH_CONTIGUOUS(T, IdxT, 2); break; \ + case 3: DISPATCH_CONTIGUOUS(T, IdxT, 3); break; \ + case 4: DISPATCH_CONTIGUOUS(T, IdxT, 4); break; \ + case 5: DISPATCH_CONTIGUOUS(T, IdxT, 5); break; \ + case 6: DISPATCH_CONTIGUOUS(T, IdxT, 6); break; \ + case 7: DISPATCH_CONTIGUOUS(T, IdxT, 7); break; \ + default: DISPATCH_CONTIGUOUS(T, IdxT, 8); break; \ + } + + #define DISPATCH_IDX_TYPE(T) \ + if (idx.dtype() == int32 || idx.dtype() == uint32) { \ + DISPATCH_NDIM(T, int32_t); \ + } else { \ + DISPATCH_NDIM(T, int64_t); \ + } + + switch (src.dtype()) { + case float32: DISPATCH_IDX_TYPE(float); break; + case int32: DISPATCH_IDX_TYPE(int32_t); break; + case uint32: DISPATCH_IDX_TYPE(uint32_t); break; + case int64: DISPATCH_IDX_TYPE(int64_t); break; + case uint64: DISPATCH_IDX_TYPE(uint64_t); break; + case float16: DISPATCH_IDX_TYPE(__half); break; + case bfloat16: DISPATCH_IDX_TYPE(hip_bfloat16); break; + case int8: DISPATCH_IDX_TYPE(int8_t); break; + case uint8: DISPATCH_IDX_TYPE(uint8_t); break; + case int16: DISPATCH_IDX_TYPE(int16_t); break; + case uint16: DISPATCH_IDX_TYPE(uint16_t); break; + case bool_: DISPATCH_IDX_TYPE(bool); break; + default: + throw std::runtime_error("Unsupported dtype for GatherAxis"); + } + + #undef LAUNCH_GATHER_KERNEL + #undef DISPATCH_CONTIGUOUS + #undef DISPATCH_NDIM + #undef DISPATCH_IDX_TYPE +} + +void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() > 2); + const auto& src = inputs[0]; + const auto& idx = inputs[1]; + const auto& upd = inputs[2]; + + // Copy src into out + CopyType copy_type; + if (src.data_size() == 1) { + copy_type = CopyType::Scalar; + } else if (src.flags().row_contiguous) { + copy_type = CopyType::Vector; + } else { + copy_type = CopyType::General; + } + copy_gpu(src, out, copy_type); + + if (upd.size() == 0) { + return; + } + + auto& s = stream(); + auto& encoder = rocm::get_command_encoder(s); + + encoder.set_input_array(upd); + encoder.set_input_array(idx); + encoder.set_output_array(out); + + size_t idx_size_pre = 1; + size_t idx_size_post = 1; + for (int i = 0; i < axis_; ++i) { + idx_size_pre *= idx.shape(i); + } + for (int i = axis_ + 1; i < idx.ndim(); ++i) { + idx_size_post *= idx.shape(i); + } + size_t idx_size_axis = idx.shape(axis_); + + // Create shape and strides with axis dimension removed + int ndim = idx.ndim() - 1; + if (ndim == 0) { + ndim = 1; // Ensure at least 1 dimension for elem_to_loc_nd + } + + std::vector shape_vec(ndim, 1); + std::vector upd_strides_vec(ndim, 0); + std::vector idx_strides_vec(ndim, 0); + std::vector out_strides_vec(ndim, 0); + + for (int i = 0, j = 0; i < idx.ndim(); ++i) { + if (i != axis_) { + if (j < ndim) { + shape_vec[j] = idx.shape(i); + upd_strides_vec[j] = upd.strides(i); + idx_strides_vec[j] = idx.strides(i); + out_strides_vec[j] = out.strides(i); + } + ++j; + } + } + + // Use const_param to pass shape and strides by value + auto shape_param = const_param(shape_vec); + auto upd_strides_param = const_param(upd_strides_vec); + auto idx_strides_param = const_param(idx_strides_vec); + auto out_strides_param = const_param(out_strides_vec); + + int64_t upd_stride_axis = upd.strides(axis_); + int64_t idx_stride_axis = idx.strides(axis_); + int64_t out_stride_axis = out.strides(axis_); + int32_t axis_size = out.shape(axis_); + + bool upd_contiguous = upd.flags().row_contiguous; + bool idx_contiguous = idx.flags().row_contiguous; + + int64_t total = idx_size_pre * idx_size_axis * idx_size_post; + int block_size = 256; + int num_blocks = (total + block_size - 1) / block_size; + + bool is_sum = (reduce_type_ == ScatterAxis::Sum); + + #define LAUNCH_SCATTER_KERNEL(T, IdxT, IS_SUM, NDIM, UpdC, IdxC) \ + encoder.add_kernel_node( \ + (&rocm::scatter_axis_kernel), \ + dim3(num_blocks), dim3(block_size), 0, \ + gpu_ptr(upd), gpu_ptr(idx), gpu_ptr(out), \ + idx_size_pre, idx_size_axis, idx_size_post, \ + shape_param, \ + upd_strides_param, \ + idx_strides_param, \ + out_strides_param, \ + axis_, axis_size, upd_stride_axis, idx_stride_axis, out_stride_axis) + + #define DISPATCH_CONTIGUOUS(T, IdxT, IS_SUM, NDIM) \ + if (upd_contiguous && idx_contiguous) { \ + LAUNCH_SCATTER_KERNEL(T, IdxT, IS_SUM, NDIM, true, true); \ + } else if (upd_contiguous) { \ + LAUNCH_SCATTER_KERNEL(T, IdxT, IS_SUM, NDIM, true, false); \ + } else if (idx_contiguous) { \ + LAUNCH_SCATTER_KERNEL(T, IdxT, IS_SUM, NDIM, false, true); \ + } else { \ + LAUNCH_SCATTER_KERNEL(T, IdxT, IS_SUM, NDIM, false, false); \ + } + + #define DISPATCH_NDIM(T, IdxT, IS_SUM) \ + switch (ndim) { \ + case 0: DISPATCH_CONTIGUOUS(T, IdxT, IS_SUM, 1); break; \ + case 1: DISPATCH_CONTIGUOUS(T, IdxT, IS_SUM, 1); break; \ + case 2: DISPATCH_CONTIGUOUS(T, IdxT, IS_SUM, 2); break; \ + case 3: DISPATCH_CONTIGUOUS(T, IdxT, IS_SUM, 3); break; \ + case 4: DISPATCH_CONTIGUOUS(T, IdxT, IS_SUM, 4); break; \ + case 5: DISPATCH_CONTIGUOUS(T, IdxT, IS_SUM, 5); break; \ + case 6: DISPATCH_CONTIGUOUS(T, IdxT, IS_SUM, 6); break; \ + case 7: DISPATCH_CONTIGUOUS(T, IdxT, IS_SUM, 7); break; \ + default: DISPATCH_CONTIGUOUS(T, IdxT, IS_SUM, 8); break; \ + } + + #define DISPATCH_IDX_TYPE(T, IS_SUM) \ + if (idx.dtype() == int32 || idx.dtype() == uint32) { \ + DISPATCH_NDIM(T, int32_t, IS_SUM); \ + } else { \ + DISPATCH_NDIM(T, int64_t, IS_SUM); \ + } + + if (is_sum) { + // Note: atomicAdd only supports float32 and float64 on ROCm + // float16/bfloat16 would need custom atomic implementations + switch (upd.dtype()) { + case float32: DISPATCH_IDX_TYPE(float, true); break; + default: + throw std::runtime_error("Unsupported dtype for ScatterAxis Sum (only float32 supported)"); + } + } else { + switch (upd.dtype()) { + case float32: DISPATCH_IDX_TYPE(float, false); break; + case float16: DISPATCH_IDX_TYPE(__half, false); break; + case bfloat16: DISPATCH_IDX_TYPE(hip_bfloat16, false); break; + case int32: DISPATCH_IDX_TYPE(int32_t, false); break; + case int64: DISPATCH_IDX_TYPE(int64_t, false); break; + case uint32: DISPATCH_IDX_TYPE(uint32_t, false); break; + case uint64: DISPATCH_IDX_TYPE(uint64_t, false); break; + case int8: DISPATCH_IDX_TYPE(int8_t, false); break; + case int16: DISPATCH_IDX_TYPE(int16_t, false); break; + case uint8: DISPATCH_IDX_TYPE(uint8_t, false); break; + case uint16: DISPATCH_IDX_TYPE(uint16_t, false); break; + case bool_: DISPATCH_IDX_TYPE(bool, false); break; + default: + throw std::runtime_error("Unsupported dtype for ScatterAxis Assign"); + } + } + + #undef LAUNCH_SCATTER_KERNEL + #undef DISPATCH_CONTIGUOUS + #undef DISPATCH_NDIM + #undef DISPATCH_IDX_TYPE +} + +void SliceUpdate::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 2); + if (out.size() == 0) { + return; + } + + auto& in = inputs[0]; + auto& upd = inputs[1]; + + if (upd.size() == 0) { + out.copy_shared_buffer(in); + return; + } + + // Donate the input buffer when uniquely owned, else copy. + bool can_donate = in.data_shared_ptr() != nullptr && + in.data_shared_ptr().use_count() == 1 && in.flags().contiguous && + in.data_size() == in.size(); + if (can_donate) { + out.copy_shared_buffer(in); + } else { + auto ctype = in.flags().contiguous && in.size() == in.data_size() + ? CopyType::Vector + : CopyType::General; + copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream()); + } + + // Calculate out strides, initial offset + auto [data_offset, out_strides] = + prepare_slice(out, start_indices_, strides_); + + // Do copy for None reduce type + if (reduce_type_ == SliceUpdate::None) { + copy_gpu_inplace( + /* const array& src = */ upd, + /* array& dst = */ out, + /* const Shape& data_shape = */ upd.shape(), + /* const Strides& i_strides = */ upd.strides(), + /* const Strides& o_strides = */ out_strides, + /* int64_t i_offset = */ 0, + /* int64_t o_offset = */ data_offset, + /* CopyType ctype = */ CopyType::GeneralGeneral, + /* const Stream& s = */ stream()); + return; + } + + // For reduce types (Sum/Prod/Max/Min), launch a kernel + auto [shape, strides] = + collapse_contiguous_dims(upd.shape(), {upd.strides(), out_strides}); + int nwork = 1; + if (shape.back() % 4 == 0) { + nwork = 4; + } else if (shape.back() % 2 == 0) { + nwork = 2; + } + + auto [ds, rc, cc] = check_contiguity(shape, strides[1]); + bool upd_contiguous = upd.flags().row_contiguous; + bool upd_scalar = upd.data_size() == 1; + bool out_contiguous = rc; + + int ndim = shape.size(); + + auto& s = stream(); + auto& encoder = rocm::get_command_encoder(s); + + encoder.set_input_array(upd); + encoder.set_output_array(out); + + auto shape_param = const_param(shape); + auto upd_strides_param = const_param(strides[0]); + auto out_strides_param = const_param(strides[1]); + + int64_t update_size = upd.size(); + int block_size = 256; + int64_t adjusted_size = (update_size + nwork - 1) / nwork; + int num_blocks = static_cast( + std::min((adjusted_size + block_size - 1) / block_size, (int64_t)65535)); + + // Plain local: a structured binding (data_offset) cannot be captured by the + // kernel-launch macro under C++17. + int64_t data_offset_v = data_offset; + + #define SLICE_UPDATE_LAUNCH(T, Op, OUT_C, UPD_C, UPD_S, NWORK_VAL) \ + encoder.add_kernel_node( \ + (&rocm::slice_update_op_kernel), \ + dim3(num_blocks), dim3(block_size), 0, \ + gpu_ptr(upd), gpu_ptr(out), update_size, \ + shape_param, upd_strides_param, ndim, \ + out_strides_param, data_offset_v) + + // Dispatch helper for NWORK + #define DISPATCH_NWORK(T, Op, OUT_C, UPD_C, UPD_S) \ + switch (nwork) { \ + case 4: SLICE_UPDATE_LAUNCH(T, Op, OUT_C, UPD_C, UPD_S, 4); break; \ + case 2: SLICE_UPDATE_LAUNCH(T, Op, OUT_C, UPD_C, UPD_S, 2); break; \ + default: SLICE_UPDATE_LAUNCH(T, Op, OUT_C, UPD_C, UPD_S, 1); break; \ + } + + // Dispatch helper for contiguity flags + #define DISPATCH_CONTIG(T, Op) \ + if (upd_scalar) { \ + if (out_contiguous) { \ + DISPATCH_NWORK(T, Op, true, false, true); \ + } else { \ + DISPATCH_NWORK(T, Op, false, false, true); \ + } \ + } else if (upd_contiguous && out_contiguous) { \ + DISPATCH_NWORK(T, Op, true, true, false); \ + } else if (upd_contiguous) { \ + DISPATCH_NWORK(T, Op, false, true, false); \ + } else if (out_contiguous) { \ + DISPATCH_NWORK(T, Op, true, false, false); \ + } else { \ + DISPATCH_NWORK(T, Op, false, false, false); \ + } + + // Dispatch helper for reduce type + #define DISPATCH_SLICE_OP(T) \ + switch (reduce_type_) { \ + case SliceUpdate::Max: DISPATCH_CONTIG(T, rocm::Maximum); break; \ + case SliceUpdate::Min: DISPATCH_CONTIG(T, rocm::Minimum); break; \ + case SliceUpdate::Sum: DISPATCH_CONTIG(T, rocm::Add); break; \ + case SliceUpdate::Prod: DISPATCH_CONTIG(T, rocm::Multiply); break; \ + default: \ + throw std::runtime_error("SliceUpdate: unsupported reduce type"); \ + } + + switch (out.dtype()) { + case float32: DISPATCH_SLICE_OP(float); break; + case float16: DISPATCH_SLICE_OP(__half); break; + case bfloat16: DISPATCH_SLICE_OP(hip_bfloat16); break; + case int32: DISPATCH_SLICE_OP(int32_t); break; + case int64: DISPATCH_SLICE_OP(int64_t); break; + case uint32: DISPATCH_SLICE_OP(uint32_t); break; + case uint64: DISPATCH_SLICE_OP(uint64_t); break; + case int8: DISPATCH_SLICE_OP(int8_t); break; + case int16: DISPATCH_SLICE_OP(int16_t); break; + case uint8: DISPATCH_SLICE_OP(uint8_t); break; + case uint16: DISPATCH_SLICE_OP(uint16_t); break; + case bool_: DISPATCH_SLICE_OP(bool); break; + default: + throw std::runtime_error("Unsupported dtype for SliceUpdate"); + } + + #undef DISPATCH_SLICE_OP + #undef DISPATCH_CONTIG + #undef DISPATCH_NWORK + #undef SLICE_UPDATE_LAUNCH +} + +void MaskedScatter::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 3); + + const auto& dst = inputs[0]; + const auto& mask = inputs[1]; + const auto& src = inputs[2]; + + auto& s = stream(); + auto& encoder = rocm::get_command_encoder(s); + + const int64_t total = mask.size(); + const CopyType copy_type = (total == 1) + ? CopyType::Scalar + : (dst.flags().row_contiguous ? CopyType::Vector : CopyType::General); + copy_gpu(dst, out, copy_type, s); + if (total == 0) { + return; + } + + array mask_flat = flatten_in_eval(mask, 1, -1, s); + if (mask_flat.data() != mask.data()) { + encoder.add_temporary(mask_flat); + } + if (!mask_flat.flags().row_contiguous) { + mask_flat = contiguous_copy_gpu(mask_flat, s); + encoder.add_temporary(mask_flat); + } + + array scatter_offsets(mask_flat.shape(), uint32, nullptr, {}); + scatter_offsets.set_data(mlx::core::rocm::malloc_async(scatter_offsets.nbytes(), encoder)); + encoder.add_temporary(scatter_offsets); + + const int64_t batch_count = mask_flat.shape(0); + const int64_t mask_batch_size = total / batch_count; + const int64_t src_batch_size = src.size() / batch_count; + + std::vector src_shape(src.shape().begin(), src.shape().end()); + std::vector src_strides(src.strides().begin(), src.strides().end()); + auto src_shape_param = const_param(src_shape); + auto src_strides_param = const_param(src_strides); + const bool src_contiguous = src.flags().row_contiguous; + + constexpr int block_size = 256; + const auto offset_grid = dim3(static_cast(batch_count)); + const auto offset_block = dim3(block_size); + const int64_t num_blocks = (total + block_size - 1) / block_size; + const int32_t src_ndim_v = static_cast(src.ndim()); + + // Offsets kernel: writes scatter_offsets (registered as output so the + // following assign kernel records a graph dependency on it). + encoder.set_input_array(mask_flat); + encoder.set_output_array(scatter_offsets); + encoder.add_kernel_node( + &rocm::masked_scatter_offsets_kernel, + offset_grid, + offset_block, + 0, + gpu_ptr(mask_flat), + gpu_ptr(scatter_offsets), + mask_batch_size); + + // Assign kernel: reads mask_flat, scatter_offsets, src; writes out. + encoder.set_input_array(mask_flat); + encoder.set_input_array(scatter_offsets); + encoder.set_input_array(src); + encoder.set_output_array(out); + +#define LAUNCH_MASKED_SCATTER(T, SrcC) \ + encoder.add_kernel_node( \ + (&rocm::masked_scatter_assign_kernel), \ + dim3(static_cast(num_blocks)), \ + dim3(block_size), \ + 0, \ + gpu_ptr(mask_flat), \ + gpu_ptr(scatter_offsets), \ + gpu_ptr(src), \ + gpu_ptr(out), \ + total, \ + src_shape_param, \ + src_strides_param, \ + src_ndim_v, \ + src_batch_size, \ + mask_batch_size) + +#define DISPATCH_MASKED_SCATTER(T) \ + if (src_contiguous) { \ + LAUNCH_MASKED_SCATTER(T, true); \ + } else { \ + LAUNCH_MASKED_SCATTER(T, false); \ + } + + switch (out.dtype()) { + case bool_: + DISPATCH_MASKED_SCATTER(bool); + break; + case uint8: + DISPATCH_MASKED_SCATTER(uint8_t); + break; + case uint16: + DISPATCH_MASKED_SCATTER(uint16_t); + break; + case uint32: + DISPATCH_MASKED_SCATTER(uint32_t); + break; + case uint64: + DISPATCH_MASKED_SCATTER(uint64_t); + break; + case int8: + DISPATCH_MASKED_SCATTER(int8_t); + break; + case int16: + DISPATCH_MASKED_SCATTER(int16_t); + break; + case int32: + DISPATCH_MASKED_SCATTER(int32_t); + break; + case int64: + DISPATCH_MASKED_SCATTER(int64_t); + break; + case float16: + DISPATCH_MASKED_SCATTER(__half); + break; + case float32: + DISPATCH_MASKED_SCATTER(float); + break; + case float64: + DISPATCH_MASKED_SCATTER(double); + break; + case bfloat16: + DISPATCH_MASKED_SCATTER(hip_bfloat16); + break; + case complex64: + DISPATCH_MASKED_SCATTER(hipFloatComplex); + break; + default: + throw std::runtime_error("Unsupported dtype for MaskedScatter"); + } + +#undef DISPATCH_MASKED_SCATTER +#undef LAUNCH_MASKED_SCATTER +} + +// In-place device-position KV kernels for HIP-graph decode. + +__global__ void _kv_pos_inc(int* p, int delta) { p[0] += delta; } + +// In-place increment of an int32 [1] device scalar (the decode position). +void gpu_kv_pos_increment(array& pos, int delta) { + auto& enc = rocm::get_command_encoder(default_stream(default_device())); + int* p = gpu_ptr(pos); + enc.set_input_array(pos); + enc.set_output_array(pos); + enc.launch_kernel([p, delta](hipStream_t s) { + hipLaunchKernelGGL(_kv_pos_inc, dim3(1), dim3(1), 0, s, p, delta); + }); +} + +// In-place set of an int32 [1] device scalar to an absolute value. +__global__ void _kv_pos_set(int* p, int v) { p[0] = v; } +void gpu_kv_pos_set(array& pos, int v) { + auto& enc = rocm::get_command_encoder(default_stream(default_device())); + int* p = gpu_ptr(pos); + enc.set_output_array(pos); + enc.launch_kernel([p, v](hipStream_t s) { + hipLaunchKernelGGL(_kv_pos_set, dim3(1), dim3(1), 0, s, p, v); + }); +} + +// In-place write of one [B,H,1,D] row into a [B,H,CAP,D] KV buffer at pos[0]. +template +__global__ void _kv_row_write( + T* kv, const T* row, const int* pos, int B, int H, int CAP, int D) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; // over B*H*D + int n = B * H * D; + if (idx >= n) return; + int d = idx % D; + int h = (idx / D) % H; + int b = idx / (H * D); + int p = pos[0]; + kv[((b * H + h) * (long)CAP + p) * D + d] = row[(b * H + h) * D + d]; +} + +void gpu_kv_row_write(array& kv, const array& row, const array& pos) { + auto& enc = rocm::get_command_encoder(default_stream(default_device())); + int B = kv.shape(0), H = kv.shape(1), CAP = kv.shape(2), D = kv.shape(3); + int n = B * H * D; + int threads = 256, blocks = (n + threads - 1) / threads; + const int* pp = gpu_ptr(const_cast(pos)); + enc.set_input_array(row); + enc.set_input_array(pos); + enc.set_output_array(kv); + auto launch = [&](auto* tag) { + using T = std::remove_pointer_t; + T* kvp = gpu_ptr(kv); + const T* rp = gpu_ptr(const_cast(row)); + enc.launch_kernel([=](hipStream_t s) { + hipLaunchKernelGGL((_kv_row_write), dim3(blocks), dim3(threads), 0, s, + kvp, rp, pp, B, H, CAP, D); + }); + }; + switch (kv.dtype()) { + case float32: launch((float*)nullptr); break; + case bfloat16: launch((hip_bfloat16*)nullptr); break; + case float16: launch((__half*)nullptr); break; + default: throw std::runtime_error("gpu_kv_row_write: unsupported dtype"); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/rocm/iterators/general_iterator.hpp b/mlx/backend/rocm/iterators/general_iterator.hpp new file mode 100644 index 0000000000..ec3a844412 --- /dev/null +++ b/mlx/backend/rocm/iterators/general_iterator.hpp @@ -0,0 +1,153 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include + +namespace mlx::core::rocm { + +template +struct GeneralIterator { + using difference_type = ptrdiff_t; + using value_type = IdxType; + using pointer = IdxType*; + using reference = IdxType&; + using iterator_category = std::random_access_iterator_tag; + + const IdxType* base_ptr; + IdxType offset; + const int* shape; + const size_t* strides; + int ndim; + size_t size; + + __device__ GeneralIterator( + const IdxType* base_ptr, + IdxType offset, + const int* shape, + const size_t* strides, + int ndim, + size_t size) + : base_ptr(base_ptr), + offset(offset), + shape(shape), + strides(strides), + ndim(ndim), + size(size) {} + + __device__ GeneralIterator operator+(difference_type n) const { + return GeneralIterator(base_ptr, offset + n, shape, strides, ndim, size); + } + + __device__ GeneralIterator operator-(difference_type n) const { + return GeneralIterator(base_ptr, offset - n, shape, strides, ndim, size); + } + + __device__ difference_type operator-(const GeneralIterator& other) const { + return offset - other.offset; + } + + __device__ GeneralIterator& operator+=(difference_type n) { + offset += n; + return *this; + } + + __device__ GeneralIterator& operator-=(difference_type n) { + offset -= n; + return *this; + } + + __device__ GeneralIterator& operator++() { + ++offset; + return *this; + } + + __device__ GeneralIterator operator++(int) { + GeneralIterator temp = *this; + ++offset; + return temp; + } + + __device__ GeneralIterator& operator--() { + --offset; + return *this; + } + + __device__ GeneralIterator operator--(int) { + GeneralIterator temp = *this; + --offset; + return temp; + } + + __device__ bool operator==(const GeneralIterator& other) const { + return offset == other.offset; + } + + __device__ bool operator!=(const GeneralIterator& other) const { + return offset != other.offset; + } + + __device__ bool operator<(const GeneralIterator& other) const { + return offset < other.offset; + } + + __device__ bool operator>(const GeneralIterator& other) const { + return offset > other.offset; + } + + __device__ bool operator<=(const GeneralIterator& other) const { + return offset <= other.offset; + } + + __device__ bool operator>=(const GeneralIterator& other) const { + return offset >= other.offset; + } + + __device__ IdxType operator*() const { + return base_ptr[elem_to_loc(offset, shape, strides, ndim)]; + } + + __device__ IdxType operator[](difference_type n) const { + return base_ptr[elem_to_loc(offset + n, shape, strides, ndim)]; + } + + private: + __device__ size_t elem_to_loc( + size_t elem, + const int* shape, + const size_t* strides, + int ndim) const { + size_t loc = 0; + for (int i = ndim - 1; i >= 0; --i) { + auto q_and_r = div(elem, static_cast(shape[i])); + loc += q_and_r.rem * strides[i]; + elem = q_and_r.quot; + } + return loc; + } + + __device__ div_t div(size_t numer, size_t denom) const { + div_t result; + result.quot = numer / denom; + result.rem = numer % denom; + return result; + } +}; + +template +__device__ std::pair, GeneralIterator> +make_general_iterators( + const IdxType* base_ptr, + size_t size, + const int* shape, + const size_t* strides, + int ndim) { + auto begin = + GeneralIterator(base_ptr, 0, shape, strides, ndim, size); + auto end = + GeneralIterator(base_ptr, size, shape, strides, ndim, size); + return std::make_pair(begin, end); +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/iterators/strided_iterator.hpp b/mlx/backend/rocm/iterators/strided_iterator.hpp new file mode 100644 index 0000000000..a4fd104a58 --- /dev/null +++ b/mlx/backend/rocm/iterators/strided_iterator.hpp @@ -0,0 +1,106 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include + +namespace mlx::core::rocm { + +template +struct StridedIterator { + using difference_type = ptrdiff_t; + using value_type = T; + using pointer = T*; + using reference = T&; + using iterator_category = std::random_access_iterator_tag; + + T* ptr; + size_t stride; + + __device__ StridedIterator(T* ptr, size_t stride) + : ptr(ptr), stride(stride) {} + + __device__ StridedIterator operator+(difference_type n) const { + return StridedIterator(ptr + n * stride, stride); + } + + __device__ StridedIterator operator-(difference_type n) const { + return StridedIterator(ptr - n * stride, stride); + } + + __device__ difference_type operator-(const StridedIterator& other) const { + return (ptr - other.ptr) / stride; + } + + __device__ StridedIterator& operator+=(difference_type n) { + ptr += n * stride; + return *this; + } + + __device__ StridedIterator& operator-=(difference_type n) { + ptr -= n * stride; + return *this; + } + + __device__ StridedIterator& operator++() { + ptr += stride; + return *this; + } + + __device__ StridedIterator operator++(int) { + StridedIterator temp = *this; + ptr += stride; + return temp; + } + + __device__ StridedIterator& operator--() { + ptr -= stride; + return *this; + } + + __device__ StridedIterator operator--(int) { + StridedIterator temp = *this; + ptr -= stride; + return temp; + } + + __device__ bool operator==(const StridedIterator& other) const { + return ptr == other.ptr; + } + + __device__ bool operator!=(const StridedIterator& other) const { + return ptr != other.ptr; + } + + __device__ bool operator<(const StridedIterator& other) const { + return ptr < other.ptr; + } + + __device__ bool operator>(const StridedIterator& other) const { + return ptr > other.ptr; + } + + __device__ bool operator<=(const StridedIterator& other) const { + return ptr <= other.ptr; + } + + __device__ bool operator>=(const StridedIterator& other) const { + return ptr >= other.ptr; + } + + __device__ T& operator*() const { + return *ptr; + } + + __device__ T& operator[](difference_type n) const { + return *(ptr + n * stride); + } +}; + +template +__device__ StridedIterator make_strided_iterator(T* ptr, size_t stride) { + return StridedIterator(ptr, stride); +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/jit_module.cpp b/mlx/backend/rocm/jit_module.cpp new file mode 100644 index 0000000000..8fa1b99771 --- /dev/null +++ b/mlx/backend/rocm/jit_module.cpp @@ -0,0 +1,500 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/jit_module.h" +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/utils.h" +#include "mlx/version.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace mlx::core::rocm { + +namespace { + +// RAII helper that silences stderr during hipRTC compilation. +// AMD's comgr library (used by hipRTC) unconditionally writes preprocessed +// source and internal diagnostics to fd 2. This floods the terminal with +// thousands of lines of compiler-internal defines every time a new fused +// kernel is JIT-compiled. +struct StderrSuppressor { + StderrSuppressor() { + saved_fd_ = dup(STDERR_FILENO); + if (saved_fd_ >= 0) { + int devnull = open("/dev/null", O_WRONLY); + if (devnull >= 0) { + dup2(devnull, STDERR_FILENO); + close(devnull); + active_ = true; + } else { + // Could not open /dev/null — leave stderr alone. + close(saved_fd_); + saved_fd_ = -1; + } + } + } + ~StderrSuppressor() { + restore(); + } + void restore() { + if (active_) { + fflush(stderr); + dup2(saved_fd_, STDERR_FILENO); + close(saved_fd_); + saved_fd_ = -1; + active_ = false; + } + } + StderrSuppressor(const StderrSuppressor&) = delete; + StderrSuppressor& operator=(const StderrSuppressor&) = delete; + + private: + int saved_fd_ = -1; + bool active_ = false; +}; + +// Extract the last N lines from a compiler log. AMD comgr prepends the +// entire preprocessed source to the error log, making it enormous. The +// actual compiler errors are always at the end. +std::string tail_lines(const std::string& text, size_t n = 60) { + if (text.empty()) { + return text; + } + // Walk backwards to find the start of the last `n` lines. + size_t count = 0; + size_t pos = text.size(); + while (pos > 0 && count < n) { + --pos; + if (text[pos] == '\n') { + ++count; + } + } + if (pos > 0) { + // Skip past the newline we stopped on. + return "... [preprocessed source truncated] ...\n" + text.substr(pos + 1); + } + return text; +} + +// Truncate long kernel names to avoid exceeding filesystem 255-byte limit. +// Names > 200 chars are replaced with a prefix + hash. +std::string safe_filename(const std::string& name) { + constexpr size_t kMaxLen = 200; + if (name.size() <= kMaxLen) { + return name; + } + auto h = std::hash{}(name); + std::ostringstream oss; + oss << name.substr(0, 64) << "_" << std::hex << h; + return oss.str(); +} + +#define CHECK_HIPRTC_ERROR(cmd) check_hiprtc_error(#cmd, (cmd)) + +void check_hiprtc_error(const char* name, hiprtcResult err) { + if (err != HIPRTC_SUCCESS) { + std::ostringstream oss; + oss << name << " failed: " << hiprtcGetErrorString(err); + throw std::runtime_error(oss.str()); + } +} + +// Return the location of the ROCm toolkit. +const std::string& rocm_home() { + static std::string home = []() -> std::string { + const char* home = std::getenv("ROCM_HOME"); + if (home) { + return home; + } + home = std::getenv("ROCM_PATH"); + if (home) { + return home; + } +#if defined(__linux__) + home = "/opt/rocm"; + if (std::filesystem::exists(home)) { + return home; + } +#endif + throw std::runtime_error( + "Environment variable ROCM_HOME or ROCM_PATH is not set."); + }(); + return home; +} + +std::string get_gpu_arch(); + +// Get the cache directory for storing compiled results. The GPU arch is part of +// the path so that, on a multi-GPU host (e.g. an integrated gfx1151 APU + a +// discrete gfx1201 R9700), kernels compiled for one arch are never loaded on the +// other — which fails with "no kernel image" or, worse, silently hangs. +// +// Resolve per CURRENT-device arch and memoize per arch. A single static path +// would freeze the arch to whatever device was current at the FIRST call (the +// default device 0 / APU, e.g. from a load-time static initializer), then serve +// that arch's cache dir to kernels compiled for the OTHER device — defeating the +// whole purpose on a multi-GPU host. +const std::filesystem::path& hsaco_cache_dir() { + static std::mutex mtx; + static std::map by_arch; + std::string arch = get_gpu_arch(); + std::lock_guard lk(mtx); + if (auto it = by_arch.find(arch); it != by_arch.end()) { + return it->second; + } + std::filesystem::path cache; + if (auto c = std::getenv("MLX_HSACO_CACHE_DIR"); c) { + cache = std::filesystem::path(c) / arch; + } else { + cache = std::filesystem::temp_directory_path() / "mlx" / version() / + "hsaco" / arch; + } + if (!std::filesystem::exists(cache)) { + std::error_code error; + if (!std::filesystem::create_directories(cache, error)) { + cache = std::filesystem::path(); + } + } + return by_arch.emplace(std::move(arch), std::move(cache)).first->second; +} + +// Get the path for HSACO file, splitting long names into nested directories. +// This mirrors the CUDA backend approach to handle long kernel names that +// would otherwise exceed filesystem filename limits (typically 255 chars). +std::filesystem::path get_hsaco_path( + const std::filesystem::path& cache_dir, + const std::string& module_name, + const std::string& extension) { + constexpr int max_file_name_length = 245; + if (module_name.size() <= max_file_name_length) { + return cache_dir / (module_name + extension); + } + + auto hsaco_path = cache_dir; + int offset = 0; + while (module_name.size() - offset > max_file_name_length) { + hsaco_path /= module_name.substr(offset, max_file_name_length); + offset += max_file_name_length; + } + hsaco_path /= module_name.substr(offset) + extension; + + return hsaco_path; +} + +// Try to read the cached |hsaco| and |hsaco_kernels| from |cache_dir|. +// If |expected_source| is non-null, the cached .hip source must match it +// exactly or the cache is treated as a miss (kernel source changed in place +// without a version bump — a stale binary would have a mismatched ABI). +bool read_cached_hsaco( + const std::filesystem::path& cache_dir, + const std::string& module_name, + std::string& hsaco, + std::vector>& hsaco_kernels, + const std::string* expected_source = nullptr) { + if (cache_dir.empty()) { + return false; + } + + if (expected_source) { + auto source_path = get_hsaco_path(cache_dir, module_name, ".hip"); + std::ifstream source_file(source_path, std::ios::binary); + if (!source_file.good()) { + return false; + } + std::string cached_source( + (std::istreambuf_iterator(source_file)), + std::istreambuf_iterator()); + if (cached_source != *expected_source) { + return false; + } + } + + auto hsaco_path = get_hsaco_path(cache_dir, module_name, ".hsaco"); + std::error_code error; + auto hsaco_size = std::filesystem::file_size(hsaco_path, error); + if (error) { + return false; + } + std::ifstream hsaco_file(hsaco_path, std::ios::binary); + if (!hsaco_file.good()) { + return false; + } + hsaco.resize(hsaco_size); + hsaco_file.read(hsaco.data(), hsaco_size); + + auto txt_path = get_hsaco_path(cache_dir, module_name, ".txt"); + std::ifstream txt_file(txt_path, std::ios::binary); + std::string line; + while (std::getline(txt_file, line)) { + auto tab = line.find('\t'); + if (tab != std::string::npos) { + hsaco_kernels.emplace_back(line.substr(0, tab), line.substr(tab + 1)); + } + } + return true; +} + +// Write the |hsaco| and |hsaco_kernels| to |cache_dir| with |name|. +void write_cached_hsaco( + const std::filesystem::path& cache_dir, + const std::string& module_name, + const std::string& hsaco, + const std::vector>& hsaco_kernels, + const std::string& source_code) { + if (cache_dir.empty()) { + return; + } + + auto hsaco_path = get_hsaco_path(cache_dir, module_name, ".hsaco"); + + // Create parent directories if they don't exist (for long module names) + std::error_code error; + std::filesystem::create_directories(hsaco_path.parent_path(), error); + if (error) { + return; + } + + std::ofstream hsaco_file(hsaco_path, std::ios::binary); + if (!hsaco.empty()) { + hsaco_file.write(&hsaco.front(), hsaco.size()); + } + + auto txt_path = get_hsaco_path(cache_dir, module_name, ".txt"); + std::ofstream txt_file(txt_path, std::ios::binary); + for (const auto& [name, mangled] : hsaco_kernels) { + txt_file << name << "\t" << mangled << std::endl; + } + + auto source_path = get_hsaco_path(cache_dir, module_name, ".hip"); + std::ofstream source_file(source_path); + source_file << source_code; +} + +// Get GPU architecture string for the current device +std::string get_gpu_arch() { + hipDeviceProp_t props; + int device_id; + CHECK_HIP_ERROR(hipGetDevice(&device_id)); + CHECK_HIP_ERROR(hipGetDeviceProperties(&props, device_id)); + // gcnArchName already contains the full architecture name like "gfx1011" + return std::string(props.gcnArchName); +} + +void compile( + Device& device, + const std::string& module_name, + const std::string& source, + const std::vector& kernel_names, + std::string& hsaco, + std::vector>& hsaco_kernels) { + // Create the program + // Use a hash of the module name to avoid "File name too long" errors + // from hiprtc creating temporary files with the program name. + auto program_name = "kernel_" + + std::to_string(std::hash{}(module_name)) + ".hip"; + hiprtcProgram prog; + CHECK_HIPRTC_ERROR(hiprtcCreateProgram( + &prog, source.c_str(), program_name.c_str(), 0, nullptr, nullptr)); + + std::unique_ptr prog_freer( + &prog, + [](hiprtcProgram* p) { CHECK_HIPRTC_ERROR(hiprtcDestroyProgram(p)); }); + + for (const auto& name : kernel_names) { + CHECK_HIPRTC_ERROR(hiprtcAddNameExpression(prog, name.c_str())); + } + + // Compile program. + std::vector args; + std::vector arg_strings; + + // Add standard flags + arg_strings.push_back("--std=c++17"); + arg_strings.push_back("-O3"); + arg_strings.push_back("-DMLX_USE_ROCM"); + + // Add GPU architecture + std::string gpu_arch = get_gpu_arch(); + std::string arch_flag = "--offload-arch=" + gpu_arch; + arg_strings.push_back(arch_flag); + + // Add include paths + std::string rocm_include = "-I" + rocm_home() + "/include"; + arg_strings.push_back(rocm_include); + + for (const auto& arg : arg_strings) { + args.push_back(arg.c_str()); + } + + // Suppress stderr during hipRTC compilation. AMD's comgr backend + // unconditionally dumps the entire preprocessed source to fd 2, flooding + // the terminal with thousands of lines of compiler-internal defines. + StderrSuppressor suppressor; + hiprtcResult compile_result = + hiprtcCompileProgram(prog, args.size(), args.data()); + suppressor.restore(); // restore stderr before any error reporting + + if (compile_result != HIPRTC_SUCCESS) { + size_t log_size; + CHECK_HIPRTC_ERROR(hiprtcGetProgramLogSize(prog, &log_size)); + std::vector log(log_size + 1, 0); + CHECK_HIPRTC_ERROR(hiprtcGetProgramLog(prog, log.data())); + // The comgr log prepends the entire preprocessed source before the + // actual error messages. Truncate to only the trailing error lines. + std::string truncated = tail_lines(std::string(log.data())); + std::ostringstream oss; + oss << "Failed to compile kernel '" << module_name << "': " << truncated; + throw std::runtime_error(oss.str()); + } + + // Get mangled names of kernel names. + for (const auto& name : kernel_names) { + const char* mangled; + CHECK_HIPRTC_ERROR(hiprtcGetLoweredName(prog, name.c_str(), &mangled)); + hsaco_kernels.emplace_back(name, mangled); + } + + // Get code data. + size_t code_size; + CHECK_HIPRTC_ERROR(hiprtcGetCodeSize(prog, &code_size)); + hsaco.resize(code_size); + CHECK_HIPRTC_ERROR(hiprtcGetCode(prog, hsaco.data())); +} + +void load_module( + const std::string& module_name, + const std::string& hsaco, + const std::vector>& hsaco_kernels, + hipModule_t& module_, + std::unordered_map>& kernels) { + // Load module. + hipError_t load_result = hipModuleLoadData(&module_, hsaco.data()); + if (load_result != hipSuccess) { + std::ostringstream oss; + oss << "Failed to load compiled " << module_name + << " kernel: " << hipGetErrorString(load_result) << "."; + throw std::runtime_error(oss.str()); + } + + // Load kernels. + for (const auto& [name, mangled] : hsaco_kernels) { + hipFunction_t kernel; + CHECK_HIP_ERROR(hipModuleGetFunction(&kernel, module_, mangled.c_str())); + kernels[name] = std::make_pair(kernel, false); + } +} + +} // namespace + +JitModule::JitModule( + Device& device, + const std::string& module_name, + const KernelBuilder& builder, + bool use_disk_cache) { + // Bind the target device before compiling/loading: hipModuleLoadData and + // hipModuleGetFunction load into the CURRENT device's context, and the kernels + // are later launched on this device's stream. If the module loaded into device + // 0's context but launches on device 1, the queue wedges. + device.make_current(); + // Will hold the actual device executable source code and kernel names + std::string hsaco; + std::vector> hsaco_kernels; + + // Use a safe filename for disk cache to avoid exceeding 255-byte limit + std::string cache_name = safe_filename(module_name); + + // Build the source first so the disk cache can be validated against it: a + // JIT kernel whose source changed in place (same module_name, no version + // bump) must invalidate the cached binary, otherwise a stale binary with a + // mismatched argument ABI is loaded and launched. + auto [precompiled, source_code, kernel_names] = builder(); + + const std::string* expected_source = precompiled ? nullptr : &source_code; + if (!read_cached_hsaco( + hsaco_cache_dir(), + cache_name, + hsaco, + hsaco_kernels, + expected_source)) { + // Get the HSACO (AMD GPU binary) + if (precompiled) { + hsaco = std::move(source_code); + for (auto& name : kernel_names) { + hsaco_kernels.emplace_back(name, name); + } + } else { + compile( + device, module_name, source_code, kernel_names, hsaco, hsaco_kernels); + } + + // If requested save them in the file cache for the next launch + if (use_disk_cache) { + write_cached_hsaco( + hsaco_cache_dir(), cache_name, hsaco, hsaco_kernels, source_code); + } + } + + // Load the module + load_module(module_name, hsaco, hsaco_kernels, module_, kernels_); +} + +JitModule::~JitModule() { + if (module_) { + (void)hipModuleUnload(module_); + } +} + +hipFunction_t JitModule::get_kernel( + const std::string& kernel_name, + std::function configure_kernel) { + auto it = kernels_.find(kernel_name); + if (it == kernels_.end()) { + throw std::runtime_error( + std::string("There is no kernel named ") + kernel_name + "."); + } + + // If it is the first time we run this kernel then configure it. Do it only + // once! + if (!it->second.second) { + if (configure_kernel) { + configure_kernel(it->second.first); + } + it->second.second = true; + } + + return it->second.first; +} + +std::unordered_map& get_jit_module_cache() { + static std::unordered_map map; + return map; +} + +JitModule& get_jit_module( + const mlx::core::Device& mlx_device, + const std::string& name, + const KernelBuilder& builder, + bool cache) { + auto& map = get_jit_module_cache(); + // Key by device too: a module compiled/loaded into one device's context is not + // valid on another. Sharing by name across devices would hand a device-1 launch + // a hipFunction_t from device 0's context and wedge the queue. + auto key = std::to_string(mlx_device.index) + ":" + name; + auto it = map.find(key); + if (it == map.end()) { + it = map.try_emplace(key, device(mlx_device), name, builder, cache).first; + } + return it->second; +} + +} // namespace mlx::core::rocm diff --git a/mlx/backend/rocm/jit_module.h b/mlx/backend/rocm/jit_module.h new file mode 100644 index 0000000000..db2064c425 --- /dev/null +++ b/mlx/backend/rocm/jit_module.h @@ -0,0 +1,125 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/common/utils.h" +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/kernel_utils.hpp" + +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace mlx::core::rocm { + +class Device; + +// Maximum number of dimensions supported for JIT kernels +// Note: device/config.h defines MAX_NDIM as a macro for device code +// We use a different name here to avoid conflicts +constexpr int JIT_MAX_NDIM = 8; + +using KernelBuilderResult = std::tuple< + /* precompiled */ bool, + /* source code */ std::string, + /* kernel names */ std::vector>; +using KernelBuilder = std::function; + +struct KernelArgs { + void** args() { + return args_.data(); + } + + void append(const array& a) { + append(reinterpret_cast(gpu_ptr(a))); + } + + template + void append(T val) { + storage_.emplace_back(val); + append_ptr(&storage_.back()); + } + + template + void append(SmallVector vec) { + storage_.emplace_back(std::move(vec)); + append_ptr(std::get>(storage_.back()).data()); + } + + template + void append(const std::vector& vec) { + append(SmallVector(vec.begin(), vec.end())); + } + + // Make sure the arg is copied to an array with size of NDIM. + template + void append_ndim(SmallVector vec) { + if (vec.size() > NDIM) { + std::ostringstream oss; + oss << "ndim can not be larger than " << NDIM << "."; + throw std::runtime_error(oss.str()); + } + vec.resize(NDIM); + append(std::move(vec)); + } + + void append_ptr(const void* v) { + args_.push_back(const_cast(v)); + } + + private: + std::vector args_; + + // The hipGraphAddKernelNode API requires passing pointers to arguments so + // store temporary values until the node is created. + using Arg = std::variant< + std::monostate, + hipDeviceptr_t, + bool, + int32_t, + uint32_t, + int64_t, + float, + SmallVector, + SmallVector, + SmallVector>; + std::deque storage_; +}; + +class JitModule { + public: + JitModule( + Device& device, + const std::string& module_name, + const KernelBuilder& builder, + bool cache); + ~JitModule(); + + JitModule(const JitModule&) = delete; + JitModule& operator=(const JitModule&) = delete; + + hipFunction_t get_kernel( + const std::string& kernel_name, + std::function configure_kernel = nullptr); + + private: + hipModule_t module_{nullptr}; + std::unordered_map> kernels_; +}; + +std::unordered_map& get_jit_module_cache(); + +JitModule& get_jit_module( + const mlx::core::Device& device, + const std::string& name, + const KernelBuilder& builder, + bool use_disk_cache = true); + +} // namespace mlx::core::rocm diff --git a/mlx/backend/rocm/kernel_utils.hip b/mlx/backend/rocm/kernel_utils.hip new file mode 100644 index 0000000000..81b3be8053 --- /dev/null +++ b/mlx/backend/rocm/kernel_utils.hip @@ -0,0 +1,29 @@ +// Copyright © 2025 Apple Inc. + +#include + +namespace mlx::core::rocm { + +// Utility functions for HIP kernels + +__device__ inline int get_global_id() { + return blockIdx.x * blockDim.x + threadIdx.x; +} + +__device__ inline int get_local_id() { + return threadIdx.x; +} + +__device__ inline int get_group_id() { + return blockIdx.x; +} + +__device__ inline int get_local_size() { + return blockDim.x; +} + +__device__ inline int get_num_groups() { + return gridDim.x; +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/kernel_utils.hpp b/mlx/backend/rocm/kernel_utils.hpp new file mode 100644 index 0000000000..a6bfd48e70 --- /dev/null +++ b/mlx/backend/rocm/kernel_utils.hpp @@ -0,0 +1,252 @@ +// Copyright © 2025 Apple Inc. + +// This file includes host-only utilities for writing HIP kernels, the +// difference from backend/rocm/device/utils.hpp is that the latter file only +// include device-only code. + +#pragma once + +#include + +#include "mlx/array.h" +#include "mlx/backend/common/utils.h" +#include "mlx/backend/rocm/allocator.h" +#include "mlx/backend/rocm/device/config.h" +#include "mlx/backend/rocm/device/utils.hpp" + +#include +#include +#include +#include +#include +#include + +namespace mlx::core { + +// Get GPU pointer from array without synchronization. +// This should be used when passing pointers to GPU kernels. +// For CPU access to managed memory, use array::data() which synchronizes. +template +inline T* gpu_ptr(array& arr) { + auto* buf = static_cast(arr.buffer().ptr()); + // Discrete GPU: if the CPU wrote through the host shadow (raw_ptr), flush it + // back to VRAM before a kernel reads it. No-op on the integrated APU and for + // buffers never touched on the CPU (host_dirty stays false). + if (buf->host_dirty) { + rocm::allocator().flush_host_shadow(*buf); + } + return reinterpret_cast(static_cast(buf->data) + arr.offset()); +} + +// For const array, keep constness in pointer unless it is untyped. +template +inline std::conditional_t, void*, const T*> gpu_ptr( + const array& arr) { + return gpu_ptr(const_cast(arr)); +} + +// Note: WARP_SIZE and MAX_NDIM are defined in device/config.h + +template +void dispatch_1_2_3(int n, F&& f) { + switch (n) { + case 1: + f(std::integral_constant{}); + break; + case 2: + f(std::integral_constant{}); + break; + case 3: + f(std::integral_constant{}); + break; + } +} + +template +void dispatch_bool(bool v, F&& f) { + if (v) { + f(std::true_type{}); + } else { + f(std::false_type{}); + } +} + +template +void dispatch_block_dim(int threads, F&& f) { + if (threads <= WARP_SIZE) { + f(std::integral_constant{}); + } else if (threads <= WARP_SIZE * 2) { + f(std::integral_constant{}); + } else if (threads <= WARP_SIZE * 4) { + f(std::integral_constant{}); + } else if (threads <= WARP_SIZE * 8) { + f(std::integral_constant{}); + } else if (threads <= WARP_SIZE * 16) { + f(std::integral_constant{}); + } else { + f(std::integral_constant{}); + } +} + +// Maps CPU types to HIP types. +template +struct CTypeToHipType { + using type = T; +}; + +template <> +struct CTypeToHipType { + using type = __half; +}; + +template <> +struct CTypeToHipType { + using type = hip_bfloat16; +}; + +template <> +struct CTypeToHipType { + using type = hipFloatComplex; +}; + +template +using hip_type_t = typename CTypeToHipType::type; + +// Type traits for detecting floating numbers. +template +inline constexpr bool is_floating_v = + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v; + +// Type traits for detecting complex numbers. +template +inline constexpr bool is_complex_v = + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v; + +// Type traits for detecting complex or real floating point numbers. +template +inline constexpr bool is_inexact_v = is_floating_v || is_complex_v; + +// Utility to copy data from vector to array in host. +template +inline rocm::hip_array const_param(const SmallVector& vec) { + if (vec.size() > NDIM) { + std::ostringstream oss; + oss << "ndim can not be larger than " << NDIM << "."; + throw std::runtime_error(oss.str()); + } + rocm::hip_array result; + std::copy_n(vec.begin(), vec.size(), result.data_); + return result; +} + +// Overload for std::vector +template +inline rocm::hip_array const_param(const std::vector& vec) { + if (vec.size() > NDIM) { + std::ostringstream oss; + oss << "ndim can not be larger than " << NDIM << "."; + throw std::runtime_error(oss.str()); + } + rocm::hip_array result; + std::copy_n(vec.begin(), vec.size(), result.data_); + return result; +} + +// Compute the grid and block dimensions +inline dim3 get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10) { + int block_x = 1; + int block_y = 1; + int block_z = 1; + + // Try to maximize occupancy while respecting dimension sizes + int total_threads = 1 << pow2; // Default to 1024 threads + + // Distribute threads across dimensions + while (block_x < dim0 && block_x < 32) { + block_x *= 2; + } + while (block_y < dim1 && block_x * block_y < total_threads) { + block_y *= 2; + } + while (block_z < dim2 && block_x * block_y * block_z < total_threads) { + block_z *= 2; + } + + return dim3(block_x, block_y, block_z); +} + +inline dim3 get_2d_grid_dims(const Shape& shape, const Strides& strides) { + Dims dims = get_2d_grid_dims_common(shape, strides); + return dim3(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims)); +} + +inline dim3 +get_2d_grid_dims(const Shape& shape, const Strides& strides, size_t divisor) { + // Compute the 2d grid dimensions such that the total size of the grid is + // divided by divisor. + size_t grid_x = 1; + size_t grid_y = 1; + for (size_t i = 0; i < shape.size(); ++i) { + if (strides[i] == 0) { + continue; + } + + // No need to add this shape we can just remove it from the divisor. + if (divisor % shape[i] == 0) { + divisor /= shape[i]; + continue; + } + + if (grid_x * shape[i] < UINT32_MAX) { + grid_x *= shape[i]; + } else { + grid_y *= shape[i]; + } + } + if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) { + throw std::runtime_error("Unable to safely factor shape."); + } + if (grid_y > grid_x) { + std::swap(grid_x, grid_y); + } + return dim3(static_cast(grid_x), static_cast(grid_y), 1); +} + +inline std::pair get_grid_and_block(int dim0, int dim1, int dim2) { + auto block_dims = get_block_dims(dim0, dim1, dim2); + dim3 grid_dims( + (dim0 + block_dims.x - 1) / block_dims.x, + (dim1 + block_dims.y - 1) / block_dims.y, + (dim2 + block_dims.z - 1) / block_dims.z); + return {grid_dims, block_dims}; +} + +// Get the num_blocks and block_dims for a kernel +inline std::tuple get_launch_args( + size_t size, + const Shape& shape, + const Strides& strides, + bool large, + int work_per_thread = 1) { + size_t adjusted_size = (size + work_per_thread - 1) / work_per_thread; + int block_size = 256; + int num_blocks = (adjusted_size + block_size - 1) / block_size; + num_blocks = std::min(num_blocks, 65535); + return {dim3(num_blocks), block_size}; +} + +inline std::tuple +get_launch_args(const array& arr, bool large, int work_per_thread = 1) { + return get_launch_args( + arr.size(), arr.shape(), arr.strides(), large, work_per_thread); +} + +// Ceil division utility +template +inline T ceildiv(T a, T b) { + return (a + b - 1) / b; +} + +} // namespace mlx::core diff --git a/mlx/backend/rocm/layer_norm.hip b/mlx/backend/rocm/layer_norm.hip new file mode 100644 index 0000000000..d695490985 --- /dev/null +++ b/mlx/backend/rocm/layer_norm.hip @@ -0,0 +1,488 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/allocator.h" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/backend/rocm/reduce/reduce.hpp" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/fast_primitives.h" + +#include + +namespace mlx::core { + +namespace rocm { + +// Warp reduce for sum +__device__ float warp_reduce_sum_f(float val) { + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + val += __shfl_xor(val, offset); + } + return val; +} + +// Warp reduce for float3 (sum, sum*t, t*t) +struct float3_sum { + float x, y, z; +}; + +__device__ float3_sum warp_reduce_sum_f3(float3_sum val) { + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + val.x += __shfl_xor(val.x, offset); + val.y += __shfl_xor(val.y, offset); + val.z += __shfl_xor(val.z, offset); + } + return val; +} + +template +__global__ void layer_norm_kernel( + const T* x, + const T* w, + const T* b, + T* out, + float eps, + int32_t axis_size, + int64_t w_stride, + int64_t b_stride) { + int row = blockIdx.x; + + x += row * axis_size; + out += row * axis_size; + + // Sum for mean + float sum = 0; + for (int i = threadIdx.x * N_READS; i < axis_size; i += BLOCK_DIM * N_READS) { + #pragma unroll + for (int j = 0; j < N_READS && i + j < axis_size; ++j) { + sum += static_cast(x[i + j]); + } + } + + // Block reduce for sum + __shared__ float shared_sum[BLOCK_DIM / WARP_SIZE + 1]; + + float warp_sum = warp_reduce_sum_f(sum); + int lane = threadIdx.x % WARP_SIZE; + int warp_id = threadIdx.x / WARP_SIZE; + + if (lane == 0) { + shared_sum[warp_id] = warp_sum; + } + __syncthreads(); + + if (warp_id == 0) { + sum = (lane < (BLOCK_DIM + WARP_SIZE - 1) / WARP_SIZE) ? shared_sum[lane] : 0; + sum = warp_reduce_sum_f(sum); + } + __syncthreads(); + + if (threadIdx.x == 0) { + shared_sum[0] = sum; + } + __syncthreads(); + float mean = shared_sum[0] / axis_size; + + // Compute variance + float var_sum = 0; + for (int i = threadIdx.x * N_READS; i < axis_size; i += BLOCK_DIM * N_READS) { + #pragma unroll + for (int j = 0; j < N_READS && i + j < axis_size; ++j) { + float t = static_cast(x[i + j]) - mean; + var_sum += t * t; + } + } + + // Block reduce for variance + warp_sum = warp_reduce_sum_f(var_sum); + + if (lane == 0) { + shared_sum[warp_id] = warp_sum; + } + __syncthreads(); + + if (warp_id == 0) { + var_sum = (lane < (BLOCK_DIM + WARP_SIZE - 1) / WARP_SIZE) ? shared_sum[lane] : 0; + var_sum = warp_reduce_sum_f(var_sum); + } + __syncthreads(); + + if (threadIdx.x == 0) { + shared_sum[0] = var_sum; + } + __syncthreads(); + // Use 1/sqrt instead of rsqrtf for IEEE-compliant precision + // (matches Metal's metal::precise::rsqrt behavior) + float normalizer = 1.0f / sqrtf(shared_sum[0] / axis_size + eps); + + // Write output + for (int i = threadIdx.x * N_READS; i < axis_size; i += BLOCK_DIM * N_READS) { + #pragma unroll + for (int j = 0; j < N_READS && i + j < axis_size; ++j) { + int idx = i + j; + float norm = (static_cast(x[idx]) - mean) * normalizer; + float wi = (w_stride == 0) ? static_cast(w[0]) : static_cast(w[idx * w_stride]); + float bi = (b_stride == 0) ? static_cast(b[0]) : static_cast(b[idx * b_stride]); + out[idx] = static_cast(wi * norm + bi); + } + } +} + +template +__global__ void layer_norm_vjp_kernel( + const T* x, + const T* w, + const T* g, + T* gx, + T* gw, + float eps, + int32_t axis_size, + int64_t w_stride) { + int row = blockIdx.x; + + x += row * axis_size; + g += row * axis_size; + gx += row * axis_size; + gw += row * axis_size; + + // Sum for mean + float sum = 0; + for (int i = threadIdx.x * N_READS; i < axis_size; i += BLOCK_DIM * N_READS) { + #pragma unroll + for (int j = 0; j < N_READS && i + j < axis_size; ++j) { + sum += static_cast(x[i + j]); + } + } + + // Block reduce for sum + __shared__ float shared_sum[BLOCK_DIM / WARP_SIZE + 1]; + __shared__ float3_sum shared_f3[BLOCK_DIM / WARP_SIZE + 1]; + + float warp_sum = warp_reduce_sum_f(sum); + int lane = threadIdx.x % WARP_SIZE; + int warp_id = threadIdx.x / WARP_SIZE; + + if (lane == 0) { + shared_sum[warp_id] = warp_sum; + } + __syncthreads(); + + if (warp_id == 0) { + sum = (lane < (BLOCK_DIM + WARP_SIZE - 1) / WARP_SIZE) ? shared_sum[lane] : 0; + sum = warp_reduce_sum_f(sum); + } + __syncthreads(); + + if (threadIdx.x == 0) { + shared_sum[0] = sum; + } + __syncthreads(); + float mean = shared_sum[0] / axis_size; + + // Compute factors: (wg_sum, wg*xc_sum, xc^2_sum) + float3_sum factors = {0, 0, 0}; + for (int i = threadIdx.x * N_READS; i < axis_size; i += BLOCK_DIM * N_READS) { + #pragma unroll + for (int j = 0; j < N_READS && i + j < axis_size; ++j) { + int idx = i + j; + float t = static_cast(x[idx]) - mean; + float wi = (w_stride == 0) ? static_cast(w[0]) : static_cast(w[idx * w_stride]); + float gi = static_cast(g[idx]); + float wg = wi * gi; + factors.x += wg; + factors.y += wg * t; + factors.z += t * t; + } + } + + // Block reduce for factors + float3_sum warp_f3 = warp_reduce_sum_f3(factors); + + if (lane == 0) { + shared_f3[warp_id] = warp_f3; + } + __syncthreads(); + + if (warp_id == 0) { + factors = (lane < (BLOCK_DIM + WARP_SIZE - 1) / WARP_SIZE) ? shared_f3[lane] : float3_sum{0, 0, 0}; + factors = warp_reduce_sum_f3(factors); + } + __syncthreads(); + + if (threadIdx.x == 0) { + shared_f3[0] = factors; + } + __syncthreads(); + factors = shared_f3[0]; + + float meanwg = factors.x / axis_size; + float meanwgxc = factors.y / axis_size; + float normalizer2 = 1.0f / (factors.z / axis_size + eps); + float normalizer = sqrtf(normalizer2); + + // Write outputs + for (int i = threadIdx.x * N_READS; i < axis_size; i += BLOCK_DIM * N_READS) { + #pragma unroll + for (int j = 0; j < N_READS && i + j < axis_size; ++j) { + int idx = i + j; + float xi_centered = static_cast(x[idx]) - mean; + float xi_norm = xi_centered * normalizer; + float wi = (w_stride == 0) ? static_cast(w[0]) : static_cast(w[idx * w_stride]); + float gi = static_cast(g[idx]); + + // Gradient for x + gx[idx] = static_cast(normalizer * (wi * gi - meanwg) - xi_norm * meanwgxc * normalizer2); + + // Gradient for w (per-element, will be reduced later) + if constexpr (HAS_W) { + gw[idx] = static_cast(gi * xi_norm); + } + } + } +} + +} // namespace rocm + +namespace fast { + +bool LayerNorm::use_fallback(Stream s) { + return s.device == Device::cpu; +} + +void LayerNorm::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + auto& s = stream(); + auto& out = outputs[0]; + auto& encoder = rocm::get_command_encoder(s); + + // Make sure that the last dimension is contiguous. + auto set_output = [&s, &out, &encoder](const array& x) { + bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; + if (no_copy && x.ndim() > 1) { + auto s = x.strides()[x.ndim() - 2]; + no_copy &= (s == 0 || s == x.shape().back()); + } + if (no_copy) { + if (x.is_donatable()) { + out.copy_shared_buffer(x); + } else { + out.set_data( + mlx::core::rocm::malloc_async(x.data_size() * x.itemsize(), encoder), + x.data_size(), + x.strides(), + x.flags()); + } + return x; + } else { + array x_copy = contiguous_copy_gpu(x, s); + out.copy_shared_buffer(x_copy); + return x_copy; + } + }; + + const array x = set_output(inputs[0]); + const array& w = inputs[1]; + const array& b = inputs[2]; + + int32_t axis_size = x.shape().back(); + int32_t n_rows = x.data_size() / axis_size; + int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; + int64_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0; + + encoder.set_input_array(x); + encoder.set_input_array(w); + encoder.set_input_array(b); + encoder.set_output_array(out); + + constexpr int BLOCK_DIM = 256; + constexpr int N_READS = 4; + + switch (out.dtype()) { + case float32: + encoder.add_kernel_node( + &rocm::layer_norm_kernel, + dim3(n_rows), dim3(BLOCK_DIM), 0, + gpu_ptr(x), gpu_ptr(w), gpu_ptr(b), gpu_ptr(out), + eps_, axis_size, w_stride, b_stride); + break; + case float16: + encoder.add_kernel_node( + &rocm::layer_norm_kernel<__half, BLOCK_DIM, N_READS>, + dim3(n_rows), dim3(BLOCK_DIM), 0, + gpu_ptr<__half>(x), gpu_ptr<__half>(w), gpu_ptr<__half>(b), gpu_ptr<__half>(out), + eps_, axis_size, w_stride, b_stride); + break; + case bfloat16: + encoder.add_kernel_node( + &rocm::layer_norm_kernel, + dim3(n_rows), dim3(BLOCK_DIM), 0, + gpu_ptr(x), gpu_ptr(w), gpu_ptr(b), gpu_ptr(out), + eps_, axis_size, w_stride, b_stride); + break; + default: + throw std::runtime_error("Unsupported type for layer_norm"); + } +} + +void LayerNormVJP::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + auto& s = stream(); + auto& encoder = rocm::get_command_encoder(s); + + // Ensure row contiguity + auto check_input = [&s](const array& x, bool& copied) { + if (x.flags().row_contiguous) { + copied = false; + return x; + } + copied = true; + return contiguous_copy_gpu(x, s); + }; + + bool donate_x = inputs[0].is_donatable(); + bool donate_g = inputs[3].is_donatable(); + bool copied; + auto x = check_input(inputs[0], copied); + donate_x |= copied; + const array& w = inputs[1]; + const array& b = inputs[2]; + bool g_copied; + auto g = check_input(inputs[3], g_copied); + donate_g |= g_copied; + array& gx = outputs[0]; + array& gw = outputs[1]; + array& gb = outputs[2]; + + // Check whether we had a weight + bool has_w = w.ndim() != 0; + + // Allocate space for the outputs + bool g_in_gx = false; + if (donate_x) { + gx.copy_shared_buffer(x); + } else if (donate_g) { + gx.copy_shared_buffer(g); + g_in_gx = true; + } else { + gx.set_data(mlx::core::rocm::malloc_async(gx.nbytes(), encoder)); + } + if (g_copied && !g_in_gx) { + encoder.add_temporary(g); + } + + int32_t axis_size = x.shape().back(); + int32_t n_rows = x.data_size() / axis_size; + int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; + + // Allocate a temporary to store the gradients for w + array gw_temp = + (has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w; + bool g_in_gw = false; + if (has_w) { + if (!g_in_gx && donate_g) { + g_in_gw = true; + gw_temp.copy_shared_buffer(g); + } else { + gw_temp.set_data(mlx::core::rocm::malloc_async(gw_temp.nbytes(), encoder)); + encoder.add_temporary(gw_temp); + } + } + + // The gradient for b in case we had a b + bool has_gb = (gb.ndim() == 1 && gb.size() == axis_size); + if (has_gb) { + // Sum reduction over rows for gb + gb.set_data(mlx::core::rocm::malloc_async(gb.nbytes(), encoder)); + // TODO: Implement proper column reduction for gb + // For now, we'll compute it in the kernel or use a simple reduction + } + + encoder.set_input_array(x); + encoder.set_input_array(w); + encoder.set_input_array(g); + encoder.set_output_array(gx); + encoder.set_output_array(gw_temp); + + constexpr int BLOCK_DIM = 256; + constexpr int N_READS = 4; + + if (has_w) { + switch (gx.dtype()) { + case float32: + encoder.add_kernel_node( + &rocm::layer_norm_vjp_kernel, + dim3(n_rows), dim3(BLOCK_DIM), 0, + gpu_ptr(x), gpu_ptr(w), gpu_ptr(g), + gpu_ptr(gx), gpu_ptr(gw_temp), + eps_, axis_size, w_stride); + break; + case float16: + encoder.add_kernel_node( + &rocm::layer_norm_vjp_kernel<__half, true, BLOCK_DIM, N_READS>, + dim3(n_rows), dim3(BLOCK_DIM), 0, + gpu_ptr<__half>(x), gpu_ptr<__half>(w), gpu_ptr<__half>(g), + gpu_ptr<__half>(gx), gpu_ptr<__half>(gw_temp), + eps_, axis_size, w_stride); + break; + case bfloat16: + encoder.add_kernel_node( + &rocm::layer_norm_vjp_kernel, + dim3(n_rows), dim3(BLOCK_DIM), 0, + gpu_ptr(x), gpu_ptr(w), gpu_ptr(g), + gpu_ptr(gx), gpu_ptr(gw_temp), + eps_, axis_size, w_stride); + break; + default: + throw std::runtime_error("Unsupported type for layer_norm_vjp"); + } + } else { + switch (gx.dtype()) { + case float32: { + float* gw_null = nullptr; + encoder.add_kernel_node( + &rocm::layer_norm_vjp_kernel, + dim3(n_rows), dim3(BLOCK_DIM), 0, + gpu_ptr(x), gpu_ptr(w), gpu_ptr(g), + gpu_ptr(gx), gw_null, + eps_, axis_size, w_stride); + break; + } + case float16: { + __half* gw_null = nullptr; + encoder.add_kernel_node( + &rocm::layer_norm_vjp_kernel<__half, false, BLOCK_DIM, N_READS>, + dim3(n_rows), dim3(BLOCK_DIM), 0, + gpu_ptr<__half>(x), gpu_ptr<__half>(w), gpu_ptr<__half>(g), + gpu_ptr<__half>(gx), gw_null, + eps_, axis_size, w_stride); + break; + } + case bfloat16: { + hip_bfloat16* gw_null = nullptr; + encoder.add_kernel_node( + &rocm::layer_norm_vjp_kernel, + dim3(n_rows), dim3(BLOCK_DIM), 0, + gpu_ptr(x), gpu_ptr(w), gpu_ptr(g), + gpu_ptr(gx), gw_null, + eps_, axis_size, w_stride); + break; + } + default: + throw std::runtime_error("Unsupported type for layer_norm_vjp"); + } + } + + // Reduce gw_temp to gw if we have weights + if (has_w) { + ReductionPlan plan( + ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size}); + col_reduce(encoder, gw_temp, gw, Reduce::ReduceType::Sum, {0}, plan); + } +} + +} // namespace fast + +} // namespace mlx::core diff --git a/mlx/backend/rocm/load.cpp b/mlx/backend/rocm/load.cpp new file mode 100644 index 0000000000..48a4439318 --- /dev/null +++ b/mlx/backend/rocm/load.cpp @@ -0,0 +1,89 @@ +// Copyright © 2025 Apple Inc. + +#include +#include + +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/allocator.h" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/backend/rocm/utils.h" +#include "mlx/primitives.h" + +#include + +namespace { + +template +void swap_endianness(uint8_t* data_bytes, size_t N) { + struct Elem { + uint8_t bytes[scalar_size]; + }; + + Elem* data = reinterpret_cast(data_bytes); + + for (size_t i = 0; i < N; i++) { + for (size_t j = 0; j < (scalar_size / 2); j++) { + std::swap(data[i].bytes[j], data[i].bytes[scalar_size - j - 1]); + } + } +} + +void hip_host_free_callback(void* ptr) { + (void)hipHostFree(ptr); +} + +} // namespace + +namespace mlx::core { + +void Load::eval_gpu(const std::vector& inputs, array& out) { + auto& encoder = rocm::get_command_encoder(stream()); + auto size = out.size(); + auto nbytes = size * out.itemsize(); + out.set_data(mlx::core::rocm::malloc_async(nbytes, encoder)); + // Stage through PINNED host memory. An async H2D copy from pageable memory is + // unreliable on a discrete GPU over a non-coherent link (TB5 eGPU): the driver + // must internally stage it, which can stall the stream (queue stuck, GPU shows + // busy, the eval's sync never returns). Pinned memory DMAs directly and lets + // the copy actually run asynchronously. + void* out_ptr = nullptr; + if (hipHostMalloc(&out_ptr, nbytes, hipHostMallocDefault) != hipSuccess || + out_ptr == nullptr) { + // Fallback: pageable + synchronous copy (still correct, just slower). + out_ptr = malloc(nbytes); + reader_->read(static_cast(out_ptr), nbytes, offset_); + if (swap_endianness_) { + switch (out.itemsize()) { + case 2: swap_endianness<2>(reinterpret_cast(out_ptr), size); break; + case 4: swap_endianness<4>(reinterpret_cast(out_ptr), size); break; + case 8: swap_endianness<8>(reinterpret_cast(out_ptr), size); break; + } + } + (void)hipMemcpy(gpu_ptr(out), out_ptr, nbytes, hipMemcpyHostToDevice); + free(out_ptr); + return; + } + reader_->read(static_cast(out_ptr), nbytes, offset_); + if (swap_endianness_) { + switch (out.itemsize()) { + case 2: + swap_endianness<2>(reinterpret_cast(out_ptr), size); + break; + case 4: + swap_endianness<4>(reinterpret_cast(out_ptr), size); + break; + case 8: + swap_endianness<8>(reinterpret_cast(out_ptr), size); + break; + } + } + (void)hipMemcpyAsync( + gpu_ptr(out), + out_ptr, + nbytes, + hipMemcpyHostToDevice, + encoder.stream()); + (void)hipLaunchHostFunc(encoder.stream(), hip_host_free_callback, out_ptr); +} + +} // namespace mlx::core diff --git a/mlx/backend/rocm/logsumexp.hip b/mlx/backend/rocm/logsumexp.hip new file mode 100644 index 0000000000..e1204badc7 --- /dev/null +++ b/mlx/backend/rocm/logsumexp.hip @@ -0,0 +1,194 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/allocator.h" +#include "mlx/backend/rocm/device/cast_op.hpp" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include + +#include + +namespace mlx::core { + +namespace rocm { + +template +inline __device__ T logsumexp_exp(T x) { + return __expf(x); +} + +// Warp reduce for max - use runtime warpSize +template +__device__ T warp_reduce_max_lse(T val) { + for (int offset = warpSize / 2; offset > 0; offset /= 2) { + T other = __shfl_xor(val, offset); + val = val > other ? val : other; + } + return val; +} + +// Warp reduce for sum - use runtime warpSize +template +__device__ T warp_reduce_sum_lse(T val) { + for (int offset = warpSize / 2; offset > 0; offset /= 2) { + val += __shfl_xor(val, offset); + } + return val; +} + +template +__global__ void logsumexp_kernel(const T* in, T* out, int axis_size) { + int row = blockIdx.x; + + in += row * axis_size; + + // Thread reduce for max + AccT prevmax; + AccT maxval = -1e38f; + AccT normalizer = 0; + + for (int r = 0; r < (axis_size + BLOCK_DIM * N_READS - 1) / (BLOCK_DIM * N_READS); r++) { + int base_idx = r * BLOCK_DIM * N_READS + threadIdx.x * N_READS; + prevmax = maxval; + + #pragma unroll + for (int j = 0; j < N_READS; ++j) { + int idx = base_idx + j; + if (idx < axis_size) { + AccT val = static_cast(in[idx]); + maxval = val > maxval ? val : maxval; + } + } + + // Online normalizer calculation + normalizer = normalizer * logsumexp_exp(prevmax - maxval); + #pragma unroll + for (int j = 0; j < N_READS; ++j) { + int idx = base_idx + j; + if (idx < axis_size) { + normalizer += logsumexp_exp(static_cast(in[idx]) - maxval); + } + } + } + + // Block reduce for max using shared memory + __shared__ AccT shared_max[32]; // Max 32 warps + __shared__ AccT shared_norm[32]; + + int lane = threadIdx.x % warpSize; + int warp_id = threadIdx.x / warpSize; + int num_warps = (BLOCK_DIM + warpSize - 1) / warpSize; + + // First warp reduce + prevmax = maxval; + maxval = warp_reduce_max_lse(maxval); + normalizer = normalizer * logsumexp_exp(prevmax - maxval); + normalizer = warp_reduce_sum_lse(normalizer); + + if (lane == 0) { + shared_max[warp_id] = maxval; + shared_norm[warp_id] = normalizer; + } + __syncthreads(); + + // Second warp reduce (only first warp) + if (warp_id == 0) { + prevmax = maxval; + maxval = (lane < num_warps) ? shared_max[lane] : -1e38f; + maxval = warp_reduce_max_lse(maxval); + + normalizer = (lane < num_warps) ? shared_norm[lane] : 0; + normalizer = normalizer * logsumexp_exp(prevmax - maxval); + normalizer = warp_reduce_sum_lse(normalizer); + } + + // Write output + if (threadIdx.x == 0) { + if (isinf(maxval)) { + out[row] = static_cast(maxval); + } else { + out[row] = static_cast(logf(normalizer) + maxval); + } + } +} + +} // namespace rocm + +void LogSumExp::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + auto& s = stream(); + auto& encoder = rocm::get_command_encoder(s); + + // Make sure that the last dimension is contiguous. + auto ensure_contiguous = [&s, &encoder](const array& x) { + if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) { + return x; + } else { + array x_copy = contiguous_copy_gpu(x, s); + encoder.add_temporary(x_copy); + return x_copy; + } + }; + + auto in = ensure_contiguous(inputs[0]); + if (in.flags().row_contiguous) { + out.set_data(mlx::core::rocm::malloc_async(out.nbytes(), encoder)); + } else { + auto n = in.shape(-1); + auto flags = in.flags(); + auto strides = in.strides(); + for (auto& stride : strides) { + stride /= n; + } + bool col_contig = strides[0] == 1; + for (int i = 1; col_contig && i < strides.size(); ++i) { + col_contig &= + (out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]); + } + flags.col_contiguous = col_contig; + out.set_data( + mlx::core::rocm::malloc_async(in.nbytes() / n, encoder), + in.data_size() / n, + std::move(strides), + flags); + } + + int axis_size = in.shape().back(); + int n_rows = in.data_size() / axis_size; + + encoder.set_input_array(in); + encoder.set_output_array(out); + + constexpr int BLOCK_DIM = 256; + constexpr int N_READS = 4; + + switch (out.dtype()) { + case float32: + encoder.add_kernel_node( + &rocm::logsumexp_kernel, + dim3(n_rows), dim3(BLOCK_DIM), 0, + gpu_ptr(in), gpu_ptr(out), axis_size); + break; + case float16: + encoder.add_kernel_node( + &rocm::logsumexp_kernel<__half, float, BLOCK_DIM, N_READS>, + dim3(n_rows), dim3(BLOCK_DIM), 0, + gpu_ptr<__half>(in), gpu_ptr<__half>(out), axis_size); + break; + case bfloat16: + encoder.add_kernel_node( + &rocm::logsumexp_kernel, + dim3(n_rows), dim3(BLOCK_DIM), 0, + gpu_ptr(in), gpu_ptr(out), axis_size); + break; + default: + throw std::runtime_error("Unsupported type for logsumexp"); + } +} + +} // namespace mlx::core + diff --git a/mlx/backend/rocm/lru_cache.h b/mlx/backend/rocm/lru_cache.h new file mode 100644 index 0000000000..b78d89dc74 --- /dev/null +++ b/mlx/backend/rocm/lru_cache.h @@ -0,0 +1,122 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace mlx::core::rocm { + +// LRU cache with byte-based keys +template +class LRUBytesKeyCache { + public: + LRUBytesKeyCache(const char* env_var, size_t default_capacity) + : capacity_(default_capacity) { + if (const char* env = std::getenv(env_var)) { + capacity_ = std::stoul(env); + } + } + + std::optional get(const Key& key) { + std::lock_guard lock(mutex_); + auto it = cache_map_.find(key); + if (it == cache_map_.end()) { + return std::nullopt; + } + // Move to front (most recently used) + cache_list_.splice(cache_list_.begin(), cache_list_, it->second); + return it->second->second; + } + + void put(const Key& key, const Value& value) { + std::lock_guard lock(mutex_); + auto it = cache_map_.find(key); + if (it != cache_map_.end()) { + // Update existing entry and move to front + it->second->second = value; + cache_list_.splice(cache_list_.begin(), cache_list_, it->second); + return; + } + + // Evict if at capacity + while (cache_list_.size() >= capacity_) { + auto last = cache_list_.back(); + cache_map_.erase(last.first); + cache_list_.pop_back(); + } + + // Insert new entry at front + cache_list_.emplace_front(key, value); + cache_map_[key] = cache_list_.begin(); + } + + void clear() { + std::lock_guard lock(mutex_); + cache_list_.clear(); + cache_map_.clear(); + } + + size_t size() const { + std::lock_guard lock(mutex_); + return cache_list_.size(); + } + + private: + size_t capacity_; + std::list> cache_list_; + std::unordered_map>::iterator> + cache_map_; + mutable std::mutex mutex_; +}; + +// Simple LRU cache with size_t keys +template +class LRUCache { + public: + explicit LRUCache(size_t capacity) : capacity_(capacity) {} + + std::optional get(size_t key) { + std::lock_guard lock(mutex_); + auto it = cache_map_.find(key); + if (it == cache_map_.end()) { + return std::nullopt; + } + cache_list_.splice(cache_list_.begin(), cache_list_, it->second); + return it->second->second; + } + + void put(size_t key, const Value& value) { + std::lock_guard lock(mutex_); + auto it = cache_map_.find(key); + if (it != cache_map_.end()) { + it->second->second = value; + cache_list_.splice(cache_list_.begin(), cache_list_, it->second); + return; + } + + while (cache_list_.size() >= capacity_) { + auto last = cache_list_.back(); + cache_map_.erase(last.first); + cache_list_.pop_back(); + } + + cache_list_.emplace_front(key, value); + cache_map_[key] = cache_list_.begin(); + } + + private: + size_t capacity_; + std::list> cache_list_; + std::unordered_map< + size_t, + typename std::list>::iterator> + cache_map_; + mutable std::mutex mutex_; +}; + +} // namespace mlx::core::rocm diff --git a/mlx/backend/rocm/matmul.cpp b/mlx/backend/rocm/matmul.cpp new file mode 100644 index 0000000000..f0e9046bfd --- /dev/null +++ b/mlx/backend/rocm/matmul.cpp @@ -0,0 +1,1124 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/matmul.h" +#include "mlx/backend/gpu/copy.h" +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/allocator.h" +#include "mlx/backend/rocm/gemms/gemv.h" +#include "mlx/backend/rocm/gemms/hipblaslt_gemm.h" +#include "mlx/backend/rocm/gemms/naive_gemm.h" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/primitives.h" +#include "mlx/types/half_types.h" + +#include +#include + +#include +#include +#include +#include + +namespace mlx::core { + +namespace { + +std::tuple +check_transpose(rocm::CommandEncoder& enc, const Stream& s, const array& arr) { + auto stx = arr.strides()[arr.ndim() - 2]; + auto sty = arr.strides()[arr.ndim() - 1]; + if (sty == 1 && stx == arr.shape(-1)) { + return std::make_tuple(false, stx, arr); + } else if (stx == 1 && sty == arr.shape(-2)) { + return std::make_tuple(true, sty, arr); + } else { + array arr_copy = contiguous_copy_gpu(arr, s); + enc.add_temporary(arr_copy); + return std::make_tuple(false, arr.shape(-1), arr_copy); + } +} + +std::tuple ensure_batch_contiguous( + const array& x, + rocm::CommandEncoder& encoder, + Stream s) { + if (x.flags().row_contiguous) { + return std::make_tuple(false, x.strides(-2), x); + } + + bool rc = true; + for (int i = 0; i < x.ndim() - 3; i++) { + rc &= (x.strides(i + 1) * x.shape(i)) == x.strides(i); + } + if (rc) { + return check_transpose(encoder, s, x); + } + + array x_copy = contiguous_copy_gpu(x, s); + encoder.add_temporary(x_copy); + return std::make_tuple(false, x_copy.strides(-2), x_copy); +} + +std::pair get_uniform_batch_stride( + const Shape& batch_shape, + const Strides& batch_strides) { + if (batch_shape.empty() || batch_shape.size() != batch_strides.size()) { + return {false, 0}; + } + + if (batch_shape.size() == 1) { + return {true, batch_strides.back()}; + } + + for (int i = batch_shape.size() - 2; i >= 0; --i) { + int64_t cur = batch_strides[i]; + int64_t next = batch_strides[i + 1]; + if (cur == 0 && next == 0) { + continue; + } + if (cur != next * batch_shape[i + 1]) { + return {false, 0}; + } + } + + return {true, batch_strides.back()}; +} + +int parse_non_negative_int_env(const char* env_name, int default_value) { + const char* raw = std::getenv(env_name); + if (raw == nullptr || *raw == '\0') { + return default_value; + } + + char* end = nullptr; + long value = std::strtol(raw, &end, 10); + if (end == raw || *end != '\0' || value < 0) { + return default_value; + } + return static_cast(value); +} + +int gemm_solution_index_f32(bool batched) { + static int single_index = + parse_non_negative_int_env("MLX_ROCM_GEMM_F32_SOLUTION_INDEX", 0); + static int batched_index = parse_non_negative_int_env( + "MLX_ROCM_GEMM_F32_BATCHED_SOLUTION_INDEX", -1); + if (!batched) { + return single_index; + } + return batched_index >= 0 ? batched_index : single_index; +} + +int gemm_solution_index_bf16(bool batched) { + static int single_index = + parse_non_negative_int_env("MLX_ROCM_GEMM_BF16_SOLUTION_INDEX", 0); + static int batched_index = parse_non_negative_int_env( + "MLX_ROCM_GEMM_BF16_BATCHED_SOLUTION_INDEX", -1); + if (!batched) { + return single_index; + } + return batched_index >= 0 ? batched_index : single_index; +} + +void gemm_rocblas( + rocm::CommandEncoder& encoder, + int M, + int N, + int K, + bool a_transposed, + int64_t lda, + bool b_transposed, + int64_t ldb, + array& out, + const array& a, + const array& b, + float alpha = 1.0f, + float beta = 0.0f) { + // Try hipBLASLt for bf16/fp16 GEMMs -- it often picks faster kernels than + // rocBLAS for half-precision on RDNA 3/3.5/4 and CDNA GPUs. + if ((a.dtype() == bfloat16 || a.dtype() == float16) && + rocm::is_hipblaslt_available()) { + try { + rocm::hipblaslt_gemm( + encoder, + a_transposed, + b_transposed, + M, + N, + K, + alpha, + a, + lda, + b, + ldb, + beta, + out, + N, // ldc = N for row-major output + a.dtype()); + return; + } catch (...) { + // hipBLASLt failed (unsupported config, etc.) -- fall through to rocBLAS. + } + } + + auto& device = encoder.device(); + rocblas_handle handle = device.get_rocblas_handle(); + + // rocBLAS uses column-major, so we swap A and B and compute B^T * A^T = (A * + // B)^T But since we want row-major output, we compute C = A * B by doing C^T + // = B^T * A^T + rocblas_operation trans_a = + b_transposed ? rocblas_operation_transpose : rocblas_operation_none; + rocblas_operation trans_b = + a_transposed ? rocblas_operation_transpose : rocblas_operation_none; + + // We pass B then A (swapped) to compute C^T = B^T * A^T. The leading + // dimensions come directly from check_transpose() for each operand. + const int64_t ld_b = ldb; + const int64_t ld_a = lda; + const void* a_ptr = gpu_ptr(a); + const void* b_ptr = gpu_ptr(b); + void* out_ptr = gpu_ptr(out); + + encoder.launch_kernel([&, a_ptr, b_ptr, out_ptr](hipStream_t stream) { + encoder.device().set_rocblas_stream(stream); + + switch (a.dtype()) { + case float32: { + float alpha_f = alpha; + float beta_f = beta; + int solution_index = gemm_solution_index_f32(false); + static std::atomic solution_valid{true}; + + if (solution_index > 0 && + solution_valid.load(std::memory_order_relaxed)) { + rocblas_status status = rocblas_gemm_ex( + handle, + trans_a, + trans_b, + N, + M, + K, + &alpha_f, + b_ptr, + rocblas_datatype_f32_r, + ld_b, + a_ptr, + rocblas_datatype_f32_r, + ld_a, + &beta_f, + out_ptr, + rocblas_datatype_f32_r, + N, + out_ptr, + rocblas_datatype_f32_r, + N, + rocblas_datatype_f32_r, + rocblas_gemm_algo_solution_index, + solution_index, + 0); + if (status != rocblas_status_success) { + solution_valid.store(false, std::memory_order_relaxed); + rocblas_sgemm( + handle, + trans_a, + trans_b, + N, + M, + K, + &alpha_f, + static_cast(b_ptr), + ld_b, + static_cast(a_ptr), + ld_a, + &beta_f, + static_cast(out_ptr), + N); + } + } else { + rocblas_sgemm( + handle, + trans_a, + trans_b, + N, + M, + K, + &alpha_f, + static_cast(b_ptr), + ld_b, + static_cast(a_ptr), + ld_a, + &beta_f, + static_cast(out_ptr), + N); + } + break; + } + case float64: { + double alpha_d = static_cast(alpha); + double beta_d = static_cast(beta); + rocblas_dgemm( + handle, + trans_a, + trans_b, + N, + M, + K, + &alpha_d, + static_cast(b_ptr), + ld_b, + static_cast(a_ptr), + ld_a, + &beta_d, + static_cast(out_ptr), + N); + break; + } + case float16: { + rocblas_half alpha_h, beta_h; + // Convert float to rocblas_half using memcpy + float16_t alpha_f16 = static_cast(alpha); + float16_t beta_f16 = static_cast(beta); + std::memcpy(&alpha_h, &alpha_f16, sizeof(rocblas_half)); + std::memcpy(&beta_h, &beta_f16, sizeof(rocblas_half)); + rocblas_hgemm( + handle, + trans_a, + trans_b, + N, + M, + K, + &alpha_h, + reinterpret_cast( + static_cast(b_ptr)), + ld_b, + reinterpret_cast( + static_cast(a_ptr)), + ld_a, + &beta_h, + reinterpret_cast(static_cast(out_ptr)), + N); + break; + } + case bfloat16: { + float alpha_f = alpha; + float beta_f = beta; + int solution_index = gemm_solution_index_bf16(false); + static std::atomic solution_valid{true}; + + rocblas_gemm_algo algo = rocblas_gemm_algo_standard; + if (solution_index > 0 && + solution_valid.load(std::memory_order_relaxed)) { + algo = rocblas_gemm_algo_solution_index; + } else { + solution_index = 0; + } + + rocblas_status status = rocblas_gemm_ex( + handle, + trans_a, + trans_b, + N, + M, + K, + &alpha_f, + static_cast(b_ptr), + rocblas_datatype_bf16_r, + ld_b, + static_cast(a_ptr), + rocblas_datatype_bf16_r, + ld_a, + &beta_f, + static_cast(out_ptr), + rocblas_datatype_bf16_r, + N, + static_cast(out_ptr), + rocblas_datatype_bf16_r, + N, + rocblas_datatype_f32_r, + algo, + solution_index, + 0); + if (status != rocblas_status_success && + algo == rocblas_gemm_algo_solution_index) { + solution_valid.store(false, std::memory_order_relaxed); + rocblas_gemm_ex( + handle, + trans_a, + trans_b, + N, + M, + K, + &alpha_f, + static_cast(b_ptr), + rocblas_datatype_bf16_r, + ld_b, + static_cast(a_ptr), + rocblas_datatype_bf16_r, + ld_a, + &beta_f, + static_cast(out_ptr), + rocblas_datatype_bf16_r, + N, + static_cast(out_ptr), + rocblas_datatype_bf16_r, + N, + rocblas_datatype_f32_r, + rocblas_gemm_algo_standard, + 0, + 0); + } + break; + } + default: + throw std::runtime_error("Unsupported dtype for matmul on ROCm"); + } + }); +} + +void gemm_strided_batched_rocblas( + rocm::CommandEncoder& encoder, + int M, + int N, + int K, + bool a_transposed, + int64_t lda, + int64_t stride_a, + bool b_transposed, + int64_t ldb, + int64_t stride_b, + int64_t stride_c, + int batch_count, + array& out, + const array& a, + const array& b, + float alpha = 1.0f, + float beta = 0.0f) { + // Try hipBLASLt for bf16/fp16 batched GEMMs. + if ((a.dtype() == bfloat16 || a.dtype() == float16) && + rocm::is_hipblaslt_available()) { + try { + rocm::hipblaslt_gemm_batched( + encoder, + a_transposed, + b_transposed, + M, + N, + K, + alpha, + a, + lda, + stride_a, + b, + ldb, + stride_b, + beta, + out, + N, // ldc = N for row-major output + stride_c, + batch_count, + a.dtype()); + return; + } catch (...) { + // hipBLASLt failed -- fall through to rocBLAS. + } + } + + auto& device = encoder.device(); + rocblas_handle handle = device.get_rocblas_handle(); + + rocblas_operation trans_a = + b_transposed ? rocblas_operation_transpose : rocblas_operation_none; + rocblas_operation trans_b = + a_transposed ? rocblas_operation_transpose : rocblas_operation_none; + + const int64_t ld_b = ldb; + const int64_t ld_a = lda; + const void* a_ptr = gpu_ptr(a); + const void* b_ptr = gpu_ptr(b); + void* out_ptr = gpu_ptr(out); + + encoder.launch_kernel([&, a_ptr, b_ptr, out_ptr](hipStream_t stream) { + encoder.device().set_rocblas_stream(stream); + + switch (a.dtype()) { + case float32: { + float alpha_f = alpha; + float beta_f = beta; + int solution_index = gemm_solution_index_f32(true); + static std::atomic solution_valid{true}; + + if (solution_index > 0 && + solution_valid.load(std::memory_order_relaxed)) { + rocblas_status status = rocblas_gemm_strided_batched_ex( + handle, + trans_a, + trans_b, + N, + M, + K, + &alpha_f, + b_ptr, + rocblas_datatype_f32_r, + ld_b, + stride_b, + a_ptr, + rocblas_datatype_f32_r, + ld_a, + stride_a, + &beta_f, + out_ptr, + rocblas_datatype_f32_r, + N, + stride_c, + out_ptr, + rocblas_datatype_f32_r, + N, + stride_c, + batch_count, + rocblas_datatype_f32_r, + rocblas_gemm_algo_solution_index, + solution_index, + 0); + if (status != rocblas_status_success) { + solution_valid.store(false, std::memory_order_relaxed); + rocblas_sgemm_strided_batched( + handle, + trans_a, + trans_b, + N, + M, + K, + &alpha_f, + static_cast(b_ptr), + ld_b, + stride_b, + static_cast(a_ptr), + ld_a, + stride_a, + &beta_f, + static_cast(out_ptr), + N, + stride_c, + batch_count); + } + } else { + rocblas_sgemm_strided_batched( + handle, + trans_a, + trans_b, + N, + M, + K, + &alpha_f, + static_cast(b_ptr), + ld_b, + stride_b, + static_cast(a_ptr), + ld_a, + stride_a, + &beta_f, + static_cast(out_ptr), + N, + stride_c, + batch_count); + } + break; + } + case float64: { + double alpha_d = static_cast(alpha); + double beta_d = static_cast(beta); + rocblas_dgemm_strided_batched( + handle, + trans_a, + trans_b, + N, + M, + K, + &alpha_d, + static_cast(b_ptr), + ld_b, + stride_b, + static_cast(a_ptr), + ld_a, + stride_a, + &beta_d, + static_cast(out_ptr), + N, + stride_c, + batch_count); + break; + } + case float16: { + rocblas_half alpha_h, beta_h; + float16_t alpha_f16 = static_cast(alpha); + float16_t beta_f16 = static_cast(beta); + std::memcpy(&alpha_h, &alpha_f16, sizeof(rocblas_half)); + std::memcpy(&beta_h, &beta_f16, sizeof(rocblas_half)); + rocblas_hgemm_strided_batched( + handle, + trans_a, + trans_b, + N, + M, + K, + &alpha_h, + reinterpret_cast( + static_cast(b_ptr)), + ld_b, + stride_b, + reinterpret_cast( + static_cast(a_ptr)), + ld_a, + stride_a, + &beta_h, + reinterpret_cast(static_cast(out_ptr)), + N, + stride_c, + batch_count); + break; + } + case bfloat16: { + float alpha_f = alpha; + float beta_f = beta; + int solution_index = gemm_solution_index_bf16(true); + static std::atomic solution_valid{true}; + + rocblas_gemm_algo algo = rocblas_gemm_algo_standard; + if (solution_index > 0 && + solution_valid.load(std::memory_order_relaxed)) { + algo = rocblas_gemm_algo_solution_index; + } else { + solution_index = 0; + } + + rocblas_status status = rocblas_gemm_strided_batched_ex( + handle, + trans_a, + trans_b, + N, + M, + K, + &alpha_f, + static_cast(b_ptr), + rocblas_datatype_bf16_r, + ld_b, + stride_b, + static_cast(a_ptr), + rocblas_datatype_bf16_r, + ld_a, + stride_a, + &beta_f, + static_cast(out_ptr), + rocblas_datatype_bf16_r, + N, + stride_c, + static_cast(out_ptr), + rocblas_datatype_bf16_r, + N, + stride_c, + batch_count, + rocblas_datatype_f32_r, + algo, + solution_index, + 0); + if (status != rocblas_status_success && + algo == rocblas_gemm_algo_solution_index) { + solution_valid.store(false, std::memory_order_relaxed); + rocblas_gemm_strided_batched_ex( + handle, + trans_a, + trans_b, + N, + M, + K, + &alpha_f, + static_cast(b_ptr), + rocblas_datatype_bf16_r, + ld_b, + stride_b, + static_cast(a_ptr), + rocblas_datatype_bf16_r, + ld_a, + stride_a, + &beta_f, + static_cast(out_ptr), + rocblas_datatype_bf16_r, + N, + stride_c, + static_cast(out_ptr), + rocblas_datatype_bf16_r, + N, + stride_c, + batch_count, + rocblas_datatype_f32_r, + rocblas_gemm_algo_standard, + 0, + 0); + } + break; + } + default: + throw std::runtime_error( + "Unsupported dtype for batched matmul on ROCm"); + } + }); +} + +void gemm_and_bias( + rocm::CommandEncoder& encoder, + int M, + int N, + int K, + bool a_transposed, + int64_t lda, + bool b_transposed, + int64_t ldb, + array& out, + const array& a, + const array& b, + float alpha = 1.0f, + float beta = 0.0f) { + // Check and collapse batch dimensions + auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b); + + auto batch_count = out.size() / (M * N); + + // Collapse batches into M if needed + if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 && + a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K && + b_batch_strides.back() == 0) { + M *= batch_shape.back(); + batch_count = 1; + + a_batch_strides = {0}; + b_batch_strides = {0}; + batch_shape = {1}; + } + + // Use GEMV when possible + if (rocm::can_use_gemv(M, N, K, a_transposed, b_transposed)) { + rocm::gemv( + a, + b, + out, + M, + N, + K, + batch_count, + batch_shape, + a_batch_strides, + b_batch_strides, + encoder); + return; + } + + // Check if rocBLAS is available + bool use_rocblas = encoder.device().is_rocblas_available(); + auto [a_uniform_batch, a_uniform_stride] = + get_uniform_batch_stride(batch_shape, a_batch_strides); + auto [b_uniform_batch, b_uniform_stride] = + get_uniform_batch_stride(batch_shape, b_batch_strides); + + if (batch_count == 1) { + // Simple single GEMM + if (use_rocblas) { + gemm_rocblas( + encoder, + M, + N, + K, + a_transposed, + lda, + b_transposed, + ldb, + out, + a, + b, + alpha, + beta); + } else { + // Use naive GEMM fallback + rocm::naive_gemm( + encoder, + a, + b, + out, + M, + N, + K, + a_transposed, + lda, + b_transposed, + ldb, + alpha, + beta); + } + } else if (a_uniform_batch && b_uniform_batch) { + // Use strided batched GEMM for uniform batches + if (use_rocblas) { + gemm_strided_batched_rocblas( + encoder, + M, + N, + K, + a_transposed, + lda, + a_uniform_stride, + b_transposed, + ldb, + b_uniform_stride, + M * N, + batch_count, + out, + a, + b, + alpha, + beta); + } else { + // Use naive batched GEMM fallback + rocm::naive_gemm_batched( + encoder, + a, + b, + out, + M, + N, + K, + a_transposed, + lda, + a_uniform_stride, + b_transposed, + ldb, + b_uniform_stride, + M * N, + batch_count, + alpha, + beta); + } + } else { + // Fallback: loop over batches for non-uniform strides + if (use_rocblas) { + const void* a_ptr_base = gpu_ptr(a); + const void* b_ptr_base = gpu_ptr(b); + void* out_ptr_base = gpu_ptr(out); + for (int64_t batch = 0; batch < batch_count; ++batch) { + int64_t a_offset = 0, b_offset = 0; + int64_t batch_idx = batch; + for (int i = batch_shape.size() - 1; i >= 0; --i) { + int64_t idx = batch_idx % batch_shape[i]; + batch_idx /= batch_shape[i]; + a_offset += idx * a_batch_strides[i]; + b_offset += idx * b_batch_strides[i]; + } + + encoder.launch_kernel([&, + a_offset, + b_offset, + batch, + a_ptr_base, + b_ptr_base, + out_ptr_base](hipStream_t stream) { + auto& device = encoder.device(); + device.set_rocblas_stream(stream); + rocblas_handle handle = device.get_rocblas_handle(); + + rocblas_operation trans_a = b_transposed ? rocblas_operation_transpose + : rocblas_operation_none; + rocblas_operation trans_b = a_transposed ? rocblas_operation_transpose + : rocblas_operation_none; + + const int64_t ld_b = ldb; + const int64_t ld_a = lda; + + switch (a.dtype()) { + case float32: { + float alpha_f = alpha, beta_f = beta; + rocblas_sgemm( + handle, + trans_a, + trans_b, + N, + M, + K, + &alpha_f, + static_cast(b_ptr_base) + b_offset, + ld_b, + static_cast(a_ptr_base) + a_offset, + ld_a, + &beta_f, + static_cast(out_ptr_base) + batch * M * N, + N); + break; + } + case float64: { + double alpha_d = static_cast(alpha); + double beta_d = static_cast(beta); + rocblas_dgemm( + handle, + trans_a, + trans_b, + N, + M, + K, + &alpha_d, + static_cast(b_ptr_base) + b_offset, + ld_b, + static_cast(a_ptr_base) + a_offset, + ld_a, + &beta_d, + static_cast(out_ptr_base) + batch * M * N, + N); + break; + } + case float16: { + rocblas_half alpha_h, beta_h; + float16_t alpha_f16 = static_cast(alpha); + float16_t beta_f16 = static_cast(beta); + std::memcpy(&alpha_h, &alpha_f16, sizeof(rocblas_half)); + std::memcpy(&beta_h, &beta_f16, sizeof(rocblas_half)); + rocblas_hgemm( + handle, + trans_a, + trans_b, + N, + M, + K, + &alpha_h, + reinterpret_cast( + static_cast(b_ptr_base) + b_offset), + ld_b, + reinterpret_cast( + static_cast(a_ptr_base) + a_offset), + ld_a, + &beta_h, + reinterpret_cast( + static_cast(out_ptr_base) + batch * M * N), + N); + break; + } + case bfloat16: { + float alpha_f = alpha; + float beta_f = beta; + auto* out_ptr = + static_cast(out_ptr_base) + batch * M * N; + rocblas_gemm_ex( + handle, + trans_a, + trans_b, + N, + M, + K, + &alpha_f, + static_cast(b_ptr_base) + b_offset, + rocblas_datatype_bf16_r, + ld_b, + static_cast(a_ptr_base) + a_offset, + rocblas_datatype_bf16_r, + ld_a, + &beta_f, + out_ptr, + rocblas_datatype_bf16_r, + N, + out_ptr, + rocblas_datatype_bf16_r, + N, + rocblas_datatype_f32_r, + rocblas_gemm_algo_standard, + 0, + 0); + break; + } + default: + throw std::runtime_error( + "Unsupported dtype for non-uniform batched matmul on ROCm"); + } + }); + } + } else { + // Use naive GEMM for each batch when rocBLAS is not available + // This is less efficient but provides correctness + for (int64_t batch = 0; batch < batch_count; ++batch) { + int64_t a_offset = 0, b_offset = 0; + int64_t batch_idx = batch; + for (int i = batch_shape.size() - 1; i >= 0; --i) { + int64_t idx = batch_idx % batch_shape[i]; + batch_idx /= batch_shape[i]; + a_offset += idx * a_batch_strides[i]; + b_offset += idx * b_batch_strides[i]; + } + + // Use naive GEMM with explicit offsets + rocm::naive_gemm_with_offset( + encoder, + a, + b, + out, + M, + N, + K, + a_transposed, + lda, + a_offset, + b_transposed, + ldb, + b_offset, + batch * M * N, + alpha, + beta); + } + } + } +} + +} // namespace + +void Matmul::eval_gpu(const std::vector& inputs, array& out) { + auto& s = stream(); + auto& encoder = rocm::get_command_encoder(s); + + assert(inputs.size() == 2); + auto& a_pre = inputs[0]; + auto& b_pre = inputs[1]; + + // Return 0s if either input is empty. + if (a_pre.size() == 0 || b_pre.size() == 0) { + array zero(0, a_pre.dtype()); + encoder.add_temporary(zero); + fill_gpu(zero, out, s); + return; + } + + out.set_data(mlx::core::rocm::malloc_async(out.nbytes(), encoder)); + + int M = a_pre.shape(-2); + int N = b_pre.shape(-1); + int K = a_pre.shape(-1); + + auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre); + auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre); + + gemm_and_bias( + encoder, M, N, K, a_transposed, lda, b_transposed, ldb, out, a, b); +} + +void AddMM::eval_gpu(const std::vector& inputs, array& out) { + auto& s = stream(); + auto& encoder = rocm::get_command_encoder(s); + + assert(inputs.size() == 3); + auto& a_pre = inputs[0]; + auto& b_pre = inputs[1]; + auto c = inputs[2]; + + int M = a_pre.shape(-2); + int N = b_pre.shape(-1); + int K = a_pre.shape(-1); + + auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre); + auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre); + + // Copy C into out only when beta uses it. + if (beta_ != 0.0f) { + copy_gpu(c, out, CopyType::General, s); + } else { + out.set_data(mlx::core::rocm::malloc_async(out.nbytes(), encoder)); + } + + // Check if rocBLAS is available + if (encoder.device().is_rocblas_available()) { + // Do GEMM with alpha and beta + gemm_rocblas( + encoder, + M, + N, + K, + a_transposed, + lda, + b_transposed, + ldb, + out, + a, + b, + alpha_, + beta_); + } else { + // Use naive GEMM fallback + rocm::naive_gemm( + encoder, + a, + b, + out, + M, + N, + K, + a_transposed, + lda, + b_transposed, + ldb, + alpha_, + beta_); + } +} + +void GatherMM::eval_gpu(const std::vector& inputs, array& out) { + auto& s = stream(); + auto& encoder = rocm::get_command_encoder(s); + + assert(inputs.size() == 4); + auto& a = inputs[0]; + auto& b = inputs[1]; + auto& lhs_indices = inputs[2]; + auto& rhs_indices = inputs[3]; + + // Return 0s if either input is empty. + if (a.size() == 0 || b.size() == 0) { + array zero(0, a.dtype()); + encoder.add_temporary(zero); + fill_gpu(zero, out, s); + return; + } + + out.set_data(mlx::core::rocm::malloc_async(out.nbytes(), encoder)); + + // Extract shapes from inputs. + int M = a.shape(-2); + int N = b.shape(-1); + int K = a.shape(-1); + + auto [transposed_a, lda, a_] = check_transpose(encoder, s, a); + auto [transposed_b, ldb, b_] = check_transpose(encoder, s, b); + + auto use_gemv = rocm::can_use_gemv(M, N, K, transposed_a, transposed_b); + + if (M == 1 && use_gemv) { + rocm::gather_mv(b_, a_, rhs_indices, lhs_indices, out, N, K, encoder); + return; + } + + if (N == 1 && use_gemv) { + rocm::gather_mv(a_, b_, lhs_indices, rhs_indices, out, M, K, encoder); + return; + } + + // Keep gather indices on device and resolve per-batch matrix offsets inside + // the kernel to avoid host synchronization. + rocm::naive_gemm_gather( + encoder, + a_, + b_, + lhs_indices, + rhs_indices, + out, + M, + N, + K, + transposed_a, + lda, + transposed_b, + ldb, + 1.0f, + 0.0f); +} + +} // namespace mlx::core diff --git a/mlx/backend/rocm/no_rocm.cpp b/mlx/backend/rocm/no_rocm.cpp new file mode 100644 index 0000000000..6b8628c842 --- /dev/null +++ b/mlx/backend/rocm/no_rocm.cpp @@ -0,0 +1,32 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/rocm.h" +#include "mlx/fast.h" + +namespace mlx::core { + +namespace rocm { + +bool is_available() { + return false; +} + +} // namespace rocm + +namespace fast { + +CustomKernelFunction hip_kernel( + const std::string&, + const std::vector&, + const std::vector&, + const std::string&, + const std::string&, + bool, + int, + std::vector>) { + throw std::runtime_error("[hip_kernel] No ROCm back-end."); +} + +} // namespace fast + +} // namespace mlx::core diff --git a/mlx/backend/rocm/primitives.cpp b/mlx/backend/rocm/primitives.cpp new file mode 100644 index 0000000000..930e9a9cf1 --- /dev/null +++ b/mlx/backend/rocm/primitives.cpp @@ -0,0 +1,55 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/distributed/primitives.h" +#include "mlx/fast_primitives.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +#define NO_GPU_MULTI(func) \ + void func::eval_gpu( \ + const std::vector& inputs, std::vector& outputs) { \ + throw std::runtime_error(#func " has no ROCm implementation."); \ + } + +#define NO_GPU_USE_FALLBACK(func) \ + bool func::use_fallback(Stream s) { \ + return true; \ + } \ + NO_GPU_MULTI(func) + +#define NO_GPU(func) \ + void func::eval_gpu(const std::vector& inputs, array& out) { \ + throw std::runtime_error(#func " has no ROCm implementation."); \ + } + +// Note: Convolution is now implemented in conv/conv.cpp +// Note: GatherMM is now implemented in matmul.cpp +// Note: QuantizedMatmul is now implemented in quantized/qmm.hip +// Note: GatherQMM is now implemented in quantized/qmm.hip + +NO_GPU(BlockMaskedMM) +NO_GPU(FFT) +NO_GPU(Hadamard) +NO_GPU_MULTI(LUF) +NO_GPU_MULTI(QRF) +NO_GPU(QQMatmul) +NO_GPU(SegmentedMM) +NO_GPU_MULTI(SVD) +NO_GPU(Inverse) +NO_GPU(Cholesky) +NO_GPU_MULTI(Eig) +NO_GPU_MULTI(Eigh) + +// Note: The following are now implemented in their respective files: +// - Load: load.cpp +// - CustomKernel: custom_kernel.cpp +// - ScaledDotProductAttention: scaled_dot_product_attention.cpp +// - ScaledDotProductAttentionVJP: scaled_dot_product_attention.cpp +// - Quantize: quantized/quantized.cpp +// - AffineQuantize: quantized/quantized.cpp +// - ConvertFP8: quantized/quantized.cpp +// - AllGather, AllReduce, ReduceScatter, Send, Recv: distributed.hip +// - Convolution: conv/conv.cpp + +} // namespace mlx::core diff --git a/mlx/backend/rocm/primitives.hip b/mlx/backend/rocm/primitives.hip new file mode 100644 index 0000000000..c91e36da3c --- /dev/null +++ b/mlx/backend/rocm/primitives.hip @@ -0,0 +1,21 @@ +// Copyright © 2025 Apple Inc. + +#include + +#include "mlx/backend/rocm/utils.h" +#include "mlx/backend/common/primitives.h" + +namespace mlx::core::rocm { + +// Basic kernel implementations will go here +// This is a placeholder for ROCm-specific primitive operations + +void add_hip() { + // Placeholder for HIP add operation +} + +void multiply_hip() { + // Placeholder for HIP multiply operation +} + +} // namespace mlx::core::rocm \ No newline at end of file diff --git a/mlx/backend/rocm/quantized/affine_quantize.hip b/mlx/backend/rocm/quantized/affine_quantize.hip new file mode 100644 index 0000000000..a34d1ffde2 --- /dev/null +++ b/mlx/backend/rocm/quantized/affine_quantize.hip @@ -0,0 +1,356 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/backend/rocm/quantized/quantized.h" + +#include +#include +#include + +namespace mlx::core { + +namespace rocm { + +template +__global__ void affine_quantize_kernel( + const T* __restrict__ input, + uint8_t* __restrict__ output, + ScaleT* __restrict__ scales, + ScaleT* __restrict__ biases, + int num_groups, + int group_size) { + int group_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (group_idx >= num_groups) + return; + + const T* group_input = input + group_idx * group_size; + + // Find min and max in group + float min_val = static_cast(group_input[0]); + float max_val = static_cast(group_input[0]); + for (int i = 1; i < group_size; ++i) { + float val = static_cast(group_input[i]); + min_val = fminf(min_val, val); + max_val = fmaxf(max_val, val); + } + + // Compute scale and bias + float range = max_val - min_val; + float max_quant = static_cast((1 << BITS) - 1); + float scale = range / max_quant; + float bias = min_val; + + // Avoid division by zero + if (scale == 0.0f) { + scale = 1.0f; + } + + scales[group_idx] = static_cast(scale); + biases[group_idx] = static_cast(bias); + + // Quantize values + int output_idx = group_idx * (group_size * BITS / 8); + int group_bytes = group_size * BITS / 8; + constexpr uint8_t mask = static_cast((1u << BITS) - 1u); + + for (int i = 0; i < group_bytes; ++i) { + output[output_idx + i] = 0; + } + + for (int i = 0; i < group_size; ++i) { + float val = static_cast(group_input[i]); + int quant_val = static_cast((val - bias) / scale + 0.5f); + quant_val = max(0, min(static_cast(max_quant), quant_val)); + + int bit_index = i * BITS; + int byte_idx = output_idx + (bit_index >> 3); + int bit_offset = bit_index & 0x7; + uint32_t shifted = + static_cast(static_cast(quant_val) & mask) + << bit_offset; + + output[byte_idx] |= static_cast(shifted & 0xFF); + if (bit_offset + BITS > 8) { + output[byte_idx + 1] |= static_cast((shifted >> 8) & 0xFF); + } + } +} + +template +__global__ void affine_dequantize_kernel( + const uint8_t* __restrict__ input, + const ScaleT* __restrict__ scales, + const ScaleT* __restrict__ biases, + T* __restrict__ output, + int num_groups, + int group_size) { + int group_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (group_idx >= num_groups) + return; + + float scale = static_cast(scales[group_idx]); + float bias = biases ? static_cast(biases[group_idx]) : 0.0f; + + int input_base = group_idx * (group_size * BITS / 8); + T* group_output = output + group_idx * group_size; + constexpr uint8_t mask = static_cast((1u << BITS) - 1u); + + for (int i = 0; i < group_size; ++i) { + int bit_index = i * BITS; + int byte_idx = input_base + (bit_index >> 3); + int bit_offset = bit_index & 0x7; + + uint32_t packed = static_cast(input[byte_idx]); + if (bit_offset + BITS > 8) { + packed |= static_cast(input[byte_idx + 1]) << 8; + } + + int quant_val = static_cast((packed >> bit_offset) & mask); + float dequant_val = static_cast(quant_val) * scale + bias; + group_output[i] = static_cast(dequant_val); + } +} + +// Optimized dequantize kernel for pack_factor elements at a time. +// RDNA 3.5 (gfx1151) with hipcc 7.13 / LLVM 23: Avoid #pragma unroll — the +// compiler emits incorrectly optimized vectorized stores that corrupt output. +// Use explicit scalar stores instead (same root cause as the uint4 load fix +// in qdequant.hpp). +template +__global__ void affine_dequantize_packed_kernel( + const uint8_t* __restrict__ input, + const T* __restrict__ scales, + const T* __restrict__ biases, + T* __restrict__ output, + size_t size, + int group_size) { + constexpr int pack_factor = 8 / BITS; + + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + size_t oindex = idx * pack_factor; + + if (oindex >= size) { + return; + } + + size_t gindex = oindex / group_size; + float scale = static_cast(scales[gindex]); + float bias = biases ? static_cast(biases[gindex]) : 0.0f; + + uint8_t val = input[idx]; + + // Manual unroll with explicit scalar stores — avoids LLVM 23 codegen bug + // on RDNA 3.5 that corrupted #pragma unroll vectorized stores. + for (int i = 0; i < pack_factor; ++i) { + if (oindex + i >= size) + break; + uint8_t d; + if constexpr (BITS == 2) { + d = (val >> (BITS * i)) & 0x03; + } else if constexpr (BITS == 4) { + d = (val >> (BITS * i)) & 0x0f; + } else if constexpr (BITS == 8) { + d = val; + } + output[oindex + i] = static_cast(scale * static_cast(d) + bias); + } +} + +} // namespace rocm + +void affine_quantize( + const array& w, + array& wq, + array& scales, + array& biases, + int group_size, + int bits, + rocm::CommandEncoder& enc, + const Stream& s) { + int num_elements = w.size(); + int num_groups = num_elements / group_size; + + int block_size = 256; + int num_blocks = (num_groups + block_size - 1) / block_size; + + enc.set_input_array(w); + enc.set_output_array(wq); + enc.set_output_array(scales); + enc.set_output_array(biases); + +#define LAUNCH_QUANTIZE(T, ScaleT, BITS) \ + enc.add_kernel_node( \ + &rocm::affine_quantize_kernel, \ + dim3(num_blocks), \ + dim3(block_size), \ + 0u, \ + gpu_ptr(w), \ + gpu_ptr(wq), \ + gpu_ptr(scales), \ + gpu_ptr(biases), \ + num_groups, \ + group_size) + +#define DISPATCH_BITS(T, ScaleT) \ + switch (bits) { \ + case 2: \ + LAUNCH_QUANTIZE(T, ScaleT, 2); \ + break; \ + case 3: \ + LAUNCH_QUANTIZE(T, ScaleT, 3); \ + break; \ + case 4: \ + LAUNCH_QUANTIZE(T, ScaleT, 4); \ + break; \ + case 5: \ + LAUNCH_QUANTIZE(T, ScaleT, 5); \ + break; \ + case 6: \ + LAUNCH_QUANTIZE(T, ScaleT, 6); \ + break; \ + case 8: \ + LAUNCH_QUANTIZE(T, ScaleT, 8); \ + break; \ + default: \ + throw std::runtime_error("Unsupported bits for affine_quantize"); \ + } + switch (w.dtype()) { + case float32: + DISPATCH_BITS(float, float); + break; + case float16: + DISPATCH_BITS(__half, __half); + break; + case bfloat16: + DISPATCH_BITS(hip_bfloat16, hip_bfloat16); + break; + default: + throw std::runtime_error("Unsupported dtype for affine_quantize"); + } + +#undef DISPATCH_BITS +#undef LAUNCH_QUANTIZE +} + +void affine_dequantize( + const array& wq, + const array& scales, + const std::optional& biases, + array& w, + int group_size, + int bits, + rocm::CommandEncoder& enc, + const Stream& s) { + enc.set_input_array(wq); + enc.set_input_array(scales); + if (biases) + enc.set_input_array(*biases); + enc.set_output_array(w); + + // Use packed kernel for power-of-2 bits + if (bits == 2 || bits == 4 || bits == 8) { + int pack_factor = 8 / bits; + size_t size = w.size() / pack_factor; + + int block_size = 256; + int num_blocks = (size + block_size - 1) / block_size; + +#define LAUNCH_DEQUANTIZE_PACKED(T, BITS) \ + enc.add_kernel_node( \ + &rocm::affine_dequantize_packed_kernel, \ + dim3(num_blocks), \ + dim3(block_size), \ + 0u, \ + gpu_ptr(wq), \ + gpu_ptr(scales), \ + static_cast(biases ? gpu_ptr(*biases) : nullptr), \ + gpu_ptr(w), \ + w.size(), \ + group_size) + +#define DISPATCH_BITS_PACKED(T) \ + switch (bits) { \ + case 2: \ + LAUNCH_DEQUANTIZE_PACKED(T, 2); \ + break; \ + case 4: \ + LAUNCH_DEQUANTIZE_PACKED(T, 4); \ + break; \ + case 8: \ + LAUNCH_DEQUANTIZE_PACKED(T, 8); \ + break; \ + default: \ + break; \ + } + switch (w.dtype()) { + case float32: + DISPATCH_BITS_PACKED(float); + break; + case float16: + DISPATCH_BITS_PACKED(__half); + break; + case bfloat16: + DISPATCH_BITS_PACKED(hip_bfloat16); + break; + default: + throw std::runtime_error("Unsupported dtype for affine_dequantize"); + } + +#undef DISPATCH_BITS_PACKED +#undef LAUNCH_DEQUANTIZE_PACKED + } else { + // Fallback for non-power-of-2 bits (3, 5, 6) + int num_elements = w.size(); + int num_groups = num_elements / group_size; + + int block_size = 256; + int num_blocks = (num_groups + block_size - 1) / block_size; + +#define LAUNCH_DEQUANTIZE(T, ScaleT, BITS) \ + enc.add_kernel_node( \ + &rocm::affine_dequantize_kernel, \ + dim3(num_blocks), \ + dim3(block_size), \ + 0u, \ + gpu_ptr(wq), \ + gpu_ptr(scales), \ + static_cast(biases ? gpu_ptr(*biases) : nullptr), \ + gpu_ptr(w), \ + num_groups, \ + group_size) + +#define DISPATCH_BITS(T, ScaleT) \ + switch (bits) { \ + case 3: \ + LAUNCH_DEQUANTIZE(T, ScaleT, 3); \ + break; \ + case 5: \ + LAUNCH_DEQUANTIZE(T, ScaleT, 5); \ + break; \ + case 6: \ + LAUNCH_DEQUANTIZE(T, ScaleT, 6); \ + break; \ + default: \ + throw std::runtime_error("Unsupported bits for affine_dequantize"); \ + } + switch (w.dtype()) { + case float32: + DISPATCH_BITS(float, float); + break; + case float16: + DISPATCH_BITS(__half, __half); + break; + case bfloat16: + DISPATCH_BITS(hip_bfloat16, hip_bfloat16); + break; + default: + throw std::runtime_error("Unsupported dtype for affine_dequantize"); + } + +#undef DISPATCH_BITS +#undef LAUNCH_DEQUANTIZE + } +} + +} // namespace mlx::core diff --git a/mlx/backend/rocm/quantized/convert_fp8.hip b/mlx/backend/rocm/quantized/convert_fp8.hip new file mode 100644 index 0000000000..45751eade6 --- /dev/null +++ b/mlx/backend/rocm/quantized/convert_fp8.hip @@ -0,0 +1,179 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/allocator.h" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/fast_primitives.h" + +#include + +namespace mlx::core { + +namespace rocm { + +// FP8 E4M3 format: 1 sign bit, 4 exponent bits, 3 mantissa bits +// Range: [-448, 448], no inf, has NaN + +template +__device__ uint8_t float_to_fp8_e4m3(T val) { + float f = static_cast(val); + + // Handle special cases + if (isnan(f)) { + return 0x7F; // NaN in E4M3 + } + + uint32_t bits = __float_as_uint(f); + uint32_t sign = (bits >> 31) & 0x1; + int32_t exp = ((bits >> 23) & 0xFF) - 127; // Unbias from float + uint32_t mant = bits & 0x7FFFFF; + + // Clamp to E4M3 range + if (exp < -9) { // Underflow to zero + return sign << 7; + } + if (exp > 8) { // Overflow to max + return (sign << 7) | 0x7E; // Max normal value + } + + // Rebias for E4M3 (bias = 7) + int32_t new_exp = exp + 7; + + // Round mantissa to 3 bits (round to nearest, ties to even) + // We're discarding 20 bits, so add 0.5 ULP = 1 << 19 = 0x80000 + uint32_t new_mant = (mant + 0x80000) >> 20; + if (new_mant > 7) { + new_mant = 0; + new_exp++; + if (new_exp > 15) { + return (sign << 7) | 0x7E; // Overflow + } + } + + if (new_exp <= 0) { + // Denormal handling + int shift = 1 - new_exp; + new_mant = ((mant | 0x800000) >> (20 + shift)); + new_exp = 0; + } + + return (sign << 7) | ((new_exp & 0xF) << 3) | (new_mant & 0x7); +} + +template +__device__ T fp8_e4m3_to_float(uint8_t val) { + uint32_t sign = (val >> 7) & 0x1; + uint32_t exp = (val >> 3) & 0xF; + uint32_t mant = val & 0x7; + + float result; + if (exp == 0) { + if (mant == 0) { + result = 0.0f; + } else { + // Denormal: value = mant * 2^(-9) + result = ldexpf(static_cast(mant), -9); + } + } else if (exp == 15 && mant == 7) { + // NaN + result = __uint_as_float(0x7FC00000); + } else { + // Normal: value = (1 + mant/8) * 2^(exp-7) + uint32_t float_exp = exp - 7 + 127; + uint32_t float_mant = mant << 20; + uint32_t bits = (sign << 31) | (float_exp << 23) | float_mant; + result = __uint_as_float(bits); + } + + return static_cast(sign ? -fabsf(result) : result); +} + +template +__global__ void to_fp8_kernel(const InT* in, OutT* out, size_t size) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= size) return; + + out[idx] = float_to_fp8_e4m3(in[idx]); +} + +template +__global__ void from_fp8_kernel(const InT* in, OutT* out, size_t size) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= size) return; + + out[idx] = fp8_e4m3_to_float(in[idx]); +} + +} // namespace rocm + +void fast::ConvertFP8::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + auto& s = stream(); + auto& d = rocm::device(s.device); + auto& enc = d.get_command_encoder(s); + + const auto& in = inputs[0]; + auto& out = outputs[0]; + + out.set_data(mlx::core::rocm::malloc_async(out.nbytes(), enc)); + + size_t size = in.size(); + int block_size = 256; + int num_blocks = (size + block_size - 1) / block_size; + + enc.set_input_array(in); + enc.set_output_array(out); + + if (to_fp8_) { + // Convert to FP8 + switch (in.dtype()) { + case float32: + enc.add_kernel_node( + &rocm::to_fp8_kernel, + dim3(num_blocks), dim3(block_size), 0u, + gpu_ptr(in), gpu_ptr(out), size); + break; + case float16: + enc.add_kernel_node( + &rocm::to_fp8_kernel<__half, uint8_t>, + dim3(num_blocks), dim3(block_size), 0u, + gpu_ptr<__half>(in), gpu_ptr(out), size); + break; + case bfloat16: + enc.add_kernel_node( + &rocm::to_fp8_kernel, + dim3(num_blocks), dim3(block_size), 0u, + gpu_ptr(in), gpu_ptr(out), size); + break; + default: + throw std::runtime_error("Unsupported input type for ConvertFP8 (to_fp8)"); + } + } else { + // Convert from FP8 + switch (out.dtype()) { + case float32: + enc.add_kernel_node( + &rocm::from_fp8_kernel, + dim3(num_blocks), dim3(block_size), 0u, + gpu_ptr(in), gpu_ptr(out), size); + break; + case float16: + enc.add_kernel_node( + &rocm::from_fp8_kernel, + dim3(num_blocks), dim3(block_size), 0u, + gpu_ptr(in), gpu_ptr<__half>(out), size); + break; + case bfloat16: + enc.add_kernel_node( + &rocm::from_fp8_kernel, + dim3(num_blocks), dim3(block_size), 0u, + gpu_ptr(in), gpu_ptr(out), size); + break; + default: + throw std::runtime_error("Unsupported output type for ConvertFP8 (from_fp8)"); + } + } +} + +} // namespace mlx::core diff --git a/mlx/backend/rocm/quantized/fp_quantize.hip b/mlx/backend/rocm/quantized/fp_quantize.hip new file mode 100644 index 0000000000..f2c076d57b --- /dev/null +++ b/mlx/backend/rocm/quantized/fp_quantize.hip @@ -0,0 +1,306 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/quantized/quantized.h" +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/kernel_utils.hpp" + +#include +#include +#include + +namespace mlx::core { + +namespace rocm { + +template +__global__ void fp_quantize_kernel( + const T* __restrict__ input, + uint8_t* __restrict__ output, + ScaleT* __restrict__ scales, + int num_groups, + int group_size) { + int group_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (group_idx >= num_groups) return; + + const T* group_input = input + group_idx * group_size; + + // Find max absolute value in group (use float for computation) + float max_abs = fabsf(static_cast(group_input[0])); + for (int i = 1; i < group_size; ++i) { + max_abs = fmaxf(max_abs, fabsf(static_cast(group_input[i]))); + } + + // Compute scale (symmetric quantization) + float max_quant = static_cast((1 << (BITS - 1)) - 1); + float scale = max_abs / max_quant; + + // Avoid division by zero + if (scale == 0.0f) { + scale = 1.0f; + } + + scales[group_idx] = static_cast(scale); + + // Quantize values + int output_idx = group_idx * (group_size * BITS / 8); + int group_bytes = group_size * BITS / 8; + constexpr uint8_t mask = static_cast((1u << BITS) - 1u); + + for (int i = 0; i < group_bytes; ++i) { + output[output_idx + i] = 0; + } + + int8_t min_val = -(1 << (BITS - 1)); + int8_t max_val = (1 << (BITS - 1)) - 1; + + for (int i = 0; i < group_size; ++i) { + float val = static_cast(group_input[i]); + int quant_val = static_cast(roundf(val / scale)); + quant_val = max(static_cast(min_val), min(static_cast(max_val), quant_val)); + + int bit_index = i * BITS; + int byte_idx = output_idx + (bit_index >> 3); + int bit_offset = bit_index & 0x7; + + uint32_t shifted = + static_cast(static_cast(quant_val) & mask) + << bit_offset; + + output[byte_idx] |= static_cast(shifted & 0xFF); + if (bit_offset + BITS > 8) { + output[byte_idx + 1] |= static_cast((shifted >> 8) & 0xFF); + } + } +} + +template +__global__ void fp_dequantize_kernel( + const uint8_t* __restrict__ input, + const ScaleT* __restrict__ scales, + T* __restrict__ output, + int num_groups, + int group_size) { + int group_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (group_idx >= num_groups) return; + + float scale = static_cast(scales[group_idx]); + + int input_base = group_idx * (group_size * BITS / 8); + T* group_output = output + group_idx * group_size; + constexpr uint8_t mask = static_cast((1u << BITS) - 1u); + constexpr uint8_t sign_bit = static_cast(1u << (BITS - 1)); + + for (int i = 0; i < group_size; ++i) { + int bit_index = i * BITS; + int byte_idx = input_base + (bit_index >> 3); + int bit_offset = bit_index & 0x7; + + uint32_t packed = static_cast(input[byte_idx]); + if (bit_offset + BITS > 8) { + packed |= static_cast(input[byte_idx + 1]) << 8; + } + uint8_t uval = static_cast((packed >> bit_offset) & mask); + + // Convert back to signed + int8_t quant_val; + if (uval & sign_bit) { + quant_val = static_cast(uval | ~mask); + } else { + quant_val = static_cast(uval); + } + + group_output[i] = static_cast(static_cast(quant_val) * scale); + } +} + +// Optimized packed dequantize kernel +template +__global__ void fp_dequantize_packed_kernel( + const uint8_t* __restrict__ input, + const T* __restrict__ scales, + T* __restrict__ output, + size_t size, + int group_size) { + constexpr int pack_factor = 8 / BITS; + + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + size_t oindex = idx * pack_factor; + + if (oindex >= size) { + return; + } + + size_t gindex = oindex / group_size; + float scale = static_cast(scales[gindex]); + + uint8_t val = input[idx]; + uint8_t mask = (1 << BITS) - 1; + uint8_t sign_bit = static_cast(1 << (BITS - 1)); + + #pragma unroll + for (int i = 0; i < pack_factor; ++i) { + uint8_t uval = (val >> (BITS * i)) & mask; + + // Convert to signed + int8_t quant_val; + if (uval & sign_bit) { + quant_val = static_cast(uval | ~mask); + } else { + quant_val = static_cast(uval); + } + + output[oindex + i] = static_cast(static_cast(quant_val) * scale); + } +} + +} // namespace rocm + +void fp_quantize( + const array& w, + array& wq, + array& scales, + int group_size, + int bits, + rocm::CommandEncoder& enc, + const Stream& s) { + int num_elements = w.size(); + int num_groups = num_elements / group_size; + + int block_size = 256; + int num_blocks = (num_groups + block_size - 1) / block_size; + + enc.set_input_array(w); + enc.set_output_array(wq); + enc.set_output_array(scales); + + #define LAUNCH_FP_QUANTIZE(T, ScaleT, BITS) \ + enc.add_kernel_node( \ + &rocm::fp_quantize_kernel, \ + dim3(num_blocks), dim3(block_size), 0u, \ + gpu_ptr(w), gpu_ptr(wq), gpu_ptr(scales), \ + num_groups, group_size) + + #define DISPATCH_BITS(T, ScaleT) \ + switch (bits) { \ + case 2: LAUNCH_FP_QUANTIZE(T, ScaleT, 2); break; \ + case 3: LAUNCH_FP_QUANTIZE(T, ScaleT, 3); break; \ + case 4: LAUNCH_FP_QUANTIZE(T, ScaleT, 4); break; \ + case 5: LAUNCH_FP_QUANTIZE(T, ScaleT, 5); break; \ + case 6: LAUNCH_FP_QUANTIZE(T, ScaleT, 6); break; \ + case 8: LAUNCH_FP_QUANTIZE(T, ScaleT, 8); break; \ + default: throw std::runtime_error("Unsupported bits for fp_quantize"); \ + } + + switch (w.dtype()) { + case float32: + DISPATCH_BITS(float, float); + break; + case float16: + DISPATCH_BITS(__half, __half); + break; + case bfloat16: + DISPATCH_BITS(hip_bfloat16, hip_bfloat16); + break; + default: + throw std::runtime_error("Unsupported dtype for fp_quantize"); + } + + #undef DISPATCH_BITS + #undef LAUNCH_FP_QUANTIZE +} + +void fp_dequantize( + const array& wq, + const array& scales, + array& w, + int group_size, + int bits, + rocm::CommandEncoder& enc, + const Stream& s) { + + enc.set_input_array(wq); + enc.set_input_array(scales); + enc.set_output_array(w); + + // Use packed kernel for power-of-2 bits + if (bits == 2 || bits == 4 || bits == 8) { + int pack_factor = 8 / bits; + size_t size = w.size() / pack_factor; + + int block_size = 256; + int num_blocks = (size + block_size - 1) / block_size; + + #define LAUNCH_FP_DEQUANTIZE_PACKED(T, BITS) \ + enc.add_kernel_node( \ + &rocm::fp_dequantize_packed_kernel, \ + dim3(num_blocks), dim3(block_size), 0u, \ + gpu_ptr(wq), gpu_ptr(scales), gpu_ptr(w), \ + w.size(), group_size) + + #define DISPATCH_BITS_PACKED(T) \ + switch (bits) { \ + case 2: LAUNCH_FP_DEQUANTIZE_PACKED(T, 2); break; \ + case 4: LAUNCH_FP_DEQUANTIZE_PACKED(T, 4); break; \ + case 8: LAUNCH_FP_DEQUANTIZE_PACKED(T, 8); break; \ + default: break; \ + } + + switch (w.dtype()) { + case float32: + DISPATCH_BITS_PACKED(float); + break; + case float16: + DISPATCH_BITS_PACKED(__half); + break; + case bfloat16: + DISPATCH_BITS_PACKED(hip_bfloat16); + break; + default: + throw std::runtime_error("Unsupported dtype for fp_dequantize"); + } + + #undef DISPATCH_BITS_PACKED + #undef LAUNCH_FP_DEQUANTIZE_PACKED + } else { + // Fallback for non-power-of-2 bits + int num_elements = w.size(); + int num_groups = num_elements / group_size; + + int block_size = 256; + int num_blocks = (num_groups + block_size - 1) / block_size; + + #define LAUNCH_FP_DEQUANTIZE(T, ScaleT, BITS) \ + enc.add_kernel_node( \ + &rocm::fp_dequantize_kernel, \ + dim3(num_blocks), dim3(block_size), 0u, \ + gpu_ptr(wq), gpu_ptr(scales), gpu_ptr(w), \ + num_groups, group_size) + + #define DISPATCH_BITS(T, ScaleT) \ + switch (bits) { \ + case 3: LAUNCH_FP_DEQUANTIZE(T, ScaleT, 3); break; \ + case 5: LAUNCH_FP_DEQUANTIZE(T, ScaleT, 5); break; \ + case 6: LAUNCH_FP_DEQUANTIZE(T, ScaleT, 6); break; \ + default: throw std::runtime_error("Unsupported bits for fp_dequantize"); \ + } + + switch (w.dtype()) { + case float32: + DISPATCH_BITS(float, float); + break; + case float16: + DISPATCH_BITS(__half, __half); + break; + case bfloat16: + DISPATCH_BITS(hip_bfloat16, hip_bfloat16); + break; + default: + throw std::runtime_error("Unsupported dtype for fp_dequantize"); + } + + #undef DISPATCH_BITS + #undef LAUNCH_FP_DEQUANTIZE + } +} + +} // namespace mlx::core diff --git a/mlx/backend/rocm/quantized/qdequant.hpp b/mlx/backend/rocm/quantized/qdequant.hpp new file mode 100644 index 0000000000..32dcfb4dca --- /dev/null +++ b/mlx/backend/rocm/quantized/qdequant.hpp @@ -0,0 +1,182 @@ +// Shared dequantization utilities for optimized QMM kernels. +// Used by qmv_kernel.hip (GEMV) and qmm_kernel.hip (GEMM). + +#pragma once + +#include +#include +#include +#include "mlx/backend/rocm/device/config.h" + +namespace mlx::core::rocm { + +// --- Compile-time constants --- + +// Number of quantized values packed per uint32 word. +// 4-bit: 8 values, 2-bit: 16 values, 8-bit: 4 values. +template +inline constexpr int pack_factor_u32 = 32 / BITS; + +// Number of uint32 words each thread loads per K-iteration. +// Chosen so that values_per_thread = 16 for all bit widths. +template +inline constexpr int packs_per_thread = 16 / pack_factor_u32; +// 4-bit: 16/8=2, 2-bit: 16/16=1, 8-bit: 16/4=4 + +// Number of quantized values each thread processes per K-iteration. +template +inline constexpr int values_per_thread = 16; + +// Number of K-elements consumed per warp per iteration. +// = values_per_thread * WARP_SIZE = 16 * 32 = 512 +inline constexpr int block_size_k = values_per_thread<4> * WARP_SIZE; + +// Number of output rows computed per thread block. +inline constexpr int ROWS_PER_BLOCK = 8; + +// --- Warp reduction --- + +__device__ __forceinline__ float warp_reduce_sum(float val) { +#pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + val += __shfl_xor(val, offset); + } + return val; +} + +// --- Dequant-and-dot: integer dot product + x-sum accumulation --- +// +// Metal-compatible accumulation: accumulates raw integer dot product and +// x-sum separately. The caller applies scale and bias ONCE per group: +// result += scale * total_qdot + bias * total_xsum +// +// This matches Metal's qdot() which returns scale * accum + sum * bias, +// where accum and sum span all values_per_thread elements at once. +// +// The naive per-element form `acc += x[i] * (scale * q[i] + bias)` is +// mathematically equivalent but produces different float32 rounding due to +// a different number of scale/bias multiply operations, causing LLM output +// to degenerate into repetitive loops after ~10 tokens. + +template +__device__ __forceinline__ void dequant_and_dot( + uint32_t packed, + const float* __restrict__ x_local, + float& qdot_acc, + float& x_sum) { + constexpr int pf = pack_factor_u32; + constexpr uint32_t mask = (1u << BITS) - 1u; + +#pragma unroll + for (int i = 0; i < pf; i++) { + float q = static_cast((packed >> (i * BITS)) & mask); + qdot_acc += x_local[i] * q; + x_sum += x_local[i]; + } +} + +// GEMV variant: 4 independent qdot partials (dual-issue-friendly). Caller reduces +// them and applies scale/bias once per group — same result as dequant_and_dot. +template +__device__ __forceinline__ void dequant_and_dot4( + uint32_t packed, + const float* __restrict__ x_local, + float (&qdot)[4], + float& x_sum) { + constexpr int pf = pack_factor_u32; + constexpr uint32_t mask = (1u << BITS) - 1u; + +#pragma unroll + for (int i = 0; i < pf; i++) { + float q = static_cast((packed >> (i * BITS)) & mask); + qdot[i & 3] += x_local[i] * q; + x_sum += x_local[i]; + } +} + +__device__ __forceinline__ float reduce_qdot4(const float (&qdot)[4]) { + return (qdot[0] + qdot[1]) + (qdot[2] + qdot[3]); +} + +// --- Vectorized weight load --- +// +// Loads PPT uint32 words in a single wide memory transaction instead of +// PPT scalar loads. For 4-bit (PPT=2), emits global_load_dwordx2 (64-bit). +// For 8-bit (PPT=4), emits global_load_dwordx4 (128-bit). +// Pointer must be naturally aligned (8-byte for uint2, 16-byte for uint4). + +template +__device__ __forceinline__ void load_weight_vec( + const uint32_t* __restrict__ ptr, + uint32_t (&out)[packs_per_thread]) { + constexpr int PPT = packs_per_thread; + if constexpr (PPT == 2) { + uint2 v = *reinterpret_cast(ptr); + out[0] = v.x; + out[1] = v.y; + } else if constexpr (PPT == 4) { + // Two uint2 loads instead of one uint4. The single-uint4 load + // (global_load_b128) miscomputes in the 8-bit affine QMV/gather paths + // (root cause: HIP_vector_type codegen on RDNA 3.5 with + // hipcc 7.13 / LLVM 23). Two paired global_load_b64 ops yield the same + // throughput on RDNA 3.5 without the miscompile. + uint2 v0 = *reinterpret_cast(ptr); + uint2 v1 = *reinterpret_cast(ptr + 2); + out[0] = v0.x; + out[1] = v0.y; + out[2] = v1.x; + out[3] = v1.y; + } else { +#pragma unroll + for (int p = 0; p < PPT; p++) { + out[p] = ptr[p]; + } + } +} + +// Non-temporal weight load for GEMV: weights are read once, so emit streaming +// (slc) loads that bypass L2, leaving it for the reused X/scales. GEMV-only. +template +__device__ __forceinline__ void load_weight_vec_streaming( + const uint32_t* __restrict__ ptr, + uint32_t (&out)[packs_per_thread]) { + constexpr int PPT = packs_per_thread; +#pragma unroll + for (int p = 0; p < PPT; p++) { + out[p] = __builtin_nontemporal_load(ptr + p); + } +} + +// --- Type conversion helpers --- + +__device__ __forceinline__ float to_float(__half x) { + return __half2float(x); +} + +__device__ __forceinline__ float to_float(hip_bfloat16 x) { + return static_cast(x); +} + +__device__ __forceinline__ float to_float(float x) { + return x; +} + +template +__device__ __forceinline__ T from_float(float x); + +template <> +__device__ __forceinline__ __half from_float<__half>(float x) { + return __float2half(x); +} + +template <> +__device__ __forceinline__ hip_bfloat16 from_float(float x) { + return hip_bfloat16(x); +} + +template <> +__device__ __forceinline__ float from_float(float x) { + return x; +} + +} // namespace mlx::core::rocm diff --git a/mlx/backend/rocm/quantized/qmm.hip b/mlx/backend/rocm/quantized/qmm.hip new file mode 100644 index 0000000000..ae2079d396 --- /dev/null +++ b/mlx/backend/rocm/quantized/qmm.hip @@ -0,0 +1,6488 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/gpu/copy.h" +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/gemms/naive_gemm.h" +#include "mlx/backend/rocm/gemms/hipblaslt_gemm.h" +#include "mlx/backend/rocm/gemms/rocblas_gemm.h" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/backend/rocm/quantized/quantized.h" +#include "mlx/backend/rocm/quantized/qmv_tiled_kernel.hip" +#include "mlx/primitives.h" + +#include +#include +#include +#include +#include +#include +#include +// rocWMMA is only supported on CDNA (gfx9xx) and RDNA 3+ (gfx11xx, gfx12xx). +// Guard the include so it doesn't trigger static_assert on RDNA 1/2 (gfx10xx). +// During host compilation __HIP_DEVICE_COMPILE__ is 0 so rocwmma defines +// ROCWMMA_ARCH_HOST and compiles fine. During device compilation for +// unsupported architectures like gfx1030 the header would static_assert. +#if !defined(__HIP_DEVICE_COMPILE__) || !__HIP_DEVICE_COMPILE__ || \ + defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__) || \ + defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \ + defined(__gfx1150__) || defined(__gfx1151__) || defined(__gfx1152__) || defined(__gfx1153__) || \ + defined(__gfx1200__) || defined(__gfx1201__) +#define ROCM_HAS_WMMA 1 +#include +#else +#define ROCM_HAS_WMMA 0 +#endif +#include +#include +#include +#include +#include +#include + +namespace mlx::core { + +namespace rocm { + +// Strided 2D row-copy kernel: copies rows from a source with row_stride != cols +// into a contiguous destination. +// src layout: row i starts at src + i * src_row_stride (elements contiguous within row) +// dst layout: row i starts at dst + i * cols (fully contiguous) +// +// When both row strides and cols_bytes are 4-byte aligned, uses uint32_t +// copies (one 4-byte word per thread iteration) for good throughput without +// alignment concerns. Falls back to byte-by-byte for the non-aligned tail. +__global__ void strided_row_copy_kernel( + const char* __restrict__ src, + char* __restrict__ dst, + int64_t num_rows, + int64_t cols_bytes, + int64_t src_row_stride_bytes, + int64_t dst_row_stride_bytes, + bool use_word_copy) { + int64_t tid = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + int64_t grid_stride = static_cast(blockDim.x) * gridDim.x; + + if (use_word_copy) { + // Fast path: 4-byte word copies. All row strides are 4-byte aligned. + constexpr int64_t WORD = 4; + int64_t cols_words = cols_bytes / WORD; + int64_t total_words = num_rows * cols_words; + for (int64_t i = tid; i < total_words; i += grid_stride) { + int64_t row = i / cols_words; + int64_t word_in_row = i % cols_words; + int64_t src_off = row * src_row_stride_bytes + word_in_row * WORD; + int64_t dst_off = row * dst_row_stride_bytes + word_in_row * WORD; + *reinterpret_cast(dst + dst_off) = + *reinterpret_cast(src + src_off); + } + // Handle remainder bytes (cols_bytes % 4) + int64_t remainder_start = cols_words * WORD; + int64_t remainder_bytes = cols_bytes - remainder_start; + if (remainder_bytes > 0) { + for (int64_t i = tid; i < num_rows * remainder_bytes; i += grid_stride) { + int64_t row = i / remainder_bytes; + int64_t byte_in_tail = i % remainder_bytes; + int64_t src_off = row * src_row_stride_bytes + remainder_start + byte_in_tail; + int64_t dst_off = row * dst_row_stride_bytes + remainder_start + byte_in_tail; + dst[dst_off] = src[src_off]; + } + } + } else { + // Slow path: byte-by-byte copy for non-aligned strides. + int64_t total_bytes = num_rows * cols_bytes; + for (int64_t i = tid; i < total_bytes; i += grid_stride) { + int64_t row = i / cols_bytes; + int64_t byte_in_row = i % cols_bytes; + int64_t src_off = row * src_row_stride_bytes + byte_in_row; + int64_t dst_off = row * dst_row_stride_bytes + byte_in_row; + dst[dst_off] = src[src_off]; + } + } +} + +// General strided copy kernel with strides passed as kernel arguments +// (by-value hip_array structs). Avoids device memory allocation + +// hipMemcpyAsync overhead that contiguous_copy_gpu -> copy_general_input +// would incur. Falls back to contiguous_copy_gpu only for ndim > MAX_NDIM. +__global__ void strided_general_copy_kernel( + const char* __restrict__ src, + char* __restrict__ dst, + int64_t total_elems, + int elem_bytes, + int ndim, + hip_array shapes, + hip_array strides_bytes) { + int64_t tid = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + int64_t grid_stride = static_cast(blockDim.x) * gridDim.x; + for (int64_t idx = tid; idx < total_elems; idx += grid_stride) { + // Convert linear index to strided source offset + int64_t src_offset = 0; + int64_t remaining = idx; + for (int d = ndim - 1; d >= 0; --d) { + int64_t coord = remaining % shapes[d]; + remaining /= shapes[d]; + src_offset += coord * strides_bytes[d]; + } + // Copy element bytes -- specialize for common QMM element sizes + int64_t dst_offset = idx * elem_bytes; + if (elem_bytes == 2) { + *reinterpret_cast(dst + dst_offset) = + *reinterpret_cast(src + src_offset); + } else if (elem_bytes == 4) { + *reinterpret_cast(dst + dst_offset) = + *reinterpret_cast(src + src_offset); + } else if (elem_bytes == 1) { + dst[dst_offset] = src[src_offset]; + } else if (elem_bytes == 8) { + *reinterpret_cast(dst + dst_offset) = + *reinterpret_cast(src + src_offset); + } else { + for (int b = 0; b < elem_bytes; ++b) { + dst[dst_offset + b] = src[src_offset + b]; + } + } + } +} + +} // namespace rocm + +namespace { + +template +struct local_type_identity { + using type = T; +}; + +// Fast contiguous-copy helper for QMM inputs. +// +// Design goals vs the previous implementation (which called contiguous_copy_gpu +// unconditionally when strides didn't match row-major): +// +// 1. **Already contiguous** -- return immediately (unchanged). +// +// 2. **Inner-contiguous with outer stride gap** -- the most common +// non-contiguous pattern from `take` / `gather_sort`. The inner N-1 +// dimensions are packed (stride-1 on the last dim, products match for +// the rest), but the outermost dimension has a stride larger than the +// product of inner shapes. We handle this with a single +// `strided_row_copy_kernel` launch -- no device memory allocation for +// shapes/strides, no hipMemcpyAsync. One kernel dispatch total. +// +// 3. **General non-contiguous** (rare for QMM inputs) -- uses +// `strided_general_copy_kernel` which takes shapes and strides as +// kernel arguments (up to QMM_COPY_MAX_DIMS dimensions). This avoids +// the 2x allocator::malloc + 2x hipMemcpyAsync that +// `contiguous_copy_gpu -> copy_general_input` would issue. One kernel +// dispatch total. Falls back to `contiguous_copy_gpu` only for arrays +// with more than MAX_NDIM (10) dimensions (extremely unlikely for +// QMM operands). +// +// Net effect: non-contiguous copies go from 5 GPU operations (2 allocs + +// 2 memcpy + 1 kernel) down to 1 kernel launch. +inline array ensure_row_contiguous_matrix( + const array& x, + rocm::CommandEncoder& enc, + const Stream& s) { + if (x.ndim() == 0) { + return x; + } + + // --- Fast path 1: already row-major contiguous --- + int ndim = x.ndim(); + const auto& strides = x.strides(); + bool row_major_contiguous = true; + int64_t expected_stride = 1; + // Track the innermost contiguous dimensions while checking. + // If we break at dimension i, dimensions [i+1 .. ndim-1] are packed. + int first_noncontig_dim = -1; + for (int i = ndim - 1; i >= 0; --i) { + if (x.shape(i) > 1) { + if (strides[i] != expected_stride) { + row_major_contiguous = false; + first_noncontig_dim = i; + break; + } + expected_stride *= x.shape(i); + } + } + + if (row_major_contiguous) { + return x; + } + + // Empty arrays don't need copying. + if (x.size() == 0) { + return x; + } + + size_t elem_bytes = x.itemsize(); + + // Helper: allocate a contiguous output array and return src/dst pointers. + // Deferred until we know a copy is actually needed and which path to use. + auto make_output = [&]() -> array { + array out(x.shape(), x.dtype(), nullptr, {}); + out.set_data(mlx::core::rocm::malloc_async(out.nbytes(), enc)); + enc.add_temporary(out); + return out; + }; + + // --- Fast path 2: inner-contiguous, only outermost dim has a stride gap --- + // This covers the common case where x comes from take/gather of a [E, K] + // or [B, M, K] array -- inner dims are packed, outer dim stride > product. + // We also handle the case where the gap is at any single dimension (not + // just dim 0) as long as all dimensions below it are packed. + if (first_noncontig_dim >= 0) { + // Verify that all dimensions below first_noncontig_dim are packed, + // and only first_noncontig_dim itself has a non-standard stride. + // Dimensions above first_noncontig_dim (if any) must also be consistent + // with first_noncontig_dim's layout. + bool is_simple_outer_gap = true; + // Check: first_noncontig_dim's stride must be >= expected_stride + // (i.e. the inner block is correct, just spaced further apart). + if (strides[first_noncontig_dim] < expected_stride) { + is_simple_outer_gap = false; + } + // Check dimensions above first_noncontig_dim: their strides must be + // consistent with first_noncontig_dim's stride * shape products. + if (is_simple_outer_gap) { + int64_t outer_expected = strides[first_noncontig_dim] * x.shape(first_noncontig_dim); + for (int i = first_noncontig_dim - 1; i >= 0; --i) { + if (x.shape(i) <= 1) continue; + if (strides[i] != outer_expected) { + is_simple_outer_gap = false; + break; + } + outer_expected *= x.shape(i); + } + } + + if (is_simple_outer_gap && first_noncontig_dim == 0) { + // Simplest case: only the outermost dim has extra stride. + // inner_size = product of shapes[1..ndim-1] + array x_copy = make_output(); + const char* src = reinterpret_cast(gpu_ptr(x)); + char* dst = reinterpret_cast(gpu_ptr(x_copy)); + + int64_t inner_size = 1; + for (int i = 1; i < ndim; ++i) { + inner_size *= x.shape(i); + } + int64_t num_rows = x.shape(0); + int64_t cols_bytes = inner_size * static_cast(elem_bytes); + int64_t src_row_stride_bytes = strides[0] * static_cast(elem_bytes); + int64_t dst_row_stride_bytes = cols_bytes; + bool word_copy = (cols_bytes % 4 == 0) && + (src_row_stride_bytes % 4 == 0) && + (dst_row_stride_bytes % 4 == 0); + + int block_size = 256; + int64_t work_items = word_copy + ? num_rows * (cols_bytes / 4) + : num_rows * cols_bytes; + int num_blocks = static_cast( + std::min((work_items + block_size - 1) / block_size, 65535)); + + enc.set_input_array(x); + enc.set_output_array(x_copy); + enc.add_kernel_node( + &rocm::strided_row_copy_kernel, + dim3(num_blocks), dim3(block_size), 0u, + src, dst, + num_rows, cols_bytes, + src_row_stride_bytes, dst_row_stride_bytes, + word_copy); + return x_copy; + } + + if (is_simple_outer_gap) { + // Gap at an interior dimension. batch_count == 1 is common here. + int64_t batch_count = 1; + for (int i = 0; i < first_noncontig_dim; ++i) { + batch_count *= x.shape(i); + } + if (batch_count == 1) { + array x_copy = make_output(); + const char* src = reinterpret_cast(gpu_ptr(x)); + char* dst = reinterpret_cast(gpu_ptr(x_copy)); + + int64_t inner_size = 1; + for (int i = first_noncontig_dim + 1; i < ndim; ++i) { + inner_size *= x.shape(i); + } + int64_t slab_rows = x.shape(first_noncontig_dim); + int64_t cols_bytes = inner_size * static_cast(elem_bytes); + int64_t src_row_stride_bytes = strides[first_noncontig_dim] * static_cast(elem_bytes); + int64_t dst_row_stride_bytes = cols_bytes; + bool word_copy = (cols_bytes % 4 == 0) && + (src_row_stride_bytes % 4 == 0) && + (dst_row_stride_bytes % 4 == 0); + + int block_size = 256; + int64_t work_items = word_copy + ? slab_rows * (cols_bytes / 4) + : slab_rows * cols_bytes; + int num_blocks = static_cast( + std::min((work_items + block_size - 1) / block_size, 65535)); + + enc.set_input_array(x); + enc.set_output_array(x_copy); + enc.add_kernel_node( + &rocm::strided_row_copy_kernel, + dim3(num_blocks), dim3(block_size), 0u, + src, dst, + slab_rows, cols_bytes, + src_row_stride_bytes, dst_row_stride_bytes, + word_copy); + return x_copy; + } + // batch_count > 1 with interior gap: fall through to general path + } + } + + // --- Fast path 3: general non-contiguous, strides as kernel args --- + // Handles arbitrary stride patterns with up to MAX_NDIM dimensions. + // Shapes and byte-strides are passed as hip_array structs (by value), + // so no device memory allocation or hipMemcpyAsync is needed. + // One kernel launch total. + if (ndim <= MAX_NDIM) { + array x_copy = make_output(); + const char* src = reinterpret_cast(gpu_ptr(x)); + char* dst = reinterpret_cast(gpu_ptr(x_copy)); + + int64_t total_elems = x.size(); + int eb = static_cast(elem_bytes); + + int block_size = 256; + int num_blocks = static_cast( + std::min((total_elems + block_size - 1) / block_size, 65535)); + + // Pack into hip_array structs that can be passed by value to the kernel. + rocm::hip_array shapes_arg = {}; + rocm::hip_array strides_bytes_arg = {}; + for (int i = 0; i < ndim; ++i) { + shapes_arg.data_[i] = x.shape(i); + strides_bytes_arg.data_[i] = strides[i] * static_cast(elem_bytes); + } + + enc.set_input_array(x); + enc.set_output_array(x_copy); + enc.add_kernel_node( + &rocm::strided_general_copy_kernel, + dim3(num_blocks), dim3(block_size), 0u, + src, dst, + total_elems, eb, ndim, + shapes_arg, strides_bytes_arg); + return x_copy; + } + + // --- Fallback: ndim > MAX_NDIM (extremely rare for QMM) --- + // Use the generic copy infrastructure which allocates device buffers + // for shape/strides arrays (2 allocs + 2 hipMemcpyAsync + 1 kernel). + array x_copy = contiguous_copy_gpu(x, s); + enc.add_temporary(x_copy); + return x_copy; +} + +inline int parse_cols_per_block_env(const char* env_name) { + const char* raw = std::getenv(env_name); + if (raw == nullptr || *raw == '\0') { + return 0; + } + + char* end = nullptr; + long value = std::strtol(raw, &end, 10); + if (end == raw || *end != '\0') { + return 0; + } + + return (value == 4 || value == 8 || value == 16 || value == 32 || value == 64) + ? static_cast(value) + : 0; +} + +inline int parse_threads_per_col_env(const char* env_name) { + const char* raw = std::getenv(env_name); + if (raw == nullptr || *raw == '\0') { + return 0; + } + + char* end = nullptr; + long value = std::strtol(raw, &end, 10); + if (end == raw || *end != '\0') { + return 0; + } + + return (value == 16 || value == WARP_SIZE) ? static_cast(value) : 0; +} + +inline bool parse_warp_kernel_env(const char* env_name, bool default_value) { + const char* raw = std::getenv(env_name); + if (raw == nullptr || *raw == '\0') { + return default_value; + } + + if (raw[0] == '0' && raw[1] == '\0') { + return false; + } + if (raw[0] == '1' && raw[1] == '\0') { + return true; + } + return default_value; +} + +inline int parse_positive_int_env(const char* env_name, int default_value) { + const char* raw = std::getenv(env_name); + if (raw == nullptr || *raw == '\0') { + return default_value; + } + + char* end = nullptr; + long value = std::strtol(raw, &end, 10); + if (end == raw || *end != '\0' || value <= 0) { + return default_value; + } + return static_cast(value); +} + +inline size_t parse_non_negative_size_t_env( + const char* env_name, + size_t default_value) { + const char* raw = std::getenv(env_name); + if (raw == nullptr || *raw == '\0') { + return default_value; + } + + char* end = nullptr; + unsigned long long value = std::strtoull(raw, &end, 10); + if (end == raw || *end != '\0') { + return default_value; + } + return static_cast(value); +} + +inline int parse_non_negative_int_env(const char* env_name, int default_value) { + const char* raw = std::getenv(env_name); + if (raw == nullptr || *raw == '\0') { + return default_value; + } + + char* end = nullptr; + long value = std::strtol(raw, &end, 10); + if (end == raw || *end != '\0' || value < 0) { + return default_value; + } + return static_cast(value); +} + +// Check if rocBLAS dequant fast path should be used +// Default ON +inline bool use_rocblas_dequant_path() { + static bool checked = false; + static bool enabled = true; + if (!checked) { + const char* raw = std::getenv("MLX_ROCM_QMM_DEQUANT_GEMM"); + if (raw != nullptr) { + enabled = (raw[0] == '1' && raw[1] == '\0'); + } + checked = true; + } + return enabled; +} + +inline bool has_only_singleton_batch_dims(const array& x) { + if (x.ndim() <= 2) { + return true; + } + for (int i = 0; i < x.ndim() - 2; ++i) { + if (x.shape(i) != 1) { + return false; + } + } + return true; +} + +inline int select_qmv_cols_per_block(int K, int N, int bits) { + int env_cols = parse_cols_per_block_env("MLX_ROCM_QMV_COLS_PER_BLOCK"); + if (env_cols > 0) { + return env_cols; + } + + (void)K; + + if (N < 256) { + return 4; + } + if (K <= 1024) { + return (N < 1024) ? 8 : 16; + } + if (bits == 8) { + if (N < 1024) { + return 8; + } + if (N < 4096) { + return 32; + } + return 16; + } + if (N < 1024) { + return 8; + } + return 16; +} + +inline int select_qmv_threads_per_col(int K, int N, int bits, int batch_count) { + // On RDNA 3.5 (wave32), 16 threads per column gives better occupancy + // than 32 for most LLM decode shapes. 32 threads only helps for very + // large K where the extra parallelism in the reduction outweighs the + // reduced block count. + int threads_per_col = 16; + if (WARP_SIZE == 32) { + bool quant_bits_supported = + (bits == 2 || bits == 4 || bits == 5 || bits == 6 || bits == 8); + // On RDNA 3.5 (40 CUs / 20 WGPs), 16 threads/col allows 2 columns + // per warp, increasing memory-level parallelism for decode. Only use + // full warp (32) for extreme K where reduction parallelism dominates. + bool extreme = (batch_count == 1) && (K >= 16384); + if (quant_bits_supported && extreme) { + threads_per_col = WARP_SIZE; + } + } + return threads_per_col; +} + +// Use shared arch detection from config.h (RocmArchTier, ArchTuning). +// Local aliases for backward compatibility within this file. +using RocmQmvArchTier = rocm::RocmArchTier; + +inline rocm::HWInfo detect_rocm_hw_info(rocm::Device& d) { + static std::mutex hw_mutex; + static std::unordered_map hw_cache; + + int hip_device = d.hip_device(); + { + std::lock_guard lock(hw_mutex); + auto it = hw_cache.find(hip_device); + if (it != hw_cache.end()) { + return it->second; + } + } + + hipDeviceProp_t props{}; + d.make_current(); + hipError_t err = hipGetDeviceProperties(&props, hip_device); + + rocm::HWInfo hw{}; + hw.tier = (WARP_SIZE == 32) ? RocmQmvArchTier::Rdna2 : RocmQmvArchTier::Cdna; + + if (err == hipSuccess) { + hw.num_cus = props.multiProcessorCount; + hw.max_threads_per_cu = props.maxThreadsPerMultiProcessor; + hw.shared_mem_per_cu = props.sharedMemPerBlock; + hw.l2_cache_bytes = props.l2CacheSize; + + const char* arch_name = props.gcnArchName; + if (arch_name != nullptr) { + if (std::strstr(arch_name, "gfx1200") != nullptr || + std::strstr(arch_name, "gfx1201") != nullptr) { + hw.tier = RocmQmvArchTier::Rdna4; + hw.simds_per_cu = 2; + } else if (std::strstr(arch_name, "gfx1150") != nullptr || + std::strstr(arch_name, "gfx1151") != nullptr || + std::strstr(arch_name, "gfx1152") != nullptr) { + hw.tier = RocmQmvArchTier::Rdna35; + hw.simds_per_cu = 2; + } else if (std::strstr(arch_name, "gfx11") != nullptr) { + hw.tier = RocmQmvArchTier::Rdna3; + hw.simds_per_cu = 2; + } else if (std::strstr(arch_name, "gfx10") != nullptr) { + hw.tier = RocmQmvArchTier::Rdna2; + hw.simds_per_cu = 2; + } else if (std::strstr(arch_name, "gfx9") != nullptr) { + hw.tier = RocmQmvArchTier::Cdna; + hw.simds_per_cu = 4; + } + + // rocWMMA library arch allowlist (AMD's official support matrix). + // CDNA1/2/3 use MFMA under rocwmma; RDNA3 dGPU + gfx1151 + RDNA4 use + // hardware WMMA. Excludes gfx1103/1150/1152 and all gfx10xx (RDNA1/2). + static const char* const kRocwmmaArches[] = { + "gfx908", "gfx90a", "gfx942", + "gfx1100", "gfx1101", "gfx1102", + "gfx1151", + "gfx1200", "gfx1201", + }; + for (const char* a : kRocwmmaArches) { + if (std::strstr(arch_name, a) != nullptr) { + hw.has_native_wmma = true; + break; + } + } + } + } + + { + std::lock_guard lock(hw_mutex); + hw_cache[hip_device] = hw; + } + return hw; +} + +inline RocmQmvArchTier detect_rocm_qmv_arch_tier(rocm::Device& d) { + return detect_rocm_hw_info(d).tier; +} + +inline int select_qmv_qmm_crossover_m_threshold( + int K, + int N, + int batch_count, + bool transpose, + bool can_use_batched_qmv, + rocm::Device& d) { + if (!transpose) { + return 1; + } + if ((batch_count > 1) && !can_use_batched_qmv) { + return 1; + } + + int small_shape_limit; + int medium_shape_limit; + int large_shape_limit; + + auto tuning = rocm::get_arch_tuning(detect_rocm_qmv_arch_tier(d)); + small_shape_limit = tuning.qmv_crossover_small; + medium_shape_limit = tuning.qmv_crossover_medium; + large_shape_limit = tuning.qmv_crossover_large; + + if (batch_count > 1 && can_use_batched_qmv) { + small_shape_limit += 8; + medium_shape_limit += 6; + large_shape_limit += 4; + } + + if (K <= 2048 && N <= 2048) { + return small_shape_limit; + } + if (K <= 4096 && N <= 4096) { + return medium_shape_limit; + } + return large_shape_limit; +} + +inline bool should_use_tiny_k_qmv_path( + int M, + int N, + int K, + int batch_count, + bool transpose, + bool can_use_batched_qmv, + int bits, + QuantizationMode mode) { + if (!transpose || can_use_batched_qmv || batch_count != 1) { + return false; + } + + bool bits_supported = (bits == 2 || bits == 4 || bits == 8) || + (mode == QuantizationMode::Affine && (bits == 5 || bits == 6)); + if (!bits_supported) { + return false; + } + + bool tiny_k = (K == 64 || K == 128 || K == 256); + bool decode_like = (M <= 4); + bool width_enough = (N >= 512); + return tiny_k && decode_like && width_enough; +} + +inline bool is_aligned_ptr(const void* ptr, size_t align) { + if (ptr == nullptr || align == 0) { + return false; + } + auto addr = reinterpret_cast(ptr); + return (addr % align) == 0; +} + +inline bool has_packed_layout_compatibility_for_aligned_qmv(int K, int bits) { + switch (bits) { + case 8: + return (K % 16) == 0; + case 6: + return (K % 64) == 0; + case 4: + return (K % 32) == 0; + case 2: + return (K % 64) == 0; + default: + return false; + } +} + +inline bool should_use_alignment_qmv_noshared_path( + int M, + int N, + int K, + int batch_count, + bool transpose, + bool can_use_batched_qmv, + int bits, + QuantizationMode mode, + const void* x_ptr, + const void* w_ptr, + const void* scales_ptr, + const void* biases_ptr, + bool has_bias) { + if (!transpose || can_use_batched_qmv || batch_count != 1) { + return false; + } + + bool bits_supported = (bits == 2 || bits == 4 || bits == 8) || + (mode == QuantizationMode::Affine && bits == 6); + if (!bits_supported) { + return false; + } + if (!has_packed_layout_compatibility_for_aligned_qmv(K, bits)) { + return false; + } + + bool decode_like = (M <= 8); + bool width_enough = (N >= 1024); + if (!decode_like || !width_enough) { + return false; + } + + bool pointers_aligned = is_aligned_ptr(x_ptr, 16) && + is_aligned_ptr(w_ptr, 16) && is_aligned_ptr(scales_ptr, 16); + if (has_bias) { + pointers_aligned = pointers_aligned && is_aligned_ptr(biases_ptr, 16); + } + return pointers_aligned; +} + +inline bool should_use_dequant_gemm_path( + int M, + int N, + int K, + int batch_count, + bool non_batched, + bool transpose, + bool can_use_batched_qmv, + rocm::Device& d) { + int env_threshold = + parse_positive_int_env("MLX_ROCM_QMM_DEQUANT_M_THRESHOLD", -1); + if (env_threshold > 0) { + return M >= env_threshold; + } + + if (!transpose) { + return true; + } + + if (batch_count > 1) { + if (!can_use_batched_qmv) { + return true; + } + } + + if (!non_batched) { + return M >= select_qmv_qmm_crossover_m_threshold( + K, N, batch_count, transpose, can_use_batched_qmv, d); + } + + int threshold = select_qmv_qmm_crossover_m_threshold( + K, N, batch_count, transpose, can_use_batched_qmv, d); + + if (M >= threshold) { + return true; + } + + // Favor dequant+GEMM slightly earlier on very large decode-style shapes. + if (N >= 8192 && K >= 2048) { + return M >= std::max(8, threshold - 4); + } + return false; +} + +struct DequantCacheKey { + std::uintptr_t w_ptr; + std::uintptr_t scales_ptr; + std::uintptr_t biases_ptr; + int group_size; + int bits; + int stream_index; + bool transpose; + Dtype dtype; + + bool operator==(const DequantCacheKey& other) const { + return w_ptr == other.w_ptr && scales_ptr == other.scales_ptr && + biases_ptr == other.biases_ptr && group_size == other.group_size && + bits == other.bits && stream_index == other.stream_index && + transpose == other.transpose && dtype == other.dtype; + } +}; + +struct DequantCacheKeyHasher { + size_t operator()(const DequantCacheKey& key) const { + size_t h = std::hash{}(key.w_ptr); + h ^= std::hash{}(key.scales_ptr) + 0x9e3779b9 + (h << 6) + + (h >> 2); + h ^= std::hash{}(key.biases_ptr) + 0x9e3779b9 + (h << 6) + + (h >> 2); + h ^= std::hash{}(key.group_size) + 0x9e3779b9 + (h << 6) + (h >> 2); + h ^= std::hash{}(key.bits) + 0x9e3779b9 + (h << 6) + (h >> 2); + h ^= std::hash{}(key.stream_index) + 0x9e3779b9 + (h << 6) + (h >> 2); + h ^= std::hash{}(static_cast(key.transpose)) + 0x9e3779b9 + + (h << 6) + (h >> 2); + h ^= std::hash{}(static_cast(key.dtype.val())) + 0x9e3779b9 + + (h << 6) + (h >> 2); + return h; + } +}; + +struct DequantCacheEntry { + array weight; + array w_source; + array scales_source; + std::optional biases_source; + size_t bytes; + std::list::iterator lru_it; +}; + +inline int dequant_cache_capacity() { + static int capacity = []() { + const char* raw = std::getenv("MLX_ROCM_QMM_DEQUANT_CACHE_SIZE"); + if (raw == nullptr || *raw == '\0') { + return 8; + } + char* end = nullptr; + long value = std::strtol(raw, &end, 10); + if (end == raw || *end != '\0' || value < 0) { + return 8; + } + return static_cast(value); + }(); + return capacity; +} + +inline size_t dequant_cache_max_bytes() { + static size_t max_bytes = parse_non_negative_size_t_env( + "MLX_ROCM_QMM_DEQUANT_CACHE_MAX_BYTES", 256ULL * 1024ULL * 1024ULL); + return max_bytes; +} + +inline int qmm_gemm_solution_index_f32(bool batched) { + static int single_index = + parse_non_negative_int_env("MLX_ROCM_GEMM_F32_SOLUTION_INDEX", 0); + static int batched_index = parse_non_negative_int_env( + "MLX_ROCM_GEMM_F32_BATCHED_SOLUTION_INDEX", -1); + if (!batched) { + return single_index; + } + return batched_index >= 0 ? batched_index : single_index; +} + +inline int qmm_gemm_solution_index_bf16(bool batched) { + static int single_index = + parse_non_negative_int_env("MLX_ROCM_GEMM_BF16_SOLUTION_INDEX", 0); + static int batched_index = parse_non_negative_int_env( + "MLX_ROCM_GEMM_BF16_BATCHED_SOLUTION_INDEX", -1); + if (!batched) { + return single_index; + } + return batched_index >= 0 ? batched_index : single_index; +} + +inline rocblas_operation to_rocblas_op(bool transpose) { + return transpose ? rocblas_operation_transpose : rocblas_operation_none; +} + +// --- fp8 (e4m3) GEMM path for RDNA4 ----------------------------------------- +// The half-precision activation and dequantized weight are cast to e4m3 with a +// per-tensor absmax scale, multiplied on fp8 matrix cores (~1.5-2.4x bf16), and +// descaled back to bf16 by hipBLASLt. fp8 buffers are raw bytes (no MLX dtype). + +constexpr float kE4M3Max = 448.0f; + +template +__global__ void fp8_absmax_kernel( + const T* __restrict__ src, size_t n, float* __restrict__ amax) { + __shared__ float sm[256]; + float local = 0.0f; + for (size_t j = blockIdx.x * blockDim.x + threadIdx.x; j < n; + j += static_cast(gridDim.x) * blockDim.x) { + local = fmaxf(local, fabsf(static_cast(src[j]))); + } + sm[threadIdx.x] = local; + __syncthreads(); + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + sm[threadIdx.x] = fmaxf(sm[threadIdx.x], sm[threadIdx.x + s]); + } + __syncthreads(); + } + if (threadIdx.x == 0) { + // Non-negative floats compare monotonically as ints. + atomicMax(reinterpret_cast(amax), __float_as_int(sm[0])); + } +} + +// Writes the descale factor (amax/448) hipBLASLt multiplies back into the +// output. Guards the all-zero case. +__global__ void fp8_descale_kernel( + const float* __restrict__ amax, float* __restrict__ descale) { + float a = *amax; + *descale = (a > 0.0f) ? (a / kE4M3Max) : 1.0f; +} + +template +__global__ void fp8_cast_kernel( + const T* __restrict__ src, + size_t n, + const float* __restrict__ amax, + __hip_fp8_e4m3* __restrict__ dst) { + float a = *amax; + float inv = (a > 0.0f) ? (kE4M3Max / a) : 0.0f; + for (size_t j = blockIdx.x * blockDim.x + threadIdx.x; j < n; + j += static_cast(gridDim.x) * blockDim.x) { + dst[j] = __hip_fp8_e4m3(static_cast(src[j]) * inv); + } +} + +template +void fp8_quantize( + hipStream_t stream, + const T* src, + size_t n, + float* amax, + float* descale, + __hip_fp8_e4m3* dst) { + int threads = 256; + int blocks = static_cast(std::min((n + threads - 1) / threads, 4096)); + fp8_absmax_kernel<<>>(src, n, amax); + fp8_descale_kernel<<<1, 1, 0, stream>>>(amax, descale); + fp8_cast_kernel<<>>(src, n, amax, dst); +} + +// Per-tensor absmax of the dequantized affine weight, computed from the quant +// params alone (dequant values are linear in q, so the extrema are at q=0 and +// q=qmax). Reads only scales/biases (group_size x fewer elements than the +// weight) — no full-weight pass. +template +__global__ void weight_absmax_kernel( + const ScaleT* __restrict__ scales, + const ScaleT* __restrict__ biases, + int num_groups, + float* __restrict__ absmax) { + __shared__ float sm[256]; + float local = 0.0f; + constexpr int qmax = (1 << BITS) - 1; + for (int g = blockIdx.x * blockDim.x + threadIdx.x; g < num_groups; + g += static_cast(gridDim.x) * blockDim.x) { + float s = static_cast(scales[g]); + float b = biases ? static_cast(biases[g]) : 0.0f; + local = fmaxf(local, fmaxf(fabsf(b), fabsf(s * qmax + b))); + } + sm[threadIdx.x] = local; + __syncthreads(); + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + sm[threadIdx.x] = fmaxf(sm[threadIdx.x], sm[threadIdx.x + s]); + } + __syncthreads(); + } + if (threadIdx.x == 0) { + atomicMax(reinterpret_cast(absmax), __float_as_int(sm[0])); + } +} + +// Dequantize packed affine weights straight to e4m3 (no bf16 intermediate), +// scaled by 448/absmax. No ROCm/HIP library consumes packed quant, so this is +// the necessary hand-rolled step. +template +__global__ void dequant_to_e4m3_kernel( + const uint8_t* __restrict__ input, + const ScaleT* __restrict__ scales, + const ScaleT* __restrict__ biases, + __hip_fp8_e4m3* __restrict__ output, + int num_groups, + int group_size, + const float* __restrict__ absmax) { + int group_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (group_idx >= num_groups) + return; + float inv = (*absmax > 0.0f) ? (kE4M3Max / *absmax) : 0.0f; + float scale = static_cast(scales[group_idx]); + float bias = biases ? static_cast(biases[group_idx]) : 0.0f; + int input_base = group_idx * (group_size * BITS / 8); + __hip_fp8_e4m3* group_output = output + group_idx * group_size; + constexpr uint8_t mask = static_cast((1u << BITS) - 1u); + for (int i = 0; i < group_size; ++i) { + int bit_index = i * BITS; + int byte_idx = input_base + (bit_index >> 3); + int bit_offset = bit_index & 0x7; + uint32_t packed = static_cast(input[byte_idx]); + if (bit_offset + BITS > 8) { + packed |= static_cast(input[byte_idx + 1]) << 8; + } + int qv = static_cast((packed >> bit_offset) & mask); + float dq = static_cast(qv) * scale + bias; + group_output[i] = __hip_fp8_e4m3(dq * inv); + } +} + +// out[M,N] = x[M,K] @ w^T via e4m3 hipBLASLt. The weight is dequantized +// straight from packed quant to e4m3; the activation is cast to e4m3; hipBLASLt +// descales both back to bf16. transpose mirrors the dequant_rocblas_gemm +// convention (transpose_a=false, transpose_b=transpose). +void dequant_fp8_gemm( + rocm::CommandEncoder& enc, + bool transpose, + int M, + int N, + int K, + const array& x, + const array& wq, + const array& scales, + const std::optional& biases, + int group_size, + int bits, + array& out, + int dq_rows, + int dq_cols) { + array x_fp8(Shape{M, K}, uint8, nullptr, {}); + x_fp8.set_data(mlx::core::rocm::malloc_async(x_fp8.nbytes(), enc)); + array w_fp8(Shape{dq_rows, dq_cols}, uint8, nullptr, {}); + w_fp8.set_data(mlx::core::rocm::malloc_async(w_fp8.nbytes(), enc)); + array scratch(Shape{4}, float32, nullptr, {}); + scratch.set_data(mlx::core::rocm::malloc_async(scratch.nbytes(), enc)); + enc.add_temporary(x_fp8); + enc.add_temporary(w_fp8); + enc.add_temporary(scratch); + + enc.set_input_array(x); + enc.set_input_array(wq); + enc.set_input_array(scales); + if (biases) + enc.set_input_array(*biases); + enc.set_output_array(out); + + rocblas_operation op_a = to_rocblas_op(false); + rocblas_operation op_b = to_rocblas_op(transpose); + int lda = K; + int ldb = transpose ? K : N; + Dtype xdt = x.dtype(); + Dtype sdt = scales.dtype(); + bool has_bias = biases.has_value(); + size_t nw = static_cast(dq_rows) * dq_cols; + int wgroups = static_cast(nw / group_size); + + enc.launch_kernel([=, &enc](hipStream_t stream) { + const void* xp = gpu_ptr(x); + const void* wqp = gpu_ptr(wq); + const void* sp = gpu_ptr(scales); + const void* bp = has_bias ? gpu_ptr(*biases) : nullptr; + void* op = gpu_ptr(out); + float* amax_x = reinterpret_cast(gpu_ptr(scratch)); + float* desc_x = amax_x + 1; + float* amax_w = amax_x + 2; + float* desc_w = amax_x + 3; + auto* xf = reinterpret_cast<__hip_fp8_e4m3*>(gpu_ptr(x_fp8)); + auto* wf = reinterpret_cast<__hip_fp8_e4m3*>(gpu_ptr(w_fp8)); + size_t nx = static_cast(M) * K; + (void)hipMemsetAsync(amax_x, 0, sizeof(float), stream); + (void)hipMemsetAsync(amax_w, 0, sizeof(float), stream); + + if (xdt == bfloat16) { + fp8_quantize<__hip_bfloat16>( + stream, static_cast(xp), nx, amax_x, desc_x, xf); + } else { + fp8_quantize<__half>( + stream, static_cast(xp), nx, amax_x, desc_x, xf); + } + + int threads = 256; + int blocks = std::max(1, (wgroups + threads - 1) / threads); + int absblocks = std::min(blocks, 4096); +#define LAUNCH_DEQUANT_E4M3(ScaleT, BITS) \ + do { \ + weight_absmax_kernel<<>>( \ + static_cast(sp), \ + static_cast(bp), \ + wgroups, \ + amax_w); \ + fp8_descale_kernel<<<1, 1, 0, stream>>>(amax_w, desc_w); \ + dequant_to_e4m3_kernel<<>>( \ + static_cast(wqp), \ + static_cast(sp), \ + static_cast(bp), \ + wf, \ + wgroups, \ + group_size, \ + amax_w); \ + } while (0) +#define DISPATCH_BITS_E4M3(ScaleT) \ + switch (bits) { \ + case 2: \ + LAUNCH_DEQUANT_E4M3(ScaleT, 2); \ + break; \ + case 4: \ + LAUNCH_DEQUANT_E4M3(ScaleT, 4); \ + break; \ + case 5: \ + LAUNCH_DEQUANT_E4M3(ScaleT, 5); \ + break; \ + case 6: \ + LAUNCH_DEQUANT_E4M3(ScaleT, 6); \ + break; \ + case 8: \ + LAUNCH_DEQUANT_E4M3(ScaleT, 8); \ + break; \ + default: \ + break; \ + } + if (sdt == bfloat16) { + DISPATCH_BITS_E4M3(__hip_bfloat16); + } else if (sdt == float16) { + DISPATCH_BITS_E4M3(__half); + } else { + DISPATCH_BITS_E4M3(float); + } +#undef DISPATCH_BITS_E4M3 +#undef LAUNCH_DEQUANT_E4M3 + + // Column-major swap: A<-w, B<-x, M<->N (same as the bf16 dequant path). + rocm::hipblaslt_gemm_fp8_raw( + stream, + static_cast(op_b), + static_cast(op_a), + N, + M, + K, + wf, + ldb, + xf, + lda, + op, + N, + desc_w, + desc_x); + }); +} + +void dequant_rocblas_gemm( + rocm::CommandEncoder& enc, + bool transpose_a, + bool transpose_b, + int M, + int N, + int K, + float alpha, + const array& a, + int lda, + const array& b, + int ldb, + float beta, + array& c, + int ldc, + Dtype dtype) { + const void* a_ptr = gpu_ptr(a); + const void* b_ptr = gpu_ptr(b); + void* c_ptr = gpu_ptr(c); + + enc.launch_kernel([&, a_ptr, b_ptr, c_ptr](hipStream_t stream) { + enc.device().set_rocblas_stream(stream); + rocblas_handle handle = enc.device().get_rocblas_handle(); + + rocblas_operation op_a = to_rocblas_op(transpose_a); + rocblas_operation op_b = to_rocblas_op(transpose_b); + + // Prefer hipBLASLt for all supported dtypes. rocBLAS/Tensile ships + // incomplete gfx1201 coverage (missing large-tile kernels such as + // MT128x128x32) and faults on large prefill GEMMs; hipBLASLt covers them. + if (rocm::is_hipblaslt_available() && + (dtype == float32 || dtype == float16 || dtype == bfloat16)) { + int dt_hint = (dtype == float16) ? 1 : (dtype == bfloat16) ? 2 : 3; + float alpha_f = alpha; + float beta_f = beta; + try { + rocm::hipblaslt_gemm_raw( + stream, + static_cast(op_b), static_cast(op_a), + N, M, K, + &alpha_f, b_ptr, ldb, a_ptr, lda, + &beta_f, c_ptr, ldc, + dt_hint, 0); + return; + } catch (...) { + // Fall through to rocBLAS below. + } + } + + switch (dtype) { + case float32: { + float alpha_f = alpha; + float beta_f = beta; + int solution_index = qmm_gemm_solution_index_f32(false); + static std::atomic solution_valid{true}; + + if (solution_index > 0 && + solution_valid.load(std::memory_order_relaxed)) { + rocblas_status status = rocblas_gemm_ex( + handle, + op_b, + op_a, + N, + M, + K, + &alpha_f, + b_ptr, + rocblas_datatype_f32_r, + ldb, + a_ptr, + rocblas_datatype_f32_r, + lda, + &beta_f, + c_ptr, + rocblas_datatype_f32_r, + ldc, + c_ptr, + rocblas_datatype_f32_r, + ldc, + rocblas_datatype_f32_r, + rocblas_gemm_algo_solution_index, + solution_index, + 0); + if (status != rocblas_status_success) { + solution_valid.store(false, std::memory_order_relaxed); + rocblas_sgemm( + handle, + op_b, + op_a, + N, + M, + K, + &alpha_f, + static_cast(b_ptr), + ldb, + static_cast(a_ptr), + lda, + &beta_f, + static_cast(c_ptr), + ldc); + } + } else { + rocblas_sgemm( + handle, + op_b, + op_a, + N, + M, + K, + &alpha_f, + static_cast(b_ptr), + ldb, + static_cast(a_ptr), + lda, + &beta_f, + static_cast(c_ptr), + ldc); + } + break; + } + case float16: { + rocblas_half alpha_h, beta_h; + __half alpha_f16 = static_cast<__half>(alpha); + __half beta_f16 = static_cast<__half>(beta); + std::memcpy(&alpha_h, &alpha_f16, sizeof(rocblas_half)); + std::memcpy(&beta_h, &beta_f16, sizeof(rocblas_half)); + rocblas_hgemm( + handle, + op_b, + op_a, + N, + M, + K, + &alpha_h, + reinterpret_cast(b_ptr), + ldb, + reinterpret_cast(a_ptr), + lda, + &beta_h, + reinterpret_cast(c_ptr), + ldc); + break; + } + case bfloat16: { + float alpha_f = alpha; + float beta_f = beta; + + // Try hipBLASLt first for bf16 GEMMs — often faster on RDNA 3.5/CDNA + if (rocm::is_hipblaslt_available()) { + try { + // data_type=0 means "use bfloat16", impl maps internally + rocm::hipblaslt_gemm_raw( + stream, + static_cast(op_b), static_cast(op_a), + N, M, K, + &alpha_f, b_ptr, ldb, a_ptr, lda, + &beta_f, c_ptr, ldc, + 2, // 2 = bfloat16 (mapped in impl) + 0); // unused + break; + } catch (...) { + // Fall through to rocBLAS + } + } + + int solution_index = qmm_gemm_solution_index_bf16(false); + static std::atomic solution_valid{true}; + + rocblas_gemm_algo algo = rocblas_gemm_algo_standard; + if (solution_index > 0 && + solution_valid.load(std::memory_order_relaxed)) { + algo = rocblas_gemm_algo_solution_index; + } else { + solution_index = 0; + } + + rocblas_status status = rocblas_gemm_ex( + handle, + op_b, + op_a, + N, + M, + K, + &alpha_f, + b_ptr, + rocblas_datatype_bf16_r, + ldb, + a_ptr, + rocblas_datatype_bf16_r, + lda, + &beta_f, + c_ptr, + rocblas_datatype_bf16_r, + ldc, + c_ptr, + rocblas_datatype_bf16_r, + ldc, + rocblas_datatype_f32_r, + algo, + solution_index, + 0); + + if (status != rocblas_status_success && + algo == rocblas_gemm_algo_solution_index) { + solution_valid.store(false, std::memory_order_relaxed); + rocblas_gemm_ex( + handle, + op_b, + op_a, + N, + M, + K, + &alpha_f, + b_ptr, + rocblas_datatype_bf16_r, + ldb, + a_ptr, + rocblas_datatype_bf16_r, + lda, + &beta_f, + c_ptr, + rocblas_datatype_bf16_r, + ldc, + c_ptr, + rocblas_datatype_bf16_r, + ldc, + rocblas_datatype_f32_r, + rocblas_gemm_algo_standard, + 0, + 0); + } + break; + } + default: + throw std::runtime_error("Unsupported dtype for rocBLAS GEMM"); + } + }); +} + +void dequant_rocblas_gemm_batched( + rocm::CommandEncoder& enc, + bool transpose_a, + bool transpose_b, + int M, + int N, + int K, + float alpha, + const array& a, + int lda, + int64_t stride_a, + const array& b, + int ldb, + int64_t stride_b, + float beta, + array& c, + int ldc, + int64_t stride_c, + int batch_count, + Dtype dtype) { + const void* a_ptr = gpu_ptr(a); + const void* b_ptr = gpu_ptr(b); + void* c_ptr = gpu_ptr(c); + + enc.launch_kernel([&, a_ptr, b_ptr, c_ptr](hipStream_t stream) { + enc.device().set_rocblas_stream(stream); + rocblas_handle handle = enc.device().get_rocblas_handle(); + + rocblas_operation op_a = to_rocblas_op(transpose_a); + rocblas_operation op_b = to_rocblas_op(transpose_b); + + switch (dtype) { + case float32: { + float alpha_f = alpha; + float beta_f = beta; + int solution_index = qmm_gemm_solution_index_f32(true); + static std::atomic solution_valid{true}; + + if (solution_index > 0 && + solution_valid.load(std::memory_order_relaxed)) { + rocblas_status status = rocblas_gemm_strided_batched_ex( + handle, + op_b, + op_a, + N, + M, + K, + &alpha_f, + b_ptr, + rocblas_datatype_f32_r, + ldb, + stride_b, + a_ptr, + rocblas_datatype_f32_r, + lda, + stride_a, + &beta_f, + c_ptr, + rocblas_datatype_f32_r, + ldc, + stride_c, + c_ptr, + rocblas_datatype_f32_r, + ldc, + stride_c, + batch_count, + rocblas_datatype_f32_r, + rocblas_gemm_algo_solution_index, + solution_index, + 0); + if (status != rocblas_status_success) { + solution_valid.store(false, std::memory_order_relaxed); + rocblas_sgemm_strided_batched( + handle, + op_b, + op_a, + N, + M, + K, + &alpha_f, + static_cast(b_ptr), + ldb, + stride_b, + static_cast(a_ptr), + lda, + stride_a, + &beta_f, + static_cast(c_ptr), + ldc, + stride_c, + batch_count); + } + } else { + rocblas_sgemm_strided_batched( + handle, + op_b, + op_a, + N, + M, + K, + &alpha_f, + static_cast(b_ptr), + ldb, + stride_b, + static_cast(a_ptr), + lda, + stride_a, + &beta_f, + static_cast(c_ptr), + ldc, + stride_c, + batch_count); + } + break; + } + case float16: { + rocblas_half alpha_h, beta_h; + __half alpha_f16 = static_cast<__half>(alpha); + __half beta_f16 = static_cast<__half>(beta); + std::memcpy(&alpha_h, &alpha_f16, sizeof(rocblas_half)); + std::memcpy(&beta_h, &beta_f16, sizeof(rocblas_half)); + rocblas_hgemm_strided_batched( + handle, + op_b, + op_a, + N, + M, + K, + &alpha_h, + reinterpret_cast(b_ptr), + ldb, + stride_b, + reinterpret_cast(a_ptr), + lda, + stride_a, + &beta_h, + reinterpret_cast(c_ptr), + ldc, + stride_c, + batch_count); + break; + } + case bfloat16: { + float alpha_f = alpha; + float beta_f = beta; + int solution_index = qmm_gemm_solution_index_bf16(true); + static std::atomic solution_valid{true}; + + rocblas_gemm_algo algo = rocblas_gemm_algo_standard; + if (solution_index > 0 && + solution_valid.load(std::memory_order_relaxed)) { + algo = rocblas_gemm_algo_solution_index; + } else { + solution_index = 0; + } + + rocblas_status status = rocblas_gemm_strided_batched_ex( + handle, + op_b, + op_a, + N, + M, + K, + &alpha_f, + b_ptr, + rocblas_datatype_bf16_r, + ldb, + stride_b, + a_ptr, + rocblas_datatype_bf16_r, + lda, + stride_a, + &beta_f, + c_ptr, + rocblas_datatype_bf16_r, + ldc, + stride_c, + c_ptr, + rocblas_datatype_bf16_r, + ldc, + stride_c, + batch_count, + rocblas_datatype_f32_r, + algo, + solution_index, + 0); + + if (status != rocblas_status_success && + algo == rocblas_gemm_algo_solution_index) { + solution_valid.store(false, std::memory_order_relaxed); + rocblas_gemm_strided_batched_ex( + handle, + op_b, + op_a, + N, + M, + K, + &alpha_f, + b_ptr, + rocblas_datatype_bf16_r, + ldb, + stride_b, + a_ptr, + rocblas_datatype_bf16_r, + lda, + stride_a, + &beta_f, + c_ptr, + rocblas_datatype_bf16_r, + ldc, + stride_c, + c_ptr, + rocblas_datatype_bf16_r, + ldc, + stride_c, + batch_count, + rocblas_datatype_f32_r, + rocblas_gemm_algo_standard, + 0, + 0); + } + break; + } + default: + throw std::runtime_error("Unsupported dtype for rocBLAS batched GEMM"); + } + }); +} + +} // namespace + +namespace rocm { + +template +__device__ inline uint8_t +unpack_packed_value(const uint8_t* packed_row, int k, int row_bytes) { + constexpr uint8_t mask = (1u << BITS) - 1u; + if constexpr (BITS == 2 || BITS == 4 || BITS == 8) { + constexpr int pack_factor = 8 / BITS; + int pack_idx = k / pack_factor; + int bit_offset = (k % pack_factor) * BITS; + return (packed_row[pack_idx] >> bit_offset) & mask; + } else { + int bit_index = k * BITS; + int byte_idx = bit_index >> 3; + int bit_offset = bit_index & 0x7; + + uint32_t window = static_cast(packed_row[byte_idx]); + if (byte_idx + 1 < row_bytes) { + window |= static_cast(packed_row[byte_idx + 1]) << 8; + } + return static_cast((window >> bit_offset) & mask); + } +} + +template +__device__ inline uint8_t +unpack_packed_value_fast(const uint8_t* packed_row, int k, int row_bytes) { + if constexpr (BITS == 8) { + (void)row_bytes; + return packed_row[k]; + } else if constexpr (BITS == 4) { + (void)row_bytes; + uint8_t packed = packed_row[k >> 1]; + return (k & 1) ? (packed >> 4) : (packed & 0xF); + } else if constexpr (BITS == 2) { + (void)row_bytes; + uint8_t packed = packed_row[k >> 2]; + return (packed >> ((k & 0x3) * 2)) & 0x3; + } else { + return unpack_packed_value(packed_row, k, row_bytes); + } +} + +template +__device__ __forceinline__ T subgroup_reduce_sum_qmm(T val) { + static_assert((SUBGROUP_SIZE & (SUBGROUP_SIZE - 1)) == 0); + static_assert(SUBGROUP_SIZE <= WARP_SIZE); + +#pragma unroll + for (int offset = SUBGROUP_SIZE / 2; offset > 0; offset /= 2) { + val += __shfl_xor(val, offset); + } + return val; +} + +template +__device__ __forceinline__ T warp_reduce_sum_qmm(T val) { + return subgroup_reduce_sum_qmm(val); +} + +__device__ inline float fp4_e2m1_to_float(uint8_t val) { + switch (val & 0xF) { + case 0x0: + return 0.0f; + case 0x1: + return 0.5f; + case 0x2: + return 1.0f; + case 0x3: + return 1.5f; + case 0x4: + return 2.0f; + case 0x5: + return 3.0f; + case 0x6: + return 4.0f; + case 0x7: + return 6.0f; + case 0x8: + return -0.0f; + case 0x9: + return -0.5f; + case 0xA: + return -1.0f; + case 0xB: + return -1.5f; + case 0xC: + return -2.0f; + case 0xD: + return -3.0f; + case 0xE: + return -4.0f; + case 0xF: + return -6.0f; + default: + return 0.0f; + } +} + +__device__ __forceinline__ float fp8_e4m3_to_float(uint8_t val) { + // Use a simple array lookup or bit manipulation. + // Actually, MI300 supports hardware fp8 conversion: + // But we can just use a fast bit manipulation without branches. + + uint32_t sign = (val >> 7) & 0x1; + uint32_t exp = (val >> 3) & 0xF; + uint32_t mant = val & 0x7; + + if (exp == 0 && mant == 0) { + return sign ? -0.0f : 0.0f; + } + + uint32_t float_exp = exp == 0 ? 0 : exp - 7 + 127; + // Handle subnormals approximately or cleanly if needed, + // but for performance, we can just do: + if (exp == 0) { + float subnormal = static_cast(mant) * 0.001953125f; // 2^-9 + return sign ? -subnormal : subnormal; + } + + uint32_t float_mant = mant << 20; + uint32_t bits = (sign << 31) | (float_exp << 23) | float_mant; + return __uint_as_float(bits); +} + +template +__device__ inline float fp_scale_to_float(uint8_t s) { + if constexpr (GROUP_SIZE == 16) { + return fp8_e4m3_to_float(s); + } else { + union { + uint16_t i; + hip_bfloat16 f; + } out; + out.i = (s == 0 ? 0x40 : (static_cast(s) << 7)); + return static_cast(out.f); + } +} + +template +__device__ inline float load_scale_value(ScaleT raw) { + if constexpr (AFFINE) { + return static_cast(raw); + } else { + return fp_scale_to_float(static_cast(raw)); + } +} + +template +__device__ inline float +dequantize_value(uint8_t quant_val, float scale, float bias) { + if constexpr (AFFINE) { + return static_cast(quant_val) * scale + bias; + } else { + (void)bias; + if constexpr (BITS == 8) { + return fp8_e4m3_to_float(quant_val) * scale; + } else { + return fp4_e2m1_to_float(quant_val) * scale; + } + } +} + +template < + typename T, + typename ScaleT, + int BITS, + int GROUP_SIZE, + bool AFFINE, + int THREADS_PER_COL> +__global__ void __launch_bounds__(1024) qmv_warp_shared_kernel( + const T* __restrict__ x, + const uint8_t* __restrict__ w, + const ScaleT* __restrict__ scales, + const ScaleT* __restrict__ biases, + T* __restrict__ out, + int M, + int N, + int K, + bool has_bias) { + const int lane = threadIdx.x; + const int warp_idx = threadIdx.y; + const int col = blockIdx.x * blockDim.y + warp_idx; + const int row = blockIdx.y; + + const bool valid = (row < M) && (col < N); + + const int num_groups = (K + GROUP_SIZE - 1) / GROUP_SIZE; + const int row_bytes = (K * BITS + 7) / 8; + + const T* x_row = (row < M) ? (x + row * K) : nullptr; + const uint8_t* w_row = valid ? (w + col * row_bytes) : nullptr; + const ScaleT* scales_row = valid ? (scales + col * num_groups) : nullptr; + const ScaleT* biases_row = + (valid && has_bias) ? (biases + col * num_groups) : nullptr; + + float acc = 0.0f; + + // We load a chunk of X into shared memory. + // We use a chunk size of 1024 elements. + constexpr int CHUNK_SIZE = 2048; + __shared__ float shared_x[CHUNK_SIZE]; + + for (int chunk_start = 0; chunk_start < K; chunk_start += CHUNK_SIZE) { + int chunk_end = min(chunk_start + CHUNK_SIZE, K); + int chunk_len = chunk_end - chunk_start; + + // Collaboratively load X chunk into shared memory + int tid = warp_idx * blockDim.x + lane; + for (int i = tid; i < chunk_len; i += blockDim.x * blockDim.y) { + shared_x[i] = static_cast(x_row[chunk_start + i]); + } + __syncthreads(); + + if (valid) { + int g_start = chunk_start / GROUP_SIZE; + int g_end = (chunk_end + GROUP_SIZE - 1) / GROUP_SIZE; + + for (int g = g_start; g < g_end; ++g) { + int k_start = max(g * GROUP_SIZE, chunk_start); + int k_end_g = min((g + 1) * GROUP_SIZE, chunk_end); + + float scale = + load_scale_value(scales_row[g]); + float bias_val = has_bias ? static_cast(biases_row[g]) : 0.0f; + + if constexpr (AFFINE) { + float qx_acc0 = 0.0f, qx_acc1 = 0.0f, qx_acc2 = 0.0f, qx_acc3 = 0.0f; + float x_group_sum = 0.0f; + float qx_acc = 0.0f; + + if constexpr (BITS == 8) { + int k_local = lane * 4; + int step = THREADS_PER_COL * 4; + for (; k_start + k_local + 3 < k_end_g; k_local += step) { + int k = k_start + k_local; + + uint32_t w_packed = *reinterpret_cast(&w_row[k]); + float w0 = static_cast(w_packed & 0xFF); + float w1 = static_cast((w_packed >> 8) & 0xFF); + float w2 = static_cast((w_packed >> 16) & 0xFF); + float w3 = static_cast((w_packed >> 24) & 0xFF); + + float x0 = shared_x[k - chunk_start]; + float x1 = shared_x[k - chunk_start + 1]; + float x2 = shared_x[k - chunk_start + 2]; + float x3 = shared_x[k - chunk_start + 3]; + + qx_acc0 = fmaf(x0, w0, qx_acc0); + qx_acc1 = fmaf(x1, w1, qx_acc1); + qx_acc2 = fmaf(x2, w2, qx_acc2); + qx_acc3 = fmaf(x3, w3, qx_acc3); + + if (has_bias) + x_group_sum += x0 + x1 + x2 + x3; + } + qx_acc = qx_acc0 + qx_acc1 + qx_acc2 + qx_acc3; + for (; k_start + k_local < k_end_g; k_local++) { + int k = k_start + k_local; + float x_val = shared_x[k - chunk_start]; + float w_val = static_cast(w_row[k]); + qx_acc = fmaf(x_val, w_val, qx_acc); + if (has_bias) + x_group_sum += x_val; + } + } else if constexpr (BITS == 4) { + int k_local = lane * 8; + int step = THREADS_PER_COL * 8; + for (; k_start + k_local + 7 < k_end_g; k_local += step) { + int k = k_start + k_local; + uint32_t w_packed = + *reinterpret_cast(&w_row[k / 2]); + float w0 = static_cast(w_packed & 0xF); + float w1 = static_cast((w_packed >> 4) & 0xF); + float w2 = static_cast((w_packed >> 8) & 0xF); + float w3 = static_cast((w_packed >> 12) & 0xF); + float w4 = static_cast((w_packed >> 16) & 0xF); + float w5 = static_cast((w_packed >> 20) & 0xF); + float w6 = static_cast((w_packed >> 24) & 0xF); + float w7 = static_cast((w_packed >> 28) & 0xF); + float x0 = shared_x[k - chunk_start]; + float x1 = shared_x[k - chunk_start + 1]; + float x2 = shared_x[k - chunk_start + 2]; + float x3 = shared_x[k - chunk_start + 3]; + float x4 = shared_x[k - chunk_start + 4]; + float x5 = shared_x[k - chunk_start + 5]; + float x6 = shared_x[k - chunk_start + 6]; + float x7 = shared_x[k - chunk_start + 7]; + // Four independent accumulators so RDNA4 can dual-issue the FMAs + // (a single serial qx_acc chain runs at half rate). Mirrors the + // 8-bit branch above; partial sums are reassociated at group end. + qx_acc0 = fmaf(x0, w0, qx_acc0); + qx_acc1 = fmaf(x1, w1, qx_acc1); + qx_acc2 = fmaf(x2, w2, qx_acc2); + qx_acc3 = fmaf(x3, w3, qx_acc3); + qx_acc0 = fmaf(x4, w4, qx_acc0); + qx_acc1 = fmaf(x5, w5, qx_acc1); + qx_acc2 = fmaf(x6, w6, qx_acc2); + qx_acc3 = fmaf(x7, w7, qx_acc3); + if (has_bias) + x_group_sum += x0 + x1 + x2 + x3 + x4 + x5 + x6 + x7; + } + qx_acc = qx_acc0 + qx_acc1 + qx_acc2 + qx_acc3; + for (; k_start + k_local < k_end_g; k_local++) { + int k = k_start + k_local; + float x_val = shared_x[k - chunk_start]; + uint8_t quant_val = + unpack_packed_value_fast(w_row, k, row_bytes); + qx_acc = fmaf(x_val, static_cast(quant_val), qx_acc); + if (has_bias) + x_group_sum += x_val; + } + } else if constexpr (BITS == 6) { + // Process 8 weights at a time (48 bits = 6 bytes) + int k_local = lane * 8; + int step = THREADS_PER_COL * 8; + // Need at least 7 bytes of room after byte_idx for safe 8-byte load + // row_bytes = (K * 6 + 7) / 8, so we need byte_idx + 7 < row_bytes + int max_safe_k = + ((row_bytes - 7) * 8) / 6; // Max k where 8-byte load is safe + for (; k_start + k_local + 7 < k_end_g && + k_start + k_local < max_safe_k; + k_local += step) { + int k = k_start + k_local; + // 8 weights * 6 bits = 48 bits, starting at bit position k*6 + int byte_idx = (k * 6) / 8; + int bit_offset = (k * 6) % 8; + // Safe to load 8 bytes (we checked bounds above) + uint64_t w_packed; + memcpy(&w_packed, &w_row[byte_idx], 8); + w_packed >>= bit_offset; + // Extract 8 6-bit weights + float w0 = static_cast(w_packed & 0x3F); + float w1 = static_cast((w_packed >> 6) & 0x3F); + float w2 = static_cast((w_packed >> 12) & 0x3F); + float w3 = static_cast((w_packed >> 18) & 0x3F); + float w4 = static_cast((w_packed >> 24) & 0x3F); + float w5 = static_cast((w_packed >> 30) & 0x3F); + float w6 = static_cast((w_packed >> 36) & 0x3F); + float w7 = static_cast((w_packed >> 42) & 0x3F); + float x0 = shared_x[k - chunk_start]; + float x1 = shared_x[k - chunk_start + 1]; + float x2 = shared_x[k - chunk_start + 2]; + float x3 = shared_x[k - chunk_start + 3]; + float x4 = shared_x[k - chunk_start + 4]; + float x5 = shared_x[k - chunk_start + 5]; + float x6 = shared_x[k - chunk_start + 6]; + float x7 = shared_x[k - chunk_start + 7]; + qx_acc = fmaf(x0, w0, qx_acc); + qx_acc = fmaf(x1, w1, qx_acc); + qx_acc = fmaf(x2, w2, qx_acc); + qx_acc = fmaf(x3, w3, qx_acc); + qx_acc = fmaf(x4, w4, qx_acc); + qx_acc = fmaf(x5, w5, qx_acc); + qx_acc = fmaf(x6, w6, qx_acc); + qx_acc = fmaf(x7, w7, qx_acc); + if (has_bias) + x_group_sum += x0 + x1 + x2 + x3 + x4 + x5 + x6 + x7; + } + for (; k_start + k_local < k_end_g; k_local++) { + int k = k_start + k_local; + float x_val = shared_x[k - chunk_start]; + uint8_t quant_val = + unpack_packed_value_fast(w_row, k, row_bytes); + qx_acc = fmaf(x_val, static_cast(quant_val), qx_acc); + if (has_bias) + x_group_sum += x_val; + } + } else { + for (int k_local = lane; k_start + k_local < k_end_g; + k_local += THREADS_PER_COL) { + int k = k_start + k_local; + float x_val = shared_x[k - chunk_start]; + uint8_t quant_val = + unpack_packed_value_fast(w_row, k, row_bytes); + qx_acc = fmaf(x_val, static_cast(quant_val), qx_acc); + if (has_bias) + x_group_sum += x_val; + } + } + acc += scale * qx_acc; + if (has_bias) + acc += bias_val * x_group_sum; + } else { + float qx_acc0 = 0.0f, qx_acc1 = 0.0f, qx_acc2 = 0.0f, qx_acc3 = 0.0f; + float qx_acc = 0.0f; + if constexpr (BITS == 8) { + int k_local = lane * 4; + int step = THREADS_PER_COL * 4; + for (; k_start + k_local + 3 < k_end_g; k_local += step) { + int k = k_start + k_local; + uint32_t w_packed = *reinterpret_cast(&w_row[k]); + float w0 = fp8_e4m3_to_float(w_packed & 0xFF); + float w1 = fp8_e4m3_to_float((w_packed >> 8) & 0xFF); + float w2 = fp8_e4m3_to_float((w_packed >> 16) & 0xFF); + float w3 = fp8_e4m3_to_float((w_packed >> 24) & 0xFF); + + float x0 = shared_x[k - chunk_start]; + float x1 = shared_x[k - chunk_start + 1]; + float x2 = shared_x[k - chunk_start + 2]; + float x3 = shared_x[k - chunk_start + 3]; + + qx_acc0 = fmaf(x0, w0, qx_acc0); + qx_acc1 = fmaf(x1, w1, qx_acc1); + qx_acc2 = fmaf(x2, w2, qx_acc2); + qx_acc3 = fmaf(x3, w3, qx_acc3); + } + qx_acc = qx_acc0 + qx_acc1 + qx_acc2 + qx_acc3; + for (; k_start + k_local < k_end_g; k_local++) { + int k = k_start + k_local; + float x_val = shared_x[k - chunk_start]; + float w_val = fp8_e4m3_to_float(w_row[k]); + qx_acc = fmaf(x_val, w_val, qx_acc); + } + } else if constexpr (BITS == 4) { + for (int k_local = lane; k_start + k_local < k_end_g; + k_local += THREADS_PER_COL) { + int k = k_start + k_local; + float x_val = shared_x[k - chunk_start]; + uint8_t quant_val = + unpack_packed_value_fast(w_row, k, row_bytes); + qx_acc = fmaf( + x_val, + dequantize_value(quant_val, 1.0f, 0.0f), + qx_acc); + } + } else if constexpr (BITS == 6) { + int k_local = lane * 8; + int step = THREADS_PER_COL * 8; + int max_safe_k = ((row_bytes - 7) * 8) / 6; + for (; k_start + k_local + 7 < k_end_g && + k_start + k_local < max_safe_k; + k_local += step) { + int k = k_start + k_local; + int byte_idx = (k * 6) / 8; + int bit_offset = (k * 6) % 8; + uint64_t w_packed = 0; + memcpy(&w_packed, &w_row[byte_idx], 8); + w_packed >>= bit_offset; + float w0 = + dequantize_value<6, false>(w_packed & 0x3F, 1.0f, 0.0f); + float w1 = dequantize_value<6, false>( + (w_packed >> 6) & 0x3F, 1.0f, 0.0f); + float w2 = dequantize_value<6, false>( + (w_packed >> 12) & 0x3F, 1.0f, 0.0f); + float w3 = dequantize_value<6, false>( + (w_packed >> 18) & 0x3F, 1.0f, 0.0f); + float w4 = dequantize_value<6, false>( + (w_packed >> 24) & 0x3F, 1.0f, 0.0f); + float w5 = dequantize_value<6, false>( + (w_packed >> 30) & 0x3F, 1.0f, 0.0f); + float w6 = dequantize_value<6, false>( + (w_packed >> 36) & 0x3F, 1.0f, 0.0f); + float w7 = dequantize_value<6, false>( + (w_packed >> 42) & 0x3F, 1.0f, 0.0f); + float x0 = shared_x[k - chunk_start]; + float x1 = shared_x[k - chunk_start + 1]; + float x2 = shared_x[k - chunk_start + 2]; + float x3 = shared_x[k - chunk_start + 3]; + float x4 = shared_x[k - chunk_start + 4]; + float x5 = shared_x[k - chunk_start + 5]; + float x6 = shared_x[k - chunk_start + 6]; + float x7 = shared_x[k - chunk_start + 7]; + qx_acc0 = fmaf(x0, w0, qx_acc0); + qx_acc1 = fmaf(x1, w1, qx_acc1); + qx_acc2 = fmaf(x2, w2, qx_acc2); + qx_acc3 = fmaf(x3, w3, qx_acc3); + qx_acc0 = fmaf(x4, w4, qx_acc0); + qx_acc1 = fmaf(x5, w5, qx_acc1); + qx_acc2 = fmaf(x6, w6, qx_acc2); + qx_acc3 = fmaf(x7, w7, qx_acc3); + } + qx_acc = qx_acc0 + qx_acc1 + qx_acc2 + qx_acc3; + for (; k_start + k_local < k_end_g; k_local++) { + int k = k_start + k_local; + float x_val = shared_x[k - chunk_start]; + uint8_t quant_val = + unpack_packed_value_fast(w_row, k, row_bytes); + qx_acc = fmaf( + x_val, + dequantize_value(quant_val, 1.0f, 0.0f), + qx_acc); + } + } else { + for (int k_local = lane; k_start + k_local < k_end_g; + k_local += THREADS_PER_COL) { + int k = k_start + k_local; + float x_val = shared_x[k - chunk_start]; + uint8_t quant_val = + unpack_packed_value_fast(w_row, k, row_bytes); + qx_acc = fmaf( + x_val, + dequantize_value(quant_val, 1.0f, 0.0f), + qx_acc); + } + } + acc += scale * qx_acc; + } + } + } + __syncthreads(); // ensure all warps are done before loading next chunk + } + + acc = subgroup_reduce_sum_qmm(acc); + if (valid && lane == 0) { + out[row * N + col] = static_cast(acc); + } +} + +template < + typename T, + typename ScaleT, + int BITS, + int GROUP_SIZE, + bool AFFINE, + int THREADS_PER_COL> +__global__ void __launch_bounds__(1024) qmv_warp_shared_batched_kernel( + const T* __restrict__ x, + const uint8_t* __restrict__ w, + const ScaleT* __restrict__ scales, + const ScaleT* __restrict__ biases, + T* __restrict__ out, + int M, + int N, + int K, + int64_t x_batch_stride, + int64_t w_batch_stride, + int64_t sb_batch_stride, + int64_t out_batch_stride, + bool has_bias) { + const int lane = threadIdx.x; + const int warp_idx = threadIdx.y; + const int col = blockIdx.x * blockDim.y + warp_idx; + const int row = blockIdx.y; + const int batch = blockIdx.z; + + const bool valid = (row < M) && (col < N); + + const int num_groups = (K + GROUP_SIZE - 1) / GROUP_SIZE; + const int row_bytes = (K * BITS + 7) / 8; + + const T* x_batch_ptr = x + static_cast(batch) * x_batch_stride; + const uint8_t* w_batch_ptr = w + static_cast(batch) * w_batch_stride; + const ScaleT* scales_batch_ptr = + scales + static_cast(batch) * sb_batch_stride; + const ScaleT* biases_batch_ptr = has_bias + ? (biases + static_cast(batch) * sb_batch_stride) + : nullptr; + T* out_batch_ptr = out + static_cast(batch) * out_batch_stride; + + const T* x_row = + (row < M) ? (x_batch_ptr + static_cast(row) * K) : nullptr; + const uint8_t* w_row = + valid ? (w_batch_ptr + static_cast(col) * row_bytes) : nullptr; + const ScaleT* scales_row = valid + ? (scales_batch_ptr + static_cast(col) * num_groups) + : nullptr; + const ScaleT* biases_row = (valid && has_bias) + ? (biases_batch_ptr + static_cast(col) * num_groups) + : nullptr; + + float acc = 0.0f; + + constexpr int CHUNK_SIZE = 2048; + __shared__ float shared_x[CHUNK_SIZE]; + + for (int chunk_start = 0; chunk_start < K; chunk_start += CHUNK_SIZE) { + int chunk_end = min(chunk_start + CHUNK_SIZE, K); + int chunk_len = chunk_end - chunk_start; + + int tid = warp_idx * blockDim.x + lane; + for (int i = tid; i < chunk_len; i += blockDim.x * blockDim.y) { + shared_x[i] = static_cast(x_row[chunk_start + i]); + } + __syncthreads(); + + if (valid) { + int g_start = chunk_start / GROUP_SIZE; + int g_end = (chunk_end + GROUP_SIZE - 1) / GROUP_SIZE; + + for (int g = g_start; g < g_end; ++g) { + int k_start = max(g * GROUP_SIZE, chunk_start); + int k_end_g = min((g + 1) * GROUP_SIZE, chunk_end); + + float scale = + load_scale_value(scales_row[g]); + float bias_val = has_bias ? static_cast(biases_row[g]) : 0.0f; + + if constexpr (AFFINE) { + float qx_acc0 = 0.0f; + float qx_acc1 = 0.0f; + float qx_acc2 = 0.0f; + float qx_acc3 = 0.0f; + float x_group_sum = 0.0f; + float qx_acc = 0.0f; + + if constexpr (BITS == 8) { + int k_local = lane * 4; + int step = THREADS_PER_COL * 4; + for (; k_start + k_local + 3 < k_end_g; k_local += step) { + int k = k_start + k_local; + + uint32_t w_packed = *reinterpret_cast(&w_row[k]); + float w0 = static_cast(w_packed & 0xFF); + float w1 = static_cast((w_packed >> 8) & 0xFF); + float w2 = static_cast((w_packed >> 16) & 0xFF); + float w3 = static_cast((w_packed >> 24) & 0xFF); + + float x0 = shared_x[k - chunk_start]; + float x1 = shared_x[k - chunk_start + 1]; + float x2 = shared_x[k - chunk_start + 2]; + float x3 = shared_x[k - chunk_start + 3]; + + qx_acc0 = fmaf(x0, w0, qx_acc0); + qx_acc1 = fmaf(x1, w1, qx_acc1); + qx_acc2 = fmaf(x2, w2, qx_acc2); + qx_acc3 = fmaf(x3, w3, qx_acc3); + + if (has_bias) { + x_group_sum += x0 + x1 + x2 + x3; + } + } + qx_acc = qx_acc0 + qx_acc1 + qx_acc2 + qx_acc3; + for (; k_start + k_local < k_end_g; k_local++) { + int k = k_start + k_local; + float x_val = shared_x[k - chunk_start]; + float w_val = static_cast(w_row[k]); + qx_acc = fmaf(x_val, w_val, qx_acc); + if (has_bias) { + x_group_sum += x_val; + } + } + } else if constexpr (BITS == 4) { + int k_local = lane * 8; + int step = THREADS_PER_COL * 8; + for (; k_start + k_local + 7 < k_end_g; k_local += step) { + int k = k_start + k_local; + uint32_t w_packed = + *reinterpret_cast(&w_row[k / 2]); + float w0 = static_cast(w_packed & 0xF); + float w1 = static_cast((w_packed >> 4) & 0xF); + float w2 = static_cast((w_packed >> 8) & 0xF); + float w3 = static_cast((w_packed >> 12) & 0xF); + float w4 = static_cast((w_packed >> 16) & 0xF); + float w5 = static_cast((w_packed >> 20) & 0xF); + float w6 = static_cast((w_packed >> 24) & 0xF); + float w7 = static_cast((w_packed >> 28) & 0xF); + float x0 = shared_x[k - chunk_start]; + float x1 = shared_x[k - chunk_start + 1]; + float x2 = shared_x[k - chunk_start + 2]; + float x3 = shared_x[k - chunk_start + 3]; + float x4 = shared_x[k - chunk_start + 4]; + float x5 = shared_x[k - chunk_start + 5]; + float x6 = shared_x[k - chunk_start + 6]; + float x7 = shared_x[k - chunk_start + 7]; + // Four independent accumulators for RDNA4 dual-issue (mirrors the + // 8-bit branch); reassociated at group end. + qx_acc0 = fmaf(x0, w0, qx_acc0); + qx_acc1 = fmaf(x1, w1, qx_acc1); + qx_acc2 = fmaf(x2, w2, qx_acc2); + qx_acc3 = fmaf(x3, w3, qx_acc3); + qx_acc0 = fmaf(x4, w4, qx_acc0); + qx_acc1 = fmaf(x5, w5, qx_acc1); + qx_acc2 = fmaf(x6, w6, qx_acc2); + qx_acc3 = fmaf(x7, w7, qx_acc3); + if (has_bias) { + x_group_sum += x0 + x1 + x2 + x3 + x4 + x5 + x6 + x7; + } + } + qx_acc = qx_acc0 + qx_acc1 + qx_acc2 + qx_acc3; + for (; k_start + k_local < k_end_g; k_local++) { + int k = k_start + k_local; + float x_val = shared_x[k - chunk_start]; + uint8_t quant_val = + unpack_packed_value_fast(w_row, k, row_bytes); + qx_acc = fmaf(x_val, static_cast(quant_val), qx_acc); + if (has_bias) { + x_group_sum += x_val; + } + } + } else if constexpr (BITS == 6) { + int k_local = lane * 8; + int step = THREADS_PER_COL * 8; + int max_safe_k = ((row_bytes - 7) * 8) / 6; + for (; k_start + k_local + 7 < k_end_g && + k_start + k_local < max_safe_k; + k_local += step) { + int k = k_start + k_local; + int byte_idx = (k * 6) / 8; + int bit_offset = (k * 6) % 8; + uint64_t w_packed = 0; + memcpy(&w_packed, &w_row[byte_idx], 8); + w_packed >>= bit_offset; + float w0 = static_cast(w_packed & 0x3F); + float w1 = static_cast((w_packed >> 6) & 0x3F); + float w2 = static_cast((w_packed >> 12) & 0x3F); + float w3 = static_cast((w_packed >> 18) & 0x3F); + float w4 = static_cast((w_packed >> 24) & 0x3F); + float w5 = static_cast((w_packed >> 30) & 0x3F); + float w6 = static_cast((w_packed >> 36) & 0x3F); + float w7 = static_cast((w_packed >> 42) & 0x3F); + float x0 = shared_x[k - chunk_start]; + float x1 = shared_x[k - chunk_start + 1]; + float x2 = shared_x[k - chunk_start + 2]; + float x3 = shared_x[k - chunk_start + 3]; + float x4 = shared_x[k - chunk_start + 4]; + float x5 = shared_x[k - chunk_start + 5]; + float x6 = shared_x[k - chunk_start + 6]; + float x7 = shared_x[k - chunk_start + 7]; + qx_acc = fmaf(x0, w0, qx_acc); + qx_acc = fmaf(x1, w1, qx_acc); + qx_acc = fmaf(x2, w2, qx_acc); + qx_acc = fmaf(x3, w3, qx_acc); + qx_acc = fmaf(x4, w4, qx_acc); + qx_acc = fmaf(x5, w5, qx_acc); + qx_acc = fmaf(x6, w6, qx_acc); + qx_acc = fmaf(x7, w7, qx_acc); + if (has_bias) { + x_group_sum += x0 + x1 + x2 + x3 + x4 + x5 + x6 + x7; + } + } + for (; k_start + k_local < k_end_g; k_local++) { + int k = k_start + k_local; + float x_val = shared_x[k - chunk_start]; + uint8_t quant_val = + unpack_packed_value_fast(w_row, k, row_bytes); + qx_acc = fmaf(x_val, static_cast(quant_val), qx_acc); + if (has_bias) { + x_group_sum += x_val; + } + } + } else { + for (int k_local = lane; k_start + k_local < k_end_g; + k_local += THREADS_PER_COL) { + int k = k_start + k_local; + float x_val = shared_x[k - chunk_start]; + uint8_t quant_val = + unpack_packed_value_fast(w_row, k, row_bytes); + qx_acc = fmaf(x_val, static_cast(quant_val), qx_acc); + if (has_bias) { + x_group_sum += x_val; + } + } + } + acc += scale * qx_acc; + if (has_bias) { + acc += bias_val * x_group_sum; + } + } else { + float qx_acc0 = 0.0f; + float qx_acc1 = 0.0f; + float qx_acc2 = 0.0f; + float qx_acc3 = 0.0f; + float qx_acc = 0.0f; + if constexpr (BITS == 8) { + int k_local = lane * 4; + int step = THREADS_PER_COL * 4; + for (; k_start + k_local + 3 < k_end_g; k_local += step) { + int k = k_start + k_local; + + uint32_t w_packed = *reinterpret_cast(&w_row[k]); + float w0 = fp8_e4m3_to_float(w_packed & 0xFF); + float w1 = fp8_e4m3_to_float((w_packed >> 8) & 0xFF); + float w2 = fp8_e4m3_to_float((w_packed >> 16) & 0xFF); + float w3 = fp8_e4m3_to_float((w_packed >> 24) & 0xFF); + + float x0 = shared_x[k - chunk_start]; + float x1 = shared_x[k - chunk_start + 1]; + float x2 = shared_x[k - chunk_start + 2]; + float x3 = shared_x[k - chunk_start + 3]; + + qx_acc0 = fmaf(x0, w0, qx_acc0); + qx_acc1 = fmaf(x1, w1, qx_acc1); + qx_acc2 = fmaf(x2, w2, qx_acc2); + qx_acc3 = fmaf(x3, w3, qx_acc3); + } + qx_acc = qx_acc0 + qx_acc1 + qx_acc2 + qx_acc3; + for (; k_start + k_local < k_end_g; k_local++) { + int k = k_start + k_local; + float x_val = shared_x[k - chunk_start]; + float w_val = fp8_e4m3_to_float(w_row[k]); + qx_acc = fmaf(x_val, w_val, qx_acc); + } + } else if constexpr (BITS == 4) { + int k_local = lane * 8; + int step = THREADS_PER_COL * 8; + for (; k_start + k_local + 7 < k_end_g; k_local += step) { + int k = k_start + k_local; + uint32_t w_packed = + *reinterpret_cast(&w_row[k / 2]); + float w0 = dequantize_value<4, false>(w_packed & 0xF, 1.0f, 0.0f); + float w1 = + dequantize_value<4, false>((w_packed >> 4) & 0xF, 1.0f, 0.0f); + float w2 = + dequantize_value<4, false>((w_packed >> 8) & 0xF, 1.0f, 0.0f); + float w3 = dequantize_value<4, false>( + (w_packed >> 12) & 0xF, 1.0f, 0.0f); + float w4 = dequantize_value<4, false>( + (w_packed >> 16) & 0xF, 1.0f, 0.0f); + float w5 = dequantize_value<4, false>( + (w_packed >> 20) & 0xF, 1.0f, 0.0f); + float w6 = dequantize_value<4, false>( + (w_packed >> 24) & 0xF, 1.0f, 0.0f); + float w7 = dequantize_value<4, false>( + (w_packed >> 28) & 0xF, 1.0f, 0.0f); + float x0 = shared_x[k - chunk_start]; + float x1 = shared_x[k - chunk_start + 1]; + float x2 = shared_x[k - chunk_start + 2]; + float x3 = shared_x[k - chunk_start + 3]; + float x4 = shared_x[k - chunk_start + 4]; + float x5 = shared_x[k - chunk_start + 5]; + float x6 = shared_x[k - chunk_start + 6]; + float x7 = shared_x[k - chunk_start + 7]; + qx_acc = fmaf(x0, w0, qx_acc); + qx_acc = fmaf(x1, w1, qx_acc); + qx_acc = fmaf(x2, w2, qx_acc); + qx_acc = fmaf(x3, w3, qx_acc); + qx_acc = fmaf(x4, w4, qx_acc); + qx_acc = fmaf(x5, w5, qx_acc); + qx_acc = fmaf(x6, w6, qx_acc); + qx_acc = fmaf(x7, w7, qx_acc); + } + for (; k_start + k_local < k_end_g; k_local++) { + int k = k_start + k_local; + float x_val = shared_x[k - chunk_start]; + uint8_t quant_val = + unpack_packed_value_fast(w_row, k, row_bytes); + qx_acc = fmaf( + x_val, + dequantize_value(quant_val, 1.0f, 0.0f), + qx_acc); + } + } else if constexpr (BITS == 6) { + int k_local = lane * 8; + int step = THREADS_PER_COL * 8; + int max_safe_k = ((row_bytes - 7) * 8) / 6; + for (; k_start + k_local + 7 < k_end_g && + k_start + k_local < max_safe_k; + k_local += step) { + int k = k_start + k_local; + int byte_idx = (k * 6) / 8; + int bit_offset = (k * 6) % 8; + uint64_t w_packed = 0; + memcpy(&w_packed, &w_row[byte_idx], 8); + w_packed >>= bit_offset; + float w0 = + dequantize_value<6, false>(w_packed & 0x3F, 1.0f, 0.0f); + float w1 = dequantize_value<6, false>( + (w_packed >> 6) & 0x3F, 1.0f, 0.0f); + float w2 = dequantize_value<6, false>( + (w_packed >> 12) & 0x3F, 1.0f, 0.0f); + float w3 = dequantize_value<6, false>( + (w_packed >> 18) & 0x3F, 1.0f, 0.0f); + float w4 = dequantize_value<6, false>( + (w_packed >> 24) & 0x3F, 1.0f, 0.0f); + float w5 = dequantize_value<6, false>( + (w_packed >> 30) & 0x3F, 1.0f, 0.0f); + float w6 = dequantize_value<6, false>( + (w_packed >> 36) & 0x3F, 1.0f, 0.0f); + float w7 = dequantize_value<6, false>( + (w_packed >> 42) & 0x3F, 1.0f, 0.0f); + float x0 = shared_x[k - chunk_start]; + float x1 = shared_x[k - chunk_start + 1]; + float x2 = shared_x[k - chunk_start + 2]; + float x3 = shared_x[k - chunk_start + 3]; + float x4 = shared_x[k - chunk_start + 4]; + float x5 = shared_x[k - chunk_start + 5]; + float x6 = shared_x[k - chunk_start + 6]; + float x7 = shared_x[k - chunk_start + 7]; + qx_acc = fmaf(x0, w0, qx_acc); + qx_acc = fmaf(x1, w1, qx_acc); + qx_acc = fmaf(x2, w2, qx_acc); + qx_acc = fmaf(x3, w3, qx_acc); + qx_acc = fmaf(x4, w4, qx_acc); + qx_acc = fmaf(x5, w5, qx_acc); + qx_acc = fmaf(x6, w6, qx_acc); + qx_acc = fmaf(x7, w7, qx_acc); + } + for (; k_start + k_local < k_end_g; k_local++) { + int k = k_start + k_local; + float x_val = shared_x[k - chunk_start]; + uint8_t quant_val = + unpack_packed_value_fast(w_row, k, row_bytes); + qx_acc = fmaf( + x_val, + dequantize_value(quant_val, 1.0f, 0.0f), + qx_acc); + } + } else { + for (int k_local = lane; k_start + k_local < k_end_g; + k_local += THREADS_PER_COL) { + int k = k_start + k_local; + float x_val = shared_x[k - chunk_start]; + uint8_t quant_val = + unpack_packed_value_fast(w_row, k, row_bytes); + qx_acc = fmaf( + x_val, + dequantize_value(quant_val, 1.0f, 0.0f), + qx_acc); + } + } + acc += scale * qx_acc; + } + } + } + __syncthreads(); + } + + acc = subgroup_reduce_sum_qmm(acc); + if (valid && lane == 0) { + out_batch_ptr[static_cast(row) * N + col] = static_cast(acc); + } +} + +template < + typename T, + typename ScaleT, + int BITS, + int GROUP_SIZE, + bool AFFINE, + int THREADS_PER_COL> +__global__ void qmv_warp_noshared_kernel( + const T* __restrict__ x, + const uint8_t* __restrict__ w, + const ScaleT* __restrict__ scales, + const ScaleT* __restrict__ biases, + T* __restrict__ out, + int M, + int N, + int K, + bool has_bias) { + const int lane = threadIdx.x; + const int col = blockIdx.x * blockDim.y + threadIdx.y; + const int row = blockIdx.y; + + const bool row_valid = (row < M); + const bool valid = row_valid && (col < N); + + constexpr int kThreadsPerCol = THREADS_PER_COL; + const int num_groups = (K + GROUP_SIZE - 1) / GROUP_SIZE; + const int row_bytes = (K * BITS + 7) / 8; + + const T* x_row = row_valid ? (x + row * K) : nullptr; + const uint8_t* w_row = valid ? (w + col * row_bytes) : nullptr; + const ScaleT* scales_row = valid ? (scales + col * num_groups) : nullptr; + const ScaleT* biases_row = + (valid && has_bias) ? (biases + col * num_groups) : nullptr; + + float acc = 0.0f; + + for (int g = 0; g < num_groups; ++g) { + int k_start = g * GROUP_SIZE; + int k_end = min(k_start + GROUP_SIZE, K); + + if (valid) { + float scale = load_scale_value(scales_row[g]); + float bias = has_bias ? static_cast(biases_row[g]) : 0.0f; + + if constexpr (AFFINE) { + float qx_acc0 = 0.0f, qx_acc1 = 0.0f, qx_acc2 = 0.0f, qx_acc3 = 0.0f; + float x_group_sum = 0.0f; + float qx_acc = 0.0f; + + if constexpr (BITS == 8) { + int k_local = lane * 4; + int step = kThreadsPerCol * 4; + for (; k_start + k_local + 3 < k_end; k_local += step) { + int k = k_start + k_local; + + // Read 4 weights at once + uint32_t w_packed = *reinterpret_cast(&w_row[k]); + float w0 = static_cast(w_packed & 0xFF); + float w1 = static_cast((w_packed >> 8) & 0xFF); + float w2 = static_cast((w_packed >> 16) & 0xFF); + float w3 = static_cast((w_packed >> 24) & 0xFF); + + float x0 = static_cast(x_row[k]); + float x1 = static_cast(x_row[k + 1]); + float x2 = static_cast(x_row[k + 2]); + float x3 = static_cast(x_row[k + 3]); + + qx_acc0 = fmaf(x0, w0, qx_acc0); + qx_acc1 = fmaf(x1, w1, qx_acc1); + qx_acc2 = fmaf(x2, w2, qx_acc2); + qx_acc3 = fmaf(x3, w3, qx_acc3); + + if (has_bias) { + x_group_sum += x0 + x1 + x2 + x3; + } + } + + qx_acc = qx_acc0 + qx_acc1 + qx_acc2 + qx_acc3; + + // Tail loop + for (; k_start + k_local < k_end; k_local++) { + int k = k_start + k_local; + float x_val = static_cast(x_row[k]); + float w_val = static_cast(w_row[k]); + qx_acc = fmaf(x_val, w_val, qx_acc); + if (has_bias) + x_group_sum += x_val; + } + } else { + for (int k_local = lane; k_start + k_local < k_end; + k_local += kThreadsPerCol) { + int k = k_start + k_local; + float x_val = static_cast(x_row[k]); + uint8_t quant_val = + unpack_packed_value_fast(w_row, k, row_bytes); + qx_acc = fmaf(x_val, static_cast(quant_val), qx_acc); + if (has_bias) + x_group_sum += x_val; + } + } + + float group_acc = scale * qx_acc; + if (has_bias) { + group_acc = fmaf(bias, x_group_sum, group_acc); + } + acc += group_acc; + } else { + float qx_acc0 = 0.0f, qx_acc1 = 0.0f, qx_acc2 = 0.0f, qx_acc3 = 0.0f; + float qx_acc = 0.0f; + + if constexpr (BITS == 8) { + int k_local = lane * 4; + int step = kThreadsPerCol * 4; + for (; k_start + k_local + 3 < k_end; k_local += step) { + int k = k_start + k_local; + + // Read 4 weights at once + uint32_t w_packed = *reinterpret_cast(&w_row[k]); + float w0 = fp8_e4m3_to_float(w_packed & 0xFF); + float w1 = fp8_e4m3_to_float((w_packed >> 8) & 0xFF); + float w2 = fp8_e4m3_to_float((w_packed >> 16) & 0xFF); + float w3 = fp8_e4m3_to_float((w_packed >> 24) & 0xFF); + + float x0 = static_cast(x_row[k]); + float x1 = static_cast(x_row[k + 1]); + float x2 = static_cast(x_row[k + 2]); + float x3 = static_cast(x_row[k + 3]); + + qx_acc0 = fmaf(x0, w0, qx_acc0); + qx_acc1 = fmaf(x1, w1, qx_acc1); + qx_acc2 = fmaf(x2, w2, qx_acc2); + qx_acc3 = fmaf(x3, w3, qx_acc3); + } + + qx_acc = qx_acc0 + qx_acc1 + qx_acc2 + qx_acc3; + + for (; k_start + k_local < k_end; k_local++) { + int k = k_start + k_local; + float x_val = static_cast(x_row[k]); + float w_val = fp8_e4m3_to_float(w_row[k]); + qx_acc = fmaf(x_val, w_val, qx_acc); + } + acc += scale * qx_acc; + } else { + float qx_acc = 0.0f; + for (int k_local = lane; k_start + k_local < k_end; + k_local += kThreadsPerCol) { + int k = k_start + k_local; + float x_val = static_cast(x_row[k]); + uint8_t quant_val = + unpack_packed_value_fast(w_row, k, row_bytes); + qx_acc = fmaf( + x_val, + dequantize_value(quant_val, 1.0f, 0.0f), + qx_acc); + } + acc += scale * qx_acc; + } + } + } + } + + acc = subgroup_reduce_sum_qmm(acc); + if (valid && lane == 0) { + out[row * N + col] = static_cast(acc); + } +} + +template +__global__ void qmv_kernel( + const T* __restrict__ x, + const uint8_t* __restrict__ w, + const ScaleT* __restrict__ scales, + const ScaleT* __restrict__ biases, + T* __restrict__ out, + int M, + int N, + int K, + bool has_bias) { + const int row = blockIdx.x; + const int col = blockIdx.y * blockDim.x + threadIdx.x; + + if (row >= M || col >= N) + return; + + float acc = 0.0f; + int num_groups = (K + GROUP_SIZE - 1) / GROUP_SIZE; + const int row_bytes = (K * BITS + 7) / 8; + const uint8_t* w_row = w + col * row_bytes; + + for (int g = 0; g < num_groups; ++g) { + float scale = load_scale_value( + scales[col * num_groups + g]); + float bias = + has_bias ? static_cast(biases[col * num_groups + g]) : 0.0f; + + int k_start = g * GROUP_SIZE; + int k_end = min(k_start + GROUP_SIZE, K); + + float qx_acc = 0.0f; + if constexpr (BITS == 8) { + int k = k_start; + for (; k + 3 < k_end; k += 4) { + uint32_t w_packed = *reinterpret_cast(&w_row[k]); + float w0 = dequantize_value<8, AFFINE>(w_packed & 0xFF, scale, bias); + float w1 = + dequantize_value<8, AFFINE>((w_packed >> 8) & 0xFF, scale, bias); + float w2 = + dequantize_value<8, AFFINE>((w_packed >> 16) & 0xFF, scale, bias); + float w3 = + dequantize_value<8, AFFINE>((w_packed >> 24) & 0xFF, scale, bias); + + qx_acc += static_cast(x[row * K + k]) * w0; + qx_acc += static_cast(x[row * K + k + 1]) * w1; + qx_acc += static_cast(x[row * K + k + 2]) * w2; + qx_acc += static_cast(x[row * K + k + 3]) * w3; + } + for (; k < k_end; ++k) { + float w_val = dequantize_value<8, AFFINE>(w_row[k], scale, bias); + qx_acc += static_cast(x[row * K + k]) * w_val; + } + } else if constexpr (BITS == 6) { + int k = k_start; + int max_safe_k = ((row_bytes - 7) * 8) / 6; + for (; k + 7 < k_end && k < max_safe_k; k += 8) { + int byte_idx = (k * 6) / 8; + int bit_offset = (k * 6) % 8; + uint64_t w_packed = 0; + memcpy(&w_packed, &w_row[byte_idx], 8); + w_packed >>= bit_offset; + + float w0 = dequantize_value<6, AFFINE>(w_packed & 0x3F, scale, bias); + float w1 = + dequantize_value<6, AFFINE>((w_packed >> 6) & 0x3F, scale, bias); + float w2 = + dequantize_value<6, AFFINE>((w_packed >> 12) & 0x3F, scale, bias); + float w3 = + dequantize_value<6, AFFINE>((w_packed >> 18) & 0x3F, scale, bias); + float w4 = + dequantize_value<6, AFFINE>((w_packed >> 24) & 0x3F, scale, bias); + float w5 = + dequantize_value<6, AFFINE>((w_packed >> 30) & 0x3F, scale, bias); + float w6 = + dequantize_value<6, AFFINE>((w_packed >> 36) & 0x3F, scale, bias); + float w7 = + dequantize_value<6, AFFINE>((w_packed >> 42) & 0x3F, scale, bias); + + qx_acc += static_cast(x[row * K + k]) * w0; + qx_acc += static_cast(x[row * K + k + 1]) * w1; + qx_acc += static_cast(x[row * K + k + 2]) * w2; + qx_acc += static_cast(x[row * K + k + 3]) * w3; + qx_acc += static_cast(x[row * K + k + 4]) * w4; + qx_acc += static_cast(x[row * K + k + 5]) * w5; + qx_acc += static_cast(x[row * K + k + 6]) * w6; + qx_acc += static_cast(x[row * K + k + 7]) * w7; + } + for (; k < k_end; ++k) { + uint8_t quant_val = unpack_packed_value_fast(w_row, k, row_bytes); + float w_val = dequantize_value<6, AFFINE>(quant_val, scale, bias); + qx_acc += static_cast(x[row * K + k]) * w_val; + } + } else { + for (int k = k_start; k < k_end; ++k) { + uint8_t quant_val = unpack_packed_value_fast(w_row, k, row_bytes); + float w_val = dequantize_value(quant_val, scale, bias); + qx_acc += static_cast(x[row * K + k]) * w_val; + } + } + acc += qx_acc; + } + + out[row * N + col] = static_cast(acc); +} + +template +__global__ void qmv_t_kernel( + const T* __restrict__ x, + const uint8_t* __restrict__ w, + const ScaleT* __restrict__ scales, + const ScaleT* __restrict__ biases, + T* __restrict__ out, + int M, + int N, + int K, + bool has_bias) { + const int row = blockIdx.x; + const int col = blockIdx.y * blockDim.x + threadIdx.x; + + if (row >= M || col >= N) + return; + + float acc = 0.0f; + int num_groups = (K + GROUP_SIZE - 1) / GROUP_SIZE; + const int row_bytes = (K * BITS + 7) / 8; + const uint8_t* w_row = w + col * row_bytes; + + for (int g = 0; g < num_groups; ++g) { + float scale = load_scale_value( + scales[col * num_groups + g]); + float bias = + has_bias ? static_cast(biases[col * num_groups + g]) : 0.0f; + + int k_start = g * GROUP_SIZE; + int k_end = min(k_start + GROUP_SIZE, K); + + float qx_acc = 0.0f; + if constexpr (BITS == 8) { + int k = k_start; + for (; k + 3 < k_end; k += 4) { + uint32_t w_packed = *reinterpret_cast(&w_row[k]); + float w0 = dequantize_value<8, AFFINE>(w_packed & 0xFF, scale, bias); + float w1 = + dequantize_value<8, AFFINE>((w_packed >> 8) & 0xFF, scale, bias); + float w2 = + dequantize_value<8, AFFINE>((w_packed >> 16) & 0xFF, scale, bias); + float w3 = + dequantize_value<8, AFFINE>((w_packed >> 24) & 0xFF, scale, bias); + + qx_acc += static_cast(x[row * K + k]) * w0; + qx_acc += static_cast(x[row * K + k + 1]) * w1; + qx_acc += static_cast(x[row * K + k + 2]) * w2; + qx_acc += static_cast(x[row * K + k + 3]) * w3; + } + for (; k < k_end; ++k) { + float w_val = dequantize_value<8, AFFINE>(w_row[k], scale, bias); + qx_acc += static_cast(x[row * K + k]) * w_val; + } + } else if constexpr (BITS == 6) { + int k = k_start; + int max_safe_k = ((row_bytes - 7) * 8) / 6; + for (; k + 7 < k_end && k < max_safe_k; k += 8) { + int byte_idx = (k * 6) / 8; + int bit_offset = (k * 6) % 8; + uint64_t w_packed = 0; + memcpy(&w_packed, &w_row[byte_idx], 8); + w_packed >>= bit_offset; + + float w0 = dequantize_value<6, AFFINE>(w_packed & 0x3F, scale, bias); + float w1 = + dequantize_value<6, AFFINE>((w_packed >> 6) & 0x3F, scale, bias); + float w2 = + dequantize_value<6, AFFINE>((w_packed >> 12) & 0x3F, scale, bias); + float w3 = + dequantize_value<6, AFFINE>((w_packed >> 18) & 0x3F, scale, bias); + float w4 = + dequantize_value<6, AFFINE>((w_packed >> 24) & 0x3F, scale, bias); + float w5 = + dequantize_value<6, AFFINE>((w_packed >> 30) & 0x3F, scale, bias); + float w6 = + dequantize_value<6, AFFINE>((w_packed >> 36) & 0x3F, scale, bias); + float w7 = + dequantize_value<6, AFFINE>((w_packed >> 42) & 0x3F, scale, bias); + + qx_acc += static_cast(x[row * K + k]) * w0; + qx_acc += static_cast(x[row * K + k + 1]) * w1; + qx_acc += static_cast(x[row * K + k + 2]) * w2; + qx_acc += static_cast(x[row * K + k + 3]) * w3; + qx_acc += static_cast(x[row * K + k + 4]) * w4; + qx_acc += static_cast(x[row * K + k + 5]) * w5; + qx_acc += static_cast(x[row * K + k + 6]) * w6; + qx_acc += static_cast(x[row * K + k + 7]) * w7; + } + for (; k < k_end; ++k) { + uint8_t quant_val = unpack_packed_value_fast(w_row, k, row_bytes); + float w_val = dequantize_value<6, AFFINE>(quant_val, scale, bias); + qx_acc += static_cast(x[row * K + k]) * w_val; + } + } else { + for (int k = k_start; k < k_end; ++k) { + uint8_t quant_val = unpack_packed_value_fast(w_row, k, row_bytes); + float w_val = dequantize_value(quant_val, scale, bias); + qx_acc += static_cast(x[row * K + k]) * w_val; + } + } + acc += qx_acc; + } + + out[row * N + col] = static_cast(acc); +} + +} // namespace rocm + +void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { + auto& s = stream(); + auto& d = rocm::device(s.device); + auto& enc = d.get_command_encoder(s); + + out.set_data(mlx::core::rocm::malloc_async(out.nbytes(), enc)); + + array x = ensure_row_contiguous_matrix(inputs[0], enc, s); + array w = ensure_row_contiguous_matrix(inputs[1], enc, s); + array scales = ensure_row_contiguous_matrix(inputs[2], enc, s); + std::optional biases = std::nullopt; + bool has_bias = (mode_ == QuantizationMode::Affine) && (inputs.size() == 4); + if (has_bias) { + biases = ensure_row_contiguous_matrix(inputs[3], enc, s); + } + + enc.set_input_array(x); + enc.set_input_array(w); + enc.set_input_array(scales); + if (has_bias) + enc.set_input_array(biases.value()); + enc.set_output_array(out); + + int K = x.shape(-1); + int M = out.shape(-2); + int N = out.shape(-1); + + int64_t matrix_size = static_cast(M) * N; + int batch_count = static_cast(out.size() / matrix_size); + int x_batch_count = static_cast( + x.size() / + (static_cast(x.shape(-2)) * static_cast(x.shape(-1)))); + int w_batch_count = static_cast( + w.size() / + (static_cast(w.shape(-2)) * static_cast(w.shape(-1)))); + + bool x_singleton_batch = has_only_singleton_batch_dims(x); + bool w_singleton_batch = has_only_singleton_batch_dims(w); + bool non_batched = + (batch_count == 1) && x_singleton_batch && w_singleton_batch; + + bool bits_supported_by_qmv = (bits_ == 2 || bits_ == 4 || bits_ == 8) || + (mode_ == QuantizationMode::Affine && (bits_ == 5 || bits_ == 6)); + bool valid_x_batch = (x_batch_count == 1) || (x_batch_count == batch_count); + bool valid_w_batch = (w_batch_count == 1) || (w_batch_count == batch_count); + bool can_use_batched_qmv = transpose_ && bits_supported_by_qmv && + (batch_count > 1) && valid_x_batch && valid_w_batch; + bool force_dequant_gemm = !transpose_ || !bits_supported_by_qmv || + ((batch_count > 1) && !can_use_batched_qmv) || + (w.ndim() > 2 && !w_singleton_batch && !can_use_batched_qmv); + bool dequant_gemm_supported_mode = (mode_ == QuantizationMode::Affine); + bool should_prefer_dequant = should_use_dequant_gemm_path( + M, N, K, batch_count, non_batched, transpose_, can_use_batched_qmv, d); + + // Dequant + rocBLAS GEMM path + // Disable with MLX_ROCM_QMM_DEQUANT_GEMM=0 if needed + if (dequant_gemm_supported_mode && d.is_rocblas_available() && + use_rocblas_dequant_path() && + (force_dequant_gemm || should_prefer_dequant)) { + if (!((x_batch_count == 1) || (x_batch_count == batch_count))) { + throw std::runtime_error( + "Unsupported x batch shape for dequant GEMM fallback"); + } + if (!((w_batch_count == 1) || (w_batch_count == batch_count))) { + throw std::runtime_error( + "Unsupported w batch shape for dequant GEMM fallback"); + } + + int dequant_rows = transpose_ ? N : K; + int dequant_cols = transpose_ ? K : N; + + // fp8 e4m3 path (RDNA4 prefill): dequantize the weight straight to e4m3 and + // cast the activation, then run the GEMM on fp8 matrix cores. Capability- + // gated — devices without e4m3 kernels stay on the bf16 dequant path below. + if ((mode_ == QuantizationMode::Affine) && (x.dtype() == bfloat16) && + (batch_count == 1) && (x_batch_count == 1) && (w_batch_count == 1) && + (M >= 64) && rocm::device_has_fp8_gemm(d.hip_device())) { + dequant_fp8_gemm( + enc, + transpose_, + M, + N, + K, + x, + w, + scales, + biases, + group_size_, + bits_, + out, + dequant_rows, + dequant_cols); + return; + } + + Shape w_dequant_shape = w.shape(); + w_dequant_shape[w_dequant_shape.size() - 2] = dequant_rows; + w_dequant_shape[w_dequant_shape.size() - 1] = dequant_cols; + + array w_dequant(w_dequant_shape, x.dtype(), nullptr, {}); + bool cache_hit = false; + int cache_cap = dequant_cache_capacity(); + size_t cache_max_bytes = dequant_cache_max_bytes(); + if (cache_cap > 0 && cache_max_bytes > 0) { + static std::mutex cache_mutex; + static std::list lru; + static size_t cached_bytes = 0; + static std::unordered_map< + DequantCacheKey, + DequantCacheEntry, + DequantCacheKeyHasher> + cache; + + DequantCacheKey key{ + reinterpret_cast(gpu_ptr(w)), + reinterpret_cast(gpu_ptr(scales)), + has_bias ? reinterpret_cast(gpu_ptr(*biases)) + : 0, + group_size_, + bits_, + s.index, + transpose_, + x.dtype()}; + + { + std::lock_guard lock(cache_mutex); + auto it = cache.find(key); + if (it != cache.end() && it->second.weight.shape() == w_dequant_shape) { + lru.splice(lru.begin(), lru, it->second.lru_it); + w_dequant = it->second.weight; + cache_hit = true; + } + } + + if (!cache_hit) { + w_dequant.set_data(mlx::core::rocm::malloc_async(w_dequant.nbytes(), enc)); + + if (mode_ == QuantizationMode::Affine) { + affine_dequantize( + w, scales, biases, w_dequant, group_size_, bits_, enc, s); + } else { + fp_dequantize(w, scales, w_dequant, group_size_, bits_, enc, s); + } + + std::lock_guard lock(cache_mutex); + auto it = cache.find(key); + if (it == cache.end()) { + size_t entry_bytes = w_dequant.nbytes(); + if (entry_bytes <= cache_max_bytes) { + lru.push_front(key); + cache.emplace( + key, + DequantCacheEntry{ + w_dequant, + w, + scales, + has_bias ? std::optional(*biases) : std::nullopt, + entry_bytes, + lru.begin()}); + cached_bytes += entry_bytes; + + while (static_cast(cache.size()) > cache_cap || + cached_bytes > cache_max_bytes) { + auto evict = lru.back(); + auto evict_it = cache.find(evict); + if (evict_it != cache.end()) { + cached_bytes -= evict_it->second.bytes; + cache.erase(evict_it); + } + lru.pop_back(); + } + } + } else { + size_t entry_bytes = w_dequant.nbytes(); + if (entry_bytes > cache_max_bytes) { + cached_bytes -= it->second.bytes; + lru.erase(it->second.lru_it); + cache.erase(it); + } else { + cached_bytes -= it->second.bytes; + it->second.w_source = w; + it->second.scales_source = scales; + it->second.biases_source = + has_bias ? std::optional(*biases) : std::nullopt; + it->second.weight = w_dequant; + it->second.bytes = entry_bytes; + cached_bytes += it->second.bytes; + lru.splice(lru.begin(), lru, it->second.lru_it); + + while (static_cast(cache.size()) > cache_cap || + cached_bytes > cache_max_bytes) { + auto evict = lru.back(); + auto evict_it = cache.find(evict); + if (evict_it != cache.end()) { + cached_bytes -= evict_it->second.bytes; + cache.erase(evict_it); + } + lru.pop_back(); + } + } + } + } + } else { + w_dequant.set_data(mlx::core::rocm::malloc_async(w_dequant.nbytes(), enc)); + + if (mode_ == QuantizationMode::Affine) { + affine_dequantize( + w, scales, biases, w_dequant, group_size_, bits_, enc, s); + } else { + fp_dequantize(w, scales, w_dequant, group_size_, bits_, enc, s); + } + } + + if (!cache_hit) { + enc.add_temporary(w_dequant); + } + + int lda = K; + int ldb = transpose_ ? K : N; + + if (batch_count == 1 && x_batch_count == 1 && w_batch_count == 1) { + dequant_rocblas_gemm( + enc, + false, + transpose_, + M, + N, + K, + 1.0f, + x, + lda, + w_dequant, + ldb, + 0.0f, + out, + N, + x.dtype()); + } else { + int64_t stride_a = + (x_batch_count == 1) ? 0 : static_cast(x.shape(-2)) * K; + int64_t stride_b = (w_batch_count == 1) + ? 0 + : static_cast(dequant_rows) * dequant_cols; + int64_t stride_c = static_cast(M) * N; + + dequant_rocblas_gemm_batched( + enc, + false, + transpose_, + M, + N, + K, + 1.0f, + x, + lda, + stride_a, + w_dequant, + ldb, + stride_b, + 0.0f, + out, + N, + stride_c, + batch_count, + x.dtype()); + } + return; + } + + bool use_fast_qmv = transpose_ && (non_batched || can_use_batched_qmv); + use_fast_qmv = parse_warp_kernel_env("MLX_ROCM_QMV_USE_WARP", use_fast_qmv); + if (can_use_batched_qmv) { + use_fast_qmv = true; + } + bool use_tiny_k_qmv = should_use_tiny_k_qmv_path( + M, N, K, batch_count, transpose_, can_use_batched_qmv, bits_, mode_); + + int block_size = 256; + dim3 grid(M, (N + block_size - 1) / block_size); + + int fast_threads_per_col = + select_qmv_threads_per_col(K, N, bits_, batch_count); + int fast_threads_env = + parse_threads_per_col_env("MLX_ROCM_QMV_THREADS_PER_COL"); + if (fast_threads_env > 0) + fast_threads_per_col = fast_threads_env; + + int fast_cols_per_block = select_qmv_cols_per_block(K, N, bits_); + if (use_tiny_k_qmv) { + fast_cols_per_block = std::max(fast_cols_per_block, 32); + } + if (fast_threads_per_col == 16 && bits_ == 8 && N >= 2048) { + fast_cols_per_block = std::max(fast_cols_per_block, 64); + } + int max_cols_per_block = rocm::kMaxThreadsPerBlock / fast_threads_per_col; + while (fast_cols_per_block > max_cols_per_block) + fast_cols_per_block /= 2; + while (fast_cols_per_block > 1 && (N % fast_cols_per_block) != 0 && + fast_cols_per_block > 8) { + fast_cols_per_block /= 2; + } + + dim3 fast_block(fast_threads_per_col, fast_cols_per_block); + dim3 fast_grid((N + fast_cols_per_block - 1) / fast_cols_per_block, M); + dim3 fast_grid_batched( + (N + fast_cols_per_block - 1) / fast_cols_per_block, M, batch_count); + + int64_t x_matrix_stride = + static_cast(x.shape(-2)) * static_cast(x.shape(-1)); + int64_t w_matrix_stride = static_cast(w.shape(-2)) * + static_cast(w.shape(-1)) * + static_cast(size_of(w.dtype())); + int num_groups = (K + group_size_ - 1) / group_size_; + int64_t sb_matrix_stride = + static_cast(w.shape(-2)) * static_cast(num_groups); + int64_t out_matrix_stride = static_cast(M) * N; + + int64_t x_batch_stride = (x_batch_count == 1) ? 0 : x_matrix_stride; + int64_t w_batch_stride = (w_batch_count == 1) ? 0 : w_matrix_stride; + int64_t sb_batch_stride = (w_batch_count == 1) ? 0 : sb_matrix_stride; + + const void* x_ptr = gpu_ptr(x); + const uint8_t* w_ptr = gpu_ptr(w); + const void* scales_ptr = gpu_ptr(scales); + const void* biases_ptr = has_bias ? gpu_ptr(*biases) : nullptr; + void* out_ptr = gpu_ptr(out); + + // The noshared variant reads x from global memory redundantly per warp. + // The shared variant caches x in LDS and is ~15x faster for decode shapes. + // Always prefer shared unless K is tiny (where LDS overhead isn't worth it). + bool use_noshared_qmv_variant = use_tiny_k_qmv; + + // L2-optimized tiled QMV with arch-tuned TILE_N. + // TILE_N is passed as a runtime argument — no template instantiation needed. + auto hw_info = detect_rocm_hw_info(enc.device()); + auto arch_tuning = rocm::get_arch_tuning(hw_info); + int tile_n = arch_tuning.qmv_tile_n; + // Allow env override for benchmarking + if (auto env = std::getenv("MLX_ROCM_QMV_TILE_N"); env && *env) + tile_n = std::atoi(env); + // Ensure N alignment + while (tile_n > 1 && N % tile_n != 0) tile_n /= 2; + + static bool use_tiled = (std::getenv("MLX_ROCM_QMV_NO_TILED") == nullptr); + + // Full-wave tiled 6-bit QMV; MLX_ROCM_QMV_6BIT_SLOW reverts to warp_shared. + static bool use_6bit_tiled = + (std::getenv("MLX_ROCM_QMV_6BIT_SLOW") == nullptr); + // GROUP_SIZE must be a multiple of 16 (the per-lane value count) so each + // lane's 16 weights fall in a single group → one scale/bias per lane, as the + // tiled accumulation assumes. 32/64/128 all satisfy this. + bool gs6_supported = + (group_size_ == 32 || group_size_ == 64 || group_size_ == 128); + bool x6_dtype_supported = + (x.dtype() == bfloat16 || x.dtype() == float16); + if (use_6bit_tiled && use_tiled && bits_ == 6 && (K % 64) == 0 && + gs6_supported && x6_dtype_supported && use_fast_qmv && + !can_use_batched_qmv && tile_n >= 8 && + mode_ == QuantizationMode::Affine) { + { + dim3 tiled_block(WARP_SIZE, tile_n); + const int n_tiles = (N + tile_n - 1) / tile_n; + int blocks_per_cu = (hw_info.max_threads_per_cu > 0) + ? (hw_info.max_threads_per_cu / (tile_n * WARP_SIZE)) : 4; + if (blocks_per_cu < 1) blocks_per_cu = 1; + int persistent_y = + (hw_info.num_cus > 0) ? hw_info.num_cus * blocks_per_cu : n_tiles; + int grid_y = (n_tiles < persistent_y) ? n_tiles : persistent_y; + if (grid_y < 1) grid_y = 1; + dim3 tiled_grid(M, grid_y); + + #define LAUNCH_TILED_6BIT(T, ScaleT, GS_V) \ + enc.add_kernel_node( \ + &rocm::qmv_tiled_6bit_kernel, \ + tiled_grid, tiled_block, 0u, \ + (const T*)x_ptr, (const uint8_t*)w_ptr, \ + (const ScaleT*)scales_ptr, (const ScaleT*)biases_ptr, \ + (T*)out_ptr, M, N, K, has_bias, tile_n, n_tiles) + + if (x.dtype() == bfloat16) { + if (group_size_ == 32) { LAUNCH_TILED_6BIT(hip_bfloat16, hip_bfloat16, 32); } + else if (group_size_ == 64) { LAUNCH_TILED_6BIT(hip_bfloat16, hip_bfloat16, 64); } + else if (group_size_ == 128) { LAUNCH_TILED_6BIT(hip_bfloat16, hip_bfloat16, 128); } + } else if (x.dtype() == float16) { + if (group_size_ == 32) { LAUNCH_TILED_6BIT(__half, __half, 32); } + else if (group_size_ == 64) { LAUNCH_TILED_6BIT(__half, __half, 64); } + else if (group_size_ == 128) { LAUNCH_TILED_6BIT(__half, __half, 128); } + } + #undef LAUNCH_TILED_6BIT + } + return; + } + + // The tiled QMV kernel (qdequant.hpp pack_factor_u32 = 32/BITS) only packs + // correctly for power-of-two widths and is only instantiated for 4/8-bit; + // other widths would match nothing here and leave `out` uninitialized. + // Restrict to 4/8-bit; 2/5/6-bit fall through to the warp-shared QMV kernel. + bool tiled_bits_supported = (bits_ == 4 || bits_ == 8); + if (use_tiled && tiled_bits_supported && use_fast_qmv && + !can_use_batched_qmv && tile_n >= 8 && + mode_ == QuantizationMode::Affine) { + { + dim3 tiled_block(WARP_SIZE, tile_n); + const int n_tiles = (N + tile_n - 1) / tile_n; + // Persistent grid: CU-bounded block count, kernel grid-strides the rest. + int blocks_per_cu = (hw_info.max_threads_per_cu > 0) + ? (hw_info.max_threads_per_cu / (tile_n * WARP_SIZE)) : 4; + if (blocks_per_cu < 1) blocks_per_cu = 1; + int persistent_y = + (hw_info.num_cus > 0) ? hw_info.num_cus * blocks_per_cu : n_tiles; + int grid_y = (n_tiles < persistent_y) ? n_tiles : persistent_y; + if (grid_y < 1) grid_y = 1; + dim3 tiled_grid(M, grid_y); + + #define LAUNCH_TILED(T, ScaleT, BITS_V, GS_V) \ + enc.add_kernel_node( \ + &rocm::qmv_tiled_kernel, \ + tiled_grid, tiled_block, 0u, \ + (const T*)x_ptr, (const uint32_t*)w_ptr, \ + (const ScaleT*)scales_ptr, (const ScaleT*)biases_ptr, \ + (T*)out_ptr, M, N, K, has_bias, tile_n, n_tiles) + + if (x.dtype() == bfloat16) { + if (bits_ == 4) { + if (group_size_ == 32) { LAUNCH_TILED(hip_bfloat16, hip_bfloat16, 4, 32); } + else if (group_size_ == 64) { LAUNCH_TILED(hip_bfloat16, hip_bfloat16, 4, 64); } + else if (group_size_ == 128) { LAUNCH_TILED(hip_bfloat16, hip_bfloat16, 4, 128); } + } else if (bits_ == 8) { + if (group_size_ == 32) { LAUNCH_TILED(hip_bfloat16, hip_bfloat16, 8, 32); } + else if (group_size_ == 64) { LAUNCH_TILED(hip_bfloat16, hip_bfloat16, 8, 64); } + else if (group_size_ == 128) { LAUNCH_TILED(hip_bfloat16, hip_bfloat16, 8, 128); } + } + } else if (x.dtype() == float16) { + if (bits_ == 4) { + if (group_size_ == 32) { LAUNCH_TILED(__half, __half, 4, 32); } + else if (group_size_ == 64) { LAUNCH_TILED(__half, __half, 4, 64); } + else if (group_size_ == 128) { LAUNCH_TILED(__half, __half, 4, 128); } + } else if (bits_ == 8) { + if (group_size_ == 32) { LAUNCH_TILED(__half, __half, 8, 32); } + else if (group_size_ == 64) { LAUNCH_TILED(__half, __half, 8, 64); } + else if (group_size_ == 128) { LAUNCH_TILED(__half, __half, 8, 128); } + } + } + #undef LAUNCH_TILED + } + return; + } + + // The noshared path used to increase cols_per_block for aligned data. + // Since we always use the shared variant now, no special grid adjustment needed. + + { + auto launch_qmv = + [&](auto type_tag, auto scale_tag, auto bits_tag, auto gs_tag) { + using T = typename decltype(type_tag)::type; + using ScaleT = typename decltype(scale_tag)::type; + constexpr int BITS = bits_tag.value; + constexpr int GROUP_SIZE = gs_tag.value; + + if (mode_ == QuantizationMode::Affine) { + if (use_fast_qmv) { + if (can_use_batched_qmv) { + if (fast_threads_per_col == 16) { + enc.add_kernel_node( + &rocm::qmv_warp_shared_batched_kernel< + T, + ScaleT, + BITS, + GROUP_SIZE, + true, + 16>, + fast_grid_batched, + fast_block, + 0u, + (const T*)x_ptr, + w_ptr, + (const ScaleT*)scales_ptr, + (const ScaleT*)biases_ptr, + (T*)out_ptr, + M, + N, + K, + x_batch_stride, + w_batch_stride, + sb_batch_stride, + out_matrix_stride, + has_bias); + } else { + enc.add_kernel_node( + &rocm::qmv_warp_shared_batched_kernel< + T, + ScaleT, + BITS, + GROUP_SIZE, + true, + WARP_SIZE>, + fast_grid_batched, + fast_block, + 0u, + (const T*)x_ptr, + w_ptr, + (const ScaleT*)scales_ptr, + (const ScaleT*)biases_ptr, + (T*)out_ptr, + M, + N, + K, + x_batch_stride, + w_batch_stride, + sb_batch_stride, + out_matrix_stride, + has_bias); + } + } else { + if (use_noshared_qmv_variant) { + if (fast_threads_per_col == 16) { + enc.add_kernel_node( + &rocm::qmv_warp_noshared_kernel< + T, + ScaleT, + BITS, + GROUP_SIZE, + true, + 16>, + fast_grid, + fast_block, + 0u, + (const T*)x_ptr, + w_ptr, + (const ScaleT*)scales_ptr, + (const ScaleT*)biases_ptr, + (T*)out_ptr, + M, + N, + K, + has_bias); + } else { + enc.add_kernel_node( + &rocm::qmv_warp_noshared_kernel< + T, + ScaleT, + BITS, + GROUP_SIZE, + true, + WARP_SIZE>, + fast_grid, + fast_block, + 0u, + (const T*)x_ptr, + w_ptr, + (const ScaleT*)scales_ptr, + (const ScaleT*)biases_ptr, + (T*)out_ptr, + M, + N, + K, + has_bias); + } + } else { + if (fast_threads_per_col == 16) { + enc.add_kernel_node( + &rocm::qmv_warp_shared_kernel< + T, + ScaleT, + BITS, + GROUP_SIZE, + true, + 16>, + fast_grid, + fast_block, + 0u, + (const T*)x_ptr, + w_ptr, + (const ScaleT*)scales_ptr, + (const ScaleT*)biases_ptr, + (T*)out_ptr, + M, + N, + K, + has_bias); + } else { + enc.add_kernel_node( + &rocm::qmv_warp_shared_kernel< + T, + ScaleT, + BITS, + GROUP_SIZE, + true, + WARP_SIZE>, + fast_grid, + fast_block, + 0u, + (const T*)x_ptr, + w_ptr, + (const ScaleT*)scales_ptr, + (const ScaleT*)biases_ptr, + (T*)out_ptr, + M, + N, + K, + has_bias); + } + } + } + } else if (transpose_) { + enc.add_kernel_node( + &rocm::qmv_t_kernel, + grid, + dim3(block_size), + 0u, + (const T*)x_ptr, + w_ptr, + (const ScaleT*)scales_ptr, + (const ScaleT*)biases_ptr, + (T*)out_ptr, + M, + N, + K, + has_bias); + } else { + enc.add_kernel_node( + &rocm::qmv_kernel, + grid, + dim3(block_size), + 0u, + (const T*)x_ptr, + w_ptr, + (const ScaleT*)scales_ptr, + (const ScaleT*)biases_ptr, + (T*)out_ptr, + M, + N, + K, + has_bias); + } + } else { + if (use_fast_qmv) { + if (can_use_batched_qmv) { + if (fast_threads_per_col == 16) { + enc.add_kernel_node( + &rocm::qmv_warp_shared_batched_kernel< + T, + ScaleT, + BITS, + GROUP_SIZE, + false, + 16>, + fast_grid_batched, + fast_block, + 0u, + (const T*)x_ptr, + w_ptr, + (const ScaleT*)scales_ptr, + (const ScaleT*)biases_ptr, + (T*)out_ptr, + M, + N, + K, + x_batch_stride, + w_batch_stride, + sb_batch_stride, + out_matrix_stride, + has_bias); + } else { + enc.add_kernel_node( + &rocm::qmv_warp_shared_batched_kernel< + T, + ScaleT, + BITS, + GROUP_SIZE, + false, + WARP_SIZE>, + fast_grid_batched, + fast_block, + 0u, + (const T*)x_ptr, + w_ptr, + (const ScaleT*)scales_ptr, + (const ScaleT*)biases_ptr, + (T*)out_ptr, + M, + N, + K, + x_batch_stride, + w_batch_stride, + sb_batch_stride, + out_matrix_stride, + has_bias); + } + } else { + if (use_noshared_qmv_variant) { + if (fast_threads_per_col == 16) { + enc.add_kernel_node( + &rocm::qmv_warp_noshared_kernel< + T, + ScaleT, + BITS, + GROUP_SIZE, + false, + 16>, + fast_grid, + fast_block, + 0u, + (const T*)x_ptr, + w_ptr, + (const ScaleT*)scales_ptr, + (const ScaleT*)biases_ptr, + (T*)out_ptr, + M, + N, + K, + has_bias); + } else { + enc.add_kernel_node( + &rocm::qmv_warp_noshared_kernel< + T, + ScaleT, + BITS, + GROUP_SIZE, + false, + WARP_SIZE>, + fast_grid, + fast_block, + 0u, + (const T*)x_ptr, + w_ptr, + (const ScaleT*)scales_ptr, + (const ScaleT*)biases_ptr, + (T*)out_ptr, + M, + N, + K, + has_bias); + } + } else { + if (fast_threads_per_col == 16) { + enc.add_kernel_node( + &rocm::qmv_warp_shared_kernel< + T, + ScaleT, + BITS, + GROUP_SIZE, + false, + 16>, + fast_grid, + fast_block, + 0u, + (const T*)x_ptr, + w_ptr, + (const ScaleT*)scales_ptr, + (const ScaleT*)biases_ptr, + (T*)out_ptr, + M, + N, + K, + has_bias); + } else { + enc.add_kernel_node( + &rocm::qmv_warp_shared_kernel< + T, + ScaleT, + BITS, + GROUP_SIZE, + false, + WARP_SIZE>, + fast_grid, + fast_block, + 0u, + (const T*)x_ptr, + w_ptr, + (const ScaleT*)scales_ptr, + (const ScaleT*)biases_ptr, + (T*)out_ptr, + M, + N, + K, + has_bias); + } + } + } + } else if (transpose_) { + enc.add_kernel_node( + &rocm::qmv_t_kernel, + grid, + dim3(block_size), + 0u, + (const T*)x_ptr, + w_ptr, + (const ScaleT*)scales_ptr, + (const ScaleT*)biases_ptr, + (T*)out_ptr, + M, + N, + K, + has_bias); + } else { + enc.add_kernel_node( + &rocm::qmv_kernel, + grid, + dim3(block_size), + 0u, + (const T*)x_ptr, + w_ptr, + (const ScaleT*)scales_ptr, + (const ScaleT*)biases_ptr, + (T*)out_ptr, + M, + N, + K, + has_bias); + } + } + }; + + // Type aliases to avoid template angle brackets in macro args + using float_id = local_type_identity; + using half_id = local_type_identity<__half>; + using bf16_id = local_type_identity; + using bits2 = std::integral_constant; + using bits4 = std::integral_constant; + using bits5 = std::integral_constant; + using bits6 = std::integral_constant; + using bits8 = std::integral_constant; + using gs32 = std::integral_constant; + using gs64 = std::integral_constant; + using gs128 = std::integral_constant; + +// Helper macro to dispatch group_size +#define DISPATCH_GROUP_SIZE(type_tag, scale_tag, bits_tag) \ + do { \ + switch (group_size_) { \ + case 32: \ + launch_qmv(type_tag, scale_tag, bits_tag, gs32{}); \ + break; \ + case 64: \ + launch_qmv(type_tag, scale_tag, bits_tag, gs64{}); \ + break; \ + case 128: \ + launch_qmv(type_tag, scale_tag, bits_tag, gs128{}); \ + break; \ + default: \ + throw std::runtime_error( \ + "Unsupported group_size for QuantizedMatmul: " + \ + std::to_string(group_size_)); \ + } \ + } while (0) + + if (x.dtype() == float32) { + if (bits_ == 8) + DISPATCH_GROUP_SIZE(float_id{}, float_id{}, bits8{}); + else if (bits_ == 5 && mode_ == QuantizationMode::Affine) { + DISPATCH_GROUP_SIZE(float_id{}, float_id{}, bits5{}); + } else if (bits_ == 6 && mode_ == QuantizationMode::Affine) { + DISPATCH_GROUP_SIZE(float_id{}, float_id{}, bits6{}); + } else if (bits_ == 4) + DISPATCH_GROUP_SIZE(float_id{}, float_id{}, bits4{}); + else if (bits_ == 2) + DISPATCH_GROUP_SIZE(float_id{}, float_id{}, bits2{}); + else + throw std::runtime_error( + "Unsupported bits for QuantizedMatmul float32: " + + std::to_string(bits_)); + } else if (x.dtype() == float16) { + if (bits_ == 8) + DISPATCH_GROUP_SIZE(half_id{}, half_id{}, bits8{}); + else if (bits_ == 5 && mode_ == QuantizationMode::Affine) { + DISPATCH_GROUP_SIZE(half_id{}, half_id{}, bits5{}); + } else if (bits_ == 6 && mode_ == QuantizationMode::Affine) { + DISPATCH_GROUP_SIZE(half_id{}, half_id{}, bits6{}); + } else if (bits_ == 4) + DISPATCH_GROUP_SIZE(half_id{}, half_id{}, bits4{}); + else if (bits_ == 2) + DISPATCH_GROUP_SIZE(half_id{}, half_id{}, bits2{}); + else + throw std::runtime_error( + "Unsupported bits for QuantizedMatmul float16: " + + std::to_string(bits_)); + } else if (x.dtype() == bfloat16) { + if (bits_ == 8) + DISPATCH_GROUP_SIZE(bf16_id{}, bf16_id{}, bits8{}); + else if (bits_ == 5 && mode_ == QuantizationMode::Affine) { + DISPATCH_GROUP_SIZE(bf16_id{}, bf16_id{}, bits5{}); + } else if (bits_ == 6 && mode_ == QuantizationMode::Affine) { + DISPATCH_GROUP_SIZE(bf16_id{}, bf16_id{}, bits6{}); + } else if (bits_ == 4) + DISPATCH_GROUP_SIZE(bf16_id{}, bf16_id{}, bits4{}); + else if (bits_ == 2) + DISPATCH_GROUP_SIZE(bf16_id{}, bf16_id{}, bits2{}); + else + throw std::runtime_error( + "Unsupported bits for QuantizedMatmul bfloat16: " + + std::to_string(bits_)); + } else { + throw std::runtime_error("Unsupported dtype for QuantizedMatmul"); + } + +#undef DISPATCH_GROUP_SIZE + } +} + +namespace rocm { + +// ====================================================================== +// GPU-only expert-batched gather QMV for sorted indices. +// +// Grid: (M, ceil(N/cols_per_block), max_unique_experts) +// Each block in z-dimension finds its expert by binary-searching the sorted +// rhs_indices array. No CPU-side run computation needed. +// +// The kernel reads the weight column ONCE per expert and iterates over all +// batch elements assigned to that expert, amortizing weight memory traffic. +// ====================================================================== +template < + typename T, + typename ScaleT, + int BITS, + int GROUP_SIZE, + bool AFFINE, + int THREADS_PER_COL> +__global__ void __launch_bounds__(1024) gather_qmv_expert_batched_kernel( + const T* __restrict__ x, + const uint8_t* __restrict__ w, + const ScaleT* __restrict__ scales, + const ScaleT* __restrict__ biases, + const uint32_t* __restrict__ lhs_indices, + const uint32_t* __restrict__ rhs_indices, // SORTED + T* __restrict__ out, + int B, + int M, + int N, + int K, + int E, + bool has_bias, + bool implicit_lhs, + int64_t implicit_x_batch_stride) { + const int lane = threadIdx.x; + const int warp_idx = threadIdx.y; + const int col = blockIdx.y * blockDim.y + warp_idx; + const int row = blockIdx.x; + const int expert_slot = blockIdx.z; // which unique expert this block handles + + if (row >= M || col >= N) return; + + // Find this expert's token range using the expert_slot as a run index. + // Since rhs_indices is sorted, run boundaries are where values change. + // We use a parallel scan: all threads cooperate to count unique experts + // up to expert_slot, then binary-search for the run boundaries. + // + // Fast path: lane 0 does a boundary skip using binary search. + int run_start = 0, run_end = 0; + uint32_t expert_id = 0; + + if (lane == 0 && warp_idx == 0) { + // Skip to the expert_slot-th unique expert by jumping over run boundaries. + // Each boundary is where rhs_indices[i] != rhs_indices[i-1]. + int pos = 0; + for (int skip = 0; skip < expert_slot && pos < B; ++skip) { + // Binary search for end of current run (first index where value differs) + uint32_t cur_val = rhs_indices[pos]; + int lo = pos + 1, hi = B; + while (lo < hi) { + int mid = (lo + hi) >> 1; + if (rhs_indices[mid] == cur_val) lo = mid + 1; + else hi = mid; + } + pos = lo; + } + if (pos < B) { + run_start = pos; + expert_id = rhs_indices[pos]; + // Binary search for end of this expert's run + int lo = pos + 1, hi = B; + while (lo < hi) { + int mid = (lo + hi) >> 1; + if (rhs_indices[mid] == expert_id) lo = mid + 1; + else hi = mid; + } + run_end = lo; + } + } + + // Broadcast via shared memory + __shared__ int s_run_start, s_run_end; + __shared__ uint32_t s_expert_id; + if (lane == 0 && warp_idx == 0) { + s_run_start = run_start; + s_run_end = run_end; + s_expert_id = expert_id; + } + __syncthreads(); + run_start = s_run_start; + run_end = s_run_end; + expert_id = s_expert_id; + + if (run_end <= run_start) return; // this block has no work + if (expert_id >= static_cast(E)) return; + + // Weight pointers for this expert (loaded ONCE, reused for all tokens in run) + const int num_groups = (K + GROUP_SIZE - 1) / GROUP_SIZE; + const int row_bytes = (K * BITS + 7) / 8; + int64_t w_expert_stride = static_cast(N) * row_bytes; + int64_t sb_expert_stride = static_cast(N) * num_groups; + + const uint8_t* w_row = w + static_cast(expert_id) * w_expert_stride + + static_cast(col) * row_bytes; + const ScaleT* scales_row = scales + static_cast(expert_id) * sb_expert_stride + + static_cast(col) * num_groups; + const ScaleT* biases_row = has_bias + ? (biases + static_cast(expert_id) * sb_expert_stride + + static_cast(col) * num_groups) + : nullptr; + + // Process each batch element in the run + int64_t x_batch_stride = static_cast(M) * K; + for (int b = run_start; b < run_end; ++b) { + uint32_t lhs_idx = implicit_lhs ? 0u : lhs_indices[b]; + int64_t x_offset = implicit_lhs + ? (static_cast(b) * implicit_x_batch_stride) + : (static_cast(lhs_idx) * x_batch_stride); + const T* x_row = x + x_offset + static_cast(row) * K; + + float acc = 0.0f; + + for (int g = 0; g < num_groups; ++g) { + int k_start = g * GROUP_SIZE; + int k_end = min(k_start + GROUP_SIZE, K); + + float scale = load_scale_value(scales_row[g]); + float bias_val = has_bias ? static_cast(biases_row[g]) : 0.0f; + + if constexpr (AFFINE) { + float qx_acc = 0.0f; + float x_group_sum = 0.0f; + + if constexpr (BITS == 4) { + int k_local = lane * 8; + int step = THREADS_PER_COL * 8; + for (; k_start + k_local + 7 < k_end; k_local += step) { + int k = k_start + k_local; + uint32_t w_packed = *reinterpret_cast(&w_row[k / 2]); + float w0 = static_cast(w_packed & 0xF); + float w1 = static_cast((w_packed >> 4) & 0xF); + float w2 = static_cast((w_packed >> 8) & 0xF); + float w3 = static_cast((w_packed >> 12) & 0xF); + float w4 = static_cast((w_packed >> 16) & 0xF); + float w5 = static_cast((w_packed >> 20) & 0xF); + float w6 = static_cast((w_packed >> 24) & 0xF); + float w7 = static_cast((w_packed >> 28) & 0xF); + float x0 = static_cast(x_row[k]); + float x1 = static_cast(x_row[k + 1]); + float x2 = static_cast(x_row[k + 2]); + float x3 = static_cast(x_row[k + 3]); + float x4 = static_cast(x_row[k + 4]); + float x5 = static_cast(x_row[k + 5]); + float x6 = static_cast(x_row[k + 6]); + float x7 = static_cast(x_row[k + 7]); + qx_acc = fmaf(x0, w0, qx_acc); + qx_acc = fmaf(x1, w1, qx_acc); + qx_acc = fmaf(x2, w2, qx_acc); + qx_acc = fmaf(x3, w3, qx_acc); + qx_acc = fmaf(x4, w4, qx_acc); + qx_acc = fmaf(x5, w5, qx_acc); + qx_acc = fmaf(x6, w6, qx_acc); + qx_acc = fmaf(x7, w7, qx_acc); + if (has_bias) + x_group_sum += x0 + x1 + x2 + x3 + x4 + x5 + x6 + x7; + } + for (; k_start + k_local < k_end; k_local++) { + int k = k_start + k_local; + float x_val = static_cast(x_row[k]); + uint8_t quant_val = unpack_packed_value_fast(w_row, k, row_bytes); + qx_acc = fmaf(x_val, static_cast(quant_val), qx_acc); + if (has_bias) x_group_sum += x_val; + } + } else if constexpr (BITS == 8) { + int k_local = lane * 4; + int step = THREADS_PER_COL * 4; + for (; k_start + k_local + 3 < k_end; k_local += step) { + int k = k_start + k_local; + uint32_t w_packed = *reinterpret_cast(&w_row[k]); + float w0 = static_cast(w_packed & 0xFF); + float w1 = static_cast((w_packed >> 8) & 0xFF); + float w2 = static_cast((w_packed >> 16) & 0xFF); + float w3 = static_cast((w_packed >> 24) & 0xFF); + float x0 = static_cast(x_row[k]); + float x1 = static_cast(x_row[k + 1]); + float x2 = static_cast(x_row[k + 2]); + float x3 = static_cast(x_row[k + 3]); + qx_acc = fmaf(x0, w0, qx_acc); + qx_acc = fmaf(x1, w1, qx_acc); + qx_acc = fmaf(x2, w2, qx_acc); + qx_acc = fmaf(x3, w3, qx_acc); + if (has_bias) x_group_sum += x0 + x1 + x2 + x3; + } + for (; k_start + k_local < k_end; k_local++) { + int k = k_start + k_local; + float x_val = static_cast(x_row[k]); + float w_val = static_cast(w_row[k]); + qx_acc = fmaf(x_val, w_val, qx_acc); + if (has_bias) x_group_sum += x_val; + } + } else { + for (int k_local = lane; k_start + k_local < k_end; k_local += THREADS_PER_COL) { + int k = k_start + k_local; + float x_val = static_cast(x_row[k]); + uint8_t quant_val = unpack_packed_value_fast(w_row, k, row_bytes); + qx_acc = fmaf(x_val, static_cast(quant_val), qx_acc); + if (has_bias) x_group_sum += x_val; + } + } + + qx_acc = subgroup_reduce_sum_qmm(qx_acc); + x_group_sum = subgroup_reduce_sum_qmm(x_group_sum); + acc += scale * qx_acc + bias_val * x_group_sum; + } else { + float qx_acc = 0.0f; + for (int k_local = lane; k_start + k_local < k_end; k_local += THREADS_PER_COL) { + int k = k_start + k_local; + float x_val = static_cast(x_row[k]); + uint8_t quant_val = unpack_packed_value_fast(w_row, k, row_bytes); + qx_acc = fmaf(x_val, dequantize_value(quant_val, 1.0f, 0.0f), qx_acc); + } + acc += scale * subgroup_reduce_sum_qmm(qx_acc); + } + } + + if (lane == 0) { + out[static_cast(b) * M * N + static_cast(row) * N + col] = static_cast(acc); + } + } +} + +// ====================================================================== +// Prefill-optimized gather QMV: groups batch elements by expert. +// +// For sorted rhs_indices, consecutive batch elements hit the same expert. +// This kernel assigns blockIdx.z to contiguous runs of same-expert batches, +// so all rows for one expert share weight reads from global memory. +// Each block handles one column (via warp cooperation) and iterates over +// all M rows for each batch element in the run. +// +// Grid: (num_runs, ceil(N/cols_per_block), max_rows_per_run) +// Where num_runs = number of contiguous expert runs. +// ====================================================================== +template < + typename T, + typename ScaleT, + int BITS, + int GROUP_SIZE, + bool AFFINE, + int THREADS_PER_COL> +__global__ void __launch_bounds__(1024) gather_qmv_prefill_kernel( + const T* __restrict__ x, + const uint8_t* __restrict__ w, + const ScaleT* __restrict__ scales, + const ScaleT* __restrict__ biases, + const uint32_t* __restrict__ lhs_indices, + const uint32_t* __restrict__ rhs_indices, + const int* __restrict__ run_starts, // [num_runs]: start batch idx of each run + const int* __restrict__ run_lengths, // [num_runs]: length of each run + const int* __restrict__ out_perm, // [B]: sorted batch idx → original batch idx + T* __restrict__ out, + int B, + int M, + int N, + int K, + int E, + bool has_bias, + int64_t x_batch_stride) { + const int lane = threadIdx.x; + const int warp_idx = threadIdx.y; + const int col = blockIdx.y * blockDim.y + warp_idx; + const int run_id = blockIdx.z; + const int row = blockIdx.x; + + if (row >= M || col >= N) return; + + int run_start = run_starts[run_id]; + int run_len = run_lengths[run_id]; + + // All batches in this run have the same expert + uint32_t rhs_idx = rhs_indices[run_start]; + if (rhs_idx >= static_cast(E)) return; + + // Weight pointers (same for all batches in run) + const int num_groups = (K + GROUP_SIZE - 1) / GROUP_SIZE; + const int row_bytes = (K * BITS + 7) / 8; + int64_t w_expert_stride = static_cast(N) * row_bytes; + int64_t sb_expert_stride = static_cast(N) * num_groups; + int64_t col_w_offset = static_cast(col) * row_bytes; + int64_t col_sb_offset = static_cast(col) * num_groups; + + const uint8_t* w_row = w + static_cast(rhs_idx) * w_expert_stride + col_w_offset; + const ScaleT* scales_row = scales + static_cast(rhs_idx) * sb_expert_stride + col_sb_offset; + const ScaleT* biases_row = has_bias + ? (biases + static_cast(rhs_idx) * sb_expert_stride + col_sb_offset) + : nullptr; + + // Process each batch element in the run + for (int r = 0; r < run_len; ++r) { + int batch = run_start + r; + uint32_t lhs_idx = lhs_indices[batch]; + const T* x_row = x + static_cast(lhs_idx) * x_batch_stride + static_cast(row) * K; + + float acc = 0.0f; + + for (int g = 0; g < num_groups; ++g) { + int k_start = g * GROUP_SIZE; + int k_end = min(k_start + GROUP_SIZE, K); + + float scale = load_scale_value(scales_row[g]); + float bias_val = has_bias ? static_cast(biases_row[g]) : 0.0f; + + if constexpr (AFFINE) { + float qx_acc = 0.0f; + float x_group_sum = 0.0f; + + if constexpr (BITS == 4) { + int k_local = lane * 8; + int step = THREADS_PER_COL * 8; + for (; k_start + k_local + 7 < k_end; k_local += step) { + int k = k_start + k_local; + uint32_t w_packed = *reinterpret_cast(&w_row[k / 2]); + float w0 = static_cast(w_packed & 0xF); + float w1 = static_cast((w_packed >> 4) & 0xF); + float w2 = static_cast((w_packed >> 8) & 0xF); + float w3 = static_cast((w_packed >> 12) & 0xF); + float w4 = static_cast((w_packed >> 16) & 0xF); + float w5 = static_cast((w_packed >> 20) & 0xF); + float w6 = static_cast((w_packed >> 24) & 0xF); + float w7 = static_cast((w_packed >> 28) & 0xF); + float x0 = static_cast(x_row[k]); + float x1 = static_cast(x_row[k + 1]); + float x2 = static_cast(x_row[k + 2]); + float x3 = static_cast(x_row[k + 3]); + float x4 = static_cast(x_row[k + 4]); + float x5 = static_cast(x_row[k + 5]); + float x6 = static_cast(x_row[k + 6]); + float x7 = static_cast(x_row[k + 7]); + qx_acc = fmaf(x0, w0, qx_acc); + qx_acc = fmaf(x1, w1, qx_acc); + qx_acc = fmaf(x2, w2, qx_acc); + qx_acc = fmaf(x3, w3, qx_acc); + qx_acc = fmaf(x4, w4, qx_acc); + qx_acc = fmaf(x5, w5, qx_acc); + qx_acc = fmaf(x6, w6, qx_acc); + qx_acc = fmaf(x7, w7, qx_acc); + if (has_bias) + x_group_sum += x0 + x1 + x2 + x3 + x4 + x5 + x6 + x7; + } + for (; k_start + k_local < k_end; k_local++) { + int k = k_start + k_local; + float x_val = static_cast(x_row[k]); + uint8_t quant_val = unpack_packed_value_fast(w_row, k, row_bytes); + qx_acc = fmaf(x_val, static_cast(quant_val), qx_acc); + if (has_bias) x_group_sum += x_val; + } + } else if constexpr (BITS == 8) { + int k_local = lane * 4; + int step = THREADS_PER_COL * 4; + for (; k_start + k_local + 3 < k_end; k_local += step) { + int k = k_start + k_local; + uint32_t w_packed = *reinterpret_cast(&w_row[k]); + float w0 = static_cast(w_packed & 0xFF); + float w1 = static_cast((w_packed >> 8) & 0xFF); + float w2 = static_cast((w_packed >> 16) & 0xFF); + float w3 = static_cast((w_packed >> 24) & 0xFF); + float x0 = static_cast(x_row[k]); + float x1 = static_cast(x_row[k + 1]); + float x2 = static_cast(x_row[k + 2]); + float x3 = static_cast(x_row[k + 3]); + qx_acc = fmaf(x0, w0, qx_acc); + qx_acc = fmaf(x1, w1, qx_acc); + qx_acc = fmaf(x2, w2, qx_acc); + qx_acc = fmaf(x3, w3, qx_acc); + if (has_bias) x_group_sum += x0 + x1 + x2 + x3; + } + for (; k_start + k_local < k_end; k_local++) { + int k = k_start + k_local; + float x_val = static_cast(x_row[k]); + float w_val = static_cast(w_row[k]); + qx_acc = fmaf(x_val, w_val, qx_acc); + if (has_bias) x_group_sum += x_val; + } + } else { + for (int k_local = lane; k_start + k_local < k_end; k_local += THREADS_PER_COL) { + int k = k_start + k_local; + float x_val = static_cast(x_row[k]); + uint8_t quant_val = unpack_packed_value_fast(w_row, k, row_bytes); + qx_acc = fmaf(x_val, static_cast(quant_val), qx_acc); + if (has_bias) x_group_sum += x_val; + } + } + + qx_acc = subgroup_reduce_sum_qmm(qx_acc); + x_group_sum = subgroup_reduce_sum_qmm(x_group_sum); + acc += scale * qx_acc + bias_val * x_group_sum; + } else { + float qx_acc = 0.0f; + for (int k_local = lane; k_start + k_local < k_end; k_local += THREADS_PER_COL) { + int k = k_start + k_local; + float x_val = static_cast(x_row[k]); + uint8_t quant_val = unpack_packed_value_fast(w_row, k, row_bytes); + qx_acc = fmaf(x_val, dequantize_value(quant_val, 1.0f, 0.0f), qx_acc); + } + acc += scale * subgroup_reduce_sum_qmm(qx_acc); + } + } + + if (lane == 0) { + const int orig_batch = out_perm[batch]; + out[static_cast(orig_batch) * M * N + static_cast(row) * N + col] = static_cast(acc); + } + } +} + +template < + typename T, + typename ScaleT, + int BITS, + int GROUP_SIZE, + bool AFFINE, + int THREADS_PER_COL> +__global__ void __launch_bounds__(1024) gather_qmv_warp_shared_kernel( + const T* __restrict__ x, + const uint8_t* __restrict__ w, + const ScaleT* __restrict__ scales, + const ScaleT* __restrict__ biases, + const uint32_t* __restrict__ lhs_indices, + const uint32_t* __restrict__ rhs_indices, + const rocm::Shape batch_shape, + const rocm::Strides lhs_idx_strides, + const rocm::Strides rhs_idx_strides, + int batch_ndim, + T* __restrict__ out, + int B, + int M, + int N, + int K, + int E, + bool has_bias, + bool implicit_lhs = false, + int64_t implicit_x_batch_stride = 0) { + const int lane = threadIdx.x; + const int warp_idx = threadIdx.y; + const int col = blockIdx.y * blockDim.y + warp_idx; + const int row = blockIdx.x; + const int batch = blockIdx.z; + + if (batch >= B || row >= M) { + return; + } + + int64_t rhs_idx_loc = 0; + int64_t lhs_idx_loc = 0; + if (batch_ndim == 1) { + rhs_idx_loc = static_cast(batch) * rhs_idx_strides[0]; + if (!implicit_lhs) { + lhs_idx_loc = static_cast(batch) * lhs_idx_strides[0]; + } + } else if (batch_ndim > 1) { + int64_t elem = static_cast(batch); + for (int i = batch_ndim - 1; i >= 0; --i) { + int64_t coord = elem % batch_shape.data_[i]; + rhs_idx_loc += coord * rhs_idx_strides.data_[i]; + if (!implicit_lhs) { + lhs_idx_loc += coord * lhs_idx_strides.data_[i]; + } + elem /= batch_shape.data_[i]; + } + } + + uint32_t lhs_idx = implicit_lhs ? 0u : lhs_indices[lhs_idx_loc]; + uint32_t rhs_idx = rhs_indices[rhs_idx_loc]; + + const bool col_valid = col < N; + const bool expert_valid = rhs_idx < static_cast(E); + const bool valid = col_valid && expert_valid; + + const int num_groups = (K + GROUP_SIZE - 1) / GROUP_SIZE; + const int row_bytes = (K * BITS + 7) / 8; + + int64_t x_batch_stride = static_cast(M) * K; + int64_t w_batch_stride = static_cast(N) * row_bytes; + int64_t sb_batch_stride = static_cast(N) * num_groups; + int64_t col_w_offset = static_cast(col) * row_bytes; + int64_t col_sb_offset = static_cast(col) * num_groups; + + int64_t x_batch_offset = implicit_lhs + ? (static_cast(batch) * implicit_x_batch_stride) + : (static_cast(lhs_idx) * x_batch_stride); + const T* x_row = x + x_batch_offset + static_cast(row) * K; + const uint8_t* w_row = valid + ? (w + static_cast(rhs_idx) * w_batch_stride + col_w_offset) + : nullptr; + const ScaleT* scales_row = valid + ? (scales + static_cast(rhs_idx) * sb_batch_stride + + col_sb_offset) + : nullptr; + const ScaleT* biases_row = (valid && has_bias) + ? (biases + static_cast(rhs_idx) * sb_batch_stride + + col_sb_offset) + : nullptr; + + float acc = 0.0f; + + constexpr int CHUNK_SIZE = 2048; + __shared__ float shared_x[CHUNK_SIZE]; + + for (int chunk_start = 0; chunk_start < K; chunk_start += CHUNK_SIZE) { + int chunk_end = min(chunk_start + CHUNK_SIZE, K); + int chunk_len = chunk_end - chunk_start; + + int tid = warp_idx * blockDim.x + lane; + for (int i = tid; i < chunk_len; i += blockDim.x * blockDim.y) { + shared_x[i] = static_cast(x_row[chunk_start + i]); + } + __syncthreads(); + + if (valid) { + int g_start = chunk_start / GROUP_SIZE; + int g_end = (chunk_end + GROUP_SIZE - 1) / GROUP_SIZE; + + for (int g = g_start; g < g_end; ++g) { + int k_start = max(g * GROUP_SIZE, chunk_start); + int k_end_g = min((g + 1) * GROUP_SIZE, chunk_end); + + float scale = + load_scale_value(scales_row[g]); + float bias_val = has_bias ? static_cast(biases_row[g]) : 0.0f; + + if constexpr (AFFINE) { + float qx_acc0 = 0.0f; + float qx_acc1 = 0.0f; + float qx_acc2 = 0.0f; + float qx_acc3 = 0.0f; + float x_group_sum = 0.0f; + float qx_acc = 0.0f; + + if constexpr (BITS == 8) { + int k_local = lane * 4; + int step = THREADS_PER_COL * 4; + for (; k_start + k_local + 3 < k_end_g; k_local += step) { + int k = k_start + k_local; + + uint32_t w_packed = *reinterpret_cast(&w_row[k]); + float w0 = static_cast(w_packed & 0xFF); + float w1 = static_cast((w_packed >> 8) & 0xFF); + float w2 = static_cast((w_packed >> 16) & 0xFF); + float w3 = static_cast((w_packed >> 24) & 0xFF); + + float x0 = shared_x[k - chunk_start]; + float x1 = shared_x[k - chunk_start + 1]; + float x2 = shared_x[k - chunk_start + 2]; + float x3 = shared_x[k - chunk_start + 3]; + + qx_acc0 = fmaf(x0, w0, qx_acc0); + qx_acc1 = fmaf(x1, w1, qx_acc1); + qx_acc2 = fmaf(x2, w2, qx_acc2); + qx_acc3 = fmaf(x3, w3, qx_acc3); + + if (has_bias) { + x_group_sum += x0 + x1 + x2 + x3; + } + } + qx_acc = qx_acc0 + qx_acc1 + qx_acc2 + qx_acc3; + for (; k_start + k_local < k_end_g; k_local++) { + int k = k_start + k_local; + float x_val = shared_x[k - chunk_start]; + float w_val = static_cast(w_row[k]); + qx_acc = fmaf(x_val, w_val, qx_acc); + if (has_bias) { + x_group_sum += x_val; + } + } + } else if constexpr (BITS == 4) { + int k_local = lane * 8; + int step = THREADS_PER_COL * 8; + for (; k_start + k_local + 7 < k_end_g; k_local += step) { + int k = k_start + k_local; + uint32_t w_packed = + *reinterpret_cast(&w_row[k / 2]); + float w0 = static_cast(w_packed & 0xF); + float w1 = static_cast((w_packed >> 4) & 0xF); + float w2 = static_cast((w_packed >> 8) & 0xF); + float w3 = static_cast((w_packed >> 12) & 0xF); + float w4 = static_cast((w_packed >> 16) & 0xF); + float w5 = static_cast((w_packed >> 20) & 0xF); + float w6 = static_cast((w_packed >> 24) & 0xF); + float w7 = static_cast((w_packed >> 28) & 0xF); + float x0 = shared_x[k - chunk_start]; + float x1 = shared_x[k - chunk_start + 1]; + float x2 = shared_x[k - chunk_start + 2]; + float x3 = shared_x[k - chunk_start + 3]; + float x4 = shared_x[k - chunk_start + 4]; + float x5 = shared_x[k - chunk_start + 5]; + float x6 = shared_x[k - chunk_start + 6]; + float x7 = shared_x[k - chunk_start + 7]; + // Four independent accumulators for RDNA4 dual-issue (mirrors the + // 8-bit branch); reassociated at group end. + qx_acc0 = fmaf(x0, w0, qx_acc0); + qx_acc1 = fmaf(x1, w1, qx_acc1); + qx_acc2 = fmaf(x2, w2, qx_acc2); + qx_acc3 = fmaf(x3, w3, qx_acc3); + qx_acc0 = fmaf(x4, w4, qx_acc0); + qx_acc1 = fmaf(x5, w5, qx_acc1); + qx_acc2 = fmaf(x6, w6, qx_acc2); + qx_acc3 = fmaf(x7, w7, qx_acc3); + if (has_bias) { + x_group_sum += x0 + x1 + x2 + x3 + x4 + x5 + x6 + x7; + } + } + qx_acc = qx_acc0 + qx_acc1 + qx_acc2 + qx_acc3; + for (; k_start + k_local < k_end_g; k_local++) { + int k = k_start + k_local; + float x_val = shared_x[k - chunk_start]; + uint8_t quant_val = + unpack_packed_value_fast(w_row, k, row_bytes); + qx_acc = fmaf(x_val, static_cast(quant_val), qx_acc); + if (has_bias) { + x_group_sum += x_val; + } + } + } else if constexpr (BITS == 6) { + int k_local = lane * 8; + int step = THREADS_PER_COL * 8; + int max_safe_k = ((row_bytes - 7) * 8) / 6; + for (; k_start + k_local + 7 < k_end_g && + k_start + k_local < max_safe_k; + k_local += step) { + int k = k_start + k_local; + int byte_idx = (k * 6) / 8; + int bit_offset = (k * 6) % 8; + uint64_t w_packed = 0; + memcpy(&w_packed, &w_row[byte_idx], 8); + w_packed >>= bit_offset; + float w0 = static_cast(w_packed & 0x3F); + float w1 = static_cast((w_packed >> 6) & 0x3F); + float w2 = static_cast((w_packed >> 12) & 0x3F); + float w3 = static_cast((w_packed >> 18) & 0x3F); + float w4 = static_cast((w_packed >> 24) & 0x3F); + float w5 = static_cast((w_packed >> 30) & 0x3F); + float w6 = static_cast((w_packed >> 36) & 0x3F); + float w7 = static_cast((w_packed >> 42) & 0x3F); + float x0 = shared_x[k - chunk_start]; + float x1 = shared_x[k - chunk_start + 1]; + float x2 = shared_x[k - chunk_start + 2]; + float x3 = shared_x[k - chunk_start + 3]; + float x4 = shared_x[k - chunk_start + 4]; + float x5 = shared_x[k - chunk_start + 5]; + float x6 = shared_x[k - chunk_start + 6]; + float x7 = shared_x[k - chunk_start + 7]; + qx_acc = fmaf(x0, w0, qx_acc); + qx_acc = fmaf(x1, w1, qx_acc); + qx_acc = fmaf(x2, w2, qx_acc); + qx_acc = fmaf(x3, w3, qx_acc); + qx_acc = fmaf(x4, w4, qx_acc); + qx_acc = fmaf(x5, w5, qx_acc); + qx_acc = fmaf(x6, w6, qx_acc); + qx_acc = fmaf(x7, w7, qx_acc); + if (has_bias) { + x_group_sum += x0 + x1 + x2 + x3 + x4 + x5 + x6 + x7; + } + } + for (; k_start + k_local < k_end_g; k_local++) { + int k = k_start + k_local; + float x_val = shared_x[k - chunk_start]; + uint8_t quant_val = + unpack_packed_value_fast(w_row, k, row_bytes); + qx_acc = fmaf(x_val, static_cast(quant_val), qx_acc); + if (has_bias) { + x_group_sum += x_val; + } + } + } else { + for (int k_local = lane; k_start + k_local < k_end_g; + k_local += THREADS_PER_COL) { + int k = k_start + k_local; + float x_val = shared_x[k - chunk_start]; + uint8_t quant_val = + unpack_packed_value_fast(w_row, k, row_bytes); + qx_acc = fmaf(x_val, static_cast(quant_val), qx_acc); + if (has_bias) { + x_group_sum += x_val; + } + } + } + acc += scale * qx_acc; + if (has_bias) { + acc += bias_val * x_group_sum; + } + } else { + float qx_acc0 = 0.0f; + float qx_acc1 = 0.0f; + float qx_acc2 = 0.0f; + float qx_acc3 = 0.0f; + float qx_acc = 0.0f; + if constexpr (BITS == 8) { + int k_local = lane * 4; + int step = THREADS_PER_COL * 4; + for (; k_start + k_local + 3 < k_end_g; k_local += step) { + int k = k_start + k_local; + + uint32_t w_packed = *reinterpret_cast(&w_row[k]); + float w0 = fp8_e4m3_to_float(w_packed & 0xFF); + float w1 = fp8_e4m3_to_float((w_packed >> 8) & 0xFF); + float w2 = fp8_e4m3_to_float((w_packed >> 16) & 0xFF); + float w3 = fp8_e4m3_to_float((w_packed >> 24) & 0xFF); + + float x0 = shared_x[k - chunk_start]; + float x1 = shared_x[k - chunk_start + 1]; + float x2 = shared_x[k - chunk_start + 2]; + float x3 = shared_x[k - chunk_start + 3]; + + qx_acc0 = fmaf(x0, w0, qx_acc0); + qx_acc1 = fmaf(x1, w1, qx_acc1); + qx_acc2 = fmaf(x2, w2, qx_acc2); + qx_acc3 = fmaf(x3, w3, qx_acc3); + } + qx_acc = qx_acc0 + qx_acc1 + qx_acc2 + qx_acc3; + for (; k_start + k_local < k_end_g; k_local++) { + int k = k_start + k_local; + float x_val = shared_x[k - chunk_start]; + float w_val = fp8_e4m3_to_float(w_row[k]); + qx_acc = fmaf(x_val, w_val, qx_acc); + } + } else if constexpr (BITS == 4) { + int k_local = lane * 8; + int step = THREADS_PER_COL * 8; + for (; k_start + k_local + 7 < k_end_g; k_local += step) { + int k = k_start + k_local; + uint32_t w_packed = + *reinterpret_cast(&w_row[k / 2]); + float w0 = dequantize_value<4, false>(w_packed & 0xF, 1.0f, 0.0f); + float w1 = + dequantize_value<4, false>((w_packed >> 4) & 0xF, 1.0f, 0.0f); + float w2 = + dequantize_value<4, false>((w_packed >> 8) & 0xF, 1.0f, 0.0f); + float w3 = dequantize_value<4, false>( + (w_packed >> 12) & 0xF, 1.0f, 0.0f); + float w4 = dequantize_value<4, false>( + (w_packed >> 16) & 0xF, 1.0f, 0.0f); + float w5 = dequantize_value<4, false>( + (w_packed >> 20) & 0xF, 1.0f, 0.0f); + float w6 = dequantize_value<4, false>( + (w_packed >> 24) & 0xF, 1.0f, 0.0f); + float w7 = dequantize_value<4, false>( + (w_packed >> 28) & 0xF, 1.0f, 0.0f); + float x0 = shared_x[k - chunk_start]; + float x1 = shared_x[k - chunk_start + 1]; + float x2 = shared_x[k - chunk_start + 2]; + float x3 = shared_x[k - chunk_start + 3]; + float x4 = shared_x[k - chunk_start + 4]; + float x5 = shared_x[k - chunk_start + 5]; + float x6 = shared_x[k - chunk_start + 6]; + float x7 = shared_x[k - chunk_start + 7]; + qx_acc = fmaf(x0, w0, qx_acc); + qx_acc = fmaf(x1, w1, qx_acc); + qx_acc = fmaf(x2, w2, qx_acc); + qx_acc = fmaf(x3, w3, qx_acc); + qx_acc = fmaf(x4, w4, qx_acc); + qx_acc = fmaf(x5, w5, qx_acc); + qx_acc = fmaf(x6, w6, qx_acc); + qx_acc = fmaf(x7, w7, qx_acc); + } + for (; k_start + k_local < k_end_g; k_local++) { + int k = k_start + k_local; + float x_val = shared_x[k - chunk_start]; + uint8_t quant_val = + unpack_packed_value_fast(w_row, k, row_bytes); + qx_acc = fmaf( + x_val, + dequantize_value(quant_val, 1.0f, 0.0f), + qx_acc); + } + } else if constexpr (BITS == 6) { + int k_local = lane * 8; + int step = THREADS_PER_COL * 8; + int max_safe_k = ((row_bytes - 7) * 8) / 6; + for (; k_start + k_local + 7 < k_end_g && + k_start + k_local < max_safe_k; + k_local += step) { + int k = k_start + k_local; + int byte_idx = (k * 6) / 8; + int bit_offset = (k * 6) % 8; + uint64_t w_packed = 0; + memcpy(&w_packed, &w_row[byte_idx], 8); + w_packed >>= bit_offset; + float w0 = + dequantize_value<6, false>(w_packed & 0x3F, 1.0f, 0.0f); + float w1 = dequantize_value<6, false>( + (w_packed >> 6) & 0x3F, 1.0f, 0.0f); + float w2 = dequantize_value<6, false>( + (w_packed >> 12) & 0x3F, 1.0f, 0.0f); + float w3 = dequantize_value<6, false>( + (w_packed >> 18) & 0x3F, 1.0f, 0.0f); + float w4 = dequantize_value<6, false>( + (w_packed >> 24) & 0x3F, 1.0f, 0.0f); + float w5 = dequantize_value<6, false>( + (w_packed >> 30) & 0x3F, 1.0f, 0.0f); + float w6 = dequantize_value<6, false>( + (w_packed >> 36) & 0x3F, 1.0f, 0.0f); + float w7 = dequantize_value<6, false>( + (w_packed >> 42) & 0x3F, 1.0f, 0.0f); + float x0 = shared_x[k - chunk_start]; + float x1 = shared_x[k - chunk_start + 1]; + float x2 = shared_x[k - chunk_start + 2]; + float x3 = shared_x[k - chunk_start + 3]; + float x4 = shared_x[k - chunk_start + 4]; + float x5 = shared_x[k - chunk_start + 5]; + float x6 = shared_x[k - chunk_start + 6]; + float x7 = shared_x[k - chunk_start + 7]; + qx_acc = fmaf(x0, w0, qx_acc); + qx_acc = fmaf(x1, w1, qx_acc); + qx_acc = fmaf(x2, w2, qx_acc); + qx_acc = fmaf(x3, w3, qx_acc); + qx_acc = fmaf(x4, w4, qx_acc); + qx_acc = fmaf(x5, w5, qx_acc); + qx_acc = fmaf(x6, w6, qx_acc); + qx_acc = fmaf(x7, w7, qx_acc); + } + for (; k_start + k_local < k_end_g; k_local++) { + int k = k_start + k_local; + float x_val = shared_x[k - chunk_start]; + uint8_t quant_val = + unpack_packed_value_fast(w_row, k, row_bytes); + qx_acc = fmaf( + x_val, + dequantize_value(quant_val, 1.0f, 0.0f), + qx_acc); + } + } else { + for (int k_local = lane; k_start + k_local < k_end_g; + k_local += THREADS_PER_COL) { + int k = k_start + k_local; + float x_val = shared_x[k - chunk_start]; + uint8_t quant_val = + unpack_packed_value_fast(w_row, k, row_bytes); + qx_acc = fmaf( + x_val, + dequantize_value(quant_val, 1.0f, 0.0f), + qx_acc); + } + } + acc += scale * qx_acc; + } + } + } + __syncthreads(); + } + + acc = subgroup_reduce_sum_qmm(acc); + if (col_valid && lane == 0) { + int64_t out_offset = (static_cast(batch) * M + row) * N + col; + out[out_offset] = expert_valid ? static_cast(acc) : static_cast(0); + } +} + +template +__global__ void gather_qmv_kernel( + const T* __restrict__ x, + const uint8_t* __restrict__ w, + const ScaleT* __restrict__ scales, + const ScaleT* __restrict__ biases, + const uint32_t* __restrict__ lhs_indices, + const uint32_t* __restrict__ rhs_indices, + const rocm::Shape batch_shape, + const rocm::Strides lhs_idx_strides, + const rocm::Strides rhs_idx_strides, + int batch_ndim, + T* __restrict__ out, + int B, + int M, + int N, + int K, + int E, + bool has_bias, + bool implicit_lhs = false, + int64_t implicit_x_batch_stride = 0) { + int batch = blockIdx.z; + int row = blockIdx.x; + int col = blockIdx.y * blockDim.x + threadIdx.x; + if (batch >= B || row >= M || col >= N) + return; + int64_t lhs_idx_loc = 0; + int64_t rhs_idx_loc = 0; + if (batch_ndim == 1) { + rhs_idx_loc = (int64_t)batch * rhs_idx_strides[0]; + if (!implicit_lhs) { + lhs_idx_loc = (int64_t)batch * lhs_idx_strides[0]; + } + } else if (batch_ndim > 1) { + int64_t elem = (int64_t)batch; + for (int i = batch_ndim - 1; i >= 0; --i) { + int64_t coord = elem % batch_shape.data_[i]; + rhs_idx_loc += coord * rhs_idx_strides.data_[i]; + if (!implicit_lhs) { + lhs_idx_loc += coord * lhs_idx_strides.data_[i]; + } + elem /= batch_shape.data_[i]; + } + } + uint32_t lhs_idx = implicit_lhs ? 0u : lhs_indices[lhs_idx_loc]; + uint32_t rhs_idx = rhs_indices[rhs_idx_loc]; + if (rhs_idx >= static_cast(E)) { + out[batch * M * N + row * N + col] = static_cast(0); + return; + } + + int num_groups = (K + GROUP_SIZE - 1) / GROUP_SIZE; + int row_bytes = (K * BITS + 7) / 8; + int64_t x_batch_stride = static_cast(M) * K; + int64_t w_batch_stride = static_cast(N) * row_bytes; + int64_t sb_batch_stride = static_cast(N) * num_groups; + int64_t col_w_offset = static_cast(col) * row_bytes; + int64_t col_sb_offset = static_cast(col) * num_groups; + + int64_t x_batch_offset = implicit_lhs + ? (static_cast(batch) * implicit_x_batch_stride) + : (static_cast(lhs_idx) * x_batch_stride); + const T* x_ptr = x + x_batch_offset + static_cast(row) * K; + const uint8_t* w_ptr = + w + static_cast(rhs_idx) * w_batch_stride + col_w_offset; + const ScaleT* scales_ptr = + scales + static_cast(rhs_idx) * sb_batch_stride + col_sb_offset; + const ScaleT* biases_ptr = has_bias + ? biases + static_cast(rhs_idx) * sb_batch_stride + col_sb_offset + : nullptr; + float acc = 0.0f; + for (int g = 0; g < num_groups; ++g) { + float scale = load_scale_value(scales_ptr[g]); + float bias = has_bias ? (float)biases_ptr[g] : 0.0f; + int k_start = g * GROUP_SIZE; + int k_end = min(k_start + GROUP_SIZE, K); + + if constexpr (BITS == 8) { + int k = k_start; + for (; k + 3 < k_end; k += 4) { + uint32_t w_packed = *reinterpret_cast(&w_ptr[k]); + float w0 = dequantize_value<8, AFFINE>(w_packed & 0xFF, scale, bias); + float w1 = + dequantize_value<8, AFFINE>((w_packed >> 8) & 0xFF, scale, bias); + float w2 = + dequantize_value<8, AFFINE>((w_packed >> 16) & 0xFF, scale, bias); + float w3 = + dequantize_value<8, AFFINE>((w_packed >> 24) & 0xFF, scale, bias); + + acc += (float)x_ptr[k] * w0; + acc += (float)x_ptr[k + 1] * w1; + acc += (float)x_ptr[k + 2] * w2; + acc += (float)x_ptr[k + 3] * w3; + } + for (; k < k_end; ++k) { + float w_val = dequantize_value<8, AFFINE>(w_ptr[k], scale, bias); + acc += (float)x_ptr[k] * w_val; + } + } else { + for (int k = k_start; k < k_end; ++k) { + uint8_t qv = unpack_packed_value_fast(w_ptr, k, row_bytes); + acc += + (float)x_ptr[k] * dequantize_value(qv, scale, bias); + } + } + } + out[batch * M * N + row * N + col] = (T)acc; +} + +// ====================================================================== +// WMMA-accelerated gather QMV prefill kernel using rocwmma 16x16x16 tiles. +// +// Each wavefront (32 lanes on RDNA 3.5 / gfx1151) computes one 16x16 +// output tile. Weights are dequantized from 4-bit packed format into +// bf16 in shared memory, then loaded into rocwmma fragments for the +// matrix multiply-accumulate. Accumulation is in float32; the final +// result is converted back to bf16 on store. +// +// Grid: (ceil(M/16), ceil(N/16), num_runs) +// Block: (32, 1, 1) -- one wave32 per 16x16 output tile +// +// On architectures without WMMA support (RDNA 1/2) the kernel body is +// an empty stub; dispatch checks prevent it from being launched there. +// ====================================================================== +template +__global__ void __launch_bounds__(32) gather_qmv_wmma_prefill_kernel( + const T* __restrict__ x, + const uint8_t* __restrict__ w, + const ScaleT* __restrict__ scales, + const ScaleT* __restrict__ biases, + const uint32_t* __restrict__ lhs_indices, + const uint32_t* __restrict__ rhs_indices, + const int* __restrict__ run_starts, + const int* __restrict__ run_lengths, + const int* __restrict__ out_perm, // maps sorted batch idx → original batch idx + T* __restrict__ out, + int B, int M, int N, int K, int E, + bool has_bias, int64_t x_batch_stride) { + +#if ROCM_HAS_WMMA + + static_assert(BITS == 4, "WMMA prefill kernel only supports 4-bit quantized weights"); + static_assert(AFFINE, "WMMA prefill kernel only supports affine quantization"); + + constexpr int WMMA_M = 16; + constexpr int WMMA_N = 16; + constexpr int WMMA_K = 16; + + // Tile coordinates in the output matrix + const int tile_row = blockIdx.x * WMMA_M; // starting row of this 16x16 tile + const int tile_col = blockIdx.y * WMMA_N; // starting col of this 16x16 tile + const int run_id = blockIdx.z; + + // Bounds check -- the dispatch guarantees M and N are multiples of 16, + // but guard anyway for safety. + if (tile_row >= M || tile_col >= N) return; + + const int lane = threadIdx.x; // 0..31 + + // Run info + const int run_start = run_starts[run_id]; + const int run_len = run_lengths[run_id]; + + const uint32_t rhs_idx = rhs_indices[run_start]; + if (rhs_idx >= static_cast(E)) return; + + // Weight layout constants + const int num_groups = (K + GROUP_SIZE - 1) / GROUP_SIZE; + const int row_bytes = (K * BITS + 7) / 8; // bytes per weight row (one output col) + const int64_t w_expert_stride = static_cast(N) * row_bytes; + const int64_t sb_expert_stride = static_cast(N) * num_groups; + + // Base pointers for this expert + const uint8_t* w_expert = w + static_cast(rhs_idx) * w_expert_stride; + const ScaleT* s_expert = scales + static_cast(rhs_idx) * sb_expert_stride; + const ScaleT* b_expert = has_bias + ? (biases + static_cast(rhs_idx) * sb_expert_stride) + : nullptr; + + // Shared memory for dequantized weight tile [WMMA_K x WMMA_N] in row-major + // and for x tile [WMMA_M x WMMA_K] in row-major. + // Total: (16*16 + 16*16) * sizeof(hip_bfloat16) = 1024 bytes + __shared__ hip_bfloat16 smem_w[WMMA_K * WMMA_N]; // [16][16] row-major + __shared__ hip_bfloat16 smem_x[WMMA_M * WMMA_K]; // [16][16] row-major + + // Fragment types for bf16 input, f32 accumulation + using frag_a = rocwmma::fragment; + using frag_b = rocwmma::fragment; + using frag_acc = rocwmma::fragment; + + // Process each batch element in the run + for (int r = 0; r < run_len; ++r) { + const int batch = run_start + r; + const uint32_t lhs_idx = lhs_indices[batch]; + const T* x_base = x + static_cast(lhs_idx) * x_batch_stride + + static_cast(tile_row) * K; + + // Zero the accumulator for this batch element + frag_acc acc; + rocwmma::fill_fragment(acc, 0.0f); + + // Loop over K dimension in chunks of WMMA_K (16) + for (int k_base = 0; k_base < K; k_base += WMMA_K) { + // --- Load x tile [WMMA_M x WMMA_K] into shared memory --- + // 32 lanes load 256 elements (16x16) -> 8 elements per lane + // Pad with zero for rows beyond M (handles non-16-aligned M) + #pragma unroll + for (int i = 0; i < (WMMA_M * WMMA_K + 31) / 32; ++i) { + int idx = lane + i * 32; + if (idx < WMMA_M * WMMA_K) { + int m_local = idx / WMMA_K; + int k_local = idx % WMMA_K; + int m_global = tile_row + m_local; + int k_global = k_base + k_local; + if (m_global < M && k_global < K) { + smem_x[idx] = x_base[m_local * K + k_global]; + } else { + smem_x[idx] = static_cast(0.0f); + } + } + } + + // --- Dequantize weight tile [WMMA_K x WMMA_N] into shared memory --- + // Layout: smem_w[k][n] = dequant(w[expert, tile_col + n, k_base + k]) + // w is stored as [N, row_bytes], each row for one output column. + // We need 16 columns x 16 K values = 256 values, 8 per lane. + #pragma unroll + for (int i = 0; i < (WMMA_K * WMMA_N + 31) / 32; ++i) { + int idx = lane + i * 32; + if (idx < WMMA_K * WMMA_N) { + int k_local = idx / WMMA_N; // row in [K, N] + int n_local = idx % WMMA_N; // col in [K, N] + int k_global = k_base + k_local; + int n_global = tile_col + n_local; + + if (k_global < K) { + // Pointer to weight row for output column n_global + const uint8_t* w_row = w_expert + static_cast(n_global) * row_bytes; + + // Extract 4-bit quantized value + uint8_t packed = w_row[k_global >> 1]; + uint8_t quant_val = (k_global & 1) ? (packed >> 4) : (packed & 0xF); + + // Dequantize: val = scale * quant_val + bias + int group_idx = k_global / GROUP_SIZE; + float scale = static_cast( + s_expert[static_cast(n_global) * num_groups + group_idx]); + float bias_val = has_bias + ? static_cast( + b_expert[static_cast(n_global) * num_groups + group_idx]) + : 0.0f; + float dequant = scale * static_cast(quant_val) + bias_val; + smem_w[idx] = static_cast(dequant); + } else { + smem_w[idx] = static_cast(0.0f); + } + } + } + + __syncthreads(); + + // --- Load fragments from shared memory and perform MMA --- + frag_a a_frag; + frag_b b_frag; + + // Load A from smem_x [WMMA_M x WMMA_K], row-major, ldm = WMMA_K + rocwmma::load_matrix_sync(a_frag, smem_x, WMMA_K); + // Load B from smem_w [WMMA_K x WMMA_N], row-major, ldm = WMMA_N + rocwmma::load_matrix_sync(b_frag, smem_w, WMMA_N); + + // D = A * B + C + rocwmma::mma_sync(acc, a_frag, b_frag, acc); + + __syncthreads(); + } + + // --- Store the 16x16 result tile --- + // Store f32 accumulator to shared memory, then convert to bf16 for output. + __shared__ float smem_out_f32[WMMA_M * WMMA_N]; + + rocwmma::store_matrix_sync(smem_out_f32, acc, WMMA_N, rocwmma::mem_row_major); + __syncthreads(); + + // Convert f32 -> bf16 and write to global output (mask out-of-bounds rows) + // Use out_perm to map sorted batch position back to original output position + const int orig_batch = out_perm[batch]; + T* out_base = out + static_cast(orig_batch) * M * N + + static_cast(tile_row) * N + + tile_col; + #pragma unroll + for (int i = 0; i < (WMMA_M * WMMA_N + 31) / 32; ++i) { + int idx = lane + i * 32; + if (idx < WMMA_M * WMMA_N) { + int m_local = idx / WMMA_N; + int n_local = idx % WMMA_N; + if (tile_row + m_local < M) { + out_base[m_local * N + n_local] = static_cast(smem_out_f32[idx]); + } + } + } + __syncthreads(); + } + +#endif // ROCM_HAS_WMMA +} + +} // namespace rocm + +void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { + auto& s = stream(); + auto& d = rocm::device(s.device); + auto& enc = d.get_command_encoder(s); + out.set_data(mlx::core::rocm::malloc_async(out.nbytes(), enc)); + array x = ensure_row_contiguous_matrix(inputs[0], enc, s); + array w = ensure_row_contiguous_matrix(inputs[1], enc, s); + array scales = ensure_row_contiguous_matrix(inputs[2], enc, s); + std::optional biases = std::nullopt; + bool has_bias = (mode_ == QuantizationMode::Affine) && (inputs.size() == 6); + if (has_bias) + biases = ensure_row_contiguous_matrix(inputs[3], enc, s); + const array& lhs_indices = inputs[inputs.size() - 2]; + const array& rhs_indices = inputs[inputs.size() - 1]; + auto [batch_shape, batch_strides] = collapse_contiguous_dims( + lhs_indices.shape(), {lhs_indices.strides(), rhs_indices.strides()}); + auto batch_shape_param = const_param(batch_shape); + auto lhs_idx_strides_param = const_param(batch_strides[0]); + auto rhs_idx_strides_param = const_param(batch_strides[1]); + int batch_ndim = batch_shape.size(); + enc.set_input_array(x); + enc.set_input_array(w); + enc.set_input_array(scales); + if (has_bias) + enc.set_input_array(biases.value()); + enc.set_input_array(lhs_indices); + enc.set_input_array(rhs_indices); + enc.set_output_array(out); + int K = x.shape(-1), M = x.shape(-2), N = out.shape(-1), + B = out.size() / M / N, E = w.size() / w.shape(-1) / w.shape(-2); + + int64_t x_batch_count = x.size() / (static_cast(M) * K); + bool use_sorted_rhs_schedule = transpose_ && right_sorted_ && (M == 1) && + (B >= 16) && (E > 0) && (B / E >= 4) && + (x_batch_count == 1 || x_batch_count == B); + int64_t implicit_x_batch_stride = + (x_batch_count == 1) ? 0 : static_cast(M) * K; + + int block_size = 256; + dim3 grid(M, (N + block_size - 1) / block_size, B); + + int fast_threads_per_col = select_qmv_threads_per_col(K, N, bits_, B); + int fast_threads_env = + parse_threads_per_col_env("MLX_ROCM_GATHER_QMV_THREADS_PER_COL"); + if (fast_threads_env <= 0) { + fast_threads_env = + parse_threads_per_col_env("MLX_ROCM_QMV_THREADS_PER_COL"); + } + if (fast_threads_env > 0) { + fast_threads_per_col = fast_threads_env; + } + + int fast_cols_per_block = select_qmv_cols_per_block(K, N, bits_); + if (fast_threads_per_col == 16 && bits_ == 8 && N >= 2048) { + fast_cols_per_block = std::max(fast_cols_per_block, 64); + } + int max_cols_per_block = rocm::kMaxThreadsPerBlock / fast_threads_per_col; + while (fast_cols_per_block > max_cols_per_block) { + fast_cols_per_block /= 2; + } + while (fast_cols_per_block > 1 && (N % fast_cols_per_block) != 0 && + fast_cols_per_block > 8) { + fast_cols_per_block /= 2; + } + + dim3 fast_block(fast_threads_per_col, fast_cols_per_block); + dim3 fast_grid(M, (N + fast_cols_per_block - 1) / fast_cols_per_block, B); + + bool bits_supported_by_fast = (bits_ == 2 || bits_ == 4 || bits_ == 8) || + (mode_ == QuantizationMode::Affine && (bits_ == 5 || bits_ == 6)); + bool use_fast_gather_qmv = transpose_ && bits_supported_by_fast; + use_fast_gather_qmv = parse_warp_kernel_env( + "MLX_ROCM_GATHER_QMV_USE_WARP", use_fast_gather_qmv); + // ---- Prefill optimization: group by expert for M>1 ---- + // Works with both sorted and unsorted rhs_indices; we sort on CPU. + // NOTE: MLX's MoE expands tokens to B individual M=1 calls, so M>1 is rare. + // The WMMA prefill kernel is used when upstream batching produces M>1. + if (M > 1 && transpose_ && E > 0 && batch_ndim == 1 && + mode_ == QuantizationMode::Affine && x.dtype() == bfloat16 && + group_size_ == 64 && (bits_ == 4 || bits_ == 8)) { + // Sort batch elements by expert to form contiguous runs. + // This allows the kernel to process all tokens for one expert together, + // sharing weight reads. We create a sorted permutation on CPU. + const auto* ri_cpu = rhs_indices.data(); + const auto* li_cpu = lhs_indices.data(); + + // Create sort permutation by expert index + std::vector perm(B); + std::iota(perm.begin(), perm.end(), 0); + std::sort(perm.begin(), perm.end(), [&](int a, int b) { + return ri_cpu[a] < ri_cpu[b]; + }); + + // Build sorted index arrays and compute runs + std::vector sorted_ri(B), sorted_li(B); + for (int i = 0; i < B; ++i) { + sorted_ri[i] = ri_cpu[perm[i]]; + sorted_li[i] = li_cpu[perm[i]]; + } + + std::vector run_starts_vec, run_lengths_vec; + run_starts_vec.reserve(E); + run_lengths_vec.reserve(E); + int run_begin = 0; + for (int b = 1; b <= B; ++b) { + if (b == B || sorted_ri[b] != sorted_ri[run_begin]) { + run_starts_vec.push_back(run_begin); + run_lengths_vec.push_back(b - run_begin); + run_begin = b; + } + } + int num_runs = static_cast(run_starts_vec.size()); + + // Upload sorted indices to GPU + array sorted_ri_arr({B}, uint32, nullptr, {}); + array sorted_li_arr({B}, uint32, nullptr, {}); + sorted_ri_arr.set_data(mlx::core::rocm::malloc_async(sorted_ri_arr.nbytes(), enc)); + sorted_li_arr.set_data(mlx::core::rocm::malloc_async(sorted_li_arr.nbytes(), enc)); + std::memcpy(sorted_ri_arr.data(), sorted_ri.data(), B * sizeof(uint32_t)); + std::memcpy(sorted_li_arr.data(), sorted_li.data(), B * sizeof(uint32_t)); + enc.set_input_array(sorted_ri_arr); + enc.set_input_array(sorted_li_arr); + + // Also need a mapping from sorted position back to original batch index for output + array perm_arr({B}, int32, nullptr, {}); + perm_arr.set_data(mlx::core::rocm::malloc_async(perm_arr.nbytes(), enc)); + std::memcpy(perm_arr.data(), perm.data(), B * sizeof(int)); + enc.set_input_array(perm_arr); + + // Upload run info to GPU + array run_starts_arr({num_runs}, int32, nullptr, {}); + array run_lengths_arr({num_runs}, int32, nullptr, {}); + run_starts_arr.set_data(mlx::core::rocm::malloc_async(run_starts_arr.nbytes(), enc)); + run_lengths_arr.set_data(mlx::core::rocm::malloc_async(run_lengths_arr.nbytes(), enc)); + std::memcpy(run_starts_arr.data(), run_starts_vec.data(), num_runs * sizeof(int)); + std::memcpy(run_lengths_arr.data(), run_lengths_vec.data(), num_runs * sizeof(int)); + enc.set_input_array(run_starts_arr); + enc.set_input_array(run_lengths_arr); + + int64_t x_bs = (x_batch_count == 1) ? 0 : static_cast(M) * K; + + // ---- WMMA path: use 16x16x16 wave matrix multiply when tiles align ---- + // WMMA tiles are 16x16; kernel handles non-aligned M with bounds masking. + // N must be 16-aligned (typical for transformer hidden dimensions). + // Gate on the device arch: a multi-arch build can compile this kernel + // for a target whose gcnArchName isn't on the rocWMMA allowlist + // (e.g. gfx1030/1103) — dispatching there would crash. + bool use_wmma = d.has_native_wmma() && (M >= 2) && (N % 16 == 0) && + (bits_ == 4); + use_wmma = parse_warp_kernel_env("MLX_ROCM_GATHER_QMV_USE_WMMA", use_wmma); + + if (use_wmma) { + // One wave32 per 16x16 output tile + dim3 wmma_block(32, 1, 1); + dim3 wmma_grid((M + 15) / 16, (N + 15) / 16, num_runs); + // Shared memory: smem_w[16*16] + smem_x[16*16] bf16 + smem_out_f32[16*16] f32 + // = 512 + 512 + 1024 = 2048 bytes + size_t wmma_smem = 0; // static shared memory, declared in-kernel + + enc.add_kernel_node_ex( + &rocm::gather_qmv_wmma_prefill_kernel, + wmma_grid, wmma_block, static_cast(wmma_smem), + gpu_ptr(x), + gpu_ptr(w), + gpu_ptr(scales), + static_cast(has_bias ? gpu_ptr(*biases) : nullptr), + gpu_ptr(sorted_li_arr), + gpu_ptr(sorted_ri_arr), + gpu_ptr(run_starts_arr), + gpu_ptr(run_lengths_arr), + gpu_ptr(perm_arr), + gpu_ptr(out), + B, M, N, K, E, has_bias, x_bs); + return; + } + + // ---- Scalar prefill fallback ---- + int fast_threads_per_col_pf = select_qmv_threads_per_col(K, N, bits_, num_runs); + int fast_cols_per_block_pf = select_qmv_cols_per_block(K, N, bits_); + int max_cpb = rocm::kMaxThreadsPerBlock / fast_threads_per_col_pf; + while (fast_cols_per_block_pf > max_cpb) fast_cols_per_block_pf /= 2; + while (fast_cols_per_block_pf > 1 && (N % fast_cols_per_block_pf) != 0 && fast_cols_per_block_pf > 8) + fast_cols_per_block_pf /= 2; + + dim3 pf_block(fast_threads_per_col_pf, fast_cols_per_block_pf); + dim3 pf_grid(M, (N + fast_cols_per_block_pf - 1) / fast_cols_per_block_pf, num_runs); + + { + auto launch_pf = [&](auto bits_tag) { + constexpr int BITS = decltype(bits_tag)::value; + enc.add_kernel_node( + &rocm::gather_qmv_prefill_kernel, + pf_grid, pf_block, 0u, + gpu_ptr(x), + gpu_ptr(w), + gpu_ptr(scales), + static_cast(has_bias ? gpu_ptr(*biases) : nullptr), + gpu_ptr(sorted_li_arr), + gpu_ptr(sorted_ri_arr), + gpu_ptr(run_starts_arr), + gpu_ptr(run_lengths_arr), + gpu_ptr(perm_arr), + gpu_ptr(out), + B, M, N, K, E, has_bias, x_bs); + }; + if (bits_ == 4) launch_pf(std::integral_constant{}); + else launch_pf(std::integral_constant{}); + } + return; + } + + const void *x_ptr = gpu_ptr(x), *w_ptr = gpu_ptr(w), + *scales_ptr = gpu_ptr(scales), + *biases_ptr = has_bias ? gpu_ptr(*biases) : nullptr; + const uint32_t *li_ptr = gpu_ptr(lhs_indices), + *ri_ptr = gpu_ptr(rhs_indices); + void* out_ptr = gpu_ptr(out); + + // GPU-only expert-batched kernel: when indices are sorted, each block finds + // its expert's token range on-GPU and processes them together. Weight data + // loaded once per expert column, reused across all tokens for that expert. + // max_unique_experts = min(B, E) is an upper bound on unique experts. + // Expert-batched kernel: beneficial when few experts have many tokens each. + // For high-expert-count models (E=512, top_k=10), most runs have 1-4 tokens, + // so the per-block run-finding overhead outweighs the shared weight benefit. + // Enable only when B/E is high enough (e.g., low expert count with long prompt). + bool use_expert_batched = transpose_ && right_sorted_ && (M == 1) && + (B >= 64) && (E > 0) && (E <= 64) && (B / E >= 4) && + mode_ == QuantizationMode::Affine && + x.dtype() == bfloat16 && group_size_ == 64 && (bits_ == 4 || bits_ == 8); + use_expert_batched = parse_warp_kernel_env( + "MLX_ROCM_GATHER_QMV_EXPERT_BATCHED", use_expert_batched); + + if (use_expert_batched) { + int max_unique_experts = std::min(B, E); + int eb_threads_per_col = select_qmv_threads_per_col(K, N, bits_, max_unique_experts); + int eb_cols_per_block = select_qmv_cols_per_block(K, N, bits_); + int eb_max_cpb = rocm::kMaxThreadsPerBlock / eb_threads_per_col; + while (eb_cols_per_block > eb_max_cpb) eb_cols_per_block /= 2; + while (eb_cols_per_block > 1 && (N % eb_cols_per_block) != 0 && eb_cols_per_block > 8) + eb_cols_per_block /= 2; + + dim3 eb_block(eb_threads_per_col, eb_cols_per_block); + dim3 eb_grid(M, (N + eb_cols_per_block - 1) / eb_cols_per_block, max_unique_experts); + + { + auto launch_eb = [&](auto bits_tag) { + constexpr int BITS = decltype(bits_tag)::value; + enc.add_kernel_node( + &rocm::gather_qmv_expert_batched_kernel< + hip_bfloat16, hip_bfloat16, BITS, 64, true, 16>, + eb_grid, eb_block, 0u, + (const hip_bfloat16*)x_ptr, + (const uint8_t*)w_ptr, + (const hip_bfloat16*)scales_ptr, + (const hip_bfloat16*)biases_ptr, + li_ptr, ri_ptr, + (hip_bfloat16*)out_ptr, + B, M, N, K, E, has_bias, + use_sorted_rhs_schedule, implicit_x_batch_stride); + }; + if (bits_ == 4) launch_eb(std::integral_constant{}); + else launch_eb(std::integral_constant{}); + } + return; + } + + // ---- Decode MoE: optional tiled gather-QMV (env-gated A/B) ---- + // The default warp-shared gather kernel launches very few blocks at decode + // (grid z = B = tokens*top_k, small), starving GPU occupancy. The tiled gather + // kernel (mirrors the fast 2D qmv_tiled path) also fans out over N tiles, for + // far more blocks. It reads the word-packed (uint32) weight layout, so it is + // restricted to 4/8-bit affine like the 2D tiled path, and to a unit-stride 1D + // batch (the common MoE decode case). Opt in with + // MLX_ROCM_GATHER_QMV_USE_TILED=1 to benchmark / enable. + static const bool g_use_tiled_gather = + (std::getenv("MLX_ROCM_GATHER_QMV_USE_TILED") != nullptr); + bool gather_tiled_ok = g_use_tiled_gather && transpose_ && + mode_ == QuantizationMode::Affine && x.dtype() == bfloat16 && + group_size_ == 64 && (bits_ == 4 || bits_ == 8) && + batch_ndim == 1 && batch_strides[0].size() == 1 && + batch_strides[0][0] == 1 && batch_strides[1][0] == 1; + + // Full-wave tiled 6-bit gather; MLX_ROCM_QMV_6BIT_SLOW reverts to warp_shared. + static const bool g_use_6bit_tiled_gather = + (std::getenv("MLX_ROCM_QMV_6BIT_SLOW") == nullptr); + bool g_gs6_supported = + (group_size_ == 32 || group_size_ == 64 || group_size_ == 128); + bool gather_tiled_6bit_ok = g_use_6bit_tiled_gather && transpose_ && + mode_ == QuantizationMode::Affine && x.dtype() == bfloat16 && + bits_ == 6 && (K % 64) == 0 && g_gs6_supported && + !use_sorted_rhs_schedule && + batch_ndim == 1 && batch_strides[0].size() == 1 && + batch_strides[0][0] == 1 && batch_strides[1][0] == 1; + + int gather_tile_n = 0; + if (gather_tiled_ok || gather_tiled_6bit_ok) { + auto gqmm_hw = detect_rocm_hw_info(enc.device()); + gather_tile_n = rocm::get_arch_tuning(gqmm_hw).qmv_tile_n; + while (gather_tile_n > 1 && (N % gather_tile_n) != 0) gather_tile_n /= 2; + if (gather_tile_n < 1) gather_tile_n = 1; + } + if (gather_tiled_6bit_ok && gather_tile_n < 8) gather_tiled_6bit_ok = false; + + { + if (gather_tiled_6bit_ok) { + dim3 gt_grid(M, (N + gather_tile_n - 1) / gather_tile_n, B); + dim3 gt_block(WARP_SIZE, gather_tile_n); + int LHS_B = static_cast(x_batch_count); + #define LAUNCH_GATHER_TILED_6BIT(GS_V) \ + enc.add_kernel_node( \ + &rocm::gather_qmv_tiled_6bit_kernel, \ + gt_grid, gt_block, 0u, \ + (const hip_bfloat16*)x_ptr, (const uint8_t*)w_ptr, \ + (const hip_bfloat16*)scales_ptr, (const hip_bfloat16*)biases_ptr, \ + li_ptr, ri_ptr, (hip_bfloat16*)out_ptr, \ + B, M, N, K, E, LHS_B, has_bias, gather_tile_n) + if (group_size_ == 32) { LAUNCH_GATHER_TILED_6BIT(32); } + else if (group_size_ == 64) { LAUNCH_GATHER_TILED_6BIT(64); } + else { LAUNCH_GATHER_TILED_6BIT(128); } + #undef LAUNCH_GATHER_TILED_6BIT + return; + } + if (gather_tiled_ok) { + dim3 gt_grid(M, (N + gather_tile_n - 1) / gather_tile_n, B); + dim3 gt_block(WARP_SIZE, gather_tile_n); + int LHS_B = static_cast(x_batch_count); + auto launch_gt = [&](auto bits_tag) { + constexpr int BITS = decltype(bits_tag)::value; + enc.add_kernel_node( + &rocm::gather_qmv_tiled_kernel, + gt_grid, gt_block, 0u, + (const hip_bfloat16*)x_ptr, + (const uint32_t*)w_ptr, + (const hip_bfloat16*)scales_ptr, + (const hip_bfloat16*)biases_ptr, + li_ptr, ri_ptr, + (hip_bfloat16*)out_ptr, + B, M, N, K, E, LHS_B, has_bias, gather_tile_n); + }; + if (bits_ == 4) launch_gt(std::integral_constant{}); + else launch_gt(std::integral_constant{}); + return; + } + if (use_fast_gather_qmv && mode_ == QuantizationMode::Affine && + x.dtype() == bfloat16 && group_size_ == 64 && + (bits_ == 4 || bits_ == 6 || bits_ == 8)) { + auto launch_fast_kernel = [&](auto bits_tag) { + constexpr int BITS = decltype(bits_tag)::value; + if (fast_threads_per_col == 16) { + enc.add_kernel_node( + &rocm::gather_qmv_warp_shared_kernel< + hip_bfloat16, + hip_bfloat16, + BITS, + 64, + true, + 16>, + fast_grid, + fast_block, + 0u, + (const hip_bfloat16*)x_ptr, + (const uint8_t*)w_ptr, + (const hip_bfloat16*)scales_ptr, + (const hip_bfloat16*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (hip_bfloat16*)out_ptr, + B, + M, + N, + K, + E, + has_bias, + use_sorted_rhs_schedule, + implicit_x_batch_stride); + } else { + enc.add_kernel_node( + &rocm::gather_qmv_warp_shared_kernel< + hip_bfloat16, + hip_bfloat16, + BITS, + 64, + true, + WARP_SIZE>, + fast_grid, + fast_block, + 0u, + (const hip_bfloat16*)x_ptr, + (const uint8_t*)w_ptr, + (const hip_bfloat16*)scales_ptr, + (const hip_bfloat16*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (hip_bfloat16*)out_ptr, + B, + M, + N, + K, + E, + has_bias, + use_sorted_rhs_schedule, + implicit_x_batch_stride); + } + }; + + if (bits_ == 4) { + launch_fast_kernel(std::integral_constant{}); + } else if (bits_ == 6) { + launch_fast_kernel(std::integral_constant{}); + } else { + launch_fast_kernel(std::integral_constant{}); + } + return; + } + +#define has_bias has_bias, use_sorted_rhs_schedule, implicit_x_batch_stride + + if (x.dtype() == float32) { + if (bits_ == 8 && group_size_ == 32) { + enc.add_kernel_node( + &rocm::gather_qmv_kernel, + grid, + dim3(block_size), + 0u, + (const float*)x_ptr, + (const uint8_t*)w_ptr, + (const float*)scales_ptr, + (const float*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (float*)out_ptr, + B, + M, + N, + K, + E, + has_bias); + } else if (bits_ == 8 && group_size_ == 64) { + enc.add_kernel_node( + &rocm::gather_qmv_kernel, + grid, + dim3(block_size), + 0u, + (const float*)x_ptr, + (const uint8_t*)w_ptr, + (const float*)scales_ptr, + (const float*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (float*)out_ptr, + B, + M, + N, + K, + E, + has_bias); + } else if (bits_ == 8 && group_size_ == 128) { + enc.add_kernel_node( + &rocm::gather_qmv_kernel, + grid, + dim3(block_size), + 0u, + (const float*)x_ptr, + (const uint8_t*)w_ptr, + (const float*)scales_ptr, + (const float*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (float*)out_ptr, + B, + M, + N, + K, + E, + has_bias); + } else if (bits_ == 5 && group_size_ == 32) { + enc.add_kernel_node( + &rocm::gather_qmv_kernel, + grid, + dim3(block_size), + 0u, + (const float*)x_ptr, + (const uint8_t*)w_ptr, + (const float*)scales_ptr, + (const float*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (float*)out_ptr, + B, + M, + N, + K, + E, + has_bias); + } else if (bits_ == 5 && group_size_ == 64) { + enc.add_kernel_node( + &rocm::gather_qmv_kernel, + grid, + dim3(block_size), + 0u, + (const float*)x_ptr, + (const uint8_t*)w_ptr, + (const float*)scales_ptr, + (const float*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (float*)out_ptr, + B, + M, + N, + K, + E, + has_bias); + } else if (bits_ == 5 && group_size_ == 128) { + enc.add_kernel_node( + &rocm::gather_qmv_kernel, + grid, + dim3(block_size), + 0u, + (const float*)x_ptr, + (const uint8_t*)w_ptr, + (const float*)scales_ptr, + (const float*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (float*)out_ptr, + B, + M, + N, + K, + E, + has_bias); + } else if (bits_ == 6 && group_size_ == 32) { + enc.add_kernel_node( + &rocm::gather_qmv_kernel, + grid, + dim3(block_size), + 0u, + (const float*)x_ptr, + (const uint8_t*)w_ptr, + (const float*)scales_ptr, + (const float*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (float*)out_ptr, + B, + M, + N, + K, + E, + has_bias); + } else if (bits_ == 6 && group_size_ == 64) { + enc.add_kernel_node( + &rocm::gather_qmv_kernel, + grid, + dim3(block_size), + 0u, + (const float*)x_ptr, + (const uint8_t*)w_ptr, + (const float*)scales_ptr, + (const float*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (float*)out_ptr, + B, + M, + N, + K, + E, + has_bias); + } else if (bits_ == 6 && group_size_ == 128) { + enc.add_kernel_node( + &rocm::gather_qmv_kernel, + grid, + dim3(block_size), + 0u, + (const float*)x_ptr, + (const uint8_t*)w_ptr, + (const float*)scales_ptr, + (const float*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (float*)out_ptr, + B, + M, + N, + K, + E, + has_bias); + } else if (bits_ == 4 && group_size_ == 32) { + enc.add_kernel_node( + &rocm::gather_qmv_kernel, + grid, + dim3(block_size), + 0u, + (const float*)x_ptr, + (const uint8_t*)w_ptr, + (const float*)scales_ptr, + (const float*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (float*)out_ptr, + B, + M, + N, + K, + E, + has_bias); + } else if (bits_ == 4 && group_size_ == 64) { + enc.add_kernel_node( + &rocm::gather_qmv_kernel, + grid, + dim3(block_size), + 0u, + (const float*)x_ptr, + (const uint8_t*)w_ptr, + (const float*)scales_ptr, + (const float*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (float*)out_ptr, + B, + M, + N, + K, + E, + has_bias); + } else if (bits_ == 4 && group_size_ == 128) { + enc.add_kernel_node( + &rocm::gather_qmv_kernel, + grid, + dim3(block_size), + 0u, + (const float*)x_ptr, + (const uint8_t*)w_ptr, + (const float*)scales_ptr, + (const float*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (float*)out_ptr, + B, + M, + N, + K, + E, + has_bias); + } else if (bits_ == 2 && group_size_ == 32) { + enc.add_kernel_node( + &rocm::gather_qmv_kernel, + grid, + dim3(block_size), + 0u, + (const float*)x_ptr, + (const uint8_t*)w_ptr, + (const float*)scales_ptr, + (const float*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (float*)out_ptr, + B, + M, + N, + K, + E, + has_bias); + } else if (bits_ == 2 && group_size_ == 64) { + enc.add_kernel_node( + &rocm::gather_qmv_kernel, + grid, + dim3(block_size), + 0u, + (const float*)x_ptr, + (const uint8_t*)w_ptr, + (const float*)scales_ptr, + (const float*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (float*)out_ptr, + B, + M, + N, + K, + E, + has_bias); + } else if (bits_ == 2 && group_size_ == 128) { + enc.add_kernel_node( + &rocm::gather_qmv_kernel, + grid, + dim3(block_size), + 0u, + (const float*)x_ptr, + (const uint8_t*)w_ptr, + (const float*)scales_ptr, + (const float*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (float*)out_ptr, + B, + M, + N, + K, + E, + has_bias); + } else { + throw std::runtime_error( + "Unsupported dtype/bits/group_size combination for float32: bits=" + + std::to_string(bits_) + " gs=" + std::to_string(group_size_)); + } + } else if (x.dtype() == float16) { + if (bits_ == 8 && group_size_ == 32) { + enc.add_kernel_node( + &rocm::gather_qmv_kernel<__half, __half, 8, 32, true>, + grid, + dim3(block_size), + 0u, + (const __half*)x_ptr, + (const uint8_t*)w_ptr, + (const __half*)scales_ptr, + (const __half*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (__half*)out_ptr, + B, + M, + N, + K, + E, + has_bias); + } else if (bits_ == 8 && group_size_ == 64) { + enc.add_kernel_node( + &rocm::gather_qmv_kernel<__half, __half, 8, 64, true>, + grid, + dim3(block_size), + 0u, + (const __half*)x_ptr, + (const uint8_t*)w_ptr, + (const __half*)scales_ptr, + (const __half*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (__half*)out_ptr, + B, + M, + N, + K, + E, + has_bias); + } else if (bits_ == 8 && group_size_ == 128) { + enc.add_kernel_node( + &rocm::gather_qmv_kernel<__half, __half, 8, 128, true>, + grid, + dim3(block_size), + 0u, + (const __half*)x_ptr, + (const uint8_t*)w_ptr, + (const __half*)scales_ptr, + (const __half*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (__half*)out_ptr, + B, + M, + N, + K, + E, + has_bias); + } else if (bits_ == 5 && group_size_ == 32) { + enc.add_kernel_node( + &rocm::gather_qmv_kernel<__half, __half, 5, 32, true>, + grid, + dim3(block_size), + 0u, + (const __half*)x_ptr, + (const uint8_t*)w_ptr, + (const __half*)scales_ptr, + (const __half*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (__half*)out_ptr, + B, + M, + N, + K, + E, + has_bias); + } else if (bits_ == 5 && group_size_ == 64) { + enc.add_kernel_node( + &rocm::gather_qmv_kernel<__half, __half, 5, 64, true>, + grid, + dim3(block_size), + 0u, + (const __half*)x_ptr, + (const uint8_t*)w_ptr, + (const __half*)scales_ptr, + (const __half*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (__half*)out_ptr, + B, + M, + N, + K, + E, + has_bias); + } else if (bits_ == 5 && group_size_ == 128) { + enc.add_kernel_node( + &rocm::gather_qmv_kernel<__half, __half, 5, 128, true>, + grid, + dim3(block_size), + 0u, + (const __half*)x_ptr, + (const uint8_t*)w_ptr, + (const __half*)scales_ptr, + (const __half*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (__half*)out_ptr, + B, + M, + N, + K, + E, + has_bias); + } else if (bits_ == 6 && group_size_ == 32) { + enc.add_kernel_node( + &rocm::gather_qmv_kernel<__half, __half, 6, 32, true>, + grid, + dim3(block_size), + 0u, + (const __half*)x_ptr, + (const uint8_t*)w_ptr, + (const __half*)scales_ptr, + (const __half*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (__half*)out_ptr, + B, + M, + N, + K, + E, + has_bias); + } else if (bits_ == 6 && group_size_ == 64) { + enc.add_kernel_node( + &rocm::gather_qmv_kernel<__half, __half, 6, 64, true>, + grid, + dim3(block_size), + 0u, + (const __half*)x_ptr, + (const uint8_t*)w_ptr, + (const __half*)scales_ptr, + (const __half*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (__half*)out_ptr, + B, + M, + N, + K, + E, + has_bias); + } else if (bits_ == 6 && group_size_ == 128) { + enc.add_kernel_node( + &rocm::gather_qmv_kernel<__half, __half, 6, 128, true>, + grid, + dim3(block_size), + 0u, + (const __half*)x_ptr, + (const uint8_t*)w_ptr, + (const __half*)scales_ptr, + (const __half*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (__half*)out_ptr, + B, + M, + N, + K, + E, + has_bias); + } else if (bits_ == 4 && group_size_ == 32) { + enc.add_kernel_node( + &rocm::gather_qmv_kernel<__half, __half, 4, 32, true>, + grid, + dim3(block_size), + 0u, + (const __half*)x_ptr, + (const uint8_t*)w_ptr, + (const __half*)scales_ptr, + (const __half*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (__half*)out_ptr, + B, + M, + N, + K, + E, + has_bias); + } else if (bits_ == 4 && group_size_ == 64) { + enc.add_kernel_node( + &rocm::gather_qmv_kernel<__half, __half, 4, 64, true>, + grid, + dim3(block_size), + 0u, + (const __half*)x_ptr, + (const uint8_t*)w_ptr, + (const __half*)scales_ptr, + (const __half*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (__half*)out_ptr, + B, + M, + N, + K, + E, + has_bias); + } else if (bits_ == 4 && group_size_ == 128) { + enc.add_kernel_node( + &rocm::gather_qmv_kernel<__half, __half, 4, 128, true>, + grid, + dim3(block_size), + 0u, + (const __half*)x_ptr, + (const uint8_t*)w_ptr, + (const __half*)scales_ptr, + (const __half*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (__half*)out_ptr, + B, + M, + N, + K, + E, + has_bias); + } else if (bits_ == 2 && group_size_ == 32) { + enc.add_kernel_node( + &rocm::gather_qmv_kernel<__half, __half, 2, 32, true>, + grid, + dim3(block_size), + 0u, + (const __half*)x_ptr, + (const uint8_t*)w_ptr, + (const __half*)scales_ptr, + (const __half*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (__half*)out_ptr, + B, + M, + N, + K, + E, + has_bias); + } else if (bits_ == 2 && group_size_ == 64) { + enc.add_kernel_node( + &rocm::gather_qmv_kernel<__half, __half, 2, 64, true>, + grid, + dim3(block_size), + 0u, + (const __half*)x_ptr, + (const uint8_t*)w_ptr, + (const __half*)scales_ptr, + (const __half*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (__half*)out_ptr, + B, + M, + N, + K, + E, + has_bias); + } else if (bits_ == 2 && group_size_ == 128) { + enc.add_kernel_node( + &rocm::gather_qmv_kernel<__half, __half, 2, 128, true>, + grid, + dim3(block_size), + 0u, + (const __half*)x_ptr, + (const uint8_t*)w_ptr, + (const __half*)scales_ptr, + (const __half*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (__half*)out_ptr, + B, + M, + N, + K, + E, + has_bias); + } else { + throw std::runtime_error( + "Unsupported dtype/bits/group_size combination for float16: bits=" + + std::to_string(bits_) + " gs=" + std::to_string(group_size_)); + } + } else if (x.dtype() == bfloat16) { + if (bits_ == 8 && group_size_ == 32) { + enc.add_kernel_node( + &rocm::gather_qmv_kernel, + grid, + dim3(block_size), + 0u, + (const hip_bfloat16*)x_ptr, + (const uint8_t*)w_ptr, + (const hip_bfloat16*)scales_ptr, + (const hip_bfloat16*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (hip_bfloat16*)out_ptr, + B, + M, + N, + K, + E, + has_bias); + } else if (bits_ == 8 && group_size_ == 64) { + enc.add_kernel_node( + &rocm::gather_qmv_kernel, + grid, + dim3(block_size), + 0u, + (const hip_bfloat16*)x_ptr, + (const uint8_t*)w_ptr, + (const hip_bfloat16*)scales_ptr, + (const hip_bfloat16*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (hip_bfloat16*)out_ptr, + B, + M, + N, + K, + E, + has_bias); + } else if (bits_ == 8 && group_size_ == 128) { + enc.add_kernel_node( + &rocm::gather_qmv_kernel, + grid, + dim3(block_size), + 0u, + (const hip_bfloat16*)x_ptr, + (const uint8_t*)w_ptr, + (const hip_bfloat16*)scales_ptr, + (const hip_bfloat16*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (hip_bfloat16*)out_ptr, + B, + M, + N, + K, + E, + has_bias); + } else if (bits_ == 5 && group_size_ == 32) { + enc.add_kernel_node( + &rocm::gather_qmv_kernel, + grid, + dim3(block_size), + 0u, + (const hip_bfloat16*)x_ptr, + (const uint8_t*)w_ptr, + (const hip_bfloat16*)scales_ptr, + (const hip_bfloat16*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (hip_bfloat16*)out_ptr, + B, + M, + N, + K, + E, + has_bias); + } else if (bits_ == 5 && group_size_ == 64) { + enc.add_kernel_node( + &rocm::gather_qmv_kernel, + grid, + dim3(block_size), + 0u, + (const hip_bfloat16*)x_ptr, + (const uint8_t*)w_ptr, + (const hip_bfloat16*)scales_ptr, + (const hip_bfloat16*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (hip_bfloat16*)out_ptr, + B, + M, + N, + K, + E, + has_bias); + } else if (bits_ == 5 && group_size_ == 128) { + enc.add_kernel_node( + &rocm::gather_qmv_kernel, + grid, + dim3(block_size), + 0u, + (const hip_bfloat16*)x_ptr, + (const uint8_t*)w_ptr, + (const hip_bfloat16*)scales_ptr, + (const hip_bfloat16*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (hip_bfloat16*)out_ptr, + B, + M, + N, + K, + E, + has_bias); + } else if (bits_ == 6 && group_size_ == 32) { + enc.add_kernel_node( + &rocm::gather_qmv_kernel, + grid, + dim3(block_size), + 0u, + (const hip_bfloat16*)x_ptr, + (const uint8_t*)w_ptr, + (const hip_bfloat16*)scales_ptr, + (const hip_bfloat16*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (hip_bfloat16*)out_ptr, + B, + M, + N, + K, + E, + has_bias); + } else if (bits_ == 6 && group_size_ == 64) { + enc.add_kernel_node( + &rocm::gather_qmv_kernel, + grid, + dim3(block_size), + 0u, + (const hip_bfloat16*)x_ptr, + (const uint8_t*)w_ptr, + (const hip_bfloat16*)scales_ptr, + (const hip_bfloat16*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (hip_bfloat16*)out_ptr, + B, + M, + N, + K, + E, + has_bias); + } else if (bits_ == 6 && group_size_ == 128) { + enc.add_kernel_node( + &rocm::gather_qmv_kernel, + grid, + dim3(block_size), + 0u, + (const hip_bfloat16*)x_ptr, + (const uint8_t*)w_ptr, + (const hip_bfloat16*)scales_ptr, + (const hip_bfloat16*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (hip_bfloat16*)out_ptr, + B, + M, + N, + K, + E, + has_bias); + } else if (bits_ == 4 && group_size_ == 32) { + enc.add_kernel_node( + &rocm::gather_qmv_kernel, + grid, + dim3(block_size), + 0u, + (const hip_bfloat16*)x_ptr, + (const uint8_t*)w_ptr, + (const hip_bfloat16*)scales_ptr, + (const hip_bfloat16*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (hip_bfloat16*)out_ptr, + B, + M, + N, + K, + E, + has_bias); + } else if (bits_ == 4 && group_size_ == 64) { + enc.add_kernel_node( + &rocm::gather_qmv_kernel, + grid, + dim3(block_size), + 0u, + (const hip_bfloat16*)x_ptr, + (const uint8_t*)w_ptr, + (const hip_bfloat16*)scales_ptr, + (const hip_bfloat16*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (hip_bfloat16*)out_ptr, + B, + M, + N, + K, + E, + has_bias); + } else if (bits_ == 4 && group_size_ == 128) { + enc.add_kernel_node( + &rocm::gather_qmv_kernel, + grid, + dim3(block_size), + 0u, + (const hip_bfloat16*)x_ptr, + (const uint8_t*)w_ptr, + (const hip_bfloat16*)scales_ptr, + (const hip_bfloat16*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (hip_bfloat16*)out_ptr, + B, + M, + N, + K, + E, + has_bias); + } else if (bits_ == 2 && group_size_ == 32) { + enc.add_kernel_node( + &rocm::gather_qmv_kernel, + grid, + dim3(block_size), + 0u, + (const hip_bfloat16*)x_ptr, + (const uint8_t*)w_ptr, + (const hip_bfloat16*)scales_ptr, + (const hip_bfloat16*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (hip_bfloat16*)out_ptr, + B, + M, + N, + K, + E, + has_bias); + } else if (bits_ == 2 && group_size_ == 64) { + enc.add_kernel_node( + &rocm::gather_qmv_kernel, + grid, + dim3(block_size), + 0u, + (const hip_bfloat16*)x_ptr, + (const uint8_t*)w_ptr, + (const hip_bfloat16*)scales_ptr, + (const hip_bfloat16*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (hip_bfloat16*)out_ptr, + B, + M, + N, + K, + E, + has_bias); + } else if (bits_ == 2 && group_size_ == 128) { + enc.add_kernel_node( + &rocm::gather_qmv_kernel, + grid, + dim3(block_size), + 0u, + (const hip_bfloat16*)x_ptr, + (const uint8_t*)w_ptr, + (const hip_bfloat16*)scales_ptr, + (const hip_bfloat16*)biases_ptr, + li_ptr, + ri_ptr, + batch_shape_param, + lhs_idx_strides_param, + rhs_idx_strides_param, + batch_ndim, + (hip_bfloat16*)out_ptr, + B, + M, + N, + K, + E, + has_bias); + } else { + throw std::runtime_error( + "Unsupported dtype/bits/group_size combination for bfloat16: bits=" + + std::to_string(bits_) + " gs=" + std::to_string(group_size_)); + } + } + +#undef has_bias + } +} + +} // namespace mlx::core diff --git a/mlx/backend/rocm/quantized/qmv_kernel.hip b/mlx/backend/rocm/quantized/qmv_kernel.hip new file mode 100644 index 0000000000..c9c625d39a --- /dev/null +++ b/mlx/backend/rocm/quantized/qmv_kernel.hip @@ -0,0 +1,224 @@ +// Optimized quantized matrix-vector multiply (GEMV) kernel for RDNA 3.5. +// +// Each warp (32 threads) cooperatively computes ONE output element by +// iterating along the K dimension with coalesced uint32 loads. +// 8 warps per block → 8 output elements per block. +// +// Key optimizations vs naive kernel: +// 1. Coalesced global memory access (adjacent threads read adjacent words) +// 2. Vectorized uint32 loads (8 values per word for 4-bit) +// 3. Warp shuffle reduction (no shared memory needed for reduction) +// 4. LDS for x vector sharing across 8 warps in a block + +#include "mlx/backend/rocm/quantized/qdequant.hpp" +#include "mlx/backend/rocm/device/config.h" + +#include + +namespace mlx::core::rocm { + +// --------------------------------------------------------------------------- +// qmv_fast_kernel: Warp-cooperative quantized GEMV +// --------------------------------------------------------------------------- +// Grid: dim3(M, ceildiv(N, ROWS_PER_BLOCK)) +// Block: dim3(WARP_SIZE, ROWS_PER_BLOCK) = dim3(32, 8) = 256 threads +// +// Each warp (threadIdx.y selects the warp) computes one output element. +// All 32 lanes iterate over K together with coalesced weight loads. + +template +__global__ __launch_bounds__(256) +void qmv_fast_kernel( + const T* __restrict__ x, // [M, K] + const uint32_t* __restrict__ w, // [N, K/pack_factor_u32] as uint32 + const ScaleT* __restrict__ scales, // [N, K/GROUP_SIZE] + const ScaleT* __restrict__ biases, // [N, K/GROUP_SIZE] or nullptr + T* __restrict__ out, // [M, N] + int M, + int N, + int K, + bool has_bias) +{ + constexpr int PF = pack_factor_u32; // values per uint32 (8 for 4-bit) + constexpr int PPT = packs_per_thread; // uint32 loads per thread (2 for 4-bit) + constexpr int VPT = values_per_thread; // values per thread per step (16) + constexpr int BSK = VPT * WARP_SIZE; // K-elements per warp per step (512) + + const int m = blockIdx.x; // output row + const int n = blockIdx.y * ROWS_PER_BLOCK + threadIdx.y; // output column + const int lane = threadIdx.x; // lane within warp + const int tid = threadIdx.y * WARP_SIZE + threadIdx.x; // flat thread id + + // NOTE: Do NOT early-return here — all threads must participate in __syncthreads. + const bool valid = (m < M && n < N); + + // --- LDS for x vector (shared across all 8 warps) --- + __shared__ float x_shared[BSK]; + + // Per-warp pointers (safe even if n >= N: we just won't write output) + const int w_stride = K / PF; // number of uint32 per weight row + const int clamped_n = (n < N) ? n : 0; // clamp to avoid OOB on pointer setup + const uint32_t* w_row = w + clamped_n * w_stride; + const int num_groups = (K + GROUP_SIZE - 1) / GROUP_SIZE; + const ScaleT* s_row = scales + clamped_n * num_groups; + const ScaleT* b_row = has_bias ? (biases + clamped_n * num_groups) : nullptr; + const T* x_row = x + m * K; + + float acc = 0.0f; + + for (int k_base = 0; k_base < K; k_base += BSK) { + // --- Cooperative load of x into LDS --- + // All 256 threads participate (including invalid ones) to avoid barrier mismatch. + __syncthreads(); + #pragma unroll + for (int i = tid; i < BSK; i += ROWS_PER_BLOCK * WARP_SIZE) { + int k = k_base + i; + x_shared[i] = (k < K) ? to_float(x_row[k]) : 0.0f; + } + __syncthreads(); + + if (!valid) continue; // Skip compute but still participate in barriers + + // --- Each lane loads its slice of x from LDS --- + float x_local[VPT]; + #pragma unroll + for (int i = 0; i < VPT; i++) { + x_local[i] = x_shared[lane * VPT + i]; + } + + // --- Coalesced weight load + dequant + accumulate --- + // Metal-compatible accumulation: separate integer dot product from scaling. + // We accumulate dot(x, q_int) and sum(x) across ALL packs in the same + // group, then apply: acc += scale * total_qdot + bias * total_xsum. + // This matches Metal's qdot() which computes scale*accum + sum*bias + // over all values_per_thread at once. + int w_offset = k_base / PF + lane * PPT; + + // Accumulate integer dot and x-sum across all packs (same group for all) + float group_qdot = 0.0f; + float group_xsum = 0.0f; + + // All PPT packs share the same group (thread's 16 values are contiguous) + int k_val = k_base + lane * VPT; + int group_idx = k_val / GROUP_SIZE; + + #pragma unroll + for (int p = 0; p < PPT; p++) { + uint32_t packed = (w_offset + p < w_stride) ? w_row[w_offset + p] : 0u; + dequant_and_dot(packed, &x_local[p * PF], group_qdot, group_xsum); + } + + // Apply scale and bias ONCE for the whole group (matches Metal) + float scale = (group_idx < num_groups) ? to_float(s_row[group_idx]) : 0.0f; + float bias = (has_bias && group_idx < num_groups) ? to_float(b_row[group_idx]) : 0.0f; + acc += scale * group_qdot + bias * group_xsum; + } + + if (!valid) return; + + // --- Warp reduction --- + acc = warp_reduce_sum(acc); + + // --- Lane 0 writes output --- + if (lane == 0) { + out[m * N + n] = from_float(acc); + } +} + +// --------------------------------------------------------------------------- +// gather_qmv_fast_kernel: Warp-cooperative gather-based quantized GEMV +// --------------------------------------------------------------------------- +// Same as qmv_fast_kernel but with batch index indirection for MoE models. + +template +__global__ __launch_bounds__(256) +void gather_qmv_fast_kernel( + const T* __restrict__ x, // [LHS_B, M, K] + const uint32_t* __restrict__ w, // [E, N, K/pack_factor] as uint32 + const ScaleT* __restrict__ scales, // [E, N, K/GROUP_SIZE] + const ScaleT* __restrict__ biases, // [E, N, K/GROUP_SIZE] or nullptr + const uint32_t* __restrict__ lhs_indices, // [B] + const uint32_t* __restrict__ rhs_indices, // [B] + T* __restrict__ out, // [B, M, N] + int B, int M, int N, int K, int E, int LHS_B, + bool has_bias) +{ + constexpr int PF = pack_factor_u32; + constexpr int PPT = packs_per_thread; + constexpr int VPT = values_per_thread; + constexpr int BSK = VPT * WARP_SIZE; + + const int batch = blockIdx.z; + const int m = blockIdx.x; + const int n = blockIdx.y * ROWS_PER_BLOCK + threadIdx.y; + const int lane = threadIdx.x; + const int tid = threadIdx.y * WARP_SIZE + threadIdx.x; + + const bool valid = (batch < B && m < M && n < N); + + uint32_t lhs_idx = valid ? lhs_indices[batch] : 0; + uint32_t rhs_idx = valid ? rhs_indices[batch] : 0; + + // Clamp indices to valid range to prevent catastrophic OOB on corrupt data. + if (lhs_idx >= static_cast(LHS_B)) lhs_idx = 0; + if (rhs_idx >= static_cast(E)) rhs_idx = 0; + + __shared__ float x_shared[BSK]; + + const int w_stride = K / PF; + const int num_groups = (K + GROUP_SIZE - 1) / GROUP_SIZE; + const int clamped_n = (n < N) ? n : 0; + const uint32_t* w_row = w + rhs_idx * N * w_stride + clamped_n * w_stride; + const ScaleT* s_row = scales + rhs_idx * N * num_groups + clamped_n * num_groups; + const ScaleT* b_row = has_bias ? (biases + rhs_idx * N * num_groups + clamped_n * num_groups) : nullptr; + const T* x_row = x + lhs_idx * M * K + m * K; + + float acc = 0.0f; + + for (int k_base = 0; k_base < K; k_base += BSK) { + __syncthreads(); + #pragma unroll + for (int i = tid; i < BSK; i += ROWS_PER_BLOCK * WARP_SIZE) { + int k = k_base + i; + x_shared[i] = (k < K && valid) ? to_float(x_row[k]) : 0.0f; + } + __syncthreads(); + + if (!valid) continue; + + float x_local[VPT]; + #pragma unroll + for (int i = 0; i < VPT; i++) { + x_local[i] = x_shared[lane * VPT + i]; + } + + int w_offset = k_base / PF + lane * PPT; + + // Accumulate integer dot and x-sum across all packs (same group) + float group_qdot = 0.0f; + float group_xsum = 0.0f; + + int k_val = k_base + lane * VPT; + int group_idx = k_val / GROUP_SIZE; + + #pragma unroll + for (int p = 0; p < PPT; p++) { + uint32_t packed = (w_offset + p < w_stride) ? w_row[w_offset + p] : 0u; + dequant_and_dot(packed, &x_local[p * PF], group_qdot, group_xsum); + } + + float scale = (group_idx < num_groups) ? to_float(s_row[group_idx]) : 0.0f; + float bias = (has_bias && group_idx < num_groups) ? to_float(b_row[group_idx]) : 0.0f; + acc += scale * group_qdot + bias * group_xsum; + } + + if (!valid) return; + + acc = warp_reduce_sum(acc); + + if (lane == 0) { + out[batch * M * N + m * N + n] = from_float(acc); + } +} + +} // namespace mlx::core::rocm diff --git a/mlx/backend/rocm/quantized/qmv_tiled_kernel.hip b/mlx/backend/rocm/quantized/qmv_tiled_kernel.hip new file mode 100644 index 0000000000..c49d9c1968 --- /dev/null +++ b/mlx/backend/rocm/quantized/qmv_tiled_kernel.hip @@ -0,0 +1,505 @@ +// L2 cache-optimized quantized GEMV kernel for RDNA 3/3.5. +// +// Key difference from qmv_fast_kernel: processes TILE_N output columns per +// block instead of ROWS_PER_BLOCK=8. Within each K-tile, all TILE_N columns +// read from the same K-range of the weight matrix. Because adjacent columns +// access adjacent weight rows in the same K-range, these rows are likely to +// be in L2 cache, improving L2 hit rate from ~10% to ~40-70%. +// +// Grid: dim3(M, ceildiv(N, TILE_N)) +// Block: dim3(WARP_SIZE, TILE_N) — one warp per output column +// +// Each warp computes one output element by reducing along K. +// All warps in the block share the same X chunk via LDS. + +#include "mlx/backend/rocm/quantized/qdequant.hpp" +#include "mlx/backend/rocm/device/config.h" + +#include + +namespace mlx::core::rocm { + +// TILE_N is passed as a runtime kernel argument. The host selects it from +// rocm::ArchTuning::qmv_tile_n (per-arch config) and sets block dim to +// (WARP_SIZE, tile_n). The kernel reads tile_n to compute column indices. +// Performance: the shared memory load loop runs exactly 1 iteration for +// standard configs (BSK=512, stride=tile_n*32=512), so no unrolling loss. +static constexpr int TILE_N_MAX = 32; // Max for __launch_bounds__ + +template +__global__ __launch_bounds__(TILE_N_MAX * WARP_SIZE) +void qmv_tiled_kernel( + const T* __restrict__ x, // [M, K] + const uint32_t* __restrict__ w, // [N, K/pack_factor] as uint32 + const ScaleT* __restrict__ scales, // [N, K/GROUP_SIZE] + const ScaleT* __restrict__ biases, // [N, K/GROUP_SIZE] or nullptr + T* __restrict__ out, // [M, N] + int M, + int N, + int K, + bool has_bias, + int tile_n, // Runtime TILE_N from arch config + int n_tiles) // ceil(N / tile_n) — for grid-stride +{ + constexpr int PF = pack_factor_u32; + constexpr int PPT = packs_per_thread; + constexpr int VPT = values_per_thread; + constexpr int BSK = VPT * WARP_SIZE; + + const int m = blockIdx.x; + const int lane = threadIdx.x; + const int tid = threadIdx.y * WARP_SIZE + threadIdx.x; + const int nthreads = tile_n * WARP_SIZE; + + __shared__ float x_shared[BSK]; + + const int w_stride = K / PF; + const int num_groups = (K + GROUP_SIZE - 1) / GROUP_SIZE; + const T* x_row = x + m * K; + + // Grid-stride over column tiles (grid is CU-bounded, not one block per tile) to + // keep the live block count steady. threadIdx.y is warp-uniform, so the + // __syncthreads below stay block-uniform. + for (int tile = blockIdx.y; tile < n_tiles; tile += gridDim.y) { + const int n = tile * tile_n + threadIdx.y; + const bool valid = (m < M && n < N); + const int clamped_n = (n < N) ? n : 0; + const uint32_t* w_row = w + clamped_n * w_stride; + const ScaleT* s_row = scales + clamped_n * num_groups; + const ScaleT* b_row = has_bias ? (biases + clamped_n * num_groups) : nullptr; + + float acc = 0.0f; + + for (int k_base = 0; k_base < K; k_base += BSK) { + __syncthreads(); + for (int i = tid; i < BSK; i += nthreads) { + int k = k_base + i; + x_shared[i] = (k < K) ? to_float(x_row[k]) : 0.0f; + } + __syncthreads(); + + if (!valid) continue; + + // Each lane loads its X slice from LDS + float x_local[VPT]; + #pragma unroll + for (int i = 0; i < VPT; i++) { + x_local[i] = x_shared[lane * VPT + i]; + } + + // Vectorized weight load + dequant + accumulate + int w_offset = k_base / PF + lane * PPT; + + float group_qdot4[4] = {0.0f, 0.0f, 0.0f, 0.0f}; + float group_xsum = 0.0f; + + int k_val = k_base + lane * VPT; + int group_idx = k_val / GROUP_SIZE; + + uint32_t w_local[PPT]; + if (k_base + BSK <= K) { + load_weight_vec_streaming(w_row + w_offset, w_local); + } else { + #pragma unroll + for (int p = 0; p < PPT; p++) { + w_local[p] = (w_offset + p < w_stride) ? w_row[w_offset + p] : 0u; + } + } + + #pragma unroll + for (int p = 0; p < PPT; p++) { + dequant_and_dot4(w_local[p], &x_local[p * PF], group_qdot4, group_xsum); + } + + float scale = (group_idx < num_groups) ? to_float(s_row[group_idx]) : 0.0f; + float bias = (has_bias && group_idx < num_groups) ? to_float(b_row[group_idx]) : 0.0f; + acc += scale * reduce_qdot4(group_qdot4) + bias * group_xsum; + } + + if (!valid) continue; + + // Warp reduction + acc = warp_reduce_sum(acc); + + if (lane == 0) { + out[m * N + n] = from_float(acc); + } + } +} + +// 6-bit tiled QMV. +// +// 6-bit packing is non-power-of-two, so pack_factor_u32 = 32/6 is not an +// integer and the generic qmv_tiled_kernel template cannot be used. This +// specialization reproduces the TILED occupancy/column-tiling/LDS-X-sharing +// structure (full Wave32 waves, tile_n columns/block sharing X in LDS) but +// loads 6-bit weights with explicit byte-aligned math. +// +// Layout: each lane owns VPT6=16 contiguous K-values. 16 six-bit weights = +// 96 bits = 12 bytes = 3 uint32. K%64==0 is enforced upstream for 6-bit, and +// BSK6 = VPT6 * WARP_SIZE = 512 (=> 384 bytes/warp-iter), so every lane's +// 16-weight slice begins on a byte boundary: +// k0 = k_base + lane*16 => k0*6 = (k_base*6) + lane*96 bits (byte-aligned) +// We process two byte-aligned sub-blocks of 8 weights each, extracted with the +// EXACT same uint64_t memcpy + shift + 0x3F mask used by the warp_shared 6-bit +// branch in qmm.hip, so dequant values are bit-identical. +// +// Accumulation uses the SAME Metal-compatible per-group order as the 4/8-bit +// tiled kernel: raw integer qdot + x-sum accumulated separately, then +// acc += scale * reduce_qdot4(qdot) + bias * xsum once per group. +template +__device__ __forceinline__ void dequant_and_dot4_6bit( + const uint8_t* __restrict__ w_bytes, // pointer to lane's 12-byte slice + const float* __restrict__ x_local, // 16 X values for this lane + float (&qdot)[4], + float& x_sum) { + // Load exactly 12 bytes (3 uint32) — never over-reads past the lane slice, + // so the final lane of a K-tile that ends exactly at K stays in-range. + // The lane slice is byte-aligned (verified by the caller), so a uint32 + // triple load is well-defined. + uint32_t w0 = *reinterpret_cast(w_bytes + 0); + uint32_t w1 = *reinterpret_cast(w_bytes + 4); + uint32_t w2 = *reinterpret_cast(w_bytes + 8); + + // Sub-block 0: weights 0..7 (bytes 0..5) = w0 | (low16(w1) << 32). + // Sub-block 1: weights 8..15 (bytes 6..11) = high16(w1) | (w2 << 16). + // bit_offset == 0 for both (k0*6 is byte-aligned), matching warp_shared. + uint64_t sb0 = static_cast(w0) | + (static_cast(w1 & 0xFFFFu) << 32); + uint64_t sb1 = static_cast(w1 >> 16) | + (static_cast(w2) << 16); + + #pragma unroll + for (int sb = 0; sb < 2; sb++) { + uint64_t w_packed = (sb == 0) ? sb0 : sb1; + float q0 = static_cast(w_packed & 0x3F); + float q1 = static_cast((w_packed >> 6) & 0x3F); + float q2 = static_cast((w_packed >> 12) & 0x3F); + float q3 = static_cast((w_packed >> 18) & 0x3F); + float q4 = static_cast((w_packed >> 24) & 0x3F); + float q5 = static_cast((w_packed >> 30) & 0x3F); + float q6 = static_cast((w_packed >> 36) & 0x3F); + float q7 = static_cast((w_packed >> 42) & 0x3F); + const float* xb = x_local + sb * 8; + qdot[0] += xb[0] * q0; + qdot[1] += xb[1] * q1; + qdot[2] += xb[2] * q2; + qdot[3] += xb[3] * q3; + qdot[0] += xb[4] * q4; + qdot[1] += xb[5] * q5; + qdot[2] += xb[6] * q6; + qdot[3] += xb[7] * q7; + x_sum += xb[0] + xb[1] + xb[2] + xb[3] + xb[4] + xb[5] + xb[6] + xb[7]; + } +} + +template +__global__ __launch_bounds__(TILE_N_MAX * WARP_SIZE) +void qmv_tiled_6bit_kernel( + const T* __restrict__ x, // [M, K] + const uint8_t* __restrict__ w, // [N, row_bytes] raw bytes + const ScaleT* __restrict__ scales, // [N, K/GROUP_SIZE] + const ScaleT* __restrict__ biases, // [N, K/GROUP_SIZE] or nullptr + T* __restrict__ out, // [M, N] + int M, + int N, + int K, + bool has_bias, + int tile_n, + int n_tiles) { + constexpr int BITS = 6; + constexpr int VPT6 = 16; // values per thread + constexpr int BSK6 = VPT6 * WARP_SIZE; // 512 + constexpr int BYTES_PER_LANE = (VPT6 * BITS) / 8; // 12 + + const int m = blockIdx.x; + const int lane = threadIdx.x; + const int tid = threadIdx.y * WARP_SIZE + threadIdx.x; + const int nthreads = tile_n * WARP_SIZE; + + __shared__ float x_shared[BSK6]; + + const int row_bytes = (K * BITS + 7) / 8; + const int num_groups = (K + GROUP_SIZE - 1) / GROUP_SIZE; + const T* x_row = x + m * K; + + for (int tile = blockIdx.y; tile < n_tiles; tile += gridDim.y) { + const int n = tile * tile_n + threadIdx.y; + const bool valid = (m < M && n < N); + const int clamped_n = (n < N) ? n : 0; + const uint8_t* w_row = w + clamped_n * row_bytes; + const ScaleT* s_row = scales + clamped_n * num_groups; + const ScaleT* b_row = has_bias ? (biases + clamped_n * num_groups) : nullptr; + + float acc = 0.0f; + + for (int k_base = 0; k_base < K; k_base += BSK6) { + __syncthreads(); + for (int i = tid; i < BSK6; i += nthreads) { + int k = k_base + i; + x_shared[i] = (k < K) ? to_float(x_row[k]) : 0.0f; + } + __syncthreads(); + + if (!valid) continue; + + float x_local[VPT6]; + #pragma unroll + for (int i = 0; i < VPT6; i++) { + x_local[i] = x_shared[lane * VPT6 + i]; + } + + float group_qdot4[4] = {0.0f, 0.0f, 0.0f, 0.0f}; + float group_xsum = 0.0f; + + const int k_val = k_base + lane * VPT6; + const int group_idx = k_val / GROUP_SIZE; + + if (k_base + BSK6 <= K) { + // Fast path: full warp tile in range, byte-aligned 12-byte lane slice. + // k_base is a multiple of BSK6=512, so (k_base*6)/8 = k_base*3/4 is an + // exact byte offset, and lane*12 keeps each lane byte-aligned. + const uint8_t* w_lane = w_row + ((k_base * BITS) / 8) + lane * BYTES_PER_LANE; + dequant_and_dot4_6bit( + w_lane, x_local, group_qdot4, group_xsum); + } else { + // Tail: extract each value with the EXACT bounded warp_shared bit math + // (unpack_packed_value general branch): read byte_idx, plus byte_idx+1 + // only if in range. + #pragma unroll + for (int i = 0; i < VPT6; i++) { + int k = k_val + i; + if (k < K) { + float xv = x_local[i]; + int bit_index = k * BITS; + int byte_idx = bit_index >> 3; + int bit_offset = bit_index & 0x7; + uint32_t window = static_cast(w_row[byte_idx]); + if (byte_idx + 1 < row_bytes) { + window |= static_cast(w_row[byte_idx + 1]) << 8; + } + float q = static_cast((window >> bit_offset) & 0x3Fu); + group_qdot4[i & 3] += xv * q; + group_xsum += xv; + } + } + } + + float scale = (group_idx < num_groups) ? to_float(s_row[group_idx]) : 0.0f; + float bias = (has_bias && group_idx < num_groups) ? to_float(b_row[group_idx]) : 0.0f; + acc += scale * reduce_qdot4(group_qdot4) + bias * group_xsum; + } + + if (!valid) continue; + + acc = warp_reduce_sum(acc); + + if (lane == 0) { + out[m * N + n] = from_float(acc); + } + } +} + +// Gather variant for MoE models +template +__global__ __launch_bounds__(TILE_N_MAX * WARP_SIZE) +void gather_qmv_tiled_kernel( + const T* __restrict__ x, + const uint32_t* __restrict__ w, + const ScaleT* __restrict__ scales, + const ScaleT* __restrict__ biases, + const uint32_t* __restrict__ lhs_indices, + const uint32_t* __restrict__ rhs_indices, + T* __restrict__ out, + int B, int M, int N, int K, int E, int LHS_B, + bool has_bias, + int tile_n) +{ + constexpr int PF = pack_factor_u32; + constexpr int PPT = packs_per_thread; + constexpr int VPT = values_per_thread; + constexpr int BSK = VPT * WARP_SIZE; + + const int batch = blockIdx.z; + const int m = blockIdx.x; + const int n = blockIdx.y * tile_n + threadIdx.y; + const int lane = threadIdx.x; + const int tid = threadIdx.y * WARP_SIZE + threadIdx.x; + const int nthreads = tile_n * WARP_SIZE; + + const bool valid = (batch < B && m < M && n < N); + + uint32_t lhs_idx = valid ? lhs_indices[batch] : 0; + uint32_t rhs_idx = valid ? rhs_indices[batch] : 0; + if (lhs_idx >= static_cast(LHS_B)) lhs_idx = 0; + if (rhs_idx >= static_cast(E)) rhs_idx = 0; + + __shared__ float x_shared[BSK]; + + const int w_stride = K / PF; + const int num_groups = (K + GROUP_SIZE - 1) / GROUP_SIZE; + const int clamped_n = (n < N) ? n : 0; + const uint32_t* w_row = w + rhs_idx * N * w_stride + clamped_n * w_stride; + const ScaleT* s_row = scales + rhs_idx * N * num_groups + clamped_n * num_groups; + const ScaleT* b_row = has_bias ? (biases + rhs_idx * N * num_groups + clamped_n * num_groups) : nullptr; + const T* x_row = x + lhs_idx * M * K + m * K; + + float acc = 0.0f; + + for (int k_base = 0; k_base < K; k_base += BSK) { + __syncthreads(); + for (int i = tid; i < BSK; i += nthreads) { + int k = k_base + i; + x_shared[i] = (k < K && valid) ? to_float(x_row[k]) : 0.0f; + } + __syncthreads(); + + if (!valid) continue; + + float x_local[VPT]; + #pragma unroll + for (int i = 0; i < VPT; i++) { + x_local[i] = x_shared[lane * VPT + i]; + } + + int w_offset = k_base / PF + lane * PPT; + float group_qdot4[4] = {0.0f, 0.0f, 0.0f, 0.0f}; + float group_xsum = 0.0f; + int k_val = k_base + lane * VPT; + int group_idx = k_val / GROUP_SIZE; + + uint32_t w_local[PPT]; + if (k_base + BSK <= K) { + load_weight_vec_streaming(w_row + w_offset, w_local); + } else { + #pragma unroll + for (int p = 0; p < PPT; p++) { + w_local[p] = (w_offset + p < w_stride) ? w_row[w_offset + p] : 0u; + } + } + + #pragma unroll + for (int p = 0; p < PPT; p++) { + dequant_and_dot4(w_local[p], &x_local[p * PF], group_qdot4, group_xsum); + } + + float scale = (group_idx < num_groups) ? to_float(s_row[group_idx]) : 0.0f; + float bias = (has_bias && group_idx < num_groups) ? to_float(b_row[group_idx]) : 0.0f; + acc += scale * reduce_qdot4(group_qdot4) + bias * group_xsum; + } + + if (!valid) return; + acc = warp_reduce_sum(acc); + if (lane == 0) { + out[batch * M * N + m * N + n] = from_float(acc); + } +} + +// 6-bit gather (MoE expert-indexed) variant of qmv_tiled_6bit_kernel. +template +__global__ __launch_bounds__(TILE_N_MAX * WARP_SIZE) +void gather_qmv_tiled_6bit_kernel( + const T* __restrict__ x, + const uint8_t* __restrict__ w, + const ScaleT* __restrict__ scales, + const ScaleT* __restrict__ biases, + const uint32_t* __restrict__ lhs_indices, + const uint32_t* __restrict__ rhs_indices, + T* __restrict__ out, + int B, int M, int N, int K, int E, int LHS_B, + bool has_bias, + int tile_n) { + constexpr int BITS = 6; + constexpr int VPT6 = 16; + constexpr int BSK6 = VPT6 * WARP_SIZE; + constexpr int BYTES_PER_LANE = (VPT6 * BITS) / 8; + + const int batch = blockIdx.z; + const int m = blockIdx.x; + const int n = blockIdx.y * tile_n + threadIdx.y; + const int lane = threadIdx.x; + const int tid = threadIdx.y * WARP_SIZE + threadIdx.x; + const int nthreads = tile_n * WARP_SIZE; + + const bool valid = (batch < B && m < M && n < N); + + uint32_t lhs_idx = valid ? lhs_indices[batch] : 0; + uint32_t rhs_idx = valid ? rhs_indices[batch] : 0; + if (lhs_idx >= static_cast(LHS_B)) lhs_idx = 0; + if (rhs_idx >= static_cast(E)) rhs_idx = 0; + + __shared__ float x_shared[BSK6]; + + const int row_bytes = (K * BITS + 7) / 8; + const int num_groups = (K + GROUP_SIZE - 1) / GROUP_SIZE; + const int clamped_n = (n < N) ? n : 0; + const uint8_t* w_row = + w + static_cast(rhs_idx) * N * row_bytes + clamped_n * row_bytes; + const ScaleT* s_row = + scales + static_cast(rhs_idx) * N * num_groups + clamped_n * num_groups; + const ScaleT* b_row = has_bias + ? (biases + static_cast(rhs_idx) * N * num_groups + clamped_n * num_groups) + : nullptr; + const T* x_row = x + static_cast(lhs_idx) * M * K + m * K; + + float acc = 0.0f; + + for (int k_base = 0; k_base < K; k_base += BSK6) { + __syncthreads(); + for (int i = tid; i < BSK6; i += nthreads) { + int k = k_base + i; + x_shared[i] = (k < K && valid) ? to_float(x_row[k]) : 0.0f; + } + __syncthreads(); + + if (!valid) continue; + + float x_local[VPT6]; + #pragma unroll + for (int i = 0; i < VPT6; i++) { + x_local[i] = x_shared[lane * VPT6 + i]; + } + + float group_qdot4[4] = {0.0f, 0.0f, 0.0f, 0.0f}; + float group_xsum = 0.0f; + + const int k_val = k_base + lane * VPT6; + const int group_idx = k_val / GROUP_SIZE; + + if (k_base + BSK6 <= K) { + const uint8_t* w_lane = w_row + ((k_base * BITS) / 8) + lane * BYTES_PER_LANE; + dequant_and_dot4_6bit( + w_lane, x_local, group_qdot4, group_xsum); + } else { + #pragma unroll + for (int i = 0; i < VPT6; i++) { + int k = k_val + i; + if (k < K) { + float xv = x_local[i]; + int bit_index = k * BITS; + int byte_idx = bit_index >> 3; + int bit_offset = bit_index & 0x7; + uint32_t window = static_cast(w_row[byte_idx]); + if (byte_idx + 1 < row_bytes) { + window |= static_cast(w_row[byte_idx + 1]) << 8; + } + float q = static_cast((window >> bit_offset) & 0x3Fu); + group_qdot4[i & 3] += xv * q; + group_xsum += xv; + } + } + } + + float scale = (group_idx < num_groups) ? to_float(s_row[group_idx]) : 0.0f; + float bias = (has_bias && group_idx < num_groups) ? to_float(b_row[group_idx]) : 0.0f; + acc += scale * reduce_qdot4(group_qdot4) + bias * group_xsum; + } + + if (!valid) return; + acc = warp_reduce_sum(acc); + if (lane == 0) { + out[batch * M * N + m * N + n] = from_float(acc); + } +} + +} // namespace mlx::core::rocm diff --git a/mlx/backend/rocm/quantized/quantized.cpp b/mlx/backend/rocm/quantized/quantized.cpp new file mode 100644 index 0000000000..1232339758 --- /dev/null +++ b/mlx/backend/rocm/quantized/quantized.cpp @@ -0,0 +1,83 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/quantized/quantized.h" +#include "mlx/backend/gpu/copy.h" +#include "mlx/backend/rocm/allocator.h" +#include "mlx/backend/rocm/device.h" +#include "mlx/fast_primitives.h" + +#include + +namespace mlx::core { + +namespace { + +inline array ensure_row_contiguous( + const array& x, + rocm::CommandEncoder& enc, + const Stream& s) { + if (!x.flags().row_contiguous) { + array x_copy = contiguous_copy_gpu(x, s); + enc.add_temporary(x_copy); + return x_copy; + } else { + return x; + } +} + +inline array +ensure_contiguous(const array& x, rocm::CommandEncoder& enc, const Stream& s) { + if (x.flags().row_contiguous || x.flags().col_contiguous) { + return x; + } + array x_copy = contiguous_copy_gpu(x, s); + enc.add_temporary(x_copy); + return x_copy; +} + +} // namespace + +// Note: affine_quantize, affine_dequantize, fp_quantize, fp_dequantize +// are implemented in affine_quantize.hip and fp_quantize.hip +// ConvertFP8 is implemented in convert_fp8.hip + +void fast::Quantize::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + auto& s = stream(); + auto& d = rocm::device(s.device); + auto& enc = d.get_command_encoder(s); + + if (dequantize_) { + auto wq = ensure_row_contiguous(inputs[0], enc, s); + auto scales = ensure_row_contiguous(inputs[1], enc, s); + auto& w = outputs[0]; + + w.set_data(mlx::core::rocm::malloc_async(w.nbytes(), enc)); + + if (mode_ == QuantizationMode::Affine) { + auto biases = ensure_row_contiguous(inputs[2], enc, s); + affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s); + } else { + fp_dequantize(wq, scales, w, group_size_, bits_, enc, s); + } + } else { + auto w = ensure_contiguous(inputs[0], enc, s); + auto& wq = outputs[0]; + auto& scales = outputs[1]; + + wq.set_data(mlx::core::rocm::malloc_async(wq.nbytes(), enc)); + scales.set_data(mlx::core::rocm::malloc_async(scales.nbytes(), enc)); + if (mode_ == QuantizationMode::Affine) { + auto& biases = outputs[2]; + biases.set_data(mlx::core::rocm::malloc_async(biases.nbytes(), enc)); + affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s); + } else { + fp_quantize(w, wq, scales, group_size_, bits_, enc, s); + } + } +} + +// Note: ConvertFP8::eval_gpu is implemented in convert_fp8.hip + +} // namespace mlx::core diff --git a/mlx/backend/rocm/quantized/quantized.h b/mlx/backend/rocm/quantized/quantized.h new file mode 100644 index 0000000000..5469f216fa --- /dev/null +++ b/mlx/backend/rocm/quantized/quantized.h @@ -0,0 +1,51 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include "mlx/array.h" +#include "mlx/backend/rocm/device.h" + +namespace mlx::core { + +// Affine quantization functions +void affine_quantize( + const array& w, + array& wq, + array& scales, + array& biases, + int group_size, + int bits, + rocm::CommandEncoder& enc, + const Stream& s); + +void affine_dequantize( + const array& wq, + const array& scales, + const std::optional& biases, + array& w, + int group_size, + int bits, + rocm::CommandEncoder& enc, + const Stream& s); + +// Floating-point quantization functions +void fp_quantize( + const array& w, + array& wq, + array& scales, + int group_size, + int bits, + rocm::CommandEncoder& enc, + const Stream& s); + +void fp_dequantize( + const array& wq, + const array& scales, + array& w, + int group_size, + int bits, + rocm::CommandEncoder& enc, + const Stream& s); + +} // namespace mlx::core diff --git a/mlx/backend/rocm/random.hip b/mlx/backend/rocm/random.hip new file mode 100644 index 0000000000..04332bd33e --- /dev/null +++ b/mlx/backend/rocm/random.hip @@ -0,0 +1,213 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/allocator.h" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/primitives.h" + +#include + +#include + +namespace mlx::core { + +namespace rocm { + +__constant__ constexpr uint32_t rotations[2][4] = { + {13, 15, 26, 6}, + {17, 29, 16, 24}}; + +union rbits_union { + uint2 val; + uint8_t bytes[2][4]; +}; + +__device__ rbits_union threefry2x32_hash(uint2 key, uint2 count) { + uint32_t ks[] = {key.x, key.y, key.x ^ key.y ^ 0x1BD11BDA}; + + rbits_union v; + v.val.x = count.x + ks[0]; + v.val.y = count.y + ks[1]; + + for (int i = 0; i < 5; ++i) { + for (int j = 0; j < 4; ++j) { + uint32_t r = rotations[i % 2][j]; + v.val.x += v.val.y; + v.val.y = (v.val.y << r) | (v.val.y >> (32 - r)); + v.val.y ^= v.val.x; + } + v.val.x += ks[(i + 1) % 3]; + v.val.y += ks[(i + 2) % 3] + i + 1; + } + + return v; +} + +__global__ void rbitsc_kernel( + const uint32_t* keys, + uint8_t* out, + uint32_t grid_dims_x, + uint32_t grid_dims_y, + bool odd, + uint32_t bytes_per_key) { + uint thread_index = blockIdx.x * blockDim.x + threadIdx.x; + uint index_x = thread_index % grid_dims_x; + uint index_y = thread_index / grid_dims_x; + if (index_x >= grid_dims_x || index_y >= grid_dims_y) { + return; + } + + auto kidx = 2 * index_x; + auto key = make_uint2(keys[kidx], keys[kidx + 1]); + auto half_size = grid_dims_y - odd; + out += index_x * bytes_per_key; + bool drop_last = odd && (index_y == half_size); + auto bits = threefry2x32_hash( + key, make_uint2(index_y, drop_last ? 0 : index_y + grid_dims_y)); + size_t idx = size_t(index_y) << 2; + for (int i = 0; i < 4; ++i) { + out[idx + i] = bits.bytes[0][i]; + } + if (!drop_last) { + idx = (drop_last ? 0 : size_t(index_y) + grid_dims_y) << 2; + if ((index_y + 1) == half_size && (bytes_per_key % 4) > 0) { + int edge_bytes = (bytes_per_key % 4); + for (int i = 0; i < edge_bytes; ++i) { + out[idx + i] = bits.bytes[1][i]; + } + } else { + for (int i = 0; i < 4; ++i) { + out[idx + i] = bits.bytes[1][i]; + } + } + } +} + +__device__ int64_t elem_to_loc_random( + int64_t elem, + const hip_array& shape, + const hip_array& strides, + int ndim) { + int64_t loc = 0; + for (int i = ndim - 1; i >= 0; --i) { + loc += (elem % shape[i]) * strides[i]; + elem /= shape[i]; + } + return loc; +} + +__global__ void rbits_kernel( + const uint32_t* keys, + uint8_t* out, + uint32_t grid_dims_x, + uint32_t grid_dims_y, + bool odd, + uint32_t bytes_per_key, + int32_t ndim, + hip_array key_shape, + hip_array key_strides) { + uint thread_index = blockIdx.x * blockDim.x + threadIdx.x; + uint index_x = thread_index % grid_dims_x; + uint index_y = thread_index / grid_dims_x; + if (index_x >= grid_dims_x || index_y >= grid_dims_y) { + return; + } + + auto kidx = 2 * index_x; + auto k1_elem = elem_to_loc_random(kidx, key_shape, key_strides, ndim); + auto k2_elem = elem_to_loc_random(kidx + 1, key_shape, key_strides, ndim); + auto key = make_uint2(keys[k1_elem], keys[k2_elem]); + auto half_size = grid_dims_y - odd; + out += size_t(index_x) * bytes_per_key; + bool drop_last = odd && (index_y == half_size); + auto bits = threefry2x32_hash( + key, make_uint2(index_y, drop_last ? 0 : index_y + grid_dims_y)); + size_t idx = size_t(index_y) << 2; + for (int i = 0; i < 4; ++i) { + out[idx + i] = bits.bytes[0][i]; + } + if (!drop_last) { + idx = (drop_last ? 0 : size_t(index_y) + grid_dims_y) << 2; + if ((index_y + 1) == half_size && (bytes_per_key % 4) > 0) { + int edge_bytes = (bytes_per_key % 4); + for (int i = 0; i < edge_bytes; ++i) { + out[idx + i] = bits.bytes[1][i]; + } + } else { + for (int i = 0; i < 4; ++i) { + out[idx + i] = bits.bytes[1][i]; + } + } + } +} + +} // namespace rocm + +void RandomBits::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + + // keys has shape (N1, ..., NK, 2) + // out has shape (N1, ..., NK, M1, M2, ...) + auto& keys = inputs[0]; + uint32_t num_keys = keys.size() / 2; + + uint32_t elems_per_key = out.size() / num_keys; + uint32_t bytes_per_key = out.itemsize() * elems_per_key; + + auto& s = stream(); + auto& encoder = rocm::get_command_encoder(s); + + out.set_data(mlx::core::rocm::malloc_async(out.nbytes(), encoder)); + if (out.size() == 0) { + return; + } + + uint32_t out_per_key = (bytes_per_key + 4 - 1) / 4; + uint32_t half_size = out_per_key / 2; + bool odd = out_per_key % 2; + + encoder.set_input_array(keys); + encoder.set_output_array(out); + + uint32_t grid_dims_x = num_keys; + uint32_t grid_dims_y = half_size + odd; + int64_t total = static_cast(grid_dims_x) * grid_dims_y; + + int block_size = 256; + int num_blocks = (total + block_size - 1) / block_size; + num_blocks = std::min(num_blocks, 65535); + + if (keys.flags().row_contiguous) { + encoder.add_kernel_node( + &rocm::rbitsc_kernel, + dim3(num_blocks), dim3(block_size), 0, + gpu_ptr(keys), + gpu_ptr(out), + grid_dims_x, + grid_dims_y, + odd, + bytes_per_key); + } else { + rocm::hip_array shape_arg = {}; + rocm::hip_array strides_arg = {}; + for (int i = 0; i < keys.ndim(); i++) { + shape_arg.data_[i] = static_cast(keys.shape()[i]); + strides_arg.data_[i] = keys.strides()[i]; + } + + encoder.add_kernel_node( + &rocm::rbits_kernel, + dim3(num_blocks), dim3(block_size), 0, + gpu_ptr(keys), + gpu_ptr(out), + grid_dims_x, + grid_dims_y, + odd, + bytes_per_key, + static_cast(keys.ndim()), + shape_arg, + strides_arg); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/rocm/reduce.hip b/mlx/backend/rocm/reduce.hip new file mode 100644 index 0000000000..0895c2fca9 --- /dev/null +++ b/mlx/backend/rocm/reduce.hip @@ -0,0 +1,68 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/reduce/reduce.hpp" +#include "mlx/backend/rocm/device/utils.hpp" +#include "mlx/backend/gpu/copy.h" + +#include +#include + +namespace mlx::core { + +void Reduce::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + array in = inputs[0]; + + // Make sure no identity reductions trickle down here. + assert(!axes_.empty()); + assert(out.size() != in.size()); + + auto& s = stream(); + auto& encoder = rocm::get_command_encoder(s); + + if (in.size() == 0) { + init_reduce(encoder, in, out, reduce_type_); + return; + } + + // Reduce. + ReductionPlan plan = get_reduction_plan(in, axes_); + + // If it is a general reduce then copy the input to a contiguous array and + // recompute the plan. + bool broadcasted = false; + for (int i = 0, j = 0; i < in.ndim() && !broadcasted; i++) { + if (j < axes_.size() && axes_[j] == i) { + j++; + } else { + broadcasted = in.strides(i) == 0; + } + } + if (plan.type == GeneralReduce || broadcasted || !in.flags().contiguous) { + array in_copy = contiguous_copy_gpu(in, s); + encoder.add_temporary(in_copy); + in = in_copy; + plan = get_reduction_plan(in, axes_); + } + + if (plan.type == ContiguousAllReduce) { + all_reduce(encoder, in, out, reduce_type_); + return; + } + + if (plan.type == ContiguousReduce || plan.type == GeneralContiguousReduce) { + row_reduce(encoder, in, out, reduce_type_, axes_, plan); + return; + } + + if (plan.type == ContiguousStridedReduce || + plan.type == GeneralStridedReduce) { + col_reduce(encoder, in, out, reduce_type_, axes_, plan); + return; + } + + throw std::runtime_error("No plan reached in reduce."); +} + +} // namespace mlx::core diff --git a/mlx/backend/rocm/reduce/all_reduce.hip b/mlx/backend/rocm/reduce/all_reduce.hip new file mode 100644 index 0000000000..d96f4bc212 --- /dev/null +++ b/mlx/backend/rocm/reduce/all_reduce.hip @@ -0,0 +1,298 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/allocator.h" +#include "mlx/backend/rocm/device/config.h" +#include "mlx/backend/rocm/reduce/reduce.hpp" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/backend/rocm/device/fp16_math.hpp" +#include "mlx/dtype_utils.h" + +#include + +namespace mlx::core { + +namespace rocm { + +// Helper to handle warp shuffle for different types +template +__device__ T warp_shfl_down_all(T val, int offset) { + return __shfl_down(val, offset); +} + +// Specialization for hip_bfloat16 - convert to float for shuffle +template <> +__device__ hip_bfloat16 warp_shfl_down_all(hip_bfloat16 val, int offset) { + float f = bf16_to_float(val); + f = __shfl_down(f, offset); + return float_to_bf16(f); +} + +// Specialization for __half - convert to float for shuffle +template <> +__device__ __half warp_shfl_down_all(__half val, int offset) { + float f = __half2float(val); + f = __shfl_down(f, offset); + return __float2half(f); +} + +// Specialization for hipFloatComplex +template <> +__device__ hipFloatComplex warp_shfl_down_all(hipFloatComplex val, int offset) { + return make_hipFloatComplex( + __shfl_down(val.x, offset), + __shfl_down(val.y, offset)); +} + +template +__device__ U warp_reduce(U val, Op op) { + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + val = op(val, warp_shfl_down_all(val, offset)); + } + return val; +} + +// Helper to cast input to accumulator type +template +__device__ U cast_to_acc(T val) { + if constexpr (std::is_same_v) { + // For And/Or operations, convert to bool + if constexpr (is_complex_v) { + return val.x != 0 || val.y != 0; + } else { + return static_cast(val); + } + } else { + return static_cast(val); + } +} + +template +__global__ void all_reduce_kernel( + const T* __restrict__ in, + U* __restrict__ out, + size_t block_step, + size_t size) { + __shared__ U shared_data[32]; + + const U init = ReduceInit::value(); + Op op; + + U acc = init; + + size_t start = blockIdx.x * block_step; + size_t end = min(start + block_step, size); + + // Each thread processes multiple elements + for (size_t i = start + threadIdx.x * N; i < end; i += blockDim.x * N) { + #pragma unroll + for (int j = 0; j < N && (i + j) < end; ++j) { + acc = op(acc, cast_to_acc(in[i + j])); + } + } + + // Warp-level reduction + int lane = threadIdx.x % WARP_SIZE; + int warp_id = threadIdx.x / WARP_SIZE; + + acc = warp_reduce(acc, op); + + if (lane == 0) { + shared_data[warp_id] = acc; + } + __syncthreads(); + + // Final reduction by first warp + int num_warps = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE; + if (warp_id == 0) { + acc = (lane < num_warps) ? shared_data[lane] : init; + acc = warp_reduce(acc, op); + + if (lane == 0) { + out[blockIdx.x] = acc; + } + } +} + +} // namespace rocm + +// Dispatch reduce operations +template +void dispatch_reduce_ops(Reduce::ReduceType reduce_type, F&& f) { + switch (reduce_type) { + case Reduce::Sum: + f(type_identity{}); + break; + case Reduce::Prod: + f(type_identity{}); + break; + case Reduce::Max: + f(type_identity{}); + break; + case Reduce::Min: + f(type_identity{}); + break; + case Reduce::And: + f(type_identity{}); + break; + case Reduce::Or: + f(type_identity{}); + break; + default: + throw std::runtime_error("Unsupported reduce type"); + } +} + +// ReduceResult type trait - determines output type for reduction +template +struct ReduceResult { + using type = T; +}; + +// And always produces bool +template +struct ReduceResult { + using type = bool; +}; + +// Or always produces bool +template +struct ReduceResult { + using type = bool; +}; + +// Sum on small integers produces int32 +template +struct ReduceResult { + using type = std::conditional_t< + (std::is_integral_v && sizeof(T) <= 4), + int32_t, + T>; +}; + +// Prod on small integers produces int32 +template +struct ReduceResult { + using type = std::conditional_t< + (std::is_integral_v && sizeof(T) <= 4), + int32_t, + T>; +}; + +// Check if a reduce operation is valid for a type +template +constexpr bool is_valid_reduce_op() { + // All reduce operations work on all types + // And/Or will cast to bool internally + return true; +} + +void all_reduce( + rocm::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type) { + constexpr int N_READS = 4; + + out.set_data(mlx::core::rocm::malloc_async(out.nbytes(), encoder)); + + auto get_args = [](size_t size, int N) { + int threads = std::min(512, static_cast((size + N - 1) / N)); + threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + int reductions_per_step = threads * N; + size_t steps_needed = (size + reductions_per_step - 1) / reductions_per_step; + + int blocks; + if (steps_needed < 32) { + blocks = 1; + } else if (steps_needed < 128) { + blocks = 32; + } else if (steps_needed < 512) { + blocks = 128; + } else if (steps_needed < 1024) { + blocks = 512; + } else { + blocks = 1024; + } + + size_t steps_per_block = (steps_needed + blocks - 1) / blocks; + size_t block_step = steps_per_block * reductions_per_step; + + return std::make_tuple(blocks, threads, block_step); + }; + + int blocks, threads; + size_t block_step; + size_t insize = in.size(); + Dtype dt = in.dtype(); + + std::tie(blocks, threads, block_step) = get_args(insize, N_READS); + + encoder.set_input_array(in); + + // For multi-block reduction, we need an intermediate buffer + if (blocks > 1) { + array intermediate({blocks}, out.dtype(), nullptr, {}); + intermediate.set_data(mlx::core::rocm::malloc_async(intermediate.nbytes(), encoder)); + encoder.add_temporary(intermediate); + encoder.set_output_array(intermediate); + + // First pass: reduce to intermediate + dispatch_all_types(dt, [&](auto type_tag) { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = hip_type_t; + using U = typename ReduceResult::type; + + if constexpr (is_valid_reduce_op()) { + encoder.add_kernel_node( + &rocm::all_reduce_kernel, + dim3(blocks), dim3(threads), 0, + gpu_ptr(in), gpu_ptr(intermediate), block_step, insize); + } + }); + }); + + // Set the input for the next step and recalculate the blocks + dt = intermediate.dtype(); + insize = intermediate.size(); + std::tie(blocks, threads, block_step) = get_args(insize, N_READS); + encoder.set_input_array(intermediate); + + // Second pass: reduce intermediate to output + encoder.set_output_array(out); + dispatch_all_types(dt, [&](auto type_tag) { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = hip_type_t; + using U = typename ReduceResult::type; + + if constexpr (is_valid_reduce_op()) { + encoder.add_kernel_node( + &rocm::all_reduce_kernel, + dim3(1), dim3(threads), 0, + gpu_ptr(intermediate), gpu_ptr(out), block_step, insize); + } + }); + }); + } else { + // Single block reduction + encoder.set_output_array(out); + dispatch_all_types(dt, [&](auto type_tag) { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = hip_type_t; + using U = typename ReduceResult::type; + + if constexpr (is_valid_reduce_op()) { + encoder.add_kernel_node( + &rocm::all_reduce_kernel, + dim3(1), dim3(threads), 0, + gpu_ptr(in), gpu_ptr(out), block_step, insize); + } + }); + }); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/rocm/reduce/col_reduce.hip b/mlx/backend/rocm/reduce/col_reduce.hip new file mode 100644 index 0000000000..2de475cb87 --- /dev/null +++ b/mlx/backend/rocm/reduce/col_reduce.hip @@ -0,0 +1,488 @@ +// Copyright © 2025 Apple Inc. + +#include + +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/reduce/reduce.hpp" +#include "mlx/backend/rocm/reduce/reduce_utils.hpp" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/backend/rocm/device/utils.hpp" + +#include + +namespace mlx::core { + +namespace rocm { + +struct ColReduceArgs { + // The size of the contiguous column reduction. + size_t reduction_size; + int64_t reduction_stride; + + // Input shape and strides excluding the reduction axes. + Shape shape; + Strides strides; + int ndim; + + // Input shape and strides of the reduction axes (including last dimension). + Shape reduce_shape; + Strides reduce_strides; + int reduce_ndim; + + // The number of column we are reducing. Namely prod(reduce_shape). + size_t non_col_reductions; + + ColReduceArgs( + const array& in, + const ReductionPlan& plan, + const std::vector& axes) { + using ShapeVector = decltype(plan.shape); + using StridesVector = decltype(plan.strides); + + ShapeVector shape_vec; + StridesVector strides_vec; + + assert(!plan.shape.empty()); + reduction_size = plan.shape.back(); + reduction_stride = plan.strides.back(); + + int64_t stride_back = 1; + std::tie(shape_vec, strides_vec) = shapes_without_reduction_axes(in, axes); + while (!shape_vec.empty() && stride_back < reduction_stride) { + stride_back *= shape_vec.back(); + shape_vec.pop_back(); + strides_vec.pop_back(); + } + std::vector indices(shape_vec.size()); + std::iota(indices.begin(), indices.end(), 0); + std::sort(indices.begin(), indices.end(), [&](int left, int right) { + return strides_vec[left] > strides_vec[right]; + }); + ShapeVector sorted_shape; + StridesVector sorted_strides; + for (auto idx : indices) { + sorted_shape.push_back(shape_vec[idx]); + sorted_strides.push_back(strides_vec[idx]); + } + std::tie(shape_vec, strides_vec) = + collapse_contiguous_dims(sorted_shape, sorted_strides); + + // Copy to fixed-size arrays + ndim = shape_vec.size(); + for (int i = 0; i < ndim && i < MAX_NDIM; i++) { + shape[i] = shape_vec[i]; + strides[i] = strides_vec[i]; + } + + reduce_ndim = plan.shape.size(); + for (int i = 0; i < reduce_ndim && i < MAX_NDIM; i++) { + reduce_shape[i] = plan.shape[i]; + reduce_strides[i] = plan.strides[i]; + } + + non_col_reductions = 1; + for (int i = 0; i < reduce_ndim - 1; i++) { + non_col_reductions *= reduce_shape[i]; + } + } +}; + +// Warp reduce helper using runtime warp size +template +__device__ T warp_reduce_col(T val, Op op) { + for (int offset = warpSize / 2; offset > 0; offset /= 2) { + T other = __shfl_xor(val, offset); + val = op(val, other); + } + return val; +} + +// Helper to cast input to accumulator type +template +__device__ U cast_to_col(T val) { + if constexpr (std::is_same_v) { + // For And/Or operations, convert to bool + return static_cast(val); + } else { + return static_cast(val); + } +} + +template < + typename T, + typename U, + typename Op, + int NDIM, + int BM, + int BN, + int N_READS = 4, + int BLOCKS = 1> +__global__ void col_reduce_looped( + const T* in, + U* out, + ColReduceArgs args, + int64_t out_size) { + + constexpr int threads_per_row = BN / N_READS; + + // Compute the indices for the tile + size_t tile_idx = blockIdx.x + blockIdx.y * gridDim.x; + size_t tile_x = tile_idx % ((args.reduction_stride + BN - 1) / BN); + size_t tile_y = tile_idx / ((args.reduction_stride + BN - 1) / BN); + size_t tile_out = tile_y / out_size; + tile_y = tile_y % out_size; + + // Compute the indices for the thread within the tile + short thread_x = threadIdx.x % threads_per_row; + short thread_y = threadIdx.x / threads_per_row; + + // Move the input pointer + in += elem_to_loc(tile_y, args.shape.data(), args.strides.data(), args.ndim) + + tile_x * BN; + + // Initialize the running totals + Op op; + U totals[N_READS]; + for (int i = 0; i < N_READS; i++) { + totals[i] = ReduceInit::value(); + } + + size_t total = args.non_col_reductions * args.reduction_size; + size_t per_block, start, end; + if constexpr (BLOCKS > 1) { + per_block = (total + BLOCKS - 1) / BLOCKS; + start = tile_out * per_block + thread_y; + end = min((tile_out + 1) * per_block, total); + } else { + per_block = total; + start = thread_y; + end = total; + } + + LoopedElemToLoc 2)> loop(args.reduce_ndim); + loop.next(start, args.reduce_shape.data(), args.reduce_strides.data()); + + int remaining = args.reduction_stride - tile_x * BN; + int base_idx = thread_x * N_READS; + + for (size_t r = start; r < end; r += BM) { + // Load values + for (int i = 0; i < N_READS; i++) { + int idx = base_idx + i; + if (idx < remaining) { + totals[i] = op(totals[i], cast_to_col(in[loop.location() + idx])); + } + } + loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); + } + + // Do warp reduce for each output. + constexpr int n_outputs = BN / threads_per_row; + __shared__ U shared_vals[BM * BN]; + short s_idx = thread_y * BN + thread_x * N_READS; + for (int i = 0; i < N_READS; i++) { + shared_vals[s_idx + i] = totals[i]; + } + __syncthreads(); + + // Reduce across threads + if (thread_y == 0) { + for (int i = 0; i < N_READS; i++) { + U val = ReduceInit::value(); + for (int j = 0; j < BM; j++) { + val = op(val, shared_vals[j * BN + thread_x * N_READS + i]); + } + totals[i] = val; + } + } + __syncthreads(); + + // Write result. + if (thread_y == 0) { + if (BLOCKS > 1) { + out += tile_out * out_size * args.reduction_stride; + } + for (int i = 0; i < N_READS; i++) { + int idx = thread_x * N_READS + i; + if (tile_x * BN + idx < args.reduction_stride) { + out[tile_y * args.reduction_stride + tile_x * BN + idx] = totals[i]; + } + } + } +} + +template +__global__ void col_reduce_small( + const T* in, + U* out, + ColReduceArgs args, + size_t total) { + Op op; + + const auto idx = (blockIdx.x * blockDim.x + threadIdx.x) * N_READS; + const auto before_axis = idx / args.reduction_stride; + const auto after_axis = idx % args.reduction_stride; + const auto offset = + before_axis * args.reduction_stride * args.reduction_size + after_axis; + + if (idx >= total) { + return; + } + + in += offset; + out += idx; + + AlignedVector accumulator; + for (int i = 0; i < N_READS; i++) { + accumulator[i] = ReduceInit::value(); + } + + for (size_t i = 0; i < args.reduction_size; i++) { + auto values = load_vector(in, 0); + + for (int j = 0; j < N_READS; j++) { + accumulator[j] = op(accumulator[j], cast_to_col(values[j])); + } + + in += args.reduction_stride; + } + + store_vector(out, 0, accumulator); +} + +// Simple column reduction kernel for contiguous strided reduce +template +__global__ void col_reduce_simple_kernel( + const T* in, + U* out, + int n_rows, + int n_cols) { + int col = blockIdx.x * blockDim.x + threadIdx.x; + if (col >= n_cols) return; + + Op op; + U val = ReduceInit::value(); + + for (int row = 0; row < n_rows; row++) { + val = op(val, cast_to_col(in[row * n_cols + col])); + } + + out[col] = val; +} + +} // namespace rocm + +inline auto output_grid_for_col_reduce( + const array& out, + const rocm::ColReduceArgs& args, + int bn, + int outer = 1) { + int gx, gy = 1; + size_t n_inner_blocks = ceildiv(args.reduction_stride, (int64_t)bn); + size_t n_outer_blocks = out.size() / args.reduction_stride; + size_t n_blocks = n_outer_blocks * n_inner_blocks * outer; + while (n_blocks / gy > INT32_MAX) { + gy *= 2; + } + gx = ceildiv(n_blocks, (size_t)gy); + + return dim3(gx, gy, 1); +} + +// Dispatch for reduce types - excludes complex64 which doesn't support most reduce ops +template +void dispatch_reduce_types(Dtype dt, Func&& func) { + switch (dt) { + case bool_: + func(type_identity{}); + break; + case uint8: + func(type_identity{}); + break; + case uint16: + func(type_identity{}); + break; + case uint32: + func(type_identity{}); + break; + case uint64: + func(type_identity{}); + break; + case int8: + func(type_identity{}); + break; + case int16: + func(type_identity{}); + break; + case int32: + func(type_identity{}); + break; + case int64: + func(type_identity{}); + break; + case float16: + func(type_identity{}); + break; + case bfloat16: + func(type_identity{}); + break; + case float32: + func(type_identity{}); + break; + case float64: + func(type_identity{}); + break; + case complex64: + throw std::runtime_error("Complex types not yet supported for reduce operations on ROCm"); + default: + throw std::runtime_error("Unsupported dtype for reduce"); + } +} + +// Dispatch helper for reduce operations - no type restrictions +// The cast_to function handles conversion to bool for And/Or +template +void dispatch_reduce_ops(Reduce::ReduceType reduce_type, Func&& func) { + switch (reduce_type) { + case Reduce::Sum: + func(type_identity{}); + break; + case Reduce::Prod: + func(type_identity{}); + break; + case Reduce::Max: + func(type_identity{}); + break; + case Reduce::Min: + func(type_identity{}); + break; + case Reduce::And: + func(type_identity{}); + break; + case Reduce::Or: + func(type_identity{}); + break; + default: + throw std::runtime_error("Unsupported reduce type"); + } +} + +// Dispatch helper for reduce ndim +template +void dispatch_reduce_ndim(int ndim, Func&& func) { + switch (ndim) { + case 1: + func(std::integral_constant{}); + break; + case 2: + func(std::integral_constant{}); + break; + case 3: + func(std::integral_constant{}); + break; + case 4: + func(std::integral_constant{}); + break; + default: + func(std::integral_constant{}); + break; + } +} + +void col_reduce_looped( + rocm::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type, + const std::vector& axes, + const ReductionPlan& plan, + const rocm::ColReduceArgs& args) { + // Allocate data for the output + allocate_same_layout(out, in, axes, encoder); + + encoder.set_input_array(in); + encoder.set_output_array(out); + + dispatch_reduce_types(in.dtype(), [&](auto type_tag) { + using T = hip_type_t; + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) { + using OP = typename decltype(reduce_type_tag)::type; + using U = typename rocm::ReduceResult::type; + + constexpr int N_READS = 4; + constexpr int BM = 32; + constexpr int BN = 32; + dim3 grid = output_grid_for_col_reduce(out, args, BN); + int blocks = BM * BN / N_READS; + + encoder.add_kernel_node( + &rocm::col_reduce_looped, + grid, dim3(blocks), 0, + gpu_ptr(in), + gpu_ptr(out), + args, + out.size() / args.reduction_stride); + }); + }); + }); +} + +void col_reduce_small( + rocm::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type, + const std::vector& axes, + const ReductionPlan& plan, + const rocm::ColReduceArgs& args) { + // Allocate data for the output + allocate_same_layout(out, in, axes, encoder); + + encoder.set_input_array(in); + encoder.set_output_array(out); + + dispatch_reduce_types(in.dtype(), [&](auto type_tag) { + using T = hip_type_t; + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + using OP = typename decltype(reduce_type_tag)::type; + using U = typename rocm::ReduceResult::type; + + constexpr int N_READS = 4; + int block_size = 256; + int num_blocks = (out.size() + block_size * N_READS - 1) / (block_size * N_READS); + + encoder.add_kernel_node( + &rocm::col_reduce_small, + dim3(num_blocks), dim3(block_size), 0, + gpu_ptr(in), + gpu_ptr(out), + args, + out.size()); + }); + }); +} + +void col_reduce( + rocm::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type, + const std::vector& axes, + const ReductionPlan& plan) { + + // Make the args struct to help route to the best kernel + rocm::ColReduceArgs args(in, plan, axes); + + // Small col reduce with a single or contiguous reduction axis + if (args.non_col_reductions == 1 && args.reduction_size <= 32 && + args.reduction_stride % 4 == 0) { + col_reduce_small(encoder, in, out, reduce_type, axes, plan, args); + return; + } + + // Fallback col reduce + col_reduce_looped(encoder, in, out, reduce_type, axes, plan, args); +} + +} // namespace mlx::core diff --git a/mlx/backend/rocm/reduce/init_reduce.hip b/mlx/backend/rocm/reduce/init_reduce.hip new file mode 100644 index 0000000000..039d9b9f93 --- /dev/null +++ b/mlx/backend/rocm/reduce/init_reduce.hip @@ -0,0 +1,81 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/allocator.h" +#include "mlx/backend/rocm/reduce/reduce.hpp" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/dtype_utils.h" + +#include + +namespace mlx::core { + +namespace rocm { + +template +__global__ void init_reduce_kernel(U* out, size_t size) { + size_t index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < size) { + out[index] = ReduceInit::value(); + } +} + +} // namespace rocm + +// Dispatch reduce operations +template +void dispatch_reduce_ops_init(Reduce::ReduceType reduce_type, F&& f) { + switch (reduce_type) { + case Reduce::Sum: + f(type_identity{}); + break; + case Reduce::Prod: + f(type_identity{}); + break; + case Reduce::Max: + f(type_identity{}); + break; + case Reduce::Min: + f(type_identity{}); + break; + case Reduce::And: + f(type_identity{}); + break; + case Reduce::Or: + f(type_identity{}); + break; + default: + throw std::runtime_error("Unsupported reduce type"); + } +} + +void init_reduce( + rocm::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type) { + // Allocate if needed + if (out.data_shared_ptr() == nullptr) { + out.set_data(mlx::core::rocm::malloc_async(out.nbytes(), encoder)); + } + + encoder.set_output_array(out); + + int block_size = 256; + int num_blocks = (out.size() + block_size - 1) / block_size; + + dispatch_all_types(in.dtype(), [&](auto type_tag) { + dispatch_reduce_ops_init(reduce_type, [&](auto reduce_type_tag) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = hip_type_t; + using U = typename rocm::ReduceResult::type; + + encoder.add_kernel_node( + &rocm::init_reduce_kernel, + dim3(num_blocks), dim3(block_size), 0, + gpu_ptr(out), out.size()); + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/rocm/reduce/reduce.hpp b/mlx/backend/rocm/reduce/reduce.hpp new file mode 100644 index 0000000000..3c000dc14f --- /dev/null +++ b/mlx/backend/rocm/reduce/reduce.hpp @@ -0,0 +1,294 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/common/reduce.h" +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/device/utils.hpp" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include + +namespace mlx::core { + +namespace rocm { + +// Reduce operations for ROCm + +// And and Or only work with bool +struct And { + __device__ bool operator()(bool a, bool b) const { + return a && b; + } +}; + +struct Or { + __device__ bool operator()(bool a, bool b) const { + return a || b; + } +}; + +struct Sum { + template + __device__ T operator()(T a, T b) const { + return a + b; + } + + // Specialization for hipFloatComplex + __device__ hipFloatComplex + operator()(hipFloatComplex a, hipFloatComplex b) const { + return make_hipFloatComplex(a.x + b.x, a.y + b.y); + } +}; + +struct Prod { + template + __device__ T operator()(T a, T b) const { + return a * b; + } + + // Specialization for hipFloatComplex (complex multiplication) + __device__ hipFloatComplex + operator()(hipFloatComplex a, hipFloatComplex b) const { + return make_hipFloatComplex(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x); + } +}; + +struct Max { + template < + typename T, + std::enable_if_t< + !is_complex_v && !std::is_same_v && + !std::is_same_v, + int> = 0> + __device__ T operator()(T a, T b) const { + return a > b ? a : b; + } + + // Specialization for float with NaN handling + __device__ float operator()(float a, float b) const { + if (isnan(a) || isnan(b)) { + return numeric_limits::quiet_NaN(); + } + return a > b ? a : b; + } + + // Specialization for double with NaN handling + __device__ double operator()(double a, double b) const { + if (isnan(a) || isnan(b)) { + return numeric_limits::quiet_NaN(); + } + return a > b ? a : b; + } + + // Specialization for hipFloatComplex + __device__ hipFloatComplex + operator()(hipFloatComplex a, hipFloatComplex b) const { + // Check for NaN + if (isnan(a.x) || isnan(a.y)) { + return a; + } + if (isnan(b.x) || isnan(b.y)) { + return b; + } + // Compare by magnitude (real^2 + imag^2), then by real part + float mag_a = a.x * a.x + a.y * a.y; + float mag_b = b.x * b.x + b.y * b.y; + if (mag_a != mag_b) { + return mag_a > mag_b ? a : b; + } + return a.x > b.x ? a : b; + } +}; + +struct Min { + template < + typename T, + std::enable_if_t< + !is_complex_v && !std::is_same_v && + !std::is_same_v, + int> = 0> + __device__ T operator()(T a, T b) const { + return a < b ? a : b; + } + + // Specialization for float with NaN handling + __device__ float operator()(float a, float b) const { + if (isnan(a) || isnan(b)) { + return numeric_limits::quiet_NaN(); + } + return a < b ? a : b; + } + + // Specialization for double with NaN handling + __device__ double operator()(double a, double b) const { + if (isnan(a) || isnan(b)) { + return numeric_limits::quiet_NaN(); + } + return a < b ? a : b; + } + + // Specialization for hipFloatComplex + __device__ hipFloatComplex + operator()(hipFloatComplex a, hipFloatComplex b) const { + // Check for NaN + if (isnan(a.x) || isnan(a.y)) { + return a; + } + if (isnan(b.x) || isnan(b.y)) { + return b; + } + // Compare by magnitude (real^2 + imag^2), then by real part + float mag_a = a.x * a.x + a.y * a.y; + float mag_b = b.x * b.x + b.y * b.y; + if (mag_a != mag_b) { + return mag_a < mag_b ? a : b; + } + return a.x < b.x ? a : b; + } +}; + +// Reduce result type mapping +template +struct ReduceResult { + using type = T; +}; + +// And and Or always return bool +template +struct ReduceResult { + using type = bool; +}; + +template +struct ReduceResult { + using type = bool; +}; + +// Sum and Prod promote small integers to int32_t +template +struct ReduceResult { + using type = + std::conditional_t<(std::is_integral_v && sizeof(T) <= 4), int32_t, T>; +}; + +template +struct ReduceResult { + using type = + std::conditional_t<(std::is_integral_v && sizeof(T) <= 4), int32_t, T>; +}; + +// Reduce init value +template +struct ReduceInit; + +template +struct ReduceInit { + static __device__ bool value() { + return true; + } +}; + +template +struct ReduceInit { + static __device__ bool value() { + return false; + } +}; + +template +struct ReduceInit { + static __device__ auto value() { + using ResultT = typename ReduceResult::type; + return ResultT(0); + } +}; + +// Specialization for hipFloatComplex +template <> +struct ReduceInit { + static __device__ hipFloatComplex value() { + return make_hipFloatComplex(0.0f, 0.0f); + } +}; + +template +struct ReduceInit { + static __device__ auto value() { + using ResultT = typename ReduceResult::type; + return ResultT(1); + } +}; + +// Specialization for hipFloatComplex +template <> +struct ReduceInit { + static __device__ hipFloatComplex value() { + return make_hipFloatComplex(1.0f, 0.0f); + } +}; + +template +struct ReduceInit { + static __device__ T value() { + return Limits::min(); + } +}; + +// Specialization for hipFloatComplex +template <> +struct ReduceInit { + static __device__ hipFloatComplex value() { + return make_hipFloatComplex(Limits::min(), Limits::min()); + } +}; + +template +struct ReduceInit { + static __device__ T value() { + return Limits::max(); + } +}; + +// Specialization for hipFloatComplex +template <> +struct ReduceInit { + static __device__ hipFloatComplex value() { + return make_hipFloatComplex(Limits::max(), Limits::max()); + } +}; + +} // namespace rocm + +// Column reduction function declarations +void col_reduce( + rocm::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type, + const std::vector& axes, + const ReductionPlan& plan); + +void all_reduce( + rocm::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type); + +void row_reduce( + rocm::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type, + const std::vector& axes, + const ReductionPlan& plan); + +void init_reduce( + rocm::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type); + +} // namespace mlx::core diff --git a/mlx/backend/rocm/reduce/reduce_ops.hpp b/mlx/backend/rocm/reduce/reduce_ops.hpp new file mode 100644 index 0000000000..5fd1a64e06 --- /dev/null +++ b/mlx/backend/rocm/reduce/reduce_ops.hpp @@ -0,0 +1,323 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/rocm/device/atomic_ops.hpp" +#include "mlx/backend/rocm/device/cast_op.hpp" +#include "mlx/backend/rocm/device/utils.hpp" + +#include + +namespace mlx::core::rocm { + +// Reduce ops with atomic_update for col_reduce + +struct And { + __device__ __forceinline__ bool operator()(bool a, bool b) const { + return a && b; + } + + template + __device__ static constexpr T init() { + return true; + } + + __device__ void atomic_update(bool* x, bool y) { + atomic_and(x, y); + } +}; + +struct Or { + __device__ __forceinline__ bool operator()(bool a, bool b) const { + return a || b; + } + + template + __device__ static constexpr T init() { + return false; + } + + __device__ void atomic_update(bool* x, bool y) { + atomic_or(x, y); + } +}; + +struct Sum { + template + __device__ __forceinline__ T operator()(T a, T b) const { + return a + b; + } + + // Specialization for hipFloatComplex + __device__ __forceinline__ hipFloatComplex + operator()(hipFloatComplex a, hipFloatComplex b) const { + return make_hipFloatComplex(a.x + b.x, a.y + b.y); + } + + template + __device__ static constexpr T init() { + return T(0); + } + + template + __device__ void atomic_update(T* x, T y) { + atomic_reduce(x, y); + } + + __device__ void atomic_update(float* x, float y) { + atomicAdd(x, y); + } + + __device__ void atomic_update(int* x, int y) { + atomicAdd(x, y); + } +}; + +struct Prod { + template + __device__ __forceinline__ T operator()(T a, T b) const { + return a * b; + } + + // Specialization for hipFloatComplex (complex multiplication) + __device__ __forceinline__ hipFloatComplex + operator()(hipFloatComplex a, hipFloatComplex b) const { + return make_hipFloatComplex(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x); + } + + template + __device__ static constexpr T init() { + return T(1); + } + + template + __device__ void atomic_update(T* x, T y) { + atomic_reduce(x, y); + } +}; + +struct Max { + template < + typename T, + std::enable_if_t< + !is_complex_v && !std::is_same_v && + !std::is_same_v, + int> = 0> + __device__ __forceinline__ T operator()(T a, T b) const { + return a > b ? a : b; + } + + // Specialization for float with NaN handling + __device__ __forceinline__ float operator()(float a, float b) const { + if (isnan(a) || isnan(b)) { + return a > b ? a : b; // Propagate NaN + } + return a > b ? a : b; + } + + // Specialization for double with NaN handling + __device__ __forceinline__ double operator()(double a, double b) const { + if (isnan(a) || isnan(b)) { + return a > b ? a : b; // Propagate NaN + } + return a > b ? a : b; + } + + // Specialization for hipFloatComplex + __device__ __forceinline__ hipFloatComplex + operator()(hipFloatComplex a, hipFloatComplex b) const { + // Check for NaN + if (isnan(a.x) || isnan(a.y)) { + return a; + } + if (isnan(b.x) || isnan(b.y)) { + return b; + } + // Compare by magnitude (real^2 + imag^2), then by real part + float mag_a = a.x * a.x + a.y * a.y; + float mag_b = b.x * b.x + b.y * b.y; + if (mag_a != mag_b) { + return mag_a > mag_b ? a : b; + } + return a.x > b.x ? a : b; + } + + template + __device__ static constexpr T init() { + return numeric_limits::lowest(); + } + + template + __device__ void atomic_update(T* x, T y) { + atomic_reduce(x, y); + } +}; + +struct Min { + template < + typename T, + std::enable_if_t< + !is_complex_v && !std::is_same_v && + !std::is_same_v, + int> = 0> + __device__ __forceinline__ T operator()(T a, T b) const { + return a < b ? a : b; + } + + // Specialization for float with NaN handling + __device__ __forceinline__ float operator()(float a, float b) const { + if (isnan(a) || isnan(b)) { + return a < b ? a : b; // Propagate NaN + } + return a < b ? a : b; + } + + // Specialization for double with NaN handling + __device__ __forceinline__ double operator()(double a, double b) const { + if (isnan(a) || isnan(b)) { + return a < b ? a : b; // Propagate NaN + } + return a < b ? a : b; + } + + // Specialization for hipFloatComplex + __device__ __forceinline__ hipFloatComplex + operator()(hipFloatComplex a, hipFloatComplex b) const { + // Check for NaN + if (isnan(a.x) || isnan(a.y)) { + return a; + } + if (isnan(b.x) || isnan(b.y)) { + return b; + } + // Compare by magnitude (real^2 + imag^2), then by real part + float mag_a = a.x * a.x + a.y * a.y; + float mag_b = b.x * b.x + b.y * b.y; + if (mag_a != mag_b) { + return mag_a < mag_b ? a : b; + } + return a.x < b.x ? a : b; + } + + template + __device__ static constexpr T init() { + return numeric_limits::max(); + } + + template + __device__ void atomic_update(T* x, T y) { + atomic_reduce(x, y); + } +}; + +// Traits to get the result type of reduce op. +template +struct ReduceResult { + using type = T; +}; + +template +struct ReduceResult { + using type = bool; +}; + +template +struct ReduceResult { + using type = bool; +}; + +template +struct ReduceResult { + using type = + std::conditional_t<(std::is_integral_v && sizeof(T) <= 4), int32_t, T>; +}; + +template +struct ReduceResult { + using type = + std::conditional_t<(std::is_integral_v && sizeof(T) <= 4), int32_t, T>; +}; + +// Traits to get the init value of reduce op. +template +struct ReduceInit { + __device__ static T value() { + return Op::template init(); + } +}; + +template +struct ReduceInit { + __device__ static auto value() { + return typename ReduceResult::type(0); + } +}; + +// Specialization for hipFloatComplex +template <> +struct ReduceInit { + __device__ static hipFloatComplex value() { + return make_hipFloatComplex(0.0f, 0.0f); + } +}; + +template +struct ReduceInit { + __device__ static auto value() { + return typename ReduceResult::type(1); + } +}; + +// Specialization for hipFloatComplex +template <> +struct ReduceInit { + __device__ static hipFloatComplex value() { + return make_hipFloatComplex(1.0f, 0.0f); + } +}; + +template +struct ReduceInit { + __device__ static T value() { + return numeric_limits::lowest(); + } +}; + +// Specialization for hipFloatComplex +template <> +struct ReduceInit { + __device__ static hipFloatComplex value() { + return numeric_limits::lowest(); + } +}; + +template +struct ReduceInit { + __device__ static T value() { + return numeric_limits::max(); + } +}; + +// Specialization for hipFloatComplex +template <> +struct ReduceInit { + __device__ static hipFloatComplex value() { + return numeric_limits::max(); + } +}; + +template +struct ReduceInit { + __device__ static bool value() { + return true; + } +}; + +template +struct ReduceInit { + __device__ static bool value() { + return false; + } +}; + +} // namespace mlx::core::rocm diff --git a/mlx/backend/rocm/reduce/reduce_utils.hpp b/mlx/backend/rocm/reduce/reduce_utils.hpp new file mode 100644 index 0000000000..4b31e746a2 --- /dev/null +++ b/mlx/backend/rocm/reduce/reduce_utils.hpp @@ -0,0 +1,157 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include + +#include "mlx/backend/common/utils.h" +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/allocator.h" +#include "mlx/backend/rocm/device/config.h" +#include "mlx/backend/rocm/device/utils.hpp" + +#include + +namespace mlx::core { + +namespace rocm { + +// WARP_SIZE is defined in device/config.h based on target architecture + +template +struct uint_by_size; +template <> +struct uint_by_size<2> { + using type = uint16_t; +}; +template <> +struct uint_by_size<4> { + using type = uint32_t; +}; +template <> +struct uint_by_size<8> { + using type = unsigned long long int; +}; + +template +__device__ void atomic_reduce(T* x, T y) { + if constexpr (sizeof(T) == 1) { + using U = uint16_t; + U* x_int = (U*)((char*)x - ((size_t)x % 2)); + int shift = ((char*)x - (char*)x_int) * 8; + int mask = 0xff << shift; + U old_val, new_val; + do { + old_val = *x_int; + T result = Op{}(static_cast((old_val >> shift) & 0xff), y); + new_val = (old_val & ~mask) | (result << shift); + } while (atomicCAS(x_int, old_val, new_val) != old_val); + } else { + using U = typename uint_by_size::type; + U* x_int = (U*)(x); + U old_val, new_val; + do { + old_val = *x_int; + T result = Op{}(*((T*)&old_val), y); + new_val = *((U*)&result); + } while (atomicCAS(x_int, old_val, new_val) != old_val); + } +} + +// Warp-level reduction using shuffle +template +__device__ T warp_reduce(T val, Op op) { + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + val = op(val, __shfl_down(val, offset)); + } + return val; +} + +// Block-level reduction +template +__device__ void +block_reduce(T (&vals)[N], T* smem, Op op, T init, int block_size) { + int lane = threadIdx.x % WARP_SIZE; + int warp_id = threadIdx.x / WARP_SIZE; + int num_warps = (block_size + WARP_SIZE - 1) / WARP_SIZE; + + // First reduce within each warp + for (int i = 0; i < N; i++) { + vals[i] = warp_reduce(vals[i], op); + } + + // Store warp results to shared memory + if (lane == 0) { + for (int i = 0; i < N; i++) { + smem[warp_id * N + i] = vals[i]; + } + } + __syncthreads(); + + // Final reduction by first warp + if (warp_id == 0) { + for (int i = 0; i < N; i++) { + vals[i] = (lane < num_warps) ? smem[lane * N + i] : init; + } + for (int i = 0; i < N; i++) { + vals[i] = warp_reduce(vals[i], op); + } + } +} + +} // namespace rocm + +// Allocate output with same layout as input (for reduce operations) +inline void allocate_same_layout( + array& out, + const array& in, + const std::vector& axes, + rocm::CommandEncoder& encoder) { + if (in.flags().row_contiguous) { + out.set_data(mlx::core::rocm::malloc_async(out.nbytes(), encoder)); + return; + } + + if (out.ndim() < in.ndim()) { + throw std::runtime_error( + "Reduction without keepdims only supported for row-contiguous inputs"); + } + + // Calculate the transpositions applied to in in order to apply them to out. + std::vector axis_order(in.ndim()); + std::iota(axis_order.begin(), axis_order.end(), 0); + std::sort(axis_order.begin(), axis_order.end(), [&](int left, int right) { + return in.strides(left) > in.strides(right); + }); + + // Transpose the shape and calculate the strides + Shape out_shape(in.ndim()); + Strides out_strides(in.ndim(), 1); + for (int i = 0; i < in.ndim(); i++) { + out_shape[i] = out.shape(axis_order[i]); + } + for (int i = in.ndim() - 2; i >= 0; i--) { + out_strides[i] = out_shape[i + 1] * out_strides[i + 1]; + } + + // Reverse the axis order to get the final strides + Strides final_strides(in.ndim()); + for (int i = 0; i < in.ndim(); i++) { + final_strides[axis_order[i]] = out_strides[i]; + } + + // Calculate the resulting contiguity and do the memory allocation + auto [data_size, rc, cc] = check_contiguity(out.shape(), final_strides); + auto fl = in.flags(); + fl.row_contiguous = rc; + fl.col_contiguous = cc; + fl.contiguous = true; + out.set_data( + mlx::core::rocm::malloc_async(out.nbytes(), encoder), + data_size, + final_strides, + fl, + allocator::free); +} + +} // namespace mlx::core diff --git a/mlx/backend/rocm/reduce/row_reduce.hip b/mlx/backend/rocm/reduce/row_reduce.hip new file mode 100644 index 0000000000..680387e6a4 --- /dev/null +++ b/mlx/backend/rocm/reduce/row_reduce.hip @@ -0,0 +1,346 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/allocator.h" +#include "mlx/backend/rocm/device/config.h" +#include "mlx/backend/rocm/reduce/reduce.hpp" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/backend/rocm/device/fp16_math.hpp" +#include "mlx/backend/rocm/device/utils.hpp" + +#include + +namespace mlx::core { + +namespace rocm { + +// Helper to handle warp shuffle for different types +template +__device__ T warp_shfl_down(T val, int offset) { + return __shfl_down(val, offset); +} + +// Specialization for hip_bfloat16 - convert to float for shuffle +template <> +__device__ hip_bfloat16 warp_shfl_down(hip_bfloat16 val, int offset) { + float f = bf16_to_float(val); + f = __shfl_down(f, offset); + return float_to_bf16(f); +} + +// Specialization for __half - convert to float for shuffle +template <> +__device__ __half warp_shfl_down(__half val, int offset) { + float f = __half2float(val); + f = __shfl_down(f, offset); + return __float2half(f); +} + +// Helper to cast input to accumulator type +template +__device__ U cast_to_row(T val) { + if constexpr (std::is_same_v) { + // For And/Or operations, convert to bool + return static_cast(val); + } else { + return static_cast(val); + } +} + +template +__global__ void row_reduce_simple_kernel( + const T* __restrict__ in, + U* __restrict__ out, + size_t n_rows, + int row_size) { + __shared__ U shared_data[32]; + + const U init = ReduceInit::value(); + Op op; + + size_t row = blockIdx.x; + if (row >= n_rows) return; + + const T* row_in = in + row * row_size; + U acc = init; + + // Each thread processes multiple elements + for (int i = threadIdx.x * N; i < row_size; i += blockDim.x * N) { + #pragma unroll + for (int j = 0; j < N && (i + j) < row_size; ++j) { + acc = op(acc, cast_to_row(row_in[i + j])); + } + } + + // Warp-level reduction using runtime warpSize + int lane = threadIdx.x % warpSize; + int warp_id = threadIdx.x / warpSize; + + for (int offset = warpSize / 2; offset > 0; offset /= 2) { + acc = op(acc, warp_shfl_down(acc, offset)); + } + + if (lane == 0) { + shared_data[warp_id] = acc; + } + __syncthreads(); + + // Final reduction by first warp + int num_warps = (blockDim.x + warpSize - 1) / warpSize; + if (warp_id == 0) { + acc = (lane < num_warps) ? shared_data[lane] : init; + for (int offset = warpSize / 2; offset > 0; offset /= 2) { + acc = op(acc, warp_shfl_down(acc, offset)); + } + + if (lane == 0) { + out[row] = acc; + } + } +} + +template +__global__ void row_reduce_looped_kernel( + const T* __restrict__ in, + U* __restrict__ out, + size_t out_size, + int row_size, + Shape shape, + Strides in_strides, + int ndim, + size_t non_row_reductions, + Shape reduce_shape, + Strides reduce_strides, + int reduce_ndim) { + __shared__ U shared_data[32]; + + const U init = ReduceInit::value(); + Op op; + + size_t out_idx = blockIdx.x; + if (out_idx >= out_size) return; + + // Compute base input offset from output index + int64_t base_offset = elem_to_loc(out_idx, shape.data(), in_strides.data(), ndim); + + U acc = init; + + // Loop over non-row reductions + LoopedElemToLoc 2)> loop(reduce_ndim); + for (size_t n = 0; n < non_row_reductions; ++n) { + const T* row_in = in + base_offset + loop.location(); + + // Reduce the row + for (int i = threadIdx.x; i < row_size; i += blockDim.x) { + acc = op(acc, cast_to_row(row_in[i])); + } + + loop.next(reduce_shape.data(), reduce_strides.data()); + } + + // Warp-level reduction + int lane = threadIdx.x % warpSize; + int warp_id = threadIdx.x / warpSize; + + for (int offset = warpSize / 2; offset > 0; offset /= 2) { + acc = op(acc, warp_shfl_down(acc, offset)); + } + + if (lane == 0) { + shared_data[warp_id] = acc; + } + __syncthreads(); + + int num_warps = (blockDim.x + warpSize - 1) / warpSize; + if (warp_id == 0) { + acc = (lane < num_warps) ? shared_data[lane] : init; + for (int offset = warpSize / 2; offset > 0; offset /= 2) { + acc = op(acc, warp_shfl_down(acc, offset)); + } + + if (lane == 0) { + out[out_idx] = acc; + } + } +} + +} // namespace rocm + +// Dispatch for reduce types - excludes complex64 which doesn't support most reduce ops +template +void dispatch_reduce_types_row(Dtype dt, Func&& func) { + switch (dt) { + case bool_: + func(type_identity{}); + break; + case uint8: + func(type_identity{}); + break; + case uint16: + func(type_identity{}); + break; + case uint32: + func(type_identity{}); + break; + case uint64: + func(type_identity{}); + break; + case int8: + func(type_identity{}); + break; + case int16: + func(type_identity{}); + break; + case int32: + func(type_identity{}); + break; + case int64: + func(type_identity{}); + break; + case float16: + func(type_identity{}); + break; + case bfloat16: + func(type_identity{}); + break; + case float32: + func(type_identity{}); + break; + case float64: + func(type_identity{}); + break; + case complex64: + throw std::runtime_error("Complex types not yet supported for reduce operations on ROCm"); + default: + throw std::runtime_error("Unsupported dtype for reduce"); + } +} + +// Dispatch helper for reduce operations - no type restrictions +// The cast_to function handles conversion to bool for And/Or +template +void dispatch_reduce_ops_row(Reduce::ReduceType reduce_type, Func&& func) { + switch (reduce_type) { + case Reduce::Sum: + func(type_identity{}); + break; + case Reduce::Prod: + func(type_identity{}); + break; + case Reduce::Max: + func(type_identity{}); + break; + case Reduce::Min: + func(type_identity{}); + break; + case Reduce::And: + func(type_identity{}); + break; + case Reduce::Or: + func(type_identity{}); + break; + default: + throw std::runtime_error("Unsupported reduce type"); + } +} + +// Dispatch helper for reduce ndim +template +void dispatch_reduce_ndim_row(int ndim, Func&& func) { + switch (ndim) { + case 1: + func(std::integral_constant{}); + break; + case 2: + func(std::integral_constant{}); + break; + case 3: + func(std::integral_constant{}); + break; + case 4: + func(std::integral_constant{}); + break; + default: + func(std::integral_constant{}); + break; + } +} + +void row_reduce( + rocm::CommandEncoder& encoder, + const array& in, + array& out, + Reduce::ReduceType reduce_type, + const std::vector& axes, + const ReductionPlan& plan) { + out.set_data(mlx::core::rocm::malloc_async(out.nbytes(), encoder)); + + int row_size = plan.shape.back(); + size_t out_size = out.size(); + + // Calculate threads based on row size + int threads = std::min(256, ((row_size + 3) / 4 + 32 - 1) / 32 * 32); + threads = std::max(threads, 32); + + encoder.set_input_array(in); + encoder.set_output_array(out); + + // Simple row reduce for single reduction axis with contiguous data + // Only use simple kernel for ContiguousReduce (row-contiguous input) + if (plan.shape.size() == 1 && plan.type == ContiguousReduce) { + dispatch_reduce_types_row(in.dtype(), [&](auto type_tag) { + using T = hip_type_t; + dispatch_reduce_ops_row(reduce_type, [&](auto reduce_type_tag) { + using OP = typename decltype(reduce_type_tag)::type; + using U = typename rocm::ReduceResult::type; + + encoder.add_kernel_node( + &rocm::row_reduce_simple_kernel, + dim3(out_size), dim3(threads), 0, + gpu_ptr(in), gpu_ptr(out), out_size, row_size); + }); + }); + } else { + // Looped row reduce for multiple reduction axes + // Build shape/strides for non-reduction axes + auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes); + + rocm::Shape shape; + rocm::Strides strides; + int ndim = shape_vec.size(); + for (int i = 0; i < ndim && i < MAX_NDIM; i++) { + shape[i] = shape_vec[i]; + strides[i] = strides_vec[i]; + } + + // Build reduce shape/strides (excluding last axis which is the row) + rocm::Shape reduce_shape; + rocm::Strides reduce_strides; + int reduce_ndim = plan.shape.size() - 1; + size_t non_row_reductions = 1; + for (int i = 0; i < reduce_ndim && i < MAX_NDIM; i++) { + reduce_shape[i] = plan.shape[i]; + reduce_strides[i] = plan.strides[i]; + non_row_reductions *= plan.shape[i]; + } + + dispatch_reduce_types_row(in.dtype(), [&](auto type_tag) { + using T = hip_type_t; + dispatch_reduce_ops_row(reduce_type, [&](auto reduce_type_tag) { + dispatch_reduce_ndim_row(reduce_ndim, [&](auto reduce_ndim_val) { + using OP = typename decltype(reduce_type_tag)::type; + using U = typename rocm::ReduceResult::type; + + encoder.add_kernel_node( + &rocm::row_reduce_looped_kernel, + dim3(out_size), dim3(threads), 0, + gpu_ptr(in), gpu_ptr(out), out_size, row_size, + shape, strides, ndim, + non_row_reductions, reduce_shape, reduce_strides, reduce_ndim); + }); + }); + }); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/rocm/rms_norm.hip b/mlx/backend/rocm/rms_norm.hip new file mode 100644 index 0000000000..842f9be8ba --- /dev/null +++ b/mlx/backend/rocm/rms_norm.hip @@ -0,0 +1,461 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/allocator.h" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/backend/rocm/reduce/reduce.hpp" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/fast_primitives.h" + +#include + +namespace mlx::core { + +namespace rocm { + +// Warp reduce for sum +__device__ float warp_reduce_sum_rms(float val) { + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + val += __shfl_xor(val, offset); + } + return val; +} + +// Warp reduce for float2 (wg*x_sum, x^2_sum) +struct float2_sum { + float x, y; +}; + +__device__ float2_sum warp_reduce_sum_f2(float2_sum val) { + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + val.x += __shfl_xor(val.x, offset); + val.y += __shfl_xor(val.y, offset); + } + return val; +} + +// Per-row RMS norm body. `x`/`out` are already offset to this row's base. +template +__device__ void rms_norm_row( + const T* x, + const T* w, + T* out, + float eps, + uint32_t axis_size, + int64_t w_stride) { + // Compute sum of squares + float normalizer = 0; + for (int i = threadIdx.x * N_READS; i < axis_size; i += BLOCK_DIM * N_READS) { + #pragma unroll + for (int j = 0; j < N_READS && i + j < axis_size; ++j) { + float t = static_cast(x[i + j]); + normalizer += t * t; + } + } + + // Block reduce for normalizer + __shared__ float shared_sum[BLOCK_DIM / WARP_SIZE + 1]; + + float warp_sum = warp_reduce_sum_rms(normalizer); + int lane = threadIdx.x % WARP_SIZE; + int warp_id = threadIdx.x / WARP_SIZE; + + if (lane == 0) { + shared_sum[warp_id] = warp_sum; + } + __syncthreads(); + + if (warp_id == 0) { + normalizer = (lane < (BLOCK_DIM + WARP_SIZE - 1) / WARP_SIZE) ? shared_sum[lane] : 0; + normalizer = warp_reduce_sum_rms(normalizer); + } + __syncthreads(); + + if (threadIdx.x == 0) { + shared_sum[0] = normalizer; + } + __syncthreads(); + // Use 1/sqrt instead of rsqrtf for IEEE-compliant precision + // (matches Metal's metal::precise::rsqrt behavior) + normalizer = 1.0f / sqrtf(shared_sum[0] / axis_size + eps); + + // Write output + // Match Metal's weight application order: w * T(x * normalizer) + // Weight multiply in output type T after truncation, not in float32 + for (int i = threadIdx.x * N_READS; i < axis_size; i += BLOCK_DIM * N_READS) { + #pragma unroll + for (int j = 0; j < N_READS && i + j < axis_size; ++j) { + int idx = i + j; + T normalized = static_cast(static_cast(x[idx]) * normalizer); + T wi = (w_stride == 0) ? w[0] : w[idx * w_stride]; + out[idx] = wi * normalized; + } + } +} + +// Packed input: rows tightly packed at row * axis_size. +template +__global__ void rms_norm_kernel( + const T* x, + const T* w, + T* out, + float eps, + uint32_t axis_size, + int64_t w_stride) { + int row = blockIdx.x; + rms_norm_row( + x + (int64_t)row * axis_size, w, out + (int64_t)row * axis_size, eps, + axis_size, w_stride); +} + +// Strided input: each row's base offset is computed from the leading dims' +// shape/strides (last dim must be contiguous). Output is packed contiguous. +// Avoids the contiguous_copy_gpu the host would otherwise insert for a +// non-packed (sliced/transposed) input, e.g. per-head q/k norm. +template +__global__ void rms_norm_strided_kernel( + const T* x, + const T* w, + T* out, + float eps, + uint32_t axis_size, + int64_t w_stride, + int n_row_dims, + hip_array row_shape, + hip_array row_strides) { + int row = blockIdx.x; + int64_t x_off = 0; + int r = row; + for (int d = n_row_dims - 1; d >= 0; --d) { + x_off += (int64_t)(r % row_shape[d]) * row_strides[d]; + r /= row_shape[d]; + } + rms_norm_row( + x + x_off, w, out + (int64_t)row * axis_size, eps, axis_size, w_stride); +} + +template +__global__ void rms_norm_vjp_kernel( + const T* x, + const T* w, + const T* g, + T* gx, + T* gw, + float eps, + int32_t axis_size, + int64_t w_stride) { + int row = blockIdx.x; + + x += row * axis_size; + g += row * axis_size; + gx += row * axis_size; + gw += row * axis_size; + + // Compute factors: (wg*x_sum, x^2_sum) + float2_sum factors = {0, 0}; + for (int i = threadIdx.x * N_READS; i < axis_size; i += BLOCK_DIM * N_READS) { + #pragma unroll + for (int j = 0; j < N_READS && i + j < axis_size; ++j) { + int idx = i + j; + float t = static_cast(x[idx]); + float wi = (w_stride == 0) ? static_cast(w[0]) : static_cast(w[idx * w_stride]); + float gi = static_cast(g[idx]); + float wg = wi * gi; + factors.x += wg * t; + factors.y += t * t; + } + } + + // Block reduce for factors + __shared__ float2_sum shared_f2[BLOCK_DIM / WARP_SIZE + 1]; + + float2_sum warp_f2 = warp_reduce_sum_f2(factors); + int lane = threadIdx.x % WARP_SIZE; + int warp_id = threadIdx.x / WARP_SIZE; + + if (lane == 0) { + shared_f2[warp_id] = warp_f2; + } + __syncthreads(); + + if (warp_id == 0) { + factors = (lane < (BLOCK_DIM + WARP_SIZE - 1) / WARP_SIZE) ? shared_f2[lane] : float2_sum{0, 0}; + factors = warp_reduce_sum_f2(factors); + } + __syncthreads(); + + if (threadIdx.x == 0) { + shared_f2[0] = factors; + } + __syncthreads(); + factors = shared_f2[0]; + + float meangwx = factors.x / axis_size; + // Use 1/sqrt instead of rsqrtf for IEEE-compliant precision + // (matches Metal's metal::precise::rsqrt behavior) + float normalizer = 1.0f / sqrtf(factors.y / axis_size + eps); + float normalizer3 = normalizer * normalizer * normalizer; + + // Write outputs + for (int i = threadIdx.x * N_READS; i < axis_size; i += BLOCK_DIM * N_READS) { + #pragma unroll + for (int j = 0; j < N_READS && i + j < axis_size; ++j) { + int idx = i + j; + float xi = static_cast(x[idx]); + float wi = (w_stride == 0) ? static_cast(w[0]) : static_cast(w[idx * w_stride]); + float gi = static_cast(g[idx]); + + // Gradient for x + gx[idx] = static_cast(normalizer * wi * gi - xi * meangwx * normalizer3); + + // Gradient for w (per-element, will be reduced later) + if constexpr (HAS_W) { + gw[idx] = static_cast(gi * xi * normalizer); + } + } + } +} + +} // namespace rocm + +namespace fast { + +bool RMSNorm::use_fallback(Stream s) { + return s.device == Device::cpu; +} + +void RMSNorm::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + auto& s = stream(); + auto& out = outputs[0]; + auto& encoder = rocm::get_command_encoder(s); + + const array& xin = inputs[0]; + const array& w = inputs[1]; + int ndim = xin.ndim(); + int32_t axis_size = xin.shape().back(); + + // Layout decision: + // - packed: rows tightly packed -> fast kernel, output adopts input layout. + // - strided: last dim contiguous but rows not packed (e.g. sliced per-head + // q/k norm) -> strided kernel reads the input in place and writes + // a packed output. Avoids a contiguous_copy_gpu launch. + // - else: fall back to a contiguous copy. + bool last_contig = ndim >= 1 && xin.strides()[ndim - 1] == 1; + bool packed = xin.flags().contiguous && last_contig; + if (packed && ndim > 1) { + auto s2 = xin.strides()[ndim - 2]; + packed &= (s2 == 0 || s2 == (int64_t)axis_size); + } + bool strided = !packed && last_contig && (ndim - 1) <= 4; + + array x = xin; + if (packed) { + if (xin.is_donatable()) { + out.copy_shared_buffer(xin); + } else { + out.set_data( + mlx::core::rocm::malloc_async(xin.data_size() * xin.itemsize(), encoder), + xin.data_size(), xin.strides(), xin.flags()); + } + } else if (strided) { + out.set_data(mlx::core::rocm::malloc_async(out.nbytes(), encoder)); // packed contiguous output + } else { + x = contiguous_copy_gpu(xin, s); + out.copy_shared_buffer(x); + } + + const array& xk = strided ? xin : x; + int32_t n_rows = (int32_t)(out.size() / axis_size); + int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; + + int n_row_dims = ndim - 1; + rocm::hip_array row_shape; + rocm::hip_array row_strides; + if (strided) { + for (int d = 0; d < n_row_dims; ++d) { + row_shape[d] = (int)xin.shape()[d]; + row_strides[d] = xin.strides()[d]; + } + } + + encoder.set_input_array(xk); + encoder.set_input_array(w); + encoder.set_output_array(out); + + constexpr int BLOCK_DIM = 256; + constexpr int N_READS = 4; + + auto launch = [&](auto tag) { + using DT = decltype(tag); + if (strided) { + encoder.add_kernel_node( + &rocm::rms_norm_strided_kernel, + dim3(n_rows), dim3(BLOCK_DIM), 0, + gpu_ptr
(xk), gpu_ptr
(w), gpu_ptr
(out), + eps_, axis_size, w_stride, n_row_dims, row_shape, row_strides); + } else { + encoder.add_kernel_node( + &rocm::rms_norm_kernel, + dim3(n_rows), dim3(BLOCK_DIM), 0, + gpu_ptr
(xk), gpu_ptr
(w), gpu_ptr
(out), + eps_, axis_size, w_stride); + } + }; + switch (out.dtype()) { + case float32: launch(float{}); break; + case float16: launch(__half{}); break; + case bfloat16: launch(hip_bfloat16{}); break; + default: throw std::runtime_error("Unsupported type for rms_norm"); + } +} + +void RMSNormVJP::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + auto& s = stream(); + auto& encoder = rocm::get_command_encoder(s); + + // Ensure row contiguity + auto check_input = [&s](const array& x, bool& copied) { + if (x.flags().row_contiguous) { + copied = false; + return x; + } + copied = true; + return contiguous_copy_gpu(x, s); + }; + + bool donate_x = inputs[0].is_donatable(); + bool donate_g = inputs[2].is_donatable(); + bool copied; + auto x = check_input(inputs[0], copied); + donate_x |= copied; + const array& w = inputs[1]; + bool g_copied; + auto g = check_input(inputs[2], g_copied); + donate_g |= g_copied; + array& gx = outputs[0]; + array& gw = outputs[1]; + + // Check whether we had a weight + bool has_w = w.ndim() != 0; + + // Allocate space for the outputs + bool g_in_gx = false; + if (donate_x) { + gx.copy_shared_buffer(x); + } else if (donate_g) { + gx.copy_shared_buffer(g); + g_in_gx = true; + } else { + gx.set_data(mlx::core::rocm::malloc_async(gx.nbytes(), encoder)); + } + if (g_copied && !g_in_gx) { + encoder.add_temporary(g); + } + + int32_t axis_size = x.shape().back(); + int32_t n_rows = x.data_size() / axis_size; + int64_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0; + + // Allocate a temporary to store the gradients for w + array gw_temp = + (has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w; + if (has_w) { + if (!g_in_gx && donate_g) { + gw_temp.copy_shared_buffer(g); + } else { + gw_temp.set_data(mlx::core::rocm::malloc_async(gw_temp.nbytes(), encoder)); + encoder.add_temporary(gw_temp); + } + } + + encoder.set_input_array(x); + encoder.set_input_array(w); + encoder.set_input_array(g); + encoder.set_output_array(gx); + encoder.set_output_array(gw_temp); + + constexpr int BLOCK_DIM = 256; + constexpr int N_READS = 4; + + if (has_w) { + switch (gx.dtype()) { + case float32: + encoder.add_kernel_node( + &rocm::rms_norm_vjp_kernel, + dim3(n_rows), dim3(BLOCK_DIM), 0, + gpu_ptr(x), gpu_ptr(w), gpu_ptr(g), + gpu_ptr(gx), gpu_ptr(gw_temp), + eps_, axis_size, w_stride); + break; + case float16: + encoder.add_kernel_node( + &rocm::rms_norm_vjp_kernel<__half, true, BLOCK_DIM, N_READS>, + dim3(n_rows), dim3(BLOCK_DIM), 0, + gpu_ptr<__half>(x), gpu_ptr<__half>(w), gpu_ptr<__half>(g), + gpu_ptr<__half>(gx), gpu_ptr<__half>(gw_temp), + eps_, axis_size, w_stride); + break; + case bfloat16: + encoder.add_kernel_node( + &rocm::rms_norm_vjp_kernel, + dim3(n_rows), dim3(BLOCK_DIM), 0, + gpu_ptr(x), gpu_ptr(w), gpu_ptr(g), + gpu_ptr(gx), gpu_ptr(gw_temp), + eps_, axis_size, w_stride); + break; + default: + throw std::runtime_error("Unsupported type for rms_norm_vjp"); + } + } else { + switch (gx.dtype()) { + case float32: { + float* gw_null = nullptr; + encoder.add_kernel_node( + &rocm::rms_norm_vjp_kernel, + dim3(n_rows), dim3(BLOCK_DIM), 0, + gpu_ptr(x), gpu_ptr(w), gpu_ptr(g), + gpu_ptr(gx), gw_null, + eps_, axis_size, w_stride); + break; + } + case float16: { + __half* gw_null = nullptr; + encoder.add_kernel_node( + &rocm::rms_norm_vjp_kernel<__half, false, BLOCK_DIM, N_READS>, + dim3(n_rows), dim3(BLOCK_DIM), 0, + gpu_ptr<__half>(x), gpu_ptr<__half>(w), gpu_ptr<__half>(g), + gpu_ptr<__half>(gx), gw_null, + eps_, axis_size, w_stride); + break; + } + case bfloat16: { + hip_bfloat16* gw_null = nullptr; + encoder.add_kernel_node( + &rocm::rms_norm_vjp_kernel, + dim3(n_rows), dim3(BLOCK_DIM), 0, + gpu_ptr(x), gpu_ptr(w), gpu_ptr(g), + gpu_ptr(gx), gw_null, + eps_, axis_size, w_stride); + break; + } + default: + throw std::runtime_error("Unsupported type for rms_norm_vjp"); + } + } + + // Reduce gw_temp to gw if we have weights + if (has_w) { + ReductionPlan plan( + ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size}); + col_reduce(encoder, gw_temp, gw, Reduce::ReduceType::Sum, {0}, plan); + } +} + +} // namespace fast + +} // namespace mlx::core diff --git a/mlx/backend/rocm/rocm.cpp b/mlx/backend/rocm/rocm.cpp new file mode 100644 index 0000000000..e042416981 --- /dev/null +++ b/mlx/backend/rocm/rocm.cpp @@ -0,0 +1,19 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/rocm.h" + +#include + +namespace mlx::core::rocm { + +bool is_available() { + static int available = -1; + if (available < 0) { + int device_count = 0; + hipError_t err = hipGetDeviceCount(&device_count); + available = (err == hipSuccess && device_count > 0) ? 1 : 0; + } + return available == 1; +} + +} // namespace mlx::core::rocm diff --git a/mlx/backend/rocm/rocm.h b/mlx/backend/rocm/rocm.h new file mode 100644 index 0000000000..2ebe88e306 --- /dev/null +++ b/mlx/backend/rocm/rocm.h @@ -0,0 +1,12 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/api.h" + +namespace mlx::core::rocm { + +/* Check if the ROCm backend is available. */ +MLX_API bool is_available(); + +} // namespace mlx::core::rocm diff --git a/mlx/backend/rocm/rope.hip b/mlx/backend/rocm/rope.hip new file mode 100644 index 0000000000..488a3ee6b0 --- /dev/null +++ b/mlx/backend/rocm/rope.hip @@ -0,0 +1,614 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/allocator.h" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/fast_primitives.h" + +#include +#include +#include + +namespace mlx::core { + +namespace rocm { + +// Single position RoPE implementation (B=1, T=1) +template +__device__ void rope_single_impl( + const T* in, + T* out, + int32_t offset, + float inv_freq, + float scale, + int64_t stride, + uint2 pos, + uint2 dims) { + float L = scale * static_cast(offset); + + // Compute costheta, sintheta using sincosf for better performance + float theta = L * inv_freq; + float sintheta, costheta; + sincosf(theta, &sintheta, &costheta); + + // Compute the input and output indices + uint32_t index_1, index_2; + if (traditional) { + index_1 = 2 * pos.x + pos.y * stride; + index_2 = index_1 + 1; + } else { + index_1 = pos.x + pos.y * stride; + index_2 = index_1 + dims.x; + } + + // Read and write the output + float x1 = static_cast(in[index_1]); + float x2 = static_cast(in[index_2]); + float rx1; + float rx2; + if (forward) { + rx1 = x1 * costheta - x2 * sintheta; + rx2 = x1 * sintheta + x2 * costheta; + } else { + rx1 = x2 * sintheta + x1 * costheta; + rx2 = x2 * costheta - x1 * sintheta; + } + out[index_1] = static_cast(rx1); + out[index_2] = static_cast(rx2); +} + +template +__global__ void rope_single( + const T* in, + T* out, + const int32_t* offset, + float scale, + float base, + int64_t stride, + uint2 dims) { + uint2 pos = make_uint2( + blockIdx.x * blockDim.x + threadIdx.x, + blockIdx.y * blockDim.y + threadIdx.y); + if (pos.x >= dims.x || pos.y >= dims.y) { + return; + } + + float d = static_cast(pos.x) / static_cast(dims.x); + float inv_freq = exp2f(-d * base); + rope_single_impl( + in, out, *offset, inv_freq, scale, stride, pos, dims); +} + +// Optimized 1D kernel for single-token decode case +// Uses flat indexing for better occupancy with small workloads +template +__global__ void rope_single_1d( + const T* in, + T* out, + const int32_t* offset, + float scale, + float base, + int64_t stride, + uint32_t half_dims, // dims.x = dims_ / 2 + uint32_t n_heads) { // dims.y = N + uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; + uint32_t total = half_dims * n_heads; + if (tid >= total) { + return; + } + + // Convert flat index to 2D position + uint32_t pos_x = tid % half_dims; // position within dimension + uint32_t pos_y = tid / half_dims; // head index + + float d = static_cast(pos_x) / static_cast(half_dims); + float inv_freq = exp2f(-d * base); + + // Inline the implementation for better performance + float L = scale * static_cast(*offset); + float theta = L * inv_freq; + float sintheta, costheta; + sincosf(theta, &sintheta, &costheta); + + uint32_t index_1, index_2; + if (traditional) { + index_1 = 2 * pos_x + pos_y * stride; + index_2 = index_1 + 1; + } else { + index_1 = pos_x + pos_y * stride; + index_2 = index_1 + half_dims; + } + + float x1 = static_cast(in[index_1]); + float x2 = static_cast(in[index_2]); + float rx1, rx2; + if (forward) { + rx1 = x1 * costheta - x2 * sintheta; + rx2 = x1 * sintheta + x2 * costheta; + } else { + rx1 = x2 * sintheta + x1 * costheta; + rx2 = x2 * costheta - x1 * sintheta; + } + out[index_1] = static_cast(rx1); + out[index_2] = static_cast(rx2); +} + +// Optimized 1D kernel for single-token decode with custom frequencies +template +__global__ void rope_single_freqs_1d( + const T* in, + T* out, + const int32_t* offset, + const float* freqs, + float scale, + int64_t stride, + uint32_t half_dims, + uint32_t n_heads, + int64_t freq_stride) { + uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; + uint32_t total = half_dims * n_heads; + if (tid >= total) { + return; + } + + uint32_t pos_x = tid % half_dims; + uint32_t pos_y = tid / half_dims; + + float inv_freq = 1.0f / freqs[freq_stride * pos_x]; + + float L = scale * static_cast(*offset); + float theta = L * inv_freq; + float sintheta, costheta; + sincosf(theta, &sintheta, &costheta); + + uint32_t index_1, index_2; + if (traditional) { + index_1 = 2 * pos_x + pos_y * stride; + index_2 = index_1 + 1; + } else { + index_1 = pos_x + pos_y * stride; + index_2 = index_1 + half_dims; + } + + float x1 = static_cast(in[index_1]); + float x2 = static_cast(in[index_2]); + float rx1, rx2; + if (forward) { + rx1 = x1 * costheta - x2 * sintheta; + rx2 = x1 * sintheta + x2 * costheta; + } else { + rx1 = x2 * sintheta + x1 * costheta; + rx2 = x2 * costheta - x1 * sintheta; + } + out[index_1] = static_cast(rx1); + out[index_2] = static_cast(rx2); +} + +template +__global__ void rope_single_freqs( + const T* in, + T* out, + const int32_t* offset, + const float* freqs, + float scale, + int64_t stride, + uint2 dims, + int64_t freq_stride) { + uint2 pos = make_uint2( + blockIdx.x * blockDim.x + threadIdx.x, + blockIdx.y * blockDim.y + threadIdx.y); + if (pos.x >= dims.x || pos.y >= dims.y) { + return; + } + + float inv_freq = 1.0f / freqs[freq_stride * pos.x]; + rope_single_impl( + in, out, *offset, inv_freq, scale, stride, pos, dims); +} + +// General RoPE implementation with batching +template +__device__ void rope_impl( + const T* in, + T* out, + const int* offset, + float inv_freq, + float scale, + const hip_array strides, + const hip_array out_strides, + int64_t offset_stride, + int n_head, + uint3 pos, + uint3 dims) { + auto n_head_up = N * ((n_head + N - 1) / N); + auto head_idx = static_cast((pos.z * N) % n_head_up); + auto batch_idx = (pos.z * N) / n_head_up; + auto batch_offset = offset[batch_idx * offset_stride]; + float L = scale * static_cast(pos.y + batch_offset); + auto mat_idx = batch_idx * n_head + head_idx; + + // Compute costheta, sintheta using sincosf for better performance + float theta = L * inv_freq; + float sintheta, costheta; + sincosf(theta, &sintheta, &costheta); + + // Compute the input and output indices + size_t in_index_1, in_index_2; + size_t out_index_1, out_index_2; + if (traditional) { + out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] + + mat_idx * out_strides[0]; + out_index_2 = out_index_1 + 1; + in_index_1 = + 2 * pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0]; + in_index_2 = in_index_1 + strides[2]; + } else { + out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] + + mat_idx * out_strides[0]; + out_index_2 = out_index_1 + dims.x * out_strides[2]; + in_index_1 = pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0]; + in_index_2 = in_index_1 + dims.x * strides[2]; + } + for (int i = 0; i < N && head_idx + i < n_head; ++i) { + // Read and write the output + float x1 = static_cast(in[in_index_1]); + float x2 = static_cast(in[in_index_2]); + float rx1; + float rx2; + if (forward) { + rx1 = x1 * costheta - x2 * sintheta; + rx2 = x1 * sintheta + x2 * costheta; + } else { + rx1 = x2 * sintheta + x1 * costheta; + rx2 = x2 * costheta - x1 * sintheta; + } + out[out_index_1] = static_cast(rx1); + out[out_index_2] = static_cast(rx2); + in_index_1 += strides[0]; + in_index_2 += strides[0]; + out_index_1 += out_strides[0]; + out_index_2 += out_strides[0]; + } +} + +template +__global__ void rope( + const T* in, + T* out, + const int32_t* offset, + float scale, + float base, + const hip_array strides, + const hip_array out_strides, + int64_t offset_stride, + int n_head, + uint3 dims) { + uint3 pos = make_uint3( + blockIdx.x * blockDim.x + threadIdx.x, + blockIdx.y * blockDim.y + threadIdx.y, + blockIdx.z * blockDim.z + threadIdx.z); + if (pos.x >= dims.x || pos.y >= dims.y || pos.z >= dims.z) { + return; + } + + float d = static_cast(pos.x) / static_cast(dims.x); + float inv_freq = exp2f(-d * base); + rope_impl( + in, + out, + offset, + inv_freq, + scale, + strides, + out_strides, + offset_stride, + n_head, + pos, + dims); +} + +template +__global__ void rope_freqs( + const T* in, + T* out, + const int32_t* offset, + const float* freqs, + float scale, + float base, + const hip_array strides, + const hip_array out_strides, + int64_t offset_stride, + int n_head, + uint3 dims, + int64_t freq_stride) { + uint3 pos = make_uint3( + blockIdx.x * blockDim.x + threadIdx.x, + blockIdx.y * blockDim.y + threadIdx.y, + blockIdx.z * blockDim.z + threadIdx.z); + if (pos.x >= dims.x || pos.y >= dims.y || pos.z >= dims.z) { + return; + } + + float inv_freq = 1.0f / freqs[freq_stride * pos.x]; + rope_impl( + in, + out, + offset, + inv_freq, + scale, + strides, + out_strides, + offset_stride, + n_head, + pos, + dims); +} + +// Helper to get grid and block dimensions +inline std::pair get_grid_and_block(uint32_t x, uint32_t y, uint32_t z) { + dim3 block(16, 16, 1); + dim3 grid( + (x + block.x - 1) / block.x, + (y + block.y - 1) / block.y, + z); + return {grid, block}; +} + +// Optimized grid/block for single-token decode case +// Uses 1D blocks for better coalescing when y (n_heads) is small +inline std::pair get_grid_and_block_single(uint32_t x, uint32_t y) { + // For decode: x = dims/2 (e.g., 64), y = n_heads (e.g., 40) + // Total elements = x * y (e.g., 2560) + // Use 1D layout for better occupancy with small workloads + constexpr uint32_t BLOCK_SIZE = 256; + uint32_t total = x * y; + dim3 block(BLOCK_SIZE, 1, 1); + dim3 grid((total + BLOCK_SIZE - 1) / BLOCK_SIZE, 1, 1); + return {grid, block}; +} + +} // namespace rocm + +namespace fast { + +bool RoPE::use_fallback(Stream s) { + return s.device == Device::cpu; +} + +void RoPE::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + auto& s = stream(); + auto& encoder = rocm::get_command_encoder(s); + auto& in = inputs[0]; + auto& offset = inputs[1]; + auto& out = outputs[0]; + + rocm::hip_array strides; + rocm::hip_array out_strides; + bool donated = false; + int ndim = in.ndim(); + + int B = in.shape(0); + int T = in.shape(-2); + int D = in.shape(-1); + size_t mat_size = T * D; + int dispatch_ndim = ndim; + while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) { + dispatch_ndim--; + } + + int N = 1; + for (int i = 1; i < (ndim - 2); ++i) { + N *= in.shape(i); + } + + // We apply rope to less than the whole vector. Normally that needs a full + // copy to `out` (so the untouched [dims_:D] tail is present) followed by an + // in-place rotate. But if the input is donatable we can rotate the first + // dims_ channels IN PLACE and leave the tail untouched — no copy at all, and + // `out` simply adopts the input's (possibly strided) layout. Downstream ops + // (SDPA) accept non-contiguous q/k, so this is safe. (Port of ml-explore/mlx + // PR #3704, "RoPE without copy".) + if (dims_ < D) { + donated = true; + if (in.is_donatable() && in.flags().row_contiguous) { + out.copy_shared_buffer(in); + strides[0] = mat_size; + strides[1] = in.strides()[ndim - 2]; + strides[2] = in.strides()[ndim - 1]; + } else { + auto ctype = + (in.flags().row_contiguous) ? CopyType::Vector : CopyType::General; + copy_gpu(in, out, ctype, s); + strides[0] = mat_size; + strides[1] = out.strides()[ndim - 2]; + strides[2] = out.strides()[ndim - 1]; + } + } + + // Either copy or apply in-place + else if (in.flags().row_contiguous) { + if (in.is_donatable()) { + donated = true; + out.copy_shared_buffer(in); + } else { + out.set_data(mlx::core::rocm::malloc_async(out.nbytes(), encoder)); + } + strides[0] = mat_size; + strides[1] = in.strides()[ndim - 2]; + strides[2] = in.strides()[ndim - 1]; + } else if (dispatch_ndim == 3) { + // Handle non-contiguous 3D inputs + out.set_data(mlx::core::rocm::malloc_async(out.nbytes(), encoder)); + strides[0] = in.strides()[ndim - 3]; + strides[1] = in.strides()[ndim - 2]; + strides[2] = in.strides()[ndim - 1]; + } else { + // Copy non-contiguous > 3D inputs into the output and treat + // input as donated + donated = true; + copy_gpu(in, out, CopyType::General, s); + strides[0] = mat_size; + strides[1] = out.strides()[ndim - 2]; + strides[2] = out.strides()[ndim - 1]; + } + out_strides[0] = mat_size; + out_strides[1] = out.strides()[ndim - 2]; + out_strides[2] = out.strides()[ndim - 1]; + + // Some flags to help us dispatch below + bool single = in.flags().row_contiguous && B == 1 && T == 1; + bool with_freqs = inputs.size() == 3; + + encoder.set_input_array(donated ? out : in); + encoder.set_input_array(offset); + if (with_freqs) { + encoder.set_input_array(inputs[2]); + } + encoder.set_output_array(out); + + // Dispatch based on dtype + dispatch_float_types(out.dtype(), "rope", [&](auto type_tag) { + using CTYPE = MLX_GET_TYPE(type_tag); + using DataType = hip_type_t; + + // Get grid/block dimensions outside the lambda to avoid C++20 structured binding capture + if (single && !with_freqs) { + // Use optimized 1D kernel for single-token decode + uint32_t half_dims = dims_ / 2; + uint32_t n_heads = N; + std::pair gb = rocm::get_grid_and_block_single(half_dims, n_heads); + dim3 grid = gb.first; + dim3 block = gb.second; + + const DataType* in_ptr = gpu_ptr(donated ? out : in); + DataType* out_ptr = gpu_ptr(out); + const int32_t* off_ptr = gpu_ptr(offset); + float base_log2 = std::log2(base_); + int64_t stride_v = static_cast(mat_size); + + #define ADD_ROPE_SINGLE_1D(TRAD, FWD) \ + encoder.add_kernel_node( \ + &rocm::rope_single_1d, \ + grid, block, 0, \ + in_ptr, out_ptr, off_ptr, scale_, base_log2, stride_v, \ + half_dims, n_heads) + if (traditional_ && forward_) { + ADD_ROPE_SINGLE_1D(true, true); + } else if (traditional_ && !forward_) { + ADD_ROPE_SINGLE_1D(true, false); + } else if (!traditional_ && forward_) { + ADD_ROPE_SINGLE_1D(false, true); + } else { + ADD_ROPE_SINGLE_1D(false, false); + } + #undef ADD_ROPE_SINGLE_1D + } else if (single) { + // Use optimized 1D kernel for single-token decode with freqs + uint32_t half_dims = dims_ / 2; + uint32_t n_heads = N; + std::pair gb = rocm::get_grid_and_block_single(half_dims, n_heads); + dim3 grid = gb.first; + dim3 block = gb.second; + int64_t freq_stride = inputs[2].strides(0); + + const DataType* in_ptr = gpu_ptr(donated ? out : in); + DataType* out_ptr = gpu_ptr(out); + const int32_t* off_ptr = gpu_ptr(offset); + const float* freqs_ptr = gpu_ptr(inputs[2]); + int64_t stride_v = static_cast(mat_size); + + #define ADD_ROPE_SINGLE_FREQS_1D(TRAD, FWD) \ + encoder.add_kernel_node( \ + &rocm::rope_single_freqs_1d, \ + grid, block, 0, \ + in_ptr, out_ptr, off_ptr, freqs_ptr, scale_, stride_v, \ + half_dims, n_heads, freq_stride) + if (traditional_ && forward_) { + ADD_ROPE_SINGLE_FREQS_1D(true, true); + } else if (traditional_ && !forward_) { + ADD_ROPE_SINGLE_FREQS_1D(true, false); + } else if (!traditional_ && forward_) { + ADD_ROPE_SINGLE_FREQS_1D(false, true); + } else { + ADD_ROPE_SINGLE_FREQS_1D(false, false); + } + #undef ADD_ROPE_SINGLE_FREQS_1D + } else if (with_freqs) { + int n_per_thread = 4; + uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread); + uint3 dims3 = make_uint3(dims_ / 2, T, dimz); + std::pair gb = rocm::get_grid_and_block(dims3.x, dims3.y, dims3.z); + dim3 grid = gb.first; + dim3 block = gb.second; + int64_t offset_stride = 0; + if (inputs[1].ndim() > 0) { + offset_stride = inputs[1].strides()[0]; + } + int64_t freq_stride = inputs[2].strides(0); + + const DataType* in_ptr = gpu_ptr(donated ? out : in); + DataType* out_ptr = gpu_ptr(out); + const int32_t* off_ptr = gpu_ptr(offset); + const float* freqs_ptr = gpu_ptr(inputs[2]); + float base_log2 = std::log2(base_); + + #define ADD_ROPE_FREQS(TRAD, FWD) \ + encoder.add_kernel_node( \ + &rocm::rope_freqs, \ + grid, block, 0, \ + in_ptr, out_ptr, off_ptr, freqs_ptr, scale_, base_log2, \ + strides, out_strides, offset_stride, N, dims3, freq_stride) + if (traditional_ && forward_) { + ADD_ROPE_FREQS(true, true); + } else if (traditional_ && !forward_) { + ADD_ROPE_FREQS(true, false); + } else if (!traditional_ && forward_) { + ADD_ROPE_FREQS(false, true); + } else { + ADD_ROPE_FREQS(false, false); + } + #undef ADD_ROPE_FREQS + } else { + int n_per_thread = 4; + uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread); + uint3 dims3 = make_uint3(dims_ / 2, T, dimz); + std::pair gb = rocm::get_grid_and_block(dims3.x, dims3.y, dims3.z); + dim3 grid = gb.first; + dim3 block = gb.second; + int64_t offset_stride = 0; + if (inputs[1].ndim() > 0) { + offset_stride = inputs[1].strides()[0]; + } + + const DataType* in_ptr = gpu_ptr(donated ? out : in); + DataType* out_ptr = gpu_ptr(out); + const int32_t* off_ptr = gpu_ptr(offset); + float base_log2 = std::log2(base_); + + #define ADD_ROPE(TRAD, FWD) \ + encoder.add_kernel_node( \ + &rocm::rope, \ + grid, block, 0, \ + in_ptr, out_ptr, off_ptr, scale_, base_log2, \ + strides, out_strides, offset_stride, N, dims3) + if (traditional_ && forward_) { + ADD_ROPE(true, true); + } else if (traditional_ && !forward_) { + ADD_ROPE(true, false); + } else if (!traditional_ && forward_) { + ADD_ROPE(false, true); + } else { + ADD_ROPE(false, false); + } + #undef ADD_ROPE + } + }); +} + +} // namespace fast + +} // namespace mlx::core diff --git a/mlx/backend/rocm/scaled_dot_product_attention.cpp b/mlx/backend/rocm/scaled_dot_product_attention.cpp new file mode 100644 index 0000000000..9f38e84ebd --- /dev/null +++ b/mlx/backend/rocm/scaled_dot_product_attention.cpp @@ -0,0 +1,226 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/gpu/copy.h" +#include "mlx/backend/rocm/device.h" +#include "mlx/fast_primitives.h" + +#include + +namespace mlx::core { + +// Defined in scaled_dot_product_attention.hip +bool supports_sdpa_vector( + const array& q, + const array& k, + const array& v, + bool has_mask, + bool has_arr_mask, + bool do_causal, + bool output_logsumexp); + +void sdpa_vector( + const array& q, + const array& k, + const array& v, + float scale, + array& o, + bool do_causal, + const std::optional& sinks, + Stream s); + +#ifdef MLX_HAS_ROCM_WMMA +// Defined in flash_attention_wmma.hip +bool supports_sdpa_flash_wmma( + const array& q, + const array& k, + const array& v, + bool has_arr_mask, + bool output_logsumexp); + +// LDS bytes the WMMA flash kernel needs for a given head dim. +int sdpa_flash_wmma_smem(int D); + +void sdpa_flash_wmma( + const array& q, + const array& k, + const array& v, + float scale, + array& o, + bool do_causal, + Stream s); +#endif + +// Defined in flash_attention.hip +bool supports_sdpa_flash( + const array& q, + const array& k, + const array& v, + bool has_mask, + bool has_arr_mask, + bool do_causal, + bool output_logsumexp); + +void sdpa_flash( + const array& q, + const array& k, + const array& v, + float scale, + array& o, + bool do_causal, + const std::optional& mask, + const std::optional& sinks, + Stream s); + +namespace { + +array prepare_sdpa_input(const array& x, Stream s) { + // SDPA kernel requirements: last dim stride be 1, pointer aligned + if (x.strides(-1) != 1) { + array x_copy = contiguous_copy_gpu(x, s); + auto& d = rocm::device(s.device); + auto& encoder = d.get_command_encoder(s); + encoder.add_temporary(x_copy); + return x_copy; + } + return x; +} + +bool prefer_flash_for_decode( + const array& q, + const array& k, + bool has_arr_mask, + bool has_sinks) { + // The flash (prefill) kernel is catastrophically slow for single-query decode + // over long contexts — profiled at ~4.7 ms/call at ~1200 keys vs the ~tens of + // microseconds the vector decode kernel needs (it parallelizes over the KV + // length). Default decode to the vector kernel; opt back into flash only via + // env for experimentation. + static const bool enable = + std::getenv("MLX_SDPA_DECODE_FLASH") != nullptr; + if (!enable) { + return false; + } + if (has_arr_mask || has_sinks) { + return false; + } + if (q.shape(2) != 1) { + return false; + } + if (k.shape(2) < 512) { + return false; + } + return q.dtype() == float16 || q.dtype() == bfloat16; +} + +} // namespace + +namespace fast { + +bool ScaledDotProductAttention::use_fallback( + const array& q, + const array& k, + const array& v, + bool has_mask, + bool has_arr_mask, + bool do_causal, + bool /*is_training*/, + bool output_logsumexp, + Stream /*s*/) { + return !supports_sdpa_vector( + q, k, v, has_mask, has_arr_mask, do_causal, output_logsumexp) && + !supports_sdpa_flash( + q, k, v, has_mask, has_arr_mask, do_causal, output_logsumexp); +} + +bool ScaledDotProductAttention::supports_bool_mask() { + return false; +} + +void ScaledDotProductAttention::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + auto& s = stream(); + + array q = prepare_sdpa_input(inputs[0], s); + array k = prepare_sdpa_input(inputs[1], s); + array v = prepare_sdpa_input(inputs[2], s); + auto& out = outputs[0]; + auto& stats = outputs[1]; + bool has_mask = inputs.size() - has_sinks_ > 3; + bool has_arr_mask = has_mask && !do_causal_; + + std::optional mask_arr; + if (has_arr_mask) { + mask_arr = prepare_sdpa_input(inputs[3], s); + } + + // Prefer WMMA flash attention when available (bf16/fp16, standard dims). + // Gate on the device's runtime arch — a multi-arch wheel can include the + // WMMA kernel even when running on a non-WMMA chip (e.g. gfx1030/1103). +#ifdef MLX_HAS_ROCM_WMMA + // Gate WMMA on the LDS budget of the device actually running the op: the + // kernel's tiled footprint must fit this device's shared-memory-per-block. + bool wmma_supported = + supports_sdpa_flash_wmma(q, k, v, has_arr_mask, output_logsumexp_) && + !has_sinks_ && rocm::device(s.device).has_native_wmma() && + sdpa_flash_wmma_smem(q.shape(-1)) <= + rocm::device(s.device).max_shared_memory_per_block(); +#else + bool wmma_supported = false; +#endif + bool vector_supported = supports_sdpa_vector( + q, k, v, has_mask, has_arr_mask, do_causal_, output_logsumexp_); + bool flash_supported = supports_sdpa_flash( + q, k, v, has_mask, has_arr_mask, do_causal_, output_logsumexp_); + bool flash_first = flash_supported && + prefer_flash_for_decode(q, k, has_arr_mask, has_sinks_); + + if (wmma_supported && q.shape(2) > 4) { +#ifdef MLX_HAS_ROCM_WMMA + // Use WMMA kernel for prefill (qL > 4); decode still uses vector kernel + sdpa_flash_wmma(q, k, v, scale_, out, do_causal_, s); +#endif + } else if (flash_first) { + if (has_sinks_) { + sdpa_flash(q, k, v, scale_, out, do_causal_, mask_arr, inputs.back(), s); + } else { + sdpa_flash(q, k, v, scale_, out, do_causal_, mask_arr, std::nullopt, s); + } + } else if (vector_supported) { + if (has_sinks_) { + sdpa_vector(q, k, v, scale_, out, do_causal_, inputs.back(), s); + } else { + sdpa_vector(q, k, v, scale_, out, do_causal_, std::nullopt, s); + } + } else if (flash_supported) { + if (has_sinks_) { + sdpa_flash(q, k, v, scale_, out, do_causal_, mask_arr, inputs.back(), s); + } else { + sdpa_flash(q, k, v, scale_, out, do_causal_, mask_arr, std::nullopt, s); + } + } else { + // This should not be reached — use_fallback() returns true for unsupported + // configs, causing the framework to decompose SDPA into basic GPU ops + // (matmul + softmax + matmul) before this primitive is created. + throw std::runtime_error( + "[ScaledDotProductAttention::eval_gpu] Unsupported configuration reached. " + "This is a bug — use_fallback() should have returned true."); + } +} + +bool ScaledDotProductAttentionVJP::use_fallback(const array& q, Stream s) { + // Always use fallback for VJP on ROCm for now + return true; +} + +void ScaledDotProductAttentionVJP::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + // VJP uses CPU fallback + throw std::runtime_error( + "SDPA VJP not yet implemented for ROCm. Using CPU fallback."); +} + +} // namespace fast + +} // namespace mlx::core diff --git a/mlx/backend/rocm/scaled_dot_product_attention.hip b/mlx/backend/rocm/scaled_dot_product_attention.hip new file mode 100644 index 0000000000..8c38964071 --- /dev/null +++ b/mlx/backend/rocm/scaled_dot_product_attention.hip @@ -0,0 +1,439 @@ +// Copyright © 2025 Apple Inc. + +#define _USE_MATH_DEFINES + +#include "mlx/backend/gpu/copy.h" +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/allocator.h" +#include "mlx/backend/rocm/device/config.h" +#include "mlx/backend/rocm/device/utils.hpp" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/dtype_utils.h" + +#include +#include +#include + +namespace mlx::core { + +namespace rocm { + +// Virtual warp size for SDPA - always 32 threads for consistent behavior +// across RDNA (32-wide) and CDNA (64-wide) architectures +constexpr int SDPA_TILE_SIZE = 32; + +struct AttnParams { + int B; + int H; + int D; + int qL; + int kL; + int gqa_factor; + float scale; + int64_t Q_strides[3]; + int64_t K_strides[3]; + int64_t V_strides[3]; + int64_t O_strides[3]; +}; + +// Tile-based reduction for 32-thread groups (works on both RDNA and CDNA) +template +__device__ __forceinline__ T tile_reduce_sum_32(T val) { + // Reduce within a 32-thread tile using shuffle operations + val += __shfl_xor(val, 16); + val += __shfl_xor(val, 8); + val += __shfl_xor(val, 4); + val += __shfl_xor(val, 2); + val += __shfl_xor(val, 1); + return val; +} + +template +__device__ __forceinline__ T tile_reduce_max_32(T val) { + // Reduce within a 32-thread tile using shuffle operations + T other; + other = __shfl_xor(val, 16); + val = val > other ? val : other; + other = __shfl_xor(val, 8); + val = val > other ? val : other; + other = __shfl_xor(val, 4); + val = val > other ? val : other; + other = __shfl_xor(val, 2); + val = val > other ? val : other; + other = __shfl_xor(val, 1); + val = val > other ? val : other; + return val; +} + +// Single-pass SDPA kernel for short sequences +// Uses 32-thread tiles for consistent behavior across architectures +template +__global__ void kernel_sdpav_1pass( + const T* Q, + const T* K, + const T* V, + T* O, + const T* sinks, + const AttnParams params) { + // BN = number of 32-thread tiles, BD = tile size (32) + constexpr int BN = 32; // Number of tiles processing keys in parallel + constexpr int BD = 32; // Tile size (always 32 for consistency) + constexpr int v_per_thread = D / BD; + + const int inner_k_stride = BN * params.K_strides[2]; + const int inner_v_stride = BN * params.V_strides[2]; + + typedef float U; + + U q[v_per_thread]; + U k[v_per_thread]; + U o[v_per_thread]; + + __shared__ U outputs[BN][BD + 1]; + __shared__ U max_scores[BN]; + __shared__ U sum_exp_scores[BN]; + + const U scale_log2 = params.scale * 1.44269504089f; // M_LOG2E + + // Use virtual 32-thread tiles instead of hardware warps + const int lane_idx = threadIdx.x % SDPA_TILE_SIZE; // 0-31 within tile + const int tile_idx = threadIdx.x / SDPA_TILE_SIZE; // Which tile (0-31) + + const int batch_idx = blockIdx.z; + const int head_idx = blockIdx.x; + const int kv_head_idx = head_idx / params.gqa_factor; + const int q_seq_idx = blockIdx.y; + const int kv_seq_idx = tile_idx; + + const T* Q_ptr = Q + batch_idx * params.Q_strides[0] + + head_idx * params.Q_strides[1] + q_seq_idx * params.Q_strides[2]; + const T* K_ptr = K + batch_idx * params.K_strides[0] + + kv_head_idx * params.K_strides[1] + kv_seq_idx * params.K_strides[2]; + const T* V_ptr = V + batch_idx * params.V_strides[0] + + kv_head_idx * params.V_strides[1] + kv_seq_idx * params.V_strides[2]; + T* O_ptr = O + batch_idx * params.O_strides[0] + + head_idx * params.O_strides[1] + q_seq_idx * params.O_strides[2]; + +// Read query and initialize output +#pragma unroll + for (int i = 0; i < v_per_thread; i++) { + q[i] = scale_log2 * static_cast(Q_ptr[v_per_thread * lane_idx + i]); + o[i] = 0.f; + } + + U max_score = -__int_as_float(0x7f7fffff); // -FLT_MAX + U sum_exp_score = 0.f; + + if (sinks && tile_idx == 0) { + max_score = 1.44269504089f * static_cast(sinks[head_idx]); // M_LOG2E + sum_exp_score = 1.f; + } + + // Process keys + for (int i = kv_seq_idx; i < params.kL; i += BN) { + bool use_key = true; + if constexpr (do_causal) { + use_key = i <= (params.kL - params.qL + q_seq_idx); + } + + if (use_key) { +#pragma unroll + for (int j = 0; j < v_per_thread; j++) { + k[j] = K_ptr[v_per_thread * lane_idx + j]; + } + + U score = 0.f; +#pragma unroll + for (int j = 0; j < v_per_thread; j++) { + score += q[j] * static_cast(k[j]); + } + + // Reduce within 32-thread tile + score = tile_reduce_sum_32(score); + + U new_max = max(max_score, score); + U factor = exp2f(max_score - new_max); + U exp_score = exp2f(score - new_max); + + max_score = new_max; + sum_exp_score = sum_exp_score * factor + exp_score; + +#pragma unroll + for (int j = 0; j < v_per_thread; j++) { + o[j] = o[j] * factor + + exp_score * static_cast(V_ptr[v_per_thread * lane_idx + j]); + } + } + + K_ptr += inner_k_stride; + V_ptr += inner_v_stride; + } + + // Store per-tile results to shared memory + if (lane_idx == 0) { + max_scores[tile_idx] = max_score; + sum_exp_scores[tile_idx] = sum_exp_score; + } + __syncthreads(); + + // Cross-tile reduction + max_score = max_scores[lane_idx % BN]; + U new_max = tile_reduce_max_32(max_score); + U factor = exp2f(max_score - new_max); + sum_exp_score = tile_reduce_sum_32(sum_exp_scores[lane_idx % BN] * factor); + +// Aggregate outputs across tiles +#pragma unroll + for (int i = 0; i < v_per_thread; i++) { + outputs[lane_idx][tile_idx] = o[i]; + __syncthreads(); + U ot = outputs[tile_idx][lane_idx] * factor; + o[i] = tile_reduce_sum_32(ot); + o[i] = sum_exp_score == 0 ? o[i] : (o[i] / sum_exp_score); + __syncthreads(); + } + + // Write final output + if (lane_idx == 0) { +#pragma unroll + for (int i = 0; i < v_per_thread; i++) { + O_ptr[v_per_thread * tile_idx + i] = static_cast(o[i]); + } + } +} + +} // namespace rocm + +// Forward declarations +bool supports_sdpa_vector( + const array& q, + const array& k, + const array& v, + bool has_mask, + bool has_arr_mask, + bool do_causal, + bool output_logsumexp); + +void sdpa_vector( + const array& q, + const array& k, + const array& v, + float scale, + array& o, + bool do_causal, + const std::optional& sinks, + Stream s); + +bool supports_sdpa_vector( + const array& q, + const array& k, + const array& v, + bool has_mask, + bool has_arr_mask, + bool do_causal, + bool output_logsumexp) { + if (output_logsumexp) { + return false; + } + + // Check for supported dtypes + if (q.dtype() != float32 && q.dtype() != float16 && q.dtype() != bfloat16) { + return false; + } + + const int value_head_dim = v.shape(-1); + const int query_head_dim = q.shape(-1); + const int query_sequence_length = q.shape(2); + + const bool sdpa_supported_head_dim = query_head_dim == value_head_dim && + (query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 || + query_head_dim == 256); + + const bool supported_vector_config = + sdpa_supported_head_dim && query_sequence_length < 4; + + return supported_vector_config && !has_arr_mask; +} + +void sdpa_vector( + const array& q, + const array& k, + const array& v, + float scale, + array& o, + bool do_causal, + const std::optional& sinks, + Stream s) { + auto& d = rocm::device(s.device); + auto& encoder = d.get_command_encoder(s); + + int B = q.shape(0); + int H = q.shape(1); + int qL = q.shape(2); + int kL = k.shape(2); + int D = q.shape(3); + int gqa_factor = q.shape(1) / k.shape(1); + + // Allocate output + o.set_data(mlx::core::rocm::malloc_async(o.nbytes(), encoder)); + + // Build params struct + rocm::AttnParams params; + params.B = B; + params.H = H; + params.D = D; + params.qL = qL; + params.kL = kL; + params.gqa_factor = gqa_factor; + params.scale = scale; + params.Q_strides[0] = q.strides(0); + params.Q_strides[1] = q.strides(1); + params.Q_strides[2] = q.strides(2); + params.K_strides[0] = k.strides(0); + params.K_strides[1] = k.strides(1); + params.K_strides[2] = k.strides(2); + params.V_strides[0] = v.strides(0); + params.V_strides[1] = v.strides(1); + params.V_strides[2] = v.strides(2); + params.O_strides[0] = o.strides(0); + params.O_strides[1] = o.strides(1); + params.O_strides[2] = o.strides(2); + + bool has_sinks = sinks.has_value(); + + encoder.set_input_array(q); + encoder.set_input_array(k); + encoder.set_input_array(v); + if (sinks) { + encoder.set_input_array(*sinks); + } + encoder.set_output_array(o); + + { + dim3 grid_dim(H, qL, B); + dim3 block_dim(1024, 1, 1); // 32 tiles * 32 threads = 1024 + + auto launch_kernel = [&](auto type_tag, auto causal_tag, auto headdim_tag) { + using DataType = decltype(type_tag); + constexpr bool causal = decltype(causal_tag)::value; + constexpr int headdim = decltype(headdim_tag)::value; + + encoder.add_kernel_node( + &rocm::kernel_sdpav_1pass, + grid_dim, + block_dim, + 0, + gpu_ptr(q), + gpu_ptr(k), + gpu_ptr(v), + gpu_ptr(o), + has_sinks ? gpu_ptr(*sinks) : nullptr, + params); + }; + + // Dispatch based on dtype, causal, and head dimension + if (o.dtype() == float32) { + if (do_causal) { + if (D == 64) + launch_kernel( + float(), std::true_type(), std::integral_constant()); + else if (D == 96) + launch_kernel( + float(), std::true_type(), std::integral_constant()); + else if (D == 128) + launch_kernel( + float(), std::true_type(), std::integral_constant()); + else if (D == 256) + launch_kernel( + float(), std::true_type(), std::integral_constant()); + } else { + if (D == 64) + launch_kernel( + float(), std::false_type(), std::integral_constant()); + else if (D == 96) + launch_kernel( + float(), std::false_type(), std::integral_constant()); + else if (D == 128) + launch_kernel( + float(), std::false_type(), std::integral_constant()); + else if (D == 256) + launch_kernel( + float(), std::false_type(), std::integral_constant()); + } + } else if (o.dtype() == float16) { + if (do_causal) { + if (D == 64) + launch_kernel( + __half(), std::true_type(), std::integral_constant()); + else if (D == 96) + launch_kernel( + __half(), std::true_type(), std::integral_constant()); + else if (D == 128) + launch_kernel( + __half(), std::true_type(), std::integral_constant()); + else if (D == 256) + launch_kernel( + __half(), std::true_type(), std::integral_constant()); + } else { + if (D == 64) + launch_kernel( + __half(), std::false_type(), std::integral_constant()); + else if (D == 96) + launch_kernel( + __half(), std::false_type(), std::integral_constant()); + else if (D == 128) + launch_kernel( + __half(), std::false_type(), std::integral_constant()); + else if (D == 256) + launch_kernel( + __half(), std::false_type(), std::integral_constant()); + } + } else if (o.dtype() == bfloat16) { + if (do_causal) { + if (D == 64) + launch_kernel( + hip_bfloat16(), + std::true_type(), + std::integral_constant()); + else if (D == 96) + launch_kernel( + hip_bfloat16(), + std::true_type(), + std::integral_constant()); + else if (D == 128) + launch_kernel( + hip_bfloat16(), + std::true_type(), + std::integral_constant()); + else if (D == 256) + launch_kernel( + hip_bfloat16(), + std::true_type(), + std::integral_constant()); + } else { + if (D == 64) + launch_kernel( + hip_bfloat16(), + std::false_type(), + std::integral_constant()); + else if (D == 96) + launch_kernel( + hip_bfloat16(), + std::false_type(), + std::integral_constant()); + else if (D == 128) + launch_kernel( + hip_bfloat16(), + std::false_type(), + std::integral_constant()); + else if (D == 256) + launch_kernel( + hip_bfloat16(), + std::false_type(), + std::integral_constant()); + } + } + } +} + +} // namespace mlx::core diff --git a/mlx/backend/rocm/scan.hip b/mlx/backend/rocm/scan.hip new file mode 100644 index 0000000000..c21844c70e --- /dev/null +++ b/mlx/backend/rocm/scan.hip @@ -0,0 +1,620 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/allocator.h" +#include "mlx/backend/rocm/device/binary_ops.hpp" +#include "mlx/backend/rocm/device/cast_op.hpp" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/backend/rocm/reduce/reduce_ops.hpp" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include + +#include + +namespace mlx::core { + +namespace rocm { + +// Scan result type trait - Sum on bool produces int32 +template +struct ScanResult { + using type = T; +}; + +template <> +struct ScanResult { + using type = int32_t; +}; + +// ReduceInit specialization for LogAddExp +template +struct ReduceInit { + __device__ static T value() { + return Limits::min(); + } +}; + +// Load values helper - handles reverse and boundary conditions +template +__device__ void +load_values(int index, const T* in, U (&values)[N_READS], int size, U init) { + int remaining = size - index * N_READS; + if constexpr (reverse) { + in += remaining - N_READS; + if (remaining < N_READS) { +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + values[N_READS - i - 1] = + (N_READS - i - 1 < remaining) ? cast_to(in[i]) : init; + } + } else { +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + values[N_READS - i - 1] = cast_to(in[i]); + } + } + } else { + in += index * N_READS; + if (remaining < N_READS) { +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + values[i] = (i < remaining) ? cast_to(in[i]) : init; + } + } else { +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + values[i] = cast_to(in[i]); + } + } + } +} + +// Store values helper - handles reverse, exclusive offset, and boundary conditions +template +__device__ void +store_values(int index, T* out, T (&values)[N_READS], int size) { + int start = index * N_READS + offset; + int remaining = size - start; + if constexpr (reverse) { + out += remaining - N_READS; + if (remaining < N_READS) { +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + if (N_READS - i - 1 < remaining) { + out[i] = values[N_READS - i - 1]; + } + } + } else { +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out[i] = values[N_READS - i - 1]; + } + } + } else { + out += start; + if (remaining < N_READS) { +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + if (i < remaining) { + out[i] = values[i]; + } + } + } else { +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out[i] = values[i]; + } + } + } +} + +// Type-safe shuffle wrappers that handle bfloat16 and half types +// For most types, __shfl_up returns the same type +template +__device__ __forceinline__ T shfl_up_safe(T val, unsigned int delta) { + return __shfl_up(val, delta); +} + +// Specialization for hip_bfloat16 - __shfl_up returns float +template <> +__device__ __forceinline__ hip_bfloat16 shfl_up_safe(hip_bfloat16 val, unsigned int delta) { + return hip_bfloat16(__shfl_up(static_cast(val), delta)); +} + +// Specialization for __half - __shfl_up returns float +template <> +__device__ __forceinline__ __half shfl_up_safe(__half val, unsigned int delta) { + return __half(__shfl_up(__half2float(val), delta)); +} + +// Specialization for hipFloatComplex (complex type) +template <> +__device__ __forceinline__ hipFloatComplex shfl_up_safe(hipFloatComplex val, unsigned int delta) { + return make_hipFloatComplex( + __shfl_up(val.x, delta), + __shfl_up(val.y, delta)); +} + +// Type-safe shfl wrapper +template +__device__ __forceinline__ T shfl_safe(T val, int src_lane) { + return __shfl(val, src_lane); +} + +// Specialization for hip_bfloat16 +template <> +__device__ __forceinline__ hip_bfloat16 shfl_safe(hip_bfloat16 val, int src_lane) { + return hip_bfloat16(__shfl(static_cast(val), src_lane)); +} + +// Specialization for __half +template <> +__device__ __forceinline__ __half shfl_safe(__half val, int src_lane) { + return __half(__shfl(__half2float(val), src_lane)); +} + +// Specialization for hipFloatComplex (complex type) +template <> +__device__ __forceinline__ hipFloatComplex shfl_safe(hipFloatComplex val, int src_lane) { + return make_hipFloatComplex( + __shfl(val.x, src_lane), + __shfl(val.y, src_lane)); +} + +// Warp-level inclusive scan using shuffle +template +__device__ T warp_inclusive_scan(T val, Op op) { + int lane = threadIdx.x % WARP_SIZE; +#pragma unroll + for (int offset = 1; offset < WARP_SIZE; offset *= 2) { + T other = shfl_up_safe(val, offset); + if (lane >= offset) { + val = op(val, other); + } + } + return val; +} + +// Warp-level exclusive scan using shuffle +template +__device__ T warp_exclusive_scan(T val, Op op, T init) { + T inclusive = warp_inclusive_scan(val, op); + T exclusive = shfl_up_safe(inclusive, 1); + return ((threadIdx.x % WARP_SIZE) == 0) ? init : exclusive; +} + +// Contiguous scan kernel - optimized for stride=1 arrays +template < + typename T, + typename U, + typename Op, + int N_READS, + bool inclusive, + bool reverse> +__global__ void contiguous_scan(const T* in, U* out, int32_t axis_size) { + // Calculate block and thread indices + int block_rank = blockIdx.x; + int thread_rank = threadIdx.x; + int block_size = blockDim.x; + int warp_id = thread_rank / WARP_SIZE; + int lane_id = thread_rank % WARP_SIZE; + int num_warps = block_size / WARP_SIZE; + + in += block_rank * axis_size; + out += block_rank * axis_size; + + __shared__ U warp_sums[WARP_SIZE]; + + Op op; + U init = ReduceInit::value(); + U prefix = init; + + // Scan per block + int num_iterations = (axis_size + block_size * N_READS - 1) / (block_size * N_READS); + for (int r = 0; r < num_iterations; ++r) { + int32_t index = r * block_size + thread_rank; + U values[N_READS]; + load_values(index, in, values, axis_size, init); + + // Compute an inclusive scan per thread +#pragma unroll + for (int i = 1; i < N_READS; ++i) { + values[i] = op(values[i], values[i - 1]); + } + + // Compute exclusive scan of thread sums within warp + U thread_sum = values[N_READS - 1]; + U prev_thread_sum = warp_exclusive_scan(thread_sum, op, init); + + // Write warp's sum to shared memory + if (lane_id == WARP_SIZE - 1) { + warp_sums[warp_id] = op(prev_thread_sum, thread_sum); + } + __syncthreads(); + + // Compute exclusive scan of warp sums (first warp only) + if (warp_id == 0) { + U warp_val = (lane_id < num_warps) ? warp_sums[lane_id] : init; + U prev_warp_sum = warp_exclusive_scan(warp_val, op, init); + if (lane_id < num_warps) { + warp_sums[lane_id] = prev_warp_sum; + } + } + __syncthreads(); + + // Compute the output + U warp_prefix = warp_sums[warp_id]; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + values[i] = op(values[i], prefix); + values[i] = op(values[i], warp_prefix); + values[i] = op(values[i], prev_thread_sum); + } + + // Write the values + if (inclusive) { + store_values(index, out, values, axis_size); + } else { + store_values(index, out, values, axis_size); + if (reverse) { + if (thread_rank == 0 && index == 0) { + out[axis_size - 1] = init; + } + } else { + if (thread_rank == 0 && index == 0) { + out[0] = init; + } + } + } + __syncthreads(); + + // Share the prefix for next iteration + if ((warp_id == num_warps - 1) && (lane_id == WARP_SIZE - 1)) { + warp_sums[0] = values[N_READS - 1]; + } + __syncthreads(); + prefix = warp_sums[0]; + } +} + +// Strided scan kernel - for non-contiguous arrays (stride > 1) +template < + typename T, + typename U, + typename Op, + int N_READS, + int BM, + int BN, + bool inclusive, + bool reverse> +__global__ void strided_scan( + const T* in, + U* out, + int32_t axis_size, + int64_t stride, + int64_t stride_blocks) { + int block_rank = blockIdx.x; + int thread_rank = threadIdx.x; + int warp_id = thread_rank / WARP_SIZE; + int lane_id = thread_rank % WARP_SIZE; + + constexpr int BN_pad = WARP_SIZE + 16 / sizeof(U); + constexpr int n_warps = BN / N_READS; + constexpr int n_scans = BN / n_warps; + + __shared__ U read_buffer[BM * BN_pad]; + + Op op; + U init = ReduceInit::value(); + U values[n_scans]; + U prefix[n_scans]; +#pragma unroll + for (int i = 0; i < n_scans; ++i) { + prefix[i] = init; + } + + // Compute offsets + int64_t offset = (block_rank / stride_blocks) * axis_size * stride; + int64_t global_index_x = (block_rank % stride_blocks) * BN; + uint32_t read_offset_y = (thread_rank * N_READS) / BN; + uint32_t read_offset_x = (thread_rank * N_READS) % BN; + uint32_t scan_offset_y = lane_id; + uint32_t scan_offset_x = warp_id * n_scans; + + uint32_t stride_limit = stride - global_index_x; + in += offset + global_index_x + read_offset_x; + out += offset + global_index_x + read_offset_x; + U* read_into = read_buffer + read_offset_y * BN_pad + read_offset_x; + U* read_from = read_buffer + scan_offset_y * BN_pad + scan_offset_x; + + for (uint32_t j = 0; j < axis_size; j += BM) { + // Calculate the indices for the current thread + uint32_t index_y = j + read_offset_y; + uint32_t check_index_y = index_y; + if (reverse) { + index_y = axis_size - 1 - index_y; + } + + // Read into shared memory + if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) { +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + read_into[i] = cast_to(in[index_y * stride + i]); + } + } else { +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) { + read_into[i] = cast_to(in[index_y * stride + i]); + } else { + read_into[i] = init; + } + } + } + __syncthreads(); + + // Read strided into registers +#pragma unroll + for (int i = 0; i < n_scans; ++i) { + values[i] = read_from[i]; + } + + // Perform the scan using warp shuffle +#pragma unroll + for (int i = 0; i < n_scans; ++i) { + values[i] = warp_inclusive_scan(values[i], op); + values[i] = op(values[i], prefix[i]); + prefix[i] = shfl_safe(values[i], WARP_SIZE - 1); + } + + // Write to shared memory +#pragma unroll + for (int i = 0; i < n_scans; ++i) { + read_from[i] = values[i]; + } + __syncthreads(); + + // Write to device memory + if (!inclusive) { + if (check_index_y == 0) { + if ((read_offset_x + N_READS) < stride_limit) { +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out[index_y * stride + i] = init; + } + } else { +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + if ((read_offset_x + i) < stride_limit) { + out[index_y * stride + i] = init; + } + } + } + } + if (reverse) { + index_y -= 1; + check_index_y += 1; + } else { + index_y += 1; + check_index_y += 1; + } + } + if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) { +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out[index_y * stride + i] = read_into[i]; + } + } else { +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) { + out[index_y * stride + i] = read_into[i]; + } + } + } + } +} + +} // namespace rocm + +// Dispatch scan operations +template +void dispatch_scan_ops(Scan::ReduceType scan_op, F&& f) { + if (scan_op == Scan::ReduceType::Max) { + f(type_identity{}); + } else if (scan_op == Scan::ReduceType::Min) { + f(type_identity{}); + } else if (scan_op == Scan::ReduceType::Sum) { + f(type_identity{}); + } else if (scan_op == Scan::ReduceType::Prod) { + f(type_identity{}); + } else if (scan_op == Scan::ReduceType::LogAddExp) { + f(type_identity{}); + } else { + throw std::invalid_argument("Unknown reduce type."); + } +} + +// Get operation name for error messages +template +const char* op_to_string() { + if constexpr (std::is_same_v) { + return "Max"; + } else if constexpr (std::is_same_v) { + return "Min"; + } else if constexpr (std::is_same_v) { + return "Sum"; + } else if constexpr (std::is_same_v) { + return "Prod"; + } else if constexpr (std::is_same_v) { + return "LogAddExp"; + } else { + return "Unknown"; + } +} + +// Check if operation is supported for type +template +constexpr bool supports_scan_op() { + if constexpr (std::is_same_v) { + return is_inexact_v; + } else { + return true; + } +} + +// Dispatch scan types - excludes complex types which don't support warp shuffle +template +void dispatch_scan_types(Dtype dtype, F&& f) { + switch (dtype) { + case bool_: + f(type_identity{}); + break; + case uint8: + f(type_identity{}); + break; + case uint16: + f(type_identity{}); + break; + case uint32: + f(type_identity{}); + break; + case uint64: + f(type_identity{}); + break; + case int8: + f(type_identity{}); + break; + case int16: + f(type_identity{}); + break; + case int32: + f(type_identity{}); + break; + case int64: + f(type_identity{}); + break; + case float16: + f(type_identity{}); + break; + case float32: + f(type_identity{}); + break; + case bfloat16: + f(type_identity{}); + break; + default: + throw std::runtime_error( + "Scan operations are not supported for complex types on ROCm."); + } +} + +void Scan::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + auto in = inputs[0]; + auto& s = stream(); + auto& encoder = rocm::get_command_encoder(s); + + // Check for complex types early + if (in.dtype() == complex64) { + throw std::runtime_error( + "Scan operations are not supported for complex types on ROCm."); + } + + if (in.flags().contiguous && in.strides()[axis_] != 0) { + if (in.is_donatable() && in.itemsize() == out.itemsize()) { + out.copy_shared_buffer(in); + } else { + out.set_data( + mlx::core::rocm::malloc_async(in.data_size() * out.itemsize(), encoder), + in.data_size(), + in.strides(), + in.flags()); + } + } else { + in = contiguous_copy_gpu(in, s); + out.copy_shared_buffer(in); + } + + constexpr int N_READS = 4; + int32_t axis_size = in.shape(axis_); + bool contiguous = in.strides()[axis_] == 1; + + encoder.set_input_array(in); + encoder.set_output_array(out); + + dispatch_scan_types(in.dtype(), [&](auto type_tag) { + using T = hip_type_t; + dispatch_scan_ops(reduce_type_, [&](auto scan_op_tag) { + using Op = MLX_GET_TYPE(scan_op_tag); + if constexpr (supports_scan_op()) { + using U = typename rocm::ScanResult::type; + dispatch_bool(inclusive_, [&](auto inclusive) { + dispatch_bool(reverse_, [&](auto reverse) { + if (contiguous) { + int block_dim = ceildiv(axis_size, N_READS); + block_dim = ceildiv(block_dim, WARP_SIZE) * WARP_SIZE; + block_dim = std::min(block_dim, WARP_SIZE * WARP_SIZE); + int num_blocks = in.data_size() / axis_size; + encoder.add_kernel_node( + &rocm::contiguous_scan< + T, + U, + Op, + N_READS, + inclusive.value, + reverse.value>, + dim3(num_blocks), + dim3(block_dim), + 0, + gpu_ptr(in), + gpu_ptr(out), + axis_size); + } else { + constexpr int BM = WARP_SIZE; + constexpr int BN = WARP_SIZE; + int64_t stride = in.strides()[axis_]; + int64_t stride_blocks = ceildiv(stride, (int64_t)BN); + dim3 num_blocks = get_2d_grid_dims( + in.shape(), in.strides(), axis_size * stride); + if (num_blocks.x * stride_blocks <= UINT32_MAX) { + num_blocks.x *= stride_blocks; + } else { + num_blocks.y *= stride_blocks; + } + int block_dim = (BN / N_READS) * WARP_SIZE; + encoder.add_kernel_node( + &rocm::strided_scan< + T, + U, + Op, + N_READS, + BM, + BN, + inclusive.value, + reverse.value>, + num_blocks, + dim3(block_dim), + 0, + gpu_ptr(in), + gpu_ptr(out), + axis_size, + stride, + stride_blocks); + } + }); + }); + } else { + throw std::runtime_error( + std::string("Can not do scan op ") + op_to_string() + + " on inputs of " + dtype_to_string(in.dtype()) + + " with result of " + dtype_to_string(out.dtype()) + "."); + } + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/rocm/slicing.cpp b/mlx/backend/rocm/slicing.cpp new file mode 100644 index 0000000000..ed4fb1d7b6 --- /dev/null +++ b/mlx/backend/rocm/slicing.cpp @@ -0,0 +1,151 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/slicing.h" +#include "mlx/backend/gpu/copy.h" +#include "mlx/backend/gpu/slicing.h" +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/allocator.h" +#include "mlx/backend/rocm/jit_module.h" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/backend/rocm/utils.h" +#include "mlx/dtype_utils.h" + +#include +#include + +namespace mlx::core { + +void concatenate_gpu( + const std::vector& inputs, + array& out, + int axis, + const Stream& s) { + std::vector sizes; + sizes.push_back(0); + for (auto& p : inputs) { + sizes.push_back(p.shape(axis)); + } + std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin()); + + out.set_data(allocator::malloc(out.nbytes())); + + auto strides = out.strides(); + auto flags = out.flags(); + flags.row_contiguous = false; + flags.col_contiguous = false; + flags.contiguous = false; + for (int i = 0; i < inputs.size(); i++) { + array out_slice(inputs[i].shape(), out.dtype(), nullptr, {}); + size_t data_offset = strides[axis] * sizes[i]; + out_slice.copy_shared_buffer( + out, strides, flags, out_slice.size(), data_offset); + copy_gpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, s); + } +} + +array compute_dynamic_offset( + const array& indices, + const Strides& strides, + const std::vector& axes, + const Stream& s) { + Dtype dtype = indices.dtype(); + int nidx = axes.size(); + + std::ostringstream module_name_ss; + module_name_ss << "compute_dynamic_offset_" << dtype_to_string(dtype) << "_" + << nidx; + std::string module_name = module_name_ss.str(); + + std::ostringstream kernel_name_ss; + kernel_name_ss << "mlx::core::rocm::compute_dynamic_offset<" + << dtype_to_hip_type(dtype) << ", " << nidx << ">"; + std::string kernel_name = kernel_name_ss.str(); + + rocm::JitModule& mod = rocm::get_jit_module(s.device, module_name, [&]() { + std::ostringstream source; + source << R"( + #include + + // Standard type definitions for JIT compilation + using int64_t = signed long long; + using int32_t = signed int; + + #define MAX_NDIM 10 + + namespace mlx::core::rocm { + + template + struct hip_array { + T data_[N]; + __host__ __device__ T& operator[](int i) { return data_[i]; } + __host__ __device__ const T& operator[](int i) const { + return data_[i]; + } + }; + + template + __global__ void compute_dynamic_offset( + const T* indices, + int64_t* offset, + hip_array strides, + hip_array axes) { + int64_t acc = 0; + #pragma unroll + for (int i = 0; i < NIDX; ++i) { + acc += static_cast(indices[i]) * strides[axes[i]]; + } + *offset = acc; + } + + } // namespace mlx::core::rocm + )"; + return std::make_tuple(false, source.str(), std::vector{kernel_name}); + }); + + auto& encoder = rocm::get_command_encoder(s); + // Prepare output. + array offset({1}, int64, nullptr, {}); + bool donate = indices.is_donatable() && + (indices.data_size() * indices.itemsize()) >= offset.itemsize(); + if (donate) { + offset.copy_shared_buffer(indices); + } else { + offset.set_data(mlx::core::rocm::malloc_async(offset.itemsize(), encoder)); + } + + encoder.add_temporary(offset); + encoder.set_input_array(indices); + encoder.set_output_array(offset); + + rocm::hip_array strides_arg = {}; + rocm::hip_array axes_arg = {}; + for (int i = 0; i < static_cast(strides.size()); ++i) { + strides_arg.data_[i] = static_cast(strides[i]); + } + for (int i = 0; i < static_cast(axes.size()); ++i) { + axes_arg.data_[i] = static_cast(axes[i]); + } + + // Get kernel before launching to avoid any potential issues + auto kernel = mod.get_kernel(kernel_name); + + // Get GPU pointers before lambda to avoid synchronization issues + const void* indices_ptr = gpu_ptr(indices); + void* offset_ptr = gpu_ptr(offset); + + encoder.launch_kernel( + [kernel, indices_ptr, offset_ptr, strides_arg, axes_arg]( + hipStream_t stream) { + const void* arg0 = indices_ptr; + void* arg1 = offset_ptr; + rocm::hip_array arg2 = strides_arg; + rocm::hip_array arg3 = axes_arg; + void* args[] = {&arg0, &arg1, &arg2, &arg3}; + (void)hipModuleLaunchKernel( + kernel, 1, 1, 1, 1, 1, 1, 0, stream, args, nullptr); + }); + + return offset; +} + +} // namespace mlx::core diff --git a/mlx/backend/rocm/softmax.hip b/mlx/backend/rocm/softmax.hip new file mode 100644 index 0000000000..16d7bb0170 --- /dev/null +++ b/mlx/backend/rocm/softmax.hip @@ -0,0 +1,372 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/device/cast_op.hpp" +#include "mlx/backend/rocm/device/fp16_math.hpp" +#include "mlx/backend/rocm/device/utils.hpp" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include + +#include + +namespace mlx::core { + +namespace rocm { + +// Type-safe shuffle wrappers for __shfl_xor +template +__device__ __forceinline__ T shfl_xor_safe(T val, int lane_mask) { + return __shfl_xor(val, lane_mask); +} + +// Specialization for hip_bfloat16 - __shfl_xor returns float +template <> +__device__ __forceinline__ hip_bfloat16 shfl_xor_safe(hip_bfloat16 val, int lane_mask) { + return hip_bfloat16(__shfl_xor(static_cast(val), lane_mask)); +} + +// Specialization for __half - __shfl_xor returns float +template <> +__device__ __forceinline__ __half shfl_xor_safe(__half val, int lane_mask) { + return __half(__shfl_xor(__half2float(val), lane_mask)); +} + +template +inline __device__ T softmax_exp(T x) { + // Softmax doesn't need high precision exponential cause x is gonna be in + // (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)). + if constexpr (std::is_same_v) { + return __expf(x); + } else if constexpr (std::is_same_v) { + return exp(x); + } else { + return T(__expf(static_cast(x))); + } +} + +// Warp reduce for max using shuffle +template +__device__ T warp_reduce_max(T val) { +#pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + T other = shfl_xor_safe(val, offset); + val = val > other ? val : other; + } + return val; +} + +// Warp reduce for sum using shuffle +template +__device__ T warp_reduce_sum(T val) { +#pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + T other = shfl_xor_safe(val, offset); + val = val + other; + } + return val; +} + +// Optimized softmax kernel using online normalizer calculation +// Reference: https://github.com/NVIDIA/online-softmax +template +__global__ void softmax_kernel(const T* in, T* out, int axis_size) { + int row = blockIdx.x; + int thread_rank = threadIdx.x; + int lane = thread_rank % WARP_SIZE; + int warp_id = thread_rank / WARP_SIZE; + int num_warps = BLOCK_DIM / WARP_SIZE; + + in += row * axis_size; + out += row * axis_size; + + // Online softmax: compute max and normalizer in a single pass + AccT prevmax; + AccT maxval = Limits::finite_min(); + AccT normalizer = AccT(0); + + int num_iterations = (axis_size + BLOCK_DIM * N_READS - 1) / (BLOCK_DIM * N_READS); + for (int r = 0; r < num_iterations; ++r) { + int index = r * BLOCK_DIM + thread_rank; + + // Load values + AccT vals[N_READS]; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + int idx = index * N_READS + i; + vals[i] = (idx < axis_size) ? static_cast(in[idx]) : Limits::min(); + } + + // Update max + prevmax = maxval; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + maxval = maxval > vals[i] ? maxval : vals[i]; + } + + // Online normalizer calculation + normalizer = normalizer * softmax_exp(prevmax - maxval); +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + normalizer = normalizer + softmax_exp(vals[i] - maxval); + } + } + + // First warp reduce + prevmax = maxval; + maxval = warp_reduce_max(maxval); + normalizer = normalizer * softmax_exp(prevmax - maxval); + normalizer = warp_reduce_sum(normalizer); + + __shared__ AccT local_max[WARP_SIZE]; + __shared__ AccT local_normalizer[WARP_SIZE]; + + // Write to shared memory and do second warp reduce + prevmax = maxval; + if (lane == 0) { + local_max[warp_id] = maxval; + } + __syncthreads(); + + maxval = (lane < num_warps) ? local_max[lane] : Limits::min(); + maxval = warp_reduce_max(maxval); + normalizer = normalizer * softmax_exp(prevmax - maxval); + + if (lane == 0) { + local_normalizer[warp_id] = normalizer; + } + __syncthreads(); + + normalizer = (lane < num_warps) ? local_normalizer[lane] : AccT(0); + normalizer = warp_reduce_sum(normalizer); + normalizer = AccT(1) / normalizer; + + // Write output + for (int r = 0; r < num_iterations; ++r) { + int index = r * BLOCK_DIM + thread_rank; + +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + int idx = index * N_READS + i; + if (idx < axis_size) { + AccT val = static_cast(in[idx]); + out[idx] = static_cast(softmax_exp(val - maxval) * normalizer); + } + } + } +} + +// Vectorized softmax kernel for better memory throughput +template +__global__ void softmax_kernel_vectorized(const T* in, T* out, int axis_size) { + int row = blockIdx.x; + int thread_rank = threadIdx.x; + int lane = thread_rank % WARP_SIZE; + int warp_id = thread_rank / WARP_SIZE; + int num_warps = BLOCK_DIM / WARP_SIZE; + + in += row * axis_size; + out += row * axis_size; + + // Online softmax: compute max and normalizer in a single pass + AccT prevmax; + AccT maxval = Limits::finite_min(); + AccT normalizer = AccT(0); + + int vec_size = axis_size / N_READS; + int num_iterations = (vec_size + BLOCK_DIM - 1) / BLOCK_DIM; + + for (int r = 0; r < num_iterations; ++r) { + int index = r * BLOCK_DIM + thread_rank; + + // Load values using vectorized load + AccT vals[N_READS]; + if (index < vec_size) { + auto vec = load_vector(in, index); +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + vals[i] = static_cast(vec[i]); + } + } else { +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + int idx = index * N_READS + i; + vals[i] = (idx < axis_size) ? static_cast(in[idx]) : Limits::min(); + } + } + + // Update max + prevmax = maxval; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + maxval = maxval > vals[i] ? maxval : vals[i]; + } + + // Online normalizer calculation + normalizer = normalizer * softmax_exp(prevmax - maxval); +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + normalizer = normalizer + softmax_exp(vals[i] - maxval); + } + } + + // Handle remaining elements + int remaining_start = vec_size * N_READS; + for (int idx = remaining_start + thread_rank; idx < axis_size; idx += BLOCK_DIM) { + prevmax = maxval; + AccT val = static_cast(in[idx]); + maxval = maxval > val ? maxval : val; + normalizer = normalizer * softmax_exp(prevmax - maxval); + normalizer = normalizer + softmax_exp(val - maxval); + } + + // First warp reduce + prevmax = maxval; + maxval = warp_reduce_max(maxval); + normalizer = normalizer * softmax_exp(prevmax - maxval); + normalizer = warp_reduce_sum(normalizer); + + __shared__ AccT local_max[WARP_SIZE]; + __shared__ AccT local_normalizer[WARP_SIZE]; + + // Write to shared memory and do second warp reduce + prevmax = maxval; + if (lane == 0) { + local_max[warp_id] = maxval; + } + __syncthreads(); + + maxval = (lane < num_warps) ? local_max[lane] : Limits::min(); + maxval = warp_reduce_max(maxval); + normalizer = normalizer * softmax_exp(prevmax - maxval); + + if (lane == 0) { + local_normalizer[warp_id] = normalizer; + } + __syncthreads(); + + normalizer = (lane < num_warps) ? local_normalizer[lane] : AccT(0); + normalizer = warp_reduce_sum(normalizer); + normalizer = AccT(1) / normalizer; + + // Write output using vectorized store + for (int r = 0; r < num_iterations; ++r) { + int index = r * BLOCK_DIM + thread_rank; + + if (index < vec_size) { + auto vec = load_vector(in, index); + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + AccT val = static_cast(vec[i]); + out_vec[i] = static_cast(softmax_exp(val - maxval) * normalizer); + } + store_vector(out, index, out_vec); + } else { +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + int idx = index * N_READS + i; + if (idx < axis_size) { + AccT val = static_cast(in[idx]); + out[idx] = static_cast(softmax_exp(val - maxval) * normalizer); + } + } + } + } + + // Handle remaining elements + for (int idx = remaining_start + thread_rank; idx < axis_size; idx += BLOCK_DIM) { + AccT val = static_cast(in[idx]); + out[idx] = static_cast(softmax_exp(val - maxval) * normalizer); + } +} + +} // namespace rocm + +void Softmax::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + auto& s = stream(); + + // Make sure that the last dimension is contiguous. + auto set_output = [&s, &out](const array& x) { + if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) { + if (x.is_donatable()) { + out.copy_shared_buffer(x); + } else { + out.set_data( + allocator::malloc(x.data_size() * x.itemsize()), + x.data_size(), + x.strides(), + x.flags()); + } + return x; + } else { + array x_copy = contiguous_copy_gpu(x, s); + out.copy_shared_buffer(x_copy); + return x_copy; + } + }; + + array in = set_output(inputs[0]); + bool precise = in.dtype() != float32 && precise_; + + int axis_size = in.shape().back(); + int n_rows = in.data_size() / axis_size; + + auto& encoder = rocm::get_command_encoder(s); + encoder.set_input_array(in); + encoder.set_output_array(out); + + // Choose block size based on axis size + auto launch_softmax = [&](auto type_tag, auto acc_type_tag) { + using T = typename decltype(type_tag)::type; + using AccT = typename decltype(acc_type_tag)::type; + + constexpr int N_READS = 4; + + // Choose block size based on axis size for better occupancy + if (axis_size <= 256 * N_READS) { + encoder.add_kernel_node( + &rocm::softmax_kernel, + dim3(n_rows), dim3(256), 0, + gpu_ptr(in), gpu_ptr(out), axis_size); + } else if (axis_size <= 512 * N_READS) { + encoder.add_kernel_node( + &rocm::softmax_kernel, + dim3(n_rows), dim3(512), 0, + gpu_ptr(in), gpu_ptr(out), axis_size); + } else { + encoder.add_kernel_node( + &rocm::softmax_kernel, + dim3(n_rows), dim3(1024), 0, + gpu_ptr(in), gpu_ptr(out), axis_size); + } + }; + + switch (out.dtype()) { + case float32: + launch_softmax(type_identity{}, type_identity{}); + break; + case float16: + if (precise) { + launch_softmax(type_identity<__half>{}, type_identity{}); + } else { + launch_softmax(type_identity<__half>{}, type_identity<__half>{}); + } + break; + case bfloat16: + if (precise) { + launch_softmax(type_identity{}, type_identity{}); + } else { + launch_softmax(type_identity{}, type_identity{}); + } + break; + default: + throw std::runtime_error("Unsupported type for softmax"); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/rocm/sort.hip b/mlx/backend/rocm/sort.hip new file mode 100644 index 0000000000..ac4cc9e5ea --- /dev/null +++ b/mlx/backend/rocm/sort.hip @@ -0,0 +1,657 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/allocator.h" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include + +// Workaround: rocprim headers use placement new in __device__ code, +// which requires __device__ overloads of operator new/delete. +// ROCm 7.12+ (clang 22+) already provides these via cuda_wrappers/new. +#ifdef __HIP_DEVICE_COMPILE__ +#ifndef __CLANG_CUDA_WRAPPERS_NEW +__device__ inline void* operator new(size_t, void* p) noexcept { return p; } +__device__ inline void* operator new[](size_t, void* p) noexcept { return p; } +__device__ inline void operator delete(void*, void*) noexcept {} +__device__ inline void operator delete[](void*, void*) noexcept {} +#endif +#endif + +#include +#include +#include + +namespace mlx::core { + +constexpr int N_PER_THREAD = 8; + +namespace rocm { + +template +__device__ __forceinline__ T nan_value(); + +template <> +__device__ __forceinline__ float nan_value() { + return __builtin_nanf(""); +} + +template <> +__device__ __forceinline__ double nan_value() { + return __builtin_nan(""); +} + +template <> +__device__ __forceinline__ _Float16 nan_value<_Float16>() { + return static_cast<_Float16>(__builtin_nanf("")); +} + +// __half may or may not be the same as _Float16 depending on HIP version. +// Provide explicit specialization via __float2half conversion. +template <> +__device__ __forceinline__ __half nan_value<__half>() { + return __float2half(__builtin_nanf("")); +} + +template <> +__device__ __forceinline__ hip_bfloat16 nan_value() { + return hip_bfloat16(__builtin_nanf("")); +} + +// Helper trait: true for all floating-point types including __half and hip_bfloat16. +// std::is_floating_point_v is false for __half and hip_bfloat16, which would +// cause NaN handling to be skipped and produce incorrect sort results. +template +inline constexpr bool is_sort_floating_v = + std::is_floating_point_v || + std::is_same_v || + std::is_same_v; + +template +struct InitValue { + __device__ __forceinline__ static T value() { + return rocm::Limits::max(); + } +}; + +template +struct InitValue>> { + __device__ __forceinline__ static T value() { + return nan_value(); + } +}; + +template +__device__ __forceinline__ void thread_swap(T& a, T& b) { + T w = a; + a = b; + b = w; +} + +template +struct LessThan { + __device__ __forceinline__ static T init() { + return InitValue::value(); + } + + __device__ __forceinline__ bool operator()(T a, T b) const { + if constexpr (is_sort_floating_v) { + bool an = isnan(static_cast(a)); + bool bn = isnan(static_cast(b)); + if (an | bn) { + return (!an) & bn; + } + } + return a < b; + } +}; + +template < + typename ValT, + typename IdxT, + bool ARG_SORT, + int N_PER_THREAD, + typename CompareOp> +struct ThreadSort { + __device__ __forceinline__ static void sort( + ValT (&vals)[N_PER_THREAD], + IdxT (&idxs)[N_PER_THREAD]) { + CompareOp op; +#pragma unroll + for (int i = 0; i < N_PER_THREAD; ++i) { +#pragma unroll + for (int j = i & 1; j < N_PER_THREAD - 1; j += 2) { + if (op(vals[j + 1], vals[j])) { + thread_swap(vals[j + 1], vals[j]); + if constexpr (ARG_SORT) { + thread_swap(idxs[j + 1], idxs[j]); + } + } + } + } + } +}; + +template < + typename ValT, + typename IdxT, + bool ARG_SORT, + int BLOCK_THREADS, + int N_PER_THREAD, + typename CompareOp> +struct BlockMergeSort { + using thread_sort_t = + ThreadSort; + + __device__ __forceinline__ static int merge_partition( + const ValT* As, + const ValT* Bs, + int A_sz, + int B_sz, + int sort_md) { + CompareOp op; + + int A_st = max(0, sort_md - B_sz); + int A_ed = min(sort_md, A_sz); + + while (A_st < A_ed) { + int md = A_st + (A_ed - A_st) / 2; + auto a = As[md]; + auto b = Bs[sort_md - 1 - md]; + + if (op(b, a)) { + A_ed = md; + } else { + A_st = md + 1; + } + } + + return A_ed; + } + + __device__ __forceinline__ static void merge_step( + const ValT* As, + const ValT* Bs, + const IdxT* As_idx, + const IdxT* Bs_idx, + int A_sz, + int B_sz, + ValT (&vals)[N_PER_THREAD], + IdxT (&idxs)[N_PER_THREAD]) { + CompareOp op; + int a_idx = 0; + int b_idx = 0; + +#pragma unroll + for (int i = 0; i < N_PER_THREAD; ++i) { + auto a = (a_idx < A_sz) ? As[a_idx] : ValT(CompareOp::init()); + auto b = (b_idx < B_sz) ? Bs[b_idx] : ValT(CompareOp::init()); + bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a)); + + vals[i] = pred ? b : a; + if constexpr (ARG_SORT) { + if (pred) { + idxs[i] = Bs_idx[b_idx]; + } else { + idxs[i] = (a_idx < A_sz) ? As_idx[a_idx] : IdxT(0); + } + } + + b_idx += int(pred); + a_idx += int(!pred); + } + } + + __device__ __forceinline__ static void + sort(ValT* tgp_vals, IdxT* tgp_idxs, int size_sorted_axis) { + int idx = threadIdx.x * N_PER_THREAD; + + ValT thread_vals[N_PER_THREAD]; + IdxT thread_idxs[N_PER_THREAD]; +#pragma unroll + for (int i = 0; i < N_PER_THREAD; ++i) { + thread_vals[i] = tgp_vals[idx + i]; + if constexpr (ARG_SORT) { + thread_idxs[i] = tgp_idxs[idx + i]; + } + } + + if (idx < size_sorted_axis) { + thread_sort_t::sort(thread_vals, thread_idxs); + } + + for (int merge_threads = 2; merge_threads <= BLOCK_THREADS; + merge_threads *= 2) { + __syncthreads(); +#pragma unroll + for (int i = 0; i < N_PER_THREAD; ++i) { + tgp_vals[idx + i] = thread_vals[i]; + if constexpr (ARG_SORT) { + tgp_idxs[idx + i] = thread_idxs[i]; + } + } + __syncthreads(); + + int merge_group = threadIdx.x / merge_threads; + int merge_lane = threadIdx.x % merge_threads; + + int sort_sz = N_PER_THREAD * merge_threads; + int sort_st = N_PER_THREAD * merge_threads * merge_group; + + int A_st = sort_st; + int A_ed = sort_st + sort_sz / 2; + int B_st = sort_st + sort_sz / 2; + int B_ed = sort_st + sort_sz; + + const ValT* As = tgp_vals + A_st; + const ValT* Bs = tgp_vals + B_st; + int A_sz = A_ed - A_st; + int B_sz = B_ed - B_st; + + int sort_md = N_PER_THREAD * merge_lane; + int partition = merge_partition(As, Bs, A_sz, B_sz, sort_md); + + As += partition; + Bs += sort_md - partition; + + A_sz -= partition; + B_sz -= sort_md - partition; + + const IdxT* As_idx = ARG_SORT ? tgp_idxs + A_st + partition : nullptr; + const IdxT* Bs_idx = + ARG_SORT ? tgp_idxs + B_st + sort_md - partition : nullptr; + + merge_step(As, Bs, As_idx, Bs_idx, A_sz, B_sz, thread_vals, thread_idxs); + } + + __syncthreads(); +#pragma unroll + for (int i = 0; i < N_PER_THREAD; ++i) { + tgp_vals[idx + i] = thread_vals[i]; + if constexpr (ARG_SORT) { + tgp_idxs[idx + i] = thread_idxs[i]; + } + } + } +}; + +template < + typename T, + typename U, + bool ARG_SORT, + int BLOCK_THREADS, + int N_PER_THREAD, + typename CompareOp = LessThan> +struct KernelMergeSort { + using ValT = T; + using IdxT = uint32_t; + using block_merge_sort_t = BlockMergeSort< + ValT, + IdxT, + ARG_SORT, + BLOCK_THREADS, + N_PER_THREAD, + CompareOp>; + + static constexpr int N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD; + + __device__ __forceinline__ static void block_sort( + const T* inp, + U* out, + int size_sorted_axis, + int64_t in_stride_sorted_axis, + int64_t out_stride_sorted_axis, + int64_t in_stride_segment_axis, + int64_t out_stride_segment_axis, + ValT* tgp_vals, + IdxT* tgp_idxs) { + inp += blockIdx.y * in_stride_segment_axis; + out += blockIdx.y * out_stride_segment_axis; + + for (int i = threadIdx.x; i < N_PER_BLOCK; i += BLOCK_THREADS) { + tgp_vals[i] = i < size_sorted_axis ? inp[i * in_stride_sorted_axis] + : ValT(CompareOp::init()); + if constexpr (ARG_SORT) { + tgp_idxs[i] = i; + } + } + + __syncthreads(); + block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis); + __syncthreads(); + + int out_limit = min(size_sorted_axis, N_PER_BLOCK); + for (int i = threadIdx.x; i < out_limit; i += BLOCK_THREADS) { + if constexpr (ARG_SORT) { + out[i * out_stride_sorted_axis] = tgp_idxs[i]; + } else { + out[i * out_stride_sorted_axis] = tgp_vals[i]; + } + } + } +}; + +template < + typename T, + typename U, + bool ARG_SORT, + int BLOCK_THREADS, + int N_PER_THREAD> +__global__ void block_sort_kernel( + const T* inp, + U* out, + int size_sorted_axis, + int64_t in_stride_sorted_axis, + int64_t out_stride_sorted_axis, + int64_t in_stride_segment_axis, + int64_t out_stride_segment_axis) { + using sort_kernel = + KernelMergeSort; + using ValT = typename sort_kernel::ValT; + using IdxT = typename sort_kernel::IdxT; + + if constexpr (ARG_SORT) { + __shared__ ValT tgp_vals[sort_kernel::N_PER_BLOCK]; + __shared__ IdxT tgp_idxs[sort_kernel::N_PER_BLOCK]; + sort_kernel::block_sort( + inp, + out, + size_sorted_axis, + in_stride_sorted_axis, + out_stride_sorted_axis, + in_stride_segment_axis, + out_stride_segment_axis, + tgp_vals, + tgp_idxs); + } else { + __shared__ ValT tgp_vals[sort_kernel::N_PER_BLOCK]; + sort_kernel::block_sort( + inp, + out, + size_sorted_axis, + in_stride_sorted_axis, + out_stride_sorted_axis, + in_stride_segment_axis, + out_stride_segment_axis, + tgp_vals, + nullptr); + } +} + +// Simple iota kernel: fills output[i] = i for i in [0, n). +// Used to initialize index arrays on-device instead of copying from host. +__global__ void iota_kernel(uint32_t* out, int n) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n) { + out[i] = static_cast(i); + } +} + +} // namespace rocm + +namespace { + +void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { + array out = out_; + auto& encoder = rocm::get_command_encoder(s); + if (axis < 0) { + axis += in.ndim(); + } + + int size_sorted_axis = in.shape(axis); + int n_rows = in.size() / size_sorted_axis; + int last_dim = in.ndim() - 1; + + // If we are not sorting the innermost dimension of a contiguous array, + // transpose and make a copy. + bool is_segmented_sort = in.flags().contiguous && in.strides()[axis] == 1; + if (!is_segmented_sort) { + array trans = swapaxes_in_eval(in, axis, last_dim); + in = contiguous_copy_gpu(trans, s); + encoder.add_temporary(in); + out = array(mlx::core::rocm::malloc_async(out.nbytes(), encoder), in.shape(), out.dtype()); + encoder.add_temporary(out); + } else { + out.set_data( + mlx::core::rocm::malloc_async(in.data_size() * out.itemsize(), encoder), + in.data_size(), + in.strides(), + in.flags()); + } + + encoder.set_input_array(in); + encoder.set_output_array(out); + + auto& stream = encoder.stream(); + + // For large arrays that exceed the block sort capacity (512 threads * 8 items = 4096), + // use rocprim radix sort which handles arbitrary sizes correctly. + constexpr int tn = N_PER_THREAD; + constexpr int max_block_sort_size = 512 * tn; // 4096 + + if (size_sorted_axis > max_block_sort_size) { + dispatch_all_types(in.dtype(), [&](auto type_tag) { + using CTYPE = MLX_GET_TYPE(type_tag); + if constexpr (!std::is_same_v) { + using ValT = hip_type_t; + + encoder.launch_kernel([&](hipStream_t hip_stream) { + int N = size_sorted_axis; + + if (argsort) { + // Allocate all temp buffers once, outside the row loop. + uint32_t* indices_in = nullptr; + uint32_t* indices_out = nullptr; + ValT* vals_tmp = nullptr; + ValT* vals_sorted = nullptr; + CHECK_HIP_ERROR(hipMalloc(&indices_in, N * sizeof(uint32_t))); + CHECK_HIP_ERROR(hipMalloc(&indices_out, N * sizeof(uint32_t))); + CHECK_HIP_ERROR(hipMalloc(&vals_tmp, N * sizeof(ValT))); + CHECK_HIP_ERROR(hipMalloc(&vals_sorted, N * sizeof(ValT))); + + // Query temp storage size (same for all rows with same N). + size_t temp_bytes = 0; + rocprim::radix_sort_pairs( + nullptr, temp_bytes, + vals_tmp, vals_sorted, + indices_in, indices_out, + N, 0, sizeof(ValT) * 8, hip_stream); + + void* temp_storage = nullptr; + CHECK_HIP_ERROR(hipMalloc(&temp_storage, temp_bytes)); + + // Initialize iota indices on device (avoids host vector + memcpy). + { + int block = 256; + int grid = (N + block - 1) / block; + hipLaunchKernelGGL( + rocm::iota_kernel, dim3(grid), dim3(block), 0, hip_stream, + indices_in, N); + } + + for (int row = 0; row < n_rows; ++row) { + const ValT* in_row = gpu_ptr(in) + row * N; + + // Copy input values to mutable buffer for rocprim. + CHECK_HIP_ERROR(hipMemcpyAsync(vals_tmp, in_row, + N * sizeof(ValT), hipMemcpyDeviceToDevice, hip_stream)); + + // Re-initialize indices for each row (iota is idempotent so + // we can re-use the same buffer if we reset it). + if (row > 0) { + hipLaunchKernelGGL( + rocm::iota_kernel, dim3((N + 255) / 256), dim3(256), + 0, hip_stream, indices_in, N); + } + + rocprim::radix_sort_pairs( + temp_storage, temp_bytes, + vals_tmp, vals_sorted, + indices_in, indices_out, + N, 0, sizeof(ValT) * 8, hip_stream); + + // Copy result indices to output. + uint32_t* out_row = gpu_ptr(out) + row * N; + CHECK_HIP_ERROR(hipMemcpyAsync(out_row, indices_out, + N * sizeof(uint32_t), hipMemcpyDeviceToDevice, hip_stream)); + } + + CHECK_HIP_ERROR(hipFree(indices_in)); + CHECK_HIP_ERROR(hipFree(indices_out)); + CHECK_HIP_ERROR(hipFree(vals_tmp)); + CHECK_HIP_ERROR(hipFree(vals_sorted)); + CHECK_HIP_ERROR(hipFree(temp_storage)); + } else { + // Sort values only -- allocate once outside loop. + ValT* vals_in = nullptr; + ValT* vals_out_buf = nullptr; + CHECK_HIP_ERROR(hipMalloc(&vals_in, N * sizeof(ValT))); + CHECK_HIP_ERROR(hipMalloc(&vals_out_buf, N * sizeof(ValT))); + + size_t temp_bytes = 0; + rocprim::radix_sort_keys( + nullptr, temp_bytes, + vals_in, vals_out_buf, + N, 0, sizeof(ValT) * 8, hip_stream); + + void* temp_storage = nullptr; + CHECK_HIP_ERROR(hipMalloc(&temp_storage, temp_bytes)); + + for (int row = 0; row < n_rows; ++row) { + const ValT* in_row = gpu_ptr(in) + row * N; + + CHECK_HIP_ERROR(hipMemcpyAsync(vals_in, in_row, + N * sizeof(ValT), hipMemcpyDeviceToDevice, hip_stream)); + + rocprim::radix_sort_keys( + temp_storage, temp_bytes, + vals_in, vals_out_buf, + N, 0, sizeof(ValT) * 8, hip_stream); + + ValT* out_row = gpu_ptr(out) + row * N; + CHECK_HIP_ERROR(hipMemcpyAsync(out_row, vals_out_buf, + N * sizeof(ValT), hipMemcpyDeviceToDevice, hip_stream)); + } + + CHECK_HIP_ERROR(hipFree(vals_in)); + CHECK_HIP_ERROR(hipFree(vals_out_buf)); + CHECK_HIP_ERROR(hipFree(temp_storage)); + } + }); + } else { + throw std::runtime_error( + "ROCm backend does not support sorting complex numbers"); + } + }); + + if (!is_segmented_sort) { + copy_gpu(swapaxes_in_eval(out, axis, last_dim), out_, CopyType::General, s); + } + return; + } + + // Determine block size for small-array block sort + int potential_bn = (size_sorted_axis + tn - 1) / tn; + int bn; + if (potential_bn > 256) { + bn = 512; + } else if (potential_bn > 128) { + bn = 256; + } else if (potential_bn > 64) { + bn = 128; + } else if (potential_bn > 32) { + bn = 64; + } else { + bn = 32; + } + + if (bn == 512 && size_of(in.dtype()) > 4) { + bn = 256; + } + + int64_t in_stride_sorted = 1; // After transpose, always 1 + int64_t out_stride_sorted = 1; + int64_t in_stride_segment = size_sorted_axis; + int64_t out_stride_segment = size_sorted_axis; + + dispatch_all_types(in.dtype(), [&](auto type_tag) { + using CTYPE = MLX_GET_TYPE(type_tag); + if constexpr (!std::is_same_v) { + using ValT = hip_type_t; + + dim3 grid(1, n_rows, 1); + + // Helper to add kernel node with specific template parameters + auto launch_sort = [&](auto argsort_tag, auto block_tag) { + constexpr bool ARG_SORT = decltype(argsort_tag)::value; + constexpr int BLOCK_THREADS = decltype(block_tag)::value; + using OutT = std::conditional_t; + + encoder.add_kernel_node( + &rocm::block_sort_kernel, + grid, + dim3(BLOCK_THREADS, 1, 1), + 0, + gpu_ptr(in), + gpu_ptr(out), + size_sorted_axis, + in_stride_sorted, + out_stride_sorted, + in_stride_segment, + out_stride_segment); + }; + + // Dispatch based on argsort and block size + if (argsort) { + switch (bn) { + case 32: launch_sort(std::true_type{}, std::integral_constant{}); break; + case 64: launch_sort(std::true_type{}, std::integral_constant{}); break; + case 128: launch_sort(std::true_type{}, std::integral_constant{}); break; + case 256: launch_sort(std::true_type{}, std::integral_constant{}); break; + case 512: launch_sort(std::true_type{}, std::integral_constant{}); break; + } + } else { + switch (bn) { + case 32: launch_sort(std::false_type{}, std::integral_constant{}); break; + case 64: launch_sort(std::false_type{}, std::integral_constant{}); break; + case 128: launch_sort(std::false_type{}, std::integral_constant{}); break; + case 256: launch_sort(std::false_type{}, std::integral_constant{}); break; + case 512: launch_sort(std::false_type{}, std::integral_constant{}); break; + } + } + } else { + throw std::runtime_error( + "ROCm backend does not support sorting complex numbers"); + } + }); + + if (!is_segmented_sort) { + // Swap the sorted axis back. + copy_gpu(swapaxes_in_eval(out, axis, last_dim), out_, CopyType::General, s); + } +} + +} // namespace + +void ArgSort::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + gpu_sort(stream(), inputs[0], out, axis_, true); +} + +void Sort::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + gpu_sort(stream(), inputs[0], out, axis_, false); +} + +void ArgPartition::eval_gpu(const std::vector& inputs, array& out) { + gpu_sort(stream(), inputs[0], out, axis_, true); +} + +void Partition::eval_gpu(const std::vector& inputs, array& out) { + gpu_sort(stream(), inputs[0], out, axis_, false); +} + +} // namespace mlx::core diff --git a/mlx/backend/rocm/ternary.hip b/mlx/backend/rocm/ternary.hip new file mode 100644 index 0000000000..7c090f6176 --- /dev/null +++ b/mlx/backend/rocm/ternary.hip @@ -0,0 +1,221 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/ternary.h" +#include "mlx/backend/common/utils.h" +#include "mlx/backend/rocm/allocator.h" +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/device/ternary_ops.hpp" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include +#include +#include + +namespace mlx::core { + +namespace rocm { + +template +__global__ void +ternary_v(const bool* a, const T* b, const T* c, T* out, IdxT size) { + IdxT index = blockIdx.x * blockDim.x + threadIdx.x; + IdxT stride = blockDim.x * gridDim.x; + + for (IdxT i = index * N_READS; i < size; i += stride * N_READS) { + if (i + N_READS <= size) { + #pragma unroll + for (int j = 0; j < N_READS; ++j) { + out[i + j] = Op{}(a[i + j], b[i + j], c[i + j]); + } + } else { + for (IdxT j = i; j < size; ++j) { + out[j] = Op{}(a[j], b[j], c[j]); + } + } + } +} + +template +__global__ void ternary_g( + const bool* a, + const T* b, + const T* c, + T* out, + IdxT size_rest, + hip_array shape, + hip_array a_strides, + hip_array b_strides, + hip_array c_strides, + int ndim) { + IdxT index_rest = blockIdx.y * blockDim.y + threadIdx.y; + if (index_rest >= size_rest) { + return; + } + + auto shape_x = shape[ndim - 1]; + auto a_stride_x = a_strides[ndim - 1]; + auto b_stride_x = b_strides[ndim - 1]; + auto c_stride_x = c_strides[ndim - 1]; + IdxT index_x = blockIdx.x * blockDim.x + threadIdx.x; + + // Compute base offsets using elem_to_loc style calculation + IdxT elem = index_rest * shape_x; + IdxT a_offset = 0; + IdxT b_offset = 0; + IdxT c_offset = 0; + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { + IdxT coord = elem % shape[i]; + elem /= shape[i]; + a_offset += coord * a_strides[i]; + b_offset += coord * b_strides[i]; + c_offset += coord * c_strides[i]; + } + + IdxT out_offset = index_rest * shape_x; + + for (IdxT i = index_x * N_READS; i < shape_x; i += blockDim.x * gridDim.x * N_READS) { + if (i + N_READS <= shape_x) { + #pragma unroll + for (int j = 0; j < N_READS; ++j) { + bool cond = a[a_offset + (i + j) * a_stride_x]; + T b_val = b[b_offset + (i + j) * b_stride_x]; + T c_val = c[c_offset + (i + j) * c_stride_x]; + out[out_offset + i + j] = Op{}(cond, b_val, c_val); + } + } else { + for (IdxT j = i; j < shape_x; ++j) { + bool cond = a[a_offset + j * a_stride_x]; + T b_val = b[b_offset + j * b_stride_x]; + T c_val = c[c_offset + j * c_stride_x]; + out[out_offset + j] = Op{}(cond, b_val, c_val); + } + } + } +} + +} // namespace rocm + +template +void ternary_op_gpu_inplace( + const std::vector& inputs, + array& out, + const Stream& s) { + const auto& a = inputs[0]; + const auto& b = inputs[1]; + const auto& c = inputs[2]; + + if (out.size() == 0) { + return; + } + + auto& encoder = rocm::get_command_encoder(s); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_input_array(c); + encoder.set_output_array(out); + + constexpr int N_READS = 4; + int block_size = 256; + + auto topt = get_ternary_op_type(a, b, c); + + dispatch_all_types(out.dtype(), [&](auto type_tag) { + using CTYPE = MLX_GET_TYPE(type_tag); + using DType = hip_type_t; + + if (topt == TernaryOpType::VectorVectorVector || + topt == TernaryOpType::ScalarScalarScalar) { + // Contiguous case - use ternary_v + auto size = out.data_size(); + int num_blocks = (size + block_size * N_READS - 1) / (block_size * N_READS); + num_blocks = std::min(num_blocks, 65535); + + int64_t size_arg = static_cast(size); + encoder.add_kernel_node( + &rocm::ternary_v, + dim3(num_blocks), dim3(block_size), 0, + gpu_ptr(a), gpu_ptr(b), gpu_ptr(c), + gpu_ptr(out), size_arg); + } else { + // General case - use ternary_g with strided access + Shape shape_vec; + std::vector strides_vec; + std::tie(shape_vec, strides_vec) = collapse_contiguous_dims(a, b, c, out); + auto& a_strides_vec = strides_vec[0]; + auto& b_strides_vec = strides_vec[1]; + auto& c_strides_vec = strides_vec[2]; + int ndim = shape_vec.size(); + + rocm::hip_array shape_arg = {}; + rocm::hip_array a_strides_arg = {}; + rocm::hip_array b_strides_arg = {}; + rocm::hip_array c_strides_arg = {}; + for (int i = 0; i < ndim; i++) { + shape_arg.data_[i] = static_cast(shape_vec[i]); + a_strides_arg.data_[i] = a_strides_vec[i]; + b_strides_arg.data_[i] = b_strides_vec[i]; + c_strides_arg.data_[i] = c_strides_vec[i]; + } + + int dim0 = ndim > 0 ? shape_vec.back() : 1; + size_t rest = out.size() / dim0; + + int work_per_thread = (dim0 >= 4) ? 4 : 1; + dim0 = (dim0 + work_per_thread - 1) / work_per_thread; + + int block_x = std::min(dim0, 32); + int block_y = std::min(static_cast(rest), 256 / block_x); + int num_blocks_x = (dim0 + block_x - 1) / block_x; + int num_blocks_y = (rest + block_y - 1) / block_y; + + int64_t rest_arg = static_cast(rest); + if (work_per_thread == 4) { + encoder.add_kernel_node( + &rocm::ternary_g, + dim3(num_blocks_x, num_blocks_y), dim3(block_x, block_y), 0, + gpu_ptr(a), gpu_ptr(b), gpu_ptr(c), + gpu_ptr(out), + rest_arg, + shape_arg, + a_strides_arg, + b_strides_arg, + c_strides_arg, + ndim); + } else { + encoder.add_kernel_node( + &rocm::ternary_g, + dim3(num_blocks_x, num_blocks_y), dim3(block_x, block_y), 0, + gpu_ptr(a), gpu_ptr(b), gpu_ptr(c), + gpu_ptr(out), + rest_arg, + shape_arg, + a_strides_arg, + b_strides_arg, + c_strides_arg, + ndim); + } + } + }); +} + +template +void ternary_op_gpu( + const std::vector& inputs, + array& out, + const Stream& s) { + auto& a = inputs[0]; + auto& b = inputs[1]; + auto& c = inputs[2]; + auto topt = get_ternary_op_type(a, b, c); + set_ternary_op_output_data(a, b, c, out, topt); + ternary_op_gpu_inplace(inputs, out, s); +} + +void Select::eval_gpu(const std::vector& inputs, array& out) { + auto& s = stream(); + ternary_op_gpu(inputs, out, s); +} + +} // namespace mlx::core diff --git a/mlx/backend/rocm/unary.hip b/mlx/backend/rocm/unary.hip new file mode 100644 index 0000000000..aaa375f7e6 --- /dev/null +++ b/mlx/backend/rocm/unary.hip @@ -0,0 +1,334 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/unary.h" +#include "mlx/backend/common/utils.h" +#include "mlx/backend/rocm/allocator.h" +#include "mlx/backend/rocm/device.h" +#include "mlx/backend/rocm/device/unary_ops.hpp" +#include "mlx/backend/rocm/kernel_utils.hpp" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include +#include + +namespace mlx::core { + +namespace rocm { + +namespace cg = cooperative_groups; + +template +__global__ void unary_v(const In* in, Out* out, IdxT size) { + IdxT index = blockIdx.x * blockDim.x + threadIdx.x; + IdxT stride = blockDim.x * gridDim.x; + + for (IdxT i = index * N_READS; i < size; i += stride * N_READS) { + if (i + N_READS <= size) { + #pragma unroll + for (int j = 0; j < N_READS; ++j) { + out[i + j] = Op{}(in[i + j]); + } + } else { + for (IdxT j = i; j < size; ++j) { + out[j] = Op{}(in[j]); + } + } + } +} + +template +__global__ void unary_g( + const In* in, + Out* out, + IdxT size_rest, + hip_array shape, + hip_array strides, + int ndim) { + IdxT index_rest = blockIdx.y * blockDim.y + threadIdx.y; + if (index_rest >= size_rest) { + return; + } + + auto shape_x = shape[ndim - 1]; + auto stride_x = strides[ndim - 1]; + + IdxT index_x = blockIdx.x * blockDim.x + threadIdx.x; + + // Compute base offset for this row using elem_to_loc style calculation + // elem = index_rest * shape_x gives us the linear element index for the start of this row + IdxT elem = index_rest * shape_x; + IdxT idx = 0; + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { + idx += (elem % shape[i]) * strides[i]; + elem /= shape[i]; + } + + // Process elements in this row + for (IdxT i = index_x * N_READS; i < shape_x; i += blockDim.x * gridDim.x * N_READS) { + if (i + N_READS <= shape_x) { + #pragma unroll + for (int j = 0; j < N_READS; ++j) { + IdxT in_idx = idx + (i + j) * stride_x; + out[shape_x * index_rest + i + j] = Op{}(in[in_idx]); + } + } else { + for (IdxT j = i; j < shape_x; ++j) { + IdxT in_idx = idx + j * stride_x; + out[shape_x * index_rest + j] = Op{}(in[in_idx]); + } + } + } +} + +template +constexpr bool supports_unary_op() { + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { + return std::is_same_v; + } + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v) { + return std::is_same_v && is_floating_v; + } + if constexpr (std::is_same_v) { + return std::is_same_v && std::is_integral_v && + !std::is_same_v; + } + if constexpr (std::is_same_v || std::is_same_v) { + return std::is_same_v && !is_complex_v; + } + if constexpr (std::is_same_v) { + return std::is_same_v && is_complex_v; + } + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v) { + return std::is_same_v && is_inexact_v; + } + if constexpr (std::is_same_v || std::is_same_v) { + return is_complex_v && std::is_same_v; + } + if constexpr (std::is_same_v) { + return std::is_same_v && std::is_same_v; + } + return false; +} + +} // namespace rocm + +template +void unary_op_gpu_inplace( + const std::vector& inputs, + array& out, + const char* op, + const Stream& s) { + auto& in = inputs[0]; + if (in.size() == 0) { + return; + } + bool contig = in.flags().contiguous; + bool large; + if (!contig) { + large = in.data_size() > INT32_MAX || out.size() > INT32_MAX; + } else { + large = in.data_size() > UINT32_MAX; + } + + auto& encoder = rocm::get_command_encoder(s); + encoder.set_input_array(in); + encoder.set_output_array(out); + + // Dispatch based on input and output types + dispatch_all_types(in.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + using CTYPE_IN = MLX_GET_TYPE(in_type_tag); + using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); + using InType = hip_type_t; + using OutType = hip_type_t; + + if constexpr (rocm::supports_unary_op()) { + if (contig) { + // Contiguous case - use unary_v + constexpr int N_READS = 4; + int block_size = 256; + auto size = out.data_size(); + int num_blocks = (size + block_size * N_READS - 1) / (block_size * N_READS); + num_blocks = std::min(num_blocks, 65535); + + if (large) { + encoder.add_kernel_node( + &rocm::unary_v, + dim3(num_blocks), dim3(block_size), 0, + gpu_ptr(in), gpu_ptr(out), static_cast(size)); + } else { + encoder.add_kernel_node( + &rocm::unary_v, + dim3(num_blocks), dim3(block_size), 0, + gpu_ptr(in), gpu_ptr(out), static_cast(size)); + } + } else { + // Non-contiguous case - use unary_g with strided access + auto [shape_vec, strides_vec] = collapse_contiguous_dims(in); + int ndim = shape_vec.size(); + + rocm::hip_array shape_arg = {}; + rocm::hip_array strides_arg = {}; + for (int i = 0; i < ndim; i++) { + shape_arg.data_[i] = static_cast(shape_vec[i]); + strides_arg.data_[i] = strides_vec[i]; + } + + int dim0 = ndim > 0 ? shape_vec.back() : 1; + size_t rest = out.size() / dim0; + + constexpr int N_READS = 4; + int work_per_thread = (dim0 >= 4) ? 4 : 1; + dim0 = (dim0 + work_per_thread - 1) / work_per_thread; + + // Calculate block and grid dimensions + int block_x = std::min(dim0, 32); + int block_y = std::min(static_cast(rest), 256 / block_x); + int num_blocks_x = (dim0 + block_x - 1) / block_x; + int num_blocks_y = (rest + block_y - 1) / block_y; + + if (large) { + int64_t rest_arg = static_cast(rest); + if (work_per_thread == 4) { + encoder.add_kernel_node( + &rocm::unary_g, + dim3(num_blocks_x, num_blocks_y), dim3(block_x, block_y), 0, + gpu_ptr(in), gpu_ptr(out), + rest_arg, + shape_arg, + strides_arg, + ndim); + } else { + encoder.add_kernel_node( + &rocm::unary_g, + dim3(num_blocks_x, num_blocks_y), dim3(block_x, block_y), 0, + gpu_ptr(in), gpu_ptr(out), + rest_arg, + shape_arg, + strides_arg, + ndim); + } + } else { + int32_t rest_arg = static_cast(rest); + if (work_per_thread == 4) { + encoder.add_kernel_node( + &rocm::unary_g, + dim3(num_blocks_x, num_blocks_y), dim3(block_x, block_y), 0, + gpu_ptr(in), gpu_ptr(out), + rest_arg, + shape_arg, + strides_arg, + ndim); + } else { + encoder.add_kernel_node( + &rocm::unary_g, + dim3(num_blocks_x, num_blocks_y), dim3(block_x, block_y), 0, + gpu_ptr(in), gpu_ptr(out), + rest_arg, + shape_arg, + strides_arg, + ndim); + } + } + } + } + }); + }); +} + +template +void unary_op_gpu( + const std::vector& inputs, + array& out, + const char* op, + const Stream& s) { + set_unary_output_data(inputs[0], out); + unary_op_gpu_inplace(inputs, out, op, s); +} + +#define UNARY_GPU(func) \ + void func::eval_gpu(const std::vector& inputs, array& out) { \ + auto& s = out.primitive().stream(); \ + unary_op_gpu(inputs, out, name(), s); \ + } + +UNARY_GPU(Abs) +UNARY_GPU(ArcCos) +UNARY_GPU(ArcCosh) +UNARY_GPU(ArcSin) +UNARY_GPU(ArcSinh) +UNARY_GPU(ArcTan) +UNARY_GPU(ArcTanh) +UNARY_GPU(BitwiseInvert) +UNARY_GPU(Ceil) +UNARY_GPU(Conjugate) +UNARY_GPU(Cos) +UNARY_GPU(Cosh) +UNARY_GPU(Erf) +UNARY_GPU(ErfInv) +UNARY_GPU(Exp) +UNARY_GPU(Expm1) +UNARY_GPU(Floor) +UNARY_GPU(Imag) +UNARY_GPU(Log1p) +UNARY_GPU(LogicalNot) +UNARY_GPU(Negative) +UNARY_GPU(Real) +UNARY_GPU(Sigmoid) +UNARY_GPU(Sign) +UNARY_GPU(Sin) +UNARY_GPU(Sinh) +UNARY_GPU(Square) +UNARY_GPU(Tan) +UNARY_GPU(Tanh) + +void Log::eval_gpu(const std::vector& inputs, array& out) { + auto& s = out.primitive().stream(); + switch (base_) { + case Base::e: + unary_op_gpu(inputs, out, name(), s); + break; + case Base::two: + unary_op_gpu(inputs, out, name(), s); + break; + case Base::ten: + unary_op_gpu(inputs, out, name(), s); + break; + } +} + +void Round::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + auto& s = out.primitive().stream(); + if (issubdtype(in.dtype(), inexact)) { + unary_op_gpu(inputs, out, name(), s); + } else { + // No-op integer types + out.copy_shared_buffer(in); + } +} + +void Sqrt::eval_gpu(const std::vector& inputs, array& out) { + auto& s = out.primitive().stream(); + if (recip_) { + unary_op_gpu(inputs, out, "Rsqrt", s); + } else { + unary_op_gpu(inputs, out, "Sqrt", s); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/rocm/utils.cpp b/mlx/backend/rocm/utils.cpp new file mode 100644 index 0000000000..e20685a4d8 --- /dev/null +++ b/mlx/backend/rocm/utils.cpp @@ -0,0 +1,82 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/utils.h" +#include "mlx/backend/rocm/device.h" +#include "mlx/dtype_utils.h" + +#include + +namespace mlx::core { + +void check_rocblas_error(const char* name, rocblas_status err) { + if (err != rocblas_status_success) { + std::ostringstream oss; + oss << name << " failed with code: " << static_cast(err) << "."; + throw std::runtime_error(oss.str()); + } +} + +void check_hip_error(const char* name, hipError_t err) { + if (err != hipSuccess) { + std::ostringstream oss; + oss << name << " failed: " << hipGetErrorString(err); + throw std::runtime_error(oss.str()); + } +} + +const char* dtype_to_hip_type(const Dtype& dtype) { + switch (dtype) { + case bool_: + return "bool"; + case int8: + return "int8_t"; + case int16: + return "int16_t"; + case int32: + return "int32_t"; + case int64: + return "int64_t"; + case uint8: + return "uint8_t"; + case uint16: + return "uint16_t"; + case uint32: + return "uint32_t"; + case uint64: + return "uint64_t"; + case float16: + return "__half"; + case bfloat16: + return "hip_bfloat16"; + case float32: + return "float"; + case float64: + return "double"; + case complex64: + return "complex64_t"; + default: + return "unknown"; + } +} + +HipGraph::HipGraph(rocm::Device& device) { + device.make_current(); + CHECK_HIP_ERROR(hipGraphCreate(&handle_, 0)); +} + +void HipGraph::end_capture(hipStream_t stream) { + assert(handle_ == nullptr); + CHECK_HIP_ERROR(hipStreamEndCapture(stream, &handle_)); +} + +void HipGraphExec::instantiate(hipGraph_t graph) { + assert(handle_ == nullptr); + CHECK_HIP_ERROR(hipGraphInstantiate(&handle_, graph, nullptr, nullptr, 0)); +} + +HipStream::HipStream(rocm::Device& device) { + device.make_current(); + CHECK_HIP_ERROR(hipStreamCreateWithFlags(&handle_, hipStreamNonBlocking)); +} + +} // namespace mlx::core diff --git a/mlx/backend/rocm/utils.h b/mlx/backend/rocm/utils.h new file mode 100644 index 0000000000..b075b96187 --- /dev/null +++ b/mlx/backend/rocm/utils.h @@ -0,0 +1,87 @@ +// Copyright © 2025 Apple Inc. + +// This file include utilities that are used by C++ code (i.e. .cpp files). + +#pragma once + +#include +#include + +namespace mlx::core { + +namespace rocm { +class Device; +} + +struct Dtype; + +// Throw exception if the HIP API does not succeed. +void check_rocblas_error(const char* name, rocblas_status err); +void check_hip_error(const char* name, hipError_t err); + +// The macro version that prints the command that failed. +#define CHECK_ROCBLAS_ERROR(cmd) check_rocblas_error(#cmd, (cmd)) +#define CHECK_HIP_ERROR(cmd) check_hip_error(#cmd, (cmd)) + +// Convert Dtype to HIP C++ types. +const char* dtype_to_hip_type(const Dtype& dtype); + +// Base class for RAII managed HIP resources. +template +class HipHandle { + public: + HipHandle(Handle handle = nullptr) : handle_(handle) {} + + HipHandle(HipHandle&& other) : handle_(other.handle_) { + assert(this != &other); + other.handle_ = nullptr; + } + + ~HipHandle() { + reset(); + } + + HipHandle(const HipHandle&) = delete; + HipHandle& operator=(const HipHandle&) = delete; + + HipHandle& operator=(HipHandle&& other) { + assert(this != &other); + reset(); + std::swap(handle_, other.handle_); + return *this; + } + + void reset() { + if (handle_ != nullptr) { + CHECK_HIP_ERROR(Destroy(handle_)); + handle_ = nullptr; + } + } + + operator Handle() const { + return handle_; + } + + protected: + Handle handle_; +}; + +// Wrappers of HIP resources. +class HipGraph : public HipHandle { + public: + using HipHandle::HipHandle; + explicit HipGraph(rocm::Device& device); + void end_capture(hipStream_t stream); +}; + +class HipGraphExec : public HipHandle { + public: + void instantiate(hipGraph_t graph); +}; + +class HipStream : public HipHandle { + public: + explicit HipStream(rocm::Device& device); +}; + +} // namespace mlx::core diff --git a/mlx/backend/rocm/worker.cpp b/mlx/backend/rocm/worker.cpp new file mode 100644 index 0000000000..0fdba7d894 --- /dev/null +++ b/mlx/backend/rocm/worker.cpp @@ -0,0 +1,80 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/rocm/worker.h" +#include "mlx/backend/rocm/utils.h" + +namespace mlx::core::rocm { + +Worker::Worker(int device) : device_(device), worker_(&Worker::thread_fn, this) {} + +Worker::~Worker() { + { + std::lock_guard lock(mtx_); + stop_ = true; + } + cond_.notify_one(); + worker_.join(); +} + +void Worker::add_task(std::function task) { + pending_tasks_.push_back(std::move(task)); +} + +void Worker::signal(void* data) { + auto w = static_cast(data); + { + std::lock_guard lock(w->mtx_); + w->signaled_batch_++; + } + w->cond_.notify_one(); +} + +void Worker::commit(hipStream_t stream) { + // Move pending tasks into tasks + if (pending_tasks_.empty()) { + return; + } + { + std::lock_guard lock(mtx_); + // Move pending tasks into ready tasks + worker_tasks_[++committed_batch_] = std::move(pending_tasks_); + } + // Use hipLaunchHostFunc to signal when stream operations complete + (void)hipLaunchHostFunc(stream, signal, this); +} + +void Worker::thread_fn() { + // Bind this thread to the encoder's device before running any task. Completion + // handlers free temporaries / return buffers to the pool and may issue HIP + // calls; they must hit the same device the stream lives on, not the default + // device 0. Without this the discrete-GPU queue wedges on a multi-GPU host. + (void)hipSetDevice(device_); + uint64_t current_batch = 0; + while (!stop_) { + Tasks tasks; + { + std::unique_lock lk(mtx_); + cond_.wait(lk, [this, current_batch] { + return this->signaled_batch_ > current_batch || this->stop_; + }); + current_batch = signaled_batch_; + auto end = worker_tasks_.upper_bound(current_batch); + for (auto it = worker_tasks_.begin(); it != end; ++it) { + if (tasks.empty()) { + tasks = std::move(it->second); + } else { + std::move( + it->second.begin(), it->second.end(), std::back_inserter(tasks)); + } + } + worker_tasks_.erase(worker_tasks_.begin(), end); + } + // Make sure tasks are cleared before the next wait + for (size_t i = 0; i < tasks.size(); ++i) { + auto task = std::move(tasks[i]); + task(); + } + } +} + +} // namespace mlx::core::rocm diff --git a/mlx/backend/rocm/worker.h b/mlx/backend/rocm/worker.h new file mode 100644 index 0000000000..d4689b0fef --- /dev/null +++ b/mlx/backend/rocm/worker.h @@ -0,0 +1,63 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace mlx::core::rocm { + +// Forward declarations +class HipEvent; + +// Run tasks in worker thread, synchronized with HIP stream. +class Worker { + public: + explicit Worker(int device); + ~Worker(); + + Worker(const Worker&) = delete; + Worker& operator=(const Worker&) = delete; + + // Add a pending |task| that will run when consumed or committed. + void add_task(std::function task); + + // Inform worker thread to run current batches after kernels in |stream| + // finish running. + void commit(hipStream_t stream); + + private: + static void signal(void*); + + void thread_fn(); + std::mutex mtx_; + std::condition_variable cond_; + + uint64_t committed_batch_{0}; + uint64_t signaled_batch_{0}; + + bool stop_{false}; + + // The HIP device this worker's stream-completion callbacks run against. The + // worker thread must hipSetDevice(device_) before running any task: HIP's + // current device is per-thread and a freshly spawned thread defaults to device + // 0. Running device-1 stream callbacks/frees from a device-0-bound thread is a + // cross-device coupling that wedges the queue on a discrete GPU. + int device_{0}; + + // Tasks are put in |pending_tasks_| first, and then moved to + // |worker_tasks_| when end_batch() is called. + using Tasks = std::vector>; + Tasks pending_tasks_; + std::map worker_tasks_; + std::thread worker_; +}; + +} // namespace mlx::core::rocm diff --git a/mlx/device.cpp b/mlx/device.cpp index f0c868f21b..13695a47bb 100644 --- a/mlx/device.cpp +++ b/mlx/device.cpp @@ -6,10 +6,23 @@ #include "mlx/backend/gpu/device_info.h" #include "mlx/device.h" +#ifdef MLX_USE_ROCM +#include "mlx/backend/rocm/rocm.h" +#endif + namespace mlx::core { Device& mutable_default_device() { - static Device default_device{gpu::is_available() ? Device::gpu : Device::cpu}; + Device::DeviceType default_type = Device::cpu; + if (gpu::is_available()) { + default_type = Device::gpu; + } +#ifdef MLX_USE_ROCM + else if (rocm::is_available()) { + default_type = Device::gpu; // ROCm devices use the generic gpu type + } +#endif + static Device default_device{default_type}; return default_device; } @@ -30,7 +43,12 @@ bool is_available(const Device& d) { case Device::cpu: return cpu::is_available() && (d.index < cpu::device_count()); case Device::gpu: +#ifdef MLX_USE_ROCM + return (gpu::is_available() || rocm::is_available()) && + (d.index < gpu::device_count()); +#else return gpu::is_available() && (d.index < gpu::device_count()); +#endif } // appease compiler return false; diff --git a/mlx/fast.cpp b/mlx/fast.cpp index a668fe9abd..92b8a68f3c 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -726,9 +726,11 @@ array scaled_dot_product_attention( auto k = inputs[1]; auto v = inputs[2]; if (n_repeats > 1) { - q = unflatten(q, 1, {n_kv_heads, n_repeats}, s); - k = expand_dims(k, 2, s); - v = expand_dims(v, 2, s); + // Avoid high-rank broadcasted matmul for GQA in the fallback path. + // Some backends are unstable with that layout; repeating k/v heads keeps + // the computation in standard 4D matmul form. + k = repeat(k, n_repeats, 1, s); + v = repeat(v, n_repeats, 1, s); } auto scores = matmul(q, swapaxes(k, -1, -2, s), s); if (has_arr_mask || do_causal) { @@ -747,14 +749,6 @@ array scaled_dot_product_attention( return inputs[3]; }; auto mask = make_or_fetch_mask(); - - if (n_repeats > 1 && mask.ndim() >= 3) { - if (mask.shape(-3) == 1) { - mask = expand_dims(mask, -3, s); - } else { - mask = unflatten(mask, -3, {n_kv_heads, n_repeats}, s); - } - } if (mask.dtype() == bool_) { scores = where( mask, scores, array(finfo(scores.dtype()).min, scores.dtype()), s); @@ -782,9 +776,6 @@ array scaled_dot_product_attention( scores = slice(scores, std::move(start), std::move(stop), s); } auto out = matmul(scores, v, s); - if (n_repeats > 1) { - out = flatten(out, 1, 2, s); - } return std::vector{out}; }; diff --git a/mlx/fast.h b/mlx/fast.h index 1183aba8fe..e91ba9ad81 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -86,6 +86,18 @@ MLX_API CustomKernelFunction cuda_kernel( bool ensure_row_contiguous = true, int shared_memory = 0); +MLX_API CustomKernelFunction hip_kernel( + const std::string& name, + const std::vector& input_names, + const std::vector& output_names, + const std::string& source, + const std::string& header = "", + bool ensure_row_contiguous = true, + int shared_memory = 0, + // Output index -> input index to alias (output reuses the input's buffer, + // in place). Used for recurrent-state kernels under HIP-graph capture. + std::vector> output_input_aliases = {}); + MLX_API std::vector precompiled_cuda_kernel( const std::string& name, const std::string& compiled_source, diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index 4434830875..c8e7e50b77 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -375,7 +375,8 @@ class CustomKernel : public Primitive { std::optional init_value, std::vector scalar_arguments, bool is_precompiled, - int shared_memory) + int shared_memory, + std::vector> output_input_aliases = {}) : Primitive(stream), name_(std::move(name)), source_(std::move(source)), @@ -386,7 +387,8 @@ class CustomKernel : public Primitive { init_value_(init_value), scalar_arguments_(std::move(scalar_arguments)), is_precompiled_(is_precompiled), - shared_memory_(shared_memory) {} + shared_memory_(shared_memory), + output_input_aliases_(std::move(output_input_aliases)) {} void eval_cpu(const std::vector& inputs, std::vector& outputs) override { @@ -422,6 +424,8 @@ class CustomKernel : public Primitive { std::vector scalar_arguments_; bool is_precompiled_; int shared_memory_; + // Output index -> input index whose buffer the output reuses in-place. + std::vector> output_input_aliases_; }; } // namespace mlx::core::fast diff --git a/mlx/ops.cpp b/mlx/ops.cpp index e4ce3d750f..a4d05b05fd 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4675,6 +4675,58 @@ array qqmm( inputs.push_back(*global_scale_w); } +#if defined(MLX_USE_ROCM) + if (stream.device == Device::gpu) { + auto xq = quantize(x, group_size, bits, mode, global_scale_x, stream); + auto xhat = dequantize( + xq[0], + xq[1], + std::nullopt, + group_size, + bits, + mode, + global_scale_x, + x.dtype(), + stream); + + auto what = [&]() { + if (w.dtype() == uint32) { + return dequantize( + w, + *scales_w, + std::nullopt, + group_size, + bits, + mode, + global_scale_w, + x.dtype(), + stream); + } + auto wq = quantize(w, group_size, bits, mode, global_scale_w, stream); + return dequantize( + wq[0], + wq[1], + std::nullopt, + group_size, + bits, + mode, + global_scale_w, + x.dtype(), + stream); + }(); + + auto out = matmul(xhat, swapaxes(what, -1, -2, stream), stream); + if (in_x.ndim() > 2) { + auto orig_shape = in_x.shape(); + orig_shape.pop_back(); + out = unflatten(out, 0, std::move(orig_shape), stream); + } else if (in_x.ndim() == 1) { + out = squeeze(out, 0, stream); + } + return out; + } +#endif + auto out_shape = inputs[0].shape(); out_shape.back() = w_outer_dims; auto out = array( @@ -4897,6 +4949,12 @@ std::vector fp_quantize( return {std::move(wq), std::move(scales)}; }; +#if defined(MLX_USE_ROCM) + if (s.device == Device::gpu) { + return fallback(inputs); + } +#endif + if (s.device == Device::gpu) { auto wq_shape = w.shape(); wq_shape.back() = w.shape(-1) * bits / 32; @@ -5162,6 +5220,21 @@ array fp_dequantize( return {reshape(multiply(out, scales, s), wshape, s)}; }; +#if defined(MLX_USE_ROCM) + if (s.device == Device::gpu) { + return dequantize( + w, + scales, + std::nullopt, + group_size, + bits, + quantization_mode_to_string(mode), + global_scale, + out_type, + Device::cpu); + } +#endif + if (s.device == Device::gpu) { auto out_shape = w.shape(); out_shape.back() = out_size; diff --git a/mlx/stream.cpp b/mlx/stream.cpp index 9f09596f90..34fdefdf40 100644 --- a/mlx/stream.cpp +++ b/mlx/stream.cpp @@ -45,7 +45,7 @@ Stream default_stream(Device d) { } auto& s = default_stream_storage(d); if (!s.has_value()) { - s = new_stream(d.type); + s = new_stream(d); } return s.value(); } diff --git a/python/src/CMakeLists.txt b/python/src/CMakeLists.txt index 447271500b..f1b89c80d5 100644 --- a/python/src/CMakeLists.txt +++ b/python/src/CMakeLists.txt @@ -18,6 +18,7 @@ nanobind_add_module( ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/rocm.cpp ${CMAKE_CURRENT_SOURCE_DIR}/memory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/mlx_func.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp @@ -49,6 +50,7 @@ if(MLX_BUILD_PYTHON_STUBS) OUTPUT "${CMAKE_CURRENT_SOURCE_DIR}/../mlx/core/__init__.pyi" "${CMAKE_CURRENT_SOURCE_DIR}/../mlx/core/cuda.pyi" + "${CMAKE_CURRENT_SOURCE_DIR}/../mlx/core/rocm.pyi" "${CMAKE_CURRENT_SOURCE_DIR}/../mlx/core/distributed.pyi" "${CMAKE_CURRENT_SOURCE_DIR}/../mlx/core/fast.pyi" "${CMAKE_CURRENT_SOURCE_DIR}/../mlx/core/fft.pyi" diff --git a/python/src/fast.cpp b/python/src/fast.cpp index 1a43d89d9b..b0a4108c9a 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -529,6 +529,120 @@ void init_fast(nb::module_& parent_module) { assert mx.allclose(b, mx.exp(a)) )pbdoc"); + m.def( + "hip_kernel", + [](const std::string& name, + const std::vector& input_names, + const std::vector& output_names, + const std::string& source, + const std::string& header, + bool ensure_row_contiguous, + int shared_mem) { + auto kernel = mx::fast::hip_kernel( + name, + input_names, + output_names, + source, + header, + ensure_row_contiguous, + shared_mem); + return nb::cpp_function( + PyCustomKernelFunction(std::move(kernel), "[hip_kernel]"), + nb::kw_only(), + "inputs"_a, + "output_shapes"_a, + "output_dtypes"_a, + "grid"_a, + "threadgroup"_a, + "template"_a = nb::none(), + "init_value"_a = nb::none(), + "verbose"_a = false, + "stream"_a = nb::none(), + nb::sig( + "def __call__(self, *, inputs: List[Union[scalar, array]], output_shapes: List[Sequence[int]], output_dtypes: List[Dtype], grid: tuple[int, int, int], threadgroup: tuple[int, int, int], template: Optional[List[Tuple[str, Union[bool, int, Dtype]]]] = None, init_value: Optional[float] = None, verbose: bool = false, stream: Union[None, Stream, Device] = None)"), + R"pbdoc( + Run the kernel. + + Args: + inputs (List[array]): The inputs passed to the HIP kernel. + output_shapes (List[Sequence[int]]): The list of shapes for each output in ``output_names``. + output_dtypes (List[Dtype]): The list of data types for each output in ``output_names``. + grid (tuple[int, int, int]): 3-tuple specifying the grid to launch the kernel with. + threadgroup (tuple[int, int, int]): 3-tuple specifying the threadgroup size to use. + template (List[Tuple[str, Union[bool, int, Dtype]]], optional): Template arguments. + These will be added as template arguments to the kernel definition. Default: ``None``. + init_value (float, optional): Optional value to use to initialize all of the output arrays. + By default, output arrays are uninitialized. Default: ``None``. + verbose (bool, optional): Whether to print the full generated source code of the kernel + when it is run. Default: ``False``. + stream (mx.stream, optional): Stream to run the kernel on. Default: ``None``. + + Returns: + List[array]: The list of output arrays.)pbdoc"); + }, + "name"_a, + "input_names"_a, + "output_names"_a, + "source"_a, + "header"_a = "", + "ensure_row_contiguous"_a = true, + "shared_memory"_a = 0, + R"pbdoc( + A jit-compiled custom HIP kernel defined from a source string. + + Args: + name (str): Name for the kernel. + input_names (List[str]): The parameter names of the inputs in the + function signature. + output_names (List[str]): The parameter names of the outputs in the + function signature. + source (str): Source code. This is the body of a function in HIP, + the function signature will be automatically generated. + header (str): Header source code to include before the main function. + Useful for helper functions or includes that should live outside of + the main function body. + ensure_row_contiguous (bool): Whether to ensure the inputs are row contiguous + before the kernel runs. Default: ``True``. + shared_memory (int): The dynamic shared memory to request for the + kernel. A value of 0 means no dynamic shared memory. Default: ``0``. + + Returns: + Callable ``hip_kernel``. + + Example: + + .. code-block:: python + + def exp_elementwise(a: mx.array): + source = ''' + int elem = blockIdx.x * blockDim.x + threadIdx.x; + T tmp = inp[elem]; + out[elem] = exp(tmp); + ''' + + kernel = mx.fast.hip_kernel( + name="myexp", + input_names=["inp"], + output_names=["out"], + source=source + ) + + outputs = kernel( + inputs=[a], + template=[("T", a.dtype)], + grid=(a.size, 1, 1), + threadgroup=(256, 1, 1), + output_shapes=[a.shape], + output_dtypes=[a.dtype], + verbose=True, + ) + return outputs[0] + + a = mx.random.normal(shape=(16, 16)).astype(mx.float16) + b = exp_elementwise(a) + assert mx.allclose(b, mx.exp(a)) + )pbdoc"); + m.def( "precompiled_cuda_kernel", [](const std::string& name, diff --git a/python/src/mlx.cpp b/python/src/mlx.cpp index cb031cf78c..eae71f3ef4 100644 --- a/python/src/mlx.cpp +++ b/python/src/mlx.cpp @@ -13,6 +13,7 @@ void init_device(nb::module_&); void init_stream(nb::module_&); void init_metal(nb::module_&); void init_cuda(nb::module_&); +void init_rocm(nb::module_&); void init_memory(nb::module_&); void init_ops(nb::module_&); void init_transforms(nb::module_&); @@ -37,6 +38,7 @@ NB_MODULE(core, m) { init_array(m); init_metal(m); init_cuda(m); + init_rocm(m); init_memory(m); init_ops(m); init_transforms(m); diff --git a/python/src/rocm.cpp b/python/src/rocm.cpp new file mode 100644 index 0000000000..77a91332a5 --- /dev/null +++ b/python/src/rocm.cpp @@ -0,0 +1,19 @@ +// Copyright © 2025 Apple Inc. + +#include + +#include "mlx/backend/rocm/rocm.h" + +namespace mx = mlx::core; +namespace nb = nanobind; + +void init_rocm(nb::module_& m) { + nb::module_ rocm = m.def_submodule("rocm", "mlx.rocm"); + + rocm.def( + "is_available", + &mx::rocm::is_available, + R"pbdoc( + Check if the ROCm back-end is available. + )pbdoc"); +} diff --git a/python/tests/mlx_tests.py b/python/tests/mlx_tests.py index ac223f0950..2143d18b1a 100644 --- a/python/tests/mlx_tests.py +++ b/python/tests/mlx_tests.py @@ -17,12 +17,60 @@ import numpy as np +def _get_backend_skip_tests(device): + if not (device == mx.gpu and not mx.metal.is_available()): + return set(), None + + if mx.cuda.is_available(): + from cuda_skip import cuda_skip + + return cuda_skip, "CUDA" + + if mx.rocm.is_available(): + from rocm_skip import rocm_skip + + return rocm_skip, "ROCm" + + return set(), None + + class MLXTestRunner(unittest.TestProgram): def __init__(self, *args, **kwargs): # Do not exit in runTests kwargs["exit"] = False super().__init__(*args, **kwargs) + def createTests(self, *args, **kwargs): + super().createTests(*args, **kwargs) + + # Check if we're running on a non-Metal GPU backend (CUDA or ROCm) + device_name = os.getenv("DEVICE", None) + if device_name is not None: + device = getattr(mx, device_name) + else: + device = mx.default_device() + + skip_tests, _ = _get_backend_skip_tests(device) + + if not skip_tests: + return + + filtered_suite = unittest.TestSuite() + + def filter_and_add(t): + if isinstance(t, unittest.TestSuite): + for sub_t in t: + filter_and_add(sub_t) + else: + t_id = ".".join(t.id().split(".")[-2:]) + if t_id in skip_tests: + print(f"Skipping {t_id}") + else: + filtered_suite.addTest(t) + + filter_and_add(self.test) + self.test = filtered_suite + def runTests(self): super().runTests() mx.clear_streams() @@ -36,9 +84,19 @@ def is_apple_silicon(self): def setUp(self): self.default = mx.default_device() - device = os.getenv("DEVICE", None) - if device is not None: - device = getattr(mx, device) + + device_name = os.getenv("DEVICE", None) + if device_name is not None: + device = getattr(mx, device_name) + else: + device = self.default + + skip_tests, backend = _get_backend_skip_tests(device) + test_id = f"{self.__class__.__name__}.{self._testMethodName}" + if test_id in skip_tests: + self.skipTest(f"Skipped on {backend} backend") + + if device_name is not None: mx.set_default_device(device) def tearDown(self): diff --git a/python/tests/rocm_skip.py b/python/tests/rocm_skip.py new file mode 100644 index 0000000000..004268f2b1 --- /dev/null +++ b/python/tests/rocm_skip.py @@ -0,0 +1,98 @@ +# Tests to skip for ROCm backend +# Based on functionality comparison with CUDA backend + +rocm_skip = { + # Same as CUDA - Block masked matmul NYI + "TestBlas.test_block_masked_matmul", + # Same as CUDA - Gather matmul NYI (ROCm throws for M > 1 and N > 1) + "TestBlas.test_gather_matmul", + "TestBlas.test_gather_matmul_grad", + "TestBlas.test_gather_mm_sorted_vjp", + # Same as CUDA - Segmented matmul NYI + "TestBlas.test_segmented_mm", + # ROCm-specific: Complex GEMM not supported in naive fallback + "TestBlas.test_complex_gemm", + "TestBlas.test_complex_gemv", + # ROCm-specific: addmm tolerance too tight for naive GEMM + "TestBlas.test_addmm", + "TestBlas.test_addmm_grad", + # ROCm-specific: empty matmul has issues on unsupported architectures + "TestBlas.test_empty_matmul", + # ROCm-specific: batched matrix-vector has precision issues on gfx1011 + "TestBlas.test_matrix_vector_batched", + # Same as CUDA - Hadamard NYI + "TestOps.test_hadamard", + "TestOps.test_hadamard_grad_vmap", + # Same as CUDA - FFTs NYI + "TestFFT.test_fft", + "TestFFT.test_fft_big_powers_of_two", + "TestFFT.test_fft_contiguity", + "TestFFT.test_fft_exhaustive", + "TestFFT.test_fft_grads", + "TestFFT.test_fft_into_ifft", + "TestFFT.test_fft_large_numbers", + "TestFFT.test_fft_shared_mem", + "TestFFT.test_fftn", + # Same as CUDA - Lapack ops NYI + "TestLinalg.test_cholesky", + "TestLinalg.test_cholesky_inv", + "TestLinalg.test_eig", + "TestLinalg.test_eigh", + "TestLinalg.test_inverse", + "TestVmap.test_vmap_inverse", + "TestLinalg.test_lu", + "TestLinalg.test_lu_factor", + "TestLinalg.test_pseudo_inverse", + "TestLinalg.test_qr_factorization", + "TestInit.test_orthogonal", + "TestLinalg.test_svd_decomposition", + "TestVmap.test_vmap_svd", + "TestLinalg.test_tri_inverse", + # Same as CUDA - Masked scatter NYI + "TestOps.test_masked_scatter", + "TestVmap.test_vmap_masked_scatter", + "TestArray.test_setitem_with_boolean_mask", + # Quantization - ROCm has different support than CUDA + "TestQuantized.test_gather_matmul_grad", + "TestQuantized.test_gather_qmm", + "TestQuantized.test_gather_qmm_sorted", + "TestQuantized.test_gather_qmm_grad", + "TestQuantized.test_non_multiples", + "TestQuantized.test_fp_qvm", + "TestQuantized.test_fp_qmv", # ROCm fp_qmv currently aborts on GPU + "TestQuantized.test_qmv_small_non_multiples", # nvfp4 qmv path unsupported + "TestQuantized.test_qvm", + "TestQuantized.test_qvm_splitk", + "TestQuantized.test_small_matrix", + "TestQuantized.test_throw", + "TestQuantized.test_vjp_scales_biases", + "TestExportImport.test_export_quantized_model", + "TestLayers.test_quantized_embedding", + # ROCm-specific: Complex power has numerical issues + "TestOps.test_complex_power", + # ROCm-specific: Complex ops (arctan) has numerical issues + "TestOps.test_complex_ops", + # ROCm-specific: Scan operations don't support complex types + "TestOps.test_logcumsumexp", + "TestOps.test_scans", + # ROCm-specific: logsumexp has numerical issues with complex types + "TestOps.test_logsumexp", + # ROCm-specific: sort has issues with multi-block sort + "TestOps.test_sort", + # ROCm-specific: Complex reduce operations not supported + "TestReduce.test_nan_propagation_complex64", + "TestReduce.test_dtypes", # Complex64 reduce not supported + # ROCm-specific: vmap matmul fails on unsupported architectures + "TestVmap.test_vmap_matmul", + # ROCm-specific: group_norm has numerical precision issues + "TestLayers.test_group_norm", + # ROCm-specific: Custom kernel tests use Metal-specific APIs + # hip_kernel is available but tests are written for metal_kernel + "TestFast.test_custom_kernel_args", + "TestFast.test_custom_kernel_attributes", + "TestFast.test_custom_kernel_basic", + "TestFast.test_custom_kernel_helper", + "TestFast.test_custom_kernel_strides", + # ROCm-specific: SDPA backward pass falls back to CPU + # These tests may be slow but should still pass +} diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index dedfa5d4fb..a11dd56aae 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -475,12 +475,20 @@ def test_matrix_vector_attn(self): o_mx = (s_mx @ v_mx_reshape) o_mx = o_mx.transpose(0, 3, 1, 2, 4).reshape(B, qsl, -1) + tol = 1e-4 + if ( + dtype == "float16" + and mx.default_device() == mx.gpu + and not mx.metal.is_available() + ): + tol = 2e-4 + # Check against np self.assertListEqual(list(s_np.shape), list(s_mx.shape)) - self.assertTrue(np.allclose(s_np, s_mx, atol=1e-4)) + self.assertTrue(np.allclose(s_np, s_mx, atol=tol)) self.assertListEqual(list(o_np.shape), list(o_mx.shape)) - self.assertTrue(np.allclose(o_np, o_mx, atol=1e-4)) + self.assertTrue(np.allclose(o_np, o_mx, atol=tol)) def test_matrix_vector_edgecases(self): for dtype in self.dtypes: diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index 7bd867084e..f11e72fa3a 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -20,26 +20,18 @@ def mlx_ref_attn(q, k, v, scale=1.0, mask=None, sinks=None): kL = k.shape[2] if n_repeats > 1: - q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1]) - k = mx.expand_dims(k, 2) - v = mx.expand_dims(v, 2) + k = mx.repeat(k, repeats=n_repeats, axis=-3) + v = mx.repeat(v, repeats=n_repeats, axis=-3) scores = q @ mx.swapaxes(k, -1, -2) is_causal = mask == "causal" if mask is not None: - if is_causal: offset = kL - L q_indices = mx.arange(L) + offset k_indices = mx.arange(kL) mask = q_indices[:, None] >= k_indices[None] - if n_repeats > 1 and mask.ndim >= 3: - if mask.shape[-3] == 1: - mask = mx.expand_dims(mask, -3) - else: - mask = mx.unflatten(mask, -3, (n_kv_heads, n_repeats)) - if mask.dtype == mx.bool_: scores = mx.where(mask, scores, mx.finfo(scores.dtype).min) else: @@ -47,8 +39,6 @@ def mlx_ref_attn(q, k, v, scale=1.0, mask=None, sinks=None): if sinks is not None: sinks = mx.expand_dims(sinks, (0, 2, 3)) - if n_repeats > 1: - sinks = mx.unflatten(sinks, 1, (n_kv_heads, n_repeats)) score_shape = list(scores.shape) score_shape[-1] = 1 sinks = mx.broadcast_to(sinks, score_shape) @@ -59,8 +49,6 @@ def mlx_ref_attn(q, k, v, scale=1.0, mask=None, sinks=None): scores = scores[..., 1:] out = scores @ v - if n_repeats > 1: - out = mx.reshape(out, [B, n_q_heads, L, -1]) return out diff --git a/test_qwen3_generation.py b/test_qwen3_generation.py new file mode 100644 index 0000000000..00973d0aaf --- /dev/null +++ b/test_qwen3_generation.py @@ -0,0 +1,231 @@ +"""Pytest-based generation checks for Qwen3, LFM2.5, and Qwen3-Coder-Next variants. + +Run with: + source venv/bin/activate + pytest -s test_qwen3_generation.py + +Environment overrides: + MLX_TEST_PROMPT="Your deterministic prompt" + MLX_TEST_SEED=42 + MLX_TEST_MAX_TOKENS=64 + MLX_TEST_DEVICE=gpu|cpu + MLX_TEST_OUTPUT_DIR=/path/to/save/outputs + MLX_TEST_REPEATABILITY=1 # rerun each model twice and compare text +""" + +from __future__ import annotations + +import itertools +import os +import re +import warnings +from pathlib import Path +from typing import Any, cast + +# Suppress known third-party SWIG deprecation noise seen during model/tokenizer imports. +warnings.filterwarnings( + "ignore", + message=r"builtin type SwigPyPacked has no __module__ attribute", + category=DeprecationWarning, +) +warnings.filterwarnings( + "ignore", + message=r"builtin type SwigPyObject has no __module__ attribute", + category=DeprecationWarning, +) +warnings.filterwarnings( + "ignore", + message=r"builtin type swigvarlink has no __module__ attribute", + category=DeprecationWarning, +) + +import mlx.core as mx +import pytest + +try: + from mlx_lm import load + from mlx_lm.generate import generate +except Exception as exc: # pragma: no cover + pytest.skip( + f"mlx_lm is required for this test file: {exc}", allow_module_level=True + ) + + +MODEL_FAMILIES = [ + "mlx-community/Qwen3-0.6B", + "mlx-community/LFM2.5-1.2B-Instruct", + "mlx-community/LFM2.5-1.2B-Thinking", +] +MODEL_VARIANTS = ["bf16", "3bit", "4bit", "6bit", "8bit"] +EXPLICIT_MODELS = [ + "mlx-community/Qwen3-Coder-Next-4bit", +] + +# Fixed model list used as pytest cases. +MODELS = [ + f"{model_family}-{variant}" + for model_family in MODEL_FAMILIES + for variant in MODEL_VARIANTS +] + EXPLICIT_MODELS + +DEFAULT_PROMPT = "Write exactly one short friendly greeting." +DEFAULT_SEED = 42 +DEFAULT_MAX_TOKENS = 64 +PROMPT = os.getenv("MLX_TEST_PROMPT", DEFAULT_PROMPT) +SEED = int(os.getenv("MLX_TEST_SEED", str(DEFAULT_SEED))) +MAX_TOKENS = int(os.getenv("MLX_TEST_MAX_TOKENS", str(DEFAULT_MAX_TOKENS))) +DEVICE_NAME = os.getenv("MLX_TEST_DEVICE", "gpu").strip().lower() +OUTPUT_DIR_OVERRIDE = os.getenv("MLX_TEST_OUTPUT_DIR", "").strip() +REPEATABILITY_CHECK = os.getenv("MLX_TEST_REPEATABILITY", "0").strip() == "1" + + +if DEVICE_NAME not in {"gpu", "cpu"}: + raise ValueError("MLX_TEST_DEVICE must be one of: gpu, cpu") +if not MODELS: + raise ValueError("No models configured. Update the MODELS list.") + + +DEVICE = mx.gpu if DEVICE_NAME == "gpu" else mx.cpu + + +def _greedy_sampler(logprobs: mx.array) -> mx.array: + return mx.argmax(logprobs, axis=-1) + + +def _case_id(model_id: str) -> str: + return model_id.split("/")[-1] + + +def _slug(text: str) -> str: + return re.sub(r"[^a-zA-Z0-9_.-]+", "_", text) + + +def _text_stats(text: str) -> dict[str, float | int]: + words = re.findall(r"\w+", text, flags=re.UNICODE) + word_count = len(words) + unique_words = len(set(words)) + unique_word_ratio = unique_words / word_count if word_count else 0.0 + longest_char_run = max( + (sum(1 for _ in group) for _, group in itertools.groupby(text)), default=0 + ) + return { + "chars": len(text), + "words": word_count, + "unique_words": unique_words, + "unique_word_ratio": unique_word_ratio, + "longest_char_run": longest_char_run, + } + + +def _exception_chain(exc: BaseException) -> tuple[BaseException, ...]: + chain: list[BaseException] = [] + stack = [exc] + seen: set[int] = set() + while stack: + current = stack.pop() + current_id = id(current) + if current_id in seen: + continue + seen.add(current_id) + chain.append(current) + if current.__cause__ is not None: + stack.append(current.__cause__) + if current.__context__ is not None: + stack.append(current.__context__) + return tuple(chain) + + +def _is_404_error(exc: Exception) -> bool: + for current in _exception_chain(exc): + response = getattr(current, "response", None) + if getattr(response, "status_code", None) == 404: + return True + if getattr(current, "status_code", None) == 404: + return True + message = str(current).lower() + if "404" in message and any( + token in message + for token in ( + "not found", + "does not exist", + "could not find", + "couldn't find", + ) + ): + return True + return False + + +def _generate(model_id: str) -> str: + mx.set_default_device(cast(Any, DEVICE)) + mx.random.seed(SEED) + + try: + model, tokenizer, *_ = load(model_id) + except Exception as exc: + if _is_404_error(exc): + pytest.skip(f"{model_id} is unavailable on the hub (404): {exc}") + raise + + text = generate( + model, + tokenizer, + prompt=PROMPT, + max_tokens=MAX_TOKENS, + sampler=_greedy_sampler, + verbose=False, + ) + + del model + del tokenizer + mx.clear_cache() + return text + + +@pytest.fixture(scope="session") +def output_dir(tmp_path_factory: pytest.TempPathFactory) -> Path: + if OUTPUT_DIR_OVERRIDE: + path = Path(OUTPUT_DIR_OVERRIDE) + path.mkdir(parents=True, exist_ok=True) + return path + return tmp_path_factory.mktemp("generation_outputs") + + +@pytest.mark.parametrize("model_id", MODELS, ids=_case_id) +def test_generate_and_show_output(model_id: str, output_dir: Path) -> None: + text = _generate(model_id) + stats = _text_stats(text) + + output_path = output_dir / f"{_slug(model_id)}.txt" + output_path.write_text(text, encoding="utf-8") + + print(f"\n=== MODEL: {model_id} ===") + print(f"device={DEVICE_NAME} seed={SEED} max_tokens={MAX_TOKENS} prompt={PROMPT!r}") + print( + "stats: " + f"chars={stats['chars']} " + f"words={stats['words']} " + f"unique_words={stats['unique_words']} " + f"unique_word_ratio={stats['unique_word_ratio']:.3f} " + f"longest_char_run={stats['longest_char_run']}" + ) + print("--- output start ---") + print(text) + print("--- output end ---") + print(f"saved: {output_path}") + + assert text.strip(), f"{model_id} generated empty output" + + +@pytest.mark.skipif( + not REPEATABILITY_CHECK, + reason="Set MLX_TEST_REPEATABILITY=1 to enforce exact repeatability.", +) +@pytest.mark.parametrize("model_id", MODELS, ids=_case_id) +def test_repeatability(model_id: str) -> None: + first = _generate(model_id) + second = _generate(model_id) + assert first == second, ( + f"{model_id} is not repeatable with fixed seed={SEED}, prompt={PROMPT!r}, " + f"device={DEVICE_NAME}." + )