Skip to content

Add frozen lm_head KL training option#852

Open
klei22 wants to merge 1 commit into
ReaLLMASIC:masterfrom
klei22:add-passive-distilation-loss-monitor
Open

Add frozen lm_head KL training option#852
klei22 wants to merge 1 commit into
ReaLLMASIC:masterfrom
klei22:add-passive-distilation-loss-monitor

Conversation

@klei22

@klei22 klei22 commented Jun 17, 2026

Copy link
Copy Markdown
Collaborator

No description provided.

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds an optional “frozen target lm_head KL” feature: load an lm_head (or embedding table fallback) from a checkpoint, compute KL(student‖target) against the frozen head, and use it either as the training loss or as a monitoring metric. It also wires the resulting metrics into the exploration/monitoring scripts.

Changes:

  • Add CLI flags for configuring a frozen target lm_head checkpoint and KL settings (mode/weight/temperature/eps/dataset index).
  • Extend training + evaluation to optionally return final hidden states, compute frozen-head KL, and log it (TensorBoard + metrics file).
  • Update exploration tooling to recognize the new metrics keys.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
train.py Loads frozen target lm_head, computes KL loss/metrics, logs values, and adds NaN-aware stats helpers.
train_args.py Adds CLI arguments to enable/configure frozen lm_head KL behavior.
model.py Adds return_hidden option to return final hidden states for single-context forward passes.
run_exploration_monitor.py Adds columns for the new KL metrics in the monitor UI.
optimization_and_search/run_from_yaml.py Expands parsed metric schema to include new metrics (and more).
optimization_and_search/run_experiments.py Extends metric schema parsing to include the new KL metrics.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread train.py
Comment on lines +553 to +559
def _resolve_checkpoint_path(self, path_value):
expanded = os.path.expanduser(path_value)
if not os.path.exists(expanded):
candidate = os.path.join(self.args.out_dir, path_value)
if os.path.exists(candidate):
expanded = candidate
return expanded
Comment thread train.py
Comment on lines +1631 to +1641
if 'target_lm_head_kl_val' in losses:
self.writer.add_scalar(
f"{target_dataset}/target_lm_head_kl_val",
losses['target_lm_head_kl_val'],
self.iter_num,
)
self.writer.add_scalar(
f"{target_dataset}/target_lm_head_kl_train_eval",
losses['target_lm_head_kl_train'],
self.iter_num,
)
Comment on lines +104 to +108
casts = [float, int, int, int] + [float] * (len(METRIC_KEYS) - 4)
metrics = {}
for key, typ, value in zip(METRIC_KEYS, casts, parts[:len(METRIC_KEYS)]):
metrics[key] = float("nan") if value == "" else typ(value)
return metrics
Comment on lines 25 to +30
METRIC_KEYS = [
"best_val_loss",
"best_val_iter",
"best_val_iter",
"best_tokens",
"num_params",
"better_than_chance",
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants