Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions benchmarks/benchmark_linear_logp.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,14 @@ def run_benchmark(args):
for num_tokens, hidden_dim, vocab in args.configs:
hidden, weight, target = _make_inputs(num_tokens, hidden_dim, vocab, device, dtype)

def fwd(op, h=hidden, w=weight):
def fwd(op, h=hidden, w=weight, t=target):
with torch.no_grad():
op(h, w, target)
op(h, w, t)

def fwd_bwd(op):
h = hidden.clone().requires_grad_(True)
w = weight.clone().requires_grad_(True)
op(h, w, target).sum().backward()
def fwd_bwd(op, h_src=hidden, w_src=weight, t=target):
h = h_src.clone().requires_grad_(True)
w = w_src.clone().requires_grad_(True)
op(h, w, t).sum().backward()

n_fwd = _time_ms(lambda: fwd(native), args.warmup, args.iters)
t_fwd = _time_ms(lambda: fwd(triton_op), args.warmup, args.iters)
Expand Down
18 changes: 10 additions & 8 deletions csrc/cuda/fused_linear_logp_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,14 @@ __global__ void fused_linear_logp_sm90_kernel(const __grid_constant__ CUtensorMa
float *sZt = sSum + BM;
int *mbar_base = reinterpret_cast<int *>(sZt + BM); // STAGES mbarriers (8B each)

const uint32_t sH_base = static_cast<uint32_t>(__cvta_generic_to_shared(sH));
const uint32_t sW_base = static_cast<uint32_t>(__cvta_generic_to_shared(sW));
int mbar[STAGES];
const uint64_t sH_base_tma = __cvta_generic_to_shared(sH);
const uint64_t sW_base_tma = __cvta_generic_to_shared(sW);
const uint32_t sH_base = static_cast<uint32_t>(sH_base_tma);
const uint32_t sW_base = static_cast<uint32_t>(sW_base_tma);
uint64_t mbar[STAGES];
#pragma unroll
for (int s = 0; s < STAGES; ++s)
mbar[s] = static_cast<int>(__cvta_generic_to_shared(mbar_base + 2 * s));
mbar[s] = __cvta_generic_to_shared(mbar_base + 2 * s);

for (int r = tid; r < num_rows; r += WG_THREADS) {
sMax[r] = -CUDART_INF_F;
Expand All @@ -111,11 +113,11 @@ __global__ void fused_linear_logp_sm90_kernel(const __grid_constant__ CUtensorMa
auto issue_load = [&](int k, int col_base) {
const int buf = k % STAGES;
const int k_off = k * BK;
tma_2d_g2s(static_cast<int>(sH_base + buf * BM * BK * sizeof(nv_bfloat16)), &h_tmap, k_off,
row_base, mbar[buf]);
tma_2d_g2s(static_cast<int>(sW_base + buf * BN * BK * sizeof(nv_bfloat16)), &w_tmap, k_off,
col_base, mbar[buf]);
mbarrier_arrive_expect_tx(mbar[buf], tile_bytes);
tma_2d_g2s(sH_base_tma + buf * BM * BK * sizeof(nv_bfloat16), &h_tmap, k_off, row_base,
mbar[buf]);
tma_2d_g2s(sW_base_tma + buf * BN * BK * sizeof(nv_bfloat16), &w_tmap, k_off, col_base,
mbar[buf]);
};

int phase[STAGES];
Expand Down
29 changes: 18 additions & 11 deletions csrc/utils/tma_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <cuda.h>
#include <cuda_bf16.h>
#include <cudaTypedefs.h>
#include <cstdint>
#include <iostream>

// Type Traits for TMA
Expand Down Expand Up @@ -51,33 +52,39 @@ inline void init_tensor_map(
}

// Device API
__device__ inline void mbarrier_init(int addr, int count) {
asm volatile("mbarrier.init.shared::cta.b64 [%0], %1;" :: "r"(addr), "r"(count));
__device__ inline void mbarrier_init(uint64_t addr, int count) {
asm volatile("mbarrier.init.shared::cta.b64 [%0], %1;" :: "l"(addr), "r"(count));
}

__device__ inline void mbarrier_arrive(int addr) {
asm volatile("mbarrier.arrive.release.cta.shared::cta.b64 _, [%0];" :: "r"(addr) : "memory");
__device__ inline void mbarrier_arrive(uint64_t addr) {
asm volatile("mbarrier.arrive.release.cta.shared::cta.b64 _, [%0];" :: "l"(addr) : "memory");
}

__device__ inline void mbarrier_arrive_expect_tx(int addr, int size) {
__device__ inline void mbarrier_arrive_expect_tx(uint64_t addr, int size) {
asm volatile("mbarrier.arrive.expect_tx.release.cta.shared::cta.b64 _, [%0], %1;"
:: "r"(addr), "r"(size) : "memory");
:: "l"(addr), "r"(size) : "memory");
}

__device__ inline void mbarrier_wait(int mbar_addr, int phase) {
__device__ inline void mbarrier_wait(uint64_t mbar_addr, int phase) {
int ticks = 0x989680;
asm volatile(
"{\n"
".reg .pred P1;\n"
"LAB_WAIT:\n"
"mbarrier.try_wait.parity.acquire.cta.shared::cta.b64 P1, [%0], %1, %2;\n"
"@!P1 bra.uni LAB_WAIT;\n"
"}" :: "r"(mbar_addr), "r"(phase), "r"(ticks)
"}" :: "l"(mbar_addr), "r"(phase), "r"(ticks)
);
}

__device__ inline void tma_2d_g2s(int dst_smem_addr, const void *tmap_ptr, int x, int y, int mbar_addr) {
asm volatile("cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_tx::bytes "
__device__ inline void tma_2d_g2s(
uint64_t dst_smem_addr,
const void *tmap_ptr,
int x,
int y,
uint64_t mbar_addr
) {
asm volatile("cp.async.bulk.tensor.2d.shared::cluster.global.tile.mbarrier::complete_tx::bytes "
"[%0], [%1, {%2, %3}], [%4];"
:: "r"(dst_smem_addr), "l"(tmap_ptr), "r"(x), "r"(y), "r"(mbar_addr) : "memory");
:: "l"(dst_smem_addr), "l"(tmap_ptr), "r"(x), "r"(y), "l"(mbar_addr) : "memory");
}
1 change: 1 addition & 0 deletions docs/.nav.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ nav:
- operators/README.md
- operators/fused-logp.md
- operators/linear-logp.md
- operators/linear-logp-tp-test.md
- operators/grpo-loss.md
- operators/ratio-kl.md
- operators/sampling.md
Expand Down
8 changes: 5 additions & 3 deletions docs/design/runtime-dispatch.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@ logical type, and the registry selects the first available backend for the curre

| Platform | Priority |
| --- | --- |
| CUDA | SM90 fused LogP when available, CUDA generic, FlashInfer, Triton generic, PyTorch native |
| CUDA | CUDA generic LogP by default; experimental SM90 fused LogP only when explicitly enabled, FlashInfer, Triton generic, PyTorch native |
| ROCm | AITER, Triton generic, PyTorch native |
| CPU | PyTorch native |

For CUDA devices with compute capability 9.0 or newer, the registry inserts the SM90
LogP backend at the front of the CUDA priority list.
For CUDA devices with compute capability 9.0 or newer, the registry only inserts
the legacy SM90 LogP backend when `RL_KERNEL_ENABLE_EXPERIMENTAL_SM90_LOGP=1` is
set. The fused linear logp SM90 backend is gated separately and remains the
default linear logp backend when the extension is built on Hopper.
Comment on lines +22 to +25

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win

Tighten the capability wording here.

This says 9.0 or newer, but the registry currently promotes the experimental backend only for cc_major in (9, 10, 12). Documenting it as an open-ended >= 9.0 gate already diverges from the implementation contract.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@docs/design/runtime-dispatch.md` around lines 22 - 25, Tighten the
runtime-dispatch capability wording to match the actual registry contract: in
the runtime dispatch documentation, replace the open-ended “compute capability
9.0 or newer” phrasing with the specific supported major versions used by the
registry logic. Refer to the SM90 LogP backend gating text so it clearly
reflects that the experimental backend is only promoted for the exact cc_major
values handled by the implementation, while keeping the fused linear logp SM90
backend description unchanged.


## Relevant Files

Expand Down
1 change: 1 addition & 0 deletions docs/operators/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Every operator page should include:

- [Fused LogP](fused-logp.md)
- [Fused Linear LogP](linear-logp.md)
- [Fused Linear LogP TP Test Runbook](linear-logp-tp-test.md)
- [GRPO Loss](grpo-loss.md)
- [Policy Ratio + KL Penalty](ratio-kl.md)
- [Sampling](sampling.md)
Expand Down
4 changes: 2 additions & 2 deletions docs/operators/fused-logp.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ output = logp_op(logits, token_ids)

| Backend | Wrapper | Native symbol | Notes |
| --- | --- | --- | --- |
| CUDA SM90 | `FusedLogpSM90Op` | `_C.fused_logp_sm90` | TMA-oriented path for Hopper-class GPUs. |
| CUDA SM90 | `FusedLogpSM90Op` | `_C.fused_logp_sm90` | Experimental TMA-oriented path for 2D contiguous bf16 logits on Hopper-class GPUs. It is disabled by default and requires `RL_KERNEL_ENABLE_EXPERIMENTAL_SM90_LOGP=1`; otherwise the wrapper delegates to the CUDA generic fallback. |
| CUDA generic | `FusedLogpGenericOp` | `_C.fused_logp` | Generic compiled extension fallback. |
| PyTorch native | `NativeOp` | None | Baseline fallback path. |

## Tensor Contract

| Argument | Shape | Dtype | Requirements |
| --- | --- | --- | --- |
| `logits` | `[N, V]` | `bfloat16` for SM90 path | Contiguous, on the target device. |
| `logits` | `[N, V]` | `bfloat16` for the experimental SM90 fast path; fp16/fp32 use generic fallback | Contiguous, on the target device for the experimental SM90 fast path. |
| `token_ids` / `labels` | `[N]` | Converted to `int32` | Same logical device as `logits`. |
| Output | `[N]` | Backend-defined tensor dtype | One selected log probability per row. |

Expand Down
Loading
Loading