Skip to content

Fix state_dict loading for models trained with torch.compile#263

Open
inorganicwriter wants to merge 1 commit into
google-research:mainfrom
inorganicwriter:fix/compile-state-dict-loading
Open

Fix state_dict loading for models trained with torch.compile#263
inorganicwriter wants to merge 1 commit into
google-research:mainfrom
inorganicwriter:fix/compile-state-dict-loading

Conversation

@inorganicwriter

Copy link
Copy Markdown

Problem

When a model is trained with compile: True (the current default in
config.py), torch.compile prepends the _orig_mod. prefix to every
parameter key in the saved state_dict. Loading those weights into a
non-compiled model fails with a key-mismatch error. This commonly
happens when:

  • Running run evaluate / run infer on a host where torch.compile
    is disabled (e.g. Windows without MSVC cl.exe).
  • Setting compile: False in the config for inference while the
    checkpoint was produced with compile: True.
  • Sharing checkpoints across compiled/non-compiled training runs.

Fix

In BaseTester._load_model_weights, strip the _orig_mod. prefix from
all keys before calling load_state_dict when its presence is detected.
The check is opt-in (only triggered when the prefix is actually there),
so loading non-compiled weights is unchanged.

Testing

  • Loaded a compile=True checkpoint with compile=False set in the
    eval config: previously raised RuntimeError, now loads cleanly.
  • Loaded a compile=False checkpoint: behavior unchanged.

@google-cla

google-cla Bot commented Jun 24, 2026

Copy link
Copy Markdown

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@inorganicwriter

Copy link
Copy Markdown
Author

@google-cla-bot check

When a model is trained with compile=True (the default), torch.compile prepends the '_orig_mod.' prefix to all parameter names in the saved state_dict. Loading those weights into a non-compiled model (e.g. during evaluation, or when compile is disabled on Windows) fails with a key mismatch. Strip the prefix automatically so the weights load transparently in both compiled and non-compiled contexts.
@inorganicwriter inorganicwriter force-pushed the fix/compile-state-dict-loading branch from f89a48e to 7e491bc Compare June 24, 2026 19:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant