diff --git a/Makefile b/Makefile index 4fe761192c..90cffc7aad 100644 --- a/Makefile +++ b/Makefile @@ -16,7 +16,7 @@ # under the License. .PHONY: help install install-uv check-license lint \ test test-integration test-integration-setup test-integration-exec test-integration-cleanup test-integration-rebuild \ - test-s3 test-adls test-gcs test-coverage coverage-report test test-notebook\ + test-s3 test-adls test-gcs test-coverage coverage-report \ docs-serve docs-build notebook notebook-infra \ clean @@ -38,10 +38,12 @@ else PYTHON_ARG = endif +# --no-sync so that overlays applied after `make install` (e.g. install-pyarrow-nightly for +# the encryption integration test) aren't reverted by uv re-syncing the lockfile on `uv run`. ifeq ($(COVERAGE),1) - TEST_RUNNER = uv run $(PYTHON_ARG) python -m coverage run --parallel-mode --source=pyiceberg -m + TEST_RUNNER = uv run --no-sync $(PYTHON_ARG) python -m coverage run --parallel-mode --source=pyiceberg -m else - TEST_RUNNER = uv run $(PYTHON_ARG) python -m + TEST_RUNNER = uv run --no-sync $(PYTHON_ARG) python -m endif ifeq ($(KEEP_COMPOSE),1) @@ -112,7 +114,16 @@ test-integration-setup: install ## Start Docker services for integration tests docker compose -f dev/docker-compose-integration.yml kill docker compose -f dev/docker-compose-integration.yml rm -f docker compose -f dev/docker-compose-integration.yml up -d --build --wait - uv run $(PYTHON_ARG) python dev/provision.py + uv run --no-sync $(PYTHON_ARG) python dev/provision.py + $(MAKE) install-pyarrow-nightly + +# Parquet Modular Encryption decryption (tests/integration/test_encryption.py) needs the +# pyarrow.parquet.encryption.create_decryption_properties API from apache/arrow#49667. That +# lands in pyarrow 25, which hasn't been released — pull the nightly until it is. Runs AFTER +# provision so that the implicit `uv run` sync during provision.py doesn't revert this overlay. +install-pyarrow-nightly: ## Overlay nightly pyarrow on top of the installed env (for PME) + uv pip install $(PYTHON_ARG) --prerelease=allow --upgrade --force-reinstall \ + -i https://pypi.anaconda.org/scientific-python-nightly-wheels/simple pyarrow test-integration-exec: ## Run integration tests (excluding provision) $(TEST_RUNNER) pytest tests/ -m integration $(PYTEST_ARGS) @@ -150,9 +161,6 @@ coverage-report: ## Combine and report coverage uv run $(PYTHON_ARG) coverage html uv run $(PYTHON_ARG) coverage xml -test-notebook: ## Run notebook tests (pyiceberg_example and spark_integration_example) via papermill - $(TEST_RUNNER) pytest tests/notebooks/test_pyiceberg_example.py tests/notebooks/test_spark_integration_example.py -m notebook $(PYTEST_ARGS) - # ================ # Documentation # ================ diff --git a/dev/provision.py b/dev/provision.py index 695ef9b1bf..8587319092 100644 --- a/dev/provision.py +++ b/dev/provision.py @@ -395,3 +395,11 @@ ) spark.sql(f"ALTER TABLE {catalog_name}.default.test_empty_scan_ordered_str WRITE ORDERED BY id") spark.sql(f"INSERT INTO {catalog_name}.default.test_empty_scan_ordered_str VALUES 'a', 'c'") + +# Encrypted Hive-cataloged table; read back via PyIceberg in tests/integration/test_encryption.py. +spark.sql(""" + CREATE OR REPLACE TABLE hive.default.test_encrypted (id bigint, data string, value float) + USING iceberg + TBLPROPERTIES ('encryption.key-id'='keyA', 'format-version'='3') +""") +spark.sql("INSERT INTO hive.default.test_encrypted VALUES (1, 'alice', 1.0), (2, 'bob', 2.0), (3, 'charlie', 3.0)") diff --git a/dev/spark/Dockerfile b/dev/spark/Dockerfile index 0e1f29d152..e578ee8158 100644 --- a/dev/spark/Dockerfile +++ b/dev/spark/Dockerfile @@ -18,10 +18,12 @@ ARG BASE_IMAGE_SPARK_VERSION=4.0.1 FROM apache/spark:${BASE_IMAGE_SPARK_VERSION} # Dependency versions - keep these compatible -# Changing these will invalidate the JAR download cache layer -ARG ICEBERG_VERSION=1.10.1 +# Changing these will invalidate the JAR download cache layer. +# Iceberg 1.11.0 carries the Hive encryption integration (apache/iceberg#13066) — the prior +# 1.10.x release predates that work and silently no-ops encryption.kms-impl / encryption.key-id. +ARG ICEBERG_VERSION=1.11.0 ARG ICEBERG_SPARK_RUNTIME_VERSION=4.0_2.13 -ARG HADOOP_VERSION=3.4.1 +ARG HADOOP_VERSION=3.4.3 ARG AWS_SDK_VERSION=2.24.6 ARG MAVEN_MIRROR=https://repo.maven.apache.org/maven2 @@ -36,14 +38,21 @@ RUN apt-get update -qq && \ mkdir -p /home/iceberg/spark-events && \ chown -R spark:spark /home/iceberg -# Download JARs with retry logic (most cacheable - only changes when versions change) -# This is the slowest step, so we do it before copying config files +# Download JARs with retry logic (most cacheable - only changes when versions change). +# iceberg-core-${ICEBERG_VERSION}-tests.jar ships org.apache.iceberg.encryption.UnitestKMS, a +# fixed-master-key KMS used by the encryption integration test on the Spark write path. RUN set -e && \ cd "${SPARK_HOME}/jars" && \ + # Spark 4.0.1 ships hadoop-client-{api,runtime} 3.4.1; replace them with HADOOP_VERSION so + # hadoop-aws methods like ConfigurationHelper.resolveEnum / S3ABlockOutputStream.builder() + # that iceberg-aws-bundle 1.11.0 expects are present at runtime. + rm -f hadoop-client-api-*.jar hadoop-client-runtime-*.jar && \ for jar_path in \ "org/apache/iceberg/iceberg-spark-runtime-${ICEBERG_SPARK_RUNTIME_VERSION}/${ICEBERG_VERSION}/iceberg-spark-runtime-${ICEBERG_SPARK_RUNTIME_VERSION}-${ICEBERG_VERSION}.jar" \ "org/apache/iceberg/iceberg-aws-bundle/${ICEBERG_VERSION}/iceberg-aws-bundle-${ICEBERG_VERSION}.jar" \ "org/apache/hadoop/hadoop-aws/${HADOOP_VERSION}/hadoop-aws-${HADOOP_VERSION}.jar" \ + "org/apache/hadoop/hadoop-client-api/${HADOOP_VERSION}/hadoop-client-api-${HADOOP_VERSION}.jar" \ + "org/apache/hadoop/hadoop-client-runtime/${HADOOP_VERSION}/hadoop-client-runtime-${HADOOP_VERSION}.jar" \ "software/amazon/awssdk/bundle/${AWS_SDK_VERSION}/bundle-${AWS_SDK_VERSION}.jar"; \ do \ jar_name=$(basename "${jar_path}") && \ @@ -53,6 +62,24 @@ RUN set -e && \ chown spark:spark "${jar_name}"; \ done +# Pull UnitestKMS + MemoryMockKMS out of iceberg-core's -tests.jar into a slim jar that only +# ships those two classes. The full -tests.jar contains testing stubs for unrelated classes +# (e.g. a stub S3ABlockOutputStream that lacks builder()) which would shadow the real ones in +# hadoop-aws and break Spark's S3A writes; this avoids that classpath collision. +RUN set -e && \ + tmp="$(mktemp -d)" && cd "${tmp}" && \ + curl -fsSL --retry 3 --retry-delay 5 \ + -o iceberg-core-tests.jar \ + "${MAVEN_MIRROR}/org/apache/iceberg/iceberg-core/${ICEBERG_VERSION}/iceberg-core-${ICEBERG_VERSION}-tests.jar" && \ + /opt/java/openjdk/bin/jar xf iceberg-core-tests.jar \ + org/apache/iceberg/encryption/UnitestKMS.class \ + org/apache/iceberg/encryption/MemoryMockKMS.class && \ + /opt/java/openjdk/bin/jar cf "${SPARK_HOME}/jars/iceberg-core-${ICEBERG_VERSION}-tests-kms-only.jar" \ + org/apache/iceberg/encryption/UnitestKMS.class \ + org/apache/iceberg/encryption/MemoryMockKMS.class && \ + chown spark:spark "${SPARK_HOME}/jars/iceberg-core-${ICEBERG_VERSION}-tests-kms-only.jar" && \ + rm -rf "${tmp}" + # Copy configuration last (changes more frequently than JARs) COPY --chown=spark:spark spark-defaults.conf ${SPARK_HOME}/conf/ diff --git a/dev/spark/spark-defaults.conf b/dev/spark/spark-defaults.conf index 4e50f590c7..02bd4ac5cc 100644 --- a/dev/spark/spark-defaults.conf +++ b/dev/spark/spark-defaults.conf @@ -34,6 +34,11 @@ spark.sql.catalog.hive.io-impl org.apache.iceberg.aws.s3.S3FileIO spark.sql.catalog.hive.warehouse s3://warehouse/hive/ spark.sql.catalog.hive.s3.endpoint http://minio:9000 +# Test-only KMS so Spark can write encrypted Iceberg tables for the encryption integration test. +# UnitestKMS comes from iceberg-core--tests.jar and uses fixed master keys ("keyA", +# "keyB") that match the InMemoryKms config used on the PyIceberg side. +spark.sql.catalog.hive.encryption.kms-impl org.apache.iceberg.encryption.UnitestKMS + # Configure Spark's default session catalog (spark_catalog) to use Iceberg backed by the Hive Metastore spark.sql.catalog.spark_catalog org.apache.iceberg.spark.SparkSessionCatalog spark.sql.catalog.spark_catalog.type hive diff --git a/pyiceberg/encryption/__init__.py b/pyiceberg/encryption/__init__.py new file mode 100644 index 0000000000..9c86adbdce --- /dev/null +++ b/pyiceberg/encryption/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Iceberg table encryption support.""" diff --git a/pyiceberg/encryption/ciphers.py b/pyiceberg/encryption/ciphers.py new file mode 100644 index 0000000000..bf0c9f01a9 --- /dev/null +++ b/pyiceberg/encryption/ciphers.py @@ -0,0 +1,82 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""AES-GCM primitives and Iceberg AGS1 stream decryption.""" + +from __future__ import annotations + +import os +import struct + +from cryptography.hazmat.primitives.ciphers.aead import AESGCM + +NONCE_LENGTH = 12 +GCM_TAG_LENGTH = 16 + + +def aes_gcm_encrypt(key: bytes, plaintext: bytes, aad: bytes | None = None) -> bytes: + nonce = os.urandom(NONCE_LENGTH) + return nonce + AESGCM(key).encrypt(nonce, plaintext, aad) + + +def aes_gcm_decrypt(key: bytes, ciphertext: bytes, aad: bytes | None = None) -> bytes: + if len(ciphertext) < NONCE_LENGTH + GCM_TAG_LENGTH: + raise ValueError(f"Ciphertext too short: {len(ciphertext)} bytes") + return AESGCM(key).decrypt(ciphertext[:NONCE_LENGTH], ciphertext[NONCE_LENGTH:], aad) + + +GCM_STREAM_MAGIC = b"AGS1" +GCM_STREAM_HEADER_LENGTH = 8 # 4 magic + 4 little-endian block size + + +def stream_block_aad(aad_prefix: bytes, block_index: int) -> bytes: + return aad_prefix + struct.pack(" bytes: + """Decrypt an Iceberg AGS1 stream. + + Layout: "AGS1" (4) | plain_block_size LE (4) | one or more {nonce(12) | cipher | tag(16)} blocks. + Each block's AAD is `aad_prefix || block_index_le32`. + """ + if len(encrypted_data) < GCM_STREAM_HEADER_LENGTH: + raise ValueError(f"AGS1 stream too short: {len(encrypted_data)} bytes") + if encrypted_data[:4] != GCM_STREAM_MAGIC: + raise ValueError(f"Invalid AGS1 magic: {encrypted_data[:4]!r}") + + plain_block_size = struct.unpack_from(" None: + self._buffer = io.BytesIO(data) + + def read(self, size: int = 0) -> bytes: + if size <= 0: + return self._buffer.read() + return self._buffer.read(size) + + def seek(self, offset: int, whence: int = 0) -> int: + return self._buffer.seek(offset, whence) + + def tell(self) -> int: + return self._buffer.tell() + + def close(self) -> None: + self._buffer.close() + + def __enter__(self) -> BytesInputStream: + return self + + def __exit__( + self, + exctype: type[BaseException] | None, + excinst: BaseException | None, + exctb: TracebackType | None, + ) -> None: + self.close() + + +class BytesInputFile(InputFile): + def __init__(self, location: str, data: bytes) -> None: + super().__init__(location) + self._data = data + + def __len__(self) -> int: + return len(self._data) + + def exists(self) -> bool: + return True + + def open(self, seekable: bool = True) -> InputStream: + return BytesInputStream(self._data) diff --git a/pyiceberg/encryption/key_metadata.py b/pyiceberg/encryption/key_metadata.py new file mode 100644 index 0000000000..211f6a48c1 --- /dev/null +++ b/pyiceberg/encryption/key_metadata.py @@ -0,0 +1,113 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""StandardKeyMetadata Avro codec. + +Wire: ``0x01 version`` || encryption_key (bytes) || aad_prefix (union[null,bytes]) + || file_length (union[null,long]). +""" + +from __future__ import annotations + +from dataclasses import dataclass + +V1 = 0x01 + + +def _read_avro_long(data: bytes, offset: int) -> tuple[int, int]: + result = 0 + shift = 0 + while True: + if offset >= len(data): + raise ValueError("Unexpected end of Avro data reading long") + b = data[offset] + offset += 1 + result |= (b & 0x7F) << shift + if (b & 0x80) == 0: + break + shift += 7 + return (result >> 1) ^ -(result & 1), offset + + +def _read_avro_bytes(data: bytes, offset: int) -> tuple[bytes, int]: + length, offset = _read_avro_long(data, offset) + if length < 0: + raise ValueError(f"Negative Avro bytes length: {length}") + end = offset + length + if end > len(data): + raise ValueError("Unexpected end of Avro data reading bytes") + return data[offset:end], end + + +@dataclass(frozen=True) +class StandardKeyMetadata: + encryption_key: bytes + aad_prefix: bytes = b"" + file_length: int | None = None + + @staticmethod + def deserialize(data: bytes) -> StandardKeyMetadata: + if not data: + raise ValueError("Empty key metadata buffer") + if data[0] != V1: + raise ValueError(f"Unsupported key metadata version: {data[0]}") + offset = 1 + + encryption_key, offset = _read_avro_bytes(data, offset) + + union_index, offset = _read_avro_long(data, offset) + if union_index == 0: + aad_prefix = b"" + elif union_index == 1: + aad_prefix, offset = _read_avro_bytes(data, offset) + else: + raise ValueError(f"Invalid union index for aad_prefix: {union_index}") + + file_length: int | None = None + if offset < len(data): + union_index, offset = _read_avro_long(data, offset) + if union_index == 1: + file_length, offset = _read_avro_long(data, offset) + elif union_index != 0: + raise ValueError(f"Invalid union index for file_length: {union_index}") + + return StandardKeyMetadata(encryption_key=encryption_key, aad_prefix=aad_prefix, file_length=file_length) + + def serialize(self) -> bytes: + parts = [bytes([V1]), _encode_avro_bytes(self.encryption_key)] + if self.aad_prefix: + parts += [_encode_avro_long(1), _encode_avro_bytes(self.aad_prefix)] + else: + parts.append(_encode_avro_long(0)) + if self.file_length is not None: + parts += [_encode_avro_long(1), _encode_avro_long(self.file_length)] + else: + parts.append(_encode_avro_long(0)) + return b"".join(parts) + + +def _encode_avro_long(value: int) -> bytes: + n = (value << 1) ^ (value >> 63) + result = bytearray() + while n & ~0x7F: + result.append((n & 0x7F) | 0x80) + n >>= 7 + result.append(n & 0x7F) + return bytes(result) + + +def _encode_avro_bytes(data: bytes) -> bytes: + return _encode_avro_long(len(data)) + data diff --git a/pyiceberg/encryption/kms.py b/pyiceberg/encryption/kms.py new file mode 100644 index 0000000000..5c7c34ddd6 --- /dev/null +++ b/pyiceberg/encryption/kms.py @@ -0,0 +1,95 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import importlib +import logging +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +from pyiceberg.encryption.ciphers import aes_gcm_decrypt, aes_gcm_encrypt + +if TYPE_CHECKING: + from pyiceberg.typedef import Properties + +logger = logging.getLogger(__name__) + +PY_KMS_IMPL = "py-kms-impl" + + +class KeyManagementClient(ABC): + @abstractmethod + def wrap_key(self, key: bytes, wrapping_key_id: str) -> bytes: ... + + @abstractmethod + def unwrap_key(self, wrapped_key: bytes, wrapping_key_id: str) -> bytes: ... + + def initialize(self, properties: dict[str, str]) -> None: # noqa: B027 + """Initialize the KMS client from catalog/table properties.""" + + +class InMemoryKms(KeyManagementClient): + """In-memory KMS for testing. NOT for production use.""" + + def __init__(self, master_keys: dict[str, bytes] | None = None) -> None: + self._master_keys: dict[str, bytes] = dict(master_keys) if master_keys else {} + + def initialize(self, properties: dict[str, str]) -> None: + prefix = "encryption.kms.key." + for key, value in properties.items(): + if key.startswith(prefix): + self._master_keys[key[len(prefix) :]] = bytes.fromhex(value) + + def wrap_key(self, key: bytes, wrapping_key_id: str) -> bytes: + return aes_gcm_encrypt(self._master(wrapping_key_id), key, aad=None) + + def unwrap_key(self, wrapped_key: bytes, wrapping_key_id: str) -> bytes: + return aes_gcm_decrypt(self._master(wrapping_key_id), wrapped_key, aad=None) + + def _master(self, wrapping_key_id: str) -> bytes: + master_key = self._master_keys.get(wrapping_key_id) + if master_key is None: + raise ValueError(f"Wrapping key not found: {wrapping_key_id}") + return master_key + + +def load_kms_client(properties: Properties) -> KeyManagementClient | None: + """Instantiate a KeyManagementClient from a fully-qualified `py-kms-impl` (or return None).""" + kms_impl = properties.get(PY_KMS_IMPL) + if kms_impl is None: + return None + + path_parts = kms_impl.split(".") + if len(path_parts) < 2: + raise ValueError(f"py-kms-impl should be a full path (module.ClassName), got: {kms_impl}") + + module_name, class_name = ".".join(path_parts[:-1]), path_parts[-1] + try: + module = importlib.import_module(module_name) + except ModuleNotFoundError as e: + raise ValueError(f"Could not import KMS module: {module_name}") from e + + kms_class = getattr(module, class_name, None) + if kms_class is None: + raise ValueError(f"KMS class {class_name} not found in module {module_name}") + + if not (isinstance(kms_class, type) and issubclass(kms_class, KeyManagementClient)): + raise ValueError(f"{kms_impl} is not a subclass of KeyManagementClient") + + client = kms_class() + client.initialize(dict(properties)) + return client diff --git a/pyiceberg/encryption/manager.py b/pyiceberg/encryption/manager.py new file mode 100644 index 0000000000..1d28a2f30c --- /dev/null +++ b/pyiceberg/encryption/manager.py @@ -0,0 +1,110 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Iceberg two-layer envelope-key manager: KMS master key wraps KEK; KEK wraps the per-file DEK.""" + +from __future__ import annotations + +import logging + +from pyiceberg.encryption.ciphers import aes_gcm_decrypt, decrypt_ags1_stream +from pyiceberg.encryption.key_metadata import StandardKeyMetadata +from pyiceberg.encryption.kms import KeyManagementClient + +logger = logging.getLogger(__name__) + +KEK_CREATED_AT_PROPERTY = "KEY_TIMESTAMP" + + +class EncryptedKey: + def __init__( + self, + key_id: str, + encrypted_key_metadata: bytes, + encrypted_by_id: str | None = None, + properties: dict[str, str] | None = None, + ) -> None: + self.key_id = key_id + self.encrypted_key_metadata = encrypted_key_metadata + self.encrypted_by_id = encrypted_by_id + self.properties = properties or {} + + def __repr__(self) -> str: + return ( + f"EncryptedKey(key_id={self.key_id!r}, " + f"encrypted_by_id={self.encrypted_by_id!r}, " + f"metadata_len={len(self.encrypted_key_metadata)})" + ) + + +class EncryptionManager: + def __init__( + self, + kms_client: KeyManagementClient, + encryption_keys: dict[str, EncryptedKey] | None = None, + ) -> None: + self._kms = kms_client + self._encryption_keys = encryption_keys or {} + self._kek_cache: dict[str, bytes] = {} + + def _unwrap_kek(self, kek: EncryptedKey) -> bytes: + if kek.key_id in self._kek_cache: + return self._kek_cache[kek.key_id] + + if not kek.encrypted_by_id: + raise ValueError(f"KEK '{kek.key_id}' has no encrypted_by_id") + + plaintext = self._kms.unwrap_key(kek.encrypted_key_metadata, kek.encrypted_by_id) + self._kek_cache[kek.key_id] = plaintext + return plaintext + + def _unwrap_dek(self, wrapped_dek: bytes, kek_key_id: str) -> bytes: + kek = self._encryption_keys.get(kek_key_id) + if kek is None: + raise ValueError(f"KEK not found in encryption keys: {kek_key_id}") + + kek_bytes = self._unwrap_kek(kek) + # KEK timestamp is bound as AAD to defeat tampering. + aad = kek.properties.get(KEK_CREATED_AT_PROPERTY) + aad_bytes = aad.encode("utf-8") if aad else None + return aes_gcm_decrypt(kek_bytes, wrapped_dek, aad=aad_bytes) + + def unwrap_key_metadata(self, encrypted_key: EncryptedKey) -> bytes: + if not encrypted_key.encrypted_by_id: + raise ValueError(f"EncryptedKey '{encrypted_key.key_id}' has no encrypted_by_id") + return self._unwrap_dek(encrypted_key.encrypted_key_metadata, encrypted_key.encrypted_by_id) + + def decrypt_manifest_list(self, encrypted_data: bytes, snapshot_key_id: str) -> bytes: + encrypted_key = self._encryption_keys.get(snapshot_key_id) + if encrypted_key is None: + raise ValueError(f"Snapshot key not found in encryption keys: {snapshot_key_id}") + + key_metadata = StandardKeyMetadata.deserialize(self.unwrap_key_metadata(encrypted_key)) + return decrypt_ags1_stream( + key=key_metadata.encryption_key, + encrypted_data=encrypted_data, + aad_prefix=key_metadata.aad_prefix, + ) + + def decrypt_manifest(self, encrypted_data: bytes, key_metadata_bytes: bytes) -> bytes: + # ManifestFile.key_metadata carries the plaintext DEK + AAD; it's already + # protected by living inside the encrypted manifest list. + key_metadata = StandardKeyMetadata.deserialize(key_metadata_bytes) + return decrypt_ags1_stream( + key=key_metadata.encryption_key, + encrypted_data=encrypted_data, + aad_prefix=key_metadata.aad_prefix, + ) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index d4414c7c52..dca8f1743b 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -187,6 +187,7 @@ from pyiceberg.utils.truncate import truncate_upper_bound_binary_string, truncate_upper_bound_text_string if TYPE_CHECKING: + from pyiceberg.encryption.manager import EncryptionManager from pyiceberg.table import FileScanTask, WriteTask logger = logging.getLogger(__name__) @@ -1116,12 +1117,47 @@ def _get_file_format(file_format: FileFormat, **kwargs: dict[str, Any]) -> ds.Fi raise ValueError(f"Unsupported file format: {file_format}") -def _read_deletes(io: FileIO, data_file: DataFile) -> dict[str, pa.ChunkedArray]: +def _get_decryption_properties(key_metadata_bytes: bytes) -> Any: + # Needs PyArrow >= 25 (nightly today) for `create_decryption_properties` (apache/arrow#49667). + try: + import pyarrow.parquet.encryption as pe + + if not hasattr(pe, "create_decryption_properties"): + raise ImportError("create_decryption_properties not available") + except ImportError as e: + raise ImportError( + "Parquet Modular Encryption requires PyArrow >= 25 (apache/arrow#49667). " + "Install the nightly via `make install-pyarrow-nightly`." + ) from e + + from pyiceberg.encryption.key_metadata import StandardKeyMetadata + + key_metadata = StandardKeyMetadata.deserialize(key_metadata_bytes) + return pe.create_decryption_properties( + footer_key=key_metadata.encryption_key, + aad_prefix=key_metadata.aad_prefix or None, + ) + + +def _read_deletes( + io: FileIO, data_file: DataFile, encryption_manager: EncryptionManager | None = None +) -> dict[str, pa.ChunkedArray]: if data_file.file_format == FileFormat.PARQUET: + arrow_format = _get_file_format( + data_file.file_format, dictionary_columns=("file_path",), pre_buffer=True, buffer_size=ONE_MEGABYTE + ) + + if data_file.key_metadata is not None and encryption_manager is not None: + decryption_properties = _get_decryption_properties(data_file.key_metadata) + scan_options = ds.ParquetFragmentScanOptions( + decryption_properties=decryption_properties, + pre_buffer=True, + buffer_size=ONE_MEGABYTE, + ) + arrow_format = ds.ParquetFileFormat(default_fragment_scan_options=scan_options) + with io.new_input(data_file.file_path).open() as fi: - delete_fragment = _get_file_format( - data_file.file_format, dictionary_columns=("file_path",), pre_buffer=True, buffer_size=ONE_MEGABYTE - ).make_fragment(fi) + delete_fragment = arrow_format.make_fragment(fi) table = ds.Scanner.from_fragment(fragment=delete_fragment).to_table() table = table.unify_dictionaries() return { @@ -1624,8 +1660,19 @@ def _task_to_record_batches( partition_spec: PartitionSpec | None = None, format_version: TableVersion = TableProperties.DEFAULT_FORMAT_VERSION, downcast_ns_timestamp_to_us: bool | None = None, + encryption_manager: EncryptionManager | None = None, ) -> Iterator[pa.RecordBatch]: arrow_format = _get_file_format(task.file.file_format, pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8)) + + if task.file.key_metadata is not None and encryption_manager is not None: + decryption_properties = _get_decryption_properties(task.file.key_metadata) + scan_options = ds.ParquetFragmentScanOptions( + decryption_properties=decryption_properties, + pre_buffer=True, + buffer_size=(ONE_MEGABYTE * 8), + ) + arrow_format = ds.ParquetFileFormat(default_fragment_scan_options=scan_options) + with io.new_input(task.file.file_path).open() as fin: fragment = arrow_format.make_fragment(fin) physical_schema = fragment.physical_schema @@ -1655,14 +1702,15 @@ def _task_to_record_batches( file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False) - fragment_scanner = ds.Scanner.from_fragment( - fragment=fragment, - schema=physical_schema, + scanner_kwargs: dict[str, Any] = { + "fragment": fragment, + "schema": physical_schema, # This will push down the query to Arrow. # But in case there are positional deletes, we have to apply them first - filter=pyarrow_filter if not positional_deletes else None, - columns=[col.name for col in file_project_schema.columns], - ) + "filter": pyarrow_filter if not positional_deletes else None, + "columns": [col.name for col in file_project_schema.columns], + } + fragment_scanner = ds.Scanner.from_fragment(**scanner_kwargs) next_index = 0 batches = fragment_scanner.to_batches() @@ -1701,14 +1749,16 @@ def _task_to_record_batches( ) -def _read_all_delete_files(io: FileIO, tasks: Iterable[FileScanTask]) -> dict[str, list[ChunkedArray]]: +def _read_all_delete_files( + io: FileIO, tasks: Iterable[FileScanTask], encryption_manager: EncryptionManager | None = None +) -> dict[str, list[ChunkedArray]]: deletes_per_file: dict[str, list[ChunkedArray]] = {} unique_deletes = set(itertools.chain.from_iterable([task.delete_files for task in tasks])) if len(unique_deletes) > 0: executor = ExecutorFactory.get_or_create() deletes_per_files: Iterator[dict[str, ChunkedArray]] = executor.map( lambda args: _read_deletes(*args), - [(io, delete_file) for delete_file in unique_deletes], + [(io, delete_file, encryption_manager) for delete_file in unique_deletes], ) for delete in deletes_per_files: for file, arr in delete.items(): @@ -1728,6 +1778,7 @@ class ArrowScan: _case_sensitive: bool _limit: int | None _downcast_ns_timestamp_to_us: bool | None + _encryption_manager: EncryptionManager | None """Scan the Iceberg Table and create an Arrow construct. Attributes: @@ -1737,6 +1788,7 @@ class ArrowScan: _bound_row_filter: Schema bound row expression to filter the data with _case_sensitive: Case sensitivity when looking up column names _limit: Limit the number of records. + _encryption_manager: Optional encryption manager for decrypting data files. """ def __init__( @@ -1747,6 +1799,7 @@ def __init__( row_filter: BooleanExpression, case_sensitive: bool = True, limit: int | None = None, + encryption_manager: EncryptionManager | None = None, ) -> None: self._table_metadata = table_metadata self._io = io @@ -1755,6 +1808,7 @@ def __init__( self._case_sensitive = case_sensitive self._limit = limit self._downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) + self._encryption_manager = encryption_manager @property def _projected_field_ids(self) -> set[int]: @@ -1817,7 +1871,7 @@ def to_record_batches(self, tasks: Iterable[FileScanTask]) -> Iterator[pa.Record ResolveError: When a required field cannot be found in the file ValueError: When a field type in the file cannot be projected to the schema type """ - deletes_per_file = _read_all_delete_files(self._io, tasks) + deletes_per_file = _read_all_delete_files(self._io, tasks, self._encryption_manager) total_row_count = 0 executor = ExecutorFactory.get_or_create() @@ -1865,6 +1919,7 @@ def _record_batches_from_scan_tasks_and_deletes( self._table_metadata.specs().get(task.file.spec_id), self._table_metadata.format_version, self._downcast_ns_timestamp_to_us, + encryption_manager=self._encryption_manager, ) for batch in batches: if self._limit is not None: diff --git a/pyiceberg/manifest.py b/pyiceberg/manifest.py index 3811a9d894..e4bbe745cd 100644 --- a/pyiceberg/manifest.py +++ b/pyiceberg/manifest.py @@ -24,6 +24,7 @@ from enum import Enum from types import TracebackType from typing import ( + TYPE_CHECKING, Any, Literal, ) @@ -39,6 +40,9 @@ from pyiceberg.partitioning import PartitionSpec from pyiceberg.schema import Schema from pyiceberg.typedef import Record, TableVersion + +if TYPE_CHECKING: + from pyiceberg.encryption.manager import EncryptionManager from pyiceberg.types import ( BinaryType, BooleanType, @@ -858,18 +862,34 @@ def has_added_files(self) -> bool: def has_existing_files(self) -> bool: return self.existing_files_count is None or self.existing_files_count > 0 - def fetch_manifest_entry(self, io: FileIO, discard_deleted: bool = True) -> list[ManifestEntry]: + def fetch_manifest_entry( + self, + io: FileIO, + discard_deleted: bool = True, + encryption_manager: EncryptionManager | None = None, + ) -> list[ManifestEntry]: """ Read the manifest entries from the manifest file. Args: io: The FileIO to fetch the file. discard_deleted: Filter on live entries. + encryption_manager: When set, decrypt an AGS1-encrypted manifest. Returns: An Iterator of manifest entries. """ input_file = io.new_input(self.manifest_path) + + # A non-null key_metadata signals the manifest is AGS1-encrypted. + if self.key_metadata is not None and encryption_manager is not None: + from pyiceberg.encryption.io import BytesInputFile + + with input_file.open() as f: + encrypted_data = f.read() + decrypted_data = encryption_manager.decrypt_manifest(encrypted_data, self.key_metadata) + input_file = BytesInputFile(self.manifest_path, decrypted_data) + with AvroFile[ManifestEntry]( input_file, MANIFEST_ENTRY_SCHEMAS[DEFAULT_READ_VERSION], @@ -900,7 +920,12 @@ def __hash__(self) -> int: _manifest_cache_lock = threading.RLock() -def _manifests(io: FileIO, manifest_list: str) -> tuple[ManifestFile, ...]: +def _manifests( + io: FileIO, + manifest_list: str, + encryption_manager: EncryptionManager | None = None, + snapshot_key_id: str | None = None, +) -> tuple[ManifestFile, ...]: """Read manifests from a manifest list, deduplicating ManifestFile objects via cache. Caches individual ManifestFile objects by manifest_path. This is memory-efficient @@ -920,12 +945,14 @@ def _manifests(io: FileIO, manifest_list: str) -> tuple[ManifestFile, ...]: Args: io: FileIO instance for reading the manifest list. manifest_list: Path to the manifest list file. + encryption_manager: When set together with snapshot_key_id, decrypt an AGS1-encrypted manifest list. + snapshot_key_id: Snapshot's encryption key id. Returns: A tuple of ManifestFile objects. """ file = io.new_input(manifest_list) - manifest_files = list(read_manifest_list(file)) + manifest_files = list(read_manifest_list(file, encryption_manager=encryption_manager, snapshot_key_id=snapshot_key_id)) result = [] with _manifest_cache_lock: @@ -940,16 +967,30 @@ def _manifests(io: FileIO, manifest_list: str) -> tuple[ManifestFile, ...]: return tuple(result) -def read_manifest_list(input_file: InputFile) -> Iterator[ManifestFile]: +def read_manifest_list( + input_file: InputFile, + encryption_manager: EncryptionManager | None = None, + snapshot_key_id: str | None = None, +) -> Iterator[ManifestFile]: """ Read the manifests from the manifest list. Args: input_file: The input file where the stream can be read from. + encryption_manager: When set together with snapshot_key_id, decrypt an AGS1-encrypted manifest list. + snapshot_key_id: Snapshot's encryption key id. Returns: An iterator of ManifestFiles that are part of the list. """ + if snapshot_key_id is not None and encryption_manager is not None: + from pyiceberg.encryption.io import BytesInputFile + + with input_file.open() as f: + encrypted_data = f.read() + decrypted_data = encryption_manager.decrypt_manifest_list(encrypted_data, snapshot_key_id) + input_file = BytesInputFile(input_file.location, decrypted_data) + with AvroFile[ManifestFile]( input_file, MANIFEST_LIST_FILE_SCHEMAS[DEFAULT_READ_VERSION], diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index b8d87143c9..90e7dcc1bc 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -107,6 +107,7 @@ from pyiceberg.catalog import Catalog from pyiceberg.catalog.rest.scan_planning import RESTContentFile, RESTDeleteFile, RESTFileScanTask + from pyiceberg.encryption.manager import EncryptionManager ALWAYS_TRUE = AlwaysTrue() DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE = "downcast-ns-timestamp-to-us-on-write" @@ -1891,6 +1892,7 @@ def _open_manifest( manifest: ManifestFile, partition_filter: Callable[[DataFile], bool], metrics_evaluator: Callable[[DataFile], bool], + encryption_manager: EncryptionManager | None = None, ) -> list[ManifestEntry]: """Open a manifest file and return matching manifest entries. @@ -1899,7 +1901,7 @@ def _open_manifest( """ return [ manifest_entry - for manifest_entry in manifest.fetch_manifest_entry(io, discard_deleted=True) + for manifest_entry in manifest.fetch_manifest_entry(io, discard_deleted=True, encryption_manager=encryption_manager) if partition_filter(manifest_entry.data_file) and metrics_evaluator(manifest_entry.data_file) ] @@ -1917,6 +1919,13 @@ def _min_sequence_number(manifests: list[ManifestFile]) -> int: class DataScan(TableScan): + @cached_property + def _merged_properties(self) -> Properties: + """Catalog properties merged with table properties (table wins).""" + if self.catalog is not None: + return {**self.catalog.properties, **self.table_metadata.properties} + return dict(self.table_metadata.properties) + def _build_partition_projection(self, spec_id: int) -> BooleanExpression: project = inclusive_projection(self.table_metadata.schema(), self.table_metadata.specs()[spec_id], self.case_sensitive) return project(self.row_filter) @@ -1988,6 +1997,38 @@ def _check_sequence_number(min_sequence_number: int, manifest: ManifestFile) -> and (manifest.sequence_number or INITIAL_SEQUENCE_NUMBER) >= min_sequence_number ) + @cached_property + def _encryption_manager(self) -> EncryptionManager | None: + """Create an EncryptionManager if the table has encryption configured.""" + if self.table_metadata.format_version < 3: + return None + + encryption_keys = self.table_metadata.encryption_keys + if not encryption_keys: + return None + + from pyiceberg.encryption.kms import load_kms_client + from pyiceberg.encryption.manager import EncryptedKey + from pyiceberg.encryption.manager import EncryptionManager as EncryptionManagerClass + + kms_client = load_kms_client(self._merged_properties) + if kms_client is None: + return None + + enc_keys_map: dict[str, EncryptedKey] = {} + for ek in encryption_keys: + enc_keys_map[ek.key_id] = EncryptedKey( + key_id=ek.key_id, + encrypted_key_metadata=ek.encrypted_key_metadata_bytes, + encrypted_by_id=ek.encrypted_by_id, + properties=dict(ek.properties), + ) + + return EncryptionManagerClass( + kms_client=kms_client, + encryption_keys=enc_keys_map, + ) + def scan_plan_helper(self) -> Iterator[list[ManifestEntry]]: """Filter and return manifest entries based on partition and metrics evaluators. @@ -1998,6 +2039,8 @@ def scan_plan_helper(self) -> Iterator[list[ManifestEntry]]: if not snapshot: return iter([]) + encryption_manager = self._encryption_manager + # step 1: filter manifests using partition summaries # the filter depends on the partition spec used to write the manifest file, so create a cache of filters for each spec id @@ -2005,7 +2048,7 @@ def scan_plan_helper(self) -> Iterator[list[ManifestEntry]]: manifests = [ manifest_file - for manifest_file in snapshot.manifests(self.io) + for manifest_file in snapshot.manifests(self.io, encryption_manager=encryption_manager) if manifest_evaluators[manifest_file.partition_spec_id](manifest_file) ] @@ -2026,6 +2069,7 @@ def scan_plan_helper(self) -> Iterator[list[ManifestEntry]]: manifest, partition_evaluators[manifest.partition_spec_id], self._build_metrics_evaluator(), + encryption_manager, ) for manifest in manifests if self._check_sequence_number(min_sequence_number, manifest) @@ -2114,7 +2158,13 @@ def to_arrow(self) -> pa.Table: from pyiceberg.io.pyarrow import ArrowScan return ArrowScan( - self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit + self.table_metadata, + self.io, + self.projection(), + self.row_filter, + self.case_sensitive, + self.limit, + encryption_manager=self._encryption_manager, ).to_table(self.plan_files()) def to_arrow_batch_reader(self) -> pa.RecordBatchReader: @@ -2134,7 +2184,13 @@ def to_arrow_batch_reader(self) -> pa.RecordBatchReader: target_schema = schema_to_pyarrow(self.projection()) batches = ArrowScan( - self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit + self.table_metadata, + self.io, + self.projection(), + self.row_filter, + self.case_sensitive, + self.limit, + encryption_manager=self._encryption_manager, ).to_record_batches(self.plan_files()) return pa.RecordBatchReader.from_batches( diff --git a/pyiceberg/table/inspect.py b/pyiceberg/table/inspect.py index 5da343ccb6..3b97c71682 100644 --- a/pyiceberg/table/inspect.py +++ b/pyiceberg/table/inspect.py @@ -404,6 +404,8 @@ def _get_all_manifests_schema(self) -> pa.Schema: all_manifests_schema = self._get_manifests_schema() all_manifests_schema = all_manifests_schema.append(pa.field("reference_snapshot_id", pa.int64(), nullable=False)) + # Iceberg 1.11.0 (apache/iceberg#14750) added key_metadata to the all_manifests table only. + all_manifests_schema = all_manifests_schema.append(pa.field("key_metadata", pa.binary(), nullable=True)) return all_manifests_schema def _generate_manifests_table(self, snapshot: Snapshot | None, is_all_manifests_table: bool = False) -> pa.Table: @@ -468,6 +470,7 @@ def _partition_summaries_to_rows( } if is_all_manifests_table: manifest_row["reference_snapshot_id"] = snapshot.snapshot_id + manifest_row["key_metadata"] = manifest.key_metadata manifests.append(manifest_row) return pa.Table.from_pylist( diff --git a/pyiceberg/table/metadata.py b/pyiceberg/table/metadata.py index 26b6e3d3ad..7d39f1c6c1 100644 --- a/pyiceberg/table/metadata.py +++ b/pyiceberg/table/metadata.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import base64 import datetime import uuid from collections.abc import Iterable @@ -48,6 +49,20 @@ from pyiceberg.utils.config import Config from pyiceberg.utils.datetime import datetime_to_millis + +class EncryptedKeyModel(IcebergBaseModel): + """An encrypted key entry in table metadata.""" + + key_id: str = Field(alias="key-id") + encrypted_key_metadata: str = Field(alias="encrypted-key-metadata") + encrypted_by_id: str | None = Field(alias="encrypted-by-id", default=None) + properties: dict[str, str] = Field(default_factory=dict) + + @property + def encrypted_key_metadata_bytes(self) -> bytes: + return base64.b64decode(self.encrypted_key_metadata) + + CURRENT_SNAPSHOT_ID = "current-snapshot-id" CURRENT_SCHEMA_ID = "current-schema-id" SCHEMAS = "schemas" @@ -574,6 +589,10 @@ def construct_refs(self) -> TableMetadata: next_row_id: int | None = Field(alias="next-row-id", default=None) """A long higher than all assigned row IDs; the next snapshot's `first-row-id`.""" + encryption_keys: list[EncryptedKeyModel] = Field(alias="encryption-keys", default_factory=list) + """Encryption key entries for the two-layer envelope encryption scheme. + Only valid for format version 3 and higher.""" + def model_dump_json(self, exclude_none: bool = True, exclude: Any | None = None, by_alias: bool = True, **kwargs: Any) -> str: raise NotImplementedError("Writing V3 is not yet supported, see: https://github.com/apache/iceberg-python/issues/1551") diff --git a/pyiceberg/table/snapshots.py b/pyiceberg/table/snapshots.py index 7e4c6eb1ec..52bbc54852 100644 --- a/pyiceberg/table/snapshots.py +++ b/pyiceberg/table/snapshots.py @@ -31,6 +31,7 @@ from pyiceberg.schema import Schema if TYPE_CHECKING: + from pyiceberg.encryption.manager import EncryptionManager from pyiceberg.table.metadata import TableMetadata from pyiceberg.typedef import IcebergBaseModel @@ -252,6 +253,9 @@ class Snapshot(IcebergBaseModel): added_rows: int | None = Field( alias="added-rows", default=None, description="The upper bound of the number of rows with assigned row IDs" ) + key_id: str | None = Field( + alias="key-id", default=None, description="ID of the encryption key used to encrypt this snapshot's manifest list" + ) def __str__(self) -> str: """Return the string representation of the Snapshot class.""" @@ -277,9 +281,16 @@ def __repr__(self) -> str: filtered_fields = [field for field in fields if field is not None] return f"Snapshot({', '.join(filtered_fields)})" - def manifests(self, io: FileIO) -> list[ManifestFile]: + def manifests(self, io: FileIO, encryption_manager: EncryptionManager | None = None) -> list[ManifestFile]: """Return the manifests for the given snapshot.""" - return list(_manifests(io, self.manifest_list)) + return list( + _manifests( + io, + self.manifest_list, + encryption_manager=encryption_manager, + snapshot_key_id=self.key_id, + ) + ) class MetadataLogEntry(IcebergBaseModel): diff --git a/tests/encryption/__init__.py b/tests/encryption/__init__.py new file mode 100644 index 0000000000..13a83393a9 --- /dev/null +++ b/tests/encryption/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/encryption/test_ciphers.py b/tests/encryption/test_ciphers.py new file mode 100644 index 0000000000..4ceb2f01a3 --- /dev/null +++ b/tests/encryption/test_ciphers.py @@ -0,0 +1,154 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os +import struct + +import pytest + +from pyiceberg.encryption.ciphers import ( + GCM_STREAM_MAGIC, + GCM_TAG_LENGTH, + NONCE_LENGTH, + aes_gcm_decrypt, + aes_gcm_encrypt, + decrypt_ags1_stream, + stream_block_aad, +) + + +class TestAesGcm: + def test_roundtrip(self) -> None: + key = os.urandom(16) + plaintext = b"hello, encryption" + ciphertext = aes_gcm_encrypt(key, plaintext) + assert aes_gcm_decrypt(key, ciphertext) == plaintext + + def test_roundtrip_with_aad(self) -> None: + key = os.urandom(16) + plaintext = b"hello, encryption" + aad = b"additional-data" + ciphertext = aes_gcm_encrypt(key, plaintext, aad=aad) + assert aes_gcm_decrypt(key, ciphertext, aad=aad) == plaintext + + def test_wrong_key_fails(self) -> None: + from cryptography.exceptions import InvalidTag + + key = os.urandom(16) + wrong_key = os.urandom(16) + ciphertext = aes_gcm_encrypt(key, b"secret") + with pytest.raises(InvalidTag): + aes_gcm_decrypt(wrong_key, ciphertext) + + def test_wrong_aad_fails(self) -> None: + from cryptography.exceptions import InvalidTag + + key = os.urandom(16) + ciphertext = aes_gcm_encrypt(key, b"secret", aad=b"correct") + with pytest.raises(InvalidTag): + aes_gcm_decrypt(key, ciphertext, aad=b"wrong") + + def test_wire_format(self) -> None: + # nonce(12) || ciphertext || tag(16) + key = os.urandom(16) + plaintext = b"test" + ciphertext = aes_gcm_encrypt(key, plaintext) + assert len(ciphertext) == NONCE_LENGTH + len(plaintext) + GCM_TAG_LENGTH + + def test_ciphertext_too_short(self) -> None: + key = os.urandom(16) + with pytest.raises(ValueError, match="Ciphertext too short"): + aes_gcm_decrypt(key, b"short") + + +class TestStreamBlockAad: + def test_with_prefix(self) -> None: + aad = stream_block_aad(b"prefix", 0) + assert aad == b"prefix" + struct.pack(" None: + aad = stream_block_aad(b"prefix", 42) + assert aad == b"prefix" + struct.pack(" None: + aad = stream_block_aad(b"", 7) + assert aad == struct.pack(" bytes: + from cryptography.hazmat.primitives.ciphers.aead import AESGCM + + aesgcm = AESGCM(key) + header = GCM_STREAM_MAGIC + struct.pack(" None: + key = os.urandom(16) + plaintext = b"hello AGS1 stream" + aad_prefix = b"test-aad" + encrypted = _encrypt_ags1_stream(key, plaintext, aad_prefix) + assert decrypt_ags1_stream(key, encrypted, aad_prefix) == plaintext + + def test_roundtrip_multi_block(self) -> None: + key = os.urandom(16) + plaintext = b"A" * 200 + aad_prefix = b"multi" + encrypted = _encrypt_ags1_stream(key, plaintext, aad_prefix, plain_block_size=64) + assert decrypt_ags1_stream(key, encrypted, aad_prefix) == plaintext + + def test_roundtrip_empty_payload(self) -> None: + key = os.urandom(16) + header = GCM_STREAM_MAGIC + struct.pack(" None: + data = b"XXXX" + struct.pack(" None: + with pytest.raises(ValueError, match="AGS1 stream too short"): + decrypt_ags1_stream(os.urandom(16), b"AGS1", b"") + + def test_custom_block_size(self) -> None: + # Block size comes from the header, not a constant. + key = os.urandom(16) + plaintext = b"B" * 300 + aad_prefix = b"custom" + encrypted = _encrypt_ags1_stream(key, plaintext, aad_prefix, plain_block_size=100) + assert struct.unpack_from(" None: + key = os.urandom(16) + data = GCM_STREAM_MAGIC + struct.pack(" None: + stream = BytesInputStream(b"hello") + assert stream.read() == b"hello" + + def test_read_partial(self) -> None: + stream = BytesInputStream(b"hello world") + assert stream.read(5) == b"hello" + assert stream.read(6) == b" world" + + def test_seek_and_tell(self) -> None: + stream = BytesInputStream(b"abcdef") + assert stream.tell() == 0 + stream.seek(3) + assert stream.tell() == 3 + assert stream.read(2) == b"de" + + def test_context_manager(self) -> None: + with BytesInputStream(b"data") as stream: + assert stream.read() == b"data" + + def test_implements_input_stream_protocol(self) -> None: + stream = BytesInputStream(b"test") + assert isinstance(stream, InputStream) + + +class TestBytesInputFile: + def test_len(self) -> None: + f = BytesInputFile("file://test", b"hello") + assert len(f) == 5 + + def test_exists(self) -> None: + f = BytesInputFile("file://test", b"") + assert f.exists() is True + + def test_open_and_read(self) -> None: + f = BytesInputFile("file://test", b"content") + with f.open() as stream: + assert stream.read() == b"content" + + def test_location(self) -> None: + f = BytesInputFile("s3://bucket/path", b"data") + assert f.location == "s3://bucket/path" diff --git a/tests/encryption/test_key_metadata.py b/tests/encryption/test_key_metadata.py new file mode 100644 index 0000000000..53b878e34a --- /dev/null +++ b/tests/encryption/test_key_metadata.py @@ -0,0 +1,77 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os + +import pytest + +from pyiceberg.encryption.key_metadata import StandardKeyMetadata + + +class TestStandardKeyMetadata: + def test_roundtrip_key_only(self) -> None: + key = os.urandom(16) + original = StandardKeyMetadata(encryption_key=key) + serialized = original.serialize() + restored = StandardKeyMetadata.deserialize(serialized) + assert restored.encryption_key == key + assert restored.aad_prefix == b"" + assert restored.file_length is None + + def test_roundtrip_with_aad_prefix(self) -> None: + key = os.urandom(16) + aad = os.urandom(8) + original = StandardKeyMetadata(encryption_key=key, aad_prefix=aad) + serialized = original.serialize() + restored = StandardKeyMetadata.deserialize(serialized) + assert restored.encryption_key == key + assert restored.aad_prefix == aad + assert restored.file_length is None + + def test_roundtrip_all_fields(self) -> None: + key = os.urandom(32) + aad = os.urandom(16) + original = StandardKeyMetadata(encryption_key=key, aad_prefix=aad, file_length=12345) + serialized = original.serialize() + restored = StandardKeyMetadata.deserialize(serialized) + assert restored.encryption_key == key + assert restored.aad_prefix == aad + assert restored.file_length == 12345 + + def test_version_byte(self) -> None: + key = os.urandom(16) + serialized = StandardKeyMetadata(encryption_key=key).serialize() + assert serialized[0] == 0x01 + + def test_deserialize_empty(self) -> None: + with pytest.raises(ValueError, match="Empty key metadata"): + StandardKeyMetadata.deserialize(b"") + + def test_deserialize_wrong_version(self) -> None: + with pytest.raises(ValueError, match="Unsupported key metadata version"): + StandardKeyMetadata.deserialize(b"\x02\x00") + + def test_frozen(self) -> None: + skm = StandardKeyMetadata(encryption_key=b"key") + with pytest.raises(AttributeError): + skm.encryption_key = b"other" # type: ignore[misc] + + def test_roundtrip_large_file_length(self) -> None: + # Exercise zigzag varint for a value beyond int32. + key = os.urandom(16) + original = StandardKeyMetadata(encryption_key=key, file_length=2**40) + restored = StandardKeyMetadata.deserialize(original.serialize()) + assert restored.file_length == 2**40 diff --git a/tests/encryption/test_kms.py b/tests/encryption/test_kms.py new file mode 100644 index 0000000000..ac1e82dd88 --- /dev/null +++ b/tests/encryption/test_kms.py @@ -0,0 +1,136 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os + +import pytest + +from pyiceberg.encryption.kms import InMemoryKms, KeyManagementClient, load_kms_client + + +class TestInMemoryKms: + def test_wrap_unwrap_roundtrip(self) -> None: + master_key = os.urandom(16) + kms = InMemoryKms(master_keys={"keyA": master_key}) + plaintext_key = os.urandom(16) + wrapped = kms.wrap_key(plaintext_key, "keyA") + unwrapped = kms.unwrap_key(wrapped, "keyA") + assert unwrapped == plaintext_key + + def test_unknown_wrapping_key(self) -> None: + kms = InMemoryKms() + with pytest.raises(ValueError, match="Wrapping key not found"): + kms.wrap_key(b"key", "nonexistent") + + def test_unknown_unwrapping_key(self) -> None: + kms = InMemoryKms() + with pytest.raises(ValueError, match="Wrapping key not found"): + kms.unwrap_key(b"wrapped", "nonexistent") + + def test_initialize_from_properties(self) -> None: + kms = InMemoryKms() + key_hex = os.urandom(16).hex() + kms.initialize({"encryption.kms.key.testKey": key_hex}) + wrapped = kms.wrap_key(b"secret", "testKey") + assert kms.unwrap_key(wrapped, "testKey") == b"secret" + + def test_initialize_ignores_unrelated_properties(self) -> None: + kms = InMemoryKms() + kms.initialize({"some.other.prop": "value", "encryption.kms.key.k1": os.urandom(16).hex()}) + with pytest.raises(ValueError, match="Wrapping key not found"): + kms.wrap_key(b"key", "nonexistent") + + def test_wrap_unwrap_with_standard_test_keys(self) -> None: + kms = InMemoryKms( + master_keys={ + "keyA": b"0123456789012345", + "keyB": b"1123456789012345", + } + ) + plaintext = os.urandom(16) + wrapped_a = kms.wrap_key(plaintext, "keyA") + assert kms.unwrap_key(wrapped_a, "keyA") == plaintext + wrapped_b = kms.wrap_key(plaintext, "keyB") + assert kms.unwrap_key(wrapped_b, "keyB") == plaintext + + def test_wrong_master_key_fails_unwrap(self) -> None: + from cryptography.exceptions import InvalidTag + + kms = InMemoryKms( + master_keys={ + "keyA": os.urandom(16), + "keyB": os.urandom(16), + } + ) + wrapped = kms.wrap_key(b"secret", "keyA") + with pytest.raises(InvalidTag): + kms.unwrap_key(wrapped, "keyB") + + +class TestLoadKmsClient: + def test_returns_none_when_not_configured(self) -> None: + assert load_kms_client({}) is None + + def test_loads_in_memory_kms(self) -> None: + client = load_kms_client( + { + "py-kms-impl": "pyiceberg.encryption.kms.InMemoryKms", + "encryption.kms.key.myKey": os.urandom(16).hex(), + } + ) + assert isinstance(client, InMemoryKms) + wrapped = client.wrap_key(b"data", "myKey") + assert client.unwrap_key(wrapped, "myKey") == b"data" + + def test_invalid_short_path(self) -> None: + with pytest.raises(ValueError, match="full path"): + load_kms_client({"py-kms-impl": "InMemoryKms"}) + + def test_nonexistent_module(self) -> None: + with pytest.raises(ValueError, match="Could not import"): + load_kms_client({"py-kms-impl": "nonexistent.module.Kms"}) + + def test_nonexistent_class(self) -> None: + with pytest.raises(ValueError, match="not found in module"): + load_kms_client({"py-kms-impl": "pyiceberg.encryption.kms.NonexistentClass"}) + + def test_not_a_subclass(self) -> None: + # AESGCM is a real class but not a KeyManagementClient + with pytest.raises(ValueError, match="not a subclass"): + load_kms_client({"py-kms-impl": "cryptography.hazmat.primitives.ciphers.aead.AESGCM"}) + + def test_custom_kms_impl(self) -> None: + class _TestKms(KeyManagementClient): + initialized_with: dict[str, str] = {} + + def wrap_key(self, key: bytes, wrapping_key_id: str) -> bytes: + return key + + def unwrap_key(self, wrapped_key: bytes, wrapping_key_id: str) -> bytes: + return wrapped_key + + def initialize(self, properties: dict[str, str]) -> None: + _TestKms.initialized_with = properties + + import pyiceberg.encryption.kms as kms_module + + kms_module._TestKms = _TestKms # type: ignore[attr-defined] + try: + client = load_kms_client({"py-kms-impl": "pyiceberg.encryption.kms._TestKms", "foo": "bar"}) + assert isinstance(client, _TestKms) + assert _TestKms.initialized_with.get("foo") == "bar" + finally: + delattr(kms_module, "_TestKms") diff --git a/tests/encryption/test_manager.py b/tests/encryption/test_manager.py new file mode 100644 index 0000000000..af83d1a9fe --- /dev/null +++ b/tests/encryption/test_manager.py @@ -0,0 +1,144 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os +import struct + +import pytest + +from pyiceberg.encryption.ciphers import ( + GCM_STREAM_MAGIC, + NONCE_LENGTH, + aes_gcm_encrypt, + stream_block_aad, +) +from pyiceberg.encryption.key_metadata import StandardKeyMetadata +from pyiceberg.encryption.kms import InMemoryKms +from pyiceberg.encryption.manager import EncryptedKey, EncryptionManager + + +def _make_ags1_stream(key: bytes, plaintext: bytes, aad_prefix: bytes, plain_block_size: int = 1024 * 1024) -> bytes: + from cryptography.hazmat.primitives.ciphers.aead import AESGCM + + aesgcm = AESGCM(key) + header = GCM_STREAM_MAGIC + struct.pack(" tuple[EncryptionManager, bytes, bytes, str]: + master_key = b"0123456789012345" + kms = InMemoryKms(master_keys={"keyA": master_key}) + + kek_bytes = os.urandom(16) + kek_wrapped = kms.wrap_key(kek_bytes, "keyA") + kek_timestamp = "1234567890" + + dek = os.urandom(16) + aad_prefix = os.urandom(8) + key_metadata = StandardKeyMetadata(encryption_key=dek, aad_prefix=aad_prefix) + wrapped_dek = aes_gcm_encrypt(kek_bytes, key_metadata.serialize(), aad=kek_timestamp.encode("utf-8")) + + encryption_keys = { + "kek-1": EncryptedKey( + key_id="kek-1", + encrypted_key_metadata=kek_wrapped, + encrypted_by_id="keyA", + properties={"KEY_TIMESTAMP": kek_timestamp}, + ), + "mlk-1": EncryptedKey( + key_id="mlk-1", + encrypted_key_metadata=wrapped_dek, + encrypted_by_id="kek-1", + ), + } + + manager = EncryptionManager(kms_client=kms, encryption_keys=encryption_keys) + return manager, dek, aad_prefix, "mlk-1" + + +class TestEncryptionManager: + def test_decrypt_manifest_list(self) -> None: + manager, dek, aad_prefix, mlk_id = _build_test_encryption_manager() + plaintext = b"manifest list content here" + encrypted_stream = _make_ags1_stream(dek, plaintext, aad_prefix) + result = manager.decrypt_manifest_list(encrypted_stream, mlk_id) + assert result == plaintext + + def test_decrypt_manifest(self) -> None: + dek = os.urandom(16) + aad_prefix = os.urandom(8) + plaintext = b"manifest content here" + encrypted_stream = _make_ags1_stream(dek, plaintext, aad_prefix) + key_metadata = StandardKeyMetadata(encryption_key=dek, aad_prefix=aad_prefix) + + manager = EncryptionManager(kms_client=InMemoryKms()) + result = manager.decrypt_manifest(encrypted_stream, key_metadata.serialize()) + assert result == plaintext + + def test_kek_caching(self) -> None: + manager, dek, aad_prefix, mlk_id = _build_test_encryption_manager() + encrypted = _make_ags1_stream(dek, b"test", aad_prefix) + + manager.decrypt_manifest_list(encrypted, mlk_id) + manager.decrypt_manifest_list(encrypted, mlk_id) + + assert "kek-1" in manager._kek_cache + + def test_missing_snapshot_key(self) -> None: + kms = InMemoryKms() + manager = EncryptionManager(kms_client=kms) + with pytest.raises(ValueError, match="Snapshot key not found"): + manager.decrypt_manifest_list(b"data", "nonexistent-key") + + def test_missing_kek(self) -> None: + kms = InMemoryKms() + encryption_keys = { + "mlk-1": EncryptedKey( + key_id="mlk-1", + encrypted_key_metadata=b"wrapped", + encrypted_by_id="missing-kek", + ), + } + manager = EncryptionManager(kms_client=kms, encryption_keys=encryption_keys) + with pytest.raises(ValueError, match="KEK not found"): + manager.decrypt_manifest_list(b"data", "mlk-1") + + def test_kek_without_encrypted_by_id(self) -> None: + kms = InMemoryKms(master_keys={"keyA": os.urandom(16)}) + encryption_keys = { + "kek-1": EncryptedKey(key_id="kek-1", encrypted_key_metadata=b"data"), + "mlk-1": EncryptedKey( + key_id="mlk-1", + encrypted_key_metadata=b"wrapped", + encrypted_by_id="kek-1", + ), + } + manager = EncryptionManager(kms_client=kms, encryption_keys=encryption_keys) + with pytest.raises(ValueError, match="has no encrypted_by_id"): + manager.decrypt_manifest_list(b"data", "mlk-1") diff --git a/tests/integration/test_encryption.py b/tests/integration/test_encryption.py new file mode 100644 index 0000000000..2c546200c8 --- /dev/null +++ b/tests/integration/test_encryption.py @@ -0,0 +1,110 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Read Spark-written, Parquet-encrypted Iceberg tables (Hive catalog) via PyIceberg. + +`hive.default.test_encrypted` is written by Spark in `dev/provision.py` using +`encryption.kms-impl=org.apache.iceberg.encryption.UnitestKMS`. UnitestKMS uses fixed master +keys keyA/keyB which we mirror into PyIceberg's InMemoryKms below. +""" + +from __future__ import annotations + +import pytest + +from pyiceberg.catalog import load_catalog + +_KEY_A_HEX = b"0123456789012345".hex() +_KEY_B_HEX = b"1123456789012345".hex() + + +@pytest.fixture(scope="module") +def hive_catalog_with_kms(): # type: ignore[no-untyped-def] + return load_catalog( + "local", + **{ + "type": "hive", + "uri": "thrift://localhost:9083", + "s3.endpoint": "http://localhost:9000", + "s3.access-key-id": "admin", + "s3.secret-access-key": "password", + "py-kms-impl": "pyiceberg.encryption.kms.InMemoryKms", + "encryption.kms.key.keyA": _KEY_A_HEX, + "encryption.kms.key.keyB": _KEY_B_HEX, + }, + ) + + +@pytest.mark.integration +def test_encrypted_table_metadata(hive_catalog_with_kms) -> None: # type: ignore[no-untyped-def] + tbl = hive_catalog_with_kms.load_table("default.test_encrypted") + + assert tbl.metadata.format_version == 3 + assert tbl.metadata.properties.get("encryption.key-id") == "keyA" + assert tbl.metadata.encryption_keys + assert tbl.current_snapshot() is not None + assert tbl.current_snapshot().key_id is not None + + +@pytest.mark.integration +def test_encrypted_table_to_arrow(hive_catalog_with_kms) -> None: # type: ignore[no-untyped-def] + tbl = hive_catalog_with_kms.load_table("default.test_encrypted") + result = tbl.scan().to_arrow().sort_by("id") + assert result.column("id").to_pylist() == [1, 2, 3] + assert result.column("data").to_pylist() == ["alice", "bob", "charlie"] + assert result.column("value").to_pylist() == [1.0, 2.0, 3.0] + + +@pytest.mark.integration +def test_encrypted_table_to_pandas(hive_catalog_with_kms) -> None: # type: ignore[no-untyped-def] + tbl = hive_catalog_with_kms.load_table("default.test_encrypted") + df = tbl.scan().to_pandas().sort_values("id").reset_index(drop=True) + assert list(df["id"]) == [1, 2, 3] + assert list(df["data"]) == ["alice", "bob", "charlie"] + assert list(df["value"]) == [1.0, 2.0, 3.0] + + +@pytest.mark.integration +def test_encrypted_table_to_duckdb(hive_catalog_with_kms) -> None: # type: ignore[no-untyped-def] + tbl = hive_catalog_with_kms.load_table("default.test_encrypted") + con = tbl.scan().to_duckdb("encrypted") + rows = con.execute("SELECT id, data, value FROM encrypted ORDER BY id").fetchall() + assert rows == [(1, "alice", 1.0), (2, "bob", 2.0), (3, "charlie", 3.0)] + + +@pytest.mark.integration +def test_encrypted_table_to_polars(hive_catalog_with_kms) -> None: # type: ignore[no-untyped-def] + tbl = hive_catalog_with_kms.load_table("default.test_encrypted") + df = tbl.scan().to_polars().sort("id") + assert df["id"].to_list() == [1, 2, 3] + assert df["data"].to_list() == ["alice", "bob", "charlie"] + assert df["value"].to_list() == [1.0, 2.0, 3.0] + + +@pytest.mark.integration +def test_encrypted_table_direct_parquet_read_fails(hive_catalog_with_kms) -> None: # type: ignore[no-untyped-def] + # Mirrors iceberg-java's TestTableEncryption#testDirectDataFileRead — proves the data + # files are really PME-encrypted (raw read without keys must fail), so the above tests + # can't silently pass on plaintext. + import pyarrow.parquet as pq + + tbl = hive_catalog_with_kms.load_table("default.test_encrypted") + data_files = [task.file.file_path for task in tbl.scan().plan_files()] + assert data_files + + for file_path in data_files: + with pytest.raises(OSError, match="encrypted"), tbl.io.new_input(file_path).open() as fi: + pq.read_table(fi) diff --git a/tests/integration/test_inspect_table.py b/tests/integration/test_inspect_table.py index 03d4437d18..4d8dfbe9bb 100644 --- a/tests/integration/test_inspect_table.py +++ b/tests/integration/test_inspect_table.py @@ -1012,6 +1012,7 @@ def test_inspect_all_manifests(spark: SparkSession, session_catalog: Catalog, fo "deleted_delete_files_count", "partition_summaries", "reference_snapshot_id", + "key_metadata", ] int_cols = [