diff --git a/padertorch/base.py b/padertorch/base.py index 4ef63fbb..dfb9d37e 100644 --- a/padertorch/base.py +++ b/padertorch/base.py @@ -6,6 +6,8 @@ """ import io import abc +import contextlib +from packaging import version from pathlib import Path import numpy as np @@ -23,6 +25,33 @@ ] +# https://github.com/huggingface/transformers/pull/34632 +def _safe_globals(): + # Starting from version 2.4 PyTorch introduces a check for the objects + # loaded with torch.load(weights_only=True). Starting from 2.6 + # weights_only=True becomes a default and requires allowlisting of objects + # being loaded. + # See: https://github.com/pytorch/pytorch/pull/137602 + # See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals + # See: https://github.com/huggingface/accelerate/pull/3036 + if ( + version.parse(torch.__version__).release + < version.parse("2.6").release + ): + return contextlib.nullcontext() + + np_core = ( + np._core if version.parse(np.__version__) + >= version.parse("2.0.0") else np.core + ) + allowlist = [ + np_core.multiarray.scalar, type(np.dtype(np.float64)), np.dtype, + ] + # Additional types that were allowed by huggingface transformers + # allowlist.extend([np_core.multiarray._reconstruct, np.ndarray, type(np.dtype(np.uint32))]) + return torch.serialization.safe_globals(allowlist) + + class Module(nn.Module, Configurable, abc.ABC): """Abstract base class for configurable Modules.""" @@ -53,6 +82,7 @@ def from_config_and_checkpoint( map_location='cpu', consider_mpi=False, + weights_only=True, ) -> 'Module': """Instantiate the module from given config and checkpoint. @@ -87,6 +117,7 @@ def from_config_and_checkpoint( in_checkpoint_path=in_checkpoint_path, map_location=map_location, consider_mpi=consider_mpi, + weights_only=weights_only, ) def load_checkpoint( @@ -96,6 +127,7 @@ def load_checkpoint( map_location='cpu', consider_mpi=False, + weights_only=True, ) -> 'Module': """Update the module parameters from the given checkpoint. @@ -117,20 +149,25 @@ def load_checkpoint( assert checkpoint_path.is_file(), checkpoint_path # Load weights - if consider_mpi: - import dlp_mpi - if dlp_mpi.IS_MASTER: - checkpoint_path_content = Path(checkpoint_path).read_bytes() + with _safe_globals(): + if consider_mpi: + import dlp_mpi + if dlp_mpi.IS_MASTER: + checkpoint_path_content = Path(checkpoint_path).read_bytes() + else: + checkpoint_path_content = None + checkpoint_path_content = dlp_mpi.bcast(checkpoint_path_content) + + checkpoint = torch.load( + io.BytesIO(checkpoint_path_content), + map_location=map_location, + weights_only=weights_only, + ) else: - checkpoint_path_content = None - checkpoint_path_content = dlp_mpi.bcast(checkpoint_path_content) - - checkpoint = torch.load( - io.BytesIO(checkpoint_path_content), - map_location=map_location, - ) - else: - checkpoint = torch.load(checkpoint_path, map_location=map_location) + checkpoint = torch.load( + checkpoint_path, map_location=map_location, + weights_only=weights_only, + ) if in_checkpoint_path: for part in in_checkpoint_path.split('.'): @@ -152,6 +189,7 @@ def from_storage_dir( in_config_path: str = 'trainer.model', in_checkpoint_path: str = 'model', consider_mpi=False, + weights_only=True, ) -> 'Module': """Instantiate the module from a given storage directory. @@ -183,6 +221,7 @@ def from_storage_dir( in_config_path=in_config_path, in_checkpoint_path=in_checkpoint_path, consider_mpi=consider_mpi, + weights_only=weights_only, ) diff --git a/padertorch/train/trainer.py b/padertorch/train/trainer.py index 7def6056..7c49d1bc 100644 --- a/padertorch/train/trainer.py +++ b/padertorch/train/trainer.py @@ -11,7 +11,6 @@ from pathlib import Path import functools import collections -from packaging import version import numpy as np import torch @@ -21,6 +20,7 @@ from paderbox.utils.nested import deflatten import padertorch as pt +from padertorch.base import _safe_globals from padertorch.configurable import Configurable from padertorch.train.optimizer import Optimizer, Adam from padertorch.train.runtime_tests import test_run @@ -32,33 +32,6 @@ ] -# https://github.com/huggingface/transformers/pull/34632 -def _safe_globals(): - # Starting from version 2.4 PyTorch introduces a check for the objects - # loaded with torch.load(weights_only=True). Starting from 2.6 - # weights_only=True becomes a default and requires allowlisting of objects - # being loaded. - # See: https://github.com/pytorch/pytorch/pull/137602 - # See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals - # See: https://github.com/huggingface/accelerate/pull/3036 - if ( - version.parse(torch.__version__).release - < version.parse("2.6").release - ): - return contextlib.nullcontext() - - np_core = ( - np._core if version.parse(np.__version__) - >= version.parse("2.0.0") else np.core - ) - allowlist = [ - np_core.multiarray.scalar, type(np.dtype(np.float64)), np.dtype - ] - # Additional types that were allowed by huggingface transformers - # allowlist.extend([np_core.multiarray._reconstruct, np.ndarray, type(np.dtype(np.uint32))]) - return torch.serialization.safe_globals(allowlist) - - class Trainer(Configurable): @classmethod