Skip to content

feat(ws1): Add PyTorch RoPE reference operator#167

Open
a-kaa wants to merge 3 commits into
RL-Align:mainfrom
a-kaa:dev-kernel
Open

feat(ws1): Add PyTorch RoPE reference operator#167
a-kaa wants to merge 3 commits into
RL-Align:mainfrom
a-kaa:dev-kernel

Conversation

@a-kaa

@a-kaa a-kaa commented Jun 21, 2026

Copy link
Copy Markdown
Collaborator

Summary

Implements NativeRoPEOp, the fp32 ground-truth RoPE reference operator for ISSUE #108.

  • Shapes: Qwen3-8B defaults — theta=1e6 (not the 1e4 default), head_dim=128, full-dimension rotation (half=64).
  • Dtype contract: forward() follows input dtype; forward_fp32() is the fp32 gold standard with internal fp32 accumulation. Both implemented (fp16 and fp32 paths tested).
  • Purity: no in-place ops, no random state, no global state. cos/sin are computed internally in fp32 from positions/theta — no external cache accepted or returned.

Changes

  • rl_engine/kernels/ops/pytorch/pos/rope.pyNativeRoPEOp (op_class = "elementwise")
  • docs/operators/rope.md — operator reference doc (shapes, convention, tolerance contract)
  • tests/ops/test_rope.py — full test suite, all green locally:
    • Axis A (bitwise batch invariance): torch.equal between batch=1 slice and batch=N slice, including padded batches and per-row 2D positions ([B, S])
    • Axis B (accuracy): forward vs forward_fp32 under tolerance contract (fp32: 1e-5, bf16: 2e-2, fp16: 1e-3)
    • Bitwise match against the frozen ISSUE_108_OPS_DEV.md §5 reference formula
    • Shape/dtype correctness, position broadcasting equivalence ([S] vs [B,S]), pure-function checks
    • Qwen3-8B specific shapes (Q heads=32, KV heads=8, head_dim=128)
image

Summary by CodeRabbit

  • New Features
    • Added a RoPE (Rotary Position Embeddings) operator with a PyTorch native backend, available for CPU, CUDA, and ROCm.
  • Documentation
    • Updated the operator index and added detailed RoPE operator documentation, including usage and reference behavior.
  • Tests
    • Added a comprehensive RoPE test suite validating correctness against an independent fp32 reference, plus output shape/dtype and batch/position semantics.

@coderabbitai

coderabbitai Bot commented Jun 21, 2026

Copy link
Copy Markdown

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: a48baef5-3b58-49a7-9348-d11921fa42a4

📥 Commits

Reviewing files that changed from the base of the PR and between 23b1156 and 4bda0dd.

📒 Files selected for processing (2)
  • docs/operators/README.md
  • rl_engine/kernels/registry.py
✅ Files skipped from review due to trivial changes (1)
  • docs/operators/README.md

📝 Walkthrough

Walkthrough

Adds NativeRoPEOp, a pure-PyTorch HF rotate-half RoPE operator. It is implemented with forward, forward_fp32, and _compute_cos_sin, registered for cpu/cuda/rocm, covered by pytest, and documented in docs/operators/rope.md.

Changes

NativeRoPEOp implementation, registration, tests, and documentation

Layer / File(s) Summary
OpBackend enum and KernelRegistry dispatch for rope
rl_engine/kernels/registry.py, rl_engine/kernels/ops/pytorch/rotary_embedding/__init__.py
Adds PYTORCH_NATIVE_ROPE to OpBackend wired to rope.NativeRoPEOp. Extends KernelRegistry._priority_map to route rope on cuda/rocm/cpu to that backend. __init__.py imports and exports NativeRoPEOp.
NativeRoPEOp: forward, forward_fp32, and _compute_cos_sin
rl_engine/kernels/ops/pytorch/rotary_embedding/rope.py
forward applies rotate-half with fp32 cos/sin and casts the result back to input dtype. forward_fp32 keeps the output in fp32. _compute_cos_sin builds inv_freq from theta, computes frequency products from positions, and reshapes cos/sin for 1D [S] or 2D [B,S] broadcasting.
Correctness, batch invariance, accuracy, and shape tests
tests/test_rope.py
Four test classes: bitwise match against HF fp32 reference and 1D/2D position equivalence; batch invariance with padded rows and bf16 inputs; dtype-parameterized tolerance checks between forward and forward_fp32; Qwen3-8B-like shapes including the n_kv_heads=8 K-path.
Operator documentation and README index update
docs/operators/rope.md, docs/operators/README.md
rope.md documents usage, backend dispatch, tensor contract, rotate-half semantics, and accuracy expectations. README Current Pages section gains a link for rope.md.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~15 minutes

Poem

🐇 I twirled the rope with cosine light,
Half the vector spun just right.
fp32 hops stayed clear and true,
Tests lined up in tidy view,
The registry found its shiny way—
A little bunny kernel dance today!

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 35.71% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly matches the main change: adding a PyTorch RoPE reference operator.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (1)
tests/test_rope.py (1)

53-126: ⚡ Quick win

Add boundary-contract tests for invalid inputs.

Given the operator now has strict shape/device contracts, add negative tests for odd head_dim and unsupported positions rank to prevent future regressions.

Suggested tests
 class TestNativeRoPEOpCorrectness:
@@
     def test_position_zero_is_identity_for_cos(self):
         ...
         assert torch.allclose(out, x.float(), atol=1e-7)
+
+    def test_odd_head_dim_raises(self):
+        op = NativeRoPEOp()
+        x = torch.randn(1, 1, 4, 127)
+        pos = torch.arange(4, dtype=torch.long)
+        with pytest.raises(ValueError, match="even head_dim"):
+            op.forward_fp32(x, pos)
+
+    def test_positions_rank_validation(self):
+        op = NativeRoPEOp()
+        x = torch.randn(1, 1, 4, 128)
+        bad_pos = torch.zeros(1, 4, 1, dtype=torch.long)
+        with pytest.raises(ValueError, match="positions"):
+            op.forward_fp32(x, bad_pos)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/test_rope.py` around lines 53 - 126, Add two new test methods to the
TestNativeRoPEOpCorrectness class to validate the operator's input contracts:
create a test_odd_head_dim_raises_error method that verifies
NativeRoPEOp.forward_fp32 raises an appropriate error when given tensors with
odd head_dim, and create a test_unsupported_positions_rank_raises_error method
that verifies the operator rejects positions tensors with unsupported rank
(e.g., rank > 2). Use pytest.raises or equivalent assertion to confirm the
expected exceptions are raised for each invalid input scenario.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@docs/operators/rope.md`:
- Around line 38-43: In the documentation table for NativeRoPEOp, the `x`
argument (Query or key tensor with shape [B, H, S, D]) is missing an explicit
precondition. Since NativeRoPEOp splits D into two halves internally, add a
requirement stating that D must be even to the Requirements column for the x
argument. This ensures users understand that odd-width dimension inputs are not
supported and clarifies the expected behavior of the rotate-half semantics.

In `@rl_engine/kernels/ops/pytorch/rotary_embedding/rope.py`:
- Around line 69-71: Add explicit validation to ensure the head dimension D is
even before proceeding with RoPE table construction. After extracting D from
x.shape[-1] (around line 69), insert a check to verify that D is even using
modulo operator, and raise a clear ValueError if D is odd to fail fast at the
contract boundary. Apply the same validation pattern at the second location
mentioned (lines 85-86) where similar dimension extraction occurs.
- Around line 80-83: In the `_compute_cos_sin` function, the `positions` tensor
may reside on a different device than `x` (e.g., `x` on GPU, `positions` on
CPU), which causes device mismatch errors during the frequency computation.
Before converting `positions` to float and computing the frequency values in the
line where `pos_float = positions.float().unsqueeze(-1)` occurs, first ensure
`positions` is moved to the same device as `x` using the appropriate device
transfer operation, then convert to float32. Additionally, validate the rank of
the `positions` tensor to ensure it has the expected shape for the downstream
operations.

---

Nitpick comments:
In `@tests/test_rope.py`:
- Around line 53-126: Add two new test methods to the
TestNativeRoPEOpCorrectness class to validate the operator's input contracts:
create a test_odd_head_dim_raises_error method that verifies
NativeRoPEOp.forward_fp32 raises an appropriate error when given tensors with
odd head_dim, and create a test_unsupported_positions_rank_raises_error method
that verifies the operator rejects positions tensors with unsupported rank
(e.g., rank > 2). Use pytest.raises or equivalent assertion to confirm the
expected exceptions are raised for each invalid input scenario.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: a4031d4f-94d1-4365-abe7-6f45f2958521

📥 Commits

Reviewing files that changed from the base of the PR and between a302be4 and 78d01b1.

📒 Files selected for processing (6)
  • docs/operators/README.md
  • docs/operators/rope.md
  • rl_engine/kernels/ops/pytorch/rotary_embedding/__init__.py
  • rl_engine/kernels/ops/pytorch/rotary_embedding/rope.py
  • rl_engine/kernels/registry.py
  • tests/test_rope.py

Comment thread docs/operators/rope.md
Comment on lines +38 to +43
| Argument | Shape | Dtype | Requirements |
| --- | --- | --- | --- |
| `x` | `[B, H, S, D]` | `float32`, `bfloat16`, or `float16` | Query or key tensor; Qwen3 uses `D=128`. |
| `positions` | `[S]` or `[B, S]` | Integer | Absolute token positions. |
| `theta` | scalar | float | Defaults to `1_000_000.0` for Qwen3. |
| Output | `[B, H, S, D]` | See below | Same shape as `x`. |

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

State the even-head-dim precondition.

NativeRoPEOp splits D into two halves, so the contract should explicitly say D must be even. Without that, odd-width inputs look supported here but won’t preserve the intended rotate-half semantics.

📌 Suggested doc tweak
-| `x` | `[B, H, S, D]` | `float32`, `bfloat16`, or `float16` | Query or key tensor; Qwen3 uses `D=128`. |
+| `x` | `[B, H, S, D]` | `float32`, `bfloat16`, or `float16` | Query or key tensor; `D` must be even; Qwen3 uses `D=128`. |
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
| Argument | Shape | Dtype | Requirements |
| --- | --- | --- | --- |
| `x` | `[B, H, S, D]` | `float32`, `bfloat16`, or `float16` | Query or key tensor; Qwen3 uses `D=128`. |
| `positions` | `[S]` or `[B, S]` | Integer | Absolute token positions. |
| `theta` | scalar | float | Defaults to `1_000_000.0` for Qwen3. |
| Output | `[B, H, S, D]` | See below | Same shape as `x`. |
| Argument | Shape | Dtype | Requirements |
| --- | --- | --- | --- |
| `x` | `[B, H, S, D]` | `float32`, `bfloat16`, or `float16` | Query or key tensor; `D` must be even; Qwen3 uses `D=128`. |
| `positions` | `[S]` or `[B, S]` | Integer | Absolute token positions. |
| `theta` | scalar | float | Defaults to `1_000_000.0` for Qwen3. |
| Output | `[B, H, S, D]` | See below | Same shape as `x`. |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@docs/operators/rope.md` around lines 38 - 43, In the documentation table for
NativeRoPEOp, the `x` argument (Query or key tensor with shape [B, H, S, D]) is
missing an explicit precondition. Since NativeRoPEOp splits D into two halves
internally, add a requirement stating that D must be even to the Requirements
column for the x argument. This ensures users understand that odd-width
dimension inputs are not supported and clarifies the expected behavior of the
rotate-half semantics.

Comment on lines +69 to +71
D = x.shape[-1]
half = D // 2

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Add explicit even head_dim validation before building RoPE tables.

Odd D currently fails later via shape mismatch; fail fast with a clear error at the contract boundary.

Proposed fix
         D = x.shape[-1]
+        if D % 2 != 0:
+            raise ValueError(f"RoPE requires even head_dim, got {D}")
         half = D // 2

Also applies to: 85-86

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/kernels/ops/pytorch/rotary_embedding/rope.py` around lines 69 - 71,
Add explicit validation to ensure the head dimension D is even before proceeding
with RoPE table construction. After extracting D from x.shape[-1] (around line
69), insert a check to verify that D is even using modulo operator, and raise a
clear ValueError if D is odd to fail fast at the contract boundary. Apply the
same validation pattern at the second location mentioned (lines 85-86) where
similar dimension extraction occurs.

Comment on lines +80 to +83
pos_float = positions.float().unsqueeze(-1)

# freqs: [S, half] or [B, S, half]
freqs = pos_float * inv_freq

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Normalize positions to x.device (and validate rank) in _compute_cos_sin.

Current code can crash on mixed-device inputs (x on accelerator, positions on CPU). Convert positions to fp32 on x.device before frequency math.

Proposed fix
-        pos_float = positions.float().unsqueeze(-1)
+        if positions.dim() not in (1, 2):
+            raise ValueError(f"positions must have shape [S] or [B,S], got dim={positions.dim()}")
+        pos_float = positions.to(device=x.device, dtype=torch.float32).unsqueeze(-1)

Also applies to: 92-99

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/kernels/ops/pytorch/rotary_embedding/rope.py` around lines 80 - 83,
In the `_compute_cos_sin` function, the `positions` tensor may reside on a
different device than `x` (e.g., `x` on GPU, `positions` on CPU), which causes
device mismatch errors during the frequency computation. Before converting
`positions` to float and computing the frequency values in the line where
`pos_float = positions.float().unsqueeze(-1)` occurs, first ensure `positions`
is moved to the same device as `x` using the appropriate device transfer
operation, then convert to float32. Additionally, validate the rank of the
`positions` tensor to ensure it has the expected shape for the downstream
operations.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@rl_engine/kernels/ops/pytorch/linear/matmul.py`:
- Around line 25-27: The forward method accepts mixed input dtypes without
validation and silently casts the result to a.dtype, which can hide upstream
mistakes. Add an explicit dtype compatibility check at the beginning of the
forward method that validates tensors a and b have compatible or identical
dtypes, raising a clear error if they don't match. Only proceed with the
forward_fp32 computation and dtype casting after this validation passes.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: d71dc47d-9508-47c9-a83c-852d642341a6

📥 Commits

Reviewing files that changed from the base of the PR and between 78d01b1 and ba5a4cd.

📒 Files selected for processing (6)
  • docs/operators/README.md
  • docs/operators/matmul.md
  • rl_engine/kernels/ops/pytorch/linear/__init__.py
  • rl_engine/kernels/ops/pytorch/linear/matmul.py
  • rl_engine/kernels/registry.py
  • tests/test_matmul.py
✅ Files skipped from review due to trivial changes (2)
  • docs/operators/README.md
  • docs/operators/matmul.md
🚧 Files skipped from review as they are similar to previous changes (1)
  • rl_engine/kernels/registry.py

Comment on lines +25 to +27
def forward(self, a: Tensor, b: Tensor) -> Tensor:
"""Compute `a @ b` and return the input dtype."""
return self.forward_fp32(a, b).to(dtype=a.dtype)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Validate dtype compatibility before casting output to a.dtype.

On Line 27, mixed input dtypes are silently accepted and the result is forced to a.dtype, which can hide upstream mistakes and unintentionally downcast results. Add an explicit dtype check (or documented promotion rule) before the cast.

Suggested patch
     def forward(self, a: Tensor, b: Tensor) -> Tensor:
         """Compute `a @ b` and return the input dtype."""
+        if a.dtype != b.dtype:
+            raise TypeError(f"NativeMatmulOp.forward expects matching dtypes, got {a.dtype} and {b.dtype}")
         return self.forward_fp32(a, b).to(dtype=a.dtype)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/kernels/ops/pytorch/linear/matmul.py` around lines 25 - 27, The
forward method accepts mixed input dtypes without validation and silently casts
the result to a.dtype, which can hide upstream mistakes. Add an explicit dtype
compatibility check at the beginning of the forward method that validates
tensors a and b have compatible or identical dtypes, raising a clear error if
they don't match. Only proceed with the forward_fp32 computation and dtype
casting after this validation passes.

@a-kaa a-kaa changed the title Add PyTorch RoPE reference operator feat(ws1): Add PyTorch RoPE reference operator Jun 22, 2026
@Flink-ddd Flink-ddd assigned Flink-ddd and unassigned Flink-ddd Jun 22, 2026
@Flink-ddd

Copy link
Copy Markdown
Collaborator

please resolve conflicts and CI error first, Thanks.

@KJLdefeated KJLdefeated left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR looks good to me. After resolve the requests, I am happy to approve.

Comment on lines +70 to +71
# positions: [S] -> [S, 1] or [B, S] -> [B, S, 1]
pos_float = positions.float().unsqueeze(-1)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The positions may not have same device with inv_freq. Maybe you should convert it to device=x.device

@Flink-ddd Flink-ddd left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Focusing strictly on the remaining architectural blind spots to ensure this PR meets the exact same WS1 standards as #160 and #166, here is the review:

Comment thread tests/test_rope.py
assert torch.allclose(out_typed, out_fp32, atol=atol, rtol=rtol), (
f"dtype={dtype}, max_abs_error={diff:.3e} exceeds " f"atol={atol}, rtol={rtol}"
)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unlike the RMSNorm and SwiGLU PRs, this PR currently has zero tests for backward(). Even though RoPE has no trainable weights, gradients must flow back through it accurately to the Query and Key projection layers.

You need to verify that x.grad is correctly computed and is strictly Batch-Invariant (Axis-A), otherwise upstream training drift is guaranteed. Please add a slice test to assert that computing the backward pass on a full batch yields bitwise identical gradients to computing it on a sliced subset (batch size = 1).

@Flink-ddd

Copy link
Copy Markdown
Collaborator

Please resolve the code conflicts first, Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants