Skip to content

feat(ws1): NativeLMHeadOp pure-PyTorch ground-truth reference + numerical contract tests#170

Open
maxiaosong1124 wants to merge 7 commits into
RL-Align:mainfrom
maxiaosong1124:feat/ws1-lm-head-pytorch-op
Open

feat(ws1): NativeLMHeadOp pure-PyTorch ground-truth reference + numerical contract tests#170
maxiaosong1124 wants to merge 7 commits into
RL-Align:mainfrom
maxiaosong1124:feat/ws1-lm-head-pytorch-op

Conversation

@maxiaosong1124

@maxiaosong1124 maxiaosong1124 commented Jun 22, 2026

Copy link
Copy Markdown
Collaborator

Summary

Adds the pure-PyTorch ground-truth reference op for the language-model head — the
output layer of the WS1 batch-invariant forward chain — built on top of the numerical
contract defined in #108. Ships the
op, its registry wiring, docs, and a 15-case test suite that pins down both alignment
axes (Axis-A bitwise batch invariance, Axis-B per-dtype accuracy), plus a GPU-only smoke
test at the real Qwen3-8B projection dims.

Refs #108

Terminology

This PR uses the WS1 alignment vocabulary from #108:

  • Axis-A — batch invariance (reproducibility). A row's logits must not depend on how
    many rows share the batch (batch size, slicing, padding). Asserted bitwise
    (torch.equal). This is what keeps train-time (large batch) and sample-time (small
    batch / dynamic padding) numerics identical so the policy ratio doesn't drift.
  • Axis-B — accuracy. The low-precision (bf16 / fp16) forward path. Unlike the
    lossless embedding gather, lm_head is a reduction over hidden, so low-precision
    accumulation drifts from the fp32 reference and is checked against a tolerance
    window (not bitwise).

Motivation / Context

#108 establishes the ground-truth
harness and numerical contract for the WS1 batch-invariant forward chain. The final stage
of the Qwen3-8B stack projects hidden states back to vocabulary logits:

logits = hidden @ weight.t() # weight is HF [out, in]

This PR provides the deterministic fp32 reference path that downstream kernels (Triton /
CUDA / ROCm) will be validated against. For Qwen3-8B the weight is the output projection
[vocab=151936, hidden=4096] in the HF nn.Linear [out, in] convention (transposed
internally), is independent from the embedding table (tie_word_embeddings=false),
and has no bias.

Changes

  • rl_engine/kernels/ops/pytorch/linear/lm_head.pyNativeLMHeadOp
    • forward() — project in the input dtype, output the input dtype (Axis-B path)
    • forward_fp32() — upcast to fp32, accumulate in fp32, forced fp32 output
      (ground-truth / backward golden source)
    • Formula: out = hidden @ weight.t() (+ bias)
    • Weight is HF [out, in] and transposed internally — the one difference from the bare
      matmul op; do not use interchangeably
    • Pure function — inputs never mutated in place; output dtype follows hidden
  • rl_engine/kernels/registry.py — register PYTORCH_NATIVE_LM_HEAD in OpBackend
    and add lm_head dispatch to the cuda / rocm / cpu priority maps
  • tests/test_lm_head.py — 15 tests (details below)
  • docs/operators/lm_head.md + nav / index wiring

How this satisfies the #108 contract

#108 requirement How it's met here
Deterministic reference path forward_fp32() accumulates in fp32; tests use fixed-seed torch.Generator so outputs are reproducible
Per-dtype tolerance policy (bitwise vs tight-tolerance) Axis-A asserted bitwise (torch.equal); Axis-B reduction drift checked against a per-dtype tolerance measured relative to the output peak (bf16 ~0.37%, fp16 ~0.05%)
Batch-config sweep / validation helper Batch-invariance checks compute on the full batch, then assert sliced/padded rows are bitwise identical to their full-batch counterparts, across fp32 / bf16 / fp16
Realistic shapes covered GPU-only smoke test at the real Qwen3-8B dims (vocab=151936, hidden=4096); CPU tests keep the real hidden=4096 reduction length (only vocab is shrunk) so the drift is representative; skips when CUDA / GPU memory is unavailable

Reduction-specific note (Axis-A reduction order)

A single torch.matmul is not bitwise batch-invariant by default: multi-threaded CPU
GEMM splits the hidden (K) reduction across threads by the M = batch*seq dimension, so
"compute full then slice" ≠ "compute slice" once hidden is large. The tests pin a single
thread to fix the reduction order (a local stand-in for the planned
testing/determinism.py::deterministic_context). On GPU cuBLAS likewise splits K by M,
so a batch-invariant GEMM is a downstream kernel concern — the GPU smoke test validates the
full-vocab shape and fp32 correctness, not Axis-A bitwise.

Test Environment

OS Ubuntu (kernel 5.15.0-124-generic)
Python 3.12.3
PyTorch 2.8.0+cu128
CUDA / cuDNN 12.8 / 9.10.02

Test Results

python -m pytest tests/test_lm_head.py -v
17 passed
image

The 17 tests cover:

  • fp32 correctness vs naive matmul, asserted bitwise (torch.equal); fp32 forward
    path bitwise-equal to the ground truth
  • bf16 / fp16 dtype-path accuracy — max abs error relative to output peak (bf16 ~0.37%,
    fp16 ~0.05%), with error stats printed
  • output shape (hidden.shape[:-1] + (vocab,))
  • bias semantics (None default == no bias; provided [vocab] bias added)
  • Axis-A batch invariance — slice + padding variants across fp32 / bf16 / fp16, asserted
    bitwise under a pinned single-thread reduction
  • purity (neither hidden, weight, nor bias mutated in place)
  • gradient flow to hidden / weight (fp32 autograd = backward golden source), verified
    against the closed-form grads
  • registry dispatch resolves lm_headNativeLMHeadOp
  • GPU-only real-shape smoke test (Qwen3-8B vocab=151936, hidden=4096)

Checklist

  • ✅ Pure-PyTorch reference, no custom extension required
  • ✅ Covered at the real Qwen3-8B projection dims (vocab=151936, hidden=4096)
  • ✅ Axis-A bitwise batch invariance enforced (fixed single-thread reduction order)
  • ✅ Axis-B per-dtype tolerance calibrated (relative-to-peak, stats in PR)
  • ✅ Registered in OpBackend + cuda/rocm/cpu priority maps
  • ✅ All 17 tests pass locally

Summary by CodeRabbit

  • New Features
    • Added documentation for the LM Head operator, including usage, supported shapes/dtypes, and precision behavior.
    • Registered LM Head as a supported operator so it can be selected automatically.
  • Bug Fixes
    • Improved numeric consistency for LM Head inference and reference outputs across precision modes.
  • Tests
    • Added coverage for output shape, bias handling, gradients, batch/padding invariance, and backend dispatch.

WS1 ground-truth language-model-head op for issue RL-Align#108 (Qwen3-8B output
projection, vocab=151936 x hidden=4096, tie_word_embeddings=false, no bias):
- NativeLMHeadOp: out = hidden @ weight.t() (+ bias), a reduction over
  hidden exposing the forward / forward_fp32 dual-path contract (fp32
  ground truth + dtype-behavior path); weight is HF [out, in] and
  transposed internally (the one difference from the bare matmul op);
  pure function, no in-place mutation.
- register PYTORCH_NATIVE_LM_HEAD in OpBackend and the cuda/rocm/cpu
  priority maps.
- tests/test_lm_head.py: fp32 correctness vs naive matmul (bitwise),
  bf16/fp16 dtype-path accuracy (relative-to-peak tolerance, bf16 ~0.37%
  / fp16 ~0.05% of output peak), bias semantics, Axis-A batch invariance
  (slice + padding, all dtypes) under a pinned single-thread reduction so
  the CPU GEMM K-split is M-independent, purity, closed-form gradient flow
  to hidden/weight, registry dispatch, and a GPU-only real-shape smoke
  test (vocab=151936, hidden=4096).
- docs/operators/lm_head.md + nav/index wiring.
@coderabbitai

coderabbitai Bot commented Jun 22, 2026

Copy link
Copy Markdown

Review Change Stack

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: a9d52900-4867-4133-a6ff-1582f95d2e6c

📥 Commits

Reviewing files that changed from the base of the PR and between 517b12d and f9db98e.

📒 Files selected for processing (4)
  • docs/.nav.yml
  • docs/operators/README.md
  • rl_engine/kernels/registry.py
  • tests/test_lm_head.py
✅ Files skipped from review due to trivial changes (2)
  • docs/.nav.yml
  • docs/operators/README.md
🚧 Files skipped from review as they are similar to previous changes (2)
  • rl_engine/kernels/registry.py
  • tests/test_lm_head.py

📝 Walkthrough

Walkthrough

Adds a NativeLMHeadOp class implementing hidden-to-vocab logit projection via hidden @ weight.t() + bias with dual forward (input dtype) and forward_fp32 (strict fp32, autocast/TF32-disabled) paths. The operator is registered in the kernel registry for cuda/rocm/cpu platforms. A full pytest suite and operator documentation page are included.

Changes

LM Head Operator

Layer / File(s) Summary
NativeLMHeadOp implementation and registry wiring
rl_engine/kernels/ops/pytorch/linear/lm_head.py, rl_engine/kernels/registry.py
Adds NativeLMHeadOp with __call__/forward/forward_fp32 methods, the _lm_head shared static helper (dtype casting, hidden @ weight.t(), optional bias, strict fp32 context via _strict_fp32_matmul), the PYTORCH_NATIVE_LM_HEAD enum value in OpBackend, and lm_head entries in the cuda/rocm/cpu dispatch priority maps.
pytest suite
tests/test_lm_head.py
Covers fp32 bitwise correctness against naive matmul, autocast/TF32 isolation, dtype-path accuracy bounds for bf16/fp16, output shape, bias semantics, batch invariance (slicing and padding under _single_thread() CPU pinning), input purity, autograd gradient correctness, registry dispatch, and a memory-gated CUDA smoke test at Qwen3-8B dimensions.
Operator documentation and navigation
docs/operators/lm_head.md, docs/operators/README.md, docs/.nav.yml
Adds the full lm_head operator doc specifying the dual-path API contract, tensor shapes/dtypes, dispatch behavior, accuracy characterization, performance notes, and known limitations; registers the page in the nav YAML and operator README index.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Suggested reviewers

  • inaniloquentee
  • KJLdefeated
  • Flink-ddd

Poem

🐇 A hidden state hops through the weight,
transposed and matmul'd — oh, how great!
In fp32 we trust, no TF32 drift,
each logit lands with a bitwise gift.
The registry knows just where to go —
lm_head's wired, ready to flow! 🥕

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title is specific and accurately reflects the main change: a native LM-head reference implementation plus numerical contract tests.
Docstring Coverage ✅ Passed Docstring coverage is 95.83% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (2)
docs/.nav.yml (1)

13-17: 🧹 Nitpick | 🔵 Trivial | ⚡ Quick win

Maintain alphabetical ordering of operator entries in navigation.

The new operators/lm_head.md entry should be inserted in alphabetical order between operators/grpo-loss.md and operators/ratio-kl.md, not appended at the end. This keeps navigation consistent and easier to scan.

📖 Proposed ordering fix
   - Operators:
     - operators/README.md
     - operators/fused-logp.md
     - operators/grpo-loss.md
+    - operators/lm_head.md
     - operators/ratio-kl.md
-    - operators/sampling.md
+    - operators/sampling.md
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@docs/.nav.yml` around lines 13 - 17, The operators list in the navigation
file has an alphabetically misplaced entry. Move the operators/lm_head.md entry
to its correct alphabetical position between operators/grpo-loss.md and
operators/ratio-kl.md, as it should come after "grpo-loss" and before "ratio-kl"
when entries are sorted alphabetically. Remove it from its current position at
the end of the operators list and insert it in the proper alphabetical order to
maintain consistency and readability of the navigation structure.
docs/operators/README.md (1)

21-26: 🧹 Nitpick | 🔵 Trivial | ⚡ Quick win

Maintain alphabetical ordering of operator index entries.

The new [LM Head] entry should be inserted in alphabetical order between [GRPO Loss] and [Policy Ratio + KL Penalty], not after [Sampling]. This keeps the index consistent and easier to navigate.

📖 Proposed ordering fix
 - [Fused LogP](fused-logp.md)
 - [GRPO Loss](grpo-loss.md)
+- [LM Head](lm_head.md)
 - [Policy Ratio + KL Penalty](ratio-kl.md)
 - [Sampling](sampling.md)
-- [LM Head](lm_head.md)
 - [Operator Doc Template](../contributing/operator-doc-template.md)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@docs/operators/README.md` around lines 21 - 26, Reorder the operator index
entries in the README.md file to maintain alphabetical ordering. Move the [LM
Head] link from its current position after [Sampling] to its correct
alphabetical position between [GRPO Loss] and [Policy Ratio + KL Penalty].
Ensure all entries in the list are ordered alphabetically by their display names
to keep the index consistent and easy to navigate.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@tests/test_lm_head.py`:
- Around line 196-200: The _enough_gpu_memory function can fail test collection
when torch.cuda.mem_get_info() raises a RuntimeError in partially configured
CUDA environments. Wrap the torch.cuda.mem_get_info() call in a try-except block
that catches RuntimeError and returns False when the error is caught, allowing
tests to be skipped gracefully instead of failing during collection.
- Around line 86-87: The test functions test_native_lm_head_dtype_path_accuracy,
and the two other similar test functions at lines 130 and 141 unconditionally
parametrize with torch.float16, which can cause failures on CPU hardware where
half-precision matmul is not guaranteed to be supported. Extract the dtype
tuples used in the parametrize decorators into module-level constants, then
replace the direct parametrization with pytest.param calls that include
conditional runtime checks to skip torch.float16 on CPU backends. Apply this
pattern consistently across all three affected test functions to prevent
backend-dependent test failures.

---

Nitpick comments:
In `@docs/.nav.yml`:
- Around line 13-17: The operators list in the navigation file has an
alphabetically misplaced entry. Move the operators/lm_head.md entry to its
correct alphabetical position between operators/grpo-loss.md and
operators/ratio-kl.md, as it should come after "grpo-loss" and before "ratio-kl"
when entries are sorted alphabetically. Remove it from its current position at
the end of the operators list and insert it in the proper alphabetical order to
maintain consistency and readability of the navigation structure.

In `@docs/operators/README.md`:
- Around line 21-26: Reorder the operator index entries in the README.md file to
maintain alphabetical ordering. Move the [LM Head] link from its current
position after [Sampling] to its correct alphabetical position between [GRPO
Loss] and [Policy Ratio + KL Penalty]. Ensure all entries in the list are
ordered alphabetically by their display names to keep the index consistent and
easy to navigate.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 18702bfa-053c-4402-b236-414de2b74d14

📥 Commits

Reviewing files that changed from the base of the PR and between cd0ca43 and 5ae1c4b.

📒 Files selected for processing (7)
  • docs/.nav.yml
  • docs/operators/README.md
  • docs/operators/lm_head.md
  • rl_engine/kernels/ops/pytorch/linear/__init__.py
  • rl_engine/kernels/ops/pytorch/linear/lm_head.py
  • rl_engine/kernels/registry.py
  • tests/test_lm_head.py

Comment thread tests/test_lm_head.py Outdated
Comment thread tests/test_lm_head.py
- Gate CPU float16 matmul parametrizations behind a runtime support
  probe so unsupported backends skip rather than fail collection.
- Harden _enough_gpu_memory against RuntimeError from mem_get_info in
  partially-configured CUDA environments.
- Add docstrings across the op and test suite to meet coverage.
- Sort lm_head entries alphabetically in operator nav/README.
Keep consistent with other PyTorch native ops.
output_dtype: torch.dtype,
) -> torch.Tensor:
"""Core matmul: cast to ``compute_dtype``, project, optionally add bias, cast out."""
h = hidden.to(compute_dtype)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One precision nit: forward_fp32() casts the inputs to fp32, but the matmul can still run under autocast or CUDA TF32 settings, so it may not be a true fp32 golden reference. Since downstream kernels will compare against this path, could we explicitly disable autocast/TF32 around the matmul, or document that required precision context?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @inaniloquentee .

forward_fp32() now disables autocast and CUDA TF32 around the matmul, while saving/restoring the previous TF32 setting so global state does not leak. The regular forward() path is unchanged and still
follows the ambient precision context.

I also added regression coverage for this:

  • CPU autocast case: forward_fp32 remains equal to the fp32 reference and restores the TF32 flag.
  • CUDA TF32 case: forward_fp32 is checked against a higher-precision reference when CUDA is available.

Docs were updated to note the precision-context behavior.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @inaniloquentee .

forward_fp32() now disables autocast and CUDA TF32 around the matmul, while saving/restoring the previous TF32 setting so global state does not leak. The regular forward() path is unchanged and still follows the ambient precision context.

I also added regression coverage for this:

  • CPU autocast case: forward_fp32 remains equal to the fp32 reference and restores the TF32 flag.
  • CUDA TF32 case: forward_fp32 is checked against a higher-precision reference when CUDA is available.

Docs were updated to note the precision-context behavior.

LGTM!

Wrap the forward_fp32 matmul to disable autocast and CUDA TF32 (saving and
restoring the global allow_tf32 flag) so the fp32 golden path is not silently
downcast by the caller's ambient precision context. The dtype-behavior forward
path is left to follow ambient precision intentionally.

Add tests: forward_fp32 stays true fp32 under ambient autocast and restores
the TF32 flag (CPU); numerically beats a TF32 matmul (GPU). Pin TF32 off in the
fp32-vs-naive bitwise test. Sync docs accordingly.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
docs/operators/lm_head.md (1)

108-115: 🧹 Nitpick | 🔵 Trivial | ⚡ Quick win

Break the test coverage list into separate sentences for readability.

The test coverage section is accurate and comprehensive, but lines 108–115 condense nine distinct test concerns into a single long sentence that impairs readability. Consider splitting into 2–3 sentences by major category (e.g., "Covers: [precision & dtype behavior]...." then "Also covers: [batch invariance & purity]..." then "GPU-only smoke test...").

📝 Suggested restructuring
- Covers: fp32 correctness vs naive matmul (bitwise, with ambient TF32 pinned off),
- `forward_fp32` precision-context safety (true fp32 under ambient autocast + restores the
- global TF32 flag on CPU; numerically beats a TF32 matmul on GPU), bf16/fp16 dtype-path
- accuracy (relative-to-peak tolerance, with `bias`), output shape, bias semantics, Axis-A
- batch invariance (slice + padding, single-thread reduction, all dtypes), input purity,
- gradient flow to `hidden`/`weight` (closed-form check), registry dispatch, and a GPU-only
- smoke test at the real Qwen3-8B dims (`vocab=151936, hidden=4096`) that skips when CUDA or
- GPU memory is unavailable.
+ Covers: fp32 correctness vs naive matmul (bitwise, with ambient TF32 pinned off) and
+ `forward_fp32` precision-context safety (true fp32 under ambient autocast, restores the
+ global TF32 flag on CPU, numerically beats a TF32 matmul on GPU).
+
+ Also covers: bf16/fp16 dtype-path accuracy (relative-to-peak tolerance, with `bias`),
+ output shape, bias semantics, Axis-A batch invariance (slice + padding, single-thread
+ reduction, all dtypes), input purity, and gradient flow to `hidden`/`weight` (closed-form).
+
+ Registry dispatch and a GPU-only smoke test at the real Qwen3-8B dims (`vocab=151936,
+ hidden=4096`) round out coverage; the smoke test skips when CUDA or GPU memory is unavailable.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@docs/operators/lm_head.md` around lines 108 - 115, The test coverage
description in the lm_head.md documentation is a single long sentence that lists
nine distinct test concerns, making it difficult to read and parse. Break this
single sentence into 2-3 shorter sentences organized by major category: group
precision and dtype-related tests together (fp32 correctness, forward_fp32
safety, bf16/fp16 accuracy, output shape and bias semantics), then create a
second sentence for batch invariance and purity tests (Axis-A batch invariance,
input purity, gradient flow), and finally add a third sentence for the GPU-only
smoke test at Qwen3-8B dimensions. This restructuring will improve readability
while maintaining all the technical details.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@docs/operators/lm_head.md`:
- Around line 108-115: The test coverage description in the lm_head.md
documentation is a single long sentence that lists nine distinct test concerns,
making it difficult to read and parse. Break this single sentence into 2-3
shorter sentences organized by major category: group precision and dtype-related
tests together (fp32 correctness, forward_fp32 safety, bf16/fp16 accuracy,
output shape and bias semantics), then create a second sentence for batch
invariance and purity tests (Axis-A batch invariance, input purity, gradient
flow), and finally add a third sentence for the GPU-only smoke test at Qwen3-8B
dimensions. This restructuring will improve readability while maintaining all
the technical details.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 75685450-1855-40ee-b44d-c14e47328e11

📥 Commits

Reviewing files that changed from the base of the PR and between a3fb370 and 517b12d.

📒 Files selected for processing (3)
  • docs/operators/lm_head.md
  • rl_engine/kernels/ops/pytorch/linear/lm_head.py
  • tests/test_lm_head.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • rl_engine/kernels/ops/pytorch/linear/lm_head.py
  • tests/test_lm_head.py

@KJLdefeated KJLdefeated left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good and clean. I am happy to approve once the requests are resolved.

Comment thread tests/test_lm_head.py Outdated
Comment on lines +263 to +276
# hidden over (batch, seq) per vocab row.
def test_lm_head_gradient_flows():
"""fp32 autograd matches the closed-form grads of out.sum()."""
op = NativeLMHeadOp()
hidden = _rand_hidden(2, 4, seed=8).requires_grad_(True)
weight = _rand_weight(seed=8).requires_grad_(True)
op.forward_fp32(hidden, weight).sum().backward()

assert torch.isfinite(hidden.grad).all() and torch.isfinite(weight.grad).all()
assert hidden.grad.shape == hidden.shape and weight.grad.shape == weight.shape
exp_h = weight.detach().sum(dim=0).expand_as(hidden.grad)
exp_w = hidden.detach().sum(dim=(0, 1)).expand_as(weight.grad)
assert torch.allclose(hidden.grad, exp_h, atol=1e-4, rtol=1e-4)
assert torch.allclose(weight.grad, exp_w, atol=1e-4, rtol=1e-4)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Closed-form check is correct and a real improvement over finiteness, but the out.sum() cotangent is the least discriminating one possible: it makes the upstream gradient all-ones, so both expected grads collapse to column sums (every row of hidden.grad identical, every row of weight.grad identical). A backward that transposed or mis-contracted something can still pass under that symmetry, and the real backward never sees an all-ones cotangent.
Recommend backpropping a random cotangent so the full contraction is exercised:

dy = torch.randn_like(out)
out.backward(dy)
exp_h = dy @ weight                                   # [.., V] @ [V, K] -> [.., K]
exp_w = dy.reshape(-1, V).t() @ hidden.reshape(-1, K) # [V, K]
torch.testing.assert_close(hidden.grad, exp_h, rtol=1e-5, atol=1e-5)
torch.testing.assert_close(weight.grad, exp_w, rtol=1e-5, atol=1e-5)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@KJLdefeated Fixed in 407200a — switched to a random (fixed-seed) cotangent and assert against the exact closed forms (dy @ weight / dy^T @ hidden), so the full contraction is exercised instead of collapsing to column sums. Seeded for reproducibility; all 17 tests pass at 1e-5. Also rebased on latest main and got CI linting back to green.

…torch-op

# Conflicts:
#	rl_engine/kernels/registry.py
… check (KJLdefeated)

Backprop a random cotangent instead of out.sum() so the full hidden@weight.t()
contraction is exercised; an all-ones cotangent collapses both grads to column
sums and could mask a transposed/mis-contracted backward. Also wrap two
over-length lines flagged by flake8 (pre-existing linting failure).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants