diff --git a/tests/workers/test_distillation_config_opsd_on_cpu.py b/tests/workers/test_distillation_config_opsd_on_cpu.py new file mode 100644 index 00000000000..50fd53b5efa --- /dev/null +++ b/tests/workers/test_distillation_config_opsd_on_cpu.py @@ -0,0 +1,31 @@ +# Copyright 2026 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""CPU tests for the OPSD fields on ``DistillationConfig``.""" + +from verl.workers.config.distillation import DistillationConfig + + +def test_opsd_defaults_are_off_and_backward_compatible(): + c = DistillationConfig() + assert c.self_distillation is False + assert c.privileged_solution_key == "reward_model.ground_truth" + # markers default to newline-wrapped text so the solution is set off from the prompt + assert c.privileged_prefix.strip() and c.privileged_suffix.strip() + + +def test_self_distillation_requires_enabled(): + import pytest + + with pytest.raises(ValueError): + DistillationConfig(self_distillation=True) # enabled defaults to False diff --git a/tests/workers/test_opsd_privileged_context_on_cpu.py b/tests/workers/test_opsd_privileged_context_on_cpu.py new file mode 100644 index 00000000000..c9fb300e212 --- /dev/null +++ b/tests/workers/test_opsd_privileged_context_on_cpu.py @@ -0,0 +1,160 @@ +# Copyright 2026 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""CPU tests for the OPSD privileged-context helpers.""" + +import numpy as np +import pytest +import torch + +from verl.trainer.distillation.privileged_context import ( + build_privileged_sequence, + resolve_privileged_solution, + slice_privileged_teacher_to_student, +) + + +def test_resolve_privileged_solution_nested_dotted_key(): + sk = {"reward_model": {"ground_truth": "72"}} + assert resolve_privileged_solution(sk, "reward_model.ground_truth") == "72" + # flat key still works + assert resolve_privileged_solution({"sol": "x"}, "sol") == "x" + + +def test_resolve_privileged_solution_missing_returns_none(): + assert resolve_privileged_solution({"a": 1}, "reward_model.ground_truth") is None + assert resolve_privileged_solution(None, "x") is None + assert resolve_privileged_solution({"g": ""}, "g") is None + + +def test_resolve_privileged_solution_normalizes_scalar_array_list(): + assert resolve_privileged_solution({"g": np.array("72")}, "g") == "72" # 0-d -> item + assert resolve_privileged_solution({"g": np.array(["x"])}, "g") == "x" # 1-elem array + assert resolve_privileged_solution({"g": ["a", "b"]}, "g") == "a\nb" # list -> joined + + +def test_build_privileged_sequence_layout(): + seq = build_privileged_sequence( + prompt_ids=[1, 2], + response_ids=[9, 10], + solution_ids=[5, 6, 7], + prefix_ids=[100], + suffix_ids=[200, 201], + ) + assert seq == [1, 2, 100, 5, 6, 7, 200, 201, 9, 10] + # the response is the suffix, exactly as in a plain prompt+response teacher input + assert seq[-2:] == [9, 10] + + +def test_build_privileged_sequence_empty_markers(): + assert build_privileged_sequence([1], [9], [5], [], []) == [1, 5, 9] + + +def test_build_privileged_sequence_insert_before_marker(): + # marker = [8] (e.g. an assistant-turn opener); solution block goes before its + # last occurrence, the response after the original prompt tail. + seq = build_privileged_sequence( + prompt_ids=[1, 8, 2, 8], + response_ids=[9], + solution_ids=[5], + prefix_ids=[100], + suffix_ids=[200], + insert_before_token_ids=[8], + ) + assert seq == [1, 8, 2, 100, 5, 200, 8, 9] + + +def test_build_privileged_sequence_insert_marker_not_found_falls_back(): + seq = build_privileged_sequence( + [1, 2], [9], [5], [100], [200], insert_before_token_ids=[42] + ) + assert seq == [1, 2, 100, 5, 200, 9] # append fallback + + +def test_slice_privileged_teacher_keeps_response_pads_prompt(): + priv_len, k = 10, 2 + teacher_ids = torch.arange(priv_len * k).reshape(priv_len, k) + teacher_logprobs = torch.randn(priv_len, k) + + ids, logprobs = slice_privileged_teacher_to_student( + teacher_ids, teacher_logprobs, student_prompt_length=2, response_length=2, pad_token_id=0 + ) + + assert ids.shape == (4, k) and logprobs.shape == (4, k) + assert torch.equal(ids[-2:], teacher_ids[-2:]) + assert torch.equal(logprobs[-2:], teacher_logprobs[-2:]) + assert torch.all(ids[:2] == 0) + assert torch.all(logprobs[:2] == 0.0) + + +def test_slice_response_length_zero_is_empty_not_whole_tensor(): + # regression: teacher_ids[-0:] would return the whole tensor; an empty + # response must yield only the padded prompt rows. + teacher_ids = torch.arange(12).reshape(6, 2) + teacher_logprobs = torch.randn(6, 2) + ids, logprobs = slice_privileged_teacher_to_student( + teacher_ids, teacher_logprobs, student_prompt_length=3, response_length=0, pad_token_id=0 + ) + assert ids.shape == (3, 2) and logprobs.shape == (3, 2) + assert torch.all(ids == 0) + + +def test_slice_pad_token_id_none_falls_back_to_zero(): + teacher_ids = torch.arange(6).reshape(3, 2) + teacher_logprobs = torch.randn(3, 2) + ids, _ = slice_privileged_teacher_to_student( + teacher_ids, teacher_logprobs, student_prompt_length=1, response_length=2, pad_token_id=None + ) + assert ids[0, 0].item() == 0 + + +def test_slice_preserves_dtype_and_pad_value(): + teacher_ids = torch.arange(6).reshape(3, 2).long() + teacher_logprobs = torch.randn(3, 2, dtype=torch.float32) + ids, logprobs = slice_privileged_teacher_to_student( + teacher_ids, teacher_logprobs, student_prompt_length=1, response_length=2, pad_token_id=7 + ) + assert ids.dtype == torch.long and logprobs.dtype == torch.float32 + assert ids[0, 0].item() == 7 + + +def test_slice_negative_response_length_raises(): + with pytest.raises(ValueError): + slice_privileged_teacher_to_student( + torch.arange(6).reshape(3, 2), torch.randn(3, 2), 1, -1, 0 + ) + + +def test_privileged_slice_aligns_response_rows_after_padding(): + """End-to-end alignment: after the privileged slice + the same left/right pad + that ``_pad_teacher_outputs`` applies, response token ``j`` lands at absolute + row ``prompt_width + j`` -- the privileged teacher's score for that token. The + pad is replicated inline (``F.pad``) so we don't import the heavy + teacher_manager package.""" + import torch.nn.functional as F + + prompt_len, block_len, resp_len, k = 3, 4, 2, 2 + prompt_width, response_width = 5, 4 # batch widths (>= the per-sample lengths) + n = prompt_len + block_len + resp_len # teacher scores prompt + block + response + teacher_ids = torch.arange(n * k).reshape(n, k) + teacher_logprobs = torch.randn(n, k) + + ids, _ = slice_privileged_teacher_to_student( + teacher_ids, teacher_logprobs, prompt_len, resp_len, pad_token_id=0 + ) + # _pad_teacher_outputs: left-pad the prompt region, right-pad the response region + padded = F.pad(ids, (0, 0, prompt_width - prompt_len, response_width - resp_len), value=0) + assert padded.shape[0] == prompt_width + response_width + for j in range(resp_len): + # absolute row prompt_width+j == the teacher's privileged score for response token j + assert torch.equal(padded[prompt_width + j], teacher_ids[prompt_len + block_len + j]) diff --git a/verl/experimental/agent_loop/agent_loop.py b/verl/experimental/agent_loop/agent_loop.py index 9b618f89d2f..ebf5b45bc49 100644 --- a/verl/experimental/agent_loop/agent_loop.py +++ b/verl/experimental/agent_loop/agent_loop.py @@ -49,6 +49,11 @@ from verl.protocol import DataProto from verl.tools.tool_registry import load_all_tools from verl.trainer.distillation import is_distillation_enabled +from verl.trainer.distillation.privileged_context import ( + build_privileged_sequence, + resolve_privileged_solution, + slice_privileged_teacher_to_student, +) from verl.utils.config import omega_conf_to_dataclass from verl.utils.dataset.rl_dataset import RLHFDataset, get_dataset_class from verl.utils.model import compute_position_id_with_mask @@ -515,6 +520,22 @@ def __init__( from verl.experimental.teacher_loop.teacher_manager import AsyncTeacherLLMServerManager self.teacher_key: str = config.distillation.teacher_key + self.self_distillation: bool = config.distillation.self_distillation + self.privileged_solution_key: str = config.distillation.privileged_solution_key + self._privileged_prefix_ids: list[int] = [] + self._privileged_suffix_ids: list[int] = [] + self._privileged_insert_before_ids: Optional[list[int]] = None + if self.self_distillation: + self._privileged_prefix_ids = self.tokenizer.encode( + config.distillation.privileged_prefix, add_special_tokens=False + ) + self._privileged_suffix_ids = self.tokenizer.encode( + config.distillation.privileged_suffix, add_special_tokens=False + ) + if config.distillation.privileged_insert_before: + self._privileged_insert_before_ids = self.tokenizer.encode( + config.distillation.privileged_insert_before, add_special_tokens=False + ) self.teacher_server_manager = AsyncTeacherLLMServerManager( config=config, teacher_client=teacher_client, @@ -1013,12 +1034,36 @@ async def _compute_teacher_logprobs( if routing_value is not None: # Non-tensor batch values arrive as 0-d numpy objects / arrays; normalize to Python. routing_key = routing_value.item() if hasattr(routing_value, "item") else routing_value + if self.self_distillation: + solution = resolve_privileged_solution(sample_kwargs, self.privileged_solution_key) + if not solution: + raise ValueError( + f"self_distillation is enabled but no privileged solution resolved at " + f"'{self.privileged_solution_key}'; verl's ground truth is usually at " + f"'reward_model.ground_truth'." + ) + solution_ids = self.tokenizer.encode(solution, add_special_tokens=False) + sequence_ids = build_privileged_sequence( + prompt_ids, + response_ids, + solution_ids, + self._privileged_prefix_ids, + self._privileged_suffix_ids, + self._privileged_insert_before_ids, + ) + else: + sequence_ids = prompt_ids + response_ids teacher_ids, teacher_logprobs = await self.teacher_server_manager.compute_teacher_logprobs_single( - sequence_ids=prompt_ids + response_ids, + sequence_ids=sequence_ids, multi_modal_data=output.multi_modal_data, mm_processor_kwargs=output.mm_processor_kwargs, routing_key=routing_key, ) + if self.self_distillation: + # Realign the teacher's privileged-context scores onto the student's positions. + teacher_ids, teacher_logprobs = slice_privileged_teacher_to_student( + teacher_ids, teacher_logprobs, len(prompt_ids), len(response_ids), self.tokenizer.pad_token_id + ) output.extra_fields["teacher_ids"] = teacher_ids output.extra_fields["teacher_logprobs"] = teacher_logprobs diff --git a/verl/experimental/teacher_loop/teacher_manager.py b/verl/experimental/teacher_loop/teacher_manager.py index 1a8d859840e..4431f8a7cf0 100644 --- a/verl/experimental/teacher_loop/teacher_manager.py +++ b/verl/experimental/teacher_loop/teacher_manager.py @@ -111,6 +111,13 @@ async def compute_teacher_logprobs_single( teacher_key = self._resolve_teacher_key(routing_key) teacher_model_config = self.teacher_model_configs[teacher_key] client = self.teacher_client[teacher_key] + max_model_len = teacher_model_config.inference.max_model_len + if max_model_len is not None and len(sequence_ids) + 1 > max_model_len: + raise ValueError( + f"Teacher input ({len(sequence_ids)} tokens, +1 to score) exceeds the teacher's " + f"max_model_len ({max_model_len}); with OPSD privileged context this usually means " + f"the reference solution is too long -- raise max_model_len or shorten the solution." + ) teacher_output = await client.generate( request_id=uuid4().hex, prompt_ids=sequence_ids, diff --git a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml index 5262305272c..e53702305cf 100644 --- a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml @@ -895,6 +895,21 @@ distillation: response_length: ${oc.select:actor_rollout_ref.rollout.response_length} temperature: ${oc.select:actor_rollout_ref.rollout.temperature} teacher_key: data_source + self_distillation: false + privileged_solution_key: reward_model.ground_truth + privileged_prefix: ' + + + Reference solution: + + ' + privileged_suffix: ' + + + Using this as a reference, derive the answer yourself. Think step by step. + + ' + privileged_insert_before: '' trainer: balance_batch: true total_epochs: 30 diff --git a/verl/trainer/config/_generated_ppo_torchtitan_trainer.yaml b/verl/trainer/config/_generated_ppo_torchtitan_trainer.yaml index 25c671291fb..11b6c2a59d6 100644 --- a/verl/trainer/config/_generated_ppo_torchtitan_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_torchtitan_trainer.yaml @@ -814,6 +814,21 @@ distillation: response_length: ${oc.select:actor_rollout_ref.rollout.response_length} temperature: ${oc.select:actor_rollout_ref.rollout.temperature} teacher_key: data_source + self_distillation: false + privileged_solution_key: reward_model.ground_truth + privileged_prefix: ' + + + Reference solution: + + ' + privileged_suffix: ' + + + Using this as a reference, derive the answer yourself. Think step by step. + + ' + privileged_insert_before: '' trainer: balance_batch: true total_epochs: 30 diff --git a/verl/trainer/config/_generated_ppo_trainer.yaml b/verl/trainer/config/_generated_ppo_trainer.yaml index efc49085fdc..d2d098e45ec 100644 --- a/verl/trainer/config/_generated_ppo_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_trainer.yaml @@ -836,6 +836,21 @@ distillation: response_length: ${oc.select:actor_rollout_ref.rollout.response_length} temperature: ${oc.select:actor_rollout_ref.rollout.temperature} teacher_key: data_source + self_distillation: false + privileged_solution_key: reward_model.ground_truth + privileged_prefix: ' + + + Reference solution: + + ' + privileged_suffix: ' + + + Using this as a reference, derive the answer yourself. Think step by step. + + ' + privileged_insert_before: '' trainer: balance_batch: true total_epochs: 30 diff --git a/verl/trainer/config/_generated_ppo_veomni_trainer.yaml b/verl/trainer/config/_generated_ppo_veomni_trainer.yaml index ec5b9022e8a..f8325b1d598 100644 --- a/verl/trainer/config/_generated_ppo_veomni_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_veomni_trainer.yaml @@ -837,6 +837,21 @@ distillation: response_length: ${oc.select:actor_rollout_ref.rollout.response_length} temperature: ${oc.select:actor_rollout_ref.rollout.temperature} teacher_key: data_source + self_distillation: false + privileged_solution_key: reward_model.ground_truth + privileged_prefix: ' + + + Reference solution: + + ' + privileged_suffix: ' + + + Using this as a reference, derive the answer yourself. Think step by step. + + ' + privileged_insert_before: '' trainer: balance_batch: true total_epochs: 30 diff --git a/verl/trainer/config/distillation/distillation.yaml b/verl/trainer/config/distillation/distillation.yaml index 84ee7932ab6..cc28e6070a0 100644 --- a/verl/trainer/config/distillation/distillation.yaml +++ b/verl/trainer/config/distillation/distillation.yaml @@ -110,4 +110,19 @@ teacher_models: temperature: ${oc.select:actor_rollout_ref.rollout.temperature} # Key to route examples to the appropriate teacher model in multi-teacher setups. Should correspond to a field in the data proto, e.g., task. -teacher_key: data_source \ No newline at end of file +teacher_key: data_source + +# On-Policy Self-Distillation (OPSD): the teacher conditions on the ground-truth +# solution (privileged context) while the student sees only the problem. Point a +# single teacher at the student checkpoint (teacher_model.model_path == student +# path) to make it a frozen self-teacher. +self_distillation: false +# Non-tensor batch field holding the ground-truth solution. verl stores it nested +# at reward_model.ground_truth; a dotted key reaches into the nested dict. +privileged_solution_key: reward_model.ground_truth +# Marker text wrapping the privileged solution in the teacher input. +privileged_prefix: "\n\nReference solution:\n" +privileged_suffix: "\n\nUsing this as a reference, derive the answer yourself. Think step by step.\n" +# Optional: insert the solution before the last occurrence of this marker text in +# the prompt (e.g. the assistant-turn opener) instead of appending. Empty = append. +privileged_insert_before: "" \ No newline at end of file diff --git a/verl/trainer/distillation/privileged_context.py b/verl/trainer/distillation/privileged_context.py new file mode 100644 index 00000000000..fc9a5ddbcc3 --- /dev/null +++ b/verl/trainer/distillation/privileged_context.py @@ -0,0 +1,154 @@ +# Copyright 2026 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Privileged-context helpers for On-Policy Self-Distillation (OPSD). + +In OPSD the teacher and student share weights but see different contexts: the +student sees only the problem, while the teacher additionally sees the +ground-truth solution. The teacher does not generate -- it scores the student's +own on-policy response conditioned on the privileged solution. These two pure +helpers build the teacher's privileged input sequence and realign the teacher's +per-token top-k outputs back onto the student's ``prompt + response`` positions, +so the rest of verl's on-policy-distillation pipeline is reused unchanged. +""" + +import torch + + +def resolve_privileged_solution(sample_kwargs: dict | None, key: str) -> str | None: + """Resolve the ground-truth solution from a (possibly nested) sample field. + + ``key`` may be dotted (e.g. ``"reward_model.ground_truth"``) to reach into + nested dicts -- verl stores the ground truth at + ``non_tensor_batch["reward_model"]["ground_truth"]``, not at a top-level key, + so a flat ``.get(key)`` silently misses it and OPSD degrades to plain OPD. + + Normalizes numpy 0-d scalars (``.item()``), arrays/lists (joined with + newlines), and dicts to a stripped string. Returns ``None`` if the field is + absent or resolves to an empty string. + """ + if sample_kwargs is None: + return None + cur = sample_kwargs + for part in key.split("."): + if isinstance(cur, dict) and part in cur: + cur = cur[part] + else: + return None + if hasattr(cur, "item") and getattr(cur, "size", 1) == 1: + cur = cur.item() + elif hasattr(cur, "tolist"): + cur = cur.tolist() + if isinstance(cur, list | tuple): + cur = "\n".join(str(x) for x in cur) if len(cur) else None + if cur is None: + return None + solution = str(cur).strip() + return solution or None + + +def build_privileged_sequence( + prompt_ids: list[int], + response_ids: list[int], + solution_ids: list[int], + prefix_ids: list[int], + suffix_ids: list[int], + insert_before_token_ids: list[int] | None = None, +) -> list[int]: + """Build the OPSD teacher's input token sequence. + + Default layout: ``prompt + prefix + solution + suffix + response`` -- the + privileged solution (wrapped in ``prefix`` / ``suffix`` markers) is appended + to the student prompt, and the response is the suffix exactly as in the plain + ``prompt + response`` teacher input. + + ``prompt_ids`` is the post-``apply_chat_template`` prompt, so it ends with the + template's assistant-turn opener. Appending the solution after it puts the + solution inside the assistant turn. If a template needs the solution placed + *before* a specific marker (e.g. the assistant-turn opener), pass + ``insert_before_token_ids``: the solution block is then inserted before the + **last** occurrence of that sub-sequence in ``prompt_ids``. If the marker is + not found, this falls back to the default append. + + Args: + prompt_ids: the student's prompt (problem) token ids. + response_ids: the student's on-policy response token ids. + solution_ids: the ground-truth solution token ids. + prefix_ids: marker tokens placed before the solution. May be empty. + suffix_ids: marker tokens placed after the solution. May be empty. + insert_before_token_ids: if provided and found, insert the solution block + before the last occurrence of this sub-sequence in ``prompt_ids``. + + Returns: + The concatenated teacher input token ids. + """ + block = prefix_ids + solution_ids + suffix_ids + if insert_before_token_ids: + m = len(insert_before_token_ids) + for i in range(len(prompt_ids) - m, -1, -1): + if prompt_ids[i : i + m] == insert_before_token_ids: + return prompt_ids[:i] + block + prompt_ids[i:] + response_ids + # marker not found -> fall through to the default append + return prompt_ids + block + response_ids + + +def slice_privileged_teacher_to_student( + teacher_ids: torch.Tensor, + teacher_logprobs: torch.Tensor, + student_prompt_length: int, + response_length: int, + pad_token_id: int | None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Realign privileged-context teacher outputs onto the student's positions. + + The teacher's top-k outputs are computed over the privileged sequence and are + 1:1 aligned to it. Only the final ``response_length`` rows -- the teacher's + per-token scores for the response tokens under the privileged context -- are + distillation targets. This keeps those rows and pads the + ``student_prompt_length`` prompt rows (the downstream response mask zeroes the + prompt region out anyway), so the returned tensors are aligned to the + student's ``prompt + response`` and feed the existing padding / loss path + unchanged. + + Args: + teacher_ids: ``(privileged_len + response_length, k)`` teacher top-k ids. + teacher_logprobs: ``(privileged_len + response_length, k)`` teacher + top-k log-probs. + student_prompt_length: length of the student's (non-privileged) prompt. + response_length: number of response tokens (may be 0). + pad_token_id: id used to fill the padded prompt rows; ``None`` falls back + to ``0`` (these rows are masked out downstream, so the value is inert). + + Returns: + ``(ids, logprobs)``, each ``(student_prompt_length + response_length, k)`` + and aligned to the student's ``prompt + response`` sequence. + """ + if response_length < 0: + raise ValueError(f"response_length must be non-negative, got {response_length}") + if pad_token_id is None: + pad_token_id = 0 + k = teacher_ids.shape[-1] + # Index from an explicit start: ``teacher_ids[-0:]`` would return the whole + # tensor, so a 0-length response must slice from the end, not from ``-0``. + start = teacher_ids.shape[0] - response_length + response_ids = teacher_ids[start:] + response_logprobs = teacher_logprobs[start:] + prompt_ids = torch.full( + (student_prompt_length, k), int(pad_token_id), dtype=response_ids.dtype, device=response_ids.device + ) + prompt_logprobs = torch.zeros( + (student_prompt_length, k), dtype=response_logprobs.dtype, device=response_logprobs.device + ) + ids = torch.cat([prompt_ids, response_ids], dim=0) + logprobs = torch.cat([prompt_logprobs, response_logprobs], dim=0) + return ids, logprobs diff --git a/verl/workers/config/distillation.py b/verl/workers/config/distillation.py index b58f8bccffd..21f272dab7e 100644 --- a/verl/workers/config/distillation.py +++ b/verl/workers/config/distillation.py @@ -264,9 +264,33 @@ class DistillationConfig(BaseConfig): nnodes: int = 0 teacher_models: dict[str, DistillationTeacherModelConfig] = field(default_factory=dict) teacher_key: str = "data_source" + # On-Policy Self-Distillation (OPSD): the teacher shares the student's weights but + # additionally conditions on the ground-truth solution (privileged context); the + # student sees only the problem. Point a single teacher at the student checkpoint + # (teacher.model_path == student path) to make it a frozen self-teacher. OPSD is + # a supervised signal: pair with distillation_loss.use_policy_gradient=False and + # use_task_rewards=False. + self_distillation: bool = False + # Non-tensor batch field with the privileged solution; dotted to reach nested + # dicts (verl stores ground truth at reward_model.ground_truth). NOTE: on some + # datasets (e.g. gsm8k) ground_truth is only the final answer -- point this at + # the full worked solution (e.g. extra_info.answer) for a stronger signal. + privileged_solution_key: str = "reward_model.ground_truth" + # Marker text wrapping the privileged solution in the teacher's input. + privileged_prefix: str = "\n\nReference solution:\n" + privileged_suffix: str = "\n\nUsing this as a reference, derive the answer yourself. Think step by step.\n" + # Optional: insert the privileged solution before the last occurrence of this + # marker text in the prompt (e.g. the assistant-turn opener) instead of appending + # it. Empty = append after the prompt. + privileged_insert_before: str = "" distillation_loss: DistillationLossConfig = field(default_factory=DistillationLossConfig) def __post_init__(self): + if self.self_distillation: + if not self.enabled: + raise ValueError("self_distillation requires distillation.enabled=True.") + if not self.privileged_solution_key: + raise ValueError("self_distillation requires a non-empty privileged_solution_key.") if not self.enabled: return