[FEAT][kernels] Add tensor-parallel linear_logp path#189
[FEAT][kernels] Add tensor-parallel linear_logp path#189inaniloquentee wants to merge 7 commits into
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds tensor-parallel ChangesTensor-Parallel linear_logp
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 6
🧹 Nitpick comments (1)
tests/test_linear_logp.py (1)
234-237: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick winSurface child tracebacks before exit-code assertions.
Workers already send tracebacks through
result_queue, but Line 235 can fail first and hide the useful distributed failure details.Proposed fix
- for process in processes: - assert process.exitcode == 0 - for result in sorted(results, key=lambda item: item["rank"]): + sorted_results = sorted(results, key=lambda item: item["rank"]) + for result in sorted_results: assert result["ok"], result.get("traceback") + for process in processes: + assert process.exitcode == 0 + for result in sorted_results: assert result["out"] < 1e-5🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/test_linear_logp.py` around lines 234 - 237, Surface child tracebacks before checking worker exit codes in the test flow. In the `test_linear_logp` assertions, inspect the collected `results` and fail on any `result["ok"]` false (using `result.get("traceback")`) before asserting `process.exitcode == 0` for each process. Keep the existing `processes` and `results` checks, but reorder them so distributed traceback details from the worker result queue are shown first.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@rl_engine/kernels/ops/pytorch/loss/linear_logp.py`:
- Around line 95-111: The TP validation in _validate_global_vocab_size and
_validate_global_targets is currently local, so one rank can raise while others
keep going and later hang in collectives. Add a collective error check before
either ValueError is raised by all-reducing a per-rank failure flag so every
rank agrees to raise or continue together. Keep the fix centered on these two
helpers in linear_logp.py and preserve the existing error messages, only making
the failure path collective.
- Around line 300-304: The bias validation in _chunked_local_linear_logp_stats
only checks numel(), so it can accept non-1D shapes like (V, 1) that later slice
incorrectly and may mis-broadcast when added to the local logits. Tighten the
check near the existing bias.numel() validation to require bias to be exactly
1-D for the local vocab shard (using bias.ndim/shape), and keep the existing
lm_head_weight.size(0) element-count check so only a flat bias matching the
shard size is allowed.
In `@scripts/test_linear_logp_tp.py`:
- Around line 202-207: The stress phase in `step()` is only validating gradients
because the returned `out` is discarded, so the finiteness check is incomplete.
Update `step()` and the surrounding stress loop to keep the `out` from the model
forward pass and include it in the same finite-value validation used for
gradients, so both outputs and grads are checked during the stress run.
In `@tests/test_linear_logp.py`:
- Line 5: The distributed test setup in test_linear_logp is using a free TCP
port that gets released before dist.init_process_group() starts, which can let
another process claim it and make the 4-rank Gloo test flaky. Update the
rendezvous setup used by the affected test helpers in test_linear_logp to avoid
the race by either switching to a per-test file:// init_method or keeping a
TCPStore alive until process group initialization completes. Use the existing
test setup functions around the init_process_group call to make the change
without relying on the temporary socket being closed early.
In `@杂/KJ.md`:
- Around line 61-75: The markdown file has a missing trailing newline, which
causes end-of-file-fixer to rewrite it and fail pre-commit. Update the
documentation file itself so it ends with a single newline after the release
notes content, keeping the existing text unchanged. Verify the fix against the
generated release-notes block so the file remains stable under linting and
commit hooks.
- Around line 48-58: Add language tags to the two fenced code blocks in this
markdown snippet to satisfy MD040. Update the fenced blocks in this section so
the first error example and the second template-instantiation error example both
have an explicit language identifier, using the same fenced-block formatting
already used elsewhere in the document.
---
Nitpick comments:
In `@tests/test_linear_logp.py`:
- Around line 234-237: Surface child tracebacks before checking worker exit
codes in the test flow. In the `test_linear_logp` assertions, inspect the
collected `results` and fail on any `result["ok"]` false (using
`result.get("traceback")`) before asserting `process.exitcode == 0` for each
process. Keep the existing `processes` and `results` checks, but reorder them so
distributed traceback details from the worker result queue are shown first.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 295838a9-5416-4f8a-a4a3-a5a2dbbc11bd
📒 Files selected for processing (11)
benchmarks/benchmark_linear_logp.pydocs/.nav.ymldocs/operators/README.mddocs/operators/linear-logp-tp-test.mddocs/operators/linear-logp.mdrl_engine/kernels/ops/cuda/loss/linear_logp.pyrl_engine/kernels/ops/pytorch/loss/linear_logp.pyrl_engine/kernels/ops/triton/loss/linear_logp.pyscripts/test_linear_logp_tp.pytests/test_linear_logp.py杂/KJ.md
| ``` | ||
| fused_logp_sm90.cu(62): error: identifier "CUDART_INF_F" is undefined | ||
| ``` | ||
| Fix: `#include <math_constants.h>`. | ||
|
|
||
| 2. **`TmaTypeTraits<c10::BFloat16>` is an incomplete type** — the host wrapper passes `logits.data_ptr<at::BFloat16>()` (i.e. `c10::BFloat16*`) straight into `init_tensor_map`, but `TmaTypeTraits` is only specialized for `nv_bfloat16`/`float`, | ||
| so template instantiation fails. | ||
| ``` | ||
| tma_utils.cuh: error: incomplete type "TmaTypeTraits<c10::BFloat16>" is not allowed | ||
| ``` | ||
| Fix: `reinterpret_cast<const nv_bfloat16*>(logits.data_ptr<at::BFloat16>())`(matching how the linear kernel and the rest of the codebase bridge `at::BFloat16` → `nv_bfloat16`). |
There was a problem hiding this comment.
📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win
Add language tags to fenced code blocks.
Two fenced blocks are missing language identifiers (MD040), which triggers markdown lint warnings.
Suggested patch
- ```
+ ```text
fused_logp_sm90.cu(62): error: identifier "CUDART_INF_F" is undefined
```
...
- ```
+ ```text
tma_utils.cuh: error: incomplete type "TmaTypeTraits<c10::BFloat16>" is not allowed
```📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| ``` | |
| fused_logp_sm90.cu(62): error: identifier "CUDART_INF_F" is undefined | |
| ``` | |
| Fix: `#include <math_constants.h>`. | |
| 2. **`TmaTypeTraits<c10::BFloat16>` is an incomplete type** — the host wrapper passes `logits.data_ptr<at::BFloat16>()` (i.e. `c10::BFloat16*`) straight into `init_tensor_map`, but `TmaTypeTraits` is only specialized for `nv_bfloat16`/`float`, | |
| so template instantiation fails. | |
| ``` | |
| tma_utils.cuh: error: incomplete type "TmaTypeTraits<c10::BFloat16>" is not allowed | |
| ``` | |
| Fix: `reinterpret_cast<const nv_bfloat16*>(logits.data_ptr<at::BFloat16>())`(matching how the linear kernel and the rest of the codebase bridge `at::BFloat16` → `nv_bfloat16`). |
🧰 Tools
🪛 markdownlint-cli2 (0.22.1)
[warning] 48-48: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
[warning] 55-55: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@杂/KJ.md` around lines 48 - 58, Add language tags to the two fenced code
blocks in this markdown snippet to satisfy MD040. Update the fenced blocks in
this section so the first error example and the second template-instantiation
error example both have an explicit language identifier, using the same
fenced-block formatting already used elsewhere in the document.
Source: Linters/SAST tools
a18a5e4 to
453d38f
Compare
Signed-off-by: inaniloquentee <3051000145@qq.com>
453d38f to
5e040e3
Compare
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (2)
setup.py (2)
56-56: 📐 Maintainability & Code Quality | 🔵 Trivial | 💤 Low valueInconsistent handling of the shared
torch_rpathlist.Line 56 assigns the shared
torch_rpathreference directly, while line 114 copies vialist(torch_rpath). If the ROCm branch ever appends to itsextra_link_args, it would mutate the shared list and bleed into other consumers. Uselist(torch_rpath)in both places for consistency and to avoid aliasing surprises.♻️ Proposed change
- extra_link_args=torch_rpath, + extra_link_args=list(torch_rpath),Also applies to: 114-114
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@setup.py` at line 56, The handling of torch_rpath is inconsistent and can leak shared mutations across branches. In the setup configuration where extra_link_args is assigned, update the assignment in both the main and ROCm paths to use a copied list instead of the shared torch_rpath reference. Keep the change localized around the extra_link_args setup so both call sites use list(torch_rpath) and avoid aliasing side effects.
40-41: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚖️ Poor tradeoffAbsolute Torch lib path baked into rpath may hurt build reproducibility.
The second entry embeds the build-machine absolute path (
{torch_lib_dir}) into the binary's rpath. The$ORIGIN/../torch/libentry already provides relocatable resolution for installed wheels; the absolute path leaks build-time paths and can break reproducible/portable builds when the artifact runs on a host where Torch lives elsewhere. Consider keeping the absolute entry only as a deliberate dev-build fallback, or drop it if$ORIGINresolution is sufficient.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@setup.py` around lines 40 - 41, The rpath setup in setup.py is baking a build-machine absolute Torch library path into the binary via torch_rpath, which harms portability and reproducibility. Update the rpath construction around torch_lib_dir and torch_rpath so the $ORIGIN-based entry is the default for installed wheels, and either remove the absolute {torch_lib_dir} entry or guard it behind an explicit dev-build fallback path. Keep the change localized to the torch rpath logic in setup.py.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@csrc/cuda/fused_linear_logp_sm90.cu`:
- Around line 116-120: The barrier priming in the SM90 load path is ordered too
late, which can race with the TMA copies. In fused_linear_logp_sm90.cu, update
the buffer setup around tma_2d_g2s and mbarrier_arrive_expect_tx so the barrier
is primed on mbar[buf] before issuing either the h_tmap or w_tmap copy, keeping
the expected byte count in place before the transfers begin.
In `@rl_engine/kernels/ops/pytorch/loss/linear_logp.py`:
- Around line 65-66: The empty-shard check in linear_logp should not raise
locally before the tensor-parallel collective validation runs, because only the
rank with an empty shard will fail and other ranks can hang in
_validate_tp_vocab_partition. Update the logic around the local_vocab_size guard
in the linear_logp path so the validation is deferred to the existing partition
check, using the gathered end <= start condition to fail consistently on every
rank for explicit multi-rank TP.
---
Nitpick comments:
In `@setup.py`:
- Line 56: The handling of torch_rpath is inconsistent and can leak shared
mutations across branches. In the setup configuration where extra_link_args is
assigned, update the assignment in both the main and ROCm paths to use a copied
list instead of the shared torch_rpath reference. Keep the change localized
around the extra_link_args setup so both call sites use list(torch_rpath) and
avoid aliasing side effects.
- Around line 40-41: The rpath setup in setup.py is baking a build-machine
absolute Torch library path into the binary via torch_rpath, which harms
portability and reproducibility. Update the rpath construction around
torch_lib_dir and torch_rpath so the $ORIGIN-based entry is the default for
installed wheels, and either remove the absolute {torch_lib_dir} entry or guard
it behind an explicit dev-build fallback path. Keep the change localized to the
torch rpath logic in setup.py.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 2de76294-4e7f-4af4-80f1-b52dbe476ac3
📒 Files selected for processing (13)
benchmarks/benchmark_linear_logp.pycsrc/cuda/fused_linear_logp_sm90.cucsrc/utils/tma_utils.cuhdocs/.nav.ymldocs/operators/README.mddocs/operators/linear-logp-tp-test.mddocs/operators/linear-logp.mdrl_engine/kernels/ops/cuda/loss/linear_logp.pyrl_engine/kernels/ops/pytorch/loss/linear_logp.pyrl_engine/kernels/ops/triton/loss/linear_logp.pyscripts/test_linear_logp_tp.pysetup.pytests/test_linear_logp.py
✅ Files skipped from review due to trivial changes (3)
- docs/.nav.yml
- docs/operators/README.md
- docs/operators/linear-logp-tp-test.md
🚧 Files skipped from review as they are similar to previous changes (6)
- benchmarks/benchmark_linear_logp.py
- tests/test_linear_logp.py
- docs/operators/linear-logp.md
- rl_engine/kernels/ops/triton/loss/linear_logp.py
- rl_engine/kernels/ops/cuda/loss/linear_logp.py
- scripts/test_linear_logp_tp.py
…GEMM Signed-off-by: inaniloquentee <3051000145@qq.com>
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (1)
tests/test_deepspeed_training_worker.py (1)
64-65: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick winAnnotate these shared lists as
ClassVar. Ruff flags these mutable class attributes, and the lists are intentionally shared test state.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/test_deepspeed_training_worker.py` around lines 64 - 65, The shared mutable test state in the test class is being flagged by Ruff because modifier_ranks and parameter_counts are plain class attributes; annotate these attributes as ClassVar in the test class so it is clear they are intentionally shared across instances. Update the class definition where modifier_ranks and parameter_counts are declared, using the existing test class name as the locator, and keep their behavior unchanged.Source: Linters/SAST tools
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@rl_engine/executors/deepspeed_trainer.py`:
- Around line 219-226: The layout assembly in deepspeed_trainer.py currently
lets manifest_metadata overwrite authoritative fields in the resolved layout.
Update the logic around the layout dict in the DeepSpeed trainer so that kind,
zero_stage, world_size, and rank remain fixed to the values computed from
_deepspeed_zero_stage, _engine_world_size(), and _engine_rank(), while still
merging any extra caller-provided layout metadata from
manifest_metadata.get("layout", {}) only for non-reserved fields.
In `@rl_engine/kernels/ops/cuda/loss/linear_logp.py`:
- Line 150: The SM90 path in fused_linear_logp is still passing the original
bias tensor instead of the prepared contiguous bias view. Update the forward
call in the fused_linear_logp_sm90 branch to use bias_t consistently, and make
sure the same contiguous bias view is also used by any corresponding backward
path so both directions avoid relying on extension stride handling.
In `@tests/test_linear_logp.py`:
- Around line 723-738: Tighten the portable fallback test for tensor-parallel
linear logp: the current sentinel stub only checks that delegation happens, not
that tp_group, vocab_start_index, and global_vocab_size are forwarded. Update
the test around op and the monkeypatched tensor_parallel_linear_logp to capture
the forwarded args/kwargs and assert those TP metadata values are preserved when
the SM90 wrapper delegates.
---
Nitpick comments:
In `@tests/test_deepspeed_training_worker.py`:
- Around line 64-65: The shared mutable test state in the test class is being
flagged by Ruff because modifier_ranks and parameter_counts are plain class
attributes; annotate these attributes as ClassVar in the test class so it is
clear they are intentionally shared across instances. Update the class
definition where modifier_ranks and parameter_counts are declared, using the
existing test class name as the locator, and keep their behavior unchanged.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 28f08037-c31c-4cdb-84ee-12071e09970d
📒 Files selected for processing (4)
rl_engine/executors/deepspeed_trainer.pyrl_engine/kernels/ops/cuda/loss/linear_logp.pytests/test_deepspeed_training_worker.pytests/test_linear_logp.py
| layout = { | ||
| "kind": "full-state", | ||
| "zero_stage": self.config.zero_stage, | ||
| "zero_stage": self._deepspeed_zero_stage, | ||
| "world_size": self._engine_world_size(), | ||
| "rank": self._engine_rank(), | ||
| } | ||
| layout.update(dict(manifest_metadata.get("layout", {}))) | ||
| manifest_metadata["layout"] = layout |
There was a problem hiding this comment.
🗄️ Data Integrity & Integration | 🟠 Major | ⚡ Quick win
Keep resolved layout fields authoritative.
layout.update(...) lets caller metadata override zero_stage, world_size, rank, or kind, which can make ZeRO-3 full-state manifests lie about the actual resolved DeepSpeed layout.
Proposed fix
- layout = {
+ layout = dict(manifest_metadata.get("layout", {}))
+ layout.update({
"kind": "full-state",
"zero_stage": self._deepspeed_zero_stage,
"world_size": self._engine_world_size(),
"rank": self._engine_rank(),
- }
- layout.update(dict(manifest_metadata.get("layout", {})))
+ })
manifest_metadata["layout"] = layout📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| layout = { | |
| "kind": "full-state", | |
| "zero_stage": self.config.zero_stage, | |
| "zero_stage": self._deepspeed_zero_stage, | |
| "world_size": self._engine_world_size(), | |
| "rank": self._engine_rank(), | |
| } | |
| layout.update(dict(manifest_metadata.get("layout", {}))) | |
| manifest_metadata["layout"] = layout | |
| layout = dict(manifest_metadata.get("layout", {})) | |
| layout.update({ | |
| "kind": "full-state", | |
| "zero_stage": self._deepspeed_zero_stage, | |
| "world_size": self._engine_world_size(), | |
| "rank": self._engine_rank(), | |
| }) | |
| manifest_metadata["layout"] = layout |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@rl_engine/executors/deepspeed_trainer.py` around lines 219 - 226, The layout
assembly in deepspeed_trainer.py currently lets manifest_metadata overwrite
authoritative fields in the resolved layout. Update the logic around the layout
dict in the DeepSpeed trainer so that kind, zero_stage, world_size, and rank
remain fixed to the values computed from _deepspeed_zero_stage,
_engine_world_size(), and _engine_rank(), while still merging any extra
caller-provided layout metadata from manifest_metadata.get("layout", {}) only
for non-reserved fields.
| ) | ||
| kernel_target = kernel_target.to(torch.int32).contiguous() | ||
|
|
||
| local_logp, local_lse = _C.fused_linear_logp_sm90(hidden_2d, weight, kernel_target, bias) |
There was a problem hiding this comment.
🩺 Stability & Availability | 🟡 Minor | ⚡ Quick win
Pass the contiguous bias to the SM90 kernel.
bias_t is prepared for this path, but the kernel still receives the original bias. Use the same contiguous bias view for forward and backward instead of relying on extension stride handling.
Proposed fix
- local_logp, local_lse = _C.fused_linear_logp_sm90(hidden_2d, weight, kernel_target, bias)
+ kernel_bias = bias_t if bias is not None else None
+ local_logp, local_lse = _C.fused_linear_logp_sm90(
+ hidden_2d,
+ weight,
+ kernel_target,
+ kernel_bias,
+ )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| local_logp, local_lse = _C.fused_linear_logp_sm90(hidden_2d, weight, kernel_target, bias) | |
| kernel_bias = bias_t if bias is not None else None | |
| local_logp, local_lse = _C.fused_linear_logp_sm90( | |
| hidden_2d, | |
| weight, | |
| kernel_target, | |
| kernel_bias, | |
| ) |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@rl_engine/kernels/ops/cuda/loss/linear_logp.py` at line 150, The SM90 path in
fused_linear_logp is still passing the original bias tensor instead of the
prepared contiguous bias view. Update the forward call in the
fused_linear_logp_sm90 branch to use bias_t consistently, and make sure the same
contiguous bias view is also used by any corresponding backward path so both
directions avoid relying on extension stride handling.
| monkeypatch.setattr( | ||
| cuda_linear_logp, | ||
| "tensor_parallel_linear_logp", | ||
| lambda *args, **kwargs: sentinel, | ||
| ) | ||
|
|
||
| out = op( | ||
| hidden, | ||
| weight, | ||
| target, | ||
| tp_group=object(), | ||
| vocab_start_index=3, | ||
| global_vocab_size=6, | ||
| ) | ||
|
|
||
| assert out is sentinel |
There was a problem hiding this comment.
🎯 Functional Correctness | 🟡 Minor | ⚡ Quick win
Assert TP metadata in the portable fallback test.
This test returns sentinel regardless of args, so it would still pass if the SM90 wrapper dropped tp_group, vocab_start_index, or global_vocab_size before delegating to tensor_parallel_linear_logp.
Proposed test tightening
hidden = torch.randn(2, 4)
weight = torch.randn(3, 4)
target = torch.tensor([3, 5])
sentinel = torch.full((2,), 11.0)
+ tp_group = object()
+ calls = {}
@@
- monkeypatch.setattr(
- cuda_linear_logp,
- "tensor_parallel_linear_logp",
- lambda *args, **kwargs: sentinel,
- )
+ def fake_portable_tp(hidden_arg, weight_arg, target_arg, bias_arg, **kwargs):
+ calls["portable_tp"] = (hidden_arg, weight_arg, target_arg, bias_arg, kwargs)
+ return sentinel
+
+ monkeypatch.setattr(cuda_linear_logp, "tensor_parallel_linear_logp", fake_portable_tp)
@@
hidden,
weight,
target,
- tp_group=object(),
+ tp_group=tp_group,
vocab_start_index=3,
global_vocab_size=6,
)
assert out is sentinel
+ assert calls["portable_tp"][0] is hidden
+ assert calls["portable_tp"][1] is weight
+ assert calls["portable_tp"][2] is target
+ assert calls["portable_tp"][4] == {
+ "tp_group": tp_group,
+ "vocab_start_index": 3,
+ "global_vocab_size": 6,
+ }📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| monkeypatch.setattr( | |
| cuda_linear_logp, | |
| "tensor_parallel_linear_logp", | |
| lambda *args, **kwargs: sentinel, | |
| ) | |
| out = op( | |
| hidden, | |
| weight, | |
| target, | |
| tp_group=object(), | |
| vocab_start_index=3, | |
| global_vocab_size=6, | |
| ) | |
| assert out is sentinel | |
| hidden = torch.randn(2, 4) | |
| weight = torch.randn(3, 4) | |
| target = torch.tensor([3, 5]) | |
| sentinel = torch.full((2,), 11.0) | |
| tp_group = object() | |
| calls = {} | |
| def fake_portable_tp(hidden_arg, weight_arg, target_arg, bias_arg, **kwargs): | |
| calls["portable_tp"] = (hidden_arg, weight_arg, target_arg, bias_arg, kwargs) | |
| return sentinel | |
| monkeypatch.setattr(cuda_linear_logp, "tensor_parallel_linear_logp", fake_portable_tp) | |
| out = op( | |
| hidden, | |
| weight, | |
| target, | |
| tp_group=tp_group, | |
| vocab_start_index=3, | |
| global_vocab_size=6, | |
| ) | |
| assert out is sentinel | |
| assert calls["portable_tp"][0] is hidden | |
| assert calls["portable_tp"][1] is weight | |
| assert calls["portable_tp"][2] is target | |
| assert calls["portable_tp"][4] == { | |
| "tp_group": tp_group, | |
| "vocab_start_index": 3, | |
| "global_vocab_size": 6, | |
| } |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tests/test_linear_logp.py` around lines 723 - 738, Tighten the portable
fallback test for tensor-parallel linear logp: the current sentinel stub only
checks that delegation happens, not that tp_group, vocab_start_index, and
global_vocab_size are forwarded. Update the test around op and the monkeypatched
tensor_parallel_linear_logp to capture the forwarded args/kwargs and assert
those TP metadata values are preserved when the SM90 wrapper delegates.
Signed-off-by: Codex <codex@openai.com>
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
tests/test_op_accuracy.py (1)
17-35: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick winAlign this helper with the runtime gate.
_sm90_logp_available()is Hopper-only (cc_major == 9), but the registry now treats9,10, and12as eligible for experimental fused-logp promotion. As written, these routing tests will skip on part of the supported surface.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/test_op_accuracy.py` around lines 17 - 35, Update _sm90_logp_available() in the test helper to match the runtime fused-logp gate used by the registry: instead of only accepting get_device_capability()[0] == 9, allow all promoted eligible major versions (9, 10, and 12) while still requiring torch.cuda.is_available(), _EXT_AVAILABLE, and hasattr(_C, "fused_logp_sm90"). Keep requires_sm90_logp using this helper so the routing tests skip only when the runtime gate would also reject the device.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@docs/design/runtime-dispatch.md`:
- Around line 22-25: Tighten the runtime-dispatch capability wording to match
the actual registry contract: in the runtime dispatch documentation, replace the
open-ended “compute capability 9.0 or newer” phrasing with the specific
supported major versions used by the registry logic. Refer to the SM90 LogP
backend gating text so it clearly reflects that the experimental backend is only
promoted for the exact cc_major values handled by the implementation, while
keeping the fused linear logp SM90 backend description unchanged.
---
Nitpick comments:
In `@tests/test_op_accuracy.py`:
- Around line 17-35: Update _sm90_logp_available() in the test helper to match
the runtime fused-logp gate used by the registry: instead of only accepting
get_device_capability()[0] == 9, allow all promoted eligible major versions (9,
10, and 12) while still requiring torch.cuda.is_available(), _EXT_AVAILABLE, and
hasattr(_C, "fused_logp_sm90"). Keep requires_sm90_logp using this helper so the
routing tests skip only when the runtime gate would also reject the device.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 99e30768-f641-49a1-a262-e40f23105f25
📒 Files selected for processing (15)
csrc/cuda/fused_linear_logp_sm90.cudocs/contributing/README.mddocs/contributing/pr-189-test-results.mddocs/contributing/vime-RLK.mddocs/design/runtime-dispatch.mddocs/operators/fused-logp.mdrl_engine/kernels/ops/cuda/loss/linear_logp.pyrl_engine/kernels/ops/cuda/loss/logp.pyrl_engine/kernels/ops/pytorch/loss/linear_logp.pyrl_engine/kernels/registry.pyscripts/test_linear_logp_tp.pysetup.pytests/test_linear_logp.pytests/test_op_accuracy.pytests/test_rl_kernel_loss_step.py
✅ Files skipped from review due to trivial changes (3)
- docs/contributing/README.md
- docs/operators/fused-logp.md
- docs/contributing/pr-189-test-results.md
🚧 Files skipped from review as they are similar to previous changes (6)
- setup.py
- csrc/cuda/fused_linear_logp_sm90.cu
- scripts/test_linear_logp_tp.py
- rl_engine/kernels/ops/cuda/loss/linear_logp.py
- tests/test_linear_logp.py
- rl_engine/kernels/ops/pytorch/loss/linear_logp.py
| For CUDA devices with compute capability 9.0 or newer, the registry only inserts | ||
| the legacy SM90 LogP backend when `RL_KERNEL_ENABLE_EXPERIMENTAL_SM90_LOGP=1` is | ||
| set. The fused linear logp SM90 backend is gated separately and remains the | ||
| default linear logp backend when the extension is built on Hopper. |
There was a problem hiding this comment.
📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win
Tighten the capability wording here.
This says 9.0 or newer, but the registry currently promotes the experimental backend only for cc_major in (9, 10, 12). Documenting it as an open-ended >= 9.0 gate already diverges from the implementation contract.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@docs/design/runtime-dispatch.md` around lines 22 - 25, Tighten the
runtime-dispatch capability wording to match the actual registry contract: in
the runtime dispatch documentation, replace the open-ended “compute capability
9.0 or newer” phrasing with the specific supported major versions used by the
registry logic. Refer to the SM90 LogP backend gating text so it clearly
reflects that the experimental backend is only promoted for the exact cc_major
values handled by the implementation, while keeping the fused linear logp SM90
backend description unchanged.
Signed-off-by: Codex <codex@openai.com>
Summary
Tensor-parallel
linear_logpnow supports vocab-sharded LM-head weights throughthe public operator API. The TP path accepts global
target_ids, locallm_head_weight[vocab_start:vocab_end], optional local bias, a TP process group,and global vocab metadata.
The TP implementation computes global log-sum-exp across vocab shards, sums the
rank-local target logit contribution, all-reduces
hidden.grad, and keepsweight.grad/bias.gradlocal to each vocab shard.Implementation
[0, V)vocab partitions andexactly-one-owner target coverage.
local target logit, and owner count.
hidden.gradis all-reduced across ranks;weight.gradandbias.gradremain shard-local.bf16/fp16 matmuls in fp32 before casting returned gradients back to input
dtypes.
shared::cluster.global.tileform._Cimports withoutmanual
LD_LIBRARY_PATH.Validation Environment
KERNEL_ALIGN_FORCE_SM90=1,_C.fused_linear_logp_sm90availableCorrectness / Tests
Native TP math sanity
Result: PASS.
Registry bf16 main gate
Result: PASS.
Wrapper delegation
Both wrapper paths passed the same bf16/fp32-reference TP correctness run:
tritonsm90SM90 log confirms the extension was linked:
Same-dtype bf16 drift check
Result: PASS with recorded full-GEMM vs shard-GEMM drift.
This drift comes from comparing PyTorch's full bf16 GEMM against separate
vocab-shard GEMMs. The main correctness target is fp32 accumulation, matching
the fused/Triton bf16 operator semantics.
Larger stress
Result: PASS.
Unit coverage
Result: 21 passed in 8.40s.
Latest PR 189 Test Results
Validated on 2026-06-27 UTC on the
pr-189/feat/fused-linear-logpbranch.Runtime code was validated at
ddf65a7; follow-up355a2e0only removes localPR notes from the submitted branch and passed docs/pre-commit checks.
Environment
nvcc12.4.131DS_BUILD_OPS=0KERNEL_ALIGN_FORCE_SM90=1;_C.fused_linear_logp_sm90availableCI and Regression Checks
python -m pytest tests/test_linear_logp.py -q -rspython -m pytest tests/test_deepspeed_training_worker.py -q -rspython -m pytest tests/test_op_accuracy.py tests/test_rl_kernel_loss_step.py -q -rs --tb=shortpython -m pytest tests rl_engine/tests -q -rs --tb=shortpython -m pytest rl_engine/tests/test_dispatch.py -vPYTEST_DISABLE_PLUGIN_AUTOLOAD=1 python -m pytest tests/test_attention_correctness.py -q -rspython -m mypy --ignore-missing-imports rl_engine/mkdocs build --strict -f mkdocs.yamlpre-commit run --all-filesThe attention skips were expected for this CUDA/H100 environment: external
flash_attn/_C.flash_attn_forwardwas unavailable, and ROCm cases wereskipped on NVIDIA CUDA.
Production-Like CUDA Checks
linear_logpTP stress, bf16/fp32 reference, uneven shardslinear_logpTP, bf16/fp32 reference, uneven shardsDeepSpeedEngine, zero stage 0,current_logp_backend=FusedLinearLogpSM90Op, finite loss, published weight version 11Direct SM90 TP correctness metrics: output max abs
3.051758e-04, hidden grad6.250000e-02, weight grad6.250000e-02, bias grad1.562500e-02.Registry TP correctness metrics: output max abs
1.068115e-04, hidden grad3.125000e-02, weight grad3.125000e-02, bias grad1.953125e-03.Regression Found and Fixed
A full-suite H100 run exposed that the legacy SM90
logpregistry path selectedFusedLogpSM90Opfor cases it did not fully support: fp16 logits,.out(...),and
.apply_fp32(...). The wrapper now delegates toFusedLogpGenericOpbydefault, and the legacy SM90 TMA
logpfast path is gated behindRL_KERNEL_ENABLE_EXPERIMENTAL_SM90_LOGP=1. This does not affect the PR's fusedSM90
linear_logpproduction path, which is covered by the two-GPU checks above.Summary by CodeRabbit
New Features
linear_logp, including newtp_group,vocab_start_index, andglobal_vocab_sizeoptions with TP-aware dispatch (including SM90 when eligible).linear_logpwith TP metadata.Bug Fixes
Documentation
linear_logpoperator docs and added an end-to-end TP test runbook.Tests