From 89abb2a76e3401c1b07ddb7ad982f22f1bafbe73 Mon Sep 17 00:00:00 2001 From: Sreesh Maheshwar Date: Wed, 8 Apr 2026 11:53:08 -0700 Subject: [PATCH 1/9] POC: Encryption read support --- pyiceberg/encryption/__init__.py | 17 +++ pyiceberg/encryption/ciphers.py | 116 ++++++++++++++++++ pyiceberg/encryption/io.py | 80 +++++++++++++ pyiceberg/encryption/key_metadata.py | 152 ++++++++++++++++++++++++ pyiceberg/encryption/kms.py | 114 ++++++++++++++++++ pyiceberg/encryption/manager.py | 162 ++++++++++++++++++++++++++ pyiceberg/io/pyarrow.py | 83 ++++++++++--- pyiceberg/manifest.py | 50 +++++++- pyiceberg/table/__init__.py | 64 +++++++++- pyiceberg/table/metadata.py | 24 ++++ pyiceberg/table/snapshots.py | 15 ++- tests/encryption/__init__.py | 16 +++ tests/encryption/test_ciphers.py | 162 ++++++++++++++++++++++++++ tests/encryption/test_io.py | 63 ++++++++++ tests/encryption/test_key_metadata.py | 80 +++++++++++++ tests/encryption/test_kms.py | 144 +++++++++++++++++++++++ tests/encryption/test_manager.py | 161 +++++++++++++++++++++++++ 17 files changed, 1480 insertions(+), 23 deletions(-) create mode 100644 pyiceberg/encryption/__init__.py create mode 100644 pyiceberg/encryption/ciphers.py create mode 100644 pyiceberg/encryption/io.py create mode 100644 pyiceberg/encryption/key_metadata.py create mode 100644 pyiceberg/encryption/kms.py create mode 100644 pyiceberg/encryption/manager.py create mode 100644 tests/encryption/__init__.py create mode 100644 tests/encryption/test_ciphers.py create mode 100644 tests/encryption/test_io.py create mode 100644 tests/encryption/test_key_metadata.py create mode 100644 tests/encryption/test_kms.py create mode 100644 tests/encryption/test_manager.py diff --git a/pyiceberg/encryption/__init__.py b/pyiceberg/encryption/__init__.py new file mode 100644 index 0000000000..9c86adbdce --- /dev/null +++ b/pyiceberg/encryption/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Iceberg table encryption support.""" diff --git a/pyiceberg/encryption/ciphers.py b/pyiceberg/encryption/ciphers.py new file mode 100644 index 0000000000..ed023ce53e --- /dev/null +++ b/pyiceberg/encryption/ciphers.py @@ -0,0 +1,116 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""AES-GCM encryption/decryption primitives and AGS1 stream decryption.""" + +from __future__ import annotations + +import os +import struct + +from cryptography.hazmat.primitives.ciphers.aead import AESGCM + +NONCE_LENGTH = 12 +GCM_TAG_LENGTH = 16 + + +def aes_gcm_encrypt(key: bytes, plaintext: bytes, aad: bytes | None = None) -> bytes: + """Encrypt using AES-GCM. Returns nonce || ciphertext || tag.""" + nonce = os.urandom(NONCE_LENGTH) + aesgcm = AESGCM(key) + ciphertext_with_tag = aesgcm.encrypt(nonce, plaintext, aad) + return nonce + ciphertext_with_tag + + +def aes_gcm_decrypt(key: bytes, ciphertext: bytes, aad: bytes | None = None) -> bytes: + """Decrypt AES-GCM data in format: nonce || ciphertext || tag.""" + if len(ciphertext) < NONCE_LENGTH + GCM_TAG_LENGTH: + raise ValueError(f"Ciphertext too short: {len(ciphertext)} bytes") + nonce = ciphertext[:NONCE_LENGTH] + encrypted_data = ciphertext[NONCE_LENGTH:] + aesgcm = AESGCM(key) + return aesgcm.decrypt(nonce, encrypted_data, aad) + + +# AGS1 stream constants +GCM_STREAM_MAGIC = b"AGS1" +GCM_STREAM_HEADER_LENGTH = 8 # 4 magic + 4 block size + + +def stream_block_aad(aad_prefix: bytes, block_index: int) -> bytes: + """Construct per-block AAD for AGS1 stream encryption. + + Format: aad_prefix || block_index (4 bytes, little-endian). + """ + index_bytes = struct.pack(" bytes: + """Decrypt an entire AGS1 stream and return the plaintext. + + AGS1 format: + - Header: "AGS1" (4 bytes) + plain_block_size (4 bytes LE) + - Blocks: each block is nonce(12) + ciphertext(up to 1MB) + tag(16) + - Each block's AAD = aad_prefix + block_index (4 bytes LE) + + """ + if len(encrypted_data) < GCM_STREAM_HEADER_LENGTH: + raise ValueError(f"AGS1 stream too short: {len(encrypted_data)} bytes") + + magic = encrypted_data[:4] + if magic != GCM_STREAM_MAGIC: + raise ValueError(f"Invalid AGS1 magic: {magic!r}, expected {GCM_STREAM_MAGIC!r}") + + plain_block_size = struct.unpack_from("= cipher_block_size: + block_cipher_size = cipher_block_size + else: + block_cipher_size = remaining + + if block_cipher_size < NONCE_LENGTH + GCM_TAG_LENGTH: + raise ValueError( + f"Truncated AGS1 block at offset {offset}: {block_cipher_size} bytes (minimum {NONCE_LENGTH + GCM_TAG_LENGTH})" + ) + + block_data = stream_data[offset : offset + block_cipher_size] + nonce = block_data[:NONCE_LENGTH] + ciphertext_with_tag = block_data[NONCE_LENGTH:] + + aad = stream_block_aad(aad_prefix, block_index) + plaintext = aesgcm.decrypt(nonce, ciphertext_with_tag, aad) + result.extend(plaintext) + + offset += block_cipher_size + block_index += 1 + + return bytes(result) diff --git a/pyiceberg/encryption/io.py b/pyiceberg/encryption/io.py new file mode 100644 index 0000000000..310cd842a3 --- /dev/null +++ b/pyiceberg/encryption/io.py @@ -0,0 +1,80 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""InputFile implementation backed by in-memory bytes.""" + +from __future__ import annotations + +import io +from types import TracebackType + +from pyiceberg.io import InputFile, InputStream + + +class BytesInputStream(InputStream): + """InputStream implementation backed by a bytes buffer.""" + + def __init__(self, data: bytes) -> None: + self._buffer = io.BytesIO(data) + + def read(self, size: int = 0) -> bytes: + if size <= 0: + return self._buffer.read() + return self._buffer.read(size) + + def seek(self, offset: int, whence: int = 0) -> int: + return self._buffer.seek(offset, whence) + + def tell(self) -> int: + return self._buffer.tell() + + def close(self) -> None: + self._buffer.close() + + def __enter__(self) -> BytesInputStream: + """Enter the context manager.""" + return self + + def __exit__( + self, + exctype: type[BaseException] | None, + excinst: BaseException | None, + exctb: TracebackType | None, + ) -> None: + """Exit the context manager and close the stream.""" + self.close() + + +class BytesInputFile(InputFile): + """InputFile implementation backed by in-memory bytes. + + Used to wrap decrypted data so that it can be read by + AvroFile and other readers that expect an InputFile. + """ + + def __init__(self, location: str, data: bytes) -> None: + super().__init__(location) + self._data = data + + def __len__(self) -> int: + """Return the length of the underlying data.""" + return len(self._data) + + def exists(self) -> bool: + return True + + def open(self, seekable: bool = True) -> InputStream: + return BytesInputStream(self._data) diff --git a/pyiceberg/encryption/key_metadata.py b/pyiceberg/encryption/key_metadata.py new file mode 100644 index 0000000000..4d22778d5c --- /dev/null +++ b/pyiceberg/encryption/key_metadata.py @@ -0,0 +1,152 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""StandardKeyMetadata Avro serialization. + +Wire format: ``0x01 version byte || Avro-encoded fields`` + +Avro schema: + - encryption_key: bytes (required) + - aad_prefix: union[null, bytes] (optional) + - file_length: union[null, long] (optional) +""" + +from __future__ import annotations + +from dataclasses import dataclass + +V1 = 0x01 + + +def _read_avro_long(data: bytes, offset: int) -> tuple[int, int]: + """Read a zigzag-encoded Avro long from data at offset. Returns (value, new_offset).""" + result = 0 + shift = 0 + while True: + if offset >= len(data): + raise ValueError("Unexpected end of Avro data reading long") + b = data[offset] + offset += 1 + result |= (b & 0x7F) << shift + if (b & 0x80) == 0: + break + shift += 7 + # Zigzag decode + return (result >> 1) ^ -(result & 1), offset + + +def _read_avro_bytes(data: bytes, offset: int) -> tuple[bytes, int]: + """Read Avro bytes (length-prefixed). Returns (bytes_value, new_offset).""" + length, offset = _read_avro_long(data, offset) + if length < 0: + raise ValueError(f"Negative Avro bytes length: {length}") + end = offset + length + if end > len(data): + raise ValueError("Unexpected end of Avro data reading bytes") + return data[offset:end], end + + +@dataclass(frozen=True) +class StandardKeyMetadata: + """Standard key metadata for Iceberg table encryption. + + Contains the plaintext encryption key (DEK), AAD prefix, and optional file length. + """ + + encryption_key: bytes + aad_prefix: bytes = b"" + file_length: int | None = None + + @staticmethod + def deserialize(data: bytes) -> StandardKeyMetadata: + """Deserialize from wire format: ``0x01 version || Avro-encoded fields``.""" + if not data: + raise ValueError("Empty key metadata buffer") + + version = data[0] + if version != V1: + raise ValueError(f"Unsupported key metadata version: {version}") + + offset = 1 + + # Read encryption_key (required bytes) + encryption_key, offset = _read_avro_bytes(data, offset) + + # Read aad_prefix (optional: union[null, bytes]) + union_index, offset = _read_avro_long(data, offset) + if union_index == 0: + aad_prefix = b"" + elif union_index == 1: + aad_prefix, offset = _read_avro_bytes(data, offset) + else: + raise ValueError(f"Invalid union index for aad_prefix: {union_index}") + + # Read file_length (optional: union[null, long]) + file_length = None + if offset < len(data): + union_index, offset = _read_avro_long(data, offset) + if union_index == 0: + file_length = None + elif union_index == 1: + file_length, offset = _read_avro_long(data, offset) + else: + raise ValueError(f"Invalid union index for file_length: {union_index}") + + return StandardKeyMetadata( + encryption_key=encryption_key, + aad_prefix=aad_prefix, + file_length=file_length, + ) + + def serialize(self) -> bytes: + """Serialize to wire format: ``0x01 version || Avro-encoded fields``.""" + parts = [bytes([V1])] + + # encryption_key (required bytes) + parts.append(_encode_avro_bytes(self.encryption_key)) + + # aad_prefix (union[null, bytes]) + if self.aad_prefix: + parts.append(_encode_avro_long(1)) # union index 1 = bytes + parts.append(_encode_avro_bytes(self.aad_prefix)) + else: + parts.append(_encode_avro_long(0)) # union index 0 = null + + # file_length (union[null, long]) + if self.file_length is not None: + parts.append(_encode_avro_long(1)) # union index 1 = long + parts.append(_encode_avro_long(self.file_length)) + else: + parts.append(_encode_avro_long(0)) # union index 0 = null + + return b"".join(parts) + + +def _encode_avro_long(value: int) -> bytes: + """Encode a long as zigzag-encoded Avro varint.""" + # Zigzag encode + n = (value << 1) ^ (value >> 63) + result = bytearray() + while n & ~0x7F: + result.append((n & 0x7F) | 0x80) + n >>= 7 + result.append(n & 0x7F) + return bytes(result) + + +def _encode_avro_bytes(data: bytes) -> bytes: + """Encode bytes with Avro length prefix.""" + return _encode_avro_long(len(data)) + data diff --git a/pyiceberg/encryption/kms.py b/pyiceberg/encryption/kms.py new file mode 100644 index 0000000000..2ba1a5e1ac --- /dev/null +++ b/pyiceberg/encryption/kms.py @@ -0,0 +1,114 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Key Management Service interfaces and implementations.""" + +from __future__ import annotations + +import importlib +import logging +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +from pyiceberg.encryption.ciphers import aes_gcm_decrypt, aes_gcm_encrypt + +if TYPE_CHECKING: + from pyiceberg.typedef import Properties + +logger = logging.getLogger(__name__) + +PY_KMS_IMPL = "py-kms-impl" + + +class KeyManagementClient(ABC): + """Abstract base class for key management operations.""" + + @abstractmethod + def wrap_key(self, key: bytes, wrapping_key_id: str) -> bytes: + """Wrap (encrypt) a key using the master key identified by wrapping_key_id.""" + + @abstractmethod + def unwrap_key(self, wrapped_key: bytes, wrapping_key_id: str) -> bytes: + """Unwrap (decrypt) a wrapped key using the master key identified by wrapping_key_id.""" + + def initialize(self, properties: dict[str, str]) -> None: # noqa: B027 + """Initialize the KMS client from catalog/table properties.""" + + +class InMemoryKms(KeyManagementClient): + """In-memory KMS for testing. NOT for production use.""" + + def __init__(self, master_keys: dict[str, bytes] | None = None) -> None: + self._master_keys: dict[str, bytes] = dict(master_keys) if master_keys else {} + + def initialize(self, properties: dict[str, str]) -> None: + for key, value in properties.items(): + if key.startswith("encryption.kms.key."): + key_id = key[len("encryption.kms.key.") :] + self._master_keys[key_id] = bytes.fromhex(value) + + def wrap_key(self, key: bytes, wrapping_key_id: str) -> bytes: + master_key = self._master_keys.get(wrapping_key_id) + if master_key is None: + raise ValueError(f"Wrapping key not found: {wrapping_key_id}") + return aes_gcm_encrypt(master_key, key, aad=None) + + def unwrap_key(self, wrapped_key: bytes, wrapping_key_id: str) -> bytes: + master_key = self._master_keys.get(wrapping_key_id) + if master_key is None: + raise ValueError(f"Wrapping key not found: {wrapping_key_id}") + return aes_gcm_decrypt(master_key, wrapped_key, aad=None) + + +def load_kms_client(properties: Properties) -> KeyManagementClient | None: + """Load a KMS client from properties using py-kms-impl. + + Follows the same pattern as py-io-impl for FileIO. + + The property 'py-kms-impl' should be a fully qualified Python class name + (e.g., 'pyiceberg.encryption.kms.InMemoryKms'). The class must be a + subclass of KeyManagementClient. + + Args: + properties: Catalog and/or table properties. + + Returns: + An initialized KeyManagementClient, or None if py-kms-impl is not set. + """ + kms_impl = properties.get(PY_KMS_IMPL) + if kms_impl is None: + return None + + path_parts = kms_impl.split(".") + if len(path_parts) < 2: + raise ValueError(f"py-kms-impl should be a full path (module.ClassName), got: {kms_impl}") + + module_name, class_name = ".".join(path_parts[:-1]), path_parts[-1] + try: + module = importlib.import_module(module_name) + except ModuleNotFoundError as e: + raise ValueError(f"Could not import KMS module: {module_name}") from e + + kms_class = getattr(module, class_name, None) + if kms_class is None: + raise ValueError(f"KMS class {class_name} not found in module {module_name}") + + if not (isinstance(kms_class, type) and issubclass(kms_class, KeyManagementClient)): + raise ValueError(f"{kms_impl} is not a subclass of KeyManagementClient") + + client = kms_class() + client.initialize(dict(properties)) + return client diff --git a/pyiceberg/encryption/manager.py b/pyiceberg/encryption/manager.py new file mode 100644 index 0000000000..07219d35ad --- /dev/null +++ b/pyiceberg/encryption/manager.py @@ -0,0 +1,162 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Encryption manager implementing two-layer envelope key management. + +Key hierarchy: + - Master Key (in KMS) wraps KEK + - KEK wraps DEK (using local AES-GCM) + - DEK encrypts data (manifest lists, manifests, data files) + +The KEK timestamp is used as AAD when wrapping/unwrapping DEKs +to prevent timestamp tampering attacks. +""" + +from __future__ import annotations + +import logging + +from pyiceberg.encryption.ciphers import aes_gcm_decrypt, decrypt_ags1_stream +from pyiceberg.encryption.key_metadata import StandardKeyMetadata +from pyiceberg.encryption.kms import KeyManagementClient + +logger = logging.getLogger(__name__) + +KEK_CREATED_AT_PROPERTY = "KEY_TIMESTAMP" + + +class EncryptedKey: + """Represents an encrypted key entry from table metadata.""" + + def __init__( + self, + key_id: str, + encrypted_key_metadata: bytes, + encrypted_by_id: str | None = None, + properties: dict[str, str] | None = None, + ) -> None: + self.key_id = key_id + self.encrypted_key_metadata = encrypted_key_metadata + self.encrypted_by_id = encrypted_by_id + self.properties = properties or {} + + def __repr__(self) -> str: + """Return a string representation of the EncryptedKey.""" + return ( + f"EncryptedKey(key_id={self.key_id!r}, " + f"encrypted_by_id={self.encrypted_by_id!r}, " + f"metadata_len={len(self.encrypted_key_metadata)})" + ) + + +class EncryptionManager: + """Manages encryption/decryption for an Iceberg table. + + Orchestrates the two-layer envelope key management: + 1. Unwrap KEK via KMS using master key + 2. Use KEK to decrypt manifest list/manifest key metadata (with timestamp AAD) + 3. Parse StandardKeyMetadata to get DEK + AAD prefix + 4. Decrypt AGS1 streams or provide FileDecryptionProperties for Parquet + """ + + def __init__( + self, + kms_client: KeyManagementClient, + encryption_keys: dict[str, EncryptedKey] | None = None, + ) -> None: + self._kms = kms_client + self._encryption_keys = encryption_keys or {} + self._kek_cache: dict[str, bytes] = {} + + def _unwrap_kek(self, kek: EncryptedKey) -> bytes: + """Unwrap a KEK using the KMS, with caching.""" + if kek.key_id in self._kek_cache: + return self._kek_cache[kek.key_id] + + if not kek.encrypted_by_id: + raise ValueError(f"KEK '{kek.key_id}' has no encrypted_by_id") + + plaintext = self._kms.unwrap_key(kek.encrypted_key_metadata, kek.encrypted_by_id) + self._kek_cache[kek.key_id] = plaintext + return plaintext + + def _unwrap_dek(self, wrapped_dek: bytes, kek_key_id: str) -> bytes: + """Unwrap a DEK using the specified KEK. + + Uses the KEK timestamp as AAD to prevent timestamp tampering. + """ + kek = self._encryption_keys.get(kek_key_id) + if kek is None: + raise ValueError(f"KEK not found in encryption keys: {kek_key_id}") + + kek_bytes = self._unwrap_kek(kek) + + # Use KEK timestamp as AAD to prevent tampering + aad = kek.properties.get(KEK_CREATED_AT_PROPERTY) + aad_bytes = aad.encode("utf-8") if aad else None + + return aes_gcm_decrypt(kek_bytes, wrapped_dek, aad=aad_bytes) + + def unwrap_key_metadata(self, encrypted_key: EncryptedKey) -> bytes: + """Unwrap key metadata that was KEK-wrapped. + + Given an EncryptedKey entry (e.g., from a snapshot's key-id mapping), + unwrap it using the KEK identified by encrypted_by_id. + """ + if not encrypted_key.encrypted_by_id: + raise ValueError(f"EncryptedKey '{encrypted_key.key_id}' has no encrypted_by_id") + + return self._unwrap_dek( + encrypted_key.encrypted_key_metadata, + encrypted_key.encrypted_by_id, + ) + + def decrypt_manifest_list(self, encrypted_data: bytes, snapshot_key_id: str) -> bytes: + """Decrypt an AGS1-encrypted manifest list. + + 1. Look up the EncryptedKey for the snapshot's key_id + 2. Unwrap the key metadata using the KEK + 3. Parse StandardKeyMetadata to get DEK + AAD prefix + 4. Decrypt the AGS1 stream + """ + encrypted_key = self._encryption_keys.get(snapshot_key_id) + if encrypted_key is None: + raise ValueError(f"Snapshot key not found in encryption keys: {snapshot_key_id}") + + # Unwrap the key metadata + key_metadata_bytes = self.unwrap_key_metadata(encrypted_key) + key_metadata = StandardKeyMetadata.deserialize(key_metadata_bytes) + + return decrypt_ags1_stream( + key=key_metadata.encryption_key, + encrypted_data=encrypted_data, + aad_prefix=key_metadata.aad_prefix, + ) + + def decrypt_manifest(self, encrypted_data: bytes, key_metadata_bytes: bytes) -> bytes: + """Decrypt an AGS1-encrypted manifest file. + + The key_metadata_bytes are from ManifestFile.key_metadata -- these contain + the plaintext DEK and AAD prefix (NOT wrapped by KEK, since they're already + stored inside the encrypted manifest list). + """ + key_metadata = StandardKeyMetadata.deserialize(key_metadata_bytes) + + return decrypt_ags1_stream( + key=key_metadata.encryption_key, + encrypted_data=encrypted_data, + aad_prefix=key_metadata.aad_prefix, + ) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index d4414c7c52..dd026bf005 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -187,6 +187,7 @@ from pyiceberg.utils.truncate import truncate_upper_bound_binary_string, truncate_upper_bound_text_string if TYPE_CHECKING: + from pyiceberg.encryption.manager import EncryptionManager from pyiceberg.table import FileScanTask, WriteTask logger = logging.getLogger(__name__) @@ -1116,12 +1117,47 @@ def _get_file_format(file_format: FileFormat, **kwargs: dict[str, Any]) -> ds.Fi raise ValueError(f"Unsupported file format: {file_format}") -def _read_deletes(io: FileIO, data_file: DataFile) -> dict[str, pa.ChunkedArray]: +def _get_decryption_properties(key_metadata_bytes: bytes) -> Any: + """Build FileDecryptionProperties from Iceberg key metadata. + + Requires a custom PyArrow build with pyarrow.parquet.encryption support. + """ + try: + import pyarrow.parquet.encryption as pe + except ImportError as e: + raise ImportError( + "Parquet Modular Encryption requires a PyArrow build with encryption support. " + "See PYARROW_ENCRYPTION_HANDOFF.md for build instructions." + ) from e + + from pyiceberg.encryption.key_metadata import StandardKeyMetadata + + key_metadata = StandardKeyMetadata.deserialize(key_metadata_bytes) + return pe.create_decryption_properties( + footer_key=key_metadata.encryption_key, + aad_prefix=key_metadata.aad_prefix if key_metadata.aad_prefix else None, + ) + + +def _read_deletes( + io: FileIO, data_file: DataFile, encryption_manager: EncryptionManager | None = None +) -> dict[str, pa.ChunkedArray]: if data_file.file_format == FileFormat.PARQUET: + arrow_format = _get_file_format( + data_file.file_format, dictionary_columns=("file_path",), pre_buffer=True, buffer_size=ONE_MEGABYTE + ) + + if data_file.key_metadata is not None and encryption_manager is not None: + decryption_properties = _get_decryption_properties(data_file.key_metadata) + scan_options = ds.ParquetFragmentScanOptions( + decryption_properties=decryption_properties, + pre_buffer=True, + buffer_size=ONE_MEGABYTE, + ) + arrow_format = ds.ParquetFileFormat(default_fragment_scan_options=scan_options) + with io.new_input(data_file.file_path).open() as fi: - delete_fragment = _get_file_format( - data_file.file_format, dictionary_columns=("file_path",), pre_buffer=True, buffer_size=ONE_MEGABYTE - ).make_fragment(fi) + delete_fragment = arrow_format.make_fragment(fi) table = ds.Scanner.from_fragment(fragment=delete_fragment).to_table() table = table.unify_dictionaries() return { @@ -1624,8 +1660,21 @@ def _task_to_record_batches( partition_spec: PartitionSpec | None = None, format_version: TableVersion = TableProperties.DEFAULT_FORMAT_VERSION, downcast_ns_timestamp_to_us: bool | None = None, + encryption_manager: EncryptionManager | None = None, ) -> Iterator[pa.RecordBatch]: arrow_format = _get_file_format(task.file.file_format, pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8)) + + # For encrypted files, create a ParquetFileFormat with decryption properties + # so that make_fragment can read the encrypted metadata + if task.file.key_metadata is not None and encryption_manager is not None: + decryption_properties = _get_decryption_properties(task.file.key_metadata) + scan_options = ds.ParquetFragmentScanOptions( + decryption_properties=decryption_properties, + pre_buffer=True, + buffer_size=(ONE_MEGABYTE * 8), + ) + arrow_format = ds.ParquetFileFormat(default_fragment_scan_options=scan_options) + with io.new_input(task.file.file_path).open() as fin: fragment = arrow_format.make_fragment(fin) physical_schema = fragment.physical_schema @@ -1655,14 +1704,15 @@ def _task_to_record_batches( file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False) - fragment_scanner = ds.Scanner.from_fragment( - fragment=fragment, - schema=physical_schema, + scanner_kwargs: dict[str, Any] = { + "fragment": fragment, + "schema": physical_schema, # This will push down the query to Arrow. # But in case there are positional deletes, we have to apply them first - filter=pyarrow_filter if not positional_deletes else None, - columns=[col.name for col in file_project_schema.columns], - ) + "filter": pyarrow_filter if not positional_deletes else None, + "columns": [col.name for col in file_project_schema.columns], + } + fragment_scanner = ds.Scanner.from_fragment(**scanner_kwargs) next_index = 0 batches = fragment_scanner.to_batches() @@ -1701,14 +1751,16 @@ def _task_to_record_batches( ) -def _read_all_delete_files(io: FileIO, tasks: Iterable[FileScanTask]) -> dict[str, list[ChunkedArray]]: +def _read_all_delete_files( + io: FileIO, tasks: Iterable[FileScanTask], encryption_manager: EncryptionManager | None = None +) -> dict[str, list[ChunkedArray]]: deletes_per_file: dict[str, list[ChunkedArray]] = {} unique_deletes = set(itertools.chain.from_iterable([task.delete_files for task in tasks])) if len(unique_deletes) > 0: executor = ExecutorFactory.get_or_create() deletes_per_files: Iterator[dict[str, ChunkedArray]] = executor.map( lambda args: _read_deletes(*args), - [(io, delete_file) for delete_file in unique_deletes], + [(io, delete_file, encryption_manager) for delete_file in unique_deletes], ) for delete in deletes_per_files: for file, arr in delete.items(): @@ -1728,6 +1780,7 @@ class ArrowScan: _case_sensitive: bool _limit: int | None _downcast_ns_timestamp_to_us: bool | None + _encryption_manager: EncryptionManager | None """Scan the Iceberg Table and create an Arrow construct. Attributes: @@ -1737,6 +1790,7 @@ class ArrowScan: _bound_row_filter: Schema bound row expression to filter the data with _case_sensitive: Case sensitivity when looking up column names _limit: Limit the number of records. + _encryption_manager: Optional encryption manager for decrypting data files. """ def __init__( @@ -1747,6 +1801,7 @@ def __init__( row_filter: BooleanExpression, case_sensitive: bool = True, limit: int | None = None, + encryption_manager: EncryptionManager | None = None, ) -> None: self._table_metadata = table_metadata self._io = io @@ -1755,6 +1810,7 @@ def __init__( self._case_sensitive = case_sensitive self._limit = limit self._downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) + self._encryption_manager = encryption_manager @property def _projected_field_ids(self) -> set[int]: @@ -1817,7 +1873,7 @@ def to_record_batches(self, tasks: Iterable[FileScanTask]) -> Iterator[pa.Record ResolveError: When a required field cannot be found in the file ValueError: When a field type in the file cannot be projected to the schema type """ - deletes_per_file = _read_all_delete_files(self._io, tasks) + deletes_per_file = _read_all_delete_files(self._io, tasks, self._encryption_manager) total_row_count = 0 executor = ExecutorFactory.get_or_create() @@ -1865,6 +1921,7 @@ def _record_batches_from_scan_tasks_and_deletes( self._table_metadata.specs().get(task.file.spec_id), self._table_metadata.format_version, self._downcast_ns_timestamp_to_us, + encryption_manager=self._encryption_manager, ) for batch in batches: if self._limit is not None: diff --git a/pyiceberg/manifest.py b/pyiceberg/manifest.py index 3811a9d894..d5b98c5b49 100644 --- a/pyiceberg/manifest.py +++ b/pyiceberg/manifest.py @@ -24,6 +24,7 @@ from enum import Enum from types import TracebackType from typing import ( + TYPE_CHECKING, Any, Literal, ) @@ -39,6 +40,9 @@ from pyiceberg.partitioning import PartitionSpec from pyiceberg.schema import Schema from pyiceberg.typedef import Record, TableVersion + +if TYPE_CHECKING: + from pyiceberg.encryption.manager import EncryptionManager from pyiceberg.types import ( BinaryType, BooleanType, @@ -858,18 +862,34 @@ def has_added_files(self) -> bool: def has_existing_files(self) -> bool: return self.existing_files_count is None or self.existing_files_count > 0 - def fetch_manifest_entry(self, io: FileIO, discard_deleted: bool = True) -> list[ManifestEntry]: + def fetch_manifest_entry( + self, + io: FileIO, + discard_deleted: bool = True, + encryption_manager: EncryptionManager | None = None, + ) -> list[ManifestEntry]: """ Read the manifest entries from the manifest file. Args: io: The FileIO to fetch the file. discard_deleted: Filter on live entries. + encryption_manager: Optional encryption manager for decrypting encrypted manifests. Returns: An Iterator of manifest entries. """ input_file = io.new_input(self.manifest_path) + + # If this manifest has key_metadata, it's AGS1-encrypted + if self.key_metadata is not None and encryption_manager is not None: + from pyiceberg.encryption.io import BytesInputFile + + with input_file.open() as f: + encrypted_data = f.read() + decrypted_data = encryption_manager.decrypt_manifest(encrypted_data, self.key_metadata) + input_file = BytesInputFile(self.manifest_path, decrypted_data) + with AvroFile[ManifestEntry]( input_file, MANIFEST_ENTRY_SCHEMAS[DEFAULT_READ_VERSION], @@ -900,7 +920,12 @@ def __hash__(self) -> int: _manifest_cache_lock = threading.RLock() -def _manifests(io: FileIO, manifest_list: str) -> tuple[ManifestFile, ...]: +def _manifests( + io: FileIO, + manifest_list: str, + encryption_manager: EncryptionManager | None = None, + snapshot_key_id: str | None = None, +) -> tuple[ManifestFile, ...]: """Read manifests from a manifest list, deduplicating ManifestFile objects via cache. Caches individual ManifestFile objects by manifest_path. This is memory-efficient @@ -920,12 +945,14 @@ def _manifests(io: FileIO, manifest_list: str) -> tuple[ManifestFile, ...]: Args: io: FileIO instance for reading the manifest list. manifest_list: Path to the manifest list file. + encryption_manager: Optional encryption manager for decrypting encrypted manifest lists. + snapshot_key_id: Optional key ID from snapshot for manifest list decryption. Returns: A tuple of ManifestFile objects. """ file = io.new_input(manifest_list) - manifest_files = list(read_manifest_list(file)) + manifest_files = list(read_manifest_list(file, encryption_manager=encryption_manager, snapshot_key_id=snapshot_key_id)) result = [] with _manifest_cache_lock: @@ -940,16 +967,31 @@ def _manifests(io: FileIO, manifest_list: str) -> tuple[ManifestFile, ...]: return tuple(result) -def read_manifest_list(input_file: InputFile) -> Iterator[ManifestFile]: +def read_manifest_list( + input_file: InputFile, + encryption_manager: EncryptionManager | None = None, + snapshot_key_id: str | None = None, +) -> Iterator[ManifestFile]: """ Read the manifests from the manifest list. Args: input_file: The input file where the stream can be read from. + encryption_manager: Optional encryption manager for decrypting encrypted manifest lists. + snapshot_key_id: Optional key ID from snapshot for manifest list decryption. Returns: An iterator of ManifestFiles that are part of the list. """ + # If we have encryption info, decrypt the manifest list first + if snapshot_key_id is not None and encryption_manager is not None: + from pyiceberg.encryption.io import BytesInputFile + + with input_file.open() as f: + encrypted_data = f.read() + decrypted_data = encryption_manager.decrypt_manifest_list(encrypted_data, snapshot_key_id) + input_file = BytesInputFile(input_file.location, decrypted_data) + with AvroFile[ManifestFile]( input_file, MANIFEST_LIST_FILE_SCHEMAS[DEFAULT_READ_VERSION], diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index b8d87143c9..90e7dcc1bc 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -107,6 +107,7 @@ from pyiceberg.catalog import Catalog from pyiceberg.catalog.rest.scan_planning import RESTContentFile, RESTDeleteFile, RESTFileScanTask + from pyiceberg.encryption.manager import EncryptionManager ALWAYS_TRUE = AlwaysTrue() DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE = "downcast-ns-timestamp-to-us-on-write" @@ -1891,6 +1892,7 @@ def _open_manifest( manifest: ManifestFile, partition_filter: Callable[[DataFile], bool], metrics_evaluator: Callable[[DataFile], bool], + encryption_manager: EncryptionManager | None = None, ) -> list[ManifestEntry]: """Open a manifest file and return matching manifest entries. @@ -1899,7 +1901,7 @@ def _open_manifest( """ return [ manifest_entry - for manifest_entry in manifest.fetch_manifest_entry(io, discard_deleted=True) + for manifest_entry in manifest.fetch_manifest_entry(io, discard_deleted=True, encryption_manager=encryption_manager) if partition_filter(manifest_entry.data_file) and metrics_evaluator(manifest_entry.data_file) ] @@ -1917,6 +1919,13 @@ def _min_sequence_number(manifests: list[ManifestFile]) -> int: class DataScan(TableScan): + @cached_property + def _merged_properties(self) -> Properties: + """Catalog properties merged with table properties (table wins).""" + if self.catalog is not None: + return {**self.catalog.properties, **self.table_metadata.properties} + return dict(self.table_metadata.properties) + def _build_partition_projection(self, spec_id: int) -> BooleanExpression: project = inclusive_projection(self.table_metadata.schema(), self.table_metadata.specs()[spec_id], self.case_sensitive) return project(self.row_filter) @@ -1988,6 +1997,38 @@ def _check_sequence_number(min_sequence_number: int, manifest: ManifestFile) -> and (manifest.sequence_number or INITIAL_SEQUENCE_NUMBER) >= min_sequence_number ) + @cached_property + def _encryption_manager(self) -> EncryptionManager | None: + """Create an EncryptionManager if the table has encryption configured.""" + if self.table_metadata.format_version < 3: + return None + + encryption_keys = self.table_metadata.encryption_keys + if not encryption_keys: + return None + + from pyiceberg.encryption.kms import load_kms_client + from pyiceberg.encryption.manager import EncryptedKey + from pyiceberg.encryption.manager import EncryptionManager as EncryptionManagerClass + + kms_client = load_kms_client(self._merged_properties) + if kms_client is None: + return None + + enc_keys_map: dict[str, EncryptedKey] = {} + for ek in encryption_keys: + enc_keys_map[ek.key_id] = EncryptedKey( + key_id=ek.key_id, + encrypted_key_metadata=ek.encrypted_key_metadata_bytes, + encrypted_by_id=ek.encrypted_by_id, + properties=dict(ek.properties), + ) + + return EncryptionManagerClass( + kms_client=kms_client, + encryption_keys=enc_keys_map, + ) + def scan_plan_helper(self) -> Iterator[list[ManifestEntry]]: """Filter and return manifest entries based on partition and metrics evaluators. @@ -1998,6 +2039,8 @@ def scan_plan_helper(self) -> Iterator[list[ManifestEntry]]: if not snapshot: return iter([]) + encryption_manager = self._encryption_manager + # step 1: filter manifests using partition summaries # the filter depends on the partition spec used to write the manifest file, so create a cache of filters for each spec id @@ -2005,7 +2048,7 @@ def scan_plan_helper(self) -> Iterator[list[ManifestEntry]]: manifests = [ manifest_file - for manifest_file in snapshot.manifests(self.io) + for manifest_file in snapshot.manifests(self.io, encryption_manager=encryption_manager) if manifest_evaluators[manifest_file.partition_spec_id](manifest_file) ] @@ -2026,6 +2069,7 @@ def scan_plan_helper(self) -> Iterator[list[ManifestEntry]]: manifest, partition_evaluators[manifest.partition_spec_id], self._build_metrics_evaluator(), + encryption_manager, ) for manifest in manifests if self._check_sequence_number(min_sequence_number, manifest) @@ -2114,7 +2158,13 @@ def to_arrow(self) -> pa.Table: from pyiceberg.io.pyarrow import ArrowScan return ArrowScan( - self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit + self.table_metadata, + self.io, + self.projection(), + self.row_filter, + self.case_sensitive, + self.limit, + encryption_manager=self._encryption_manager, ).to_table(self.plan_files()) def to_arrow_batch_reader(self) -> pa.RecordBatchReader: @@ -2134,7 +2184,13 @@ def to_arrow_batch_reader(self) -> pa.RecordBatchReader: target_schema = schema_to_pyarrow(self.projection()) batches = ArrowScan( - self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit + self.table_metadata, + self.io, + self.projection(), + self.row_filter, + self.case_sensitive, + self.limit, + encryption_manager=self._encryption_manager, ).to_record_batches(self.plan_files()) return pa.RecordBatchReader.from_batches( diff --git a/pyiceberg/table/metadata.py b/pyiceberg/table/metadata.py index 26b6e3d3ad..ed07177af4 100644 --- a/pyiceberg/table/metadata.py +++ b/pyiceberg/table/metadata.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import base64 import datetime import uuid from collections.abc import Iterable @@ -48,6 +49,25 @@ from pyiceberg.utils.config import Config from pyiceberg.utils.datetime import datetime_to_millis + +class EncryptedKeyModel(IcebergBaseModel): + """An encrypted key entry in table metadata. + + Matches the EncryptedKey schema in the REST API spec. + """ + + key_id: str = Field(alias="key-id") + encrypted_key_metadata: str = Field(alias="encrypted-key-metadata") + """Base64-encoded encrypted key metadata bytes.""" + encrypted_by_id: str | None = Field(alias="encrypted-by-id", default=None) + properties: dict[str, str] = Field(default_factory=dict) + + @property + def encrypted_key_metadata_bytes(self) -> bytes: + """Decode the base64-encoded encrypted key metadata.""" + return base64.b64decode(self.encrypted_key_metadata) + + CURRENT_SNAPSHOT_ID = "current-snapshot-id" CURRENT_SCHEMA_ID = "current-schema-id" SCHEMAS = "schemas" @@ -574,6 +594,10 @@ def construct_refs(self) -> TableMetadata: next_row_id: int | None = Field(alias="next-row-id", default=None) """A long higher than all assigned row IDs; the next snapshot's `first-row-id`.""" + encryption_keys: list[EncryptedKeyModel] = Field(alias="encryption-keys", default_factory=list) + """Encryption key entries for the two-layer envelope encryption scheme. + Only valid for format version 3 and higher.""" + def model_dump_json(self, exclude_none: bool = True, exclude: Any | None = None, by_alias: bool = True, **kwargs: Any) -> str: raise NotImplementedError("Writing V3 is not yet supported, see: https://github.com/apache/iceberg-python/issues/1551") diff --git a/pyiceberg/table/snapshots.py b/pyiceberg/table/snapshots.py index 7e4c6eb1ec..52bbc54852 100644 --- a/pyiceberg/table/snapshots.py +++ b/pyiceberg/table/snapshots.py @@ -31,6 +31,7 @@ from pyiceberg.schema import Schema if TYPE_CHECKING: + from pyiceberg.encryption.manager import EncryptionManager from pyiceberg.table.metadata import TableMetadata from pyiceberg.typedef import IcebergBaseModel @@ -252,6 +253,9 @@ class Snapshot(IcebergBaseModel): added_rows: int | None = Field( alias="added-rows", default=None, description="The upper bound of the number of rows with assigned row IDs" ) + key_id: str | None = Field( + alias="key-id", default=None, description="ID of the encryption key used to encrypt this snapshot's manifest list" + ) def __str__(self) -> str: """Return the string representation of the Snapshot class.""" @@ -277,9 +281,16 @@ def __repr__(self) -> str: filtered_fields = [field for field in fields if field is not None] return f"Snapshot({', '.join(filtered_fields)})" - def manifests(self, io: FileIO) -> list[ManifestFile]: + def manifests(self, io: FileIO, encryption_manager: EncryptionManager | None = None) -> list[ManifestFile]: """Return the manifests for the given snapshot.""" - return list(_manifests(io, self.manifest_list)) + return list( + _manifests( + io, + self.manifest_list, + encryption_manager=encryption_manager, + snapshot_key_id=self.key_id, + ) + ) class MetadataLogEntry(IcebergBaseModel): diff --git a/tests/encryption/__init__.py b/tests/encryption/__init__.py new file mode 100644 index 0000000000..13a83393a9 --- /dev/null +++ b/tests/encryption/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/encryption/test_ciphers.py b/tests/encryption/test_ciphers.py new file mode 100644 index 0000000000..ccb7d058d5 --- /dev/null +++ b/tests/encryption/test_ciphers.py @@ -0,0 +1,162 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os +import struct + +import pytest + +from pyiceberg.encryption.ciphers import ( + GCM_STREAM_MAGIC, + GCM_TAG_LENGTH, + NONCE_LENGTH, + aes_gcm_decrypt, + aes_gcm_encrypt, + decrypt_ags1_stream, + stream_block_aad, +) + + +class TestAesGcm: + def test_roundtrip(self) -> None: + key = os.urandom(16) + plaintext = b"hello, encryption" + ciphertext = aes_gcm_encrypt(key, plaintext) + assert aes_gcm_decrypt(key, ciphertext) == plaintext + + def test_roundtrip_with_aad(self) -> None: + key = os.urandom(16) + plaintext = b"hello, encryption" + aad = b"additional-data" + ciphertext = aes_gcm_encrypt(key, plaintext, aad=aad) + assert aes_gcm_decrypt(key, ciphertext, aad=aad) == plaintext + + def test_wrong_key_fails(self) -> None: + from cryptography.exceptions import InvalidTag + + key = os.urandom(16) + wrong_key = os.urandom(16) + ciphertext = aes_gcm_encrypt(key, b"secret") + with pytest.raises(InvalidTag): + aes_gcm_decrypt(wrong_key, ciphertext) + + def test_wrong_aad_fails(self) -> None: + from cryptography.exceptions import InvalidTag + + key = os.urandom(16) + ciphertext = aes_gcm_encrypt(key, b"secret", aad=b"correct") + with pytest.raises(InvalidTag): + aes_gcm_decrypt(key, ciphertext, aad=b"wrong") + + def test_wire_format(self) -> None: + """Verify the wire format: nonce(12) || ciphertext || tag(16).""" + key = os.urandom(16) + plaintext = b"test" + ciphertext = aes_gcm_encrypt(key, plaintext) + # Minimum size: nonce + tag + at least len(plaintext) of ciphertext + assert len(ciphertext) == NONCE_LENGTH + len(plaintext) + GCM_TAG_LENGTH + + def test_ciphertext_too_short(self) -> None: + key = os.urandom(16) + with pytest.raises(ValueError, match="Ciphertext too short"): + aes_gcm_decrypt(key, b"short") + + +class TestStreamBlockAad: + def test_with_prefix(self) -> None: + aad = stream_block_aad(b"prefix", 0) + assert aad == b"prefix" + struct.pack(" None: + aad = stream_block_aad(b"prefix", 42) + assert aad == b"prefix" + struct.pack(" None: + aad = stream_block_aad(b"", 7) + assert aad == struct.pack(" bytes: + """Build an AGS1 encrypted stream for testing.""" + from cryptography.hazmat.primitives.ciphers.aead import AESGCM + + aesgcm = AESGCM(key) + header = GCM_STREAM_MAGIC + struct.pack(" None: + key = os.urandom(16) + plaintext = b"hello AGS1 stream" + aad_prefix = b"test-aad" + encrypted = _encrypt_ags1_stream(key, plaintext, aad_prefix) + assert decrypt_ags1_stream(key, encrypted, aad_prefix) == plaintext + + def test_roundtrip_multi_block(self) -> None: + """Test with a small block size to force multiple blocks.""" + key = os.urandom(16) + plaintext = b"A" * 200 + aad_prefix = b"multi" + # Use a 64-byte block size to get multiple blocks + encrypted = _encrypt_ags1_stream(key, plaintext, aad_prefix, plain_block_size=64) + assert decrypt_ags1_stream(key, encrypted, aad_prefix) == plaintext + + def test_roundtrip_empty_payload(self) -> None: + key = os.urandom(16) + header = GCM_STREAM_MAGIC + struct.pack(" None: + data = b"XXXX" + struct.pack(" None: + with pytest.raises(ValueError, match="AGS1 stream too short"): + decrypt_ags1_stream(os.urandom(16), b"AGS1", b"") + + def test_custom_block_size(self) -> None: + """Verify the block size from the header is respected, not hardcoded.""" + key = os.urandom(16) + plaintext = b"B" * 300 + aad_prefix = b"custom" + # Encrypt with a 100-byte block size + encrypted = _encrypt_ags1_stream(key, plaintext, aad_prefix, plain_block_size=100) + # Verify the header contains 100 + assert struct.unpack_from(" None: + key = os.urandom(16) + # Header + a few bytes that are too short for even nonce+tag + data = GCM_STREAM_MAGIC + struct.pack(" None: + stream = BytesInputStream(b"hello") + assert stream.read() == b"hello" + + def test_read_partial(self) -> None: + stream = BytesInputStream(b"hello world") + assert stream.read(5) == b"hello" + assert stream.read(6) == b" world" + + def test_seek_and_tell(self) -> None: + stream = BytesInputStream(b"abcdef") + assert stream.tell() == 0 + stream.seek(3) + assert stream.tell() == 3 + assert stream.read(2) == b"de" + + def test_context_manager(self) -> None: + with BytesInputStream(b"data") as stream: + assert stream.read() == b"data" + + def test_implements_input_stream_protocol(self) -> None: + stream = BytesInputStream(b"test") + assert isinstance(stream, InputStream) + + +class TestBytesInputFile: + def test_len(self) -> None: + f = BytesInputFile("file://test", b"hello") + assert len(f) == 5 + + def test_exists(self) -> None: + f = BytesInputFile("file://test", b"") + assert f.exists() is True + + def test_open_and_read(self) -> None: + f = BytesInputFile("file://test", b"content") + with f.open() as stream: + assert stream.read() == b"content" + + def test_location(self) -> None: + f = BytesInputFile("s3://bucket/path", b"data") + assert f.location == "s3://bucket/path" diff --git a/tests/encryption/test_key_metadata.py b/tests/encryption/test_key_metadata.py new file mode 100644 index 0000000000..e9f05fc078 --- /dev/null +++ b/tests/encryption/test_key_metadata.py @@ -0,0 +1,80 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os + +import pytest + +from pyiceberg.encryption.key_metadata import StandardKeyMetadata + + +class TestStandardKeyMetadata: + def test_roundtrip_key_only(self) -> None: + key = os.urandom(16) + original = StandardKeyMetadata(encryption_key=key) + serialized = original.serialize() + restored = StandardKeyMetadata.deserialize(serialized) + assert restored.encryption_key == key + assert restored.aad_prefix == b"" + assert restored.file_length is None + + def test_roundtrip_with_aad_prefix(self) -> None: + key = os.urandom(16) + aad = os.urandom(8) + original = StandardKeyMetadata(encryption_key=key, aad_prefix=aad) + serialized = original.serialize() + restored = StandardKeyMetadata.deserialize(serialized) + assert restored.encryption_key == key + assert restored.aad_prefix == aad + assert restored.file_length is None + + def test_roundtrip_all_fields(self) -> None: + key = os.urandom(32) + aad = os.urandom(16) + original = StandardKeyMetadata(encryption_key=key, aad_prefix=aad, file_length=12345) + serialized = original.serialize() + restored = StandardKeyMetadata.deserialize(serialized) + assert restored.encryption_key == key + assert restored.aad_prefix == aad + assert restored.file_length == 12345 + + def test_version_byte(self) -> None: + """First byte should always be 0x01.""" + key = os.urandom(16) + serialized = StandardKeyMetadata(encryption_key=key).serialize() + assert serialized[0] == 0x01 + + def test_deserialize_empty(self) -> None: + with pytest.raises(ValueError, match="Empty key metadata"): + StandardKeyMetadata.deserialize(b"") + + def test_deserialize_wrong_version(self) -> None: + with pytest.raises(ValueError, match="Unsupported key metadata version"): + StandardKeyMetadata.deserialize(b"\x02\x00") + + def test_frozen(self) -> None: + """StandardKeyMetadata is a frozen dataclass.""" + skm = StandardKeyMetadata(encryption_key=b"key") + with pytest.raises(AttributeError): + skm.encryption_key = b"other" # type: ignore[misc] + + def test_roundtrip_large_file_length(self) -> None: + """Zigzag encoding should handle large values correctly.""" + key = os.urandom(16) + original = StandardKeyMetadata(encryption_key=key, file_length=2**40) + serialized = original.serialize() + restored = StandardKeyMetadata.deserialize(serialized) + assert restored.file_length == 2**40 diff --git a/tests/encryption/test_kms.py b/tests/encryption/test_kms.py new file mode 100644 index 0000000000..d445f27201 --- /dev/null +++ b/tests/encryption/test_kms.py @@ -0,0 +1,144 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os + +import pytest + +from pyiceberg.encryption.kms import InMemoryKms, KeyManagementClient, load_kms_client + + +class TestInMemoryKms: + def test_wrap_unwrap_roundtrip(self) -> None: + master_key = os.urandom(16) + kms = InMemoryKms(master_keys={"keyA": master_key}) + plaintext_key = os.urandom(16) + wrapped = kms.wrap_key(plaintext_key, "keyA") + unwrapped = kms.unwrap_key(wrapped, "keyA") + assert unwrapped == plaintext_key + + def test_unknown_wrapping_key(self) -> None: + kms = InMemoryKms() + with pytest.raises(ValueError, match="Wrapping key not found"): + kms.wrap_key(b"key", "nonexistent") + + def test_unknown_unwrapping_key(self) -> None: + kms = InMemoryKms() + with pytest.raises(ValueError, match="Wrapping key not found"): + kms.unwrap_key(b"wrapped", "nonexistent") + + def test_initialize_from_properties(self) -> None: + kms = InMemoryKms() + key_hex = os.urandom(16).hex() + kms.initialize({"encryption.kms.key.testKey": key_hex}) + # Should be able to wrap/unwrap with the initialized key + wrapped = kms.wrap_key(b"secret", "testKey") + assert kms.unwrap_key(wrapped, "testKey") == b"secret" + + def test_initialize_ignores_unrelated_properties(self) -> None: + kms = InMemoryKms() + kms.initialize({"some.other.prop": "value", "encryption.kms.key.k1": os.urandom(16).hex()}) + with pytest.raises(ValueError, match="Wrapping key not found"): + kms.wrap_key(b"key", "nonexistent") + + def test_wrap_unwrap_with_standard_test_keys(self) -> None: + """Wrap/unwrap with the standard Iceberg test master keys.""" + kms = InMemoryKms( + master_keys={ + "keyA": b"0123456789012345", + "keyB": b"1123456789012345", + } + ) + plaintext = os.urandom(16) + wrapped_a = kms.wrap_key(plaintext, "keyA") + assert kms.unwrap_key(wrapped_a, "keyA") == plaintext + wrapped_b = kms.wrap_key(plaintext, "keyB") + assert kms.unwrap_key(wrapped_b, "keyB") == plaintext + + def test_wrong_master_key_fails_unwrap(self) -> None: + from cryptography.exceptions import InvalidTag + + kms = InMemoryKms( + master_keys={ + "keyA": os.urandom(16), + "keyB": os.urandom(16), + } + ) + wrapped = kms.wrap_key(b"secret", "keyA") + with pytest.raises(InvalidTag): + kms.unwrap_key(wrapped, "keyB") + + +class TestLoadKmsClient: + def test_returns_none_when_not_configured(self) -> None: + assert load_kms_client({}) is None + + def test_loads_in_memory_kms(self) -> None: + client = load_kms_client( + { + "py-kms-impl": "pyiceberg.encryption.kms.InMemoryKms", + "encryption.kms.key.myKey": os.urandom(16).hex(), + } + ) + assert client is not None + assert isinstance(client, InMemoryKms) + # Should be initialized — the key should be usable + wrapped = client.wrap_key(b"data", "myKey") + assert client.unwrap_key(wrapped, "myKey") == b"data" + + def test_invalid_short_path(self) -> None: + with pytest.raises(ValueError, match="full path"): + load_kms_client({"py-kms-impl": "InMemoryKms"}) + + def test_nonexistent_module(self) -> None: + with pytest.raises(ValueError, match="Could not import"): + load_kms_client({"py-kms-impl": "nonexistent.module.Kms"}) + + def test_nonexistent_class(self) -> None: + with pytest.raises(ValueError, match="not found in module"): + load_kms_client({"py-kms-impl": "pyiceberg.encryption.kms.NonexistentClass"}) + + def test_not_a_subclass(self) -> None: + with pytest.raises(ValueError, match="not a subclass"): + # AESGCM is a real class but not a KeyManagementClient + load_kms_client({"py-kms-impl": "cryptography.hazmat.primitives.ciphers.aead.AESGCM"}) + + def test_custom_kms_impl(self) -> None: + """Verify that a custom KMS implementation can be loaded by module path.""" + + class _TestKms(KeyManagementClient): + initialized_with: dict[str, str] = {} + + def wrap_key(self, key: bytes, wrapping_key_id: str) -> bytes: + return key + + def unwrap_key(self, wrapped_key: bytes, wrapping_key_id: str) -> bytes: + return wrapped_key + + def initialize(self, properties: dict[str, str]) -> None: + _TestKms.initialized_with = properties + + # Register in the module namespace so importlib can find it + import pyiceberg.encryption.kms as kms_module + + kms_module._TestKms = _TestKms # type: ignore[attr-defined] + try: + client = load_kms_client({"py-kms-impl": "pyiceberg.encryption.kms._TestKms", "foo": "bar"}) + assert client is not None + assert isinstance(client, _TestKms) + assert _TestKms.initialized_with.get("foo") == "bar" + finally: + delattr(kms_module, "_TestKms") diff --git a/tests/encryption/test_manager.py b/tests/encryption/test_manager.py new file mode 100644 index 0000000000..fed30cca84 --- /dev/null +++ b/tests/encryption/test_manager.py @@ -0,0 +1,161 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os +import struct + +import pytest + +from pyiceberg.encryption.ciphers import ( + GCM_STREAM_MAGIC, + NONCE_LENGTH, + aes_gcm_encrypt, + stream_block_aad, +) +from pyiceberg.encryption.key_metadata import StandardKeyMetadata +from pyiceberg.encryption.kms import InMemoryKms +from pyiceberg.encryption.manager import EncryptedKey, EncryptionManager + + +def _make_ags1_stream(key: bytes, plaintext: bytes, aad_prefix: bytes, plain_block_size: int = 1024 * 1024) -> bytes: + """Build an AGS1 encrypted stream for testing.""" + from cryptography.hazmat.primitives.ciphers.aead import AESGCM + + aesgcm = AESGCM(key) + header = GCM_STREAM_MAGIC + struct.pack(" tuple[EncryptionManager, bytes, bytes, str]: + """Build an EncryptionManager with test keys, mimicking the REST catalog flow. + + Returns (manager, dek, aad_prefix, manifest_list_key_id). + """ + master_key = b"0123456789012345" # 16 bytes, standard Iceberg test key "keyA" + kms = InMemoryKms(master_keys={"keyA": master_key}) + + # Create a KEK (simulating what the REST catalog would provide) + kek_bytes = os.urandom(16) + kek_wrapped = kms.wrap_key(kek_bytes, "keyA") + kek_timestamp = "1234567890" + + # Create a DEK for the manifest list + dek = os.urandom(16) + aad_prefix = os.urandom(8) + key_metadata = StandardKeyMetadata(encryption_key=dek, aad_prefix=aad_prefix) + key_metadata_bytes = key_metadata.serialize() + + # Wrap the DEK key metadata with the KEK (using timestamp as AAD) + wrapped_dek = aes_gcm_encrypt(kek_bytes, key_metadata_bytes, aad=kek_timestamp.encode("utf-8")) + + encryption_keys = { + "kek-1": EncryptedKey( + key_id="kek-1", + encrypted_key_metadata=kek_wrapped, + encrypted_by_id="keyA", + properties={"KEY_TIMESTAMP": kek_timestamp}, + ), + "mlk-1": EncryptedKey( + key_id="mlk-1", + encrypted_key_metadata=wrapped_dek, + encrypted_by_id="kek-1", + ), + } + + manager = EncryptionManager(kms_client=kms, encryption_keys=encryption_keys) + return manager, dek, aad_prefix, "mlk-1" + + +class TestEncryptionManager: + def test_decrypt_manifest_list(self) -> None: + manager, dek, aad_prefix, mlk_id = _build_test_encryption_manager() + plaintext = b"manifest list content here" + encrypted_stream = _make_ags1_stream(dek, plaintext, aad_prefix) + result = manager.decrypt_manifest_list(encrypted_stream, mlk_id) + assert result == plaintext + + def test_decrypt_manifest(self) -> None: + dek = os.urandom(16) + aad_prefix = os.urandom(8) + plaintext = b"manifest content here" + encrypted_stream = _make_ags1_stream(dek, plaintext, aad_prefix) + key_metadata = StandardKeyMetadata(encryption_key=dek, aad_prefix=aad_prefix) + + # The manager only needs a KMS for KEK unwrapping; manifest decryption + # uses the plaintext key metadata directly (from inside the encrypted manifest list) + kms = InMemoryKms() + manager = EncryptionManager(kms_client=kms) + result = manager.decrypt_manifest(encrypted_stream, key_metadata.serialize()) + assert result == plaintext + + def test_kek_caching(self) -> None: + """KEK should be unwrapped once and cached.""" + manager, dek, aad_prefix, mlk_id = _build_test_encryption_manager() + plaintext = b"test" + encrypted = _make_ags1_stream(dek, plaintext, aad_prefix) + + # Decrypt twice + manager.decrypt_manifest_list(encrypted, mlk_id) + manager.decrypt_manifest_list(encrypted, mlk_id) + + # KEK should be cached + assert "kek-1" in manager._kek_cache + + def test_missing_snapshot_key(self) -> None: + kms = InMemoryKms() + manager = EncryptionManager(kms_client=kms) + with pytest.raises(ValueError, match="Snapshot key not found"): + manager.decrypt_manifest_list(b"data", "nonexistent-key") + + def test_missing_kek(self) -> None: + kms = InMemoryKms() + encryption_keys = { + "mlk-1": EncryptedKey( + key_id="mlk-1", + encrypted_key_metadata=b"wrapped", + encrypted_by_id="missing-kek", + ), + } + manager = EncryptionManager(kms_client=kms, encryption_keys=encryption_keys) + with pytest.raises(ValueError, match="KEK not found"): + manager.decrypt_manifest_list(b"data", "mlk-1") + + def test_kek_without_encrypted_by_id(self) -> None: + kms = InMemoryKms(master_keys={"keyA": os.urandom(16)}) + encryption_keys = { + "kek-1": EncryptedKey(key_id="kek-1", encrypted_key_metadata=b"data"), + "mlk-1": EncryptedKey( + key_id="mlk-1", + encrypted_key_metadata=b"wrapped", + encrypted_by_id="kek-1", + ), + } + manager = EncryptionManager(kms_client=kms, encryption_keys=encryption_keys) + with pytest.raises(ValueError, match="has no encrypted_by_id"): + manager.decrypt_manifest_list(b"data", "mlk-1") From 8427e54fae75023a4a85567b4874b03ff52ca961 Mon Sep 17 00:00:00 2001 From: Sreesh Maheshwar Date: Fri, 29 May 2026 02:18:20 +0100 Subject: [PATCH 2/9] infra: hive encryption integration test --- Makefile | 20 +++-- dev/provision.py | 10 +++ dev/spark/Dockerfile | 12 ++- dev/spark/spark-defaults.conf | 5 ++ tests/integration/test_encryption.py | 112 +++++++++++++++++++++++++++ 5 files changed, 148 insertions(+), 11 deletions(-) create mode 100644 tests/integration/test_encryption.py diff --git a/Makefile b/Makefile index 4fe761192c..2474ed0e83 100644 --- a/Makefile +++ b/Makefile @@ -16,7 +16,7 @@ # under the License. .PHONY: help install install-uv check-license lint \ test test-integration test-integration-setup test-integration-exec test-integration-cleanup test-integration-rebuild \ - test-s3 test-adls test-gcs test-coverage coverage-report test test-notebook\ + test-s3 test-adls test-gcs test-coverage coverage-report \ docs-serve docs-build notebook notebook-infra \ clean @@ -38,10 +38,12 @@ else PYTHON_ARG = endif +# --no-sync so that overlays applied after `make install` (e.g. install-pyarrow-nightly for +# the encryption integration test) aren't reverted by uv re-syncing the lockfile on `uv run`. ifeq ($(COVERAGE),1) - TEST_RUNNER = uv run $(PYTHON_ARG) python -m coverage run --parallel-mode --source=pyiceberg -m + TEST_RUNNER = uv run --no-sync $(PYTHON_ARG) python -m coverage run --parallel-mode --source=pyiceberg -m else - TEST_RUNNER = uv run $(PYTHON_ARG) python -m + TEST_RUNNER = uv run --no-sync $(PYTHON_ARG) python -m endif ifeq ($(KEEP_COMPOSE),1) @@ -108,12 +110,19 @@ test: ## Run all unit tests (excluding integration) test-integration: test-integration-setup test-integration-exec test-integration-cleanup ## Run integration tests -test-integration-setup: install ## Start Docker services for integration tests +test-integration-setup: install install-pyarrow-nightly ## Start Docker services for integration tests docker compose -f dev/docker-compose-integration.yml kill docker compose -f dev/docker-compose-integration.yml rm -f docker compose -f dev/docker-compose-integration.yml up -d --build --wait uv run $(PYTHON_ARG) python dev/provision.py +# Parquet Modular Encryption decryption (tests/integration/test_encryption.py) needs the +# pyarrow.parquet.encryption.create_decryption_properties API from apache/arrow#49667. That +# lands in pyarrow 25, which hasn't been released — pull the nightly until it is. +install-pyarrow-nightly: ## Overlay nightly pyarrow on top of the installed env (for PME) + uv pip install $(PYTHON_ARG) --prerelease=allow --upgrade --force-reinstall \ + -i https://pypi.anaconda.org/scientific-python-nightly-wheels/simple pyarrow + test-integration-exec: ## Run integration tests (excluding provision) $(TEST_RUNNER) pytest tests/ -m integration $(PYTEST_ARGS) @@ -150,9 +159,6 @@ coverage-report: ## Combine and report coverage uv run $(PYTHON_ARG) coverage html uv run $(PYTHON_ARG) coverage xml -test-notebook: ## Run notebook tests (pyiceberg_example and spark_integration_example) via papermill - $(TEST_RUNNER) pytest tests/notebooks/test_pyiceberg_example.py tests/notebooks/test_spark_integration_example.py -m notebook $(PYTEST_ARGS) - # ================ # Documentation # ================ diff --git a/dev/provision.py b/dev/provision.py index 695ef9b1bf..dd5c3f931f 100644 --- a/dev/provision.py +++ b/dev/provision.py @@ -395,3 +395,13 @@ ) spark.sql(f"ALTER TABLE {catalog_name}.default.test_empty_scan_ordered_str WRITE ORDERED BY id") spark.sql(f"INSERT INTO {catalog_name}.default.test_empty_scan_ordered_str VALUES 'a', 'c'") + +# Encrypted Iceberg table written via Spark, read back via PyIceberg in tests/integration/test_encryption.py. +# Only the Hive catalog is configured with a Java-side KMS (encryption.kms-impl=UnitestKMS); the REST catalog +# image does not ship UnitestKMS so we limit this fixture to Hive. +spark.sql(""" + CREATE OR REPLACE TABLE hive.default.test_encrypted (id bigint, data string, value float) + USING iceberg + TBLPROPERTIES ('encryption.key-id'='keyA', 'format-version'='3') +""") +spark.sql("INSERT INTO hive.default.test_encrypted VALUES (1, 'alice', 1.0), (2, 'bob', 2.0), (3, 'charlie', 3.0)") diff --git a/dev/spark/Dockerfile b/dev/spark/Dockerfile index 0e1f29d152..4f8f063fa7 100644 --- a/dev/spark/Dockerfile +++ b/dev/spark/Dockerfile @@ -18,8 +18,10 @@ ARG BASE_IMAGE_SPARK_VERSION=4.0.1 FROM apache/spark:${BASE_IMAGE_SPARK_VERSION} # Dependency versions - keep these compatible -# Changing these will invalidate the JAR download cache layer -ARG ICEBERG_VERSION=1.10.1 +# Changing these will invalidate the JAR download cache layer. +# Iceberg 1.11.0 carries the Hive encryption integration (apache/iceberg#13066) — the prior +# 1.10.x release predates that work and silently no-ops encryption.kms-impl / encryption.key-id. +ARG ICEBERG_VERSION=1.11.0 ARG ICEBERG_SPARK_RUNTIME_VERSION=4.0_2.13 ARG HADOOP_VERSION=3.4.1 ARG AWS_SDK_VERSION=2.24.6 @@ -36,13 +38,15 @@ RUN apt-get update -qq && \ mkdir -p /home/iceberg/spark-events && \ chown -R spark:spark /home/iceberg -# Download JARs with retry logic (most cacheable - only changes when versions change) -# This is the slowest step, so we do it before copying config files +# Download JARs with retry logic (most cacheable - only changes when versions change). +# iceberg-core-${ICEBERG_VERSION}-tests.jar ships org.apache.iceberg.encryption.UnitestKMS, a +# fixed-master-key KMS used by the encryption integration test on the Spark write path. RUN set -e && \ cd "${SPARK_HOME}/jars" && \ for jar_path in \ "org/apache/iceberg/iceberg-spark-runtime-${ICEBERG_SPARK_RUNTIME_VERSION}/${ICEBERG_VERSION}/iceberg-spark-runtime-${ICEBERG_SPARK_RUNTIME_VERSION}-${ICEBERG_VERSION}.jar" \ "org/apache/iceberg/iceberg-aws-bundle/${ICEBERG_VERSION}/iceberg-aws-bundle-${ICEBERG_VERSION}.jar" \ + "org/apache/iceberg/iceberg-core/${ICEBERG_VERSION}/iceberg-core-${ICEBERG_VERSION}-tests.jar" \ "org/apache/hadoop/hadoop-aws/${HADOOP_VERSION}/hadoop-aws-${HADOOP_VERSION}.jar" \ "software/amazon/awssdk/bundle/${AWS_SDK_VERSION}/bundle-${AWS_SDK_VERSION}.jar"; \ do \ diff --git a/dev/spark/spark-defaults.conf b/dev/spark/spark-defaults.conf index 4e50f590c7..02bd4ac5cc 100644 --- a/dev/spark/spark-defaults.conf +++ b/dev/spark/spark-defaults.conf @@ -34,6 +34,11 @@ spark.sql.catalog.hive.io-impl org.apache.iceberg.aws.s3.S3FileIO spark.sql.catalog.hive.warehouse s3://warehouse/hive/ spark.sql.catalog.hive.s3.endpoint http://minio:9000 +# Test-only KMS so Spark can write encrypted Iceberg tables for the encryption integration test. +# UnitestKMS comes from iceberg-core--tests.jar and uses fixed master keys ("keyA", +# "keyB") that match the InMemoryKms config used on the PyIceberg side. +spark.sql.catalog.hive.encryption.kms-impl org.apache.iceberg.encryption.UnitestKMS + # Configure Spark's default session catalog (spark_catalog) to use Iceberg backed by the Hive Metastore spark.sql.catalog.spark_catalog org.apache.iceberg.spark.SparkSessionCatalog spark.sql.catalog.spark_catalog.type hive diff --git a/tests/integration/test_encryption.py b/tests/integration/test_encryption.py new file mode 100644 index 0000000000..1d069864c9 --- /dev/null +++ b/tests/integration/test_encryption.py @@ -0,0 +1,112 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Reads of Spark-written, Parquet-encrypted Iceberg tables via PyIceberg. + +The encrypted table (`hive.default.test_encrypted`) is provisioned by `dev/provision.py` +using Spark with `encryption.kms-impl=org.apache.iceberg.encryption.UnitestKMS`. UnitestKMS +ships hardcoded master keys (keyA=b"0123456789012345", keyB=b"1123456789012345"); we mirror +those bytes here through PyIceberg's InMemoryKms so unwrapping succeeds. + +Decryption of the data files requires PyArrow's `parquet.encryption.create_decryption_properties` +API, which is available in PyArrow >= 25 (currently shipped only via the nightly wheels). See +the Makefile target `install-pyarrow-nightly`. +""" + +from __future__ import annotations + +import pytest + +from pyiceberg.catalog import load_catalog + +# UnitestKMS master keys, hex-encoded so they can be set as catalog properties and parsed by +# InMemoryKms.initialize (`encryption.kms.key.=`). +_KEY_A_HEX = b"0123456789012345".hex() +_KEY_B_HEX = b"1123456789012345".hex() + + +@pytest.fixture(scope="module") +def hive_catalog_with_kms(): # type: ignore[no-untyped-def] + return load_catalog( + "local", + **{ + "type": "hive", + "uri": "thrift://localhost:9083", + "s3.endpoint": "http://localhost:9000", + "s3.access-key-id": "admin", + "s3.secret-access-key": "password", + "py-kms-impl": "pyiceberg.encryption.kms.InMemoryKms", + "encryption.kms.key.keyA": _KEY_A_HEX, + "encryption.kms.key.keyB": _KEY_B_HEX, + }, + ) + + +@pytest.mark.integration +def test_encrypted_table_metadata(hive_catalog_with_kms) -> None: # type: ignore[no-untyped-def] + tbl = hive_catalog_with_kms.load_table("default.test_encrypted") + + assert tbl.metadata.format_version == 3 + assert tbl.metadata.properties.get("encryption.key-id") == "keyA" + assert tbl.metadata.encryption_keys, "expected encryption keys on table metadata" + + snapshot = tbl.current_snapshot() + assert snapshot is not None + assert snapshot.key_id is not None, "expected key_id on current snapshot" + + +@pytest.mark.integration +def test_encrypted_table_to_arrow(hive_catalog_with_kms) -> None: # type: ignore[no-untyped-def] + tbl = hive_catalog_with_kms.load_table("default.test_encrypted") + + result = tbl.scan().to_arrow().sort_by("id") + + assert result.num_rows == 3 + assert result.column("id").to_pylist() == [1, 2, 3] + assert result.column("data").to_pylist() == ["alice", "bob", "charlie"] + assert result.column("value").to_pylist() == [1.0, 2.0, 3.0] + + +@pytest.mark.integration +def test_encrypted_table_to_pandas(hive_catalog_with_kms) -> None: # type: ignore[no-untyped-def] + tbl = hive_catalog_with_kms.load_table("default.test_encrypted") + + df = tbl.scan().to_pandas().sort_values("id").reset_index(drop=True) + + assert list(df["id"]) == [1, 2, 3] + assert list(df["data"]) == ["alice", "bob", "charlie"] + assert list(df["value"]) == [1.0, 2.0, 3.0] + + +@pytest.mark.integration +def test_encrypted_table_to_duckdb(hive_catalog_with_kms) -> None: # type: ignore[no-untyped-def] + tbl = hive_catalog_with_kms.load_table("default.test_encrypted") + + con = tbl.scan().to_duckdb("encrypted") + rows = con.execute("SELECT id, data, value FROM encrypted ORDER BY id").fetchall() + + assert rows == [(1, "alice", 1.0), (2, "bob", 2.0), (3, "charlie", 3.0)] + + +@pytest.mark.integration +def test_encrypted_table_to_polars(hive_catalog_with_kms) -> None: # type: ignore[no-untyped-def] + tbl = hive_catalog_with_kms.load_table("default.test_encrypted") + + df = tbl.scan().to_polars().sort("id") + + assert df["id"].to_list() == [1, 2, 3] + assert df["data"].to_list() == ["alice", "bob", "charlie"] + assert df["value"].to_list() == [1.0, 2.0, 3.0] From a511f4c6efba77a054504fe882337ac24db2582f Mon Sep 17 00:00:00 2001 From: Sreesh Maheshwar Date: Fri, 29 May 2026 02:24:32 +0100 Subject: [PATCH 3/9] tests: assert direct parquet read fails, polish PME error message --- pyiceberg/io/pyarrow.py | 12 +++++++++--- tests/integration/test_encryption.py | 21 +++++++++++++++++++++ 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index dd026bf005..f272498fa4 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1120,14 +1120,20 @@ def _get_file_format(file_format: FileFormat, **kwargs: dict[str, Any]) -> ds.Fi def _get_decryption_properties(key_metadata_bytes: bytes) -> Any: """Build FileDecryptionProperties from Iceberg key metadata. - Requires a custom PyArrow build with pyarrow.parquet.encryption support. + Requires PyArrow >= 25 (currently nightly-only) for the direct-key + `create_decryption_properties` API added by apache/arrow#49667. """ try: import pyarrow.parquet.encryption as pe + + if not hasattr(pe, "create_decryption_properties"): + raise ImportError("create_decryption_properties not available") except ImportError as e: raise ImportError( - "Parquet Modular Encryption requires a PyArrow build with encryption support. " - "See PYARROW_ENCRYPTION_HANDOFF.md for build instructions." + "Parquet Modular Encryption requires PyArrow >= 25 with the direct-key API " + "(apache/arrow#49667). Until it releases, install the nightly: " + "`make install-pyarrow-nightly` (or `uv pip install -i " + "https://pypi.anaconda.org/scientific-python-nightly-wheels/simple pyarrow`)." ) from e from pyiceberg.encryption.key_metadata import StandardKeyMetadata diff --git a/tests/integration/test_encryption.py b/tests/integration/test_encryption.py index 1d069864c9..4f6b8b57a3 100644 --- a/tests/integration/test_encryption.py +++ b/tests/integration/test_encryption.py @@ -110,3 +110,24 @@ def test_encrypted_table_to_polars(hive_catalog_with_kms) -> None: # type: igno assert df["id"].to_list() == [1, 2, 3] assert df["data"].to_list() == ["alice", "bob", "charlie"] assert df["value"].to_list() == [1.0, 2.0, 3.0] + + +@pytest.mark.integration +def test_encrypted_table_direct_parquet_read_fails(hive_catalog_with_kms) -> None: # type: ignore[no-untyped-def] + """Canary: a raw PyArrow read of a data file without decryption properties must fail. + + Mirrors iceberg-java's TestTableEncryption#testDirectDataFileRead, which proves the data + files are genuinely PME-encrypted by asserting that reading them without the keys raises + ParquetCryptoRuntimeException. Without this check, the read tests above could silently pass + on plaintext Parquet and the POC would be meaningless. + """ + import pyarrow.parquet as pq + + tbl = hive_catalog_with_kms.load_table("default.test_encrypted") + + data_files = [task.file.file_path for task in tbl.scan().plan_files()] + assert data_files, "expected at least one data file in the encrypted table" + + for file_path in data_files: + with pytest.raises(OSError, match="encrypted"), tbl.io.new_input(file_path).open() as fi: + pq.read_table(fi) From 915dd8737f68ceec3fa9594ca97217bcc55ff0dd Mon Sep 17 00:00:00 2001 From: Sreesh Maheshwar Date: Fri, 29 May 2026 02:44:52 +0100 Subject: [PATCH 4/9] ci: install nightly pyarrow after provision so uv run doesn't revert it --- Makefile | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/Makefile b/Makefile index 2474ed0e83..90cffc7aad 100644 --- a/Makefile +++ b/Makefile @@ -110,15 +110,17 @@ test: ## Run all unit tests (excluding integration) test-integration: test-integration-setup test-integration-exec test-integration-cleanup ## Run integration tests -test-integration-setup: install install-pyarrow-nightly ## Start Docker services for integration tests +test-integration-setup: install ## Start Docker services for integration tests docker compose -f dev/docker-compose-integration.yml kill docker compose -f dev/docker-compose-integration.yml rm -f docker compose -f dev/docker-compose-integration.yml up -d --build --wait - uv run $(PYTHON_ARG) python dev/provision.py + uv run --no-sync $(PYTHON_ARG) python dev/provision.py + $(MAKE) install-pyarrow-nightly # Parquet Modular Encryption decryption (tests/integration/test_encryption.py) needs the # pyarrow.parquet.encryption.create_decryption_properties API from apache/arrow#49667. That -# lands in pyarrow 25, which hasn't been released — pull the nightly until it is. +# lands in pyarrow 25, which hasn't been released — pull the nightly until it is. Runs AFTER +# provision so that the implicit `uv run` sync during provision.py doesn't revert this overlay. install-pyarrow-nightly: ## Overlay nightly pyarrow on top of the installed env (for PME) uv pip install $(PYTHON_ARG) --prerelease=allow --upgrade --force-reinstall \ -i https://pypi.anaconda.org/scientific-python-nightly-wheels/simple pyarrow From bd5dd46b3c703920a64ee7010a06c462eb3f6d20 Mon Sep 17 00:00:00 2001 From: Sreesh Maheshwar Date: Fri, 29 May 2026 03:16:56 +0100 Subject: [PATCH 5/9] bump hadoop-aws to 3.4.3 to match iceberg 1.11.0's S3ABlockOutputStream API --- dev/spark/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/spark/Dockerfile b/dev/spark/Dockerfile index 4f8f063fa7..0ead8c4376 100644 --- a/dev/spark/Dockerfile +++ b/dev/spark/Dockerfile @@ -23,7 +23,7 @@ FROM apache/spark:${BASE_IMAGE_SPARK_VERSION} # 1.10.x release predates that work and silently no-ops encryption.kms-impl / encryption.key-id. ARG ICEBERG_VERSION=1.11.0 ARG ICEBERG_SPARK_RUNTIME_VERSION=4.0_2.13 -ARG HADOOP_VERSION=3.4.1 +ARG HADOOP_VERSION=3.4.3 ARG AWS_SDK_VERSION=2.24.6 ARG MAVEN_MIRROR=https://repo.maven.apache.org/maven2 From 807a96390a39d0da8df4e9d0423744595dffa64d Mon Sep 17 00:00:00 2001 From: Sreesh Maheshwar Date: Fri, 29 May 2026 05:29:18 +0100 Subject: [PATCH 6/9] spark: replace bundled hadoop-client 3.4.1 jars with HADOOP_VERSION so aws bundle methods resolve --- dev/spark/Dockerfile | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/dev/spark/Dockerfile b/dev/spark/Dockerfile index 0ead8c4376..58c3dd61a6 100644 --- a/dev/spark/Dockerfile +++ b/dev/spark/Dockerfile @@ -43,11 +43,17 @@ RUN apt-get update -qq && \ # fixed-master-key KMS used by the encryption integration test on the Spark write path. RUN set -e && \ cd "${SPARK_HOME}/jars" && \ + # Spark 4.0.1 ships hadoop-client-{api,runtime} 3.4.1; replace them with HADOOP_VERSION so + # hadoop-aws methods like ConfigurationHelper.resolveEnum / S3ABlockOutputStream.builder() + # that iceberg-aws-bundle 1.11.0 expects are present at runtime. + rm -f hadoop-client-api-*.jar hadoop-client-runtime-*.jar && \ for jar_path in \ "org/apache/iceberg/iceberg-spark-runtime-${ICEBERG_SPARK_RUNTIME_VERSION}/${ICEBERG_VERSION}/iceberg-spark-runtime-${ICEBERG_SPARK_RUNTIME_VERSION}-${ICEBERG_VERSION}.jar" \ "org/apache/iceberg/iceberg-aws-bundle/${ICEBERG_VERSION}/iceberg-aws-bundle-${ICEBERG_VERSION}.jar" \ "org/apache/iceberg/iceberg-core/${ICEBERG_VERSION}/iceberg-core-${ICEBERG_VERSION}-tests.jar" \ "org/apache/hadoop/hadoop-aws/${HADOOP_VERSION}/hadoop-aws-${HADOOP_VERSION}.jar" \ + "org/apache/hadoop/hadoop-client-api/${HADOOP_VERSION}/hadoop-client-api-${HADOOP_VERSION}.jar" \ + "org/apache/hadoop/hadoop-client-runtime/${HADOOP_VERSION}/hadoop-client-runtime-${HADOOP_VERSION}.jar" \ "software/amazon/awssdk/bundle/${AWS_SDK_VERSION}/bundle-${AWS_SDK_VERSION}.jar"; \ do \ jar_name=$(basename "${jar_path}") && \ From 48644d0905ebcae38fe0ef00897ec60a1edaf4c4 Mon Sep 17 00:00:00 2001 From: Sreesh Maheshwar Date: Fri, 29 May 2026 05:47:03 +0100 Subject: [PATCH 7/9] spark: ship UnitestKMS via slim repackaged jar (full -tests.jar shadowed hadoop-aws S3ABlockOutputStream) --- dev/spark/Dockerfile | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/dev/spark/Dockerfile b/dev/spark/Dockerfile index 58c3dd61a6..e578ee8158 100644 --- a/dev/spark/Dockerfile +++ b/dev/spark/Dockerfile @@ -50,7 +50,6 @@ RUN set -e && \ for jar_path in \ "org/apache/iceberg/iceberg-spark-runtime-${ICEBERG_SPARK_RUNTIME_VERSION}/${ICEBERG_VERSION}/iceberg-spark-runtime-${ICEBERG_SPARK_RUNTIME_VERSION}-${ICEBERG_VERSION}.jar" \ "org/apache/iceberg/iceberg-aws-bundle/${ICEBERG_VERSION}/iceberg-aws-bundle-${ICEBERG_VERSION}.jar" \ - "org/apache/iceberg/iceberg-core/${ICEBERG_VERSION}/iceberg-core-${ICEBERG_VERSION}-tests.jar" \ "org/apache/hadoop/hadoop-aws/${HADOOP_VERSION}/hadoop-aws-${HADOOP_VERSION}.jar" \ "org/apache/hadoop/hadoop-client-api/${HADOOP_VERSION}/hadoop-client-api-${HADOOP_VERSION}.jar" \ "org/apache/hadoop/hadoop-client-runtime/${HADOOP_VERSION}/hadoop-client-runtime-${HADOOP_VERSION}.jar" \ @@ -63,6 +62,24 @@ RUN set -e && \ chown spark:spark "${jar_name}"; \ done +# Pull UnitestKMS + MemoryMockKMS out of iceberg-core's -tests.jar into a slim jar that only +# ships those two classes. The full -tests.jar contains testing stubs for unrelated classes +# (e.g. a stub S3ABlockOutputStream that lacks builder()) which would shadow the real ones in +# hadoop-aws and break Spark's S3A writes; this avoids that classpath collision. +RUN set -e && \ + tmp="$(mktemp -d)" && cd "${tmp}" && \ + curl -fsSL --retry 3 --retry-delay 5 \ + -o iceberg-core-tests.jar \ + "${MAVEN_MIRROR}/org/apache/iceberg/iceberg-core/${ICEBERG_VERSION}/iceberg-core-${ICEBERG_VERSION}-tests.jar" && \ + /opt/java/openjdk/bin/jar xf iceberg-core-tests.jar \ + org/apache/iceberg/encryption/UnitestKMS.class \ + org/apache/iceberg/encryption/MemoryMockKMS.class && \ + /opt/java/openjdk/bin/jar cf "${SPARK_HOME}/jars/iceberg-core-${ICEBERG_VERSION}-tests-kms-only.jar" \ + org/apache/iceberg/encryption/UnitestKMS.class \ + org/apache/iceberg/encryption/MemoryMockKMS.class && \ + chown spark:spark "${SPARK_HOME}/jars/iceberg-core-${ICEBERG_VERSION}-tests-kms-only.jar" && \ + rm -rf "${tmp}" + # Copy configuration last (changes more frequently than JARs) COPY --chown=spark:spark spark-defaults.conf ${SPARK_HOME}/conf/ From 53c9bf9a3fd5a6fc1057ceccd65e569b63aa5561 Mon Sep 17 00:00:00 2001 From: Sreesh Maheshwar Date: Fri, 29 May 2026 05:58:07 +0100 Subject: [PATCH 8/9] inspect: surface key_metadata in all_manifests (iceberg 1.11.0, apache/iceberg#14750) --- pyiceberg/table/inspect.py | 3 +++ tests/integration/test_inspect_table.py | 1 + 2 files changed, 4 insertions(+) diff --git a/pyiceberg/table/inspect.py b/pyiceberg/table/inspect.py index 5da343ccb6..3b97c71682 100644 --- a/pyiceberg/table/inspect.py +++ b/pyiceberg/table/inspect.py @@ -404,6 +404,8 @@ def _get_all_manifests_schema(self) -> pa.Schema: all_manifests_schema = self._get_manifests_schema() all_manifests_schema = all_manifests_schema.append(pa.field("reference_snapshot_id", pa.int64(), nullable=False)) + # Iceberg 1.11.0 (apache/iceberg#14750) added key_metadata to the all_manifests table only. + all_manifests_schema = all_manifests_schema.append(pa.field("key_metadata", pa.binary(), nullable=True)) return all_manifests_schema def _generate_manifests_table(self, snapshot: Snapshot | None, is_all_manifests_table: bool = False) -> pa.Table: @@ -468,6 +470,7 @@ def _partition_summaries_to_rows( } if is_all_manifests_table: manifest_row["reference_snapshot_id"] = snapshot.snapshot_id + manifest_row["key_metadata"] = manifest.key_metadata manifests.append(manifest_row) return pa.Table.from_pylist( diff --git a/tests/integration/test_inspect_table.py b/tests/integration/test_inspect_table.py index 03d4437d18..4d8dfbe9bb 100644 --- a/tests/integration/test_inspect_table.py +++ b/tests/integration/test_inspect_table.py @@ -1012,6 +1012,7 @@ def test_inspect_all_manifests(spark: SparkSession, session_catalog: Catalog, fo "deleted_delete_files_count", "partition_summaries", "reference_snapshot_id", + "key_metadata", ] int_cols = [ From b2db501a21b1d00ab6cd524a86a9b91beb440dee Mon Sep 17 00:00:00 2001 From: Sreesh Maheshwar Date: Fri, 29 May 2026 07:16:55 +0100 Subject: [PATCH 9/9] drop REST mentions and trim verbose comments/docstrings --- dev/provision.py | 4 +- pyiceberg/encryption/ciphers.py | 66 +++++++------------------- pyiceberg/encryption/io.py | 13 +----- pyiceberg/encryption/key_metadata.py | 67 ++++++--------------------- pyiceberg/encryption/kms.py | 41 +++++----------- pyiceberg/encryption/manager.py | 64 +++---------------------- pyiceberg/io/pyarrow.py | 16 ++----- pyiceberg/manifest.py | 13 +++--- pyiceberg/table/metadata.py | 7 +-- tests/encryption/test_ciphers.py | 12 +---- tests/encryption/test_key_metadata.py | 7 +-- tests/encryption/test_kms.py | 10 +--- tests/encryption/test_manager.py | 25 ++-------- tests/integration/test_encryption.py | 45 +++++------------- 14 files changed, 80 insertions(+), 310 deletions(-) diff --git a/dev/provision.py b/dev/provision.py index dd5c3f931f..8587319092 100644 --- a/dev/provision.py +++ b/dev/provision.py @@ -396,9 +396,7 @@ spark.sql(f"ALTER TABLE {catalog_name}.default.test_empty_scan_ordered_str WRITE ORDERED BY id") spark.sql(f"INSERT INTO {catalog_name}.default.test_empty_scan_ordered_str VALUES 'a', 'c'") -# Encrypted Iceberg table written via Spark, read back via PyIceberg in tests/integration/test_encryption.py. -# Only the Hive catalog is configured with a Java-side KMS (encryption.kms-impl=UnitestKMS); the REST catalog -# image does not ship UnitestKMS so we limit this fixture to Hive. +# Encrypted Hive-cataloged table; read back via PyIceberg in tests/integration/test_encryption.py. spark.sql(""" CREATE OR REPLACE TABLE hive.default.test_encrypted (id bigint, data string, value float) USING iceberg diff --git a/pyiceberg/encryption/ciphers.py b/pyiceberg/encryption/ciphers.py index ed023ce53e..bf0c9f01a9 100644 --- a/pyiceberg/encryption/ciphers.py +++ b/pyiceberg/encryption/ciphers.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""AES-GCM encryption/decryption primitives and AGS1 stream decryption.""" +"""AES-GCM primitives and Iceberg AGS1 stream decryption.""" from __future__ import annotations @@ -28,58 +28,37 @@ def aes_gcm_encrypt(key: bytes, plaintext: bytes, aad: bytes | None = None) -> bytes: - """Encrypt using AES-GCM. Returns nonce || ciphertext || tag.""" nonce = os.urandom(NONCE_LENGTH) - aesgcm = AESGCM(key) - ciphertext_with_tag = aesgcm.encrypt(nonce, plaintext, aad) - return nonce + ciphertext_with_tag + return nonce + AESGCM(key).encrypt(nonce, plaintext, aad) def aes_gcm_decrypt(key: bytes, ciphertext: bytes, aad: bytes | None = None) -> bytes: - """Decrypt AES-GCM data in format: nonce || ciphertext || tag.""" if len(ciphertext) < NONCE_LENGTH + GCM_TAG_LENGTH: raise ValueError(f"Ciphertext too short: {len(ciphertext)} bytes") - nonce = ciphertext[:NONCE_LENGTH] - encrypted_data = ciphertext[NONCE_LENGTH:] - aesgcm = AESGCM(key) - return aesgcm.decrypt(nonce, encrypted_data, aad) + return AESGCM(key).decrypt(ciphertext[:NONCE_LENGTH], ciphertext[NONCE_LENGTH:], aad) -# AGS1 stream constants GCM_STREAM_MAGIC = b"AGS1" -GCM_STREAM_HEADER_LENGTH = 8 # 4 magic + 4 block size +GCM_STREAM_HEADER_LENGTH = 8 # 4 magic + 4 little-endian block size def stream_block_aad(aad_prefix: bytes, block_index: int) -> bytes: - """Construct per-block AAD for AGS1 stream encryption. - - Format: aad_prefix || block_index (4 bytes, little-endian). - """ - index_bytes = struct.pack(" bytes: - """Decrypt an entire AGS1 stream and return the plaintext. - - AGS1 format: - - Header: "AGS1" (4 bytes) + plain_block_size (4 bytes LE) - - Blocks: each block is nonce(12) + ciphertext(up to 1MB) + tag(16) - - Each block's AAD = aad_prefix + block_index (4 bytes LE) + """Decrypt an Iceberg AGS1 stream. + Layout: "AGS1" (4) | plain_block_size LE (4) | one or more {nonce(12) | cipher | tag(16)} blocks. + Each block's AAD is `aad_prefix || block_index_le32`. """ if len(encrypted_data) < GCM_STREAM_HEADER_LENGTH: raise ValueError(f"AGS1 stream too short: {len(encrypted_data)} bytes") - - magic = encrypted_data[:4] - if magic != GCM_STREAM_MAGIC: - raise ValueError(f"Invalid AGS1 magic: {magic!r}, expected {GCM_STREAM_MAGIC!r}") + if encrypted_data[:4] != GCM_STREAM_MAGIC: + raise ValueError(f"Invalid AGS1 magic: {encrypted_data[:4]!r}") plain_block_size = struct.unpack_from(" result = bytearray() offset = 0 block_index = 0 - while offset < len(stream_data): - # Determine this block's cipher size - remaining = len(stream_data) - offset - if remaining >= cipher_block_size: - block_cipher_size = cipher_block_size - else: - block_cipher_size = remaining - + block_cipher_size = min(cipher_block_size, len(stream_data) - offset) if block_cipher_size < NONCE_LENGTH + GCM_TAG_LENGTH: - raise ValueError( - f"Truncated AGS1 block at offset {offset}: {block_cipher_size} bytes (minimum {NONCE_LENGTH + GCM_TAG_LENGTH})" - ) - - block_data = stream_data[offset : offset + block_cipher_size] - nonce = block_data[:NONCE_LENGTH] - ciphertext_with_tag = block_data[NONCE_LENGTH:] - - aad = stream_block_aad(aad_prefix, block_index) - plaintext = aesgcm.decrypt(nonce, ciphertext_with_tag, aad) - result.extend(plaintext) + raise ValueError(f"Truncated AGS1 block at offset {offset}: {block_cipher_size} bytes") + block = stream_data[offset : offset + block_cipher_size] + result.extend( + aesgcm.decrypt(block[:NONCE_LENGTH], block[NONCE_LENGTH:], stream_block_aad(aad_prefix, block_index)) + ) offset += block_cipher_size block_index += 1 diff --git a/pyiceberg/encryption/io.py b/pyiceberg/encryption/io.py index 310cd842a3..cbba9fffcc 100644 --- a/pyiceberg/encryption/io.py +++ b/pyiceberg/encryption/io.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""InputFile implementation backed by in-memory bytes.""" +"""In-memory InputFile/InputStream used to wrap decrypted Avro buffers for AvroFile.""" from __future__ import annotations @@ -25,8 +25,6 @@ class BytesInputStream(InputStream): - """InputStream implementation backed by a bytes buffer.""" - def __init__(self, data: bytes) -> None: self._buffer = io.BytesIO(data) @@ -45,7 +43,6 @@ def close(self) -> None: self._buffer.close() def __enter__(self) -> BytesInputStream: - """Enter the context manager.""" return self def __exit__( @@ -54,23 +51,15 @@ def __exit__( excinst: BaseException | None, exctb: TracebackType | None, ) -> None: - """Exit the context manager and close the stream.""" self.close() class BytesInputFile(InputFile): - """InputFile implementation backed by in-memory bytes. - - Used to wrap decrypted data so that it can be read by - AvroFile and other readers that expect an InputFile. - """ - def __init__(self, location: str, data: bytes) -> None: super().__init__(location) self._data = data def __len__(self) -> int: - """Return the length of the underlying data.""" return len(self._data) def exists(self) -> bool: diff --git a/pyiceberg/encryption/key_metadata.py b/pyiceberg/encryption/key_metadata.py index 4d22778d5c..211f6a48c1 100644 --- a/pyiceberg/encryption/key_metadata.py +++ b/pyiceberg/encryption/key_metadata.py @@ -14,14 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""StandardKeyMetadata Avro serialization. +"""StandardKeyMetadata Avro codec. -Wire format: ``0x01 version byte || Avro-encoded fields`` - -Avro schema: - - encryption_key: bytes (required) - - aad_prefix: union[null, bytes] (optional) - - file_length: union[null, long] (optional) +Wire: ``0x01 version`` || encryption_key (bytes) || aad_prefix (union[null,bytes]) + || file_length (union[null,long]). """ from __future__ import annotations @@ -32,7 +28,6 @@ def _read_avro_long(data: bytes, offset: int) -> tuple[int, int]: - """Read a zigzag-encoded Avro long from data at offset. Returns (value, new_offset).""" result = 0 shift = 0 while True: @@ -44,12 +39,10 @@ def _read_avro_long(data: bytes, offset: int) -> tuple[int, int]: if (b & 0x80) == 0: break shift += 7 - # Zigzag decode return (result >> 1) ^ -(result & 1), offset def _read_avro_bytes(data: bytes, offset: int) -> tuple[bytes, int]: - """Read Avro bytes (length-prefixed). Returns (bytes_value, new_offset).""" length, offset = _read_avro_long(data, offset) if length < 0: raise ValueError(f"Negative Avro bytes length: {length}") @@ -61,31 +54,20 @@ def _read_avro_bytes(data: bytes, offset: int) -> tuple[bytes, int]: @dataclass(frozen=True) class StandardKeyMetadata: - """Standard key metadata for Iceberg table encryption. - - Contains the plaintext encryption key (DEK), AAD prefix, and optional file length. - """ - encryption_key: bytes aad_prefix: bytes = b"" file_length: int | None = None @staticmethod def deserialize(data: bytes) -> StandardKeyMetadata: - """Deserialize from wire format: ``0x01 version || Avro-encoded fields``.""" if not data: raise ValueError("Empty key metadata buffer") - - version = data[0] - if version != V1: - raise ValueError(f"Unsupported key metadata version: {version}") - + if data[0] != V1: + raise ValueError(f"Unsupported key metadata version: {data[0]}") offset = 1 - # Read encryption_key (required bytes) encryption_key, offset = _read_avro_bytes(data, offset) - # Read aad_prefix (optional: union[null, bytes]) union_index, offset = _read_avro_long(data, offset) if union_index == 0: aad_prefix = b"" @@ -94,50 +76,30 @@ def deserialize(data: bytes) -> StandardKeyMetadata: else: raise ValueError(f"Invalid union index for aad_prefix: {union_index}") - # Read file_length (optional: union[null, long]) - file_length = None + file_length: int | None = None if offset < len(data): union_index, offset = _read_avro_long(data, offset) - if union_index == 0: - file_length = None - elif union_index == 1: + if union_index == 1: file_length, offset = _read_avro_long(data, offset) - else: + elif union_index != 0: raise ValueError(f"Invalid union index for file_length: {union_index}") - return StandardKeyMetadata( - encryption_key=encryption_key, - aad_prefix=aad_prefix, - file_length=file_length, - ) + return StandardKeyMetadata(encryption_key=encryption_key, aad_prefix=aad_prefix, file_length=file_length) def serialize(self) -> bytes: - """Serialize to wire format: ``0x01 version || Avro-encoded fields``.""" - parts = [bytes([V1])] - - # encryption_key (required bytes) - parts.append(_encode_avro_bytes(self.encryption_key)) - - # aad_prefix (union[null, bytes]) + parts = [bytes([V1]), _encode_avro_bytes(self.encryption_key)] if self.aad_prefix: - parts.append(_encode_avro_long(1)) # union index 1 = bytes - parts.append(_encode_avro_bytes(self.aad_prefix)) + parts += [_encode_avro_long(1), _encode_avro_bytes(self.aad_prefix)] else: - parts.append(_encode_avro_long(0)) # union index 0 = null - - # file_length (union[null, long]) + parts.append(_encode_avro_long(0)) if self.file_length is not None: - parts.append(_encode_avro_long(1)) # union index 1 = long - parts.append(_encode_avro_long(self.file_length)) + parts += [_encode_avro_long(1), _encode_avro_long(self.file_length)] else: - parts.append(_encode_avro_long(0)) # union index 0 = null - + parts.append(_encode_avro_long(0)) return b"".join(parts) def _encode_avro_long(value: int) -> bytes: - """Encode a long as zigzag-encoded Avro varint.""" - # Zigzag encode n = (value << 1) ^ (value >> 63) result = bytearray() while n & ~0x7F: @@ -148,5 +110,4 @@ def _encode_avro_long(value: int) -> bytes: def _encode_avro_bytes(data: bytes) -> bytes: - """Encode bytes with Avro length prefix.""" return _encode_avro_long(len(data)) + data diff --git a/pyiceberg/encryption/kms.py b/pyiceberg/encryption/kms.py index 2ba1a5e1ac..5c7c34ddd6 100644 --- a/pyiceberg/encryption/kms.py +++ b/pyiceberg/encryption/kms.py @@ -14,8 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Key Management Service interfaces and implementations.""" - from __future__ import annotations import importlib @@ -34,15 +32,11 @@ class KeyManagementClient(ABC): - """Abstract base class for key management operations.""" - @abstractmethod - def wrap_key(self, key: bytes, wrapping_key_id: str) -> bytes: - """Wrap (encrypt) a key using the master key identified by wrapping_key_id.""" + def wrap_key(self, key: bytes, wrapping_key_id: str) -> bytes: ... @abstractmethod - def unwrap_key(self, wrapped_key: bytes, wrapping_key_id: str) -> bytes: - """Unwrap (decrypt) a wrapped key using the master key identified by wrapping_key_id.""" + def unwrap_key(self, wrapped_key: bytes, wrapping_key_id: str) -> bytes: ... def initialize(self, properties: dict[str, str]) -> None: # noqa: B027 """Initialize the KMS client from catalog/table properties.""" @@ -55,39 +49,26 @@ def __init__(self, master_keys: dict[str, bytes] | None = None) -> None: self._master_keys: dict[str, bytes] = dict(master_keys) if master_keys else {} def initialize(self, properties: dict[str, str]) -> None: + prefix = "encryption.kms.key." for key, value in properties.items(): - if key.startswith("encryption.kms.key."): - key_id = key[len("encryption.kms.key.") :] - self._master_keys[key_id] = bytes.fromhex(value) + if key.startswith(prefix): + self._master_keys[key[len(prefix) :]] = bytes.fromhex(value) def wrap_key(self, key: bytes, wrapping_key_id: str) -> bytes: - master_key = self._master_keys.get(wrapping_key_id) - if master_key is None: - raise ValueError(f"Wrapping key not found: {wrapping_key_id}") - return aes_gcm_encrypt(master_key, key, aad=None) + return aes_gcm_encrypt(self._master(wrapping_key_id), key, aad=None) def unwrap_key(self, wrapped_key: bytes, wrapping_key_id: str) -> bytes: + return aes_gcm_decrypt(self._master(wrapping_key_id), wrapped_key, aad=None) + + def _master(self, wrapping_key_id: str) -> bytes: master_key = self._master_keys.get(wrapping_key_id) if master_key is None: raise ValueError(f"Wrapping key not found: {wrapping_key_id}") - return aes_gcm_decrypt(master_key, wrapped_key, aad=None) + return master_key def load_kms_client(properties: Properties) -> KeyManagementClient | None: - """Load a KMS client from properties using py-kms-impl. - - Follows the same pattern as py-io-impl for FileIO. - - The property 'py-kms-impl' should be a fully qualified Python class name - (e.g., 'pyiceberg.encryption.kms.InMemoryKms'). The class must be a - subclass of KeyManagementClient. - - Args: - properties: Catalog and/or table properties. - - Returns: - An initialized KeyManagementClient, or None if py-kms-impl is not set. - """ + """Instantiate a KeyManagementClient from a fully-qualified `py-kms-impl` (or return None).""" kms_impl = properties.get(PY_KMS_IMPL) if kms_impl is None: return None diff --git a/pyiceberg/encryption/manager.py b/pyiceberg/encryption/manager.py index 07219d35ad..1d28a2f30c 100644 --- a/pyiceberg/encryption/manager.py +++ b/pyiceberg/encryption/manager.py @@ -14,16 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Encryption manager implementing two-layer envelope key management. - -Key hierarchy: - - Master Key (in KMS) wraps KEK - - KEK wraps DEK (using local AES-GCM) - - DEK encrypts data (manifest lists, manifests, data files) - -The KEK timestamp is used as AAD when wrapping/unwrapping DEKs -to prevent timestamp tampering attacks. -""" +"""Iceberg two-layer envelope-key manager: KMS master key wraps KEK; KEK wraps the per-file DEK.""" from __future__ import annotations @@ -39,8 +30,6 @@ class EncryptedKey: - """Represents an encrypted key entry from table metadata.""" - def __init__( self, key_id: str, @@ -54,7 +43,6 @@ def __init__( self.properties = properties or {} def __repr__(self) -> str: - """Return a string representation of the EncryptedKey.""" return ( f"EncryptedKey(key_id={self.key_id!r}, " f"encrypted_by_id={self.encrypted_by_id!r}, " @@ -63,15 +51,6 @@ def __repr__(self) -> str: class EncryptionManager: - """Manages encryption/decryption for an Iceberg table. - - Orchestrates the two-layer envelope key management: - 1. Unwrap KEK via KMS using master key - 2. Use KEK to decrypt manifest list/manifest key metadata (with timestamp AAD) - 3. Parse StandardKeyMetadata to get DEK + AAD prefix - 4. Decrypt AGS1 streams or provide FileDecryptionProperties for Parquet - """ - def __init__( self, kms_client: KeyManagementClient, @@ -82,7 +61,6 @@ def __init__( self._kek_cache: dict[str, bytes] = {} def _unwrap_kek(self, kek: EncryptedKey) -> bytes: - """Unwrap a KEK using the KMS, with caching.""" if kek.key_id in self._kek_cache: return self._kek_cache[kek.key_id] @@ -94,52 +72,27 @@ def _unwrap_kek(self, kek: EncryptedKey) -> bytes: return plaintext def _unwrap_dek(self, wrapped_dek: bytes, kek_key_id: str) -> bytes: - """Unwrap a DEK using the specified KEK. - - Uses the KEK timestamp as AAD to prevent timestamp tampering. - """ kek = self._encryption_keys.get(kek_key_id) if kek is None: raise ValueError(f"KEK not found in encryption keys: {kek_key_id}") kek_bytes = self._unwrap_kek(kek) - - # Use KEK timestamp as AAD to prevent tampering + # KEK timestamp is bound as AAD to defeat tampering. aad = kek.properties.get(KEK_CREATED_AT_PROPERTY) aad_bytes = aad.encode("utf-8") if aad else None - return aes_gcm_decrypt(kek_bytes, wrapped_dek, aad=aad_bytes) def unwrap_key_metadata(self, encrypted_key: EncryptedKey) -> bytes: - """Unwrap key metadata that was KEK-wrapped. - - Given an EncryptedKey entry (e.g., from a snapshot's key-id mapping), - unwrap it using the KEK identified by encrypted_by_id. - """ if not encrypted_key.encrypted_by_id: raise ValueError(f"EncryptedKey '{encrypted_key.key_id}' has no encrypted_by_id") - - return self._unwrap_dek( - encrypted_key.encrypted_key_metadata, - encrypted_key.encrypted_by_id, - ) + return self._unwrap_dek(encrypted_key.encrypted_key_metadata, encrypted_key.encrypted_by_id) def decrypt_manifest_list(self, encrypted_data: bytes, snapshot_key_id: str) -> bytes: - """Decrypt an AGS1-encrypted manifest list. - - 1. Look up the EncryptedKey for the snapshot's key_id - 2. Unwrap the key metadata using the KEK - 3. Parse StandardKeyMetadata to get DEK + AAD prefix - 4. Decrypt the AGS1 stream - """ encrypted_key = self._encryption_keys.get(snapshot_key_id) if encrypted_key is None: raise ValueError(f"Snapshot key not found in encryption keys: {snapshot_key_id}") - # Unwrap the key metadata - key_metadata_bytes = self.unwrap_key_metadata(encrypted_key) - key_metadata = StandardKeyMetadata.deserialize(key_metadata_bytes) - + key_metadata = StandardKeyMetadata.deserialize(self.unwrap_key_metadata(encrypted_key)) return decrypt_ags1_stream( key=key_metadata.encryption_key, encrypted_data=encrypted_data, @@ -147,14 +100,9 @@ def decrypt_manifest_list(self, encrypted_data: bytes, snapshot_key_id: str) -> ) def decrypt_manifest(self, encrypted_data: bytes, key_metadata_bytes: bytes) -> bytes: - """Decrypt an AGS1-encrypted manifest file. - - The key_metadata_bytes are from ManifestFile.key_metadata -- these contain - the plaintext DEK and AAD prefix (NOT wrapped by KEK, since they're already - stored inside the encrypted manifest list). - """ + # ManifestFile.key_metadata carries the plaintext DEK + AAD; it's already + # protected by living inside the encrypted manifest list. key_metadata = StandardKeyMetadata.deserialize(key_metadata_bytes) - return decrypt_ags1_stream( key=key_metadata.encryption_key, encrypted_data=encrypted_data, diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index f272498fa4..dca8f1743b 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1118,11 +1118,7 @@ def _get_file_format(file_format: FileFormat, **kwargs: dict[str, Any]) -> ds.Fi def _get_decryption_properties(key_metadata_bytes: bytes) -> Any: - """Build FileDecryptionProperties from Iceberg key metadata. - - Requires PyArrow >= 25 (currently nightly-only) for the direct-key - `create_decryption_properties` API added by apache/arrow#49667. - """ + # Needs PyArrow >= 25 (nightly today) for `create_decryption_properties` (apache/arrow#49667). try: import pyarrow.parquet.encryption as pe @@ -1130,10 +1126,8 @@ def _get_decryption_properties(key_metadata_bytes: bytes) -> Any: raise ImportError("create_decryption_properties not available") except ImportError as e: raise ImportError( - "Parquet Modular Encryption requires PyArrow >= 25 with the direct-key API " - "(apache/arrow#49667). Until it releases, install the nightly: " - "`make install-pyarrow-nightly` (or `uv pip install -i " - "https://pypi.anaconda.org/scientific-python-nightly-wheels/simple pyarrow`)." + "Parquet Modular Encryption requires PyArrow >= 25 (apache/arrow#49667). " + "Install the nightly via `make install-pyarrow-nightly`." ) from e from pyiceberg.encryption.key_metadata import StandardKeyMetadata @@ -1141,7 +1135,7 @@ def _get_decryption_properties(key_metadata_bytes: bytes) -> Any: key_metadata = StandardKeyMetadata.deserialize(key_metadata_bytes) return pe.create_decryption_properties( footer_key=key_metadata.encryption_key, - aad_prefix=key_metadata.aad_prefix if key_metadata.aad_prefix else None, + aad_prefix=key_metadata.aad_prefix or None, ) @@ -1670,8 +1664,6 @@ def _task_to_record_batches( ) -> Iterator[pa.RecordBatch]: arrow_format = _get_file_format(task.file.file_format, pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8)) - # For encrypted files, create a ParquetFileFormat with decryption properties - # so that make_fragment can read the encrypted metadata if task.file.key_metadata is not None and encryption_manager is not None: decryption_properties = _get_decryption_properties(task.file.key_metadata) scan_options = ds.ParquetFragmentScanOptions( diff --git a/pyiceberg/manifest.py b/pyiceberg/manifest.py index d5b98c5b49..e4bbe745cd 100644 --- a/pyiceberg/manifest.py +++ b/pyiceberg/manifest.py @@ -874,14 +874,14 @@ def fetch_manifest_entry( Args: io: The FileIO to fetch the file. discard_deleted: Filter on live entries. - encryption_manager: Optional encryption manager for decrypting encrypted manifests. + encryption_manager: When set, decrypt an AGS1-encrypted manifest. Returns: An Iterator of manifest entries. """ input_file = io.new_input(self.manifest_path) - # If this manifest has key_metadata, it's AGS1-encrypted + # A non-null key_metadata signals the manifest is AGS1-encrypted. if self.key_metadata is not None and encryption_manager is not None: from pyiceberg.encryption.io import BytesInputFile @@ -945,8 +945,8 @@ def _manifests( Args: io: FileIO instance for reading the manifest list. manifest_list: Path to the manifest list file. - encryption_manager: Optional encryption manager for decrypting encrypted manifest lists. - snapshot_key_id: Optional key ID from snapshot for manifest list decryption. + encryption_manager: When set together with snapshot_key_id, decrypt an AGS1-encrypted manifest list. + snapshot_key_id: Snapshot's encryption key id. Returns: A tuple of ManifestFile objects. @@ -977,13 +977,12 @@ def read_manifest_list( Args: input_file: The input file where the stream can be read from. - encryption_manager: Optional encryption manager for decrypting encrypted manifest lists. - snapshot_key_id: Optional key ID from snapshot for manifest list decryption. + encryption_manager: When set together with snapshot_key_id, decrypt an AGS1-encrypted manifest list. + snapshot_key_id: Snapshot's encryption key id. Returns: An iterator of ManifestFiles that are part of the list. """ - # If we have encryption info, decrypt the manifest list first if snapshot_key_id is not None and encryption_manager is not None: from pyiceberg.encryption.io import BytesInputFile diff --git a/pyiceberg/table/metadata.py b/pyiceberg/table/metadata.py index ed07177af4..7d39f1c6c1 100644 --- a/pyiceberg/table/metadata.py +++ b/pyiceberg/table/metadata.py @@ -51,20 +51,15 @@ class EncryptedKeyModel(IcebergBaseModel): - """An encrypted key entry in table metadata. - - Matches the EncryptedKey schema in the REST API spec. - """ + """An encrypted key entry in table metadata.""" key_id: str = Field(alias="key-id") encrypted_key_metadata: str = Field(alias="encrypted-key-metadata") - """Base64-encoded encrypted key metadata bytes.""" encrypted_by_id: str | None = Field(alias="encrypted-by-id", default=None) properties: dict[str, str] = Field(default_factory=dict) @property def encrypted_key_metadata_bytes(self) -> bytes: - """Decode the base64-encoded encrypted key metadata.""" return base64.b64decode(self.encrypted_key_metadata) diff --git a/tests/encryption/test_ciphers.py b/tests/encryption/test_ciphers.py index ccb7d058d5..4ceb2f01a3 100644 --- a/tests/encryption/test_ciphers.py +++ b/tests/encryption/test_ciphers.py @@ -62,11 +62,10 @@ def test_wrong_aad_fails(self) -> None: aes_gcm_decrypt(key, ciphertext, aad=b"wrong") def test_wire_format(self) -> None: - """Verify the wire format: nonce(12) || ciphertext || tag(16).""" + # nonce(12) || ciphertext || tag(16) key = os.urandom(16) plaintext = b"test" ciphertext = aes_gcm_encrypt(key, plaintext) - # Minimum size: nonce + tag + at least len(plaintext) of ciphertext assert len(ciphertext) == NONCE_LENGTH + len(plaintext) + GCM_TAG_LENGTH def test_ciphertext_too_short(self) -> None: @@ -90,7 +89,6 @@ def test_without_prefix(self) -> None: def _encrypt_ags1_stream(key: bytes, plaintext: bytes, aad_prefix: bytes, plain_block_size: int = 1024 * 1024) -> bytes: - """Build an AGS1 encrypted stream for testing.""" from cryptography.hazmat.primitives.ciphers.aead import AESGCM aesgcm = AESGCM(key) @@ -120,11 +118,9 @@ def test_roundtrip_single_block(self) -> None: assert decrypt_ags1_stream(key, encrypted, aad_prefix) == plaintext def test_roundtrip_multi_block(self) -> None: - """Test with a small block size to force multiple blocks.""" key = os.urandom(16) plaintext = b"A" * 200 aad_prefix = b"multi" - # Use a 64-byte block size to get multiple blocks encrypted = _encrypt_ags1_stream(key, plaintext, aad_prefix, plain_block_size=64) assert decrypt_ags1_stream(key, encrypted, aad_prefix) == plaintext @@ -143,20 +139,16 @@ def test_too_short(self) -> None: decrypt_ags1_stream(os.urandom(16), b"AGS1", b"") def test_custom_block_size(self) -> None: - """Verify the block size from the header is respected, not hardcoded.""" + # Block size comes from the header, not a constant. key = os.urandom(16) plaintext = b"B" * 300 aad_prefix = b"custom" - # Encrypt with a 100-byte block size encrypted = _encrypt_ags1_stream(key, plaintext, aad_prefix, plain_block_size=100) - # Verify the header contains 100 assert struct.unpack_from(" None: key = os.urandom(16) - # Header + a few bytes that are too short for even nonce+tag data = GCM_STREAM_MAGIC + struct.pack(" None: assert restored.file_length == 12345 def test_version_byte(self) -> None: - """First byte should always be 0x01.""" key = os.urandom(16) serialized = StandardKeyMetadata(encryption_key=key).serialize() assert serialized[0] == 0x01 @@ -66,15 +65,13 @@ def test_deserialize_wrong_version(self) -> None: StandardKeyMetadata.deserialize(b"\x02\x00") def test_frozen(self) -> None: - """StandardKeyMetadata is a frozen dataclass.""" skm = StandardKeyMetadata(encryption_key=b"key") with pytest.raises(AttributeError): skm.encryption_key = b"other" # type: ignore[misc] def test_roundtrip_large_file_length(self) -> None: - """Zigzag encoding should handle large values correctly.""" + # Exercise zigzag varint for a value beyond int32. key = os.urandom(16) original = StandardKeyMetadata(encryption_key=key, file_length=2**40) - serialized = original.serialize() - restored = StandardKeyMetadata.deserialize(serialized) + restored = StandardKeyMetadata.deserialize(original.serialize()) assert restored.file_length == 2**40 diff --git a/tests/encryption/test_kms.py b/tests/encryption/test_kms.py index d445f27201..ac1e82dd88 100644 --- a/tests/encryption/test_kms.py +++ b/tests/encryption/test_kms.py @@ -44,7 +44,6 @@ def test_initialize_from_properties(self) -> None: kms = InMemoryKms() key_hex = os.urandom(16).hex() kms.initialize({"encryption.kms.key.testKey": key_hex}) - # Should be able to wrap/unwrap with the initialized key wrapped = kms.wrap_key(b"secret", "testKey") assert kms.unwrap_key(wrapped, "testKey") == b"secret" @@ -55,7 +54,6 @@ def test_initialize_ignores_unrelated_properties(self) -> None: kms.wrap_key(b"key", "nonexistent") def test_wrap_unwrap_with_standard_test_keys(self) -> None: - """Wrap/unwrap with the standard Iceberg test master keys.""" kms = InMemoryKms( master_keys={ "keyA": b"0123456789012345", @@ -93,9 +91,7 @@ def test_loads_in_memory_kms(self) -> None: "encryption.kms.key.myKey": os.urandom(16).hex(), } ) - assert client is not None assert isinstance(client, InMemoryKms) - # Should be initialized — the key should be usable wrapped = client.wrap_key(b"data", "myKey") assert client.unwrap_key(wrapped, "myKey") == b"data" @@ -112,13 +108,11 @@ def test_nonexistent_class(self) -> None: load_kms_client({"py-kms-impl": "pyiceberg.encryption.kms.NonexistentClass"}) def test_not_a_subclass(self) -> None: + # AESGCM is a real class but not a KeyManagementClient with pytest.raises(ValueError, match="not a subclass"): - # AESGCM is a real class but not a KeyManagementClient load_kms_client({"py-kms-impl": "cryptography.hazmat.primitives.ciphers.aead.AESGCM"}) def test_custom_kms_impl(self) -> None: - """Verify that a custom KMS implementation can be loaded by module path.""" - class _TestKms(KeyManagementClient): initialized_with: dict[str, str] = {} @@ -131,13 +125,11 @@ def unwrap_key(self, wrapped_key: bytes, wrapping_key_id: str) -> bytes: def initialize(self, properties: dict[str, str]) -> None: _TestKms.initialized_with = properties - # Register in the module namespace so importlib can find it import pyiceberg.encryption.kms as kms_module kms_module._TestKms = _TestKms # type: ignore[attr-defined] try: client = load_kms_client({"py-kms-impl": "pyiceberg.encryption.kms._TestKms", "foo": "bar"}) - assert client is not None assert isinstance(client, _TestKms) assert _TestKms.initialized_with.get("foo") == "bar" finally: diff --git a/tests/encryption/test_manager.py b/tests/encryption/test_manager.py index fed30cca84..af83d1a9fe 100644 --- a/tests/encryption/test_manager.py +++ b/tests/encryption/test_manager.py @@ -31,7 +31,6 @@ def _make_ags1_stream(key: bytes, plaintext: bytes, aad_prefix: bytes, plain_block_size: int = 1024 * 1024) -> bytes: - """Build an AGS1 encrypted stream for testing.""" from cryptography.hazmat.primitives.ciphers.aead import AESGCM aesgcm = AESGCM(key) @@ -53,26 +52,17 @@ def _make_ags1_stream(key: bytes, plaintext: bytes, aad_prefix: bytes, plain_blo def _build_test_encryption_manager() -> tuple[EncryptionManager, bytes, bytes, str]: - """Build an EncryptionManager with test keys, mimicking the REST catalog flow. - - Returns (manager, dek, aad_prefix, manifest_list_key_id). - """ - master_key = b"0123456789012345" # 16 bytes, standard Iceberg test key "keyA" + master_key = b"0123456789012345" kms = InMemoryKms(master_keys={"keyA": master_key}) - # Create a KEK (simulating what the REST catalog would provide) kek_bytes = os.urandom(16) kek_wrapped = kms.wrap_key(kek_bytes, "keyA") kek_timestamp = "1234567890" - # Create a DEK for the manifest list dek = os.urandom(16) aad_prefix = os.urandom(8) key_metadata = StandardKeyMetadata(encryption_key=dek, aad_prefix=aad_prefix) - key_metadata_bytes = key_metadata.serialize() - - # Wrap the DEK key metadata with the KEK (using timestamp as AAD) - wrapped_dek = aes_gcm_encrypt(kek_bytes, key_metadata_bytes, aad=kek_timestamp.encode("utf-8")) + wrapped_dek = aes_gcm_encrypt(kek_bytes, key_metadata.serialize(), aad=kek_timestamp.encode("utf-8")) encryption_keys = { "kek-1": EncryptedKey( @@ -107,24 +97,17 @@ def test_decrypt_manifest(self) -> None: encrypted_stream = _make_ags1_stream(dek, plaintext, aad_prefix) key_metadata = StandardKeyMetadata(encryption_key=dek, aad_prefix=aad_prefix) - # The manager only needs a KMS for KEK unwrapping; manifest decryption - # uses the plaintext key metadata directly (from inside the encrypted manifest list) - kms = InMemoryKms() - manager = EncryptionManager(kms_client=kms) + manager = EncryptionManager(kms_client=InMemoryKms()) result = manager.decrypt_manifest(encrypted_stream, key_metadata.serialize()) assert result == plaintext def test_kek_caching(self) -> None: - """KEK should be unwrapped once and cached.""" manager, dek, aad_prefix, mlk_id = _build_test_encryption_manager() - plaintext = b"test" - encrypted = _make_ags1_stream(dek, plaintext, aad_prefix) + encrypted = _make_ags1_stream(dek, b"test", aad_prefix) - # Decrypt twice manager.decrypt_manifest_list(encrypted, mlk_id) manager.decrypt_manifest_list(encrypted, mlk_id) - # KEK should be cached assert "kek-1" in manager._kek_cache def test_missing_snapshot_key(self) -> None: diff --git a/tests/integration/test_encryption.py b/tests/integration/test_encryption.py index 4f6b8b57a3..2c546200c8 100644 --- a/tests/integration/test_encryption.py +++ b/tests/integration/test_encryption.py @@ -14,16 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Reads of Spark-written, Parquet-encrypted Iceberg tables via PyIceberg. +"""Read Spark-written, Parquet-encrypted Iceberg tables (Hive catalog) via PyIceberg. -The encrypted table (`hive.default.test_encrypted`) is provisioned by `dev/provision.py` -using Spark with `encryption.kms-impl=org.apache.iceberg.encryption.UnitestKMS`. UnitestKMS -ships hardcoded master keys (keyA=b"0123456789012345", keyB=b"1123456789012345"); we mirror -those bytes here through PyIceberg's InMemoryKms so unwrapping succeeds. - -Decryption of the data files requires PyArrow's `parquet.encryption.create_decryption_properties` -API, which is available in PyArrow >= 25 (currently shipped only via the nightly wheels). See -the Makefile target `install-pyarrow-nightly`. +`hive.default.test_encrypted` is written by Spark in `dev/provision.py` using +`encryption.kms-impl=org.apache.iceberg.encryption.UnitestKMS`. UnitestKMS uses fixed master +keys keyA/keyB which we mirror into PyIceberg's InMemoryKms below. """ from __future__ import annotations @@ -32,8 +27,6 @@ from pyiceberg.catalog import load_catalog -# UnitestKMS master keys, hex-encoded so they can be set as catalog properties and parsed by -# InMemoryKms.initialize (`encryption.kms.key.=`). _KEY_A_HEX = b"0123456789012345".hex() _KEY_B_HEX = b"1123456789012345".hex() @@ -61,20 +54,15 @@ def test_encrypted_table_metadata(hive_catalog_with_kms) -> None: # type: ignor assert tbl.metadata.format_version == 3 assert tbl.metadata.properties.get("encryption.key-id") == "keyA" - assert tbl.metadata.encryption_keys, "expected encryption keys on table metadata" - - snapshot = tbl.current_snapshot() - assert snapshot is not None - assert snapshot.key_id is not None, "expected key_id on current snapshot" + assert tbl.metadata.encryption_keys + assert tbl.current_snapshot() is not None + assert tbl.current_snapshot().key_id is not None @pytest.mark.integration def test_encrypted_table_to_arrow(hive_catalog_with_kms) -> None: # type: ignore[no-untyped-def] tbl = hive_catalog_with_kms.load_table("default.test_encrypted") - result = tbl.scan().to_arrow().sort_by("id") - - assert result.num_rows == 3 assert result.column("id").to_pylist() == [1, 2, 3] assert result.column("data").to_pylist() == ["alice", "bob", "charlie"] assert result.column("value").to_pylist() == [1.0, 2.0, 3.0] @@ -83,9 +71,7 @@ def test_encrypted_table_to_arrow(hive_catalog_with_kms) -> None: # type: ignor @pytest.mark.integration def test_encrypted_table_to_pandas(hive_catalog_with_kms) -> None: # type: ignore[no-untyped-def] tbl = hive_catalog_with_kms.load_table("default.test_encrypted") - df = tbl.scan().to_pandas().sort_values("id").reset_index(drop=True) - assert list(df["id"]) == [1, 2, 3] assert list(df["data"]) == ["alice", "bob", "charlie"] assert list(df["value"]) == [1.0, 2.0, 3.0] @@ -94,19 +80,15 @@ def test_encrypted_table_to_pandas(hive_catalog_with_kms) -> None: # type: igno @pytest.mark.integration def test_encrypted_table_to_duckdb(hive_catalog_with_kms) -> None: # type: ignore[no-untyped-def] tbl = hive_catalog_with_kms.load_table("default.test_encrypted") - con = tbl.scan().to_duckdb("encrypted") rows = con.execute("SELECT id, data, value FROM encrypted ORDER BY id").fetchall() - assert rows == [(1, "alice", 1.0), (2, "bob", 2.0), (3, "charlie", 3.0)] @pytest.mark.integration def test_encrypted_table_to_polars(hive_catalog_with_kms) -> None: # type: ignore[no-untyped-def] tbl = hive_catalog_with_kms.load_table("default.test_encrypted") - df = tbl.scan().to_polars().sort("id") - assert df["id"].to_list() == [1, 2, 3] assert df["data"].to_list() == ["alice", "bob", "charlie"] assert df["value"].to_list() == [1.0, 2.0, 3.0] @@ -114,19 +96,14 @@ def test_encrypted_table_to_polars(hive_catalog_with_kms) -> None: # type: igno @pytest.mark.integration def test_encrypted_table_direct_parquet_read_fails(hive_catalog_with_kms) -> None: # type: ignore[no-untyped-def] - """Canary: a raw PyArrow read of a data file without decryption properties must fail. - - Mirrors iceberg-java's TestTableEncryption#testDirectDataFileRead, which proves the data - files are genuinely PME-encrypted by asserting that reading them without the keys raises - ParquetCryptoRuntimeException. Without this check, the read tests above could silently pass - on plaintext Parquet and the POC would be meaningless. - """ + # Mirrors iceberg-java's TestTableEncryption#testDirectDataFileRead — proves the data + # files are really PME-encrypted (raw read without keys must fail), so the above tests + # can't silently pass on plaintext. import pyarrow.parquet as pq tbl = hive_catalog_with_kms.load_table("default.test_encrypted") - data_files = [task.file.file_path for task in tbl.scan().plan_files()] - assert data_files, "expected at least one data file in the encrypted table" + assert data_files for file_path in data_files: with pytest.raises(OSError, match="encrypted"), tbl.io.new_input(file_path).open() as fi: