test: add selected logprob parity harness#198
Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (3)
🚧 Files skipped from review as they are similar to previous changes (2)
📝 WalkthroughWalkthroughAdds a new testing module for selected-logprob layout parity, re-exports its helpers from the testing package, and expands the test suite with permutation, padding, invalid-input, and CUDA dtype checks. ChangesLogprob Parity Utilities and Tests
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@rl_engine/testing/logprob_parity.py`:
- Around line 89-90: Validate candidate_rows before using it in logprob_parity’s
indexing path: in the logic that builds restored from candidate[rows], add the
same row-id checks used by make_padded_batch_layout so each reference batch
entry maps to exactly one candidate row, with no negative or duplicate indices.
Also bounds-check the remapped rows against candidate.shape[0] before indexing,
and fail fast if the layout is invalid.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: ff5ebd1f-7893-473e-a0b0-c33ea78094ba
📒 Files selected for processing (3)
rl_engine/testing/__init__.pyrl_engine/testing/logprob_parity.pytests/test_logprob_parity.py
f713279 to
7b31bf7
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@rl_engine/testing/logprob_parity.py`:
- Line 20: The `pad_token_id` used in `selected_logprobs_reference` and the
related test setup must be validated against the logits vocabulary before any
indexing occurs. Add an explicit check in the `logprob_parity` test helpers to
ensure `pad_token_id` is within the valid token range derived from the logits
shape, and fail fast with a clear assertion if it is not. Apply the same
validation wherever `pad_token_id` is passed through the parity test path so
fully masked rows never rely on an out-of-range index.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 48ce688f-85e0-4cdf-983c-d78932299446
📒 Files selected for processing (3)
rl_engine/testing/__init__.pyrl_engine/testing/logprob_parity.pytests/test_logprob_parity.py
🚧 Files skipped from review as they are similar to previous changes (1)
- rl_engine/testing/init.py
7b31bf7 to
92b9799
Compare
92b9799 to
b208efc
Compare
|
Fixed the pre-commit EOF issue and re-ran local checks:
|
|
cc @maxiaosong1124 @a-kaa PTAL |
Summary
Adds a small selected-logprob parity harness for validating that the same effective completion rows produce identical selected logprobs across batch layout changes.
This is intended as a first step toward the rollout-vs-training numerical parity work discussed in #148, #152, and #154.
What is covered
summarize_kernel_driftWhat is intentionally out of scope
Tests
Validated on AutoDL RTX 4090:
python -m pytest tests/test_logprob_parity.py -vpython -m pytest tests/test_reference_ops.py tests/test_logprob_parity.py -vResult:
4 passed11 passedSummary by CodeRabbit
float32andbfloat16, validating consistent parity in a chosen output dtype.