Skip to content

Add option to import lm head and wte embeddings#837

Open
klei22 wants to merge 3 commits into
ReaLLMASIC:masterfrom
klei22:add-option-to-import-lm-head-and-wte-embeddings
Open

Add option to import lm head and wte embeddings#837
klei22 wants to merge 3 commits into
ReaLLMASIC:masterfrom
klei22:add-option-to-import-lm-head-and-wte-embeddings

Conversation

@klei22

@klei22 klei22 commented Jun 14, 2026

Copy link
Copy Markdown
Collaborator

No description provided.

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds a new training-time initialization path to seed only the token embedding table (WTE) and LM head weights from an existing nanoGPT-style checkpoint, optionally freezing those imported parameters. This supports experiments where you want to reuse learned token-space representations without resuming full training state.

Changes:

  • Add CLI/config options to import WTE + LM head weights from a checkpoint and optionally freeze them.
  • Implement checkpoint loading + selective weight copy into transformer.wte.weight and lm_head.weight, with handling for weight tying.
  • Document usage in the README and add an exploration config to compare seeded vs unseeded runs.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
train_args.py Adds new CLI flags for importing/freezing WTE+LM head from a checkpoint.
gpt_conf.py Extends GPTConfig with new import/freezing fields.
model.py Implements the selective checkpoint import logic for WTE and LM head.
README.md Documents how to use the new import options (including weight-tying guidance).
explorations/default_inf_wte_lm_head_import_comparison.yaml Adds an experiment template to compare default vs seeded runs.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread model.py

def _checkpoint_state_dict(self, checkpoint_path):
"""Load a checkpoint and return a prefix-normalized model state dict."""
checkpoint_obj = torch.load(checkpoint_path, map_location="cpu")
Comment thread model.py
Comment on lines +346 to +360
def _copy_imported_weight(self, source_state_dict, target_key):
if target_key not in source_state_dict:
raise KeyError(f"Checkpoint is missing required weight '{target_key}'")
target_state_dict = self.state_dict()
if target_key not in target_state_dict:
raise KeyError(f"Current model does not have a '{target_key}' parameter to import into")
source_weight = source_state_dict[target_key].detach().float()
target_weight = target_state_dict[target_key]
if source_weight.shape != target_weight.shape:
raise ValueError(
f"Shape mismatch for {target_key}: checkpoint has {tuple(source_weight.shape)} "
f"but current model expects {tuple(target_weight.shape)}"
)
target_weight.copy_(source_weight.to(device=target_weight.device, dtype=target_weight.dtype))

Comment thread model.py
Comment on lines 200 to +207
# import wte
if self.config.import_wte_npy:
# Replace wte with values from numpy and retie weights
self.import_wte(self.config.import_wte_npy)

# import full wte + lm_head from an existing nanoGPT checkpoint
if self.config.import_wte_lm_head_ckpt:
self.import_wte_lm_head_from_ckpt(
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants