Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions tests/workers/test_distillation_config_opsd_on_cpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright 2026 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""CPU tests for the OPSD fields on ``DistillationConfig``."""

from verl.workers.config.distillation import DistillationConfig


def test_opsd_defaults_are_off_and_backward_compatible():
c = DistillationConfig()
assert c.self_distillation is False
assert c.privileged_solution_key == "reward_model.ground_truth"
# markers default to newline-wrapped text so the solution is set off from the prompt
assert c.privileged_prefix.strip() and c.privileged_suffix.strip()


def test_self_distillation_requires_enabled():
import pytest

with pytest.raises(ValueError):
DistillationConfig(self_distillation=True) # enabled defaults to False
160 changes: 160 additions & 0 deletions tests/workers/test_opsd_privileged_context_on_cpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Copyright 2026 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""CPU tests for the OPSD privileged-context helpers."""

import numpy as np
import pytest
import torch

from verl.trainer.distillation.privileged_context import (
build_privileged_sequence,
resolve_privileged_solution,
slice_privileged_teacher_to_student,
)


def test_resolve_privileged_solution_nested_dotted_key():
sk = {"reward_model": {"ground_truth": "72"}}
assert resolve_privileged_solution(sk, "reward_model.ground_truth") == "72"
# flat key still works
assert resolve_privileged_solution({"sol": "x"}, "sol") == "x"


def test_resolve_privileged_solution_missing_returns_none():
assert resolve_privileged_solution({"a": 1}, "reward_model.ground_truth") is None
assert resolve_privileged_solution(None, "x") is None
assert resolve_privileged_solution({"g": ""}, "g") is None


def test_resolve_privileged_solution_normalizes_scalar_array_list():
assert resolve_privileged_solution({"g": np.array("72")}, "g") == "72" # 0-d -> item
assert resolve_privileged_solution({"g": np.array(["x"])}, "g") == "x" # 1-elem array
assert resolve_privileged_solution({"g": ["a", "b"]}, "g") == "a\nb" # list -> joined


def test_build_privileged_sequence_layout():
seq = build_privileged_sequence(
prompt_ids=[1, 2],
response_ids=[9, 10],
solution_ids=[5, 6, 7],
prefix_ids=[100],
suffix_ids=[200, 201],
)
assert seq == [1, 2, 100, 5, 6, 7, 200, 201, 9, 10]
# the response is the suffix, exactly as in a plain prompt+response teacher input
assert seq[-2:] == [9, 10]


def test_build_privileged_sequence_empty_markers():
assert build_privileged_sequence([1], [9], [5], [], []) == [1, 5, 9]


def test_build_privileged_sequence_insert_before_marker():
# marker = [8] (e.g. an assistant-turn opener); solution block goes before its
# last occurrence, the response after the original prompt tail.
seq = build_privileged_sequence(
prompt_ids=[1, 8, 2, 8],
response_ids=[9],
solution_ids=[5],
prefix_ids=[100],
suffix_ids=[200],
insert_before_token_ids=[8],
)
assert seq == [1, 8, 2, 100, 5, 200, 8, 9]


def test_build_privileged_sequence_insert_marker_not_found_falls_back():
seq = build_privileged_sequence(
[1, 2], [9], [5], [100], [200], insert_before_token_ids=[42]
)
assert seq == [1, 2, 100, 5, 200, 9] # append fallback


def test_slice_privileged_teacher_keeps_response_pads_prompt():
priv_len, k = 10, 2
teacher_ids = torch.arange(priv_len * k).reshape(priv_len, k)
teacher_logprobs = torch.randn(priv_len, k)

ids, logprobs = slice_privileged_teacher_to_student(
teacher_ids, teacher_logprobs, student_prompt_length=2, response_length=2, pad_token_id=0
)

assert ids.shape == (4, k) and logprobs.shape == (4, k)
assert torch.equal(ids[-2:], teacher_ids[-2:])
assert torch.equal(logprobs[-2:], teacher_logprobs[-2:])
assert torch.all(ids[:2] == 0)
assert torch.all(logprobs[:2] == 0.0)


def test_slice_response_length_zero_is_empty_not_whole_tensor():
# regression: teacher_ids[-0:] would return the whole tensor; an empty
# response must yield only the padded prompt rows.
teacher_ids = torch.arange(12).reshape(6, 2)
teacher_logprobs = torch.randn(6, 2)
ids, logprobs = slice_privileged_teacher_to_student(
teacher_ids, teacher_logprobs, student_prompt_length=3, response_length=0, pad_token_id=0
)
assert ids.shape == (3, 2) and logprobs.shape == (3, 2)
assert torch.all(ids == 0)


def test_slice_pad_token_id_none_falls_back_to_zero():
teacher_ids = torch.arange(6).reshape(3, 2)
teacher_logprobs = torch.randn(3, 2)
ids, _ = slice_privileged_teacher_to_student(
teacher_ids, teacher_logprobs, student_prompt_length=1, response_length=2, pad_token_id=None
)
assert ids[0, 0].item() == 0


def test_slice_preserves_dtype_and_pad_value():
teacher_ids = torch.arange(6).reshape(3, 2).long()
teacher_logprobs = torch.randn(3, 2, dtype=torch.float32)
ids, logprobs = slice_privileged_teacher_to_student(
teacher_ids, teacher_logprobs, student_prompt_length=1, response_length=2, pad_token_id=7
)
assert ids.dtype == torch.long and logprobs.dtype == torch.float32
assert ids[0, 0].item() == 7


def test_slice_negative_response_length_raises():
with pytest.raises(ValueError):
slice_privileged_teacher_to_student(
torch.arange(6).reshape(3, 2), torch.randn(3, 2), 1, -1, 0
)


def test_privileged_slice_aligns_response_rows_after_padding():
"""End-to-end alignment: after the privileged slice + the same left/right pad
that ``_pad_teacher_outputs`` applies, response token ``j`` lands at absolute
row ``prompt_width + j`` -- the privileged teacher's score for that token. The
pad is replicated inline (``F.pad``) so we don't import the heavy
teacher_manager package."""
import torch.nn.functional as F

prompt_len, block_len, resp_len, k = 3, 4, 2, 2
prompt_width, response_width = 5, 4 # batch widths (>= the per-sample lengths)
n = prompt_len + block_len + resp_len # teacher scores prompt + block + response
teacher_ids = torch.arange(n * k).reshape(n, k)
teacher_logprobs = torch.randn(n, k)

ids, _ = slice_privileged_teacher_to_student(
teacher_ids, teacher_logprobs, prompt_len, resp_len, pad_token_id=0
)
# _pad_teacher_outputs: left-pad the prompt region, right-pad the response region
padded = F.pad(ids, (0, 0, prompt_width - prompt_len, response_width - resp_len), value=0)
assert padded.shape[0] == prompt_width + response_width
for j in range(resp_len):
# absolute row prompt_width+j == the teacher's privileged score for response token j
assert torch.equal(padded[prompt_width + j], teacher_ids[prompt_len + block_len + j])
47 changes: 46 additions & 1 deletion verl/experimental/agent_loop/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@
from verl.protocol import DataProto
from verl.tools.tool_registry import load_all_tools
from verl.trainer.distillation import is_distillation_enabled
from verl.trainer.distillation.privileged_context import (
build_privileged_sequence,
resolve_privileged_solution,
slice_privileged_teacher_to_student,
)
from verl.utils.config import omega_conf_to_dataclass
from verl.utils.dataset.rl_dataset import RLHFDataset, get_dataset_class
from verl.utils.model import compute_position_id_with_mask
Expand Down Expand Up @@ -515,6 +520,22 @@ def __init__(
from verl.experimental.teacher_loop.teacher_manager import AsyncTeacherLLMServerManager

self.teacher_key: str = config.distillation.teacher_key
self.self_distillation: bool = config.distillation.self_distillation
self.privileged_solution_key: str = config.distillation.privileged_solution_key
self._privileged_prefix_ids: list[int] = []
self._privileged_suffix_ids: list[int] = []
self._privileged_insert_before_ids: Optional[list[int]] = None
if self.self_distillation:
self._privileged_prefix_ids = self.tokenizer.encode(
config.distillation.privileged_prefix, add_special_tokens=False
)
self._privileged_suffix_ids = self.tokenizer.encode(
config.distillation.privileged_suffix, add_special_tokens=False
)
if config.distillation.privileged_insert_before:
self._privileged_insert_before_ids = self.tokenizer.encode(
config.distillation.privileged_insert_before, add_special_tokens=False
)
self.teacher_server_manager = AsyncTeacherLLMServerManager(
config=config,
teacher_client=teacher_client,
Expand Down Expand Up @@ -1013,12 +1034,36 @@ async def _compute_teacher_logprobs(
if routing_value is not None:
# Non-tensor batch values arrive as 0-d numpy objects / arrays; normalize to Python.
routing_key = routing_value.item() if hasattr(routing_value, "item") else routing_value
if self.self_distillation:
solution = resolve_privileged_solution(sample_kwargs, self.privileged_solution_key)
if not solution:
raise ValueError(
f"self_distillation is enabled but no privileged solution resolved at "
f"'{self.privileged_solution_key}'; verl's ground truth is usually at "
f"'reward_model.ground_truth'."
)
solution_ids = self.tokenizer.encode(solution, add_special_tokens=False)
sequence_ids = build_privileged_sequence(
prompt_ids,
response_ids,
solution_ids,
self._privileged_prefix_ids,
self._privileged_suffix_ids,
self._privileged_insert_before_ids,
)
else:
sequence_ids = prompt_ids + response_ids
teacher_ids, teacher_logprobs = await self.teacher_server_manager.compute_teacher_logprobs_single(
sequence_ids=prompt_ids + response_ids,
sequence_ids=sequence_ids,
multi_modal_data=output.multi_modal_data,
mm_processor_kwargs=output.mm_processor_kwargs,
routing_key=routing_key,
)
if self.self_distillation:
# Realign the teacher's privileged-context scores onto the student's positions.
teacher_ids, teacher_logprobs = slice_privileged_teacher_to_student(
teacher_ids, teacher_logprobs, len(prompt_ids), len(response_ids), self.tokenizer.pad_token_id
)
Comment on lines +1064 to +1066

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

If self.tokenizer.pad_token_id is None (which is common for several tokenizers like LLaMA unless explicitly configured), passing it directly to slice_privileged_teacher_to_student will cause a TypeError inside torch.full. We should fall back to eos_token_id or 0 to prevent runtime crashes.

Suggested change
teacher_ids, teacher_logprobs = slice_privileged_teacher_to_student(
teacher_ids, teacher_logprobs, len(prompt_ids), len(response_ids), self.tokenizer.pad_token_id
)
pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else (self.tokenizer.eos_token_id if self.tokenizer.eos_token_id is not None else 0)
teacher_ids, teacher_logprobs = slice_privileged_teacher_to_student(
teacher_ids, teacher_logprobs, len(prompt_ids), len(response_ids), pad_token_id
)

output.extra_fields["teacher_ids"] = teacher_ids
output.extra_fields["teacher_logprobs"] = teacher_logprobs

Expand Down
7 changes: 7 additions & 0 deletions verl/experimental/teacher_loop/teacher_manager.py

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Originally, the filtering was based on the length of the Student's prompt, which is why this error has a chance of occurring. Would it be better to define a separate filtering parameter in rl_dataset?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point -- the overflow is because the privileged sequence is longer than the student prompt the dataset filtered on. I kept the teacher-side raise as a loud safety net for now, but a dataset-level filter (capping prompt+solution upfront, like filter_overlong_prompts) is cleaner and I think fits better as a follow-up PR. Open to either.

Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,13 @@ async def compute_teacher_logprobs_single(
teacher_key = self._resolve_teacher_key(routing_key)
teacher_model_config = self.teacher_model_configs[teacher_key]
client = self.teacher_client[teacher_key]
max_model_len = teacher_model_config.inference.max_model_len
if max_model_len is not None and len(sequence_ids) + 1 > max_model_len:
raise ValueError(
f"Teacher input ({len(sequence_ids)} tokens, +1 to score) exceeds the teacher's "
f"max_model_len ({max_model_len}); with OPSD privileged context this usually means "
f"the reference solution is too long -- raise max_model_len or shorten the solution."
)
teacher_output = await client.generate(
request_id=uuid4().hex,
prompt_ids=sequence_ids,
Expand Down
15 changes: 15 additions & 0 deletions verl/trainer/config/_generated_ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,21 @@ distillation:
response_length: ${oc.select:actor_rollout_ref.rollout.response_length}
temperature: ${oc.select:actor_rollout_ref.rollout.temperature}
teacher_key: data_source
self_distillation: false
privileged_solution_key: reward_model.ground_truth
privileged_prefix: '


Reference solution:

'
privileged_suffix: '


Using this as a reference, derive the answer yourself. Think step by step.

'
privileged_insert_before: ''
trainer:
balance_batch: true
total_epochs: 30
Expand Down
15 changes: 15 additions & 0 deletions verl/trainer/config/_generated_ppo_torchtitan_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,21 @@ distillation:
response_length: ${oc.select:actor_rollout_ref.rollout.response_length}
temperature: ${oc.select:actor_rollout_ref.rollout.temperature}
teacher_key: data_source
self_distillation: false
privileged_solution_key: reward_model.ground_truth
privileged_prefix: '


Reference solution:

'
privileged_suffix: '


Using this as a reference, derive the answer yourself. Think step by step.

'
privileged_insert_before: ''
trainer:
balance_batch: true
total_epochs: 30
Expand Down
15 changes: 15 additions & 0 deletions verl/trainer/config/_generated_ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,21 @@ distillation:
response_length: ${oc.select:actor_rollout_ref.rollout.response_length}
temperature: ${oc.select:actor_rollout_ref.rollout.temperature}
teacher_key: data_source
self_distillation: false
privileged_solution_key: reward_model.ground_truth
privileged_prefix: '


Reference solution:

'
privileged_suffix: '


Using this as a reference, derive the answer yourself. Think step by step.

'
privileged_insert_before: ''
trainer:
balance_batch: true
total_epochs: 30
Expand Down
15 changes: 15 additions & 0 deletions verl/trainer/config/_generated_ppo_veomni_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -837,6 +837,21 @@ distillation:
response_length: ${oc.select:actor_rollout_ref.rollout.response_length}
temperature: ${oc.select:actor_rollout_ref.rollout.temperature}
teacher_key: data_source
self_distillation: false
privileged_solution_key: reward_model.ground_truth
privileged_prefix: '


Reference solution:

'
privileged_suffix: '


Using this as a reference, derive the answer yourself. Think step by step.

'
privileged_insert_before: ''
trainer:
balance_batch: true
total_epochs: 30
Expand Down
17 changes: 16 additions & 1 deletion verl/trainer/config/distillation/distillation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -110,4 +110,19 @@ teacher_models:
temperature: ${oc.select:actor_rollout_ref.rollout.temperature}

# Key to route examples to the appropriate teacher model in multi-teacher setups. Should correspond to a field in the data proto, e.g., task.
teacher_key: data_source
teacher_key: data_source

# On-Policy Self-Distillation (OPSD): the teacher conditions on the ground-truth
# solution (privileged context) while the student sees only the problem. Point a
# single teacher at the student checkpoint (teacher_model.model_path == student
# path) to make it a frozen self-teacher.
self_distillation: false
# Non-tensor batch field holding the ground-truth solution. verl stores it nested
# at reward_model.ground_truth; a dotted key reaches into the nested dict.
privileged_solution_key: reward_model.ground_truth
# Marker text wrapping the privileged solution in the teacher input.
privileged_prefix: "\n\nReference solution:\n"
privileged_suffix: "\n\nUsing this as a reference, derive the answer yourself. Think step by step.\n"
# Optional: insert the solution before the last occurrence of this marker text in
# the prompt (e.g. the assistant-turn opener) instead of appending. Empty = append.
privileged_insert_before: ""
Loading