Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions googlehydrology/evaluation/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,16 @@ def _load_weights(self, epoch: int = None):
weight_file = self._get_weight_file(epoch)

LOGGER.info(f'Using the model weights from {weight_file}')
self.model.load_state_dict(
torch.load(weight_file, map_location=self.device, weights_only=True)
state_dict = torch.load(
weight_file, map_location=self.device, weights_only=True
)
# Drop `_orig_mod.` prefix introduced by torch.compile when the model
# was trained with compile=True but is now loaded without compilation.
if any(k.startswith('_orig_mod.') for k in state_dict):
state_dict = {
k[len('_orig_mod.'):]: v for k, v in state_dict.items()
}
self.model.load_state_dict(state_dict)

def _get_dataset_all(self) -> Dataset:
"""Get dataset for all basins."""
Expand Down