diff --git a/googlehydrology/evaluation/tester.py b/googlehydrology/evaluation/tester.py index 52a4c292..9a9f87c4 100644 --- a/googlehydrology/evaluation/tester.py +++ b/googlehydrology/evaluation/tester.py @@ -163,9 +163,16 @@ def _load_weights(self, epoch: int = None): weight_file = self._get_weight_file(epoch) LOGGER.info(f'Using the model weights from {weight_file}') - self.model.load_state_dict( - torch.load(weight_file, map_location=self.device, weights_only=True) + state_dict = torch.load( + weight_file, map_location=self.device, weights_only=True ) + # Drop `_orig_mod.` prefix introduced by torch.compile when the model + # was trained with compile=True but is now loaded without compilation. + if any(k.startswith('_orig_mod.') for k in state_dict): + state_dict = { + k[len('_orig_mod.'):]: v for k, v in state_dict.items() + } + self.model.load_state_dict(state_dict) def _get_dataset_all(self) -> Dataset: """Get dataset for all basins."""