From 1e8c39afde7e06c419f34fdd17169395925f91d5 Mon Sep 17 00:00:00 2001 From: Michael Kuhlmann Date: Wed, 13 Aug 2025 10:15:28 +0200 Subject: [PATCH 1/2] Support safe_globals in padertorch.Module Safely call padertorch.Module.from_storage without having to set weights_only=False --- padertorch/base.py | 58 ++++++++++++++++++++++++++++--------- padertorch/train/trainer.py | 29 +------------------ 2 files changed, 46 insertions(+), 41 deletions(-) diff --git a/padertorch/base.py b/padertorch/base.py index 4ef63fbb..d2b2d714 100644 --- a/padertorch/base.py +++ b/padertorch/base.py @@ -6,6 +6,8 @@ """ import io import abc +import contextlib +from packaging import version from pathlib import Path import numpy as np @@ -23,6 +25,33 @@ ] +# https://github.com/huggingface/transformers/pull/34632 +def _safe_globals(): + # Starting from version 2.4 PyTorch introduces a check for the objects + # loaded with torch.load(weights_only=True). Starting from 2.6 + # weights_only=True becomes a default and requires allowlisting of objects + # being loaded. + # See: https://github.com/pytorch/pytorch/pull/137602 + # See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals + # See: https://github.com/huggingface/accelerate/pull/3036 + if ( + version.parse(torch.__version__).release + < version.parse("2.6").release + ): + return contextlib.nullcontext() + + np_core = ( + np._core if version.parse(np.__version__) + >= version.parse("2.0.0") else np.core + ) + allowlist = [ + np_core.multiarray.scalar, type(np.dtype(np.float64)), np.dtype, + ] + # Additional types that were allowed by huggingface transformers + # allowlist.extend([np_core.multiarray._reconstruct, np.ndarray, type(np.dtype(np.uint32))]) + return torch.serialization.safe_globals(allowlist) + + class Module(nn.Module, Configurable, abc.ABC): """Abstract base class for configurable Modules.""" @@ -117,20 +146,23 @@ def load_checkpoint( assert checkpoint_path.is_file(), checkpoint_path # Load weights - if consider_mpi: - import dlp_mpi - if dlp_mpi.IS_MASTER: - checkpoint_path_content = Path(checkpoint_path).read_bytes() + with _safe_globals(): + if consider_mpi: + import dlp_mpi + if dlp_mpi.IS_MASTER: + checkpoint_path_content = Path(checkpoint_path).read_bytes() + else: + checkpoint_path_content = None + checkpoint_path_content = dlp_mpi.bcast(checkpoint_path_content) + + checkpoint = torch.load( + io.BytesIO(checkpoint_path_content), + map_location=map_location, + ) else: - checkpoint_path_content = None - checkpoint_path_content = dlp_mpi.bcast(checkpoint_path_content) - - checkpoint = torch.load( - io.BytesIO(checkpoint_path_content), - map_location=map_location, - ) - else: - checkpoint = torch.load(checkpoint_path, map_location=map_location) + checkpoint = torch.load( + checkpoint_path, map_location=map_location, + ) if in_checkpoint_path: for part in in_checkpoint_path.split('.'): diff --git a/padertorch/train/trainer.py b/padertorch/train/trainer.py index 7def6056..7c49d1bc 100644 --- a/padertorch/train/trainer.py +++ b/padertorch/train/trainer.py @@ -11,7 +11,6 @@ from pathlib import Path import functools import collections -from packaging import version import numpy as np import torch @@ -21,6 +20,7 @@ from paderbox.utils.nested import deflatten import padertorch as pt +from padertorch.base import _safe_globals from padertorch.configurable import Configurable from padertorch.train.optimizer import Optimizer, Adam from padertorch.train.runtime_tests import test_run @@ -32,33 +32,6 @@ ] -# https://github.com/huggingface/transformers/pull/34632 -def _safe_globals(): - # Starting from version 2.4 PyTorch introduces a check for the objects - # loaded with torch.load(weights_only=True). Starting from 2.6 - # weights_only=True becomes a default and requires allowlisting of objects - # being loaded. - # See: https://github.com/pytorch/pytorch/pull/137602 - # See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals - # See: https://github.com/huggingface/accelerate/pull/3036 - if ( - version.parse(torch.__version__).release - < version.parse("2.6").release - ): - return contextlib.nullcontext() - - np_core = ( - np._core if version.parse(np.__version__) - >= version.parse("2.0.0") else np.core - ) - allowlist = [ - np_core.multiarray.scalar, type(np.dtype(np.float64)), np.dtype - ] - # Additional types that were allowed by huggingface transformers - # allowlist.extend([np_core.multiarray._reconstruct, np.ndarray, type(np.dtype(np.uint32))]) - return torch.serialization.safe_globals(allowlist) - - class Trainer(Configurable): @classmethod From 854b2f5457ca33ef6f92a5fc8f719b611b7584ce Mon Sep 17 00:00:00 2001 From: Michael Kuhlmann Date: Wed, 13 Aug 2025 10:23:33 +0200 Subject: [PATCH 2/2] Add weights_only option to checkpoint loading In some cases, loading checkpoints with safe_globals will not work. Example: np.core was renamed to np._core in numpy>2.0. Checkpoints that were saved with numpy<2.0 cannot be loaded in an environment with numpy>=2.0. In that case, the only way to load the checkpoint is to set weights_only=False. --- padertorch/base.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/padertorch/base.py b/padertorch/base.py index d2b2d714..dfb9d37e 100644 --- a/padertorch/base.py +++ b/padertorch/base.py @@ -82,6 +82,7 @@ def from_config_and_checkpoint( map_location='cpu', consider_mpi=False, + weights_only=True, ) -> 'Module': """Instantiate the module from given config and checkpoint. @@ -116,6 +117,7 @@ def from_config_and_checkpoint( in_checkpoint_path=in_checkpoint_path, map_location=map_location, consider_mpi=consider_mpi, + weights_only=weights_only, ) def load_checkpoint( @@ -125,6 +127,7 @@ def load_checkpoint( map_location='cpu', consider_mpi=False, + weights_only=True, ) -> 'Module': """Update the module parameters from the given checkpoint. @@ -158,10 +161,12 @@ def load_checkpoint( checkpoint = torch.load( io.BytesIO(checkpoint_path_content), map_location=map_location, + weights_only=weights_only, ) else: checkpoint = torch.load( checkpoint_path, map_location=map_location, + weights_only=weights_only, ) if in_checkpoint_path: @@ -184,6 +189,7 @@ def from_storage_dir( in_config_path: str = 'trainer.model', in_checkpoint_path: str = 'model', consider_mpi=False, + weights_only=True, ) -> 'Module': """Instantiate the module from a given storage directory. @@ -215,6 +221,7 @@ def from_storage_dir( in_config_path=in_config_path, in_checkpoint_path=in_checkpoint_path, consider_mpi=consider_mpi, + weights_only=weights_only, )