Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 52 additions & 13 deletions padertorch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
"""
import io
import abc
import contextlib
from packaging import version
from pathlib import Path

import numpy as np
Expand All @@ -23,6 +25,33 @@
]


# https://github.com/huggingface/transformers/pull/34632
def _safe_globals():
# Starting from version 2.4 PyTorch introduces a check for the objects
# loaded with torch.load(weights_only=True). Starting from 2.6
# weights_only=True becomes a default and requires allowlisting of objects
# being loaded.
# See: https://github.com/pytorch/pytorch/pull/137602
# See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals
# See: https://github.com/huggingface/accelerate/pull/3036
if (
version.parse(torch.__version__).release
< version.parse("2.6").release
):
return contextlib.nullcontext()

np_core = (
np._core if version.parse(np.__version__)
>= version.parse("2.0.0") else np.core
)
allowlist = [
np_core.multiarray.scalar, type(np.dtype(np.float64)), np.dtype,
]
# Additional types that were allowed by huggingface transformers
# allowlist.extend([np_core.multiarray._reconstruct, np.ndarray, type(np.dtype(np.uint32))])
return torch.serialization.safe_globals(allowlist)


class Module(nn.Module, Configurable, abc.ABC):
"""Abstract base class for configurable Modules."""

Expand Down Expand Up @@ -53,6 +82,7 @@ def from_config_and_checkpoint(

map_location='cpu',
consider_mpi=False,
weights_only=True,
) -> 'Module':
"""Instantiate the module from given config and checkpoint.

Expand Down Expand Up @@ -87,6 +117,7 @@ def from_config_and_checkpoint(
in_checkpoint_path=in_checkpoint_path,
map_location=map_location,
consider_mpi=consider_mpi,
weights_only=weights_only,
)

def load_checkpoint(
Expand All @@ -96,6 +127,7 @@ def load_checkpoint(

map_location='cpu',
consider_mpi=False,
weights_only=True,
) -> 'Module':
"""Update the module parameters from the given checkpoint.

Expand All @@ -117,20 +149,25 @@ def load_checkpoint(
assert checkpoint_path.is_file(), checkpoint_path

# Load weights
if consider_mpi:
import dlp_mpi
if dlp_mpi.IS_MASTER:
checkpoint_path_content = Path(checkpoint_path).read_bytes()
with _safe_globals():
if consider_mpi:
import dlp_mpi
if dlp_mpi.IS_MASTER:
checkpoint_path_content = Path(checkpoint_path).read_bytes()
else:
checkpoint_path_content = None
checkpoint_path_content = dlp_mpi.bcast(checkpoint_path_content)

checkpoint = torch.load(
io.BytesIO(checkpoint_path_content),
map_location=map_location,
weights_only=weights_only,
)
else:
checkpoint_path_content = None
checkpoint_path_content = dlp_mpi.bcast(checkpoint_path_content)

checkpoint = torch.load(
io.BytesIO(checkpoint_path_content),
map_location=map_location,
)
else:
checkpoint = torch.load(checkpoint_path, map_location=map_location)
checkpoint = torch.load(
checkpoint_path, map_location=map_location,
weights_only=weights_only,
)

if in_checkpoint_path:
for part in in_checkpoint_path.split('.'):
Expand All @@ -152,6 +189,7 @@ def from_storage_dir(
in_config_path: str = 'trainer.model',
in_checkpoint_path: str = 'model',
consider_mpi=False,
weights_only=True,
) -> 'Module':
"""Instantiate the module from a given storage directory.

Expand Down Expand Up @@ -183,6 +221,7 @@ def from_storage_dir(
in_config_path=in_config_path,
in_checkpoint_path=in_checkpoint_path,
consider_mpi=consider_mpi,
weights_only=weights_only,
)


Expand Down
29 changes: 1 addition & 28 deletions padertorch/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from pathlib import Path
import functools
import collections
from packaging import version

import numpy as np
import torch
Expand All @@ -21,6 +20,7 @@

from paderbox.utils.nested import deflatten
import padertorch as pt
from padertorch.base import _safe_globals
from padertorch.configurable import Configurable
from padertorch.train.optimizer import Optimizer, Adam
from padertorch.train.runtime_tests import test_run
Expand All @@ -32,33 +32,6 @@
]


# https://github.com/huggingface/transformers/pull/34632
def _safe_globals():
# Starting from version 2.4 PyTorch introduces a check for the objects
# loaded with torch.load(weights_only=True). Starting from 2.6
# weights_only=True becomes a default and requires allowlisting of objects
# being loaded.
# See: https://github.com/pytorch/pytorch/pull/137602
# See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals
# See: https://github.com/huggingface/accelerate/pull/3036
if (
version.parse(torch.__version__).release
< version.parse("2.6").release
):
return contextlib.nullcontext()

np_core = (
np._core if version.parse(np.__version__)
>= version.parse("2.0.0") else np.core
)
allowlist = [
np_core.multiarray.scalar, type(np.dtype(np.float64)), np.dtype
]
# Additional types that were allowed by huggingface transformers
# allowlist.extend([np_core.multiarray._reconstruct, np.ndarray, type(np.dtype(np.uint32))])
return torch.serialization.safe_globals(allowlist)


class Trainer(Configurable):

@classmethod
Expand Down
Loading