From 7e491bc87c0293ed2653774bba9faa539a62b2c7 Mon Sep 17 00:00:00 2001 From: inorganicwriter Date: Thu, 25 Jun 2026 03:29:44 +0800 Subject: [PATCH] Fixed state_dict loading for models trained with torch.compile. When a model is trained with compile=True (the default), torch.compile prepends the '_orig_mod.' prefix to all parameter names in the saved state_dict. Loading those weights into a non-compiled model (e.g. during evaluation, or when compile is disabled on Windows) fails with a key mismatch. Strip the prefix automatically so the weights load transparently in both compiled and non-compiled contexts. --- googlehydrology/evaluation/tester.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/googlehydrology/evaluation/tester.py b/googlehydrology/evaluation/tester.py index 52a4c292..9a9f87c4 100644 --- a/googlehydrology/evaluation/tester.py +++ b/googlehydrology/evaluation/tester.py @@ -163,9 +163,16 @@ def _load_weights(self, epoch: int = None): weight_file = self._get_weight_file(epoch) LOGGER.info(f'Using the model weights from {weight_file}') - self.model.load_state_dict( - torch.load(weight_file, map_location=self.device, weights_only=True) + state_dict = torch.load( + weight_file, map_location=self.device, weights_only=True ) + # Drop `_orig_mod.` prefix introduced by torch.compile when the model + # was trained with compile=True but is now loaded without compilation. + if any(k.startswith('_orig_mod.') for k in state_dict): + state_dict = { + k[len('_orig_mod.'):]: v for k, v in state_dict.items() + } + self.model.load_state_dict(state_dict) def _get_dataset_all(self) -> Dataset: """Get dataset for all basins."""