diff --git a/benchmarks/benchmark_linear_logp.py b/benchmarks/benchmark_linear_logp.py index 2e8ffab..e5a3095 100644 --- a/benchmarks/benchmark_linear_logp.py +++ b/benchmarks/benchmark_linear_logp.py @@ -105,14 +105,14 @@ def run_benchmark(args): for num_tokens, hidden_dim, vocab in args.configs: hidden, weight, target = _make_inputs(num_tokens, hidden_dim, vocab, device, dtype) - def fwd(op, h=hidden, w=weight): + def fwd(op, h=hidden, w=weight, t=target): with torch.no_grad(): - op(h, w, target) + op(h, w, t) - def fwd_bwd(op): - h = hidden.clone().requires_grad_(True) - w = weight.clone().requires_grad_(True) - op(h, w, target).sum().backward() + def fwd_bwd(op, h_src=hidden, w_src=weight, t=target): + h = h_src.clone().requires_grad_(True) + w = w_src.clone().requires_grad_(True) + op(h, w, t).sum().backward() n_fwd = _time_ms(lambda: fwd(native), args.warmup, args.iters) t_fwd = _time_ms(lambda: fwd(triton_op), args.warmup, args.iters) diff --git a/csrc/cuda/fused_linear_logp_sm90.cu b/csrc/cuda/fused_linear_logp_sm90.cu index ff1ec74..18504ee 100644 --- a/csrc/cuda/fused_linear_logp_sm90.cu +++ b/csrc/cuda/fused_linear_logp_sm90.cu @@ -85,12 +85,14 @@ __global__ void fused_linear_logp_sm90_kernel(const __grid_constant__ CUtensorMa float *sZt = sSum + BM; int *mbar_base = reinterpret_cast(sZt + BM); // STAGES mbarriers (8B each) - const uint32_t sH_base = static_cast(__cvta_generic_to_shared(sH)); - const uint32_t sW_base = static_cast(__cvta_generic_to_shared(sW)); - int mbar[STAGES]; + const uint64_t sH_base_tma = __cvta_generic_to_shared(sH); + const uint64_t sW_base_tma = __cvta_generic_to_shared(sW); + const uint32_t sH_base = static_cast(sH_base_tma); + const uint32_t sW_base = static_cast(sW_base_tma); + uint64_t mbar[STAGES]; #pragma unroll for (int s = 0; s < STAGES; ++s) - mbar[s] = static_cast(__cvta_generic_to_shared(mbar_base + 2 * s)); + mbar[s] = __cvta_generic_to_shared(mbar_base + 2 * s); for (int r = tid; r < num_rows; r += WG_THREADS) { sMax[r] = -CUDART_INF_F; @@ -111,11 +113,11 @@ __global__ void fused_linear_logp_sm90_kernel(const __grid_constant__ CUtensorMa auto issue_load = [&](int k, int col_base) { const int buf = k % STAGES; const int k_off = k * BK; - tma_2d_g2s(static_cast(sH_base + buf * BM * BK * sizeof(nv_bfloat16)), &h_tmap, k_off, - row_base, mbar[buf]); - tma_2d_g2s(static_cast(sW_base + buf * BN * BK * sizeof(nv_bfloat16)), &w_tmap, k_off, - col_base, mbar[buf]); mbarrier_arrive_expect_tx(mbar[buf], tile_bytes); + tma_2d_g2s(sH_base_tma + buf * BM * BK * sizeof(nv_bfloat16), &h_tmap, k_off, row_base, + mbar[buf]); + tma_2d_g2s(sW_base_tma + buf * BN * BK * sizeof(nv_bfloat16), &w_tmap, k_off, col_base, + mbar[buf]); }; int phase[STAGES]; diff --git a/csrc/utils/tma_utils.cuh b/csrc/utils/tma_utils.cuh index f7311e2..f22d358 100644 --- a/csrc/utils/tma_utils.cuh +++ b/csrc/utils/tma_utils.cuh @@ -6,6 +6,7 @@ #include #include #include +#include #include // Type Traits for TMA @@ -51,20 +52,20 @@ inline void init_tensor_map( } // Device API -__device__ inline void mbarrier_init(int addr, int count) { - asm volatile("mbarrier.init.shared::cta.b64 [%0], %1;" :: "r"(addr), "r"(count)); +__device__ inline void mbarrier_init(uint64_t addr, int count) { + asm volatile("mbarrier.init.shared::cta.b64 [%0], %1;" :: "l"(addr), "r"(count)); } -__device__ inline void mbarrier_arrive(int addr) { - asm volatile("mbarrier.arrive.release.cta.shared::cta.b64 _, [%0];" :: "r"(addr) : "memory"); +__device__ inline void mbarrier_arrive(uint64_t addr) { + asm volatile("mbarrier.arrive.release.cta.shared::cta.b64 _, [%0];" :: "l"(addr) : "memory"); } -__device__ inline void mbarrier_arrive_expect_tx(int addr, int size) { +__device__ inline void mbarrier_arrive_expect_tx(uint64_t addr, int size) { asm volatile("mbarrier.arrive.expect_tx.release.cta.shared::cta.b64 _, [%0], %1;" - :: "r"(addr), "r"(size) : "memory"); + :: "l"(addr), "r"(size) : "memory"); } -__device__ inline void mbarrier_wait(int mbar_addr, int phase) { +__device__ inline void mbarrier_wait(uint64_t mbar_addr, int phase) { int ticks = 0x989680; asm volatile( "{\n" @@ -72,12 +73,18 @@ __device__ inline void mbarrier_wait(int mbar_addr, int phase) { "LAB_WAIT:\n" "mbarrier.try_wait.parity.acquire.cta.shared::cta.b64 P1, [%0], %1, %2;\n" "@!P1 bra.uni LAB_WAIT;\n" - "}" :: "r"(mbar_addr), "r"(phase), "r"(ticks) + "}" :: "l"(mbar_addr), "r"(phase), "r"(ticks) ); } -__device__ inline void tma_2d_g2s(int dst_smem_addr, const void *tmap_ptr, int x, int y, int mbar_addr) { - asm volatile("cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_tx::bytes " +__device__ inline void tma_2d_g2s( + uint64_t dst_smem_addr, + const void *tmap_ptr, + int x, + int y, + uint64_t mbar_addr +) { + asm volatile("cp.async.bulk.tensor.2d.shared::cluster.global.tile.mbarrier::complete_tx::bytes " "[%0], [%1, {%2, %3}], [%4];" - :: "r"(dst_smem_addr), "l"(tmap_ptr), "r"(x), "r"(y), "r"(mbar_addr) : "memory"); + :: "l"(dst_smem_addr), "l"(tmap_ptr), "r"(x), "r"(y), "l"(mbar_addr) : "memory"); } diff --git a/docs/.nav.yml b/docs/.nav.yml index 60525c2..6b0b1e3 100644 --- a/docs/.nav.yml +++ b/docs/.nav.yml @@ -13,6 +13,7 @@ nav: - operators/README.md - operators/fused-logp.md - operators/linear-logp.md + - operators/linear-logp-tp-test.md - operators/grpo-loss.md - operators/ratio-kl.md - operators/sampling.md diff --git a/docs/design/runtime-dispatch.md b/docs/design/runtime-dispatch.md index c7e4fde..bedf347 100644 --- a/docs/design/runtime-dispatch.md +++ b/docs/design/runtime-dispatch.md @@ -15,12 +15,14 @@ logical type, and the registry selects the first available backend for the curre | Platform | Priority | | --- | --- | -| CUDA | SM90 fused LogP when available, CUDA generic, FlashInfer, Triton generic, PyTorch native | +| CUDA | CUDA generic LogP by default; experimental SM90 fused LogP only when explicitly enabled, FlashInfer, Triton generic, PyTorch native | | ROCm | AITER, Triton generic, PyTorch native | | CPU | PyTorch native | -For CUDA devices with compute capability 9.0 or newer, the registry inserts the SM90 -LogP backend at the front of the CUDA priority list. +For CUDA devices with compute capability 9.0 or newer, the registry only inserts +the legacy SM90 LogP backend when `RL_KERNEL_ENABLE_EXPERIMENTAL_SM90_LOGP=1` is +set. The fused linear logp SM90 backend is gated separately and remains the +default linear logp backend when the extension is built on Hopper. ## Relevant Files diff --git a/docs/operators/README.md b/docs/operators/README.md index c4eae60..3c6fd95 100644 --- a/docs/operators/README.md +++ b/docs/operators/README.md @@ -20,6 +20,7 @@ Every operator page should include: - [Fused LogP](fused-logp.md) - [Fused Linear LogP](linear-logp.md) +- [Fused Linear LogP TP Test Runbook](linear-logp-tp-test.md) - [GRPO Loss](grpo-loss.md) - [Policy Ratio + KL Penalty](ratio-kl.md) - [Sampling](sampling.md) diff --git a/docs/operators/fused-logp.md b/docs/operators/fused-logp.md index d5008e5..adb5e81 100644 --- a/docs/operators/fused-logp.md +++ b/docs/operators/fused-logp.md @@ -17,7 +17,7 @@ output = logp_op(logits, token_ids) | Backend | Wrapper | Native symbol | Notes | | --- | --- | --- | --- | -| CUDA SM90 | `FusedLogpSM90Op` | `_C.fused_logp_sm90` | TMA-oriented path for Hopper-class GPUs. | +| CUDA SM90 | `FusedLogpSM90Op` | `_C.fused_logp_sm90` | Experimental TMA-oriented path for 2D contiguous bf16 logits on Hopper-class GPUs. It is disabled by default and requires `RL_KERNEL_ENABLE_EXPERIMENTAL_SM90_LOGP=1`; otherwise the wrapper delegates to the CUDA generic fallback. | | CUDA generic | `FusedLogpGenericOp` | `_C.fused_logp` | Generic compiled extension fallback. | | PyTorch native | `NativeOp` | None | Baseline fallback path. | @@ -25,7 +25,7 @@ output = logp_op(logits, token_ids) | Argument | Shape | Dtype | Requirements | | --- | --- | --- | --- | -| `logits` | `[N, V]` | `bfloat16` for SM90 path | Contiguous, on the target device. | +| `logits` | `[N, V]` | `bfloat16` for the experimental SM90 fast path; fp16/fp32 use generic fallback | Contiguous, on the target device for the experimental SM90 fast path. | | `token_ids` / `labels` | `[N]` | Converted to `int32` | Same logical device as `logits`. | | Output | `[N]` | Backend-defined tensor dtype | One selected log probability per row. | diff --git a/docs/operators/linear-logp-tp-test.md b/docs/operators/linear-logp-tp-test.md new file mode 100644 index 0000000..e02324f --- /dev/null +++ b/docs/operators/linear-logp-tp-test.md @@ -0,0 +1,303 @@ +# Fused Linear LogP TP Test Runbook + +This runbook describes how to validate the tensor-parallel path of +`linear_logp` on a 4-GPU Hopper machine. The script exercises the public +operator API with vocab-sharded LM-head weights: + +- each rank owns `lm_head_weight[vocab_start:vocab_end]`; +- `target_ids` are global token ids and are replicated across TP ranks; +- forward merges the global log-sum-exp through TP collectives; +- backward all-reduces `hidden.grad` and keeps weight/bias gradients local. + +## Test Script + +```bash +scripts/test_linear_logp_tp.py +``` + +Launch it with `torchrun`; do not run it with plain `python` unless you are +debugging argument parsing only. + +```bash +torchrun --standalone --nproc_per_node=4 scripts/test_linear_logp_tp.py +``` + +The script has two phases: + +| Phase | Default | What it checks | +| --- | --- | --- | +| `correctness` | Always on | Builds a full materialized reference on every rank and compares TP output, `hidden.grad`, local `weight.grad`, and local `bias.grad`. | +| `stress` | `--run-stress` | Runs a larger TP-only shape without materializing full logits/reference; checks finite output/gradients, elapsed time, and peak memory. | + +## Prerequisites + +From the repository root: + +```bash +python -m pip install -e ".[dev]" +``` + +For SM90 direct-backend testing, rebuild the extension on the Hopper host: + +```bash +KERNEL_ALIGN_FORCE_SM90=1 python -m pip install -e . +``` + +Recommended environment: + +```bash +export CUDA_VISIBLE_DEVICES=0,1,2,3 +export NCCL_DEBUG=INFO +export NCCL_ASYNC_ERROR_HANDLING=1 +export TORCH_NCCL_BLOCKING_WAIT=1 +``` + +If the machine uses a scheduler, make sure the job really owns four GPUs on the +same node before launching. + +## Quick Start + +Start with a small fp32 run. This should be numerically tight and easy to debug: + +```bash +torchrun --standalone --nproc_per_node=4 scripts/test_linear_logp_tp.py \ + --dtype fp32 \ + --tokens 128 \ + --hidden-size 256 \ + --vocab-size 4096 \ + --uneven-shards +``` + +Then run the bf16 path that matches the intended Hopper use case: + +```bash +torchrun --standalone --nproc_per_node=4 scripts/test_linear_logp_tp.py \ + --dtype bf16 \ + --tokens 256 \ + --hidden-size 512 \ + --vocab-size 8192 \ + --reference-mode fp32 \ + --uneven-shards +``` + +Finally run a larger TP smoke without the full reference: + +```bash +torchrun --standalone --nproc_per_node=4 scripts/test_linear_logp_tp.py \ + --dtype bf16 \ + --tokens 256 \ + --hidden-size 512 \ + --vocab-size 8192 \ + --reference-mode fp32 \ + --run-stress \ + --stress-tokens 4096 \ + --stress-hidden-size 2048 \ + --stress-vocab-size 32768 +``` + +## Recommended Test Matrix + +Run these in order. Stop at the first failure and keep the full terminal log. + +### 1. Native TP math sanity + +This bypasses registry/backend selection and directly tests the shared TP +autograd path. + +```bash +torchrun --standalone --nproc_per_node=4 scripts/test_linear_logp_tp.py \ + --op-source native \ + --dtype fp32 \ + --tokens 128 \ + --hidden-size 256 \ + --vocab-size 4096 \ + --uneven-shards +``` + +Expected: every metric is `PASS` with max errors around `1e-4` on CUDA/NCCL. + +### 2. Registry path, bf16 fp32 reference + +This is the main merge gate for the current TP implementation. The fused and +Triton bf16 paths accumulate matmuls in fp32, so `fp32` reference mode compares +against the same semantic target. + +```bash +torchrun --standalone --nproc_per_node=4 scripts/test_linear_logp_tp.py \ + --op-source registry \ + --dtype bf16 \ + --reference-mode fp32 \ + --tokens 256 \ + --hidden-size 512 \ + --vocab-size 8192 \ + --uneven-shards +``` + +Expected: `output`, `hidden_grad`, `weight_grad`, and `bias_grad` all pass. + +### 3. Optional same-dtype drift check + +This compares bf16 TP against a materialized full-vocab `F.linear` in the input +dtype. It is useful for understanding PyTorch full-GEMM vs shard-GEMM drift, but +the main correctness target is the fp32-accumulation reference above. + +```bash +torchrun --standalone --nproc_per_node=4 scripts/test_linear_logp_tp.py \ + --op-source registry \ + --dtype bf16 \ + --reference-mode matching \ + --tokens 128 \ + --hidden-size 512 \ + --vocab-size 8192 \ + --atol 0.75 \ + --rtol 0.75 +``` + +Record the drift numbers separately from the merge gate. + +### 4. Triton wrapper delegation + +Use this when Triton is installed. The TP kwargs should route through the shared +TP path instead of the local non-TP Triton kernel. + +```bash +torchrun --standalone --nproc_per_node=4 scripts/test_linear_logp_tp.py \ + --op-source triton \ + --dtype bf16 \ + --reference-mode fp32 \ + --tokens 256 \ + --hidden-size 512 \ + --vocab-size 8192 +``` + +### 5. SM90 wrapper delegation + +Use this only after rebuilding with `KERNEL_ALIGN_FORCE_SM90=1`. The direct SM90 +op still delegates to the shared TP path once TP kwargs are present. + +```bash +torchrun --standalone --nproc_per_node=4 scripts/test_linear_logp_tp.py \ + --op-source sm90 \ + --dtype bf16 \ + --reference-mode fp32 \ + --tokens 256 \ + --hidden-size 512 \ + --vocab-size 8192 +``` + +### 6. Larger stress + +This checks the end-to-end distributed path at a more realistic shape without +building full reference logits. + +```bash +torchrun --standalone --nproc_per_node=4 scripts/test_linear_logp_tp.py \ + --op-source registry \ + --dtype bf16 \ + --reference-mode fp32 \ + --tokens 256 \ + --hidden-size 512 \ + --vocab-size 8192 \ + --run-stress \ + --stress-tokens 4096 \ + --stress-hidden-size 2048 \ + --stress-vocab-size 32768 +``` + +Expected: `finite=PASS`. Record `max_rank_elapsed_ms` and +`max_rank_peak_memory_gb`. + +## Reading Output + +Successful output looks like: + +```text +[correctness] + dtype=torch.bfloat16, reference_mode=fp32, atol=0.08, rtol=0.08 + tokens=256, hidden=512, vocab=8192 + shard_boundaries=[0, 2048, 4096, 6144, 8192] + PASS output: max_abs=... + PASS hidden_grad: max_abs=... + PASS weight_grad: max_abs=... + PASS bias_grad: max_abs=... + +[result] + PASS +``` + +If `--uneven-shards` is set, the shard boundaries will not be equal-sized. That +is intentional and validates `vocab_start_index` handling. + +## Important Flags + +| Flag | Meaning | +| --- | --- | +| `--op-source registry` | Use `kernel_registry.get_op("linear_logp")`; recommended default. | +| `--op-source native` | Directly test the shared PyTorch TP path. | +| `--op-source triton` | Test Triton wrapper TP delegation. | +| `--op-source sm90` | Test SM90 wrapper TP delegation; requires SM90 extension. | +| `--dtype bf16` | Hopper target dtype. | +| `--reference-mode fp32` | Full reference upcasts hidden/weight/bias to fp32; best for fused/Triton bf16 correctness. | +| `--reference-mode matching` | Full reference uses same-dtype `F.linear`; useful for measuring PyTorch full-GEMM vs shard-GEMM drift. | +| `--uneven-shards` | Builds non-equal vocab shards to test range handling. | +| `--run-stress` | Adds a large TP-only finite/peak-memory smoke. | +| `--no-bias` | Tests the no-bias path. | + +## Troubleshooting + +### NCCL timeout or hang + +Check that every process reaches the same code path and that `nproc_per_node` +matches the number of visible GPUs: + +```bash +echo $CUDA_VISIBLE_DEVICES +nvidia-smi +``` + +Re-run with: + +```bash +export NCCL_DEBUG=INFO +export TORCH_DISTRIBUTED_DEBUG=DETAIL +``` + +### `target_ids must be covered by exactly one TP vocab shard` + +The rank-local `vocab_start_index` or shard sizes are inconsistent. The script +prints `shard_boundaries`; verify they form a contiguous `[0, vocab_size)` +partition. + +### bf16 fp32-reference passes but matching-reference drifts + +This indicates numerical drift between PyTorch's full bf16 GEMM and the TP +vocab-sharded GEMMs. Record the reported `max_abs` and `max_rel`; use the +fp32-reference result as the main TP correctness signal. + +### CUDA OOM in correctness phase + +The correctness phase materializes full `[tokens, vocab]` logits on every rank. +Reduce `--tokens`, `--hidden-size`, or `--vocab-size`. Use `--run-stress` for +larger shapes after small correctness has passed. + +### `fused_linear_logp_sm90 is not compiled` + +This only affects `--op-source sm90`. Rebuild with: + +```bash +KERNEL_ALIGN_FORCE_SM90=1 python -m pip install -e . +``` + +The default `--op-source registry` should still work by falling back to Triton or +native. + +## What To Attach To The PR + +For the PR update, paste: + +- command lines for each run; +- the `[env]` block; +- all `[correctness]` metric lines; +- the `[stress]` finite, elapsed time, and peak memory lines; +- GPU model and CUDA/PyTorch versions; +- any matching-reference drift numbers if that optional run fails or requires loose tolerances. diff --git a/docs/operators/linear-logp.md b/docs/operators/linear-logp.md index 2d8b8ce..f45477f 100644 --- a/docs/operators/linear-logp.md +++ b/docs/operators/linear-logp.md @@ -24,6 +24,9 @@ logp = linear_logp( lm_head_weight, # [V, D] (differentiable) target_ids, # [B, S] or [N] int, in [0, V) bias=None, # [V] optional (differentiable) + tp_group=None, # optional torch.distributed ProcessGroup for vocab TP + vocab_start_index=0, # first global vocab id owned by this rank + global_vocab_size=None, # optional global/padded vocab size ) # -> [B, S] or [N], float32 logp.sum().backward() # gradients flow into hidden, lm_head_weight, bias @@ -63,10 +66,38 @@ logits and is the correctness oracle. Gradients flow into `hidden`, `lm_head_weight`, and `bias`; `target_ids` is integer and non-differentiable. +## Tensor Parallel Vocab Shards + +For tensor-parallel LM heads sharded along the vocab dimension, pass each rank's +local weight shard `[V_local, D]`, optional local bias shard `[V_local]`, the TP +process group, and the shard's global vocab start: + +```python +logp = linear_logp( + hidden, + local_lm_head_weight, + target_ids, # global token ids, replicated across TP ranks + local_bias, + tp_group=tp_group, + vocab_start_index=vocab_start_index, + global_vocab_size=padded_vocab_size, +) +``` + +The TP path streams local vocab chunks, computes local online-softmax state, then +uses TP collectives to merge the global max, global sum, and selected target +logit. Backward recomputes local logits in chunks and all-reduces `hidden.grad` +across the TP group; `lm_head_weight.grad` and `bias.grad` remain local shards. +Shard ranges must form a contiguous `[0, global_vocab_size)` partition. + ## Reference Semantics ```python -logits = torch.nn.functional.linear(hidden.float(), weight.float(), bias) # [N, V] +logits = torch.nn.functional.linear( + hidden.float(), + weight.float(), + None if bias is None else bias.float(), +) # [N, V] logp = torch.log_softmax(logits, dim=-1) out = logp.gather(-1, target_ids.long().unsqueeze(-1)).squeeze(-1) ``` @@ -128,6 +159,9 @@ bf16) vs native, Triton backward vs native autograd (with and without bias), lea shape preservation, a large-vocab smoke test, and registry dispatch. Triton tests skip without CUDA + Triton. +For 4-GPU tensor-parallel validation, use +[Fused Linear LogP TP Test Runbook](linear-logp-tp-test.md). + ## Implementation Files - `rl_engine/kernels/ops/triton/loss/linear_logp.py` diff --git a/rl_engine/executors/deepspeed_trainer.py b/rl_engine/executors/deepspeed_trainer.py index 4a7acfb..8cfcc18 100644 --- a/rl_engine/executors/deepspeed_trainer.py +++ b/rl_engine/executors/deepspeed_trainer.py @@ -7,6 +7,7 @@ import os import sysconfig import time +from contextlib import nullcontext from dataclasses import dataclass, field from pathlib import Path from typing import Any, Mapping, Optional, TypeVar, overload @@ -26,12 +27,9 @@ TrainingStageResult, objective_reference_logps, ) -from rl_engine.testing import ( - compute_policy_ratio, - compute_reference_kl, - masked_mean, - selected_logprobs_reference, -) +from rl_engine.kernels.ops.pytorch.loss.linear_logp import NativeLinearLogpOp +from rl_engine.kernels.registry import kernel_registry +from rl_engine.testing import compute_policy_ratio, compute_reference_kl, masked_mean _TDestination = TypeVar("_TDestination", bound=dict[str, Any]) @@ -53,6 +51,25 @@ def __post_init__(self) -> None: raise ValueError("zero_stage must be >= 0") +class _EmbeddingLMHeadModel(torch.nn.Module): + def __init__( + self, + vocab_size: int, + hidden_dim: int, + *, + bias: bool = True, + tie_weights: bool = False, + ) -> None: + super().__init__() + self.embedding = torch.nn.Embedding(vocab_size, hidden_dim) + self.lm_head = torch.nn.Linear(hidden_dim, vocab_size, bias=bias) + if tie_weights: + self.lm_head.weight = self.embedding.weight + + def forward(self, token_ids: torch.Tensor) -> torch.Tensor: + return self.embedding(token_ids.long()) + + class DeepSpeedTrainingWorker(RolloutBatchMixin): """ Training worker implementation backed by a real DeepSpeed engine contract. @@ -84,43 +101,39 @@ def __init__( deepspeed = _load_deepspeed() self._deepspeed = deepspeed torch.manual_seed(self.config.seed) - self.model = torch.nn.Sequential( - torch.nn.Embedding(self.config.vocab_size, self.config.hidden_dim), - torch.nn.Linear(self.config.hidden_dim, self.config.vocab_size), + self.model = _EmbeddingLMHeadModel( + self.config.vocab_size, + self.config.hidden_dim, ).to(device=self.device) self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.config.lr) + self._deepspeed_config = self._resolved_deepspeed_config() + self._deepspeed_zero_stage = _resolved_zero_stage( + self._deepspeed_config, + fallback=self.config.zero_stage, + ) init_result = deepspeed.initialize( model=self.model, model_parameters=self.model.parameters(), optimizer=self.optimizer, - config=self._resolved_deepspeed_config(), + config=self._deepspeed_config, **dict(self.config.initialize_kwargs), ) self.engine = _first_initialize_result(init_result) engine_device = getattr(self.engine, "device", None) if engine_device is not None: self.device = torch.device(engine_device) + self._linear_logp = _linear_logp_op_for_device(self.device) def train(self, rollout: RolloutStageResult) -> TrainingStageResult: started_at = time.perf_counter() batch, payload_metrics = self._batch_from_rollout_or_synthetic(rollout) - - logits = _extract_logits(self.engine(batch.token_ids.long())) - current_logps = selected_logprobs_reference( - logits, + training_model = _unwrap_training_model(self.engine, self.model) + training_embedding = _embedding_layer(training_model) + _validate_model_input_token_ids( batch.token_ids, - mask=batch.completion_mask, - output_dtype=torch.float32, + vocab_size=training_embedding.num_embeddings, ) - old_logps = current_logps.detach() - 0.01 - ref_logps = objective_reference_logps(current_logps, batch) - ratio = compute_policy_ratio(current_logps, old_logps, batch.completion_mask) - unclipped = ratio * batch.advantages.float() - clipped = torch.clamp(ratio, 0.8, 1.2) * batch.advantages.float() - policy_loss = -torch.minimum(unclipped, clipped) - kl = compute_reference_kl(current_logps, ref_logps, batch.completion_mask) - loss = masked_mean(policy_loss + 0.01 * kl, batch.completion_mask) if hasattr(self.engine, "zero_grad"): try: @@ -129,7 +142,30 @@ def train(self, rollout: RolloutStageResult) -> TrainingStageResult: self.engine.zero_grad() elif hasattr(self.optimizer, "zero_grad"): self.optimizer.zero_grad(set_to_none=True) - self.engine.backward(loss) + + with _linear_logp_parameter_context( + self._deepspeed, + training_model, + zero_stage=self._deepspeed_zero_stage, + world_size=self._engine_world_size(), + ): + current_logps = _extract_logps( + self.engine(batch.token_ids.long()), + training_model, + batch.token_ids, + batch.completion_mask, + self._linear_logp, + output_dtype=torch.float32, + ) + old_logps = current_logps.detach() - 0.01 + ref_logps = objective_reference_logps(current_logps, batch) + ratio = compute_policy_ratio(current_logps, old_logps, batch.completion_mask) + unclipped = ratio * batch.advantages.float() + clipped = torch.clamp(ratio, 0.8, 1.2) * batch.advantages.float() + policy_loss = -torch.minimum(unclipped, clipped) + kl = compute_reference_kl(current_logps, ref_logps, batch.completion_mask) + loss = masked_mean(policy_loss + 0.01 * kl, batch.completion_mask) + self.engine.backward(loss) self.engine.step() finished_at = time.perf_counter() @@ -146,7 +182,9 @@ def train(self, rollout: RolloutStageResult) -> TrainingStageResult: "training_backend": "deepspeed", "training_device": str(self.device), "deepspeed_engine": type(self.engine).__name__, - "deepspeed_zero_stage": self.config.zero_stage, + "deepspeed_zero_stage": self._deepspeed_zero_stage, + "current_logp_path": "linear_logp", + "current_logp_backend": type(self._linear_logp).__name__, "active_advantage_mean_global": ( float(active_advantages.mean().detach().cpu().item()) if active_advantages.numel() @@ -180,14 +218,14 @@ def publish_weights( manifest_metadata = dict(metadata or {}) layout = { "kind": "full-state", - "zero_stage": self.config.zero_stage, + "zero_stage": self._deepspeed_zero_stage, "world_size": self._engine_world_size(), "rank": self._engine_rank(), } layout.update(dict(manifest_metadata.get("layout", {}))) manifest_metadata["layout"] = layout publish_model: torch.nn.Module = self.model - if self.config.zero_stage >= 3: + if self._deepspeed_zero_stage >= 3: publish_model, export_metadata = self._export_zero3_full_state_model() manifest_metadata["deepspeed_zero3_full_state_export"] = export_metadata return self.weight_bridge.publish( @@ -361,6 +399,15 @@ def _first_initialize_result(init_result: Any) -> Any: return init_result +def _resolved_zero_stage(config: Mapping[str, Any], *, fallback: int) -> int: + zero_config = config.get("zero_optimization") + if isinstance(zero_config, Mapping): + return int(zero_config.get("stage", fallback)) + if zero_config is False: + return 0 + return int(fallback) + + def _extract_logits(model_output: Any) -> torch.Tensor: if isinstance(model_output, torch.Tensor): return model_output @@ -374,6 +421,226 @@ def _extract_logits(model_output: Any) -> torch.Tensor: raise TypeError(f"DeepSpeed model output does not expose logits: {type(model_output)!r}") +def _extract_hidden_states( + model_output: Any, + *, + expected_hidden_dim: Optional[int] = None, +) -> torch.Tensor: + hidden = _coerce_hidden_tensor(model_output, expected_hidden_dim=expected_hidden_dim) + if hidden is None: + raise TypeError( + f"DeepSpeed model output does not expose a hidden-state tensor: {type(model_output)!r}" + ) + return hidden + + +def _linear_logp_op_for_device(device: torch.device | str) -> Any: + resolved = torch.device(device) + if resolved.type == "cpu": + return NativeLinearLogpOp() + return kernel_registry.get_op("linear_logp") + + +def _unwrap_training_model(engine: Any, fallback_model: torch.nn.Module) -> torch.nn.Module: + model = getattr(engine, "module", None) + if isinstance(model, torch.nn.Module): + return model + return fallback_model + + +def _embedding_layer(model: torch.nn.Module) -> torch.nn.Embedding: + embedding = getattr(model, "embedding", None) + if not isinstance(embedding, torch.nn.Embedding): + raise TypeError( + "DeepSpeed training model must expose an embedding torch.nn.Embedding for " + "model-input validation" + ) + return embedding + + +def _coerce_hidden_tensor( + candidate: Any, + *, + expected_hidden_dim: Optional[int] = None, +) -> Optional[torch.Tensor]: + if isinstance(candidate, torch.Tensor): + return candidate if _looks_like_hidden_tensor(candidate, expected_hidden_dim) else None + if isinstance(candidate, Mapping): + for key in ("last_hidden_state", "hidden"): + value = candidate.get(key) + hidden = _coerce_hidden_tensor(value, expected_hidden_dim=expected_hidden_dim) + if hidden is not None: + return hidden + hidden_states = candidate.get("hidden_states") + hidden = _last_hidden_state_tensor( + hidden_states, + expected_hidden_dim=expected_hidden_dim, + ) + if hidden is not None: + return hidden + return None + for attr in ("last_hidden_state", "hidden"): + if hasattr(candidate, attr): + hidden = _coerce_hidden_tensor( + getattr(candidate, attr), + expected_hidden_dim=expected_hidden_dim, + ) + if hidden is not None: + return hidden + if hasattr(candidate, "hidden_states"): + hidden = _last_hidden_state_tensor( + candidate.hidden_states, + expected_hidden_dim=expected_hidden_dim, + ) + if hidden is not None: + return hidden + if isinstance(candidate, (tuple, list)): + for item in candidate: + if _has_hidden_state_metadata(item): + hidden = _coerce_hidden_tensor(item, expected_hidden_dim=expected_hidden_dim) + if hidden is not None: + return hidden + tensor_candidates = [ + item + for item in candidate + if isinstance(item, torch.Tensor) + and _looks_like_hidden_tensor(item, expected_hidden_dim) + ] + if len(tensor_candidates) == 1: + return tensor_candidates[0] + if tensor_candidates: + max_ndim = max(tensor.ndim for tensor in tensor_candidates) + deepest = [tensor for tensor in tensor_candidates if tensor.ndim == max_ndim] + if len(deepest) == 1: + return deepest[0] + for item in candidate: + if isinstance(item, torch.Tensor): + continue + hidden = _coerce_hidden_tensor(item, expected_hidden_dim=expected_hidden_dim) + if hidden is not None: + return hidden + return None + + +def _has_hidden_state_metadata(candidate: Any) -> bool: + if isinstance(candidate, Mapping): + return any(key in candidate for key in ("last_hidden_state", "hidden", "hidden_states")) + return any( + hasattr(candidate, attr) for attr in ("last_hidden_state", "hidden", "hidden_states") + ) + + +def _looks_like_hidden_tensor( + tensor: torch.Tensor, + expected_hidden_dim: Optional[int], +) -> bool: + if tensor.ndim < 2: + return False + if expected_hidden_dim is not None and int(tensor.size(-1)) != int(expected_hidden_dim): + return False + return True + + +def _last_hidden_state_tensor( + candidate: Any, + *, + expected_hidden_dim: Optional[int] = None, +) -> Optional[torch.Tensor]: + if isinstance(candidate, torch.Tensor): + return candidate if _looks_like_hidden_tensor(candidate, expected_hidden_dim) else None + if isinstance(candidate, (tuple, list)): + for item in reversed(candidate): + hidden = _coerce_hidden_tensor(item, expected_hidden_dim=expected_hidden_dim) + if hidden is not None: + return hidden + return None + + +def _safe_token_ids(token_ids: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: + safe_token_ids = token_ids.long() + if mask is None: + return safe_token_ids + active_mask = mask.to(device=safe_token_ids.device, dtype=torch.bool) + if active_mask.shape != safe_token_ids.shape: + raise ValueError( + f"mask shape {tuple(active_mask.shape)} must match token_ids shape " + f"{tuple(safe_token_ids.shape)}" + ) + return safe_token_ids.masked_fill(~active_mask, 0) + + +def _validate_model_input_token_ids(token_ids: torch.Tensor, *, vocab_size: int) -> None: + invalid = (token_ids < 0) | (token_ids >= int(vocab_size)) + if bool(invalid.any().item()): + t_min = int(token_ids.min().item()) + t_max = int(token_ids.max().item()) + raise ValueError( + f"model input token_ids must be in [0, {int(vocab_size) - 1}], got " + f"[{t_min}, {t_max}]. Keep ignore-index / padding sentinels out of the model " + "input path and apply masking only at the logprob/loss stage." + ) + + +def _linear_logp_parameter_context( + deepspeed_runtime: Any, + model: torch.nn.Module, + *, + zero_stage: int, + world_size: int, +) -> Any: + if int(zero_stage) < 3 or int(world_size) <= 1: + return nullcontext() + + lm_head = getattr(model, "lm_head", None) + if not isinstance(lm_head, torch.nn.Linear): + raise TypeError( + "DeepSpeed training model must expose an lm_head torch.nn.Linear for ZeRO-3 " + "linear_logp gathering" + ) + + gathered_parameters = getattr( + getattr(deepspeed_runtime, "zero", None), + "GatheredParameters", + None, + ) + if not callable(gathered_parameters): + raise WeightBridgeUnavailableError( + "DeepSpeed ZeRO-3 linear_logp training requires deepspeed.zero.GatheredParameters " + "or an equivalent full-parameter gather API." + ) + + parameters = [lm_head.weight] + if lm_head.bias is not None: + parameters.append(lm_head.bias) + return gathered_parameters(parameters, modifier_rank=None) + + +def _extract_logps( + model_output: Any, + model: torch.nn.Module, + token_ids: torch.Tensor, + completion_mask: Optional[torch.Tensor], + linear_logp_op: Any, + *, + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + lm_head = getattr(model, "lm_head", None) + if not isinstance(lm_head, torch.nn.Linear): + raise TypeError( + "DeepSpeed training model must expose an lm_head torch.nn.Linear for linear_logp" + ) + + hidden = _extract_hidden_states( + model_output, + expected_hidden_dim=int(lm_head.in_features), + ) + targets = _safe_token_ids(token_ids.to(device=hidden.device), completion_mask) + logps = linear_logp_op(hidden, lm_head.weight, targets, lm_head.bias) + if completion_mask is not None: + logps = logps.masked_fill(~completion_mask.to(device=logps.device, dtype=torch.bool), 0.0) + return logps.to(dtype=output_dtype) + + class _StateDictModule(torch.nn.Module): def __init__(self, state_dict: Mapping[str, torch.Tensor]): super().__init__() diff --git a/rl_engine/kernels/ops/cuda/loss/linear_logp.py b/rl_engine/kernels/ops/cuda/loss/linear_logp.py index 44aa998..44f258c 100644 --- a/rl_engine/kernels/ops/cuda/loss/linear_logp.py +++ b/rl_engine/kernels/ops/cuda/loss/linear_logp.py @@ -3,23 +3,39 @@ from __future__ import annotations -from typing import Optional +from typing import Any, Optional import torch from rl_engine.kernels.ops.base import _C, _EXT_AVAILABLE -from rl_engine.kernels.ops.pytorch.loss.linear_logp import chunked_linear_logp_backward +from rl_engine.kernels.ops.pytorch.loss.linear_logp import ( + BWD_CHUNK_ELEMS, + _linear_logits, + _require_distributed_initialized, + _use_fp32_matmul, + _validate_global_targets, + _validate_tp_vocab_partition, + chunked_linear_logp_backward, + should_use_tensor_parallel_linear_logp, + tensor_parallel_linear_logp, +) from rl_engine.utils.logger import logger # Hidden-dim slice the SM90 forward streams per TMA load; D must be a multiple of # it (mirrors `constexpr int BK` in csrc/cuda/fused_linear_logp_sm90.cu). _SM90_BK = 32 +_SM90_TP_PATH_LOGGED = False def _sm90_supported(hidden: torch.Tensor, lm_head_weight: torch.Tensor) -> bool: """Whether the bf16 TMA+MMA forward can run these inputs directly.""" + if not (hidden.is_cuda and lm_head_weight.is_cuda): + return False + if hidden.device != lm_head_weight.device: + return False + cc_major, _ = torch.cuda.get_device_capability(hidden.device) return ( - hidden.is_cuda + cc_major == 9 and hidden.dtype == torch.bfloat16 and lm_head_weight.dtype == torch.bfloat16 and hidden.size(-1) % _SM90_BK == 0 @@ -83,6 +99,156 @@ def backward(ctx, grad_logp): return grad_hidden, grad_weight, grad_bias, None +class _FusedTensorParallelLinearLogpSM90Function(torch.autograd.Function): + """SM90 local-shard forward with tensor-parallel logsumexp reduction. + + Each rank runs the fused SM90 kernel over its local vocab shard to get local + log-sum-exp and the owned target logit. TP ranks then merge those states into + the global selected log-prob. Backward intentionally reuses the existing + chunked TP path so training gradients keep the same contract as the portable + implementation. + """ + + @staticmethod + def forward( + ctx, + hidden, + lm_head_weight, + bias, + target_ids, + vocab_start_index, + global_vocab_size, + tp_group, + ): + dist = _require_distributed_initialized() + + hidden_2d = hidden.reshape(-1, hidden.size(-1)).contiguous() + weight = lm_head_weight.contiguous() + target_1d = ( + target_ids.reshape(-1).to(device=hidden_2d.device, dtype=torch.long).contiguous() + ) + bias_t = bias.contiguous() if bias is not None else hidden_2d + vocab_start_index = int(vocab_start_index) + global_vocab_size = _validate_tp_vocab_partition( + tp_group=tp_group, + device=hidden_2d.device, + vocab_start_index=vocab_start_index, + local_vocab_size=weight.size(0), + global_vocab_size=global_vocab_size, + ) + _validate_global_targets(target_1d, global_vocab_size, tp_group) + + local_vocab = weight.size(0) + local_target = target_1d - vocab_start_index + owns_target = (local_target >= 0) & (local_target < local_vocab) + kernel_target = torch.where(local_target >= 0, local_target, torch.zeros_like(local_target)) + kernel_target = torch.where( + kernel_target < local_vocab, kernel_target, torch.zeros_like(kernel_target) + ) + kernel_target = kernel_target.to(torch.int32).contiguous() + + local_logp, local_lse = _C.fused_linear_logp_sm90(hidden_2d, weight, kernel_target, bias) + local_target_logit = torch.where( + owns_target, local_logp + local_lse, torch.zeros_like(local_lse) + ) + target_logit = local_target_logit.clone() + dist.all_reduce(target_logit, op=dist.ReduceOp.SUM, group=tp_group) + + global_lse_max = local_lse.clone() + dist.all_reduce(global_lse_max, op=dist.ReduceOp.MAX, group=tp_group) + global_lse_sum = torch.exp(local_lse - global_lse_max) + dist.all_reduce(global_lse_sum, op=dist.ReduceOp.SUM, group=tp_group) + global_lse = global_lse_max + torch.log(global_lse_sum) + + ctx.save_for_backward(hidden_2d, weight, bias_t, target_1d, global_lse) + ctx.has_bias = bias is not None + ctx.lead_shape = hidden.shape[:-1] + ctx.hidden_dtype = hidden.dtype + ctx.weight_dtype = lm_head_weight.dtype + ctx.bias_dtype = bias.dtype if bias is not None else None + ctx.vocab_start_index = vocab_start_index + ctx.tp_group = tp_group + return (target_logit - global_lse).reshape(hidden.shape[:-1]) + + @staticmethod + def backward(ctx, grad_logp): + dist = _require_distributed_initialized() + hidden_2d, weight, bias_t, target_1d, global_lse = ctx.saved_tensors + n, d = hidden_2d.shape + local_vocab = weight.shape[0] + dt = weight.dtype + g = grad_logp.reshape(-1).to(torch.float32) + + grad_h = torch.empty_like(hidden_2d, dtype=torch.float32) + grad_w = torch.zeros(local_vocab, d, device=weight.device, dtype=torch.float32) + grad_b = ( + torch.zeros(local_vocab, device=weight.device, dtype=torch.float32) + if ctx.has_bias + else None + ) + use_fp32 = _use_fp32_matmul(hidden_2d, weight) + + chunk = max(1, min(n, BWD_CHUNK_ELEMS // local_vocab)) + for i0 in range(0, n, chunk): + i1 = min(i0 + chunk, n) + x = hidden_2d[i0:i1] + logits = _linear_logits( + x, + weight, + bias_t if ctx.has_bias else None, + use_fp32=use_fp32, + ) + + dz = -torch.exp(logits.float() - global_lse[i0:i1].unsqueeze(1)) + local_idx = target_1d[i0:i1] - ctx.vocab_start_index + owns_target = (local_idx >= 0) & (local_idx < local_vocab) + if bool(owns_target.any().item()): + rows = torch.arange(i1 - i0, device=dz.device)[owns_target] + dz[rows, local_idx[owns_target].long()] += 1.0 + dz *= g[i0:i1].unsqueeze(1) + + if use_fp32: + grad_h[i0:i1] = torch.matmul(dz, weight.float()).float() + grad_w += torch.matmul(dz.t(), x.float()).float() + else: + dz_dt = dz.to(dt) + grad_h[i0:i1] = torch.matmul(dz_dt, weight).float() + grad_w += torch.matmul(dz_dt.t(), x).float() + if grad_b is not None: + grad_b += dz.sum(0) + + dist.all_reduce(grad_h, op=dist.ReduceOp.SUM, group=ctx.tp_group) + grad_hidden = grad_h.to(ctx.hidden_dtype).reshape((*ctx.lead_shape, d)) + grad_weight = grad_w.to(ctx.weight_dtype) + grad_bias = grad_b.to(ctx.bias_dtype) if grad_b is not None else None + return grad_hidden, grad_weight, grad_bias, None, None, None, None + + +def _sm90_tensor_parallel_linear_logp( + hidden: torch.Tensor, + lm_head_weight: torch.Tensor, + target_ids: torch.Tensor, + bias: Optional[torch.Tensor], + *, + tp_group: Any, + vocab_start_index: int, + global_vocab_size: Optional[int], +) -> torch.Tensor: + global _SM90_TP_PATH_LOGGED + if not _SM90_TP_PATH_LOGGED: + logger.info("Using fused_linear_logp_sm90 tensor-parallel local-shard path.") + _SM90_TP_PATH_LOGGED = True + return _FusedTensorParallelLinearLogpSM90Function.apply( + hidden, + lm_head_weight, + bias, + target_ids, + int(vocab_start_index), + None if global_vocab_size is None else int(global_vocab_size), + tp_group, + ) + + class FusedLinearLogpSM90Op: """SM90 (Hopper) TMA+WGMMA fused linear log-prob. @@ -105,8 +271,20 @@ def __call__( lm_head_weight: torch.Tensor, target_ids: torch.Tensor, bias: Optional[torch.Tensor] = None, + *, + tp_group: Any = None, + vocab_start_index: int = 0, + global_vocab_size: Optional[int] = None, ) -> torch.Tensor: - return self.apply(hidden, lm_head_weight, target_ids, bias) + return self.apply( + hidden, + lm_head_weight, + target_ids, + bias, + tp_group=tp_group, + vocab_start_index=vocab_start_index, + global_vocab_size=global_vocab_size, + ) def apply( self, @@ -114,6 +292,10 @@ def apply( lm_head_weight: torch.Tensor, target_ids: torch.Tensor, bias: Optional[torch.Tensor] = None, + *, + tp_group: Any = None, + vocab_start_index: int = 0, + global_vocab_size: Optional[int] = None, ) -> torch.Tensor: if lm_head_weight.size(-1) != hidden.size(-1): raise ValueError( @@ -140,6 +322,31 @@ def apply( raise ValueError( f"bias device {bias.device} must match hidden device {hidden.device}" ) + if should_use_tensor_parallel_linear_logp( + tp_group, + int(vocab_start_index), + global_vocab_size, + lm_head_weight.size(0), + ): + if _sm90_supported(hidden, lm_head_weight): + return _sm90_tensor_parallel_linear_logp( + hidden, + lm_head_weight, + target_ids, + bias, + tp_group=tp_group, + vocab_start_index=vocab_start_index, + global_vocab_size=global_vocab_size, + ) + return tensor_parallel_linear_logp( + hidden, + lm_head_weight, + target_ids, + bias, + tp_group=tp_group, + vocab_start_index=vocab_start_index, + global_vocab_size=global_vocab_size, + ) if not _sm90_supported(hidden, lm_head_weight): return _fallback_op()(hidden, lm_head_weight, target_ids, bias) vocab = lm_head_weight.size(0) diff --git a/rl_engine/kernels/ops/cuda/loss/logp.py b/rl_engine/kernels/ops/cuda/loss/logp.py index 1742b9d..30aaa4a 100644 --- a/rl_engine/kernels/ops/cuda/loss/logp.py +++ b/rl_engine/kernels/ops/cuda/loss/logp.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # Copyright (c) 2026 RL-Kernel Contributors +import os + import torch from rl_engine.kernels.ops.base import _C, _EXT_AVAILABLE @@ -17,14 +19,75 @@ def __init__(self): "Please rebuild extension using 'pip install -e .'" ) self.op = _C.fused_logp_sm90 + self._fallback = None logger.info("Successfully linked to precompiled _C.fused_logp_sm90 kernel.") def __call__(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: - assert logits.dtype == torch.bfloat16, "TMA logp currently requires bfloat16 logits" - assert logits.is_contiguous(), "Logits must be contiguous for TMA block loading" + return self.apply(logits, labels) + + def _fallback_op(self): + if self._fallback is None: + self._fallback = FusedLogpGenericOp() + return self._fallback + + def _can_use_sm90(self, logits: torch.Tensor) -> bool: + return ( + os.getenv("RL_KERNEL_ENABLE_EXPERIMENTAL_SM90_LOGP") == "1" + and logits.dim() == 2 + and logits.dtype == torch.bfloat16 + and logits.is_contiguous() + ) + + def apply(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + if not self._can_use_sm90(logits): + return self._fallback_op().apply(logits, labels) labels_fused = labels.to(device=logits.device, dtype=torch.int32).contiguous() return self.op(logits, labels_fused) + def apply_fp32(self, logits: torch.Tensor, token_ids: torch.Tensor) -> torch.Tensor: + return self._fallback_op().apply_fp32(logits, token_ids) + + def out( + self, logits: torch.Tensor, token_ids: torch.Tensor, output: torch.Tensor + ) -> torch.Tensor: + return self._fallback_op().out(logits, token_ids, output) + + def indexed_out( + self, + logits: torch.Tensor, + token_ids: torch.Tensor, + row_indices: torch.Tensor, + output: torch.Tensor, + ) -> torch.Tensor: + return self._fallback_op().indexed_out(logits, token_ids, row_indices, output) + + def indexed_fp32( + self, logits: torch.Tensor, token_ids: torch.Tensor, row_indices: torch.Tensor + ) -> torch.Tensor: + return self._fallback_op().indexed_fp32(logits, token_ids, row_indices) + + def online_out( + self, logits: torch.Tensor, token_ids: torch.Tensor, output: torch.Tensor + ) -> torch.Tensor: + return self._fallback_op().online_out(logits, token_ids, output) + + def online_fp32(self, logits: torch.Tensor, token_ids: torch.Tensor) -> torch.Tensor: + return self._fallback_op().online_fp32(logits, token_ids) + + def online_indexed_out( + self, + logits: torch.Tensor, + token_ids: torch.Tensor, + row_indices: torch.Tensor, + output: torch.Tensor, + ) -> torch.Tensor: + return self._fallback_op().online_indexed_out(logits, token_ids, row_indices, output) + + def online_indexed_fp32( + self, logits: torch.Tensor, token_ids: torch.Tensor, row_indices: torch.Tensor + ) -> torch.Tensor: + return self._fallback_op().online_indexed_fp32(logits, token_ids, row_indices) + class FusedLogpGenericOp: """Generic custom CUDA fallback Fused LogP with RL variants.""" diff --git a/rl_engine/kernels/ops/pytorch/loss/linear_logp.py b/rl_engine/kernels/ops/pytorch/loss/linear_logp.py index 5e64955..9bb1704 100644 --- a/rl_engine/kernels/ops/pytorch/loss/linear_logp.py +++ b/rl_engine/kernels/ops/pytorch/loss/linear_logp.py @@ -3,7 +3,7 @@ from __future__ import annotations -from typing import Optional +from typing import Any, Optional import torch @@ -11,6 +11,373 @@ # elements per cuBLAS step so peak backward memory stays ~``chunk*V`` instead of # ``N*V``. BWD_CHUNK_ELEMS = 1 << 24 +_LOW_PRECISION_DTYPES = (torch.float16, torch.bfloat16) + + +def _use_fp32_matmul(*tensors: torch.Tensor) -> bool: + return any(tensor.dtype in _LOW_PRECISION_DTYPES for tensor in tensors) + + +def _matmul_operand(tensor: torch.Tensor, use_fp32: bool) -> torch.Tensor: + return tensor.float() if use_fp32 else tensor + + +def _linear_logits( + hidden_2d: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + *, + use_fp32: bool, +) -> torch.Tensor: + logits = torch.matmul( + _matmul_operand(hidden_2d, use_fp32), + _matmul_operand(weight, use_fp32).t(), + ) + if bias is not None: + logits = logits + _matmul_operand(bias, use_fp32) + return logits + + +def _require_distributed_initialized(): + import torch.distributed as dist + + if not dist.is_available(): + raise RuntimeError("tensor-parallel linear_logp requires torch.distributed.") + if not dist.is_initialized(): + raise RuntimeError("tensor-parallel linear_logp requires an initialized process group.") + return dist + + +def _tensor_parallel_world_size(tp_group: Any) -> int: + if tp_group is None: + return 1 + dist = _require_distributed_initialized() + return dist.get_world_size(group=tp_group) + + +def should_use_tensor_parallel_linear_logp( + tp_group: Any, + vocab_start_index: int, + global_vocab_size: Optional[int], + local_vocab_size: int, +) -> bool: + """Whether a linear_logp call describes a vocab-parallel weight shard.""" + explicit_tp = tp_group is not None or vocab_start_index != 0 or global_vocab_size is not None + if local_vocab_size <= 0 and not explicit_tp: + raise ValueError("lm_head_weight must contain at least one vocab row.") + if not explicit_tp: + return False + + world_size = _tensor_parallel_world_size(tp_group) + if local_vocab_size <= 0 and world_size <= 1: + raise ValueError("lm_head_weight must contain at least one vocab row.") + if world_size <= 1: + if vocab_start_index != 0: + raise ValueError("vocab_start_index requires a tensor-parallel group.") + if global_vocab_size is not None and int(global_vocab_size) != local_vocab_size: + raise ValueError( + "global_vocab_size differs from the local vocab size, but no " + "multi-rank tensor-parallel group was provided." + ) + return False + return True + + +def _validate_tp_vocab_partition( + *, + tp_group: Any, + device: torch.device, + vocab_start_index: int, + local_vocab_size: int, + global_vocab_size: Optional[int], +) -> int: + dist = _require_distributed_initialized() + local_end = vocab_start_index + local_vocab_size + local_range = torch.tensor([vocab_start_index, local_end], device=device, dtype=torch.long) + ranges_t = [torch.empty_like(local_range) for _ in range(dist.get_world_size(tp_group))] + dist.all_gather(ranges_t, local_range, group=tp_group) + + ranges = sorted((int(r[0].item()), int(r[1].item())) for r in ranges_t) + expected_start = 0 + for start, end in ranges: + if end <= start: + raise ValueError(f"invalid TP vocab shard range [{start}, {end}).") + if start != expected_start: + raise ValueError( + "TP vocab shards must form a contiguous [0, V) partition; " f"got ranges={ranges}." + ) + expected_start = end + + covered_vocab_size = expected_start + global_size = torch.tensor( + [0, 0 if global_vocab_size is None else int(global_vocab_size)], + device=device, + dtype=torch.long, + ) + global_sizes_t = [torch.empty_like(global_size) for _ in range(dist.get_world_size(tp_group))] + dist.all_gather(global_sizes_t, global_size, group=tp_group) + invalid_sizes = [ + int(value[1].item()) + for value in global_sizes_t + if int(value[0].item()) and int(value[1].item()) != covered_vocab_size + ] + if invalid_sizes: + raise ValueError( + "global_vocab_size must match the TP vocab partition size: " + f"got {invalid_sizes[0]}, covered {covered_vocab_size}." + ) + return covered_vocab_size if global_vocab_size is None else int(global_vocab_size) + + +def _validate_global_targets( + target_1d: torch.Tensor, + global_vocab_size: int, + tp_group: Any = None, +) -> None: + invalid = (target_1d < 0) | (target_1d >= global_vocab_size) + local_invalid = bool(invalid.any().item()) + if tp_group is not None: + dist = _require_distributed_initialized() + invalid_flag = torch.tensor(int(local_invalid), device=target_1d.device, dtype=torch.int32) + dist.all_reduce(invalid_flag, op=dist.ReduceOp.MAX, group=tp_group) + if target_1d.numel(): + min_target = torch.tensor( + int(target_1d.min().item()), device=target_1d.device, dtype=torch.long + ) + max_target = torch.tensor( + int(target_1d.max().item()), device=target_1d.device, dtype=torch.long + ) + else: + min_target = torch.tensor(global_vocab_size, device=target_1d.device, dtype=torch.long) + max_target = torch.tensor(-1, device=target_1d.device, dtype=torch.long) + dist.all_reduce(min_target, op=dist.ReduceOp.MIN, group=tp_group) + dist.all_reduce(max_target, op=dist.ReduceOp.MAX, group=tp_group) + local_invalid = bool(invalid_flag.item()) + t_min, t_max = int(min_target.item()), int(max_target.item()) + elif local_invalid: + t_min, t_max = int(target_1d.min().item()), int(target_1d.max().item()) + if local_invalid: + raise ValueError( + f"target_ids out of range: expected [0, {global_vocab_size - 1}], " + f"got [{t_min}, {t_max}]. Mask or filter padding / ignore-index values " + "(e.g. -100) before this op." + ) + + +def _chunked_local_linear_logp_stats( + hidden_2d: torch.Tensor, + weight: torch.Tensor, + target_1d: torch.Tensor, + bias_t: torch.Tensor, + *, + has_bias: bool, + vocab_start_index: int, + chunk_elems: int = BWD_CHUNK_ELEMS, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + n = hidden_2d.size(0) + local_vocab = weight.size(0) + device = hidden_2d.device + + local_max = torch.full((n,), -torch.inf, device=device, dtype=torch.float32) + local_sum = torch.zeros(n, device=device, dtype=torch.float32) + local_target_logit = torch.zeros(n, device=device, dtype=torch.float32) + owner_count = torch.zeros(n, device=device, dtype=torch.int32) + rows = torch.arange(n, device=device) + use_fp32 = _use_fp32_matmul(hidden_2d, weight) + + vocab_chunk = max(1, min(local_vocab, chunk_elems // max(n, 1))) + for v0 in range(0, local_vocab, vocab_chunk): + v1 = min(v0 + vocab_chunk, local_vocab) + logits = _linear_logits( + hidden_2d, + weight[v0:v1], + bias_t[v0:v1] if has_bias else None, + use_fp32=use_fp32, + ) + logits_f = logits.float() + + tile_max = logits_f.max(dim=-1).values + new_max = torch.maximum(local_max, tile_max) + local_sum = local_sum * torch.exp(local_max - new_max) + torch.exp( + logits_f - new_max.unsqueeze(1) + ).sum(dim=-1) + local_max = new_max + + global_v0 = vocab_start_index + v0 + global_v1 = vocab_start_index + v1 + owns_target = (target_1d >= global_v0) & (target_1d < global_v1) + if bool(owns_target.any().item()): + local_idx = (target_1d[owns_target] - global_v0).long() + local_target_logit[owns_target] = logits_f[rows[owns_target], local_idx] + owner_count[owns_target] += 1 + + return local_max, local_sum, local_target_logit, owner_count + + +class _TensorParallelLinearLogpFunction(torch.autograd.Function): + """Autograd path for vocab-sharded LM-head tensor parallelism.""" + + @staticmethod + def forward( + ctx, + hidden, + lm_head_weight, + bias, + target_ids, + vocab_start_index, + global_vocab_size, + tp_group, + ): + dist = _require_distributed_initialized() + hidden_2d = hidden.reshape(-1, hidden.size(-1)).contiguous() + weight = lm_head_weight.contiguous() + target_1d = ( + target_ids.reshape(-1).to(device=hidden_2d.device, dtype=torch.long).contiguous() + ) + bias_t = bias.contiguous() if bias is not None else hidden_2d + vocab_start_index = int(vocab_start_index) + global_vocab_size = _validate_tp_vocab_partition( + tp_group=tp_group, + device=hidden_2d.device, + vocab_start_index=vocab_start_index, + local_vocab_size=weight.size(0), + global_vocab_size=global_vocab_size, + ) + _validate_global_targets(target_1d, global_vocab_size, tp_group) + + local_max, local_sum, local_target_logit, owner_count = _chunked_local_linear_logp_stats( + hidden_2d, + weight, + target_1d, + bias_t, + has_bias=bias is not None, + vocab_start_index=vocab_start_index, + ) + + global_max = local_max.clone() + dist.all_reduce(global_max, op=dist.ReduceOp.MAX, group=tp_group) + global_sum = local_sum * torch.exp(local_max - global_max) + dist.all_reduce(global_sum, op=dist.ReduceOp.SUM, group=tp_group) + + target_logit = local_target_logit.clone() + dist.all_reduce(target_logit, op=dist.ReduceOp.SUM, group=tp_group) + + global_owner_count = owner_count.clone() + dist.all_reduce(global_owner_count, op=dist.ReduceOp.SUM, group=tp_group) + if bool((global_owner_count != 1).any().item()): + raise ValueError( + "target_ids must be covered by exactly one TP vocab shard; check " + "vocab_start_index and global_vocab_size." + ) + + lse = global_max + torch.log(global_sum) + ctx.save_for_backward(hidden_2d, weight, bias_t, target_1d, lse) + ctx.has_bias = bias is not None + ctx.lead_shape = hidden.shape[:-1] + ctx.hidden_dtype = hidden.dtype + ctx.weight_dtype = lm_head_weight.dtype + ctx.bias_dtype = bias.dtype if bias is not None else None + ctx.vocab_start_index = vocab_start_index + ctx.tp_group = tp_group + return (target_logit - lse).reshape(hidden.shape[:-1]) + + @staticmethod + def backward(ctx, grad_logp): + dist = _require_distributed_initialized() + hidden_2d, weight, bias_t, target_1d, lse = ctx.saved_tensors + n, d = hidden_2d.shape + local_vocab = weight.shape[0] + dt = weight.dtype + g = grad_logp.reshape(-1).to(torch.float32) + + grad_h = torch.empty_like(hidden_2d, dtype=torch.float32) + grad_w = torch.zeros(local_vocab, d, device=weight.device, dtype=torch.float32) + grad_b = ( + torch.zeros(local_vocab, device=weight.device, dtype=torch.float32) + if ctx.has_bias + else None + ) + use_fp32 = _use_fp32_matmul(hidden_2d, weight) + + chunk = max(1, min(n, BWD_CHUNK_ELEMS // local_vocab)) + for i0 in range(0, n, chunk): + i1 = min(i0 + chunk, n) + x = hidden_2d[i0:i1] + logits = _linear_logits( + x, + weight, + bias_t if ctx.has_bias else None, + use_fp32=use_fp32, + ) + + dz = -torch.exp(logits.float() - lse[i0:i1].unsqueeze(1)) + local_idx = target_1d[i0:i1] - ctx.vocab_start_index + owns_target = (local_idx >= 0) & (local_idx < local_vocab) + if bool(owns_target.any().item()): + rows = torch.arange(i1 - i0, device=dz.device)[owns_target] + dz[rows, local_idx[owns_target].long()] += 1.0 + dz *= g[i0:i1].unsqueeze(1) + + if use_fp32: + grad_h[i0:i1] = torch.matmul(dz, weight.float()).float() + grad_w += torch.matmul(dz.t(), x.float()).float() + else: + dz_dt = dz.to(dt) + grad_h[i0:i1] = torch.matmul(dz_dt, weight).float() + grad_w += torch.matmul(dz_dt.t(), x).float() + if grad_b is not None: + grad_b += dz.sum(0) + + dist.all_reduce(grad_h, op=dist.ReduceOp.SUM, group=ctx.tp_group) + grad_hidden = grad_h.to(ctx.hidden_dtype).reshape((*ctx.lead_shape, d)) + grad_weight = grad_w.to(ctx.weight_dtype) + grad_bias = grad_b.to(ctx.bias_dtype) if grad_b is not None else None + return grad_hidden, grad_weight, grad_bias, None, None, None, None + + +def tensor_parallel_linear_logp( + hidden: torch.Tensor, + lm_head_weight: torch.Tensor, + target_ids: torch.Tensor, + bias: Optional[torch.Tensor] = None, + *, + tp_group: Any, + vocab_start_index: int = 0, + global_vocab_size: Optional[int] = None, +) -> torch.Tensor: + if hidden.shape[:-1] != target_ids.shape: + raise ValueError( + f"hidden leading shape {tuple(hidden.shape[:-1])} must match " + f"target_ids shape {tuple(target_ids.shape)}" + ) + if lm_head_weight.size(-1) != hidden.size(-1): + raise ValueError( + f"hidden dim {hidden.size(-1)} must match lm_head_weight dim " + f"{lm_head_weight.size(-1)}" + ) + if lm_head_weight.device != hidden.device: + raise ValueError( + f"lm_head_weight device {lm_head_weight.device} must match hidden " + f"device {hidden.device}" + ) + if bias is not None: + if bias.ndim != 1 or bias.numel() != lm_head_weight.size(0): + raise ValueError( + f"bias must be 1-D with local V={lm_head_weight.size(0)} elements, " + f"got shape {tuple(bias.shape)}" + ) + if bias.device != hidden.device: + raise ValueError(f"bias device {bias.device} must match hidden device {hidden.device}") + + return _TensorParallelLinearLogpFunction.apply( + hidden, + lm_head_weight, + bias, + target_ids, + int(vocab_start_index), + None if global_vocab_size is None else int(global_vocab_size), + tp_group, + ) def chunked_linear_logp_backward( @@ -36,14 +403,18 @@ def chunked_linear_logp_backward( grad_h = torch.empty_like(hidden_2d, dtype=torch.float32) grad_w = torch.zeros(v, d, device=weight.device, dtype=torch.float32) grad_b = torch.zeros(v, device=weight.device, dtype=torch.float32) if has_bias else None + use_fp32 = _use_fp32_matmul(hidden_2d, weight) chunk = max(1, min(n, chunk_elems // v)) for i0 in range(0, n, chunk): i1 = min(i0 + chunk, n) x = hidden_2d[i0:i1] # [C, D] - logits = torch.matmul(x, weight.t()) # [C, V] - if has_bias: - logits = logits + bias_t + logits = _linear_logits( + x, + weight, + bias_t if has_bias else None, + use_fp32=use_fp32, + ) # dz = g * (onehot - softmax(logits)), recomputed from scratch so it is # self-normalizing and independent of the forward's saved lse. @@ -52,9 +423,13 @@ def chunked_linear_logp_backward( dz[rows, target_1d[i0:i1].long()] += 1.0 dz *= g[i0:i1].unsqueeze(1) - dz_dt = dz.to(dt) - grad_h[i0:i1] = torch.matmul(dz_dt, weight).float() # [C, D] - grad_w += torch.matmul(dz_dt.t(), x).float() # [V, D] + if use_fp32: + grad_h[i0:i1] = torch.matmul(dz, weight.float()).float() # [C, D] + grad_w += torch.matmul(dz.t(), x.float()).float() # [V, D] + else: + dz_dt = dz.to(dt) + grad_h[i0:i1] = torch.matmul(dz_dt, weight).float() # [C, D] + grad_w += torch.matmul(dz_dt.t(), x).float() # [V, D] if grad_b is not None: grad_b += dz.sum(0) @@ -83,8 +458,20 @@ def __call__( lm_head_weight: torch.Tensor, target_ids: torch.Tensor, bias: Optional[torch.Tensor] = None, + *, + tp_group: Any = None, + vocab_start_index: int = 0, + global_vocab_size: Optional[int] = None, ) -> torch.Tensor: - return self.apply(hidden, lm_head_weight, target_ids, bias) + return self.apply( + hidden, + lm_head_weight, + target_ids, + bias, + tp_group=tp_group, + vocab_start_index=vocab_start_index, + global_vocab_size=global_vocab_size, + ) def apply( self, @@ -92,6 +479,10 @@ def apply( lm_head_weight: torch.Tensor, target_ids: torch.Tensor, bias: Optional[torch.Tensor] = None, + *, + tp_group: Any = None, + vocab_start_index: int = 0, + global_vocab_size: Optional[int] = None, ) -> torch.Tensor: """Selected-token log-prob ``z[t] - logsumexp(z)``, returned in float32.""" if hidden.shape[:-1] != target_ids.shape: @@ -104,6 +495,21 @@ def apply( f"hidden dim {hidden.size(-1)} must match lm_head_weight dim " f"{lm_head_weight.size(-1)}" ) + if should_use_tensor_parallel_linear_logp( + tp_group, + int(vocab_start_index), + global_vocab_size, + lm_head_weight.size(0), + ): + return tensor_parallel_linear_logp( + hidden, + lm_head_weight, + target_ids, + bias, + tp_group=tp_group, + vocab_start_index=vocab_start_index, + global_vocab_size=global_vocab_size, + ) lead_shape = hidden.shape[:-1] hidden_2d = hidden.reshape(-1, hidden.size(-1)) diff --git a/rl_engine/kernels/ops/triton/loss/linear_logp.py b/rl_engine/kernels/ops/triton/loss/linear_logp.py index 4ef9ce1..d3a435c 100644 --- a/rl_engine/kernels/ops/triton/loss/linear_logp.py +++ b/rl_engine/kernels/ops/triton/loss/linear_logp.py @@ -2,13 +2,17 @@ # Copyright (c) 2026 RL-Kernel Contributors from __future__ import annotations -from typing import Optional +from typing import Any, Optional import torch import triton import triton.language as tl -from rl_engine.kernels.ops.pytorch.loss.linear_logp import chunked_linear_logp_backward +from rl_engine.kernels.ops.pytorch.loss.linear_logp import ( + chunked_linear_logp_backward, + should_use_tensor_parallel_linear_logp, + tensor_parallel_linear_logp, +) # Token / vocab / hidden tile sizes (forward Triton kernel). _BLOCK_N = 32 @@ -165,8 +169,20 @@ def __call__( lm_head_weight: torch.Tensor, target_ids: torch.Tensor, bias: Optional[torch.Tensor] = None, + *, + tp_group: Any = None, + vocab_start_index: int = 0, + global_vocab_size: Optional[int] = None, ) -> torch.Tensor: - return self.apply(hidden, lm_head_weight, target_ids, bias) + return self.apply( + hidden, + lm_head_weight, + target_ids, + bias, + tp_group=tp_group, + vocab_start_index=vocab_start_index, + global_vocab_size=global_vocab_size, + ) def apply( self, @@ -174,6 +190,10 @@ def apply( lm_head_weight: torch.Tensor, target_ids: torch.Tensor, bias: Optional[torch.Tensor] = None, + *, + tp_group: Any = None, + vocab_start_index: int = 0, + global_vocab_size: Optional[int] = None, ) -> torch.Tensor: if hidden.device.type not in ("cuda", "xpu", "hip"): raise RuntimeError( @@ -190,6 +210,21 @@ def apply( f"hidden dim {hidden.size(-1)} must match lm_head_weight dim " f"{lm_head_weight.size(-1)}" ) + if should_use_tensor_parallel_linear_logp( + tp_group, + int(vocab_start_index), + global_vocab_size, + lm_head_weight.size(0), + ): + return tensor_parallel_linear_logp( + hidden, + lm_head_weight, + target_ids, + bias, + tp_group=tp_group, + vocab_start_index=vocab_start_index, + global_vocab_size=global_vocab_size, + ) vocab = lm_head_weight.size(0) if bool(((target_ids < 0) | (target_ids >= vocab)).any()): t_min, t_max = int(target_ids.min()), int(target_ids.max()) diff --git a/rl_engine/kernels/registry.py b/rl_engine/kernels/registry.py index 6780157..058eb6d 100644 --- a/rl_engine/kernels/registry.py +++ b/rl_engine/kernels/registry.py @@ -135,8 +135,7 @@ def _adjust_priority_from_env(self): ) def _adjust_priority_for_hardware(self): - """Prioritize the fused TMA LogP kernel only when it is compiled into the - extension and the device is TMA-capable (SM90/100/120).""" + """Adjust CUDA priorities for hardware-gated experimental and production kernels.""" if device_ctx.device_type != "cuda": return try: @@ -148,10 +147,11 @@ def _adjust_priority_for_hardware(self): cc = cc_major * 10 + cc_minor tma_compiled = _EXT_AVAILABLE and hasattr(_C, "fused_logp_sm90") - if tma_compiled and cc_major in (9, 10, 12): + sm90_logp_enabled = os.getenv("RL_KERNEL_ENABLE_EXPERIMENTAL_SM90_LOGP") == "1" + if sm90_logp_enabled and tma_compiled and cc_major in (9, 10, 12): logger.info( f"Detected TMA-capable architecture (SM{cc}); " - "prioritizing fused TMA LogP kernel." + "prioritizing experimental fused TMA LogP kernel." ) logp_list = self._priority_map["cuda"]["logp"] if OpBackend.CUDA_FUSED_LOGP_SM90 not in logp_list: @@ -166,8 +166,8 @@ def _adjust_priority_for_hardware(self): ll_list.insert(0, OpBackend.CUDA_FUSED_LINEAR_LOGP_SM90) elif cc >= 90: logger.debug( - f"SM{cc}: fused TMA LogP kernel not compiled into _C; " - "using generic fused kernel." + f"SM{cc}: fused linear-logp SM90 kernel not compiled into _C; " + "using generic linear-logp backend." ) except Exception as e: logger.warning(f"Failed to probe device capability: {e}") diff --git a/scripts/test_linear_logp_tp.py b/scripts/test_linear_logp_tp.py new file mode 100644 index 0000000..87dc379 --- /dev/null +++ b/scripts/test_linear_logp_tp.py @@ -0,0 +1,484 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +"""Distributed TP validation for the fused linear_logp operator. + +Launch with torchrun, for example: + + torchrun --standalone --nproc_per_node=4 scripts/test_linear_logp_tp.py + +The correctness phase compares the tensor-parallel path against a materialized +full-vocab reference. The optional stress phase skips the full reference and only +checks that larger vocab-sharded runs complete with finite outputs/gradients. +""" + +from __future__ import annotations + +import argparse +import os +import sys +import time +from pathlib import Path +from typing import Any, Optional + +import torch +import torch.distributed as dist +import torch.nn.functional as F + +REPO_ROOT = Path(__file__).resolve().parents[1] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + + +def _parse_dtype(name: str) -> torch.dtype: + lowered = name.lower() + if lowered in {"bf16", "bfloat16"}: + return torch.bfloat16 + if lowered in {"fp16", "float16", "half"}: + return torch.float16 + if lowered in {"fp32", "float32", "float"}: + return torch.float32 + raise ValueError(f"unsupported dtype: {name}") + + +def _dtype_default_atol(dtype: torch.dtype, reference_mode: str) -> float: + if dtype == torch.float32: + return 1e-4 + if reference_mode == "fp32": + return 8e-2 + return 3e-2 + + +def _dtype_default_rtol(dtype: torch.dtype, reference_mode: str) -> float: + if dtype == torch.float32: + return 1e-4 + if reference_mode == "fp32": + return 8e-2 + return 3e-2 + + +def _rank_env() -> tuple[int, int, int]: + rank = int(os.environ.get("RANK", "0")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + return rank, local_rank, world_size + + +def _init_distributed() -> tuple[int, int, int, torch.device, str]: + if not dist.is_available(): + raise RuntimeError("torch.distributed is not available in this PyTorch build.") + + rank, local_rank, world_size = _rank_env() + if world_size < 2: + raise RuntimeError("Run this script with torchrun and at least 2 processes.") + + if torch.cuda.is_available(): + if local_rank >= torch.cuda.device_count(): + raise RuntimeError( + f"LOCAL_RANK={local_rank} but only {torch.cuda.device_count()} CUDA devices exist." + ) + torch.cuda.set_device(local_rank) + device = torch.device("cuda", local_rank) + backend = "nccl" + else: + device = torch.device("cpu") + backend = "gloo" + + if not dist.is_initialized(): + dist.init_process_group(backend=backend) + return rank, local_rank, world_size, device, backend + + +def _print_rank0(rank: int, message: str) -> None: + if rank == 0: + print(message, flush=True) + + +def _generator(device: torch.device, seed: int) -> torch.Generator: + return torch.Generator(device=device).manual_seed(seed) + + +def _make_boundaries(vocab_size: int, world_size: int, uneven: bool) -> list[int]: + if vocab_size < world_size: + raise ValueError("vocab_size must be >= world_size so every rank has a shard.") + + sizes = [vocab_size // world_size for _ in range(world_size)] + sizes[-1] += vocab_size % world_size + + if uneven and world_size > 1: + for rank in range(world_size - 1): + move = min(rank + 1, sizes[rank] - 1) + sizes[rank] -= move + sizes[-1] += move + + boundaries = [0] + for size in sizes: + boundaries.append(boundaries[-1] + size) + if boundaries[-1] != vocab_size: + raise AssertionError("internal shard boundary construction failed") + return boundaries + + +def _materialized_logp( + hidden: torch.Tensor, + weight: torch.Tensor, + target: torch.Tensor, + bias: Optional[torch.Tensor], + *, + reference_mode: str, +) -> torch.Tensor: + if reference_mode == "fp32": + logits = F.linear(hidden.float(), weight.float(), None if bias is None else bias.float()) + else: + logits = F.linear(hidden, weight, bias).float() + target_1d = target.reshape(-1).to(device=logits.device, dtype=torch.long) + selected = torch.gather( + torch.log_softmax(logits.float(), dim=-1), + dim=-1, + index=target_1d.unsqueeze(1), + ).squeeze(1) + return selected.reshape(target.shape) + + +def _load_op(source: str) -> Any: + if source == "registry": + from rl_engine.kernels.registry import kernel_registry + + return kernel_registry.get_op("linear_logp") + if source == "native": + from rl_engine.kernels.ops.pytorch.loss.linear_logp import NativeLinearLogpOp + + return NativeLinearLogpOp() + if source == "triton": + from rl_engine.kernels.ops.triton.loss.linear_logp import TritonLinearLogpOp + + return TritonLinearLogpOp() + if source == "sm90": + from rl_engine.kernels.ops.cuda.loss.linear_logp import FusedLinearLogpSM90Op + + return FusedLinearLogpSM90Op() + raise ValueError(f"unknown op source: {source}") + + +def _max_abs(actual: torch.Tensor, expected: torch.Tensor) -> float: + return float((actual.float() - expected.float()).abs().max().item()) + + +def _max_rel(actual: torch.Tensor, expected: torch.Tensor) -> float: + diff = (actual.float() - expected.float()).abs() + denom = expected.float().abs().clamp_min(1e-8) + return float((diff / denom).max().item()) + + +def _reduce_max(value: float, device: torch.device) -> float: + tensor = torch.tensor(value, device=device, dtype=torch.float64) + dist.all_reduce(tensor, op=dist.ReduceOp.MAX) + return float(tensor.item()) + + +def _reduce_min_int(value: bool, device: torch.device) -> bool: + tensor = torch.tensor(1 if value else 0, device=device, dtype=torch.int32) + dist.all_reduce(tensor, op=dist.ReduceOp.MIN) + return bool(tensor.item()) + + +def _synchronize(device: torch.device) -> None: + if device.type == "cuda": + torch.cuda.synchronize(device) + + +def _reset_peak_memory(device: torch.device) -> None: + if device.type == "cuda": + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats(device) + + +def _peak_memory_gb(device: torch.device) -> float: + if device.type != "cuda": + return 0.0 + return torch.cuda.max_memory_allocated(device) / (1024**3) + + +def _time_block(device: torch.device, fn) -> float: + _synchronize(device) + start = time.perf_counter() + fn() + _synchronize(device) + return (time.perf_counter() - start) * 1000.0 + + +def _check_metric( + *, + name: str, + actual: torch.Tensor, + expected: torch.Tensor, + atol: float, + rtol: float, + device: torch.device, +) -> tuple[bool, str]: + local_abs = _max_abs(actual, expected) + local_rel = _max_rel(actual, expected) + local_ok = bool(torch.allclose(actual.float(), expected.float(), atol=atol, rtol=rtol)) + max_abs = _reduce_max(local_abs, device) + max_rel = _reduce_max(local_rel, device) + ok = _reduce_min_int(local_ok, device) + return ok, f"{name}: max_abs={max_abs:.6e}, max_rel={max_rel:.6e}" + + +def run_correctness(args, rank: int, world_size: int, device: torch.device, op: Any) -> bool: + dtype = _parse_dtype(args.dtype) + boundaries = _make_boundaries(args.vocab_size, world_size, args.uneven_shards) + start, end = boundaries[rank], boundaries[rank + 1] + + gen = _generator(device, args.seed) + hidden = torch.randn( + args.tokens, + args.hidden_size, + generator=gen, + device=device, + dtype=dtype, + ) + weight = torch.randn( + args.vocab_size, + args.hidden_size, + generator=gen, + device=device, + dtype=dtype, + ) + bias = ( + torch.randn(args.vocab_size, generator=gen, device=device, dtype=dtype) + if not args.no_bias + else None + ) + target = torch.randint(0, args.vocab_size, (args.tokens,), generator=gen, device=device) + grad_out = torch.randn(args.tokens, generator=gen, device=device, dtype=torch.float32) + + ref_hidden = hidden.detach().clone().requires_grad_(True) + ref_weight = weight.detach().clone().requires_grad_(True) + ref_bias = bias.detach().clone().requires_grad_(True) if bias is not None else None + ref_out = _materialized_logp( + ref_hidden, + ref_weight, + target, + ref_bias, + reference_mode=args.reference_mode, + ) + (ref_out * grad_out).sum().backward() + + tp_hidden = hidden.detach().clone().requires_grad_(True) + local_weight = weight[start:end].detach().clone().requires_grad_(True) + local_bias = bias[start:end].detach().clone().requires_grad_(True) if bias is not None else None + + tp_out = op( + tp_hidden, + local_weight, + target, + local_bias, + tp_group=dist.group.WORLD, + vocab_start_index=start, + global_vocab_size=args.vocab_size, + ) + (tp_out * grad_out).sum().backward() + + atol = args.atol if args.atol is not None else _dtype_default_atol(dtype, args.reference_mode) + rtol = args.rtol if args.rtol is not None else _dtype_default_rtol(dtype, args.reference_mode) + local_bias_ref = ref_bias.grad[start:end] if ref_bias is not None else None + + checks = [ + _check_metric( + name="output", + actual=tp_out, + expected=ref_out, + atol=atol, + rtol=rtol, + device=device, + ), + _check_metric( + name="hidden_grad", + actual=tp_hidden.grad, + expected=ref_hidden.grad, + atol=atol, + rtol=rtol, + device=device, + ), + _check_metric( + name="weight_grad", + actual=local_weight.grad, + expected=ref_weight.grad[start:end], + atol=atol, + rtol=rtol, + device=device, + ), + ] + if local_bias is not None and local_bias_ref is not None: + checks.append( + _check_metric( + name="bias_grad", + actual=local_bias.grad, + expected=local_bias_ref, + atol=atol, + rtol=rtol, + device=device, + ) + ) + + if rank == 0: + print("\n[correctness]") + print(f" dtype={dtype}, reference_mode={args.reference_mode}, atol={atol}, rtol={rtol}") + print(f" tokens={args.tokens}, hidden={args.hidden_size}, vocab={args.vocab_size}") + print(f" shard_boundaries={boundaries}") + for ok, line in checks: + print(f" {'PASS' if ok else 'FAIL'} {line}") + + return all(ok for ok, _ in checks) + + +def run_stress(args, rank: int, world_size: int, device: torch.device, op: Any) -> bool: + dtype = _parse_dtype(args.dtype) + boundaries = _make_boundaries(args.stress_vocab_size, world_size, args.uneven_shards) + start, end = boundaries[rank], boundaries[rank + 1] + + hidden_gen = _generator(device, args.seed + 1000) + local_gen = _generator(device, args.seed + 2000 + rank) + target_gen = _generator(device, args.seed + 3000) + + hidden = torch.randn( + args.stress_tokens, + args.stress_hidden_size, + generator=hidden_gen, + device=device, + dtype=dtype, + requires_grad=True, + ) + local_weight = torch.randn( + end - start, + args.stress_hidden_size, + generator=local_gen, + device=device, + dtype=dtype, + requires_grad=True, + ) + local_bias = None + if not args.no_bias: + local_bias = torch.randn( + end - start, + generator=local_gen, + device=device, + dtype=dtype, + requires_grad=True, + ) + target = torch.randint( + 0, + args.stress_vocab_size, + (args.stress_tokens,), + generator=target_gen, + device=device, + ) + + def step() -> torch.Tensor: + out = op( + hidden, + local_weight, + target, + local_bias, + tp_group=dist.group.WORLD, + vocab_start_index=start, + global_vocab_size=args.stress_vocab_size, + ) + loss = out.float().mean() + loss.backward() + return out + + stress_out = None + + def timed_step() -> None: + nonlocal stress_out + stress_out = step() + + _reset_peak_memory(device) + elapsed_ms = _time_block(device, timed_step) + + finite_tensor = ( + torch.isfinite(stress_out).all() + & torch.isfinite(hidden.grad).all() + & torch.isfinite(local_weight.grad).all() + ) + if local_bias is not None: + finite_tensor = finite_tensor & torch.isfinite(local_bias.grad).all() + finite = _reduce_min_int(bool(finite_tensor.item()), device) + + peak_gb = _reduce_max(_peak_memory_gb(device), device) + elapsed_ms = _reduce_max(elapsed_ms, device) + if rank == 0: + print("\n[stress]") + print( + " tokens=%d, hidden=%d, vocab=%d, dtype=%s" + % (args.stress_tokens, args.stress_hidden_size, args.stress_vocab_size, dtype) + ) + print(f" shard_boundaries={boundaries}") + print(f" finite={'PASS' if finite else 'FAIL'}") + print(f" max_rank_elapsed_ms={elapsed_ms:.3f}") + if device.type == "cuda": + print(f" max_rank_peak_memory_gb={peak_gb:.3f}") + return finite + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--op-source", + choices=["registry", "native", "triton", "sm90"], + default="registry", + ) + parser.add_argument("--dtype", default="bf16", help="bf16, fp16, or fp32") + parser.add_argument("--reference-mode", choices=["matching", "fp32"], default="matching") + parser.add_argument("--seed", type=int, default=1234) + parser.add_argument("--tokens", type=int, default=128) + parser.add_argument("--hidden-size", type=int, default=256) + parser.add_argument("--vocab-size", type=int, default=4096) + parser.add_argument("--no-bias", action="store_true") + parser.add_argument("--uneven-shards", action="store_true") + parser.add_argument("--atol", type=float, default=None) + parser.add_argument("--rtol", type=float, default=None) + parser.add_argument("--run-stress", action="store_true") + parser.add_argument("--stress-tokens", type=int, default=4096) + parser.add_argument("--stress-hidden-size", type=int, default=2048) + parser.add_argument("--stress-vocab-size", type=int, default=32768) + return parser + + +def main() -> None: + args = build_parser().parse_args() + rank, local_rank, world_size, device, backend = _init_distributed() + try: + _print_rank0(rank, "[env]") + _print_rank0(rank, f" backend={backend}, world_size={world_size}") + _print_rank0(rank, f" torch={torch.__version__}, cuda={torch.version.cuda}") + if device.type == "cuda": + name = torch.cuda.get_device_name(device) + cc = torch.cuda.get_device_capability(device) + _print_rank0(rank, f" rank0_device={name}, capability=sm_{cc[0]}{cc[1]}") + _print_rank0(rank, f" op_source={args.op_source}") + + op = _load_op(args.op_source) + dist.barrier() + ok = run_correctness(args, rank, world_size, device, op) + if args.run_stress: + dist.barrier() + ok = run_stress(args, rank, world_size, device, op) and ok + + ok = _reduce_min_int(ok, device) + dist.barrier() + if rank == 0: + print("\n[result]") + print(" PASS" if ok else " FAIL") + if not ok: + raise SystemExit(1) + finally: + if dist.is_initialized(): + dist.destroy_process_group() + _ = local_rank + + +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py index d5ddb89..8aca85d 100644 --- a/setup.py +++ b/setup.py @@ -37,6 +37,10 @@ def get_extensions(): return [] extensions = [] + torch_lib_dir = os.path.join(os.path.dirname(torch.__file__), "lib") + torch_rpath = ["-Wl,-rpath,$ORIGIN/../torch/lib"] + if os.environ.get("KERNEL_ALIGN_DEV_RPATH") == "1": + torch_rpath.append(f"-Wl,-rpath,{torch_lib_dir}") is_rocm = torch.version.hip is not None if is_rocm and ROCMExtension is not None: @@ -51,6 +55,7 @@ def get_extensions(): "cxx": ["-O3", "-std=c++17"], "hipcc": ["-O3", "--use_fast_math", "-Xhipcc", "-compress-all"], }, + extra_link_args=list(torch_rpath), ) ) elif torch.cuda.is_available(): @@ -108,7 +113,7 @@ def get_extensions(): nvcc_flags.append("-lineinfo") cxx_flags = ["-O3", "-std=c++17", "-DKERNEL_ALIGN_WITH_CUDA"] - extra_link_args = [] + extra_link_args = list(torch_rpath) sm90_srcs = [ "csrc/cuda/fused_logp_sm90.cu", diff --git a/tests/test_deepspeed_training_worker.py b/tests/test_deepspeed_training_worker.py index e093444..35caefa 100644 --- a/tests/test_deepspeed_training_worker.py +++ b/tests/test_deepspeed_training_worker.py @@ -8,12 +8,15 @@ import os import sys import time +from dataclasses import replace import pytest import torch from rl_engine.executors.bridge import LocalTensorCopyBridge, WeightBridgeUnavailableError from rl_engine.executors.training_contract import RolloutStageResult +from rl_engine.kernels.ops.pytorch.loss.linear_logp import NativeLinearLogpOp +from rl_engine.testing import make_synthetic_rl_kernel_batch, selected_logprobs_reference class FakeDeepSpeedEngine: @@ -56,16 +59,25 @@ def initialize(self, **kwargs): class FakeGatheredParameters: calls = 0 + active = 0 + max_active = 0 + modifier_ranks = [] + parameter_counts = [] def __init__(self, parameters, modifier_rank=0): self.parameters = list(parameters) self.modifier_rank = modifier_rank + type(self).modifier_ranks.append(modifier_rank) + type(self).parameter_counts.append(len(self.parameters)) def __enter__(self): type(self).calls += 1 + type(self).active += 1 + type(self).max_active = max(type(self).max_active, type(self).active) return self.parameters def __exit__(self, exc_type, exc, traceback): + type(self).active -= 1 return False @@ -78,6 +90,10 @@ def _install_fake_deepspeed(monkeypatch): def _install_fake_deepspeed_with_gather(monkeypatch): fake = FakeDeepSpeedModule() FakeGatheredParameters.calls = 0 + FakeGatheredParameters.active = 0 + FakeGatheredParameters.max_active = 0 + FakeGatheredParameters.modifier_ranks = [] + FakeGatheredParameters.parameter_counts = [] fake.zero = type("FakeZeroNamespace", (), {"GatheredParameters": FakeGatheredParameters})() monkeypatch.setitem(sys.modules, "deepspeed", fake) return fake @@ -98,6 +114,24 @@ def _rollout(iteration=2, weight_version=9): ) +class SpyLinearLogpOp: + def __init__(self): + self.calls = [] + self._delegate = NativeLinearLogpOp() + + def __call__(self, hidden, lm_head_weight, target_ids, bias=None, **kwargs): + self.calls.append( + { + "hidden": hidden.detach().clone(), + "lm_head_weight": lm_head_weight.detach().clone(), + "target_ids": target_ids.detach().clone(), + "bias": None if bias is None else bias.detach().clone(), + "kwargs": dict(kwargs), + } + ) + return self._delegate(hidden, lm_head_weight, target_ids, bias, **kwargs) + + def test_importing_module_does_not_import_deepspeed(monkeypatch): monkeypatch.delitem(sys.modules, "deepspeed", raising=False) @@ -218,9 +252,12 @@ def test_deepspeed_training_worker_uses_engine_backward_and_step(monkeypatch): assert result.consumed_weight_version == 9 assert result.published_weight_version == 10 assert result.metrics["training_backend"] == "deepspeed" + assert result.metrics["deepspeed_zero_stage"] == 1 assert result.metrics["training_data_source"] == "rollout_payload" assert result.metrics["rollout_sequences"] == 2 assert result.metrics["rollout_tokens"] == 6 + assert result.metrics["current_logp_path"] == "linear_logp" + assert result.metrics["current_logp_backend"] == "NativeLinearLogpOp" assert math.isfinite(result.metrics["loss"]) assert "advantage_mean" not in result.metrics assert "advantage_std" not in result.metrics @@ -228,6 +265,309 @@ def test_deepspeed_training_worker_uses_engine_backward_and_step(monkeypatch): assert result.metrics["active_advantage_std_global"] >= 0.0 +def test_extract_logps_matches_masked_reference_with_ignore_index(): + from rl_engine.executors.deepspeed_trainer import _EmbeddingLMHeadModel, _extract_logps + + torch.manual_seed(2026) + model = _EmbeddingLMHeadModel(vocab_size=13, hidden_dim=7) + input_ids = torch.tensor([[4, 3, 2], [1, 0, 5]], dtype=torch.long) + token_ids = torch.tensor([[6, -100, 2], [-100, 1, 4]], dtype=torch.long) + mask = token_ids.ne(-100) + + hidden = model(input_ids) + actual = _extract_logps( + hidden, + model, + token_ids, + mask, + NativeLinearLogpOp(), + output_dtype=torch.float32, + ) + logits = torch.nn.functional.linear( + hidden.float(), + model.lm_head.weight.float(), + model.lm_head.bias.float(), + ) + expected = selected_logprobs_reference(logits, token_ids, mask=mask) + + assert torch.allclose(actual, expected, atol=1e-5) + assert actual[~mask].eq(0.0).all() + + +def test_extract_logps_uses_hidden_dim_to_disambiguate_tuple_logits(): + from rl_engine.executors.deepspeed_trainer import _EmbeddingLMHeadModel, _extract_logps + + torch.manual_seed(2027) + model = _EmbeddingLMHeadModel(vocab_size=13, hidden_dim=5) + input_ids = torch.tensor([[4, 3, 2]], dtype=torch.long) + token_ids = torch.tensor([[6, 1, 4]], dtype=torch.long) + mask = torch.ones_like(token_ids, dtype=torch.bool) + + hidden = model(input_ids) + logits = torch.randn(1, 3, model.lm_head.out_features) + actual = _extract_logps( + (torch.tensor(1.0), logits, hidden), + model, + token_ids, + mask, + NativeLinearLogpOp(), + output_dtype=torch.float32, + ) + expected_logits = torch.nn.functional.linear( + hidden.float(), + model.lm_head.weight.float(), + model.lm_head.bias.float(), + ) + expected = selected_logprobs_reference(expected_logits, token_ids, mask=mask) + + assert torch.allclose(actual, expected, atol=1e-5) + + +def test_extract_hidden_states_prefers_last_hidden_state_over_hidden_state_stack(): + from rl_engine.executors.deepspeed_trainer import _extract_hidden_states + + expected = torch.randn(2, 3, 5) + output = { + "hidden_states": ( + torch.randn(2, 3, 5), + torch.randn(2, 3, 5), + ), + "last_hidden_state": expected, + } + + actual = _extract_hidden_states(output) + + assert actual is expected + + +def test_extract_hidden_states_uses_last_tensor_from_hidden_state_stack(): + from rl_engine.executors.deepspeed_trainer import _extract_hidden_states + + layers = ( + torch.randn(2, 3, 5), + torch.randn(2, 3, 5), + torch.randn(2, 3, 5), + ) + + actual = _extract_hidden_states({"hidden_states": layers}) + + assert actual is layers[-1] + + +def test_extract_hidden_states_prefers_structured_hidden_over_tuple_logits(): + from rl_engine.executors.deepspeed_trainer import _extract_hidden_states + + logits = torch.randn(2, 3, 11) + expected = torch.randn(2, 3, 5) + output = (torch.tensor(1.0), logits, {"last_hidden_state": expected}) + + actual = _extract_hidden_states(output) + + assert actual is expected + + +def test_extract_hidden_states_rejects_ambiguous_multi_tensor_tuple(): + from rl_engine.executors.deepspeed_trainer import _extract_hidden_states + + with pytest.raises(TypeError, match="hidden-state tensor"): + _extract_hidden_states((torch.randn(2, 3, 11), torch.randn(2, 3, 5))) + + +def test_deepspeed_training_worker_routes_linear_logp_and_zeroes_masked_targets(monkeypatch): + _install_fake_deepspeed(monkeypatch) + from rl_engine.executors import deepspeed_trainer + + spy = SpyLinearLogpOp() + monkeypatch.setattr(deepspeed_trainer, "_linear_logp_op_for_device", lambda device: spy) + + worker = deepspeed_trainer.DeepSpeedTrainingWorker( + deepspeed_trainer.DeepSpeedTrainingConfig( + num_prompts=1, + samples_per_prompt=2, + prompt_len=1, + completion_len=4, + vocab_size=23, + hidden_dim=8, + seed=31, + ) + ) + batch = make_synthetic_rl_kernel_batch( + num_prompts=1, + samples_per_prompt=2, + prompt_len=1, + completion_len=4, + vocab_size=23, + valid_density=1.0, + device="cpu", + seed=32, + ) + completion_mask = torch.tensor( + [[True, False, True, False], [False, True, True, False]], + dtype=torch.bool, + ) + patched_batch = replace( + batch, + completion_mask=completion_mask, + valid_indices=completion_mask.reshape(-1).nonzero(as_tuple=False).squeeze(-1), + metadata={ + **batch.metadata, + "valid_density": float(completion_mask.float().mean().item()), + "valid_tokens": int(completion_mask.sum().item()), + }, + ) + monkeypatch.setattr( + worker, + "_batch_from_rollout_or_synthetic", + lambda rollout: ( + patched_batch, + { + "training_data_source": "patched_fixture", + "rollout_sequences": patched_batch.batch_size, + "rollout_tokens": int(completion_mask.sum().item()), + }, + ), + ) + + result = worker.train(_rollout()) + + assert len(spy.calls) == 1 + recorded_targets = spy.calls[0]["target_ids"] + assert torch.equal(recorded_targets[completion_mask], patched_batch.token_ids[completion_mask]) + assert torch.equal( + recorded_targets[~completion_mask], + torch.zeros_like(recorded_targets[~completion_mask]), + ) + assert result.metrics["training_data_source"] == "patched_fixture" + assert result.metrics["current_logp_path"] == "linear_logp" + assert result.metrics["current_logp_backend"] == "SpyLinearLogpOp" + assert math.isfinite(result.metrics["loss"]) + + +def test_deepspeed_training_worker_rejects_ignore_index_in_model_inputs(monkeypatch): + _install_fake_deepspeed(monkeypatch) + from rl_engine.executors import deepspeed_trainer + + worker = deepspeed_trainer.DeepSpeedTrainingWorker( + deepspeed_trainer.DeepSpeedTrainingConfig( + num_prompts=1, + samples_per_prompt=1, + prompt_len=1, + completion_len=3, + vocab_size=17, + hidden_dim=8, + seed=33, + ) + ) + batch = make_synthetic_rl_kernel_batch( + num_prompts=1, + samples_per_prompt=1, + prompt_len=1, + completion_len=3, + vocab_size=17, + valid_density=1.0, + device="cpu", + seed=34, + ) + broken_batch = replace( + batch, + token_ids=batch.token_ids.clone(), + ) + broken_batch.token_ids[0, 1] = -100 + monkeypatch.setattr( + worker, + "_batch_from_rollout_or_synthetic", + lambda rollout: (broken_batch, {"training_data_source": "patched_fixture"}), + ) + + with pytest.raises(ValueError, match="ignore-index"): + worker.train(_rollout()) + + +def test_deepspeed_zero3_training_gathers_lm_head_parameters_during_backward(monkeypatch): + _install_fake_deepspeed_with_gather(monkeypatch) + from rl_engine.executors import deepspeed_trainer + + worker = deepspeed_trainer.DeepSpeedTrainingWorker( + deepspeed_trainer.DeepSpeedTrainingConfig( + vocab_size=19, + hidden_dim=8, + zero_stage=3, + seed=35, + ) + ) + worker.engine.world_size = 2 + + active_during_backward = {"value": False} + original_backward = worker.engine.backward + + def wrapped_backward(loss): + active_during_backward["value"] = FakeGatheredParameters.active > 0 + return original_backward(loss) + + worker.engine.backward = wrapped_backward + + result = worker.train(_rollout()) + + assert result.metrics["current_logp_path"] == "linear_logp" + assert FakeGatheredParameters.calls == 1 + assert FakeGatheredParameters.parameter_counts == [2] + assert FakeGatheredParameters.modifier_ranks == [None] + assert FakeGatheredParameters.max_active == 1 + assert active_during_backward["value"] is True + assert FakeGatheredParameters.active == 0 + + +def test_deepspeed_zero3_training_without_gather_api_is_blocked(monkeypatch): + _install_fake_deepspeed(monkeypatch) + from rl_engine.executors import deepspeed_trainer + + worker = deepspeed_trainer.DeepSpeedTrainingWorker( + deepspeed_trainer.DeepSpeedTrainingConfig( + vocab_size=19, + hidden_dim=8, + zero_stage=3, + seed=36, + ) + ) + worker.engine.world_size = 2 + + with pytest.raises(WeightBridgeUnavailableError, match="linear_logp training requires"): + worker.train(_rollout()) + + +def test_deepspeed_config_zero3_override_controls_training_and_publish(monkeypatch): + fake = _install_fake_deepspeed_with_gather(monkeypatch) + from rl_engine.executors import deepspeed_trainer + + bridge = LocalTensorCopyBridge(source_worker="training", source_rank=0) + worker = deepspeed_trainer.DeepSpeedTrainingWorker( + deepspeed_trainer.DeepSpeedTrainingConfig( + vocab_size=19, + hidden_dim=8, + zero_stage=0, + deepspeed_config={"zero_optimization": {"stage": 3}}, + seed=37, + ), + weight_bridge=bridge, + ) + worker.engine.world_size = 2 + + result = worker.train(_rollout()) + manifest = worker.publish_weights(weight_version=41) + + assert fake.initialize_calls[0]["config"]["zero_optimization"]["stage"] == 3 + assert result.metrics["deepspeed_zero_stage"] == 3 + assert manifest.metadata["layout"]["zero_stage"] == 3 + assert manifest.metadata["deepspeed_zero3_full_state_export"]["method"] == ( + "deepspeed.zero.GatheredParameters" + ) + assert FakeGatheredParameters.calls == 2 + assert FakeGatheredParameters.parameter_counts == [2, 3] + assert FakeGatheredParameters.modifier_ranks == [None, 0] + + bridge.release(manifest.update_id) + + def test_deepspeed_training_worker_synthetic_fallback(monkeypatch): _install_fake_deepspeed(monkeypatch) from rl_engine.executors.deepspeed_trainer import ( diff --git a/tests/test_linear_logp.py b/tests/test_linear_logp.py index a6ce1b3..4936cd6 100644 --- a/tests/test_linear_logp.py +++ b/tests/test_linear_logp.py @@ -1,10 +1,21 @@ # SPDX-License-Identifier: Apache-2.0 # Copyright (c) 2026 RL-Kernel Contributors +import queue +import tempfile +import traceback +from pathlib import Path + import pytest import torch +import torch.multiprocessing as mp -from rl_engine.kernels.ops.pytorch.loss.linear_logp import NativeLinearLogpOp +from rl_engine.executors.deepspeed_trainer import _EmbeddingLMHeadModel, _safe_token_ids +from rl_engine.kernels.ops.pytorch.loss.linear_logp import ( + NativeLinearLogpOp, + chunked_linear_logp_backward, +) +from rl_engine.testing import selected_logprobs_reference try: import triton # noqa: F401 @@ -39,6 +50,82 @@ def _sm90_available(): "extension built KERNEL_ALIGN_FORCE_SM90=1.", ) + +def _gloo_available(): + return torch.distributed.is_available() and torch.distributed.is_gloo_available() + + +requires_gloo = pytest.mark.skipif( + not _gloo_available(), + reason="tensor-parallel linear_logp CPU test requires torch.distributed Gloo.", +) + + +def _tp_linear_logp_gloo_worker(rank, world_size, init_method, result_queue): + try: + import torch.distributed as dist + + torch.set_num_threads(1) + dist.init_process_group( + backend="gloo", + init_method=init_method, + rank=rank, + world_size=world_size, + ) + + torch.manual_seed(2026) + n, d, vocab = 8, 5, 16 + boundaries = [0, 3, 7, 12, vocab] + start = boundaries[rank] + end = boundaries[rank + 1] + + hidden_base = torch.randn(n, d) + weight_full = torch.randn(vocab, d) + bias_full = torch.randn(vocab) + target = torch.tensor([0, 2, 3, 6, 7, 11, 12, 15], dtype=torch.long) + grad_out = torch.randn(n) + op = NativeLinearLogpOp() + + ref_hidden = hidden_base.detach().clone().requires_grad_(True) + ref_weight = weight_full.detach().clone().requires_grad_(True) + ref_bias = bias_full.detach().clone().requires_grad_(True) + ref_out = op(ref_hidden, ref_weight, target, ref_bias) + ref_out.backward(grad_out) + + tp_hidden = hidden_base.detach().clone().requires_grad_(True) + local_weight = weight_full[start:end].detach().clone().requires_grad_(True) + local_bias = bias_full[start:end].detach().clone().requires_grad_(True) + tp_out = op( + tp_hidden, + local_weight, + target, + local_bias, + tp_group=dist.group.WORLD, + vocab_start_index=start, + global_vocab_size=vocab, + ) + tp_out.backward(grad_out) + + result_queue.put( + { + "ok": True, + "rank": rank, + "out": float((tp_out - ref_out).abs().max().item()), + "hidden_grad": float((tp_hidden.grad - ref_hidden.grad).abs().max().item()), + "weight_grad": float( + (local_weight.grad - ref_weight.grad[start:end]).abs().max().item() + ), + "bias_grad": float((local_bias.grad - ref_bias.grad[start:end]).abs().max().item()), + } + ) + except Exception: # pragma: no cover - forwarded to parent process + result_queue.put({"ok": False, "rank": rank, "traceback": traceback.format_exc()}) + raise + finally: + if torch.distributed.is_available() and torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + # SM90 forward needs bf16 and a hidden dim that is a multiple of the kernel's K # slice (32); N / V are deliberately left unaligned to the 64-wide tiles. _SM90_N = 96 @@ -83,6 +170,48 @@ def _manual_reference(hidden, weight, target, bias): return sel.reshape(target.shape) +def _layout_inputs(base_hidden, base_target, base_mask, order, lead_shape): + order_t = torch.tensor(order, dtype=torch.long) + hidden = base_hidden.index_select(0, order_t).reshape(*lead_shape, base_hidden.size(-1)) + target = base_target.index_select(0, order_t).reshape(*lead_shape) + mask = base_mask.index_select(0, order_t).reshape(*lead_shape) + masked_target = target.masked_fill(~mask, -100) + return hidden, masked_target, mask + + +def _recover_canonical_rows(layout_values, order): + flat = layout_values.reshape( + layout_values.shape[0] * layout_values.shape[1], *layout_values.shape[2:] + ) + recovered = torch.empty_like(flat) + recovered[torch.tensor(order, dtype=torch.long)] = flat + return recovered + + +def _run_chunked_backward(hidden, weight, target, bias, grad_out, *, chunk_elems): + return chunked_linear_logp_backward( + grad_out, + hidden.reshape(-1, hidden.size(-1)).contiguous(), + weight, + target.reshape(-1).contiguous(), + hidden.reshape(-1, hidden.size(-1)).contiguous() if bias is None else bias, + has_bias=bias is not None, + lead_shape=target.shape, + hidden_dtype=hidden.dtype, + weight_dtype=weight.dtype, + bias_dtype=None if bias is None else bias.dtype, + chunk_elems=chunk_elems, + ) + + +def _run_autograd_linear_logp(hidden, weight, target, bias, grad_out): + h = hidden.detach().clone().requires_grad_(True) + w = weight.detach().clone().requires_grad_(True) + b = bias.detach().clone().requires_grad_(True) if bias is not None else None + NativeLinearLogpOp()(h, w, target, b).backward(grad_out) + return h.grad, w.grad, (None if b is None else b.grad) + + def test_native_matches_manual_reference(): native = NativeLinearLogpOp() hidden, weight, target, bias = _inputs(0, device="cpu") @@ -92,6 +221,146 @@ def test_native_matches_manual_reference(): assert torch.allclose(out, ref, atol=1e-5) +def test_linear_logp_handoff_matches_masked_reference_across_layouts(): + torch.manual_seed(2026) + op = NativeLinearLogpOp() + base_hidden = torch.randn(6, 5) + weight = torch.randn(17, 5) + bias = torch.randn(17) + base_target = torch.tensor([3, 7, 1, 9, 4, 6], dtype=torch.long) + base_mask = torch.tensor([True, False, True, True, False, True], dtype=torch.bool) + layouts = [ + ((2, 3), [0, 1, 2, 3, 4, 5]), + ((3, 2), [5, 1, 3, 0, 4, 2]), + ((1, 6), [2, 4, 1, 5, 0, 3]), + ] + + canonical = None + for lead_shape, order in layouts: + hidden, target, mask = _layout_inputs( + base_hidden, base_target, base_mask, order, lead_shape + ) + actual = op(hidden, weight, _safe_token_ids(target, mask), bias).masked_fill(~mask, 0.0) + logits = torch.nn.functional.linear(hidden.float(), weight.float(), bias.float()) + expected = selected_logprobs_reference(logits, target, mask=mask) + recovered = _recover_canonical_rows(actual.unsqueeze(-1), order).squeeze(-1) + + assert torch.allclose(actual, expected, atol=1e-5) + if canonical is None: + canonical = recovered + else: + assert torch.allclose(recovered, canonical, atol=1e-6) + + +@pytest.mark.parametrize("use_bias", [True, False]) +def test_chunked_linear_logp_backward_matches_autograd_and_layout_invariance(use_bias): + torch.manual_seed(2027) + weight = torch.randn(19, 7) + bias = torch.randn(19) if use_bias else None + base_hidden = torch.randn(6, 7) + base_target = torch.tensor([1, 7, 3, 5, 0, 9], dtype=torch.long) + base_mask = torch.tensor([True, False, True, True, False, True], dtype=torch.bool) + base_grad = torch.tensor([0.5, 0.0, -1.25, 0.75, 0.0, 1.5], dtype=torch.float32) + layouts = [ + ((2, 3), [0, 1, 2, 3, 4, 5]), + ((3, 2), [5, 2, 1, 0, 4, 3]), + ] + + canonical_hidden_grad = None + canonical_weight_grad = None + canonical_bias_grad = None + chunk_elems = weight.size(0) * 2 + + for lead_shape, order in layouts: + hidden, target, mask = _layout_inputs( + base_hidden, base_target, base_mask, order, lead_shape + ) + safe_target = _safe_token_ids(target, mask) + grad_out = base_grad[torch.tensor(order, dtype=torch.long)].reshape(lead_shape) + grad_out = grad_out.masked_fill(~mask, 0.0) + + grad_hidden, grad_weight, grad_bias = _run_chunked_backward( + hidden, + weight, + safe_target, + bias, + grad_out, + chunk_elems=chunk_elems, + ) + ref_hidden, ref_weight, ref_bias = _run_autograd_linear_logp( + hidden, + weight, + safe_target, + bias, + grad_out, + ) + recovered_hidden = _recover_canonical_rows(grad_hidden, order) + + assert torch.allclose(grad_hidden, ref_hidden, atol=1e-5) + assert torch.allclose(grad_weight, ref_weight, atol=1e-5) + if use_bias: + assert torch.allclose(grad_bias, ref_bias, atol=1e-5) + + if canonical_hidden_grad is None: + canonical_hidden_grad = recovered_hidden + canonical_weight_grad = grad_weight + canonical_bias_grad = grad_bias + else: + assert torch.allclose(recovered_hidden, canonical_hidden_grad, atol=1e-6) + assert torch.allclose(grad_weight, canonical_weight_grad, atol=1e-6) + if use_bias: + assert torch.allclose(grad_bias, canonical_bias_grad, atol=1e-6) + + +def test_tied_embedding_lm_head_shared_gradient_is_layout_invariant(): + torch.manual_seed(2028) + model = _EmbeddingLMHeadModel(vocab_size=13, hidden_dim=6, bias=False, tie_weights=True) + op = NativeLinearLogpOp() + base_input_ids = torch.tensor([2, 5, 1, 5, 2, 3], dtype=torch.long) + base_target = torch.tensor([4, 1, 0, 2, 6, 3], dtype=torch.long) + base_mask = torch.tensor([True, False, True, True, False, True], dtype=torch.bool) + base_upstream = torch.tensor([0.75, 0.0, -1.25, 0.5, 0.0, 1.0], dtype=torch.float32) + layouts = [ + ((2, 3), [0, 1, 2, 3, 4, 5]), + ((3, 2), [5, 2, 1, 0, 4, 3]), + ] + + assert model.lm_head.weight is model.embedding.weight + canonical_logps = None + canonical_grad = None + + for lead_shape, order in layouts: + order_t = torch.tensor(order, dtype=torch.long) + input_ids = base_input_ids.index_select(0, order_t).reshape(lead_shape) + target = base_target.index_select(0, order_t).reshape(lead_shape) + mask = base_mask.index_select(0, order_t).reshape(lead_shape) + masked_target = target.masked_fill(~mask, -100) + upstream = ( + base_upstream.index_select(0, order_t).reshape(lead_shape).masked_fill(~mask, 0.0) + ) + + model.zero_grad(set_to_none=True) + hidden = model(input_ids) + logps = op( + hidden, model.lm_head.weight, _safe_token_ids(masked_target, mask), model.lm_head.bias + ) + logps = logps.masked_fill(~mask, 0.0) + logits = torch.nn.functional.linear(hidden.float(), model.lm_head.weight.float(), None) + expected = selected_logprobs_reference(logits, masked_target, mask=mask) + (logps * upstream).sum().backward() + + recovered_logps = _recover_canonical_rows(logps.unsqueeze(-1), order).squeeze(-1) + shared_grad = model.embedding.weight.grad.detach().clone() + + assert torch.allclose(logps, expected, atol=1e-5) + if canonical_logps is None: + canonical_logps = recovered_logps + canonical_grad = shared_grad + else: + assert torch.allclose(recovered_logps, canonical_logps, atol=1e-6) + assert torch.allclose(shared_grad, canonical_grad, atol=1e-6) + + def test_native_rejects_shape_mismatch(): native = NativeLinearLogpOp() hidden, weight, _, bias = _inputs(0, device="cpu") @@ -99,6 +368,60 @@ def test_native_rejects_shape_mismatch(): native(hidden, weight, torch.zeros(_N + 1, dtype=torch.long), bias) +def test_tensor_parallel_metadata_requires_multi_rank_group(): + native = NativeLinearLogpOp() + hidden, weight, target, bias = _inputs(0, device="cpu") + with pytest.raises(ValueError, match="vocab_start_index requires"): + native(hidden, weight, target, bias, vocab_start_index=4) + with pytest.raises(ValueError, match="global_vocab_size differs"): + native(hidden, weight, target, bias, global_vocab_size=weight.size(0) + 1) + + +@requires_gloo +def test_native_tensor_parallel_matches_full_reference_cpu_gloo_4_ranks(): + ctx = mp.get_context("spawn") + world_size = 4 + with tempfile.TemporaryDirectory() as tmpdir: + init_method = (Path(tmpdir) / "gloo_init").as_uri() + result_queue = ctx.Queue() + processes = [ + ctx.Process( + target=_tp_linear_logp_gloo_worker, + args=(rank, world_size, init_method, result_queue), + ) + for rank in range(world_size) + ] + + for process in processes: + process.start() + + results = [] + try: + for _ in processes: + results.append(result_queue.get(timeout=45)) + except queue.Empty: + for process in processes: + if process.is_alive(): + process.terminate() + pytest.fail("timed out waiting for tensor-parallel Gloo workers") + finally: + for process in processes: + process.join(timeout=10) + if process.is_alive(): + process.terminate() + + sorted_results = sorted(results, key=lambda item: item["rank"]) + for result in sorted_results: + assert result["ok"], result.get("traceback") + for process in processes: + assert process.exitcode == 0 + for result in sorted_results: + assert result["out"] < 1e-5 + assert result["hidden_grad"] < 1e-5 + assert result["weight_grad"] < 1e-5 + assert result["bias_grad"] < 1e-5 + + @requires_triton_cuda def test_triton_forward_matches_native_fp32(): from rl_engine.kernels.ops.triton.loss.linear_logp import TritonLinearLogpOp @@ -325,6 +648,94 @@ def test_sm90_rejects_out_of_range_target(): assert out.shape == target.shape and torch.isfinite(out).all() +def test_sm90_tp_metadata_prefers_sm90_tp_helper(monkeypatch): + from rl_engine.kernels.ops.cuda.loss import linear_logp as cuda_linear_logp + + op = object.__new__(cuda_linear_logp.FusedLinearLogpSM90Op) + hidden = torch.randn(2, 4) + weight = torch.randn(3, 4) + target = torch.tensor([3, 5]) + sentinel = torch.full((2,), 7.0) + tp_group = object() + calls = {} + + monkeypatch.setattr( + cuda_linear_logp, + "should_use_tensor_parallel_linear_logp", + lambda *args, **kwargs: True, + ) + monkeypatch.setattr(cuda_linear_logp, "_sm90_supported", lambda h, w: True) + + def fake_sm90_tp(hidden_arg, weight_arg, target_arg, bias_arg, **kwargs): + calls["sm90_tp"] = (hidden_arg, weight_arg, target_arg, bias_arg, kwargs) + return sentinel + + def forbidden_portable_tp(*args, **kwargs): + raise AssertionError("portable TP path should not run when SM90 TP is available") + + monkeypatch.setattr(cuda_linear_logp, "_sm90_tensor_parallel_linear_logp", fake_sm90_tp) + monkeypatch.setattr(cuda_linear_logp, "tensor_parallel_linear_logp", forbidden_portable_tp) + + out = op( + hidden, + weight, + target, + tp_group=tp_group, + vocab_start_index=3, + global_vocab_size=6, + ) + + assert out is sentinel + assert calls["sm90_tp"][0] is hidden + assert calls["sm90_tp"][1] is weight + assert calls["sm90_tp"][2] is target + assert calls["sm90_tp"][4] == { + "tp_group": tp_group, + "vocab_start_index": 3, + "global_vocab_size": 6, + } + + +def test_sm90_tp_metadata_falls_back_to_portable_tp_when_sm90_unsupported(monkeypatch): + from rl_engine.kernels.ops.cuda.loss import linear_logp as cuda_linear_logp + + op = object.__new__(cuda_linear_logp.FusedLinearLogpSM90Op) + hidden = torch.randn(2, 4) + weight = torch.randn(3, 4) + target = torch.tensor([3, 5]) + sentinel = torch.full((2,), 11.0) + + monkeypatch.setattr( + cuda_linear_logp, + "should_use_tensor_parallel_linear_logp", + lambda *args, **kwargs: True, + ) + monkeypatch.setattr(cuda_linear_logp, "_sm90_supported", lambda h, w: False) + monkeypatch.setattr( + cuda_linear_logp, + "_sm90_tensor_parallel_linear_logp", + lambda *args, **kwargs: (_ for _ in ()).throw( + AssertionError("SM90 TP helper should not run for unsupported inputs") + ), + ) + monkeypatch.setattr( + cuda_linear_logp, + "tensor_parallel_linear_logp", + lambda *args, **kwargs: sentinel, + ) + + out = op( + hidden, + weight, + target, + tp_group=object(), + vocab_start_index=3, + global_vocab_size=6, + ) + + assert out is sentinel + + def test_registry_dispatch_matches_native(): from rl_engine.kernels.registry import kernel_registry from rl_engine.platforms.device import device_ctx diff --git a/tests/test_op_accuracy.py b/tests/test_op_accuracy.py index 5a73e26..56f7ab6 100644 --- a/tests/test_op_accuracy.py +++ b/tests/test_op_accuracy.py @@ -14,6 +14,27 @@ def _fused_logp_op(op_type: str = "logp"): return kernel_registry.get_op(op_type) +def _sm90_logp_available(): + if not torch.cuda.is_available(): + return False + try: + from rl_engine.kernels.ops.base import _C, _EXT_AVAILABLE + + return ( + _EXT_AVAILABLE + and hasattr(_C, "fused_logp_sm90") + and torch.cuda.get_device_capability()[0] == 9 + ) + except Exception: # pragma: no cover + return False + + +requires_sm90_logp = pytest.mark.skipif( + not _sm90_logp_available(), + reason="SM90 fused logp requires a Hopper GPU and the extension built with SM90 enabled.", +) + + def _reference_selected_logp(logits: torch.Tensor, token_ids: torch.Tensor) -> torch.Tensor: ref_logp = torch.log_softmax(logits.float(), dim=-1) return torch.gather(ref_logp, dim=-1, index=token_ids.long().unsqueeze(-1)).squeeze(-1) @@ -166,6 +187,44 @@ def test_accuracy(): assert diff < threshold +@requires_sm90_logp +def test_sm90_logp_registry_uses_generic_default_on_sm90(monkeypatch): + from rl_engine.kernels.registry import KernelRegistry + + monkeypatch.delenv("RL_KERNEL_ENABLE_EXPERIMENTAL_SM90_LOGP", raising=False) + + op = KernelRegistry().get_op("logp") + + assert op.__class__.__name__ == "FusedLogpGenericOp" + + +@requires_sm90_logp +def test_sm90_logp_bf16_defaults_to_generic_fallback(monkeypatch): + from rl_engine.kernels.ops.cuda.loss.logp import FusedLogpSM90Op + + op = FusedLogpSM90Op() + monkeypatch.delenv("RL_KERNEL_ENABLE_EXPERIMENTAL_SM90_LOGP", raising=False) + calls = {} + + class Fallback: + def apply(self, logits_arg, token_ids_arg): + calls["apply"] = True + return _reference_selected_logp(logits_arg, token_ids_arg) + + monkeypatch.setattr(op, "_fallback_op", lambda: Fallback()) + + generator = torch.Generator(device="cuda").manual_seed(189) + logits = torch.randn(128, 4096, device="cuda", dtype=torch.bfloat16, generator=generator) + token_ids = torch.randint(0, logits.size(-1), (logits.size(0),), device="cuda") + + result = op(logits.contiguous(), token_ids) + expected = _reference_selected_logp(logits, token_ids) + + assert calls["apply"] is True + assert result.shape == token_ids.shape + assert torch.equal(result, expected) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") def test_fused_logp_out_reuses_output_storage(): device = torch.device("cuda") diff --git a/tests/test_rl_kernel_loss_step.py b/tests/test_rl_kernel_loss_step.py index 934b241..597f48a 100644 --- a/tests/test_rl_kernel_loss_step.py +++ b/tests/test_rl_kernel_loss_step.py @@ -146,7 +146,7 @@ def test_minimal_rl_loss_step_fused_logp_candidate_cuda(): old_logps = reference_logps - 0.01 ref_logps = reference_logps - 0.02 candidate_op = kernel_registry.get_op("logp") - if candidate_op.__class__.__name__ != "FusedLogpGenericOp": + if candidate_op.__class__.__name__ not in {"FusedLogpGenericOp", "FusedLogpSM90Op"}: pytest.skip("fused logp CUDA backend is unavailable") candidate_logps = candidate_op(logits, batch.token_ids).float()