Add frozen lm_head KL training option#852
Open
klei22 wants to merge 1 commit into
Open
Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds an optional “frozen target lm_head KL” feature: load an lm_head (or embedding table fallback) from a checkpoint, compute KL(student‖target) against the frozen head, and use it either as the training loss or as a monitoring metric. It also wires the resulting metrics into the exploration/monitoring scripts.
Changes:
- Add CLI flags for configuring a frozen target lm_head checkpoint and KL settings (mode/weight/temperature/eps/dataset index).
- Extend training + evaluation to optionally return final hidden states, compute frozen-head KL, and log it (TensorBoard + metrics file).
- Update exploration tooling to recognize the new metrics keys.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
train.py |
Loads frozen target lm_head, computes KL loss/metrics, logs values, and adds NaN-aware stats helpers. |
train_args.py |
Adds CLI arguments to enable/configure frozen lm_head KL behavior. |
model.py |
Adds return_hidden option to return final hidden states for single-context forward passes. |
run_exploration_monitor.py |
Adds columns for the new KL metrics in the monitor UI. |
optimization_and_search/run_from_yaml.py |
Expands parsed metric schema to include new metrics (and more). |
optimization_and_search/run_experiments.py |
Extends metric schema parsing to include the new KL metrics. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
+553
to
+559
| def _resolve_checkpoint_path(self, path_value): | ||
| expanded = os.path.expanduser(path_value) | ||
| if not os.path.exists(expanded): | ||
| candidate = os.path.join(self.args.out_dir, path_value) | ||
| if os.path.exists(candidate): | ||
| expanded = candidate | ||
| return expanded |
Comment on lines
+1631
to
+1641
| if 'target_lm_head_kl_val' in losses: | ||
| self.writer.add_scalar( | ||
| f"{target_dataset}/target_lm_head_kl_val", | ||
| losses['target_lm_head_kl_val'], | ||
| self.iter_num, | ||
| ) | ||
| self.writer.add_scalar( | ||
| f"{target_dataset}/target_lm_head_kl_train_eval", | ||
| losses['target_lm_head_kl_train'], | ||
| self.iter_num, | ||
| ) |
Comment on lines
+104
to
+108
| casts = [float, int, int, int] + [float] * (len(METRIC_KEYS) - 4) | ||
| metrics = {} | ||
| for key, typ, value in zip(METRIC_KEYS, casts, parts[:len(METRIC_KEYS)]): | ||
| metrics[key] = float("nan") if value == "" else typ(value) | ||
| return metrics |
Comment on lines
25
to
+30
| METRIC_KEYS = [ | ||
| "best_val_loss", | ||
| "best_val_iter", | ||
| "best_val_iter", | ||
| "best_tokens", | ||
| "num_params", | ||
| "better_than_chance", |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
No description provided.