Skip to content

[FEAT][kernels] Add tensor-parallel linear_logp path#189

Open
inaniloquentee wants to merge 7 commits into
mainfrom
feat/fused-linear-logp
Open

[FEAT][kernels] Add tensor-parallel linear_logp path#189
inaniloquentee wants to merge 7 commits into
mainfrom
feat/fused-linear-logp

Conversation

@inaniloquentee

@inaniloquentee inaniloquentee commented Jun 24, 2026

Copy link
Copy Markdown
Collaborator

Summary

Tensor-parallel linear_logp now supports vocab-sharded LM-head weights through
the public operator API. The TP path accepts global target_ids, local
lm_head_weight[vocab_start:vocab_end], optional local bias, a TP process group,
and global vocab metadata.

Path Status
PyTorch TP path Shared autograd implementation for native, Triton, and SM90 wrappers.
Registry/Triton TP wrappers TP kwargs delegate into the shared TP path instead of local non-TP kernels.
SM90 TP wrapper Constructs the SM90 op when compiled, then delegates TP kwargs to the shared TP path.
4-GPU runbook Validated on 4x H100 with NCCL.

The TP implementation computes global log-sum-exp across vocab shards, sums the
rank-local target logit contribution, all-reduces hidden.grad, and keeps
weight.grad / bias.grad local to each vocab shard.

Implementation

  • Added TP metadata validation for contiguous [0, V) vocab partitions and
    exactly-one-owner target coverage.
  • Added chunked forward stats over local vocab shards: local max, local sum,
    local target logit, and owner count.
  • Added custom TP autograd:
    • forward merges local online-softmax state with NCCL collectives;
    • backward recomputes local logits by chunks;
    • hidden.grad is all-reduced across ranks;
    • weight.grad and bias.grad remain shard-local.
  • Aligned low-precision TP math with fused/Triton bf16 semantics by computing
    bf16/fp16 matmuls in fp32 before casting returned gradients back to input
    dtypes.
  • Fixed SM90 extension build issues:
    • TMA inline PTX now uses the CUDA 12.4-supported
      shared::cluster.global.tile form.
    • TMA/mbarrier helpers use shared-address operands accepted by ptxas.
    • The extension links with a PyTorch library rpath so _C imports without
      manual LD_LIBRARY_PATH.

Validation Environment

Item Value
GPU 4x NVIDIA H100 80GB HBM3
Driver 580.126.09
CUDA driver/runtime driver CUDA 13.0, PyTorch CUDA 12.4
PyTorch 2.4.1+cu124
Distributed backend NCCL, world_size=4
Extension KERNEL_ALIGN_FORCE_SM90=1, _C.fused_linear_logp_sm90 available

Correctness / Tests

Native TP math sanity

torchrun --standalone --nproc_per_node=4 scripts/test_linear_logp_tp.py \
  --op-source native \
  --dtype fp32 \
  --tokens 128 \
  --hidden-size 256 \
  --vocab-size 4096 \
  --uneven-shards

Result: PASS.

Metric max_abs max_rel
output 4.196167e-05 1.038115e-06
hidden_grad 6.890297e-05 1.170046e-01
weight_grad 5.006790e-05 8.836515e-02
bias_grad 1.573563e-05 1.523999e-02

Registry bf16 main gate

torchrun --standalone --nproc_per_node=4 scripts/test_linear_logp_tp.py \
  --op-source registry \
  --dtype bf16 \
  --reference-mode fp32 \
  --tokens 256 \
  --hidden-size 512 \
  --vocab-size 8192 \
  --uneven-shards

Result: PASS.

Metric max_abs max_rel
output 5.340576e-05 1.703110e-06
hidden_grad 1.562500e-02 8.754433e+00
weight_grad 1.562500e-02 1.062290e-01
bias_grad 4.882812e-04 6.451613e-03

Wrapper delegation

Both wrapper paths passed the same bf16/fp32-reference TP correctness run:

op_source Result
triton PASS
sm90 PASS

SM90 log confirms the extension was linked:

Successfully linked to precompiled _C.fused_linear_logp_sm90 kernel.

Same-dtype bf16 drift check

torchrun --standalone --nproc_per_node=4 scripts/test_linear_logp_tp.py \
  --op-source registry \
  --dtype bf16 \
  --reference-mode matching \
  --tokens 128 \
  --hidden-size 512 \
  --vocab-size 8192 \
  --atol 0.75 \
  --rtol 0.75

Result: PASS with recorded full-GEMM vs shard-GEMM drift.

Metric max_abs max_rel
output 3.042984e-01 6.301940e-03
hidden_grad 8.593750e-01 3.489747e+02
weight_grad 5.625000e-01 4.345703e+06
bias_grad 1.953125e-01 8.200000e+00

This drift comes from comparing PyTorch's full bf16 GEMM against separate
vocab-shard GEMMs. The main correctness target is fp32 accumulation, matching
the fused/Triton bf16 operator semantics.

Larger stress

torchrun --standalone --nproc_per_node=4 scripts/test_linear_logp_tp.py \
  --op-source registry \
  --dtype bf16 \
  --reference-mode fp32 \
  --tokens 256 \
  --hidden-size 512 \
  --vocab-size 8192 \
  --run-stress \
  --stress-tokens 4096 \
  --stress-hidden-size 2048 \
  --stress-vocab-size 32768

Result: PASS.

Metric Value
finite PASS
max_rank_elapsed_ms 105.494
max_rank_peak_memory_gb 0.469

Unit coverage

python -m pytest tests/test_linear_logp.py -q -rs

Result: 21 passed in 8.40s.

Latest PR 189 Test Results

Validated on 2026-06-27 UTC on the pr-189 / feat/fused-linear-logp branch.
Runtime code was validated at ddf65a7; follow-up 355a2e0 only removes local
PR notes from the submitted branch and passed docs/pre-commit checks.

Environment

Item Value
Python 3.11.10
PyTorch 2.4.1+cu124
CUDA runtime/toolkit 12.4 / nvcc 12.4.131
Driver 580.126.09
GPU 2 x NVIDIA H100 80GB HBM3, compute capability 9.0
DeepSpeed 0.19.2, installed with DS_BUILD_OPS=0
Extension Editable rebuild with KERNEL_ALIGN_FORCE_SM90=1; _C.fused_linear_logp_sm90 available

CI and Regression Checks

Area Command Result
PR-focused linear logp unit tests python -m pytest tests/test_linear_logp.py -q -rs 27 passed
DeepSpeed worker contract tests python -m pytest tests/test_deepspeed_training_worker.py -q -rs 22 passed
Fused logp fallback and CUDA loss-step regression python -m pytest tests/test_op_accuracy.py tests/test_rl_kernel_loss_step.py -q -rs --tb=short 21 passed
Full H100/SM90 pytest suite python -m pytest tests rl_engine/tests -q -rs --tb=short 283 passed, 82 skipped
CI dispatch baseline python -m pytest rl_engine/tests/test_dispatch.py -v 5 passed
CI attention baseline PYTEST_DISABLE_PLUGIN_AUTOLOAD=1 python -m pytest tests/test_attention_correctness.py -q -rs 45 passed, 82 skipped
Type check python -m mypy --ignore-missing-imports rl_engine/ Success: no issues found in 50 source files
Documentation build mkdocs build --strict -f mkdocs.yaml Passed
Pre-commit pre-commit run --all-files Passed

The attention skips were expected for this CUDA/H100 environment: external
flash_attn / _C.flash_attn_forward was unavailable, and ROCm cases were
skipped on NVIDIA CUDA.

Production-Like CUDA Checks

Scenario Result Notes
2-GPU direct SM90 linear_logp TP stress, bf16/fp32 reference, uneven shards PASS Stress finite check passed; max rank elapsed 184.532 ms; peak memory 1.696 GiB
2-GPU registry-dispatched linear_logp TP, bf16/fp32 reference, uneven shards PASS Registry selected the fused SM90 TMA linear-logp path
Real DeepSpeed CUDA smoke PASS DeepSpeedEngine, zero stage 0, current_logp_backend=FusedLinearLogpSM90Op, finite loss, published weight version 11

Direct SM90 TP correctness metrics: output max abs 3.051758e-04, hidden grad
6.250000e-02, weight grad 6.250000e-02, bias grad 1.562500e-02.
Registry TP correctness metrics: output max abs 1.068115e-04, hidden grad
3.125000e-02, weight grad 3.125000e-02, bias grad 1.953125e-03.

Regression Found and Fixed

A full-suite H100 run exposed that the legacy SM90 logp registry path selected
FusedLogpSM90Op for cases it did not fully support: fp16 logits, .out(...),
and .apply_fp32(...). The wrapper now delegates to FusedLogpGenericOp by
default, and the legacy SM90 TMA logp fast path is gated behind
RL_KERNEL_ENABLE_EXPERIMENTAL_SM90_LOGP=1. This does not affect the PR's fused
SM90 linear_logp production path, which is covered by the two-GPU checks above.

Summary by CodeRabbit

  • New Features

    • Added vocab-sharded tensor-parallel support for linear_logp, including new tp_group, vocab_start_index, and global_vocab_size options with TP-aware dispatch (including SM90 when eligible).
    • Added distributed scripts/runbooks and enhanced the training worker to compute log-probabilities via linear_logp with TP metadata.
  • Bug Fixes

    • Improved tensor-parallel metadata validation and tightened correctness for forward/backward outputs and gradients (including ignore-index handling).
  • Documentation

    • Expanded linear_logp operator docs and added an end-to-end TP test runbook.
  • Tests

    • Added comprehensive correctness, gradient, and multi-rank TP coverage (single- and distributed).

@coderabbitai

coderabbitai Bot commented Jun 24, 2026

Copy link
Copy Markdown

Review Change Stack

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds tensor-parallel linear_logp execution across Native, Triton, and SM90 paths, updates DeepSpeed training to use it, adds distributed validation and pytest coverage, and refreshes related docs, build rpaths, TMA helpers, and benchmark plumbing.

Changes

Tensor-Parallel linear_logp

Layer / File(s) Summary
TP core op and validation
rl_engine/kernels/ops/pytorch/loss/linear_logp.py, rl_engine/kernels/ops/triton/loss/linear_logp.py, rl_engine/kernels/ops/cuda/loss/linear_logp.py
Adds TP scaffolding, shard validation, chunked local log-prob computation, the TP autograd function, the public tensor-parallel entrypoint, and conditional gradient matmul handling in the PyTorch implementation.
SM90 and TMA backend routing
rl_engine/kernels/ops/cuda/loss/linear_logp.py, rl_engine/kernels/ops/cuda/loss/logp.py, rl_engine/kernels/registry.py, csrc/utils/tma_utils.cuh, csrc/cuda/fused_linear_logp_sm90.cu
Extends CUDA log-prob and fused-linear-logp entrypoints with tensor-parallel parameters, adds SM90 gating and fallback routing, updates registry priority handling, and promotes TMA shared-memory helpers to 64-bit addresses.
DeepSpeed training path and worker tests
rl_engine/executors/deepspeed_trainer.py, tests/test_deepspeed_training_worker.py
Updates DeepSpeed training to compute log-probs through the linear_logp op, handle ZeRO-3 parameter gathering and resolved config state, and adds unit tests for log-prob extraction, hidden-state extraction, routing, ignore-index handling, and gather behavior.
Distributed validation script and TP tests
scripts/test_linear_logp_tp.py, tests/test_linear_logp.py
Adds the torchrun validation script and expands pytest coverage with distributed correctness checks, layout-invariance tests, chunked backward comparisons, and SM90 TP dispatch tests.
Docs, build, and benchmark updates
docs/operators/linear-logp.md, docs/operators/linear-logp-tp-test.md, docs/.nav.yml, docs/operators/README.md, docs/contributing/README.md, docs/contributing/pr-189-test-results.md, docs/contributing/vime-RLK.md, docs/design/runtime-dispatch.md, docs/operators/fused-logp.md, setup.py, benchmarks/benchmark_linear_logp.py
Updates operator docs and navigation, adds the TP test runbook and related contribution pages, adjusts extension rpath settings in setup.py, and changes the benchmark helper signatures to pass targets explicitly.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Suggested labels

component: kernels

Suggested reviewers

  • EthanZero2Hero
  • maxiaosong1124

Poem

🐇 I hop through shards and lanes of light,
TP log-probs sing through the night.
SM90 twinkles, docs take flight,
Tests and trainers all run right.

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 4.20% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly matches the main change: adding a tensor-parallel linear_logp path.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feat/fused-linear-logp

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@inaniloquentee inaniloquentee marked this pull request as draft June 24, 2026 07:56

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🧹 Nitpick comments (1)
tests/test_linear_logp.py (1)

234-237: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick win

Surface child tracebacks before exit-code assertions.

Workers already send tracebacks through result_queue, but Line 235 can fail first and hide the useful distributed failure details.

Proposed fix
-    for process in processes:
-        assert process.exitcode == 0
-    for result in sorted(results, key=lambda item: item["rank"]):
+    sorted_results = sorted(results, key=lambda item: item["rank"])
+    for result in sorted_results:
         assert result["ok"], result.get("traceback")
+    for process in processes:
+        assert process.exitcode == 0
+    for result in sorted_results:
         assert result["out"] < 1e-5
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/test_linear_logp.py` around lines 234 - 237, Surface child tracebacks
before checking worker exit codes in the test flow. In the `test_linear_logp`
assertions, inspect the collected `results` and fail on any `result["ok"]` false
(using `result.get("traceback")`) before asserting `process.exitcode == 0` for
each process. Keep the existing `processes` and `results` checks, but reorder
them so distributed traceback details from the worker result queue are shown
first.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@rl_engine/kernels/ops/pytorch/loss/linear_logp.py`:
- Around line 95-111: The TP validation in _validate_global_vocab_size and
_validate_global_targets is currently local, so one rank can raise while others
keep going and later hang in collectives. Add a collective error check before
either ValueError is raised by all-reducing a per-rank failure flag so every
rank agrees to raise or continue together. Keep the fix centered on these two
helpers in linear_logp.py and preserve the existing error messages, only making
the failure path collective.
- Around line 300-304: The bias validation in _chunked_local_linear_logp_stats
only checks numel(), so it can accept non-1D shapes like (V, 1) that later slice
incorrectly and may mis-broadcast when added to the local logits. Tighten the
check near the existing bias.numel() validation to require bias to be exactly
1-D for the local vocab shard (using bias.ndim/shape), and keep the existing
lm_head_weight.size(0) element-count check so only a flat bias matching the
shard size is allowed.

In `@scripts/test_linear_logp_tp.py`:
- Around line 202-207: The stress phase in `step()` is only validating gradients
because the returned `out` is discarded, so the finiteness check is incomplete.
Update `step()` and the surrounding stress loop to keep the `out` from the model
forward pass and include it in the same finite-value validation used for
gradients, so both outputs and grads are checked during the stress run.

In `@tests/test_linear_logp.py`:
- Line 5: The distributed test setup in test_linear_logp is using a free TCP
port that gets released before dist.init_process_group() starts, which can let
another process claim it and make the 4-rank Gloo test flaky. Update the
rendezvous setup used by the affected test helpers in test_linear_logp to avoid
the race by either switching to a per-test file:// init_method or keeping a
TCPStore alive until process group initialization completes. Use the existing
test setup functions around the init_process_group call to make the change
without relying on the temporary socket being closed early.

In `@杂/KJ.md`:
- Around line 61-75: The markdown file has a missing trailing newline, which
causes end-of-file-fixer to rewrite it and fail pre-commit. Update the
documentation file itself so it ends with a single newline after the release
notes content, keeping the existing text unchanged. Verify the fix against the
generated release-notes block so the file remains stable under linting and
commit hooks.
- Around line 48-58: Add language tags to the two fenced code blocks in this
markdown snippet to satisfy MD040. Update the fenced blocks in this section so
the first error example and the second template-instantiation error example both
have an explicit language identifier, using the same fenced-block formatting
already used elsewhere in the document.

---

Nitpick comments:
In `@tests/test_linear_logp.py`:
- Around line 234-237: Surface child tracebacks before checking worker exit
codes in the test flow. In the `test_linear_logp` assertions, inspect the
collected `results` and fail on any `result["ok"]` false (using
`result.get("traceback")`) before asserting `process.exitcode == 0` for each
process. Keep the existing `processes` and `results` checks, but reorder them so
distributed traceback details from the worker result queue are shown first.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 295838a9-5416-4f8a-a4a3-a5a2dbbc11bd

📥 Commits

Reviewing files that changed from the base of the PR and between be5ec9b and a18a5e4.

📒 Files selected for processing (11)
  • benchmarks/benchmark_linear_logp.py
  • docs/.nav.yml
  • docs/operators/README.md
  • docs/operators/linear-logp-tp-test.md
  • docs/operators/linear-logp.md
  • rl_engine/kernels/ops/cuda/loss/linear_logp.py
  • rl_engine/kernels/ops/pytorch/loss/linear_logp.py
  • rl_engine/kernels/ops/triton/loss/linear_logp.py
  • scripts/test_linear_logp_tp.py
  • tests/test_linear_logp.py
  • 杂/KJ.md

Comment thread rl_engine/kernels/ops/pytorch/loss/linear_logp.py Outdated
Comment thread rl_engine/kernels/ops/pytorch/loss/linear_logp.py
Comment thread scripts/test_linear_logp_tp.py
Comment thread tests/test_linear_logp.py Outdated
Comment thread 杂/KJ.md Outdated
Comment on lines +48 to +58
```
fused_logp_sm90.cu(62): error: identifier "CUDART_INF_F" is undefined
```
Fix: `#include <math_constants.h>`.

2. **`TmaTypeTraits<c10::BFloat16>` is an incomplete type** — the host wrapper passes `logits.data_ptr<at::BFloat16>()` (i.e. `c10::BFloat16*`) straight into `init_tensor_map`, but `TmaTypeTraits` is only specialized for `nv_bfloat16`/`float`,
so template instantiation fails.
```
tma_utils.cuh: error: incomplete type "TmaTypeTraits<c10::BFloat16>" is not allowed
```
Fix: `reinterpret_cast<const nv_bfloat16*>(logits.data_ptr<at::BFloat16>())`(matching how the linear kernel and the rest of the codebase bridge `at::BFloat16` → `nv_bfloat16`).

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win

Add language tags to fenced code blocks.

Two fenced blocks are missing language identifiers (MD040), which triggers markdown lint warnings.

Suggested patch
-   ```
+   ```text
    fused_logp_sm90.cu(62): error: identifier "CUDART_INF_F" is undefined
    ```
...
-   ```
+   ```text
    tma_utils.cuh: error: incomplete type "TmaTypeTraits<c10::BFloat16>" is not allowed
    ```
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
```
fused_logp_sm90.cu(62): error: identifier "CUDART_INF_F" is undefined
```
Fix: `#include <math_constants.h>`.
2. **`TmaTypeTraits<c10::BFloat16>` is an incomplete type** — the host wrapper passes `logits.data_ptr<at::BFloat16>()` (i.e. `c10::BFloat16*`) straight into `init_tensor_map`, but `TmaTypeTraits` is only specialized for `nv_bfloat16`/`float`,
so template instantiation fails.
```
tma_utils.cuh: error: incomplete type "TmaTypeTraits<c10::BFloat16>" is not allowed
```
Fix: `reinterpret_cast<const nv_bfloat16*>(logits.data_ptr<at::BFloat16>())`(matching how the linear kernel and the rest of the codebase bridge `at::BFloat16``nv_bfloat16`).
🧰 Tools
🪛 markdownlint-cli2 (0.22.1)

[warning] 48-48: Fenced code blocks should have a language specified

(MD040, fenced-code-language)


[warning] 55-55: Fenced code blocks should have a language specified

(MD040, fenced-code-language)

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@杂/KJ.md` around lines 48 - 58, Add language tags to the two fenced code
blocks in this markdown snippet to satisfy MD040. Update the fenced blocks in
this section so the first error example and the second template-instantiation
error example both have an explicit language identifier, using the same
fenced-block formatting already used elsewhere in the document.

Source: Linters/SAST tools

@inaniloquentee inaniloquentee force-pushed the feat/fused-linear-logp branch from a18a5e4 to 453d38f Compare June 24, 2026 10:30
Signed-off-by: inaniloquentee <3051000145@qq.com>
@inaniloquentee inaniloquentee force-pushed the feat/fused-linear-logp branch from 453d38f to 5e040e3 Compare June 24, 2026 10:41
@inaniloquentee inaniloquentee marked this pull request as ready for review June 24, 2026 10:48
@inaniloquentee inaniloquentee requested a review from bitborne as a code owner June 24, 2026 10:48

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (2)
setup.py (2)

56-56: 📐 Maintainability & Code Quality | 🔵 Trivial | 💤 Low value

Inconsistent handling of the shared torch_rpath list.

Line 56 assigns the shared torch_rpath reference directly, while line 114 copies via list(torch_rpath). If the ROCm branch ever appends to its extra_link_args, it would mutate the shared list and bleed into other consumers. Use list(torch_rpath) in both places for consistency and to avoid aliasing surprises.

♻️ Proposed change
-                extra_link_args=torch_rpath,
+                extra_link_args=list(torch_rpath),

Also applies to: 114-114

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@setup.py` at line 56, The handling of torch_rpath is inconsistent and can
leak shared mutations across branches. In the setup configuration where
extra_link_args is assigned, update the assignment in both the main and ROCm
paths to use a copied list instead of the shared torch_rpath reference. Keep the
change localized around the extra_link_args setup so both call sites use
list(torch_rpath) and avoid aliasing side effects.

40-41: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚖️ Poor tradeoff

Absolute Torch lib path baked into rpath may hurt build reproducibility.

The second entry embeds the build-machine absolute path ({torch_lib_dir}) into the binary's rpath. The $ORIGIN/../torch/lib entry already provides relocatable resolution for installed wheels; the absolute path leaks build-time paths and can break reproducible/portable builds when the artifact runs on a host where Torch lives elsewhere. Consider keeping the absolute entry only as a deliberate dev-build fallback, or drop it if $ORIGIN resolution is sufficient.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@setup.py` around lines 40 - 41, The rpath setup in setup.py is baking a
build-machine absolute Torch library path into the binary via torch_rpath, which
harms portability and reproducibility. Update the rpath construction around
torch_lib_dir and torch_rpath so the $ORIGIN-based entry is the default for
installed wheels, and either remove the absolute {torch_lib_dir} entry or guard
it behind an explicit dev-build fallback path. Keep the change localized to the
torch rpath logic in setup.py.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@csrc/cuda/fused_linear_logp_sm90.cu`:
- Around line 116-120: The barrier priming in the SM90 load path is ordered too
late, which can race with the TMA copies. In fused_linear_logp_sm90.cu, update
the buffer setup around tma_2d_g2s and mbarrier_arrive_expect_tx so the barrier
is primed on mbar[buf] before issuing either the h_tmap or w_tmap copy, keeping
the expected byte count in place before the transfers begin.

In `@rl_engine/kernels/ops/pytorch/loss/linear_logp.py`:
- Around line 65-66: The empty-shard check in linear_logp should not raise
locally before the tensor-parallel collective validation runs, because only the
rank with an empty shard will fail and other ranks can hang in
_validate_tp_vocab_partition. Update the logic around the local_vocab_size guard
in the linear_logp path so the validation is deferred to the existing partition
check, using the gathered end <= start condition to fail consistently on every
rank for explicit multi-rank TP.

---

Nitpick comments:
In `@setup.py`:
- Line 56: The handling of torch_rpath is inconsistent and can leak shared
mutations across branches. In the setup configuration where extra_link_args is
assigned, update the assignment in both the main and ROCm paths to use a copied
list instead of the shared torch_rpath reference. Keep the change localized
around the extra_link_args setup so both call sites use list(torch_rpath) and
avoid aliasing side effects.
- Around line 40-41: The rpath setup in setup.py is baking a build-machine
absolute Torch library path into the binary via torch_rpath, which harms
portability and reproducibility. Update the rpath construction around
torch_lib_dir and torch_rpath so the $ORIGIN-based entry is the default for
installed wheels, and either remove the absolute {torch_lib_dir} entry or guard
it behind an explicit dev-build fallback path. Keep the change localized to the
torch rpath logic in setup.py.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 2de76294-4e7f-4af4-80f1-b52dbe476ac3

📥 Commits

Reviewing files that changed from the base of the PR and between a18a5e4 and 5e040e3.

📒 Files selected for processing (13)
  • benchmarks/benchmark_linear_logp.py
  • csrc/cuda/fused_linear_logp_sm90.cu
  • csrc/utils/tma_utils.cuh
  • docs/.nav.yml
  • docs/operators/README.md
  • docs/operators/linear-logp-tp-test.md
  • docs/operators/linear-logp.md
  • rl_engine/kernels/ops/cuda/loss/linear_logp.py
  • rl_engine/kernels/ops/pytorch/loss/linear_logp.py
  • rl_engine/kernels/ops/triton/loss/linear_logp.py
  • scripts/test_linear_logp_tp.py
  • setup.py
  • tests/test_linear_logp.py
✅ Files skipped from review due to trivial changes (3)
  • docs/.nav.yml
  • docs/operators/README.md
  • docs/operators/linear-logp-tp-test.md
🚧 Files skipped from review as they are similar to previous changes (6)
  • benchmarks/benchmark_linear_logp.py
  • tests/test_linear_logp.py
  • docs/operators/linear-logp.md
  • rl_engine/kernels/ops/triton/loss/linear_logp.py
  • rl_engine/kernels/ops/cuda/loss/linear_logp.py
  • scripts/test_linear_logp_tp.py

Comment thread csrc/cuda/fused_linear_logp_sm90.cu Outdated
Comment thread rl_engine/kernels/ops/pytorch/loss/linear_logp.py Outdated
@RL-Align RL-Align deleted a comment from coderabbitai Bot Jun 24, 2026
…GEMM

Signed-off-by: inaniloquentee <3051000145@qq.com>
@inaniloquentee inaniloquentee marked this pull request as draft June 27, 2026 08:13
@inaniloquentee inaniloquentee marked this pull request as ready for review June 27, 2026 08:15

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (1)
tests/test_deepspeed_training_worker.py (1)

64-65: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick win

Annotate these shared lists as ClassVar. Ruff flags these mutable class attributes, and the lists are intentionally shared test state.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/test_deepspeed_training_worker.py` around lines 64 - 65, The shared
mutable test state in the test class is being flagged by Ruff because
modifier_ranks and parameter_counts are plain class attributes; annotate these
attributes as ClassVar in the test class so it is clear they are intentionally
shared across instances. Update the class definition where modifier_ranks and
parameter_counts are declared, using the existing test class name as the
locator, and keep their behavior unchanged.

Source: Linters/SAST tools

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@rl_engine/executors/deepspeed_trainer.py`:
- Around line 219-226: The layout assembly in deepspeed_trainer.py currently
lets manifest_metadata overwrite authoritative fields in the resolved layout.
Update the logic around the layout dict in the DeepSpeed trainer so that kind,
zero_stage, world_size, and rank remain fixed to the values computed from
_deepspeed_zero_stage, _engine_world_size(), and _engine_rank(), while still
merging any extra caller-provided layout metadata from
manifest_metadata.get("layout", {}) only for non-reserved fields.

In `@rl_engine/kernels/ops/cuda/loss/linear_logp.py`:
- Line 150: The SM90 path in fused_linear_logp is still passing the original
bias tensor instead of the prepared contiguous bias view. Update the forward
call in the fused_linear_logp_sm90 branch to use bias_t consistently, and make
sure the same contiguous bias view is also used by any corresponding backward
path so both directions avoid relying on extension stride handling.

In `@tests/test_linear_logp.py`:
- Around line 723-738: Tighten the portable fallback test for tensor-parallel
linear logp: the current sentinel stub only checks that delegation happens, not
that tp_group, vocab_start_index, and global_vocab_size are forwarded. Update
the test around op and the monkeypatched tensor_parallel_linear_logp to capture
the forwarded args/kwargs and assert those TP metadata values are preserved when
the SM90 wrapper delegates.

---

Nitpick comments:
In `@tests/test_deepspeed_training_worker.py`:
- Around line 64-65: The shared mutable test state in the test class is being
flagged by Ruff because modifier_ranks and parameter_counts are plain class
attributes; annotate these attributes as ClassVar in the test class so it is
clear they are intentionally shared across instances. Update the class
definition where modifier_ranks and parameter_counts are declared, using the
existing test class name as the locator, and keep their behavior unchanged.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 28f08037-c31c-4cdb-84ee-12071e09970d

📥 Commits

Reviewing files that changed from the base of the PR and between 5e040e3 and 4858bbd.

📒 Files selected for processing (4)
  • rl_engine/executors/deepspeed_trainer.py
  • rl_engine/kernels/ops/cuda/loss/linear_logp.py
  • tests/test_deepspeed_training_worker.py
  • tests/test_linear_logp.py

Comment on lines 219 to 226
layout = {
"kind": "full-state",
"zero_stage": self.config.zero_stage,
"zero_stage": self._deepspeed_zero_stage,
"world_size": self._engine_world_size(),
"rank": self._engine_rank(),
}
layout.update(dict(manifest_metadata.get("layout", {})))
manifest_metadata["layout"] = layout

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🗄️ Data Integrity & Integration | 🟠 Major | ⚡ Quick win

Keep resolved layout fields authoritative.

layout.update(...) lets caller metadata override zero_stage, world_size, rank, or kind, which can make ZeRO-3 full-state manifests lie about the actual resolved DeepSpeed layout.

Proposed fix
-        layout = {
+        layout = dict(manifest_metadata.get("layout", {}))
+        layout.update({
             "kind": "full-state",
             "zero_stage": self._deepspeed_zero_stage,
             "world_size": self._engine_world_size(),
             "rank": self._engine_rank(),
-        }
-        layout.update(dict(manifest_metadata.get("layout", {})))
+        })
         manifest_metadata["layout"] = layout
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
layout = {
"kind": "full-state",
"zero_stage": self.config.zero_stage,
"zero_stage": self._deepspeed_zero_stage,
"world_size": self._engine_world_size(),
"rank": self._engine_rank(),
}
layout.update(dict(manifest_metadata.get("layout", {})))
manifest_metadata["layout"] = layout
layout = dict(manifest_metadata.get("layout", {}))
layout.update({
"kind": "full-state",
"zero_stage": self._deepspeed_zero_stage,
"world_size": self._engine_world_size(),
"rank": self._engine_rank(),
})
manifest_metadata["layout"] = layout
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/executors/deepspeed_trainer.py` around lines 219 - 226, The layout
assembly in deepspeed_trainer.py currently lets manifest_metadata overwrite
authoritative fields in the resolved layout. Update the logic around the layout
dict in the DeepSpeed trainer so that kind, zero_stage, world_size, and rank
remain fixed to the values computed from _deepspeed_zero_stage,
_engine_world_size(), and _engine_rank(), while still merging any extra
caller-provided layout metadata from manifest_metadata.get("layout", {}) only
for non-reserved fields.

)
kernel_target = kernel_target.to(torch.int32).contiguous()

local_logp, local_lse = _C.fused_linear_logp_sm90(hidden_2d, weight, kernel_target, bias)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🩺 Stability & Availability | 🟡 Minor | ⚡ Quick win

Pass the contiguous bias to the SM90 kernel.

bias_t is prepared for this path, but the kernel still receives the original bias. Use the same contiguous bias view for forward and backward instead of relying on extension stride handling.

Proposed fix
-        local_logp, local_lse = _C.fused_linear_logp_sm90(hidden_2d, weight, kernel_target, bias)
+        kernel_bias = bias_t if bias is not None else None
+        local_logp, local_lse = _C.fused_linear_logp_sm90(
+            hidden_2d,
+            weight,
+            kernel_target,
+            kernel_bias,
+        )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
local_logp, local_lse = _C.fused_linear_logp_sm90(hidden_2d, weight, kernel_target, bias)
kernel_bias = bias_t if bias is not None else None
local_logp, local_lse = _C.fused_linear_logp_sm90(
hidden_2d,
weight,
kernel_target,
kernel_bias,
)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/kernels/ops/cuda/loss/linear_logp.py` at line 150, The SM90 path in
fused_linear_logp is still passing the original bias tensor instead of the
prepared contiguous bias view. Update the forward call in the
fused_linear_logp_sm90 branch to use bias_t consistently, and make sure the same
contiguous bias view is also used by any corresponding backward path so both
directions avoid relying on extension stride handling.

Comment thread tests/test_linear_logp.py
Comment on lines +723 to +738
monkeypatch.setattr(
cuda_linear_logp,
"tensor_parallel_linear_logp",
lambda *args, **kwargs: sentinel,
)

out = op(
hidden,
weight,
target,
tp_group=object(),
vocab_start_index=3,
global_vocab_size=6,
)

assert out is sentinel

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟡 Minor | ⚡ Quick win

Assert TP metadata in the portable fallback test.

This test returns sentinel regardless of args, so it would still pass if the SM90 wrapper dropped tp_group, vocab_start_index, or global_vocab_size before delegating to tensor_parallel_linear_logp.

Proposed test tightening
     hidden = torch.randn(2, 4)
     weight = torch.randn(3, 4)
     target = torch.tensor([3, 5])
     sentinel = torch.full((2,), 11.0)
+    tp_group = object()
+    calls = {}
@@
-    monkeypatch.setattr(
-        cuda_linear_logp,
-        "tensor_parallel_linear_logp",
-        lambda *args, **kwargs: sentinel,
-    )
+    def fake_portable_tp(hidden_arg, weight_arg, target_arg, bias_arg, **kwargs):
+        calls["portable_tp"] = (hidden_arg, weight_arg, target_arg, bias_arg, kwargs)
+        return sentinel
+
+    monkeypatch.setattr(cuda_linear_logp, "tensor_parallel_linear_logp", fake_portable_tp)
@@
         hidden,
         weight,
         target,
-        tp_group=object(),
+        tp_group=tp_group,
         vocab_start_index=3,
         global_vocab_size=6,
     )
 
     assert out is sentinel
+    assert calls["portable_tp"][0] is hidden
+    assert calls["portable_tp"][1] is weight
+    assert calls["portable_tp"][2] is target
+    assert calls["portable_tp"][4] == {
+        "tp_group": tp_group,
+        "vocab_start_index": 3,
+        "global_vocab_size": 6,
+    }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
monkeypatch.setattr(
cuda_linear_logp,
"tensor_parallel_linear_logp",
lambda *args, **kwargs: sentinel,
)
out = op(
hidden,
weight,
target,
tp_group=object(),
vocab_start_index=3,
global_vocab_size=6,
)
assert out is sentinel
hidden = torch.randn(2, 4)
weight = torch.randn(3, 4)
target = torch.tensor([3, 5])
sentinel = torch.full((2,), 11.0)
tp_group = object()
calls = {}
def fake_portable_tp(hidden_arg, weight_arg, target_arg, bias_arg, **kwargs):
calls["portable_tp"] = (hidden_arg, weight_arg, target_arg, bias_arg, kwargs)
return sentinel
monkeypatch.setattr(cuda_linear_logp, "tensor_parallel_linear_logp", fake_portable_tp)
out = op(
hidden,
weight,
target,
tp_group=tp_group,
vocab_start_index=3,
global_vocab_size=6,
)
assert out is sentinel
assert calls["portable_tp"][0] is hidden
assert calls["portable_tp"][1] is weight
assert calls["portable_tp"][2] is target
assert calls["portable_tp"][4] == {
"tp_group": tp_group,
"vocab_start_index": 3,
"global_vocab_size": 6,
}
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/test_linear_logp.py` around lines 723 - 738, Tighten the portable
fallback test for tensor-parallel linear logp: the current sentinel stub only
checks that delegation happens, not that tp_group, vocab_start_index, and
global_vocab_size are forwarded. Update the test around op and the monkeypatched
tensor_parallel_linear_logp to capture the forwarded args/kwargs and assert
those TP metadata values are preserved when the SM90 wrapper delegates.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
tests/test_op_accuracy.py (1)

17-35: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick win

Align this helper with the runtime gate.

_sm90_logp_available() is Hopper-only (cc_major == 9), but the registry now treats 9, 10, and 12 as eligible for experimental fused-logp promotion. As written, these routing tests will skip on part of the supported surface.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/test_op_accuracy.py` around lines 17 - 35, Update
_sm90_logp_available() in the test helper to match the runtime fused-logp gate
used by the registry: instead of only accepting get_device_capability()[0] == 9,
allow all promoted eligible major versions (9, 10, and 12) while still requiring
torch.cuda.is_available(), _EXT_AVAILABLE, and hasattr(_C, "fused_logp_sm90").
Keep requires_sm90_logp using this helper so the routing tests skip only when
the runtime gate would also reject the device.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@docs/design/runtime-dispatch.md`:
- Around line 22-25: Tighten the runtime-dispatch capability wording to match
the actual registry contract: in the runtime dispatch documentation, replace the
open-ended “compute capability 9.0 or newer” phrasing with the specific
supported major versions used by the registry logic. Refer to the SM90 LogP
backend gating text so it clearly reflects that the experimental backend is only
promoted for the exact cc_major values handled by the implementation, while
keeping the fused linear logp SM90 backend description unchanged.

---

Nitpick comments:
In `@tests/test_op_accuracy.py`:
- Around line 17-35: Update _sm90_logp_available() in the test helper to match
the runtime fused-logp gate used by the registry: instead of only accepting
get_device_capability()[0] == 9, allow all promoted eligible major versions (9,
10, and 12) while still requiring torch.cuda.is_available(), _EXT_AVAILABLE, and
hasattr(_C, "fused_logp_sm90"). Keep requires_sm90_logp using this helper so the
routing tests skip only when the runtime gate would also reject the device.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 99e30768-f641-49a1-a262-e40f23105f25

📥 Commits

Reviewing files that changed from the base of the PR and between 81ed9b2 and ddf65a7.

📒 Files selected for processing (15)
  • csrc/cuda/fused_linear_logp_sm90.cu
  • docs/contributing/README.md
  • docs/contributing/pr-189-test-results.md
  • docs/contributing/vime-RLK.md
  • docs/design/runtime-dispatch.md
  • docs/operators/fused-logp.md
  • rl_engine/kernels/ops/cuda/loss/linear_logp.py
  • rl_engine/kernels/ops/cuda/loss/logp.py
  • rl_engine/kernels/ops/pytorch/loss/linear_logp.py
  • rl_engine/kernels/registry.py
  • scripts/test_linear_logp_tp.py
  • setup.py
  • tests/test_linear_logp.py
  • tests/test_op_accuracy.py
  • tests/test_rl_kernel_loss_step.py
✅ Files skipped from review due to trivial changes (3)
  • docs/contributing/README.md
  • docs/operators/fused-logp.md
  • docs/contributing/pr-189-test-results.md
🚧 Files skipped from review as they are similar to previous changes (6)
  • setup.py
  • csrc/cuda/fused_linear_logp_sm90.cu
  • scripts/test_linear_logp_tp.py
  • rl_engine/kernels/ops/cuda/loss/linear_logp.py
  • tests/test_linear_logp.py
  • rl_engine/kernels/ops/pytorch/loss/linear_logp.py

Comment on lines +22 to +25
For CUDA devices with compute capability 9.0 or newer, the registry only inserts
the legacy SM90 LogP backend when `RL_KERNEL_ENABLE_EXPERIMENTAL_SM90_LOGP=1` is
set. The fused linear logp SM90 backend is gated separately and remains the
default linear logp backend when the extension is built on Hopper.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win

Tighten the capability wording here.

This says 9.0 or newer, but the registry currently promotes the experimental backend only for cc_major in (9, 10, 12). Documenting it as an open-ended >= 9.0 gate already diverges from the implementation contract.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@docs/design/runtime-dispatch.md` around lines 22 - 25, Tighten the
runtime-dispatch capability wording to match the actual registry contract: in
the runtime dispatch documentation, replace the open-ended “compute capability
9.0 or newer” phrasing with the specific supported major versions used by the
registry logic. Refer to the SM90 LogP backend gating text so it clearly
reflects that the experimental backend is only promoted for the exact cc_major
values handled by the implementation, while keeping the fused linear logp SM90
backend description unchanged.

Signed-off-by: Codex <codex@openai.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants