diff --git a/tensorizer/serialization.py b/tensorizer/serialization.py index 104fdab..062e5a7 100644 --- a/tensorizer/serialization.py +++ b/tensorizer/serialization.py @@ -634,6 +634,7 @@ def from_io( zero_hashes: bool = True, check_crypt_info: bool = False, long_shape_tensors: frozenset = frozenset(), + max_header_len: Optional[int] = None, ) -> Optional["_TensorHeaderDeserializer"]: # We read the entire header into memory rather than reading # it piecewise to avoid the overhead of many small reads, @@ -643,6 +644,11 @@ def from_io( header_len: int = cls.header_len_segment.unpack(header_len_bytes)[0] if header_len == 0: return None + if max_header_len is not None and header_len > max_header_len: + raise ValueError( + "Tensor header length exceeds metadata bounds:" + f" {header_len} > {max_header_len}" + ) buffer = bytearray(header_len) buffer[:offset] = header_len_bytes with memoryview(buffer) as mv: @@ -3070,6 +3076,9 @@ def _copy_thread( tensor_sizes_by_name: Dict[_TensorPath, int] = { t.name: t.deserialized_length for t in tensor_items } + metadata_by_offset: Dict[int, TensorEntry] = { + t.offset: t for t in unsafe_self._metadata.values() + } # then for each tensor in tensor_items tensors_read = 0 @@ -3077,11 +3086,21 @@ def _copy_thread( if halt: break + header_offset = file_.tell() + metadata_entry = metadata_by_offset.get(header_offset) + if metadata_entry is None: + raise ValueError( + "Unexpected tensor header offset:" + f" {header_offset}" + ) + header = _TensorHeaderDeserializer.from_io( file_, zero_hashes=True, check_crypt_info=unsafe_self._has_crypt_info, long_shape_tensors=unsafe_self._long_shape_tensors, + max_header_len=metadata_entry.data_offset + - metadata_entry.offset, ) if header is None: diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 6df018e..1ca08dc 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -10,6 +10,7 @@ import os import re import secrets +import struct import sys import tempfile import time @@ -345,6 +346,40 @@ def test_serialization(self): finally: os.unlink(serialized_model) + def test_oversized_tensor_header_is_rejected(self): + with temporary_file("wb+") as tensorized_file: + serializer = TensorSerializer(tensorized_file) + serializer.write_state_dict({"tensor": torch.zeros(1)}) + serializer.close() + + with TensorDeserializer( + tensorized_file.name, + device="cpu", + lazy_load=True, + num_readers=1, + ) as deserializer: + entry = deserializer._metadata[("tensor",)] + + with open(tensorized_file.name, "rb") as file: + data = bytearray(file.read()) + + legal_header_len = entry.data_offset - entry.offset + struct.pack_into( + "