Skip to content

[WS1][kernels] Batch-invariant logprob (Native, Triton)#199

Open
hihaluemen wants to merge 3 commits into
RL-Align:mainfrom
hihaluemen:feat/batch-invariant-logp
Open

[WS1][kernels] Batch-invariant logprob (Native, Triton)#199
hihaluemen wants to merge 3 commits into
RL-Align:mainfrom
hihaluemen:feat/batch-invariant-logp

Conversation

@hihaluemen

@hihaluemen hihaluemen commented Jun 28, 2026

Copy link
Copy Markdown

Summary

Batch-invariant selected log-prob op (PyTorch native reference + Triton forward/backward).

This implements batch_invariant_logp for materialized logits:

selected_logp[row] = logits[row, target_ids[row]] - logsumexp(logits[row, :])

The main goal is not only correctness versus log_softmax + gather, but stable
row-local behavior: the same row should produce the same selected log-prob when
evaluated alone, at a different batch position, or with different neighboring rows.

Refs: #148

Backend Status
CUDA / ROCm (Triton) Online-softmax forward with fixed vocab tiling and locked row-local reduction order; tile-wise Triton backward using saved per-row lse.
PyTorch native FP32 row-wise reference implementation; CPU fallback and Triton-less fallback.

Current dispatch:

CUDA / ROCm: Triton -> PyTorch
CPU:         PyTorch

Implementation

  • PyTorch NativeBatchInvariantLogpOp

    • Reshapes [*lead, V] logits to [N, V].
    • Computes selected log-probs row-wise in FP32.
    • Returns [*lead] float32 output.
    • Supports CPU and CUDA tensors.
    • Validates non-ignored target ids are in [0, V).
    • Supports ignore_index; ignored rows output 0.0.
  • Triton TritonBatchInvariantLogpOp

    • Forward uses one Triton program per token row.

    • Vocab is scanned with fixed _BLOCK_V=1024; no autotune.

    • Uses one-pass online logsumexp to keep a fixed left-to-right reduction order.

    • Saves per-row lse for backward.

    • Backward uses one Triton program per (row, vocab_tile).

    • Computes:

      grad_logits[row, j] =
          grad_out[row] * (1[j == target] - exp(logits[row, j] - lse[row]))
      
    • Ignored rows receive zero gradient across the full vocab.

    • No atomic writes and no cross-row reductions.

  • Registry / docs

    • Adds batch_invariant_logp to KernelRegistry.
    • Adds operator documentation under docs/operators/batch-invariant-logp.md.
    • Adds the operator to docs/operators/README.md and docs/.nav.yml.

Batch-invariance contract

The operator is designed so each row depends only on:

logits[row, :]
target_ids[row]
ignore_index
grad_out[row]    # backward only

It should not depend on:

batch size
row position
neighboring rows
padding layout

The Triton path enforces this by:

  • using grid=(num_tokens,) for forward;
  • using fixed _BLOCK_V=1024;
  • scanning vocab tiles in a deterministic left-to-right order;
  • using grid=(num_tokens, vocab_tiles) for backward;
  • reusing saved per-row lse in backward;
  • avoiding atomics and shared cross-row writes.

Tensor contract

from rl_engine.kernels.registry import kernel_registry

batch_invariant_logp = kernel_registry.get_op("batch_invariant_logp")

out = batch_invariant_logp(
    logits,       # [B, T, V] or [N, V], differentiable
    target_ids,   # [B, T] or [N], int
    ignore_index=-100,
)                # -> [B, T] or [N], float32

Rules:

  • target_ids.shape == logits.shape[:-1]

  • output shape equals target_ids.shape

  • output dtype is float32

  • non-ignored target ids must be in [0, vocab_size)

  • target_ids[row] == ignore_index means:

    out[row] = 0.0
    grad_logits[row, :] = 0.0
    

Correctness / tests

Validated on WSL/Linux with CUDA + Triton:

python -m pytest ./tests/test_triton_batch_invariant_logp.py -q -rs
# 21 passed

python -m pytest ./tests/test_batch_invariant_logp.py ./tests/test_triton_batch_invariant_logp.py -q -rs
# 53 passed

python -m pytest ./tests/test_linear_logp.py ./tests/test_op_accuracy.py ./tests/test_batch_invariant_logp.py ./tests/test_triton_batch_invariant_logp.py -q -rs
# 79 passed, 9 skipped

The skipped cases are existing SM90/Hopper-only linear_logp tests:

Fused linear log-prob SM90 kernel requires a Hopper (sm_90) GPU with the extension built KERNEL_ALIGN_FORCE_SM90=1.

They are unrelated to this operator.

Test coverage

  • Correctness vs torch.log_softmax(...).gather(...).
  • PyTorch native vs existing native logp reference.
  • Triton vs PyTorch native.
  • fp32 / fp16 / bf16 inputs.
  • float32 output dtype.
  • leading-shape preservation, including [B, T, V] -> [B, T].
  • large vocab smoke tests.
  • batch size 1 vs batch N invariance.
  • same row at different batch positions.
  • mixed neighboring batch contents.
  • repeated-run determinism.
  • padding / ignore-index layout invariance.
  • backward correctness.
  • gradient batch-invariance.
  • ignored-row zero gradient.
  • invalid target validation:
    • negative target not equal to ignore_index;
    • target id equal to or greater than vocab size.
  • registry dispatch.

Follow-ups

This PR intentionally focuses on the PyTorch reference and Triton backend.

Planned follow-ups:

  • compiled CUDA backend;
  • CUDA-specific tests;
  • benchmark script and performance table;
  • docs update with benchmark results.

Summary by CodeRabbit

  • New Features
    • Added a new batch_invariant_logp operator with native and accelerated (Triton) support.
    • Added a validate option to control target-range checking (with stricter default behavior on the native path).
  • Documentation
    • Added a new operator guide and linked it from the operator docs index and site navigation.
  • Tests
    • Added extensive correctness, batch-invariance, ignore-index, validation, backward-gradient, and registry-dispatch coverage for supported backends.
  • Chores
    • Updated version control ignore rules to exclude local development notes.

Implements batch_invariant_logp for selected-token log probabilities from materialized logits with row-local, batch-invariant semantics.

- PyTorch NativeBatchInvariantLogpOp: FP32 row-wise reference with ignore_index handling and target validation.

- Triton TritonBatchInvariantLogpOp: online-softmax forward with fixed vocab tiling and tile-wise backward using saved per-row lse.

- Registry dispatch, PyTorch/Triton tests, and operator docs.
@coderabbitai

coderabbitai Bot commented Jun 28, 2026

Copy link
Copy Markdown

Review Change Stack

📝 Walkthrough

Walkthrough

Adds a new batch_invariant_logp operator with PyTorch and Triton backends, registry dispatch, tests for correctness and invariance, and documentation updates. Also ignores a local _dev_notes/ directory in .gitignore.

Changes

Batch-Invariant LogP Operator

Layer / File(s) Summary
PyTorch backend implementation
rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py
NativeBatchInvariantLogpOp adds kw-only validate, flattens inputs for row-wise computation, performs conditional range checks, and preserves ignore-index masking with FP32 log-sum-exp and gather semantics.
Triton backend kernels and autograd wrapper
rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py
Forward and backward Triton kernels compute selected log-probabilities and logits gradients, and the public API adds device validation and conditional target-range checks.
Registry wiring for batch_invariant_logp
rl_engine/kernels/registry.py
OpBackend gains Triton and PyTorch entries for batch_invariant_logp, and platform priority lists are extended for cuda, rocm, and cpu dispatch.
PyTorch and Triton test suites
tests/test_batch_invariant_logp.py
Tests cover native and Triton correctness, batch invariance, validation, backward gradients, ignore_index behavior, CUDA agreement, and registry dispatch.
Operator documentation and nav updates
docs/operators/batch-invariant-logp.md, docs/.nav.yml, docs/operators/README.md
New operator documentation describes the tensor contract, reference semantics, batch-invariance constraints, accuracy expectations, usage example, test coverage, and file references; navigation and README entries add the page.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related issues

Suggested labels

needs-gpu-ci, component: kernels, platform: rocm

Suggested reviewers

  • EthanZero2Hero
  • inaniloquentee
  • Flink-ddd

🐇 I hopped through logits bright,
Batch-invariant, row by row tonight.
Triton sings and PyTorch chimes,
Same log-prob through all the times.
Zeroed ignores, gradients true—
A tidy kernel, crisp and new!

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 22.33% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Title check ✅ Passed The title clearly reflects the new batch-invariant logprob operator and mentions both native and Triton backends.
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (2)
docs/operators/batch-invariant-logp.md (1)

99-102: 📐 Maintainability & Code Quality | 🔵 Trivial | 💤 Low value

Reword successive "Triton" sentence openings.

Three successive bullet points begin with "Triton" (lines 99–102 in the rendered document), which reads repetitively. Vary the sentence structure for better flow, e.g., "The vocab traversal...", "Forward scan...", "Backward computation...".

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@docs/operators/batch-invariant-logp.md` around lines 99 - 102, The three
bullet points in the batch-invariant logp section all start with “Triton,” which
reads repetitive; reword the openings in that block to vary the sentence
structure while keeping the same meaning. Update the bullet text near the vocab
traversal, forward scan, and backward computation statements so each begins
differently, using the surrounding markdown list content as the target.
tests/test_triton_batch_invariant_logp.py (1)

328-340: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick win

Let the CPU-tensor validation run without CUDA.

TritonBatchInvariantLogpOp.apply raises on CPU tensors before any kernel launch, so Lines 337-340 do not need a CUDA device. Keeping this test under the class-level @requires_triton_cuda skips that guard on CPU-only CI and leaves the validation path unexercised. Move this case out of the CUDA-gated class, or gate it only on Triton import.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/test_triton_batch_invariant_logp.py` around lines 328 - 340, The
CPU-tensor validation in TritonBatchInvariantLogpOp should be exercised without
requiring CUDA, since the RuntimeError is raised before any kernel launch. Move
test_rejects_cpu_tensor out of the TestTritonValidation class that is decorated
with `@requires_triton_cuda`, or gate only the Triton import in _get_op while
leaving this test ungated so it runs on CPU-only CI.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py`:
- Around line 41-42: The target conversion in batch_invariant_logp is too
permissive because target_ids is cast to torch.long after reshaping, which can
silently truncate float or bool inputs. Add an explicit validation step in the
batch_invariant_logp path to require target_ids to already be an integer/long
tensor before any cast, and fail fast for non-integer inputs. Keep the reshape
logic for logits_2d and target_1d, but ensure the check happens before
target_ids is converted or used.

In `@rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py`:
- Around line 230-244: The `target_ids` validation in `batch_invariant_logp.py`
only checks value bounds, so float tensors can slip through and later be
truncated by `_BatchInvariantLogpFunction.apply`/`forward` when cast to int64.
Update the validation block before the existing range check to reject
non-integer dtypes for `target_ids` (using the `vocab_size`/`valid_targets` path
as the location), and raise a clear error for any float or otherwise non-integer
input.

In `@tests/test_batch_invariant_logp.py`:
- Around line 466-478: The registry test is inheriting shared singleton cache
state from kernel_registry, which can make get_op("batch_invariant_logp") return
a previously cached operator from another test. Update
test_registry_dispatches_correctly to isolate itself by using a fresh registry
instance or by saving and restoring/clearing the registry caches around the
assertion, so the test does not depend on prior backend selections. Use
kernel_registry and get_op as the main symbols to locate the fix.
- Around line 449-458: The CPU-vs-CUDA equality check in
test_cpu_gpu_cross_check is too strict for NativeBatchInvariantLogpOp, since
backend reduction order can differ even when both are correct. Update this test
to compare each backend’s output against the same reference behavior instead of
directly asserting out_cpu and out_cuda match at 1e-6, and keep the check
localized to the test_cpu_gpu_cross_check method.

---

Nitpick comments:
In `@docs/operators/batch-invariant-logp.md`:
- Around line 99-102: The three bullet points in the batch-invariant logp
section all start with “Triton,” which reads repetitive; reword the openings in
that block to vary the sentence structure while keeping the same meaning. Update
the bullet text near the vocab traversal, forward scan, and backward computation
statements so each begins differently, using the surrounding markdown list
content as the target.

In `@tests/test_triton_batch_invariant_logp.py`:
- Around line 328-340: The CPU-tensor validation in TritonBatchInvariantLogpOp
should be exercised without requiring CUDA, since the RuntimeError is raised
before any kernel launch. Move test_rejects_cpu_tensor out of the
TestTritonValidation class that is decorated with `@requires_triton_cuda`, or gate
only the Triton import in _get_op while leaving this test ungated so it runs on
CPU-only CI.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: e8e119e7-3e1c-4443-b35a-f200f33948a7

📥 Commits

Reviewing files that changed from the base of the PR and between 9480500 and d829be2.

📒 Files selected for processing (9)
  • .gitignore
  • docs/.nav.yml
  • docs/operators/README.md
  • docs/operators/batch-invariant-logp.md
  • rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py
  • rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py
  • rl_engine/kernels/registry.py
  • tests/test_batch_invariant_logp.py
  • tests/test_triton_batch_invariant_logp.py

Comment on lines +41 to +42
logits_2d = logits.reshape(-1, vocab_size).float()
target_1d = target_ids.reshape(-1).to(logits.device, dtype=torch.long)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟠 Major | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

printf '\n== file outline ==\n'
ast-grep outline rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py --view expanded || true

printf '\n== file contents (relevant slice) ==\n'
sed -n '1,220p' rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py

printf '\n== related tests search ==\n'
rg -n "batch_invariant_logp|NativeBatchInvariantLogpOp|target_ids|ignore_index|dtype" rl_engine -g '*test*' -g '*spec*' || true

printf '\n== pytorch casting probe ==\n'
python3 - <<'PY'
import torch

cases = [
    torch.tensor([1.9, -0.2, 3.7], dtype=torch.float32),
    torch.tensor([True, False, True], dtype=torch.bool),
    torch.tensor([1, 2, 3], dtype=torch.int32),
]
for t in cases:
    out = t.to(dtype=torch.long)
    print(t.dtype, '->', out.dtype, out.tolist())
PY

Repository: RL-Align/RL-Kernel

Length of output: 5090


🏁 Script executed:

#!/bin/bash
set -euo pipefail

printf '\n== call sites ==\n'
rg -n "NativeBatchInvariantLogpOp|batch_invariant_logp|selected_logprob|target_ids" rl_engine -g '!**/*.pyc' || true

printf '\n== tests in loss area ==\n'
fd -a "test*.py" rl_engine/tests rl_engine/kernels 2>/dev/null || true

printf '\n== nearby test files mentioning logprob ==\n'
rg -n "logprob|batch_invariant|target_ids|ignore_index" rl_engine/tests rl_engine/kernels/tests -g '*test*.py' || true

Repository: RL-Align/RL-Kernel

Length of output: 9236


Reject non-integer target_ids before the cast. .to(dtype=torch.long) truncates float/bool tensors, so invalid inputs can silently pick the wrong class instead of failing.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py` around lines 41 -
42, The target conversion in batch_invariant_logp is too permissive because
target_ids is cast to torch.long after reshaping, which can silently truncate
float or bool inputs. Add an explicit validation step in the
batch_invariant_logp path to require target_ids to already be an integer/long
tensor before any cast, and fail fast for non-integer inputs. Keep the reshape
logic for logits_2d and target_1d, but ensure the check happens before
target_ids is converted or used.

Comment on lines +230 to +244
vocab_size = logits.size(-1)
target_flat = target_ids.reshape(-1)
valid_targets = target_flat[target_flat != ignore_index]
if valid_targets.numel() > 0 and (
(valid_targets < 0).any() or (valid_targets >= vocab_size).any()
):
bad = valid_targets[
(valid_targets < 0) | (valid_targets >= vocab_size)
]
raise ValueError(
f"target_ids contains values outside [0, {vocab_size}): "
f"{bad.tolist()}"
)

return _BatchInvariantLogpFunction.apply(logits, target_ids, ignore_index)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟠 Major | ⚡ Quick win

Validate target_ids dtype before the range check. Float targets can pass the current bounds test and then get truncated by _BatchInvariantLogpFunction.forward()’s int64 cast, which can compute the loss for the wrong class. Reject non-integer dtypes here.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py` around lines 230 -
244, The `target_ids` validation in `batch_invariant_logp.py` only checks value
bounds, so float tensors can slip through and later be truncated by
`_BatchInvariantLogpFunction.apply`/`forward` when cast to int64. Update the
validation block before the existing range check to reject non-integer dtypes
for `target_ids` (using the `vocab_size`/`valid_targets` path as the location),
and raise a clear error for any float or otherwise non-integer input.

Comment on lines +449 to +458
def test_cpu_gpu_cross_check(self):
"""Same input on CPU vs CUDA should match within tolerance."""
op = NativeBatchInvariantLogpOp()
logits_cpu = torch.randn(8, _V)
target_cpu = torch.randint(0, _V, (8,))
out_cpu = op(logits_cpu, target_cpu)
out_cuda = op(logits_cpu.cuda(), target_cpu.cuda())
assert torch.allclose(out_cpu, out_cuda.cpu(), atol=1e-6, rtol=1e-6), (
"CPU vs CUDA result mismatch"
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟠 Major | ⚡ Quick win

Relax this CPU-vs-CUDA assertion.

NativeBatchInvariantLogpOp only promises per-row batch invariance; it does not promise CPU and CUDA reductions land on the same FP32 bits. Comparing out_cpu and out_cuda at 1e-6 can fail on a correct implementation because the two backends are free to reduce in different orders. Compare each backend against the common reference instead of asserting direct equality.

Suggested test change
     def test_cpu_gpu_cross_check(self):
-        """Same input on CPU vs CUDA should match within tolerance."""
+        """CPU and CUDA should each match the reference."""
         op = NativeBatchInvariantLogpOp()
         logits_cpu = torch.randn(8, _V)
         target_cpu = torch.randint(0, _V, (8,))
         out_cpu = op(logits_cpu, target_cpu)
+        ref = _reference_logp(logits_cpu, target_cpu)
         out_cuda = op(logits_cpu.cuda(), target_cpu.cuda())
-        assert torch.allclose(out_cpu, out_cuda.cpu(), atol=1e-6, rtol=1e-6), (
-            "CPU vs CUDA result mismatch"
-        )
+        assert torch.allclose(out_cpu, ref, atol=1e-6, rtol=1e-6)
+        assert torch.allclose(out_cuda.cpu(), ref, atol=1e-6, rtol=1e-6)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def test_cpu_gpu_cross_check(self):
"""Same input on CPU vs CUDA should match within tolerance."""
op = NativeBatchInvariantLogpOp()
logits_cpu = torch.randn(8, _V)
target_cpu = torch.randint(0, _V, (8,))
out_cpu = op(logits_cpu, target_cpu)
out_cuda = op(logits_cpu.cuda(), target_cpu.cuda())
assert torch.allclose(out_cpu, out_cuda.cpu(), atol=1e-6, rtol=1e-6), (
"CPU vs CUDA result mismatch"
)
def test_cpu_gpu_cross_check(self):
"""CPU and CUDA should each match the reference."""
op = NativeBatchInvariantLogpOp()
logits_cpu = torch.randn(8, _V)
target_cpu = torch.randint(0, _V, (8,))
out_cpu = op(logits_cpu, target_cpu)
ref = _reference_logp(logits_cpu, target_cpu)
out_cuda = op(logits_cpu.cuda(), target_cpu.cuda())
assert torch.allclose(out_cpu, ref, atol=1e-6, rtol=1e-6)
assert torch.allclose(out_cuda.cpu(), ref, atol=1e-6, rtol=1e-6)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/test_batch_invariant_logp.py` around lines 449 - 458, The CPU-vs-CUDA
equality check in test_cpu_gpu_cross_check is too strict for
NativeBatchInvariantLogpOp, since backend reduction order can differ even when
both are correct. Update this test to compare each backend’s output against the
same reference behavior instead of directly asserting out_cpu and out_cuda match
at 1e-6, and keep the check localized to the test_cpu_gpu_cross_check method.

Comment on lines +466 to +478
def test_registry_dispatches_correctly():
from rl_engine.kernels.registry import kernel_registry

op = kernel_registry.get_op("batch_invariant_logp")
assert (
isinstance(op, NativeBatchInvariantLogpOp)
or type(op).__name__ == "TritonBatchInvariantLogpOp"
)
logits = torch.randn(4, _V, device="cuda" if torch.cuda.is_available() else "cpu")
target = torch.randint(0, _V, (4,), device=logits.device)
out = op(logits, target)
ref = _reference_logp(logits, target)
assert torch.allclose(out, ref, atol=1e-6)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟠 Major | ⚡ Quick win

Isolate this registry test from global cache state.

kernel_registry.get_op() reuses cached instances keyed only by backend.name in rl_engine/kernels/registry.py:192-223. If an earlier test already cached some other PYTORCH_NATIVE/TRITON operator, this call can return the wrong instance and make this test suite-order dependent. Use a fresh registry instance here, or clear/restore the singleton caches around the assertion. The underlying registry bug is the cache key, but this test currently inherits that flakiness.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/test_batch_invariant_logp.py` around lines 466 - 478, The registry test
is inheriting shared singleton cache state from kernel_registry, which can make
get_op("batch_invariant_logp") return a previously cached operator from another
test. Update test_registry_dispatches_correctly to isolate itself by using a
fresh registry instance or by saving and restoring/clearing the registry caches
around the assertion, so the test does not depend on prior backend selections.
Use kernel_registry and get_op as the main symbols to locate the fix.

@KJLdefeated KJLdefeated left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall correct and clean. The forward is one program per row, fixed _BLOCK_V=1024, no autotune, a sequential left-to-right online-LSE over vocab tiles, so the per-row reduction order is locked and independent of batch size, position, and neighbors. The backward is the FlashAttention pattern done right: save per-row lse in forward, recompute softmax = exp(logits − lse) in backward against that saved lse, grad = grad_out·(onehot − softmax), grid (rows, vocab_tiles), no atomics, no cross-row coupling, so the gradient is batch-invariant too, and consistent with the forward's exact lse.

Some notes:
This PR supports Batch-invariant logprob (Native / Triton), not yet CUDA (cc. me). Can you note that in PR's title? --> add Batch-invariant logprob (Native, Triton)
When this PR is done, I will make another PR to support CUDA version of this.

Comment on lines +232 to +235
valid_targets = target_flat[target_flat != ignore_index]
if valid_targets.numel() > 0 and (
(valid_targets < 0).any() or (valid_targets >= vocab_size).any()
):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On CUDA, .any() inside a Python if calls bool → .item() → a blocking sync — twice per forward — plus the boolean-mask index allocates a tensor each call. This op runs every training step, so every step eats two stream stalls just to validate, which undercuts the point of a fast batch-invariant kernel.
I would suggest it to be opt-in / debug only:

def apply(self, logits, target_ids, ignore_index=-100, *, validate=False):
    ...
    if validate:  # opt-in / debug only — this path syncs
        valid = target_flat[target_flat != ignore_index]
        if valid.numel() and ((valid < 0).any() or (valid >= vocab_size).any()):
            raise ValueError(...)

Comment on lines +3 to +11

"""Tests for the Triton batch-invariant selected-logprob kernel (issue #148).

These tests validate that the Triton kernel produces results that:
1. Match the PyTorch reference implementation (correctness).
2. Are bitwise identical across different batch sizes / positions (batch-invariance).
3. Support backward pass (gradient correctness).

All tests are skipped when Triton or CUDA is unavailable (e.g. on Windows or CPU-only).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you combine all different backend method of batch invariant logp (Triton / Torch) into one file (tests/test_batch_invariant_logp.py) for codebase cleanliness and reusable functions? For future CUDA batch-invariant logp, I will also put CUDA test in tests/test_batch_invariant_logp.py.

out.sum().backward()
```

## Tests

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just make a note that this kernel is currently lacking benchmarks, I will support it in another PR (cuda ver. for this).

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed, thanks.
Changes:

  • Renamed PR scope/title to Batch-invariant logprob (Native, Triton); CUDA + benchmarks will be handled in a follow-up PR.
  • Made target range validation opt-in via validate=False by default to avoid CUDA stream sync in the Triton hot path.
  • Merged Native and Triton tests into tests/test_batch_invariant_logp.py.
  • Updated operator docs to mention validate=True and benchmark/CUDA follow-up.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py`:
- Around line 27-30: The default path in BatchInvariantLogp currently leaves
validation off, which lets non-ignore negative target_ids fall through and index
from the vocab tail. Update the public forward/apply flow in
batch_invariant_logp.py so validation is enabled by default, and keep
validate=False only as an explicit opt-in fast path for trusted callers; make
sure the checks in the target handling logic around the validation branch are
still used before any indexing.

In `@tests/test_batch_invariant_logp.py`:
- Around line 758-770: The CPU-input guard in TritonBatchInvariantLogpOp should
be tested without depending on a CUDA-only environment. Move
test_rejects_cpu_tensor out from under the requires_triton_cuda gate, or apply a
Triton-only check instead, so it still exercises
TritonBatchInvariantLogpOp.apply’s early device validation on CPU-only runs.
Keep the test focused on the RuntimeError path that rejects non-GPU tensors.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 7ab61cb0-ea7e-42ef-99f1-6d8d047ff9b1

📥 Commits

Reviewing files that changed from the base of the PR and between d829be2 and d748a8f.

📒 Files selected for processing (4)
  • docs/operators/batch-invariant-logp.md
  • rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py
  • rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py
  • tests/test_batch_invariant_logp.py
✅ Files skipped from review due to trivial changes (1)
  • docs/operators/batch-invariant-logp.md
🚧 Files skipped from review as they are similar to previous changes (1)
  • rl_engine/kernels/ops/triton/loss/batch_invariant_logp.py

Comment thread rl_engine/kernels/ops/pytorch/loss/batch_invariant_logp.py
Comment thread tests/test_batch_invariant_logp.py Outdated
@hihaluemen hihaluemen changed the title [WS1][kernels] Batch-invariant logprob (selected, locked reduction) [WS1][kernels] Batch-invariant logprob (Native, Triton) Jun 29, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants