Fix ContextBuilder checkpoint loading for non-default architectures#13
Open
harens wants to merge 1 commit into
Open
Fix ContextBuilder checkpoint loading for non-default architectures#13harens wants to merge 1 commit into
harens wants to merge 1 commit into
Conversation
Save ContextBuilder architecture settings alongside the state_dict so non-default num_layers, bidirectional, and LSTM configurations can be restored. Keep loading older raw state_dict checkpoints by inferring constructor settings from stored tensor shapes and recurrent layer keys where possible.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR updates
ContextBuilderpersistence so checkpoints store the architecture metadata needed to reconstruct the model reliably.Previously,
save()wrote only the rawstate_dict, andload()inferred constructor settings from tensor shapes. That worked for some default models, but was unreliable for non-default configurations such as customnum_layers, bidirectional encoders, or LSTM-based models.Changes:
ContextBuilderconstructor metadata alongside thestate_dict.input_size,output_size,hidden_size,num_layers,max_length,bidirectional, andLSTMvalues on load.state_dictcheckpoints.num_layers,bidirectional, andLSTMfrom recurrent tensor keys/shapes where older checkpoints do not include metadata.This should make checkpoint round-tripping reliable for non-default architectures while keeping existing saved models loadable.
For legacy checkpoints, architecture parameters are inferred from PyTorch recurrent weight naming conventions (e.g.
weight_ih_l{k},_reversesuffix, and gate dimensionality). This is a best-effort heuristic.