feat(ws1): add NativeKVCacheAttnOp pure-PyTorch KV-cache attention reference#195
feat(ws1): add NativeKVCacheAttnOp pure-PyTorch KV-cache attention reference#195maxiaosong1124 wants to merge 12 commits into
Conversation
WS1 ground-truth attention op for issue RL-Align#108 (Qwen3-8B GQA attention): - NativeAttentionOp: out = softmax(Q Kᵀ * scale + masks) @ V, a hand-written naive softmax (NOT F.scaled_dot_product_attention / flash) so the reduction order is fixed for the batch-invariance contract. GQA 32/8 via repeat_interleave, causal offset Skv-Sq+1 (prefill + decode), key_padding_mask (True=valid), scale default 1/sqrt(128). Exposes the forward / forward_fp32 dual-path contract (fp32 ground truth + dtype-behavior path); forward_fp32 disables TF32/autocast for a strict fp32 reference. Pure function, fp32 accumulation. - register PYTORCH_NATIVE_ATTENTION in OpBackend and the cuda/rocm/cpu priority maps under op_type "attention" (distinct from the production "attn" / PYTORCH_ATTN SDPA fallback) - tests/test_attention.py: forward_fp32 vs independent fp32 reference, closed-form causal/decode, GQA replication + divisibility guard, scale, key-padding, dtype-path accuracy (Axis-B), Axis-A batch invariance (slice + chunked + padding), purity, gradient flow, registry dispatch, GPU-only LARGE Qwen3-8B smoke - docs/operators/attention.md + nav/index wiring
- standard_attn: define fully key-padding-masked query rows as 0 (was NaN); guarded to the padding branch so the no-pad path is unchanged, row-independent so Axis-A holds; add test_fully_masked_query_returns_zero_not_nan - test: drop the double 1.5x margin in _enough_gpu_memory (LARGE skip now ~50 GB as documented, no longer over-skips 80 GB GPUs) - docs/attention.md: add text lang to the diagram fence (MD040); clarify that dispatch uses forward() input-dtype path, forward_fp32() is the explicit fp32 path
key_padding_mask drift over differing reduction widths (Skv=10 vs 6) is ~1.3e-6 and platform-sensitive; atol=1e-6 failed locally for the reviewer. Bump the threshold to 2e-6 for headroom, and update the test-coverage doc line so padding reads as near-equality, not part of the bitwise Axis-A claim.
|
Warning Review limit reached
More reviews will be available in 39 minutes and 43 seconds. Learn how PR review limits work. Your organization has run out of usage credits. Purchase more credits in the billing tab to continue. ⌛ How to resolve this issue?After more reviews become available, a review can be triggered using the To avoid repeated limits, reduce automatic review volume by pausing incremental auto-reviews earlier, using label-based review opt-in, excluding WIP or generated PR titles, or requesting reviews manually when the PR is ready. If your team needs uninterrupted high-volume reviews, an organization admin can enable usage-based credits. 🚦 How do rate limits work?CodeRabbit enforces per-developer PR review limits for each organization. Most developers receive the normal plan review availability. For paid Pro and Pro+ PR reviews, CodeRabbit uses adaptive limits for sustained high-volume activity. When a developer's recent PR review activity reaches the 95th percentile or higher among CodeRabbit users, additional reviews become available more gradually as earlier reviews age out of the rolling window. Please see our Fair Usage Limits Policy for further information. ℹ️ Review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (2)
📝 WalkthroughWalkthroughAdds pure-PyTorch reference attention operators for standard and KV-cache flows, routes both operator names through the kernel registry, documents the standard attention contract, and expands CI plus pytest coverage for correctness, masking, dtype behavior, batch invariance, and dispatch. ChangesAttention reference operators
Sequence Diagram(s)sequenceDiagram
participant kernel_registry
participant NativeAttentionOp
participant _strict_fp32_math
kernel_registry->>NativeAttentionOp: get_op("attention")
NativeAttentionOp->>_strict_fp32_math: forward_fp32() disables autocast and TF32
_strict_fp32_math-->>NativeAttentionOp: restores TF32 state
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 5
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@docs/operators/attention.md`:
- Around line 72-73: The operator contract text is out of sync with the
implementation and should be corrected in the attention docs. Update the
`forward_fp32()` contract wording so it no longer claims the original dtype is
preserved, and change the padding tolerance statement to match the shipped
`2e-6` value. Keep the wording aligned with the operator contract section and
the relevant `forward_fp32()` behavior so downstream users get the correct
guarantees.
In `@rl_engine/kernels/ops/pytorch/attention/kv_cache.py`:
- Around line 78-118: Add an explicit sequence-length validation in
`KVCacheAttention.forward` and `KVCacheAttention.forward_fp32` to reject cases
where `q` does not match the newly appended token length (`Sq != S_new`). Use
the existing `_concat_kv` flow as the location to enforce this check before
calling `_attn.forward` or `_attn.forward_fp32`, and raise an error with clear
context when the `q`/`k_new`/`v_new` shapes are misaligned.
In `@rl_engine/kernels/ops/pytorch/attention/standard_attn.py`:
- Around line 183-191: The _strict_fp32_math helper currently mutates the
process-wide torch.backends.cuda.matmul.allow_tf32 flag, which can affect
unrelated CUDA work and concurrent calls. Update the forward_fp32 path to avoid
toggling shared backend state in standard_attn.py, and instead use a safer
isolation approach within _strict_fp32_math (or an equivalent local-only
mechanism) while still disabling autocast for the true fp32 reference path.
In `@tests/test_attention.py`:
- Around line 281-285: Reformat the assert block in test_attention.py so it
matches Black’s preferred wrapping, especially the multi-line assert with the
trailing message string. Update the padding-mask check around the
masked/valid_only comparison to use Black-compatible line breaks and
parentheses, keeping the same logic but adjusting the formatting so CI no longer
rewrites it.
In `@tests/test_kv_cache_attention.py`:
- Around line 145-157: The split-point KV-cache test is using a full-length
query tensor for every split, which hides whether the suffix-only path is
correct. Update the test in the split loop so the query tensor is the suffix of
q that corresponds to each non-zero split, and keep the existing forward_fp32
comparison against the all-cache/all-new cases. Use the symbols q, k_full,
v_full, split, and op.forward_fp32 to align the test with the documented
KV-cache contract.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 527d2fb3-1020-4591-be59-ecf3322d1308
📒 Files selected for processing (9)
.github/workflows/ci.ymldocs/.nav.ymldocs/operators/README.mddocs/operators/attention.mdrl_engine/kernels/ops/pytorch/attention/kv_cache.pyrl_engine/kernels/ops/pytorch/attention/standard_attn.pyrl_engine/kernels/registry.pytests/test_attention.pytests/test_kv_cache_attention.py
| def _strict_fp32_math(device_type: str): | ||
| """Disable autocast and TF32 for a true fp32 path, restoring state after.""" | ||
| prev_tf32 = torch.backends.cuda.matmul.allow_tf32 | ||
| torch.backends.cuda.matmul.allow_tf32 = False | ||
| try: | ||
| with torch.autocast(device_type=device_type, enabled=False): | ||
| yield | ||
| finally: | ||
| torch.backends.cuda.matmul.allow_tf32 = prev_tf32 |
There was a problem hiding this comment.
🩺 Stability & Availability | 🟠 Major | 🏗️ Heavy lift
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Inspect where the public attention reference is reached from, and whether
# concurrent execution paths exist in the repo.
rg -nP 'get_op\("attention"\)|NativeAttentionOp\(|forward_fp32\s*\(' rl_engine tests
rg -nP '\b(threading|ThreadPoolExecutor|ProcessPoolExecutor|asyncio|DataLoader)\b' rl_engine testsRepository: RL-Align/RL-Kernel
Length of output: 6173
Avoid mutating the global TF32 backend flag here. torch.backends.cuda.matmul.allow_tf32 is process-wide, so forward_fp32() can perturb unrelated CUDA work and becomes unsafe if called concurrently. If this stays as a public reference path, it needs a safer isolation strategy than toggling shared backend state.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@rl_engine/kernels/ops/pytorch/attention/standard_attn.py` around lines 183 -
191, The _strict_fp32_math helper currently mutates the process-wide
torch.backends.cuda.matmul.allow_tf32 flag, which can affect unrelated CUDA work
and concurrent calls. Update the forward_fp32 path to avoid toggling shared
backend state in standard_attn.py, and instead use a safer isolation approach
within _strict_fp32_math (or an equivalent local-only mechanism) while still
disabling autocast for the true fp32 reference path.
…pytorch-op # Conflicts: # docs/.nav.yml # docs/operators/README.md
…e (KJLdefeated) Replace isfinite-only grad check with autograd through the independent double-precision _ref_softmax_attn under a random (seeded) cotangent. isfinite can't tell a correct gradient from a wrong-but-finite one, and attention's backward (softmax Jacobian + dQ/dK/dV contractions) is the most error-prone in the stack; .sum()'s all-ones cotangent would also collapse the contraction.
efdcce2 to
3d466fe
Compare
…eview forward_fp32 returns fp32 (not the input dtype), and the shipped padding tolerance is atol=2e-6, not 1e-6. Aligns the operator contract doc with the code (CodeRabbit).
…ference Concat cache+new K/V then delegate to NativeAttentionOp so prefill and decode share one reduction path. Register kv_cache_attention in the registry (cpu/cuda/ rocm) and add a standalone CPU-safe test suite + CI step.
3d466fe to
4e6df2c
Compare
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/test_kv_cache_attention.py (1)
351-353: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick winResolve Ruff B905 on the snapshot loop.
If the repo targets Python 3.10+, make this
zip(..., strict=True)so mismatched tensor lists cannot truncate silently and the Ruff warning goes away.Suggested change
- for orig, snap in zip((q, k_cache, v_cache, k_new, v_new), snapshots): + for orig, snap in zip((q, k_cache, v_cache, k_new, v_new), snapshots, strict=True):🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/test_kv_cache_attention.py` around lines 351 - 353, The snapshot verification loop in test_kv_cache_attention.py uses zip without strict checking, which triggers Ruff B905. Update the loop that iterates over q, k_cache, v_cache, k_new, and v_new to use zip(..., strict=True) so the tensor snapshot comparison cannot silently truncate and the warning is resolved.Source: Linters/SAST tools
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@tests/test_kv_cache_attention.py`:
- Around line 351-353: The snapshot verification loop in
test_kv_cache_attention.py uses zip without strict checking, which triggers Ruff
B905. Update the loop that iterates over q, k_cache, v_cache, k_new, and v_new
to use zip(..., strict=True) so the tensor snapshot comparison cannot silently
truncate and the warning is resolved.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 8e271f6f-2062-47ad-bec9-0183469d462d
📒 Files selected for processing (5)
.github/workflows/ci.ymldocs/operators/attention.mdrl_engine/kernels/ops/pytorch/attention/kv_cache.pyrl_engine/kernels/registry.pytests/test_kv_cache_attention.py
✅ Files skipped from review due to trivial changes (1)
- docs/operators/attention.md
🚧 Files skipped from review as they are similar to previous changes (3)
- rl_engine/kernels/ops/pytorch/attention/kv_cache.py
- .github/workflows/ci.yml
- rl_engine/kernels/registry.py
…tract-aligned split-point test - kv_cache.py: validate q holds exactly the new positions (Sq == S_new) in forward/forward_fp32; a mismatch would silently use the wrong causal offset. - test: split-point test now decodes the suffix queries and compares to the matching slice of full prefill (near-equal, atol=2e-6, not bitwise -- the score matmul's M dim differs across splits); add a misaligned-q raises test.
NativeKVCacheAttnOpdelegates to the sameNativeAttentionOpreduction added in #188,so prefill and decode share one numerical path. This branch is stacked on top of #188's
feat/ws1-attention-pytorch-op.Because #188 is not merged yet, the diff below currently also contains #188's attention
files. For this review, please look only at the kv_cache-specific changes:
rl_engine/kernels/ops/pytorch/attention/kv_cache.pytests/test_kv_cache_attention.pykv_cache_attentionentries inrl_engine/kernels/registry.py.github/workflows/ci.ymlOnce #188 merges to
main, I'll rebase ontomainand the diff becomes lean (attention-free).Please do not merge this before #188.
Summary
Pure-PyTorch fp32 ground-truth reference for KV-cache (incremental/decode) attention
(WS1 / ISSUE #108). It concatenates cached and new K/V along the sequence axis and then calls
the exact same
NativeAttentionOpused for full-sequence prefill — re-implementing softmaxhere would defeat the purpose. This guarantees decode (rollout) and prefill (training) take an
identical reduction, which is what keeps rollout↔training numerically consistent.
Shapes follow Qwen3-8B (synthetic tensors, no weights downloaded): 32 Q / 8 KV heads
(GQA g=4) / head_dim=128, scale = 1/√128.
What this adds
NativeKVCacheAttnOp(pytorch/attention/kv_cache.py):forward/forward_fp32(q, k_cache, v_cache, k_new, v_new, *, causal=True, scale=None, key_padding_mask=None)per the frozen contract (INTERFACES §9).cat([cache, new], dim=2)→ delegates to
NativeAttentionOp. Returns attention output only; cache update is out ofscope for the numerical contract.
PYTORCH_NATIVE_KV_CACHE_ATTNenum +kv_cache_attentionpriority-map entries(cpu / cuda / rocm).
tests/test_kv_cache_attention.py— 14 CPU-safe tests (+ GPU/large markers).pytest tests/test_kv_cache_attention.py -k "not large and not gpu".Test coverage
forward/forward_fp32equalNativeAttentionOpon theconcatenated K/V.
gives the same result.
torch.equal).atol=2e-6(differingsoftmax reduction widths, same rationale as :feat(ws1): add NativeAttentionOp pure-PyTorch standard-softmax reference #188's padding test, with headroom over the
~1.3e-6 platform-sensitive drift).
q, GQA divisibility guard, input purity (no mutation), registry dispatch.Local run
```
python -m pytest tests/test_kv_cache_attention.py -v -k "not large and not gpu"
14 passed
```
Notes
True = valid / keep,False = padding(consistent with the rest of the repo).Summary by CodeRabbit