Skip to content

test(distillation): add value-level CPU tests for forward_kl_topk OPD loss#6822

Draft
HaozheZhang6 wants to merge 1 commit into
verl-project:mainfrom
HaozheZhang6:test/forward-kl-topk-loss-values
Draft

test(distillation): add value-level CPU tests for forward_kl_topk OPD loss#6822
HaozheZhang6 wants to merge 1 commit into
verl-project:mainfrom
HaozheZhang6:test/forward-kl-topk-loss-values

Conversation

@HaozheZhang6

Copy link
Copy Markdown
Contributor

What does this PR do?

Adds value-level CPU unit tests for compute_forward_kl_topk (FSDP backend). The existing CPU test (test_distillation_topk_symmetry_on_cpu) only asserts the two diagnostic outputs (overlap_count / overlap_token_advantage); the load-bearing outputs of the GKD-OPD loss had no value-level coverage:

  • distillation_losses (the objective), student_mass, teacher_mass -- pinned against a from-scratch forward-KL reference over the teacher top-k support.
  • the log_prob_min_clamp branch (on by default at -10.0) -- asserted to match a hand-clamped reference and to materially change the loss.
  • the use_chunked_topk path -- asserted to match the default path on all outputs.

A sign regression in kl_divergence, or a broken clamp/chunk branch, would pass CI today; these tests catch it (verified by flipping the kl_divergence sign locally -- the two value tests fail, the parity test correctly still passes).

Test

pytest tests/workers/test_forward_kl_topk_loss_values_on_cpu.py -> 3 passed. CPU-only, tiny tensors, pure PyTorch -- no GPU.

… loss

The existing CPU coverage only asserts the overlap diagnostics; the loss
value, student/teacher mass, the default-on log_prob_min_clamp branch, and
the use_chunked_topk path had no value-level assertions. Pin them against a
from-scratch forward-KL reference on tiny CPU tensors.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds value-level CPU tests for the compute_forward_kl_topk function to validate the loss math, clamp logic, and chunked top-k path against a reference implementation. The feedback suggests reducing the default chunked_topk_chunk_size in the test configuration helper from 4096 to a smaller value (e.g., 2) to ensure that the chunking loop is actually executed across multiple iterations during testing.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +33 to +40
def _cfg(log_prob_min_clamp=None, use_chunked_topk=False):
return SimpleNamespace(
distillation_loss=SimpleNamespace(
log_prob_min_clamp=log_prob_min_clamp,
use_chunked_topk=use_chunked_topk,
chunked_topk_chunk_size=4096,
)
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current _cfg helper sets chunked_topk_chunk_size=4096 by default. In test_chunked_topk_matches_default_path, the sequence length is 6, which is much smaller than 4096. As a result, the chunking loop in _chunked_topk_log_probs only executes a single iteration, leaving the multi-chunk boundary logic and loop state transitions completely untested.

By changing the default chunked_topk_chunk_size to a small value (e.g., 2), the test with seqlen=6 will naturally partition the input into multiple chunks, thereby thoroughly exercising the chunking loop and ensuring its correctness.

Suggested change
def _cfg(log_prob_min_clamp=None, use_chunked_topk=False):
return SimpleNamespace(
distillation_loss=SimpleNamespace(
log_prob_min_clamp=log_prob_min_clamp,
use_chunked_topk=use_chunked_topk,
chunked_topk_chunk_size=4096,
)
)
def _cfg(log_prob_min_clamp=None, use_chunked_topk=False, chunked_topk_chunk_size=2):
return SimpleNamespace(
distillation_loss=SimpleNamespace(
log_prob_min_clamp=log_prob_min_clamp,
use_chunked_topk=use_chunked_topk,
chunked_topk_chunk_size=chunked_topk_chunk_size,
)
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant