[WS1][kernels] Batch-invariant logprob (Native, Triton)#199
Conversation
Implements batch_invariant_logp for selected-token log probabilities from materialized logits with row-local, batch-invariant semantics. - PyTorch NativeBatchInvariantLogpOp: FP32 row-wise reference with ignore_index handling and target validation. - Triton TritonBatchInvariantLogpOp: online-softmax forward with fixed vocab tiling and tile-wise backward using saved per-row lse. - Registry dispatch, PyTorch/Triton tests, and operator docs.
📝 WalkthroughWalkthroughAdds a new ChangesBatch-Invariant LogP Operator
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related issues
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 4
🧹 Nitpick comments (2)
docs/operators/batch-invariant-logp.md (1)
99-102: 📐 Maintainability & Code Quality | 🔵 Trivial | 💤 Low valueReword successive "Triton" sentence openings.
Three successive bullet points begin with "Triton" (lines 99–102 in the rendered document), which reads repetitively. Vary the sentence structure for better flow, e.g., "The vocab traversal...", "Forward scan...", "Backward computation...".
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@docs/operators/batch-invariant-logp.md` around lines 99 - 102, The three bullet points in the batch-invariant logp section all start with “Triton,” which reads repetitive; reword the openings in that block to vary the sentence structure while keeping the same meaning. Update the bullet text near the vocab traversal, forward scan, and backward computation statements so each begins differently, using the surrounding markdown list content as the target.tests/test_triton_batch_invariant_logp.py (1)
328-340: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick winLet the CPU-tensor validation run without CUDA.
TritonBatchInvariantLogpOp.applyraises on CPU tensors before any kernel launch, so Lines 337-340 do not need a CUDA device. Keeping this test under the class-level@requires_triton_cudaskips that guard on CPU-only CI and leaves the validation path unexercised. Move this case out of the CUDA-gated class, or gate it only on Triton import.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/test_triton_batch_invariant_logp.py` around lines 328 - 340, The CPU-tensor validation in TritonBatchInvariantLogpOp should be exercised without requiring CUDA, since the RuntimeError is raised before any kernel launch. Move test_rejects_cpu_tensor out of the TestTritonValidation class that is decorated with `@requires_triton_cuda`, or gate only the Triton import in _get_op while leaving this test ungated so it runs on CPU-only CI.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py`:
- Around line 41-42: The target conversion in batch_invariant_logp is too
permissive because target_ids is cast to torch.long after reshaping, which can
silently truncate float or bool inputs. Add an explicit validation step in the
batch_invariant_logp path to require target_ids to already be an integer/long
tensor before any cast, and fail fast for non-integer inputs. Keep the reshape
logic for logits_2d and target_1d, but ensure the check happens before
target_ids is converted or used.
In `@rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py`:
- Around line 230-244: The `target_ids` validation in `batch_invariant_logp.py`
only checks value bounds, so float tensors can slip through and later be
truncated by `_BatchInvariantLogpFunction.apply`/`forward` when cast to int64.
Update the validation block before the existing range check to reject
non-integer dtypes for `target_ids` (using the `vocab_size`/`valid_targets` path
as the location), and raise a clear error for any float or otherwise non-integer
input.
In `@tests/test_batch_invariant_logp.py`:
- Around line 466-478: The registry test is inheriting shared singleton cache
state from kernel_registry, which can make get_op("batch_invariant_logp") return
a previously cached operator from another test. Update
test_registry_dispatches_correctly to isolate itself by using a fresh registry
instance or by saving and restoring/clearing the registry caches around the
assertion, so the test does not depend on prior backend selections. Use
kernel_registry and get_op as the main symbols to locate the fix.
- Around line 449-458: The CPU-vs-CUDA equality check in
test_cpu_gpu_cross_check is too strict for NativeBatchInvariantLogpOp, since
backend reduction order can differ even when both are correct. Update this test
to compare each backend’s output against the same reference behavior instead of
directly asserting out_cpu and out_cuda match at 1e-6, and keep the check
localized to the test_cpu_gpu_cross_check method.
---
Nitpick comments:
In `@docs/operators/batch-invariant-logp.md`:
- Around line 99-102: The three bullet points in the batch-invariant logp
section all start with “Triton,” which reads repetitive; reword the openings in
that block to vary the sentence structure while keeping the same meaning. Update
the bullet text near the vocab traversal, forward scan, and backward computation
statements so each begins differently, using the surrounding markdown list
content as the target.
In `@tests/test_triton_batch_invariant_logp.py`:
- Around line 328-340: The CPU-tensor validation in TritonBatchInvariantLogpOp
should be exercised without requiring CUDA, since the RuntimeError is raised
before any kernel launch. Move test_rejects_cpu_tensor out of the
TestTritonValidation class that is decorated with `@requires_triton_cuda`, or gate
only the Triton import in _get_op while leaving this test ungated so it runs on
CPU-only CI.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: e8e119e7-3e1c-4443-b35a-f200f33948a7
📒 Files selected for processing (9)
.gitignoredocs/.nav.ymldocs/operators/README.mddocs/operators/batch-invariant-logp.mdrl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.pyrl_engine/kernels/ops/triton/loss/batch_invariant_logp.pyrl_engine/kernels/registry.pytests/test_batch_invariant_logp.pytests/test_triton_batch_invariant_logp.py
| logits_2d = logits.reshape(-1, vocab_size).float() | ||
| target_1d = target_ids.reshape(-1).to(logits.device, dtype=torch.long) |
There was a problem hiding this comment.
🎯 Functional Correctness | 🟠 Major | ⚡ Quick win
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
printf '\n== file outline ==\n'
ast-grep outline rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py --view expanded || true
printf '\n== file contents (relevant slice) ==\n'
sed -n '1,220p' rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py
printf '\n== related tests search ==\n'
rg -n "batch_invariant_logp|NativeBatchInvariantLogpOp|target_ids|ignore_index|dtype" rl_engine -g '*test*' -g '*spec*' || true
printf '\n== pytorch casting probe ==\n'
python3 - <<'PY'
import torch
cases = [
torch.tensor([1.9, -0.2, 3.7], dtype=torch.float32),
torch.tensor([True, False, True], dtype=torch.bool),
torch.tensor([1, 2, 3], dtype=torch.int32),
]
for t in cases:
out = t.to(dtype=torch.long)
print(t.dtype, '->', out.dtype, out.tolist())
PYRepository: RL-Align/RL-Kernel
Length of output: 5090
🏁 Script executed:
#!/bin/bash
set -euo pipefail
printf '\n== call sites ==\n'
rg -n "NativeBatchInvariantLogpOp|batch_invariant_logp|selected_logprob|target_ids" rl_engine -g '!**/*.pyc' || true
printf '\n== tests in loss area ==\n'
fd -a "test*.py" rl_engine/tests rl_engine/kernels 2>/dev/null || true
printf '\n== nearby test files mentioning logprob ==\n'
rg -n "logprob|batch_invariant|target_ids|ignore_index" rl_engine/tests rl_engine/kernels/tests -g '*test*.py' || trueRepository: RL-Align/RL-Kernel
Length of output: 9236
Reject non-integer target_ids before the cast. .to(dtype=torch.long) truncates float/bool tensors, so invalid inputs can silently pick the wrong class instead of failing.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py` around lines 41 -
42, The target conversion in batch_invariant_logp is too permissive because
target_ids is cast to torch.long after reshaping, which can silently truncate
float or bool inputs. Add an explicit validation step in the
batch_invariant_logp path to require target_ids to already be an integer/long
tensor before any cast, and fail fast for non-integer inputs. Keep the reshape
logic for logits_2d and target_1d, but ensure the check happens before
target_ids is converted or used.
| vocab_size = logits.size(-1) | ||
| target_flat = target_ids.reshape(-1) | ||
| valid_targets = target_flat[target_flat != ignore_index] | ||
| if valid_targets.numel() > 0 and ( | ||
| (valid_targets < 0).any() or (valid_targets >= vocab_size).any() | ||
| ): | ||
| bad = valid_targets[ | ||
| (valid_targets < 0) | (valid_targets >= vocab_size) | ||
| ] | ||
| raise ValueError( | ||
| f"target_ids contains values outside [0, {vocab_size}): " | ||
| f"{bad.tolist()}" | ||
| ) | ||
|
|
||
| return _BatchInvariantLogpFunction.apply(logits, target_ids, ignore_index) |
There was a problem hiding this comment.
🎯 Functional Correctness | 🟠 Major | ⚡ Quick win
Validate target_ids dtype before the range check. Float targets can pass the current bounds test and then get truncated by _BatchInvariantLogpFunction.forward()’s int64 cast, which can compute the loss for the wrong class. Reject non-integer dtypes here.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py` around lines 230 -
244, The `target_ids` validation in `batch_invariant_logp.py` only checks value
bounds, so float tensors can slip through and later be truncated by
`_BatchInvariantLogpFunction.apply`/`forward` when cast to int64. Update the
validation block before the existing range check to reject non-integer dtypes
for `target_ids` (using the `vocab_size`/`valid_targets` path as the location),
and raise a clear error for any float or otherwise non-integer input.
| def test_cpu_gpu_cross_check(self): | ||
| """Same input on CPU vs CUDA should match within tolerance.""" | ||
| op = NativeBatchInvariantLogpOp() | ||
| logits_cpu = torch.randn(8, _V) | ||
| target_cpu = torch.randint(0, _V, (8,)) | ||
| out_cpu = op(logits_cpu, target_cpu) | ||
| out_cuda = op(logits_cpu.cuda(), target_cpu.cuda()) | ||
| assert torch.allclose(out_cpu, out_cuda.cpu(), atol=1e-6, rtol=1e-6), ( | ||
| "CPU vs CUDA result mismatch" | ||
| ) |
There was a problem hiding this comment.
🎯 Functional Correctness | 🟠 Major | ⚡ Quick win
Relax this CPU-vs-CUDA assertion.
NativeBatchInvariantLogpOp only promises per-row batch invariance; it does not promise CPU and CUDA reductions land on the same FP32 bits. Comparing out_cpu and out_cuda at 1e-6 can fail on a correct implementation because the two backends are free to reduce in different orders. Compare each backend against the common reference instead of asserting direct equality.
Suggested test change
def test_cpu_gpu_cross_check(self):
- """Same input on CPU vs CUDA should match within tolerance."""
+ """CPU and CUDA should each match the reference."""
op = NativeBatchInvariantLogpOp()
logits_cpu = torch.randn(8, _V)
target_cpu = torch.randint(0, _V, (8,))
out_cpu = op(logits_cpu, target_cpu)
+ ref = _reference_logp(logits_cpu, target_cpu)
out_cuda = op(logits_cpu.cuda(), target_cpu.cuda())
- assert torch.allclose(out_cpu, out_cuda.cpu(), atol=1e-6, rtol=1e-6), (
- "CPU vs CUDA result mismatch"
- )
+ assert torch.allclose(out_cpu, ref, atol=1e-6, rtol=1e-6)
+ assert torch.allclose(out_cuda.cpu(), ref, atol=1e-6, rtol=1e-6)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def test_cpu_gpu_cross_check(self): | |
| """Same input on CPU vs CUDA should match within tolerance.""" | |
| op = NativeBatchInvariantLogpOp() | |
| logits_cpu = torch.randn(8, _V) | |
| target_cpu = torch.randint(0, _V, (8,)) | |
| out_cpu = op(logits_cpu, target_cpu) | |
| out_cuda = op(logits_cpu.cuda(), target_cpu.cuda()) | |
| assert torch.allclose(out_cpu, out_cuda.cpu(), atol=1e-6, rtol=1e-6), ( | |
| "CPU vs CUDA result mismatch" | |
| ) | |
| def test_cpu_gpu_cross_check(self): | |
| """CPU and CUDA should each match the reference.""" | |
| op = NativeBatchInvariantLogpOp() | |
| logits_cpu = torch.randn(8, _V) | |
| target_cpu = torch.randint(0, _V, (8,)) | |
| out_cpu = op(logits_cpu, target_cpu) | |
| ref = _reference_logp(logits_cpu, target_cpu) | |
| out_cuda = op(logits_cpu.cuda(), target_cpu.cuda()) | |
| assert torch.allclose(out_cpu, ref, atol=1e-6, rtol=1e-6) | |
| assert torch.allclose(out_cuda.cpu(), ref, atol=1e-6, rtol=1e-6) |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tests/test_batch_invariant_logp.py` around lines 449 - 458, The CPU-vs-CUDA
equality check in test_cpu_gpu_cross_check is too strict for
NativeBatchInvariantLogpOp, since backend reduction order can differ even when
both are correct. Update this test to compare each backend’s output against the
same reference behavior instead of directly asserting out_cpu and out_cuda match
at 1e-6, and keep the check localized to the test_cpu_gpu_cross_check method.
| def test_registry_dispatches_correctly(): | ||
| from rl_engine.kernels.registry import kernel_registry | ||
|
|
||
| op = kernel_registry.get_op("batch_invariant_logp") | ||
| assert ( | ||
| isinstance(op, NativeBatchInvariantLogpOp) | ||
| or type(op).__name__ == "TritonBatchInvariantLogpOp" | ||
| ) | ||
| logits = torch.randn(4, _V, device="cuda" if torch.cuda.is_available() else "cpu") | ||
| target = torch.randint(0, _V, (4,), device=logits.device) | ||
| out = op(logits, target) | ||
| ref = _reference_logp(logits, target) | ||
| assert torch.allclose(out, ref, atol=1e-6) |
There was a problem hiding this comment.
🎯 Functional Correctness | 🟠 Major | ⚡ Quick win
Isolate this registry test from global cache state.
kernel_registry.get_op() reuses cached instances keyed only by backend.name in rl_engine/kernels/registry.py:192-223. If an earlier test already cached some other PYTORCH_NATIVE/TRITON operator, this call can return the wrong instance and make this test suite-order dependent. Use a fresh registry instance here, or clear/restore the singleton caches around the assertion. The underlying registry bug is the cache key, but this test currently inherits that flakiness.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tests/test_batch_invariant_logp.py` around lines 466 - 478, The registry test
is inheriting shared singleton cache state from kernel_registry, which can make
get_op("batch_invariant_logp") return a previously cached operator from another
test. Update test_registry_dispatches_correctly to isolate itself by using a
fresh registry instance or by saving and restoring/clearing the registry caches
around the assertion, so the test does not depend on prior backend selections.
Use kernel_registry and get_op as the main symbols to locate the fix.
KJLdefeated
left a comment
There was a problem hiding this comment.
Overall correct and clean. The forward is one program per row, fixed _BLOCK_V=1024, no autotune, a sequential left-to-right online-LSE over vocab tiles, so the per-row reduction order is locked and independent of batch size, position, and neighbors. The backward is the FlashAttention pattern done right: save per-row lse in forward, recompute softmax = exp(logits − lse) in backward against that saved lse, grad = grad_out·(onehot − softmax), grid (rows, vocab_tiles), no atomics, no cross-row coupling, so the gradient is batch-invariant too, and consistent with the forward's exact lse.
Some notes:
This PR supports Batch-invariant logprob (Native / Triton), not yet CUDA (cc. me). Can you note that in PR's title? --> add Batch-invariant logprob (Native, Triton)
When this PR is done, I will make another PR to support CUDA version of this.
| valid_targets = target_flat[target_flat != ignore_index] | ||
| if valid_targets.numel() > 0 and ( | ||
| (valid_targets < 0).any() or (valid_targets >= vocab_size).any() | ||
| ): |
There was a problem hiding this comment.
On CUDA, .any() inside a Python if calls bool → .item() → a blocking sync — twice per forward — plus the boolean-mask index allocates a tensor each call. This op runs every training step, so every step eats two stream stalls just to validate, which undercuts the point of a fast batch-invariant kernel.
I would suggest it to be opt-in / debug only:
def apply(self, logits, target_ids, ignore_index=-100, *, validate=False):
...
if validate: # opt-in / debug only — this path syncs
valid = target_flat[target_flat != ignore_index]
if valid.numel() and ((valid < 0).any() or (valid >= vocab_size).any()):
raise ValueError(...)|
|
||
| """Tests for the Triton batch-invariant selected-logprob kernel (issue #148). | ||
|
|
||
| These tests validate that the Triton kernel produces results that: | ||
| 1. Match the PyTorch reference implementation (correctness). | ||
| 2. Are bitwise identical across different batch sizes / positions (batch-invariance). | ||
| 3. Support backward pass (gradient correctness). | ||
|
|
||
| All tests are skipped when Triton or CUDA is unavailable (e.g. on Windows or CPU-only). |
There was a problem hiding this comment.
Can you combine all different backend method of batch invariant logp (Triton / Torch) into one file (tests/test_batch_invariant_logp.py) for codebase cleanliness and reusable functions? For future CUDA batch-invariant logp, I will also put CUDA test in tests/test_batch_invariant_logp.py.
| out.sum().backward() | ||
| ``` | ||
|
|
||
| ## Tests |
There was a problem hiding this comment.
Just make a note that this kernel is currently lacking benchmarks, I will support it in another PR (cuda ver. for this).
There was a problem hiding this comment.
Addressed, thanks.
Changes:
- Renamed PR scope/title to Batch-invariant logprob (Native, Triton); CUDA + benchmarks will be handled in a follow-up PR.
- Made target range validation opt-in via validate=False by default to avoid CUDA stream sync in the Triton hot path.
- Merged Native and Triton tests into tests/test_batch_invariant_logp.py.
- Updated operator docs to mention validate=True and benchmark/CUDA follow-up.
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py`:
- Around line 27-30: The default path in BatchInvariantLogp currently leaves
validation off, which lets non-ignore negative target_ids fall through and index
from the vocab tail. Update the public forward/apply flow in
batch_invariant_logp.py so validation is enabled by default, and keep
validate=False only as an explicit opt-in fast path for trusted callers; make
sure the checks in the target handling logic around the validation branch are
still used before any indexing.
In `@tests/test_batch_invariant_logp.py`:
- Around line 758-770: The CPU-input guard in TritonBatchInvariantLogpOp should
be tested without depending on a CUDA-only environment. Move
test_rejects_cpu_tensor out from under the requires_triton_cuda gate, or apply a
Triton-only check instead, so it still exercises
TritonBatchInvariantLogpOp.apply’s early device validation on CPU-only runs.
Keep the test focused on the RuntimeError path that rejects non-GPU tensors.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 7ab61cb0-ea7e-42ef-99f1-6d8d047ff9b1
📒 Files selected for processing (4)
docs/operators/batch-invariant-logp.mdrl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.pyrl_engine/kernels/ops/triton/loss/batch_invariant_logp.pytests/test_batch_invariant_logp.py
✅ Files skipped from review due to trivial changes (1)
- docs/operators/batch-invariant-logp.md
🚧 Files skipped from review as they are similar to previous changes (1)
- rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py
Summary
Batch-invariant selected log-prob op (PyTorch native reference + Triton forward/backward).
This implements
batch_invariant_logpfor materialized logits:The main goal is not only correctness versus
log_softmax + gather, but stablerow-local behavior: the same row should produce the same selected log-prob when
evaluated alone, at a different batch position, or with different neighboring rows.
Refs: #148
lse.Current dispatch:
Implementation
PyTorch
NativeBatchInvariantLogpOp[*lead, V]logits to[N, V].[*lead]float32 output.[0, V).ignore_index; ignored rows output0.0.Triton
TritonBatchInvariantLogpOpForward uses one Triton program per token row.
Vocab is scanned with fixed
_BLOCK_V=1024; no autotune.Uses one-pass online logsumexp to keep a fixed left-to-right reduction order.
Saves per-row
lsefor backward.Backward uses one Triton program per
(row, vocab_tile).Computes:
Ignored rows receive zero gradient across the full vocab.
No atomic writes and no cross-row reductions.
Registry / docs
batch_invariant_logptoKernelRegistry.docs/operators/batch-invariant-logp.md.docs/operators/README.mdanddocs/.nav.yml.Batch-invariance contract
The operator is designed so each row depends only on:
It should not depend on:
The Triton path enforces this by:
grid=(num_tokens,)for forward;_BLOCK_V=1024;grid=(num_tokens, vocab_tiles)for backward;lsein backward;Tensor contract
Rules:
target_ids.shape == logits.shape[:-1]output shape equals
target_ids.shapeoutput dtype is
float32non-ignored target ids must be in
[0, vocab_size)target_ids[row] == ignore_indexmeans:Correctness / tests
Validated on WSL/Linux with CUDA + Triton:
The skipped cases are existing SM90/Hopper-only
linear_logptests:They are unrelated to this operator.
Test coverage
torch.log_softmax(...).gather(...).[B, T, V] -> [B, T].ignore_index;Follow-ups
This PR intentionally focuses on the PyTorch reference and Triton backend.
Planned follow-ups:
Summary by CodeRabbit
batch_invariant_logpoperator with native and accelerated (Triton) support.validateoption to control target-range checking (with stricter default behavior on the native path).