Add option to import lm head and wte embeddings#837
Open
klei22 wants to merge 3 commits into
Open
Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds a new training-time initialization path to seed only the token embedding table (WTE) and LM head weights from an existing nanoGPT-style checkpoint, optionally freezing those imported parameters. This supports experiments where you want to reuse learned token-space representations without resuming full training state.
Changes:
- Add CLI/config options to import WTE + LM head weights from a checkpoint and optionally freeze them.
- Implement checkpoint loading + selective weight copy into
transformer.wte.weightandlm_head.weight, with handling for weight tying. - Document usage in the README and add an exploration config to compare seeded vs unseeded runs.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| train_args.py | Adds new CLI flags for importing/freezing WTE+LM head from a checkpoint. |
| gpt_conf.py | Extends GPTConfig with new import/freezing fields. |
| model.py | Implements the selective checkpoint import logic for WTE and LM head. |
| README.md | Documents how to use the new import options (including weight-tying guidance). |
| explorations/default_inf_wte_lm_head_import_comparison.yaml | Adds an experiment template to compare default vs seeded runs. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| def _checkpoint_state_dict(self, checkpoint_path): | ||
| """Load a checkpoint and return a prefix-normalized model state dict.""" | ||
| checkpoint_obj = torch.load(checkpoint_path, map_location="cpu") |
Comment on lines
+346
to
+360
| def _copy_imported_weight(self, source_state_dict, target_key): | ||
| if target_key not in source_state_dict: | ||
| raise KeyError(f"Checkpoint is missing required weight '{target_key}'") | ||
| target_state_dict = self.state_dict() | ||
| if target_key not in target_state_dict: | ||
| raise KeyError(f"Current model does not have a '{target_key}' parameter to import into") | ||
| source_weight = source_state_dict[target_key].detach().float() | ||
| target_weight = target_state_dict[target_key] | ||
| if source_weight.shape != target_weight.shape: | ||
| raise ValueError( | ||
| f"Shape mismatch for {target_key}: checkpoint has {tuple(source_weight.shape)} " | ||
| f"but current model expects {tuple(target_weight.shape)}" | ||
| ) | ||
| target_weight.copy_(source_weight.to(device=target_weight.device, dtype=target_weight.dtype)) | ||
|
|
Comment on lines
200
to
+207
| # import wte | ||
| if self.config.import_wte_npy: | ||
| # Replace wte with values from numpy and retie weights | ||
| self.import_wte(self.config.import_wte_npy) | ||
|
|
||
| # import full wte + lm_head from an existing nanoGPT checkpoint | ||
| if self.config.import_wte_lm_head_ckpt: | ||
| self.import_wte_lm_head_from_ckpt( |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
No description provided.