Skip to content

feat(ws1): add NativeKVCacheAttnOp pure-PyTorch KV-cache attention reference#195

Open
maxiaosong1124 wants to merge 12 commits into
RL-Align:mainfrom
maxiaosong1124:feat/ws1-kv-cache-attention-pytorch-op
Open

feat(ws1): add NativeKVCacheAttnOp pure-PyTorch KV-cache attention reference#195
maxiaosong1124 wants to merge 12 commits into
RL-Align:mainfrom
maxiaosong1124:feat/ws1-kv-cache-attention-pytorch-op

Conversation

@maxiaosong1124

@maxiaosong1124 maxiaosong1124 commented Jun 27, 2026

Copy link
Copy Markdown
Collaborator

⚠️ Depends on #188 (attention) — do not merge first

NativeKVCacheAttnOp delegates to the same NativeAttentionOp reduction added in #188,
so prefill and decode share one numerical path. This branch is stacked on top of #188's
feat/ws1-attention-pytorch-op.

Because #188 is not merged yet, the diff below currently also contains #188's attention
files.
For this review, please look only at the kv_cache-specific changes:

  • rl_engine/kernels/ops/pytorch/attention/kv_cache.py
  • tests/test_kv_cache_attention.py
  • the kv_cache_attention entries in rl_engine/kernels/registry.py
  • the kv_cache CI step in .github/workflows/ci.yml

Once #188 merges to main, I'll rebase onto main and the diff becomes lean (attention-free).
Please do not merge this before #188.

Summary

Pure-PyTorch fp32 ground-truth reference for KV-cache (incremental/decode) attention
(WS1 / ISSUE #108). It concatenates cached and new K/V along the sequence axis and then calls
the exact same NativeAttentionOp
used for full-sequence prefill — re-implementing softmax
here would defeat the purpose. This guarantees decode (rollout) and prefill (training) take an
identical reduction, which is what keeps rollout↔training numerically consistent.

Shapes follow Qwen3-8B (synthetic tensors, no weights downloaded): 32 Q / 8 KV heads
(GQA g=4) / head_dim=128, scale = 1/√128.

What this adds

  • NativeKVCacheAttnOp (pytorch/attention/kv_cache.py):
    forward / forward_fp32(q, k_cache, v_cache, k_new, v_new, *, causal=True, scale=None, key_padding_mask=None) per the frozen contract (INTERFACES §9). cat([cache, new], dim=2)
    → delegates to NativeAttentionOp. Returns attention output only; cache update is out of
    scope for the numerical contract.
  • Registry: PYTORCH_NATIVE_KV_CACHE_ATTN enum + kv_cache_attention priority-map entries
    (cpu / cuda / rocm).
  • tests/test_kv_cache_attention.py — 14 CPU-safe tests (+ GPU/large markers).
  • CI: pytest tests/test_kv_cache_attention.py -k "not large and not gpu".

Test coverage

  • Delegation correctnessforward / forward_fp32 equal NativeAttentionOp on the
    concatenated K/V.
  • Prefill↔decode consistency — stepwise decode matches a single full-sequence prefill.
  • Cache split-point invariance — same total KV split at different cache/new boundaries
    gives the same result.
  • Axis-A — batch slice invariance (torch.equal).
  • key_padding_mask — padded keys get zero weight; near-equality atol=2e-6 (differing
    softmax reduction widths, same rationale as :feat(ws1): add NativeAttentionOp pure-PyTorch standard-softmax reference #188's padding test, with headroom over the
    ~1.3e-6 platform-sensitive drift).
  • Axis-B — bf16 / fp16 forward tracks fp32 within per-dtype peak-relative tolerance.
  • Empty-cache == plain attention, closed-form uniform-attention decode, output shape follows
    q, GQA divisibility guard, input purity (no mutation), registry dispatch.

Local run

```
python -m pytest tests/test_kv_cache_attention.py -v -k "not large and not gpu"

14 passed

```

Notes

  • Pure function: no randomness, no in-place, no device moves; dtype/device follow inputs.
  • Mask convention: True = valid / keep, False = padding (consistent with the rest of the repo).

Summary by CodeRabbit

  • New Features
    • Added new pure-PyTorch reference implementations for standard attention and KV-cache (incremental decode) attention.
    • Updated kernel dispatch/registry routing so these attention backends are selected automatically across CPU, CUDA, and ROCm.
  • Documentation
    • Added an Operators guide page describing the attention operator contract, masking, dtype behavior, and test expectations.
  • Tests
    • Added deterministic validation suites for attention and KV-cache attention, including fp32 golden checks, masking/GQA correctness, determinism, gradients, and dispatch.
  • CI
    • Expanded CPU-safe unit-test attention ground-truth checks (excluding large/gpu).

WS1 ground-truth attention op for issue RL-Align#108 (Qwen3-8B GQA attention):
- NativeAttentionOp: out = softmax(Q Kᵀ * scale + masks) @ V, a hand-written
  naive softmax (NOT F.scaled_dot_product_attention / flash) so the reduction
  order is fixed for the batch-invariance contract. GQA 32/8 via
  repeat_interleave, causal offset Skv-Sq+1 (prefill + decode), key_padding_mask
  (True=valid), scale default 1/sqrt(128).
Exposes the forward / forward_fp32 dual-path contract (fp32 ground truth +
dtype-behavior path); forward_fp32 disables TF32/autocast for a strict fp32
reference. Pure function, fp32 accumulation.

- register PYTORCH_NATIVE_ATTENTION in OpBackend and the cuda/rocm/cpu priority
  maps under op_type "attention" (distinct from the production "attn" /
  PYTORCH_ATTN SDPA fallback)
- tests/test_attention.py: forward_fp32 vs independent fp32 reference, closed-form
  causal/decode, GQA replication + divisibility guard, scale, key-padding,
  dtype-path accuracy (Axis-B), Axis-A batch invariance (slice + chunked +
  padding), purity, gradient flow, registry dispatch, GPU-only LARGE Qwen3-8B smoke
- docs/operators/attention.md + nav/index wiring
- standard_attn: define fully key-padding-masked query rows as 0 (was NaN);
  guarded to the padding branch so the no-pad path is unchanged, row-independent
  so Axis-A holds; add test_fully_masked_query_returns_zero_not_nan
- test: drop the double 1.5x margin in _enough_gpu_memory (LARGE skip now ~50 GB
  as documented, no longer over-skips 80 GB GPUs)
- docs/attention.md: add text lang to the diagram fence (MD040); clarify that
  dispatch uses forward() input-dtype path, forward_fp32() is the explicit fp32 path
key_padding_mask drift over differing reduction widths (Skv=10 vs 6) is
~1.3e-6 and platform-sensitive; atol=1e-6 failed locally for the reviewer.
Bump the threshold to 2e-6 for headroom, and update the test-coverage doc
line so padding reads as near-equality, not part of the bitwise Axis-A claim.
@coderabbitai

coderabbitai Bot commented Jun 27, 2026

Copy link
Copy Markdown

Review Change Stack

Warning

Review limit reached

@maxiaosong1124, we couldn't start this review because you've reached your PR review rate limit.

More reviews will be available in 39 minutes and 43 seconds. Learn how PR review limits work.

Your organization has run out of usage credits. Purchase more credits in the billing tab to continue.

⌛ How to resolve this issue?

After more reviews become available, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

To avoid repeated limits, reduce automatic review volume by pausing incremental auto-reviews earlier, using label-based review opt-in, excluding WIP or generated PR titles, or requesting reviews manually when the PR is ready. If your team needs uninterrupted high-volume reviews, an organization admin can enable usage-based credits.

🚦 How do rate limits work?

CodeRabbit enforces per-developer PR review limits for each organization. Most developers receive the normal plan review availability.

For paid Pro and Pro+ PR reviews, CodeRabbit uses adaptive limits for sustained high-volume activity. When a developer's recent PR review activity reaches the 95th percentile or higher among CodeRabbit users, additional reviews become available more gradually as earlier reviews age out of the rolling window.

Please see our Fair Usage Limits Policy for further information.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 89c325f2-41fd-40ba-bbd6-6a61100c1667

📥 Commits

Reviewing files that changed from the base of the PR and between 4e6df2c and 4a23b5f.

📒 Files selected for processing (2)
  • rl_engine/kernels/ops/pytorch/attention/kv_cache.py
  • tests/test_kv_cache_attention.py
📝 Walkthrough

Walkthrough

Adds pure-PyTorch reference attention operators for standard and KV-cache flows, routes both operator names through the kernel registry, documents the standard attention contract, and expands CI plus pytest coverage for correctness, masking, dtype behavior, batch invariance, and dispatch.

Changes

Attention reference operators

Layer / File(s) Summary
Standard attention contract
docs/operators/attention.md, docs/.nav.yml, docs/operators/README.md, rl_engine/kernels/ops/pytorch/attention/standard_attn.py
The new operator page documents the WS1 softmax attention contract, and NativeAttentionOp implements the default and strict-fp32 reference paths with causal masking, key padding, GQA replication, scaling, and NaN handling.
KV-cache operator and routing
rl_engine/kernels/ops/pytorch/attention/kv_cache.py, rl_engine/kernels/registry.py
NativeKVCacheAttnOp concatenates cached and new K/V tensors before delegating to NativeAttentionOp, and the registry adds the new backend enums plus attention/kv_cache_attention routing on CUDA, ROCm, and CPU.
Attention validation and CI
tests/test_attention.py, .github/workflows/ci.yml
The attention test suite checks fp32 reference matching, masking, GQA, dtype accuracy, batch invariance, gradients, registry dispatch, and the large-shape smoke test, and CI adds CPU-safe runs for attention and KV-cache attention tests.
KV-cache validation
tests/test_kv_cache_attention.py
The KV-cache attention test suite checks concatenation equivalence, split-point invariance, decode versus prefill behavior, masking, dtype accuracy, output shape, immutability, and registry dispatch.

Sequence Diagram(s)

sequenceDiagram
  participant kernel_registry
  participant NativeAttentionOp
  participant _strict_fp32_math
  kernel_registry->>NativeAttentionOp: get_op("attention")
  NativeAttentionOp->>_strict_fp32_math: forward_fp32() disables autocast and TF32
  _strict_fp32_math-->>NativeAttentionOp: restores TF32 state
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Suggested labels

component: kernels

Suggested reviewers

  • bitborne
  • inaniloquentee
  • Flink-ddd

Poem

Hop hop, the softmax moon is bright 🐇
Cache and query dance at night.
FP32 moonbeams keep truth in sight,
masks and heads all line up right,
Tiny paws applaud the kernel flight.

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly describes the primary change: adding the NativeKVCacheAttnOp pure-PyTorch KV-cache attention reference.
Docstring Coverage ✅ Passed Docstring coverage is 92.31% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@docs/operators/attention.md`:
- Around line 72-73: The operator contract text is out of sync with the
implementation and should be corrected in the attention docs. Update the
`forward_fp32()` contract wording so it no longer claims the original dtype is
preserved, and change the padding tolerance statement to match the shipped
`2e-6` value. Keep the wording aligned with the operator contract section and
the relevant `forward_fp32()` behavior so downstream users get the correct
guarantees.

In `@rl_engine/kernels/ops/pytorch/attention/kv_cache.py`:
- Around line 78-118: Add an explicit sequence-length validation in
`KVCacheAttention.forward` and `KVCacheAttention.forward_fp32` to reject cases
where `q` does not match the newly appended token length (`Sq != S_new`). Use
the existing `_concat_kv` flow as the location to enforce this check before
calling `_attn.forward` or `_attn.forward_fp32`, and raise an error with clear
context when the `q`/`k_new`/`v_new` shapes are misaligned.

In `@rl_engine/kernels/ops/pytorch/attention/standard_attn.py`:
- Around line 183-191: The _strict_fp32_math helper currently mutates the
process-wide torch.backends.cuda.matmul.allow_tf32 flag, which can affect
unrelated CUDA work and concurrent calls. Update the forward_fp32 path to avoid
toggling shared backend state in standard_attn.py, and instead use a safer
isolation approach within _strict_fp32_math (or an equivalent local-only
mechanism) while still disabling autocast for the true fp32 reference path.

In `@tests/test_attention.py`:
- Around line 281-285: Reformat the assert block in test_attention.py so it
matches Black’s preferred wrapping, especially the multi-line assert with the
trailing message string. Update the padding-mask check around the
masked/valid_only comparison to use Black-compatible line breaks and
parentheses, keeping the same logic but adjusting the formatting so CI no longer
rewrites it.

In `@tests/test_kv_cache_attention.py`:
- Around line 145-157: The split-point KV-cache test is using a full-length
query tensor for every split, which hides whether the suffix-only path is
correct. Update the test in the split loop so the query tensor is the suffix of
q that corresponds to each non-zero split, and keep the existing forward_fp32
comparison against the all-cache/all-new cases. Use the symbols q, k_full,
v_full, split, and op.forward_fp32 to align the test with the documented
KV-cache contract.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 527d2fb3-1020-4591-be59-ecf3322d1308

📥 Commits

Reviewing files that changed from the base of the PR and between ea196da and efdcce2.

📒 Files selected for processing (9)
  • .github/workflows/ci.yml
  • docs/.nav.yml
  • docs/operators/README.md
  • docs/operators/attention.md
  • rl_engine/kernels/ops/pytorch/attention/kv_cache.py
  • rl_engine/kernels/ops/pytorch/attention/standard_attn.py
  • rl_engine/kernels/registry.py
  • tests/test_attention.py
  • tests/test_kv_cache_attention.py

Comment thread docs/operators/attention.md Outdated
Comment thread rl_engine/kernels/ops/pytorch/attention/kv_cache.py
Comment on lines +183 to +191
def _strict_fp32_math(device_type: str):
"""Disable autocast and TF32 for a true fp32 path, restoring state after."""
prev_tf32 = torch.backends.cuda.matmul.allow_tf32
torch.backends.cuda.matmul.allow_tf32 = False
try:
with torch.autocast(device_type=device_type, enabled=False):
yield
finally:
torch.backends.cuda.matmul.allow_tf32 = prev_tf32

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🩺 Stability & Availability | 🟠 Major | 🏗️ Heavy lift

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Inspect where the public attention reference is reached from, and whether
# concurrent execution paths exist in the repo.
rg -nP 'get_op\("attention"\)|NativeAttentionOp\(|forward_fp32\s*\(' rl_engine tests
rg -nP '\b(threading|ThreadPoolExecutor|ProcessPoolExecutor|asyncio|DataLoader)\b' rl_engine tests

Repository: RL-Align/RL-Kernel

Length of output: 6173


Avoid mutating the global TF32 backend flag here. torch.backends.cuda.matmul.allow_tf32 is process-wide, so forward_fp32() can perturb unrelated CUDA work and becomes unsafe if called concurrently. If this stays as a public reference path, it needs a safer isolation strategy than toggling shared backend state.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/kernels/ops/pytorch/attention/standard_attn.py` around lines 183 -
191, The _strict_fp32_math helper currently mutates the process-wide
torch.backends.cuda.matmul.allow_tf32 flag, which can affect unrelated CUDA work
and concurrent calls. Update the forward_fp32 path to avoid toggling shared
backend state in standard_attn.py, and instead use a safer isolation approach
within _strict_fp32_math (or an equivalent local-only mechanism) while still
disabling autocast for the true fp32 reference path.

Comment thread tests/test_attention.py Outdated
Comment thread tests/test_kv_cache_attention.py Outdated
…pytorch-op

# Conflicts:
#	docs/.nav.yml
#	docs/operators/README.md
…e (KJLdefeated)

Replace isfinite-only grad check with autograd through the independent
double-precision _ref_softmax_attn under a random (seeded) cotangent. isfinite
can't tell a correct gradient from a wrong-but-finite one, and attention's
backward (softmax Jacobian + dQ/dK/dV contractions) is the most error-prone in
the stack; .sum()'s all-ones cotangent would also collapse the contraction.
@maxiaosong1124 maxiaosong1124 force-pushed the feat/ws1-kv-cache-attention-pytorch-op branch from efdcce2 to 3d466fe Compare June 28, 2026 08:46
…eview

forward_fp32 returns fp32 (not the input dtype), and the shipped padding
tolerance is atol=2e-6, not 1e-6. Aligns the operator contract doc with the
code (CodeRabbit).
…ference

Concat cache+new K/V then delegate to NativeAttentionOp so prefill and decode
share one reduction path. Register kv_cache_attention in the registry (cpu/cuda/
rocm) and add a standalone CPU-safe test suite + CI step.
@maxiaosong1124 maxiaosong1124 force-pushed the feat/ws1-kv-cache-attention-pytorch-op branch from 3d466fe to 4e6df2c Compare June 28, 2026 08:55

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/test_kv_cache_attention.py (1)

351-353: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick win

Resolve Ruff B905 on the snapshot loop.

If the repo targets Python 3.10+, make this zip(..., strict=True) so mismatched tensor lists cannot truncate silently and the Ruff warning goes away.

Suggested change
-    for orig, snap in zip((q, k_cache, v_cache, k_new, v_new), snapshots):
+    for orig, snap in zip((q, k_cache, v_cache, k_new, v_new), snapshots, strict=True):
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/test_kv_cache_attention.py` around lines 351 - 353, The snapshot
verification loop in test_kv_cache_attention.py uses zip without strict
checking, which triggers Ruff B905. Update the loop that iterates over q,
k_cache, v_cache, k_new, and v_new to use zip(..., strict=True) so the tensor
snapshot comparison cannot silently truncate and the warning is resolved.

Source: Linters/SAST tools

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@tests/test_kv_cache_attention.py`:
- Around line 351-353: The snapshot verification loop in
test_kv_cache_attention.py uses zip without strict checking, which triggers Ruff
B905. Update the loop that iterates over q, k_cache, v_cache, k_new, and v_new
to use zip(..., strict=True) so the tensor snapshot comparison cannot silently
truncate and the warning is resolved.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 8e271f6f-2062-47ad-bec9-0183469d462d

📥 Commits

Reviewing files that changed from the base of the PR and between 3d466fe and 4e6df2c.

📒 Files selected for processing (5)
  • .github/workflows/ci.yml
  • docs/operators/attention.md
  • rl_engine/kernels/ops/pytorch/attention/kv_cache.py
  • rl_engine/kernels/registry.py
  • tests/test_kv_cache_attention.py
✅ Files skipped from review due to trivial changes (1)
  • docs/operators/attention.md
🚧 Files skipped from review as they are similar to previous changes (3)
  • rl_engine/kernels/ops/pytorch/attention/kv_cache.py
  • .github/workflows/ci.yml
  • rl_engine/kernels/registry.py

…tract-aligned split-point test

- kv_cache.py: validate q holds exactly the new positions (Sq == S_new) in
  forward/forward_fp32; a mismatch would silently use the wrong causal offset.
- test: split-point test now decodes the suffix queries and compares to the
  matching slice of full prefill (near-equal, atol=2e-6, not bitwise -- the
  score matmul's M dim differs across splits); add a misaligned-q raises test.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant