Fix state_dict loading for models trained with torch.compile#263
Open
inorganicwriter wants to merge 1 commit into
Open
Fix state_dict loading for models trained with torch.compile#263inorganicwriter wants to merge 1 commit into
inorganicwriter wants to merge 1 commit into
Conversation
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
Author
|
@google-cla-bot check |
When a model is trained with compile=True (the default), torch.compile prepends the '_orig_mod.' prefix to all parameter names in the saved state_dict. Loading those weights into a non-compiled model (e.g. during evaluation, or when compile is disabled on Windows) fails with a key mismatch. Strip the prefix automatically so the weights load transparently in both compiled and non-compiled contexts.
f89a48e to
7e491bc
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
When a model is trained with
compile: True(the current default inconfig.py),torch.compileprepends the_orig_mod.prefix to everyparameter key in the saved
state_dict. Loading those weights into anon-compiled model fails with a key-mismatch error. This commonly
happens when:
run evaluate/run inferon a host wheretorch.compileis disabled (e.g. Windows without MSVC
cl.exe).compile: Falsein the config for inference while thecheckpoint was produced with
compile: True.Fix
In
BaseTester._load_model_weights, strip the_orig_mod.prefix fromall keys before calling
load_state_dictwhen its presence is detected.The check is opt-in (only triggered when the prefix is actually there),
so loading non-compiled weights is unchanged.
Testing
compile=Truecheckpoint withcompile=Falseset in theeval config: previously raised
RuntimeError, now loads cleanly.compile=Falsecheckpoint: behavior unchanged.