diff --git a/test/_helpers.py b/test/_helpers.py new file mode 100644 index 0000000..d1422bf --- /dev/null +++ b/test/_helpers.py @@ -0,0 +1,71 @@ +"""Shared test helpers for fickling test suite.""" + +from __future__ import annotations + +import pickle +from typing import Any + +from fickling.analysis import check_safety +from fickling.fickle import Pickled + + +def make_malicious_pickle( + module: str, func: str, args: tuple[Any, ...] = (), protocol: int = 4 +) -> bytes: + """Create a malicious pickle that calls module.func(*args). + + Uses __reduce__ to serialize a payload that, when unpickled, would call + the specified function. Note: Python resolves the module at pickle time + via importlib, so dotted module paths like "os.path" resolve to their + actual implementation (e.g., posixpath on Unix). + """ + + class Payload: + def __reduce__(self) -> tuple[Any, tuple[Any, ...]]: + import importlib + + mod = importlib.import_module(module) + fn = getattr(mod, func) + return (fn, args) + + return pickle.dumps(Payload(), protocol=protocol) + + +def make_benign_pickle(data: Any | None = None, protocol: int = 4) -> bytes: + """Create a benign pickle with safe data.""" + if data is None: + data = [1, 2, 3] + return pickle.dumps(data, protocol=protocol) + + +def make_pickle(obj: Any, protocol: int = 4) -> bytes: + """Create a pickle from a Python object.""" + return pickle.dumps(obj, protocol=protocol) + + +def assert_not_malicious(data: bytes) -> None: + """Assert that a pickle is not flagged as overtly malicious. + + Standard library imports may be flagged as SUSPICIOUS (unused variable) + or LIKELY_UNSAFE (non-standard imports) but should NEVER be flagged as + OVERTLY_MALICIOUS unless they're actually dangerous. + """ + from fickling.analysis import Severity + + pickled = Pickled.load(data) + result = check_safety(pickled) + assert result.severity < Severity.LIKELY_OVERTLY_MALICIOUS, ( + f"Safe object incorrectly flagged as malicious. " + f"Severity: {result.severity.name}. Results: {result.to_string()}" + ) + + +def assert_likely_safe(data: bytes) -> None: + """Assert that a pickle is LIKELY_SAFE (pure data, no imports).""" + from fickling.analysis import Severity + + pickled = Pickled.load(data) + result = check_safety(pickled) + assert result.severity == Severity.LIKELY_SAFE, ( + f"Expected LIKELY_SAFE, got {result.severity.name}. Results: {result.to_string()}" + ) diff --git a/test/test_archive_scanning.py b/test/test_archive_scanning.py new file mode 100644 index 0000000..333ce28 --- /dev/null +++ b/test/test_archive_scanning.py @@ -0,0 +1,336 @@ +"""Archive-based attack tests for fickling. + +These tests verify that fickling can detect malicious pickles embedded +in various archive formats (ZIP, TAR, etc.) that are commonly used to +distribute ML models. + +Key patterns tested: +- Malicious pickle inside ZIP archives +- Malicious pickle inside TAR archives +- Nested archives with malicious content +- PyTorch-style ZIP structures with malicious pickles +""" + +from __future__ import annotations + +import io +import tarfile +import zipfile +from pathlib import Path + +import pytest + +from fickling.analysis import Severity, check_safety +from fickling.fickle import Pickled +from test._helpers import make_benign_pickle, make_malicious_pickle + +# ============================================================================= +# ZIP Archive Tests +# ============================================================================= + + +def test_malicious_pickle_in_zip() -> None: + """Malicious pickle inside a ZIP archive should be detected.""" + malicious = make_malicious_pickle("os", "system", ("id",)) + + buffer = io.BytesIO() + with zipfile.ZipFile(buffer, "w", zipfile.ZIP_STORED) as zf: + zf.writestr("data.pkl", malicious) + + buffer.seek(0) + with zipfile.ZipFile(buffer, "r") as zf: + pkl_data = zf.read("data.pkl") + pickled = Pickled.load(pkl_data) + result = check_safety(pickled) + assert result.severity >= Severity.LIKELY_OVERTLY_MALICIOUS, ( + "Failed to detect malicious pickle in ZIP" + ) + + +def test_malicious_pickle_in_nested_zip_path() -> None: + """Malicious pickle in nested ZIP path should be detected.""" + malicious = make_malicious_pickle("subprocess", "call", (["id"],)) + + buffer = io.BytesIO() + with zipfile.ZipFile(buffer, "w", zipfile.ZIP_DEFLATED) as zf: + # Mimic PyTorch model structure + zf.writestr("model/archive/data.pkl", malicious) + zf.writestr("model/version", "1") + + buffer.seek(0) + with zipfile.ZipFile(buffer, "r") as zf: + pkl_data = zf.read("model/archive/data.pkl") + pickled = Pickled.load(pkl_data) + result = check_safety(pickled) + assert result.severity >= Severity.LIKELY_OVERTLY_MALICIOUS, ( + "Failed to detect malicious pickle in nested ZIP path" + ) + + +def test_zip_with_multiple_pickles_mixed() -> None: + """ZIP with mixed benign and malicious pickles should detect malicious ones.""" + benign = make_benign_pickle([1, 2, 3]) + malicious = make_malicious_pickle("builtins", "eval", ("1+1",)) + + buffer = io.BytesIO() + with zipfile.ZipFile(buffer, "w") as zf: + zf.writestr("data/safe_model.pkl", benign) + zf.writestr("data/payload.pkl", malicious) + zf.writestr("data/another_safe.pkl", benign) + + buffer.seek(0) + with zipfile.ZipFile(buffer, "r") as zf: + # Check benign pickles + for safe_name in ["data/safe_model.pkl", "data/another_safe.pkl"]: + pkl_data = zf.read(safe_name) + pickled = Pickled.load(pkl_data) + result = check_safety(pickled) + assert result.severity == Severity.LIKELY_SAFE, ( + f"Safe pickle {safe_name} incorrectly flagged" + ) + + # Check malicious pickle + pkl_data = zf.read("data/payload.pkl") + pickled = Pickled.load(pkl_data) + result = check_safety(pickled) + assert result.severity >= Severity.LIKELY_OVERTLY_MALICIOUS, ( + "Failed to detect malicious pickle among safe ones" + ) + + +def test_pytorch_style_zip_with_malicious_data() -> None: + """PyTorch-style ZIP with malicious data.pkl should be detected.""" + malicious = make_malicious_pickle("socket", "socket", ()) + + buffer = io.BytesIO() + with zipfile.ZipFile(buffer, "w") as zf: + # Mimic PyTorch v1.3+ format + zf.writestr("archive/data.pkl", malicious) + zf.writestr("archive/version", "3") + zf.writestr("archive/data/0", b"\x00" * 100) # Fake tensor data + + buffer.seek(0) + with zipfile.ZipFile(buffer, "r") as zf: + pkl_data = zf.read("archive/data.pkl") + pickled = Pickled.load(pkl_data) + result = check_safety(pickled) + assert result.severity >= Severity.LIKELY_OVERTLY_MALICIOUS, ( + "Failed to detect malicious PyTorch-style pickle" + ) + + +# ============================================================================= +# TAR Archive Tests +# ============================================================================= + + +def test_malicious_pickle_in_tar() -> None: + """Malicious pickle inside a TAR archive should be detected.""" + malicious = make_malicious_pickle("os", "system", ("id",)) + + buffer = io.BytesIO() + with tarfile.open(fileobj=buffer, mode="w") as tf: + # Add malicious pickle + pkl_io = io.BytesIO(malicious) + info = tarfile.TarInfo(name="model.pkl") + info.size = len(malicious) + tf.addfile(info, pkl_io) + + buffer.seek(0) + with tarfile.open(fileobj=buffer, mode="r") as tf: + member = tf.getmember("model.pkl") + f = tf.extractfile(member) + assert f is not None + pkl_data = f.read() + pickled = Pickled.load(pkl_data) + result = check_safety(pickled) + assert result.severity >= Severity.LIKELY_OVERTLY_MALICIOUS, ( + "Failed to detect malicious pickle in TAR" + ) + + +def test_malicious_pickle_in_tar_gz() -> None: + """Malicious pickle inside a .tar.gz archive should be detected.""" + malicious = make_malicious_pickle("pty", "spawn", ("/bin/sh",)) + + buffer = io.BytesIO() + with tarfile.open(fileobj=buffer, mode="w:gz") as tf: + pkl_io = io.BytesIO(malicious) + info = tarfile.TarInfo(name="payload.pkl") + info.size = len(malicious) + tf.addfile(info, pkl_io) + + buffer.seek(0) + with tarfile.open(fileobj=buffer, mode="r:gz") as tf: + member = tf.getmember("payload.pkl") + f = tf.extractfile(member) + assert f is not None + pkl_data = f.read() + pickled = Pickled.load(pkl_data) + result = check_safety(pickled) + assert result.severity >= Severity.LIKELY_OVERTLY_MALICIOUS, ( + "Failed to detect malicious pickle in .tar.gz" + ) + + +def test_legacy_pytorch_tar_with_malicious_pickle() -> None: + """Legacy PyTorch TAR format with malicious pickle should be detected.""" + malicious = make_malicious_pickle("builtins", "exec", ("import os",)) + + buffer = io.BytesIO() + with tarfile.open(fileobj=buffer, mode="w") as tf: + # Mimic PyTorch v0.1.1 format + pkl_io = io.BytesIO(malicious) + info = tarfile.TarInfo(name="pickle") + info.size = len(malicious) + tf.addfile(info, pkl_io) + + # Add empty storages and tensors directories + info = tarfile.TarInfo(name="storages/") + info.type = tarfile.DIRTYPE + tf.addfile(info) + + info = tarfile.TarInfo(name="tensors/") + info.type = tarfile.DIRTYPE + tf.addfile(info) + + buffer.seek(0) + with tarfile.open(fileobj=buffer, mode="r") as tf: + member = tf.getmember("pickle") + f = tf.extractfile(member) + assert f is not None + pkl_data = f.read() + pickled = Pickled.load(pkl_data) + result = check_safety(pickled) + assert result.severity >= Severity.LIKELY_OVERTLY_MALICIOUS, ( + "Failed to detect malicious pickle in legacy PyTorch TAR" + ) + + +# ============================================================================= +# File-based Archive Tests +# ============================================================================= + + +def test_zip_file_on_disk(tmp_path: Path) -> None: + """Malicious pickle in ZIP file on disk should be detected.""" + malicious = make_malicious_pickle("os", "popen", ("id",)) + zip_path = tmp_path / "model.zip" + + with zipfile.ZipFile(zip_path, "w") as zf: + zf.writestr("model.pkl", malicious) + + with zipfile.ZipFile(zip_path, "r") as zf: + pkl_data = zf.read("model.pkl") + pickled = Pickled.load(pkl_data) + result = check_safety(pickled) + assert result.severity >= Severity.LIKELY_OVERTLY_MALICIOUS + + +def test_tar_file_on_disk(tmp_path: Path) -> None: + """Malicious pickle in TAR file on disk should be detected.""" + malicious = make_malicious_pickle("runpy", "run_path", ("/tmp/evil.py",)) + tar_path = tmp_path / "model.tar" + + with tarfile.open(tar_path, "w") as tf: + pkl_io = io.BytesIO(malicious) + info = tarfile.TarInfo(name="weights.pkl") + info.size = len(malicious) + tf.addfile(info, pkl_io) + + with tarfile.open(tar_path, "r") as tf: + member = tf.getmember("weights.pkl") + f_extracted = tf.extractfile(member) + assert f_extracted is not None + pkl_data = f_extracted.read() + pickled = Pickled.load(pkl_data) + result = check_safety(pickled) + assert result.severity >= Severity.LIKELY_OVERTLY_MALICIOUS + + +# ============================================================================= +# Edge Cases +# ============================================================================= + + +def test_zip_with_non_pickle_binary() -> None: + """ZIP with non-pickle binary data should not be confused for pickle.""" + # Create some random binary data that's not a pickle + random_data = b"\x00\x01\x02\x03\xff\xfe\xfd" * 100 + + buffer = io.BytesIO() + with zipfile.ZipFile(buffer, "w") as zf: + zf.writestr("model.bin", random_data) + + buffer.seek(0) + with zipfile.ZipFile(buffer, "r") as zf: + bin_data = zf.read("model.bin") + with pytest.raises(ValueError): + Pickled.load(bin_data) + + +def test_deeply_nested_malicious_pickle() -> None: + """Deeply nested malicious pickle should still be detected.""" + malicious = make_malicious_pickle("ctypes", "CDLL", ("libc.so.6",)) + + buffer = io.BytesIO() + with zipfile.ZipFile(buffer, "w") as zf: + zf.writestr("level1/level2/level3/level4/deep_model.pkl", malicious) + + buffer.seek(0) + with zipfile.ZipFile(buffer, "r") as zf: + pkl_data = zf.read("level1/level2/level3/level4/deep_model.pkl") + pickled = Pickled.load(pkl_data) + result = check_safety(pickled) + assert result.severity >= Severity.LIKELY_OVERTLY_MALICIOUS, ( + "Failed to detect deeply nested malicious pickle" + ) + + +# ============================================================================= +# Protocol Version Tests in Archives +# ============================================================================= + + +@pytest.mark.parametrize("protocol", [0, 1, 2, 3, 4, 5]) +def test_all_protocols_in_zip(protocol: int) -> None: + """All pickle protocols should be detected in ZIP archives.""" + malicious = make_malicious_pickle("os", "system", ("id",), protocol=protocol) + + buffer = io.BytesIO() + with zipfile.ZipFile(buffer, "w") as zf: + zf.writestr(f"model_proto{protocol}.pkl", malicious) + + buffer.seek(0) + with zipfile.ZipFile(buffer, "r") as zf: + pkl_data = zf.read(f"model_proto{protocol}.pkl") + pickled = Pickled.load(pkl_data) + result = check_safety(pickled) + assert result.severity >= Severity.LIKELY_OVERTLY_MALICIOUS, ( + f"Failed to detect malicious pickle at protocol {protocol} in ZIP" + ) + + +@pytest.mark.parametrize("protocol", [0, 1, 2, 3, 4, 5]) +def test_all_protocols_in_tar(protocol: int) -> None: + """All pickle protocols should be detected in TAR archives.""" + malicious = make_malicious_pickle("os", "system", ("id",), protocol=protocol) + + buffer = io.BytesIO() + with tarfile.open(fileobj=buffer, mode="w") as tf: + pkl_io = io.BytesIO(malicious) + info = tarfile.TarInfo(name=f"model_proto{protocol}.pkl") + info.size = len(malicious) + tf.addfile(info, pkl_io) + + buffer.seek(0) + with tarfile.open(fileobj=buffer, mode="r") as tf: + member = tf.getmember(f"model_proto{protocol}.pkl") + f = tf.extractfile(member) + assert f is not None + pkl_data = f.read() + pickled = Pickled.load(pkl_data) + result = check_safety(pickled) + assert result.severity >= Severity.LIKELY_OVERTLY_MALICIOUS, ( + f"Failed to detect malicious pickle at protocol {protocol} in TAR" + ) diff --git a/test/test_attack_vectors.py b/test/test_attack_vectors.py new file mode 100644 index 0000000..0acaa41 --- /dev/null +++ b/test/test_attack_vectors.py @@ -0,0 +1,139 @@ +"""Comprehensive attack vector tests for fickling. + +These tests verify that fickling detects various malicious pickle patterns +across all pickle protocols (0-5). Each test generates malicious pickles +programmatically using __reduce__ and verifies detection. + +Inspired by picklescan's malicious test samples. +""" + +from __future__ import annotations + +import pytest + +from fickling.analysis import Severity, check_safety +from fickling.fickle import Pickled +from test._helpers import make_malicious_pickle + +PROTOCOLS = [0, 1, 2, 3, 4, 5] + +# Each entry: (module, func, args, test_id) +ATTACK_VECTORS = [ + pytest.param("os", "system", ("id",), id="os_system"), + pytest.param("os", "popen", ("id",), id="os_popen"), + pytest.param("os", "execv", ("/bin/sh", ["/bin/sh", "-c", "id"]), id="os_execv"), + pytest.param("subprocess", "call", (["id"],), id="subprocess_call"), + pytest.param("subprocess", "Popen", (["id"],), id="subprocess_popen"), + pytest.param("subprocess", "run", (["id"],), id="subprocess_run"), + pytest.param("subprocess", "check_output", (["id"],), id="subprocess_check_output"), + pytest.param( + "builtins", + "eval", + ("__import__('os').system('id')",), + id="builtins_eval", + ), + pytest.param( + "builtins", + "exec", + ("import os; os.system('id')",), + id="builtins_exec", + ), + pytest.param( + "builtins", + "compile", + ("import os; os.system('id')", "", "exec"), + id="builtins_compile", + ), + pytest.param("builtins", "__import__", ("os",), id="builtins_import"), + pytest.param("builtins", "getattr", (object, "__class__"), id="builtins_getattr"), + pytest.param( + "socket", + "create_connection", + (("evil.com", 4444),), + id="socket_create_connection", + ), + pytest.param("socket", "socket", (), id="socket_socket"), + pytest.param("runpy", "run_path", ("/tmp/malicious.py",), id="runpy_run_path"), + pytest.param("runpy", "run_module", ("os",), id="runpy_run_module"), + pytest.param( + "cProfile", + "run", + ("import os; os.system('id')",), + id="cprofile_run", + ), + pytest.param("code", "InteractiveInterpreter", (), id="code_interactiveinterpreter"), + pytest.param("importlib", "import_module", ("os",), id="importlib_import_module"), + pytest.param("ctypes", "CDLL", ("libc.so.6",), id="ctypes_cdll"), + pytest.param("pty", "spawn", ("/bin/sh",), id="pty_spawn"), + pytest.param("pydoc", "locate", ("os.system",), id="pydoc_locate"), + pytest.param( + "multiprocessing.util", + "spawnv_passfds", + (b"/bin/sh", [b"/bin/sh", b"-c", b"id"], ()), + id="multiprocessing_util_spawnv_passfds", + ), + pytest.param("sys", "exit", (0,), id="sys_exit"), + pytest.param("posix", "system", ("id",), id="posix_system"), +] + + +@pytest.mark.parametrize("protocol", PROTOCOLS) +@pytest.mark.parametrize("module,func,args", ATTACK_VECTORS) +def test_attack_vector(module: str, func: str, args: tuple, protocol: int) -> None: + """Verify fickling detects malicious module.func across all protocols.""" + data = make_malicious_pickle(module, func, args, protocol) + pickled = Pickled.load(data) + result = check_safety(pickled) + assert result.severity >= Severity.LIKELY_OVERTLY_MALICIOUS, ( + f"Failed to detect {module}.{func} at protocol {protocol}" + ) + + +# ============================================================================= +# Special cases that need non-standard pickle construction +# ============================================================================= + + +@pytest.mark.parametrize("protocol", PROTOCOLS) +def test_marshal_loads(protocol: int) -> None: + """marshal.loads can deserialize code objects.""" + import marshal + + code = compile('import os; os.system("id")', "", "exec") + code_bytes = marshal.dumps(code) + data = make_malicious_pickle("marshal", "loads", (code_bytes,), protocol) + pickled = Pickled.load(data) + result = check_safety(pickled) + assert result.severity >= Severity.LIKELY_OVERTLY_MALICIOUS, ( + f"Failed to detect marshal.loads at protocol {protocol}" + ) + + +def test_types_functiontype() -> None: + """types.FunctionType can create executable functions. + + Uses raw opcode construction because types.FunctionType cannot be + pickled directly using pickle.dumps with __reduce__. + """ + import fickling.fickle as op + + pickled = Pickled( + [ + op.Proto.create(4), + op.Frame(30), + op.ShortBinUnicode("types"), + op.Memoize(), + op.ShortBinUnicode("FunctionType"), + op.Memoize(), + op.StackGlobal(), + op.Memoize(), + op.EmptyTuple(), + op.Reduce(), + op.Memoize(), + op.Stop(), + ] + ) + result = check_safety(pickled) + assert result.severity >= Severity.LIKELY_OVERTLY_MALICIOUS, ( + "Failed to detect types.FunctionType" + ) diff --git a/test/test_benign_edge_cases.py b/test/test_benign_edge_cases.py new file mode 100644 index 0000000..188d37e --- /dev/null +++ b/test/test_benign_edge_cases.py @@ -0,0 +1,533 @@ +"""Benign edge case tests for fickling. + +These tests verify that fickling correctly identifies safe patterns and prevents +false positives. The key distinction: + +- LIKELY_SAFE: Pure data structures with no function calls +- SUSPICIOUS: Safe imports but unused variables (by design - fickling flags this) +- LIKELY_UNSAFE: Non-standard imports or unsafe calls +- OVERTLY_MALICIOUS: Known dangerous operations + +Known Limitations (Documented False Positives): +- `builtins` module imports (range, slice, set, frozenset) are flagged as LIKELY_OVERTLY_MALICIOUS + because builtins contains dangerous functions like eval/exec. This is a conservative approach. +- Protocol 0 uses GLOBAL opcode which may trigger different detection paths. + +Key patterns tested: +- Standard library safe imports should NOT be OVERTLY_MALICIOUS (most cases) +- Pure data structures (no imports) should be LIKELY_SAFE +- Custom class serialization +- NumPy arrays and dtypes (optional) +- False positive prevention for strings containing dangerous keywords +""" + +from __future__ import annotations + +from collections import Counter, OrderedDict +from dataclasses import dataclass +from datetime import date, datetime, time, timedelta, timezone +from decimal import Decimal +from enum import Enum, IntEnum +from fractions import Fraction +from pathlib import PurePosixPath +from typing import Any, NamedTuple +from uuid import UUID + +import pytest + +from fickling.analysis import Severity, check_safety +from fickling.fickle import Pickled +from test._helpers import assert_likely_safe, assert_not_malicious, make_pickle + +# Higher protocols use more modern opcodes with fewer false positives +# Note: Protocol 5 with NumPy can cause parsing issues in fickling +HIGHER_PROTOCOLS = [2, 3, 4, 5] +HIGHER_PROTOCOLS_NUMPY = [2, 3, 4] # Exclude protocol 5 for NumPy due to out-of-band data +ALL_PROTOCOLS = [0, 1, 2, 3, 4, 5] + + +# ============================================================================= +# Pure Data Structures (No Imports) - Should be LIKELY_SAFE +# Higher protocols (2+) are more likely to produce clean pickles. +# ============================================================================= + + +@pytest.mark.parametrize("protocol", HIGHER_PROTOCOLS) +def test_primitive_types_safe(protocol: int) -> None: + """Primitive types (int, float, str, bool, None) should be LIKELY_SAFE.""" + for obj in [42, 3.14, "hello", True, False, None]: + data = make_pickle(obj, protocol) + assert_likely_safe(data) + + +@pytest.mark.parametrize("protocol", HIGHER_PROTOCOLS) +def test_list_safe(protocol: int) -> None: + """Lists should be LIKELY_SAFE.""" + data = make_pickle([1, 2, 3, "a", "b", "c"], protocol) + assert_likely_safe(data) + + +@pytest.mark.parametrize("protocol", HIGHER_PROTOCOLS) +def test_dict_safe(protocol: int) -> None: + """Dicts should be LIKELY_SAFE.""" + data = make_pickle({"key": "value", "num": 42}, protocol) + assert_likely_safe(data) + + +@pytest.mark.parametrize("protocol", HIGHER_PROTOCOLS) +def test_tuple_with_many_elements_safe(protocol: int) -> None: + """Tuples with many elements should be LIKELY_SAFE.""" + data = make_pickle(tuple(range(100)), protocol) + assert_likely_safe(data) + + +@pytest.mark.parametrize("protocol", HIGHER_PROTOCOLS) +def test_deeply_nested_list_safe(protocol: int) -> None: + """Deeply nested lists should be LIKELY_SAFE.""" + nested = [[[[[[1, 2, 3]]]]]] + data = make_pickle(nested, protocol) + assert_likely_safe(data) + + +@pytest.mark.parametrize("protocol", HIGHER_PROTOCOLS) +def test_deeply_nested_dict_safe(protocol: int) -> None: + """Deeply nested dicts should be LIKELY_SAFE.""" + nested = {"a": {"b": {"c": {"d": {"e": 1}}}}} + data = make_pickle(nested, protocol) + assert_likely_safe(data) + + +@pytest.mark.parametrize("protocol", HIGHER_PROTOCOLS) +def test_large_list_safe(protocol: int) -> None: + """Large lists should be LIKELY_SAFE.""" + data = make_pickle(list(range(10000)), protocol) + assert_likely_safe(data) + + +@pytest.mark.parametrize("protocol", HIGHER_PROTOCOLS) +def test_large_dict_safe(protocol: int) -> None: + """Large dicts should be LIKELY_SAFE.""" + d = {f"key_{i}": i for i in range(1000)} + data = make_pickle(d, protocol) + assert_likely_safe(data) + + +@pytest.mark.parametrize("protocol", HIGHER_PROTOCOLS) +def test_unicode_strings_safe(protocol: int) -> None: + """Unicode strings with special characters should be LIKELY_SAFE.""" + data = make_pickle("Hello 世界 🌍 αβγ", protocol) + assert_likely_safe(data) + + +@pytest.mark.parametrize("protocol", HIGHER_PROTOCOLS) +def test_empty_containers_safe(protocol: int) -> None: + """Empty containers should be LIKELY_SAFE.""" + for obj in [[], {}, ()]: + data = make_pickle(obj, protocol) + assert_likely_safe(data) + + +# ============================================================================= +# Standard Library Objects - Should NOT be OVERTLY_MALICIOUS +# +# These objects use __reduce__ which creates function calls, so fickling +# correctly marks them as SUSPICIOUS (unused variable). But they should +# NEVER be flagged as OVERTLY_MALICIOUS since they're safe stdlib objects. +# ============================================================================= + + +@pytest.mark.parametrize("protocol", HIGHER_PROTOCOLS) +def test_datetime_not_malicious(protocol: int) -> None: + """datetime objects should not be flagged as malicious.""" + data = make_pickle(datetime.now(), protocol) + assert_not_malicious(data) + + +@pytest.mark.parametrize("protocol", HIGHER_PROTOCOLS) +def test_date_not_malicious(protocol: int) -> None: + """date objects should not be flagged as malicious.""" + data = make_pickle(date.today(), protocol) + assert_not_malicious(data) + + +@pytest.mark.parametrize("protocol", HIGHER_PROTOCOLS) +def test_time_not_malicious(protocol: int) -> None: + """time objects should not be flagged as malicious.""" + data = make_pickle(time(12, 30, 45), protocol) + assert_not_malicious(data) + + +@pytest.mark.parametrize("protocol", HIGHER_PROTOCOLS) +def test_timedelta_not_malicious(protocol: int) -> None: + """timedelta objects should not be flagged as malicious.""" + data = make_pickle(timedelta(days=1, hours=2), protocol) + assert_not_malicious(data) + + +@pytest.mark.parametrize("protocol", HIGHER_PROTOCOLS) +def test_timezone_not_malicious(protocol: int) -> None: + """timezone objects should not be flagged as malicious.""" + data = make_pickle(timezone.utc, protocol) + assert_not_malicious(data) + + +@pytest.mark.parametrize("protocol", HIGHER_PROTOCOLS) +def test_uuid_not_malicious(protocol: int) -> None: + """UUID objects should not be flagged as malicious.""" + data = make_pickle(UUID("12345678-1234-5678-1234-567812345678"), protocol) + assert_not_malicious(data) + + +@pytest.mark.parametrize("protocol", HIGHER_PROTOCOLS) +def test_decimal_not_malicious(protocol: int) -> None: + """Decimal objects should not be flagged as malicious.""" + data = make_pickle(Decimal("123.456"), protocol) + assert_not_malicious(data) + + +@pytest.mark.parametrize("protocol", HIGHER_PROTOCOLS) +def test_fraction_not_malicious(protocol: int) -> None: + """Fraction objects should not be flagged as malicious.""" + data = make_pickle(Fraction(1, 3), protocol) + assert_not_malicious(data) + + +@pytest.mark.parametrize("protocol", HIGHER_PROTOCOLS) +def test_pathlib_not_malicious(protocol: int) -> None: + """PurePath objects should not be flagged as malicious.""" + data = make_pickle(PurePosixPath("/usr/local/bin"), protocol) + assert_not_malicious(data) + + +# ============================================================================= +# Collections Module - Should NOT be OVERTLY_MALICIOUS +# ============================================================================= + + +@pytest.mark.parametrize("protocol", HIGHER_PROTOCOLS) +def test_ordered_dict_not_malicious(protocol: int) -> None: + """OrderedDict should not be flagged as malicious.""" + data = make_pickle(OrderedDict([("a", 1), ("b", 2)]), protocol) + assert_not_malicious(data) + + +@pytest.mark.parametrize("protocol", HIGHER_PROTOCOLS) +def test_counter_not_malicious(protocol: int) -> None: + """Counter should not be flagged as malicious.""" + data = make_pickle(Counter(["a", "b", "a", "c"]), protocol) + assert_not_malicious(data) + + +# ============================================================================= +# Custom Class Serialization - Should NOT be OVERTLY_MALICIOUS +# ============================================================================= + + +class SimpleEnum(Enum): + """Simple enum for testing.""" + + VALUE_A = "a" + VALUE_B = "b" + + +class SimpleIntEnum(IntEnum): + """Simple int enum for testing.""" + + ONE = 1 + TWO = 2 + + +@pytest.mark.parametrize("protocol", HIGHER_PROTOCOLS) +def test_enum_not_malicious(protocol: int) -> None: + """Enum subclasses should not be flagged as malicious.""" + data = make_pickle(SimpleEnum.VALUE_A, protocol) + assert_not_malicious(data) + + +@pytest.mark.parametrize("protocol", HIGHER_PROTOCOLS) +def test_int_enum_not_malicious(protocol: int) -> None: + """IntEnum subclasses should not be flagged as malicious.""" + data = make_pickle(SimpleIntEnum.ONE, protocol) + assert_not_malicious(data) + + +class SimpleNamedTuple(NamedTuple): + """Simple named tuple for testing.""" + + x: int + y: str + z: float + + +@pytest.mark.parametrize("protocol", HIGHER_PROTOCOLS) +def test_namedtuple_not_malicious(protocol: int) -> None: + """NamedTuple instances should not be flagged as malicious.""" + data = make_pickle(SimpleNamedTuple(1, "hello", 3.14), protocol) + assert_not_malicious(data) + + +@dataclass +class SimpleDataclass: + """Simple dataclass for testing.""" + + name: str + value: int + items: list[int] + + +@pytest.mark.parametrize("protocol", HIGHER_PROTOCOLS) +def test_dataclass_not_malicious(protocol: int) -> None: + """Simple dataclasses without callables should not be flagged as malicious.""" + data = make_pickle(SimpleDataclass("test", 42, [1, 2, 3]), protocol) + assert_not_malicious(data) + + +class CustomGetState: + """Class with __getstate__/__setstate__ for data only.""" + + def __init__(self, value: int) -> None: + self.value = value + + def __getstate__(self) -> dict[str, Any]: + return {"value": self.value} + + def __setstate__(self, state: dict[str, Any]) -> None: + self.value = state["value"] + + +@pytest.mark.parametrize("protocol", HIGHER_PROTOCOLS) +def test_custom_getstate_setstate_not_malicious(protocol: int) -> None: + """Classes with __getstate__/__setstate__ should not be flagged as malicious.""" + data = make_pickle(CustomGetState(42), protocol) + assert_not_malicious(data) + + +# ============================================================================= +# NumPy Edge Cases (Optional - skip if numpy not installed) +# Should NOT be OVERTLY_MALICIOUS +# ============================================================================= + + +@pytest.mark.parametrize("protocol", HIGHER_PROTOCOLS_NUMPY) +def test_numpy_array_not_malicious(protocol: int) -> None: + """NumPy arrays should not be flagged as malicious.""" + np = pytest.importorskip("numpy") + arr = np.array([1, 2, 3, 4, 5]) + data = make_pickle(arr, protocol) + assert_not_malicious(data) + + +@pytest.mark.parametrize("protocol", HIGHER_PROTOCOLS_NUMPY) +def test_numpy_multidimensional_not_malicious(protocol: int) -> None: + """Multi-dimensional NumPy arrays should not be flagged as malicious.""" + np = pytest.importorskip("numpy") + arr = np.array([[1, 2, 3], [4, 5, 6]]) + data = make_pickle(arr, protocol) + assert_not_malicious(data) + + +@pytest.mark.parametrize("protocol", HIGHER_PROTOCOLS_NUMPY) +def test_numpy_scalar_not_malicious(protocol: int) -> None: + """NumPy scalars (int32, float64) should not be flagged as malicious.""" + np = pytest.importorskip("numpy") + data = make_pickle(np.int32(42), protocol) + assert_not_malicious(data) + data = make_pickle(np.float64(3.14), protocol) + assert_not_malicious(data) + + +@pytest.mark.parametrize("protocol", HIGHER_PROTOCOLS_NUMPY) +def test_numpy_structured_array_not_malicious(protocol: int) -> None: + """Structured NumPy arrays with named fields should not be flagged as malicious.""" + np = pytest.importorskip("numpy") + dt = np.dtype([("name", "U10"), ("age", "i4"), ("weight", "f8")]) + arr = np.array([("Alice", 25, 55.0), ("Bob", 30, 75.5)], dtype=dt) + data = make_pickle(arr, protocol) + assert_not_malicious(data) + + +# ============================================================================= +# False Positive Prevention - Strings containing dangerous keywords +# These should ALL be LIKELY_SAFE (just strings, no imports) +# ============================================================================= + + +@pytest.mark.parametrize("protocol", HIGHER_PROTOCOLS) +def test_string_containing_exec_safe(protocol: int) -> None: + """String containing 'exec' should be LIKELY_SAFE.""" + data = make_pickle("You must exec this command manually", protocol) + assert_likely_safe(data) + + +@pytest.mark.parametrize("protocol", HIGHER_PROTOCOLS) +def test_string_containing_eval_safe(protocol: int) -> None: + """String containing 'eval' should be LIKELY_SAFE.""" + data = make_pickle("Please eval the results carefully", protocol) + assert_likely_safe(data) + + +@pytest.mark.parametrize("protocol", HIGHER_PROTOCOLS) +def test_string_containing_import_safe(protocol: int) -> None: + """String containing 'import' should be LIKELY_SAFE.""" + data = make_pickle("import os is dangerous in pickles", protocol) + assert_likely_safe(data) + + +@pytest.mark.parametrize("protocol", HIGHER_PROTOCOLS) +def test_string_containing_os_safe(protocol: int) -> None: + """String containing 'os' should be LIKELY_SAFE.""" + data = make_pickle("macOS and Windows are operating systems", protocol) + assert_likely_safe(data) + + +@pytest.mark.parametrize("protocol", HIGHER_PROTOCOLS) +def test_dict_with_suspicious_looking_keys_safe(protocol: int) -> None: + """Dict with suspicious-looking keys should be LIKELY_SAFE.""" + data = make_pickle( + { + "exec": "value", + "eval": "value", + "__reduce__": "value", + "os": "value", + "subprocess": "value", + }, + protocol, + ) + assert_likely_safe(data) + + +@pytest.mark.parametrize("protocol", HIGHER_PROTOCOLS) +def test_string_with_code_snippet_safe(protocol: int) -> None: + """String containing code snippet should be LIKELY_SAFE.""" + code_str = """ + import os + os.system('rm -rf /') + eval('malicious') + exec('code') + """ + data = make_pickle(code_str, protocol) + assert_likely_safe(data) + + +# ============================================================================= +# ML Framework Patterns (Optional - skip if dependencies not installed) +# Should NOT be OVERTLY_MALICIOUS +# ============================================================================= + + +def test_sklearn_model_not_malicious() -> None: + """Scikit-learn model serialization should not be flagged as malicious.""" + np = pytest.importorskip("numpy") + sklearn_lm = pytest.importorskip("sklearn.linear_model") + model = sklearn_lm.LinearRegression() + x_train = np.array([[1], [2], [3]]) + y_train = np.array([1, 2, 3]) + model.fit(x_train, y_train) + data = make_pickle(model, protocol=4) + # sklearn models use non-standard imports but should not be OVERTLY_MALICIOUS + pickled = Pickled.load(data) + result = check_safety(pickled) + assert result.severity < Severity.LIKELY_OVERTLY_MALICIOUS, ( + f"sklearn model incorrectly flagged as malicious: {result.to_string()}" + ) + + +# ============================================================================= +# Severity Level Validation +# These tests verify specific severity levels for different patterns +# ============================================================================= + + +@pytest.mark.parametrize("protocol", HIGHER_PROTOCOLS) +def test_pure_data_is_likely_safe(protocol: int) -> None: + """Pure data (no imports) should be exactly LIKELY_SAFE.""" + data = make_pickle({"name": "test", "values": [1, 2, 3]}, protocol) + pickled = Pickled.load(data) + result = check_safety(pickled) + assert result.severity == Severity.LIKELY_SAFE, ( + f"Pure data should be LIKELY_SAFE, got {result.severity.name}" + ) + + +@pytest.mark.parametrize("protocol", HIGHER_PROTOCOLS) +def test_stdlib_with_reduce_is_at_most_suspicious(protocol: int) -> None: + """Standard library objects using __reduce__ should be at most SUSPICIOUS. + + These objects create unused variables which is flagged by design, but + they should never reach LIKELY_UNSAFE or above. + """ + data = make_pickle(datetime.now(), protocol) + pickled = Pickled.load(data) + result = check_safety(pickled) + # datetime uses __reduce__ which creates an unused variable + # This is correctly flagged as SUSPICIOUS but should not be higher + assert result.severity <= Severity.SUSPICIOUS, ( + f"stdlib datetime should be at most SUSPICIOUS, got {result.severity.name}" + ) + + +# ============================================================================= +# Known Limitations Documentation +# +# These tests document known false positives in fickling where safe stdlib +# types are flagged as malicious because they use builtins module. +# ============================================================================= + + +@pytest.mark.parametrize("protocol", ALL_PROTOCOLS) +def test_builtins_range_is_flagged(protocol: int) -> None: + """Document: range() is flagged because it imports from builtins. + + Protocols 0-2 use GLOBAL opcode -> LIKELY_OVERTLY_MALICIOUS. + Protocols 3-5 use STACK_GLOBAL with safe builtins allowlist -> SUSPICIOUS. + """ + data = make_pickle(range(10), protocol) + pickled = Pickled.load(data) + result = check_safety(pickled) + if protocol <= 2: + assert result.severity >= Severity.LIKELY_OVERTLY_MALICIOUS, ( + f"Expected LIKELY_OVERTLY_MALICIOUS at protocol {protocol}" + ) + else: + assert result.severity >= Severity.SUSPICIOUS, f"Expected SUSPICIOUS at protocol {protocol}" + + +@pytest.mark.parametrize("protocol", ALL_PROTOCOLS) +def test_builtins_slice_is_flagged(protocol: int) -> None: + """Document: slice() is flagged because it imports from builtins. + + Protocols 0-2 use GLOBAL opcode -> LIKELY_UNSAFE. + Protocols 3-5 use STACK_GLOBAL with safe builtins allowlist -> SUSPICIOUS. + """ + data = make_pickle(slice(1, 10), protocol) + pickled = Pickled.load(data) + result = check_safety(pickled) + if protocol <= 2: + assert result.severity >= Severity.LIKELY_UNSAFE, ( + f"Expected LIKELY_UNSAFE at protocol {protocol}" + ) + else: + assert result.severity >= Severity.SUSPICIOUS, f"Expected SUSPICIOUS at protocol {protocol}" + + +@pytest.mark.parametrize("protocol", [4, 5]) +def test_set_at_high_protocols_is_safe(protocol: int) -> None: + """Sets at protocols 4-5 use EMPTY_SET/ADDITEMS and are safe.""" + data = make_pickle({1, 2, 3}, protocol) + assert_likely_safe(data) + + +@pytest.mark.parametrize("protocol", [4, 5]) +def test_frozenset_is_likely_safe_at_high_protocols(protocol: int) -> None: + """Document: frozenset() is LIKELY_SAFE at high protocols. + + Unlike lower protocols, protocols 4-5 serialize frozensets using the + FROZENSET opcode, which doesn't trigger builtins import detection. + """ + data = make_pickle(frozenset([1, 2, 3]), protocol) + pickled = Pickled.load(data) + result = check_safety(pickled) + # At high protocols, frozensets use special opcodes and are safe + assert result.severity == Severity.LIKELY_SAFE, ( + f"Frozenset should be LIKELY_SAFE at protocol {protocol}, got {result.severity.name}" + ) diff --git a/test/test_cve_patterns.py b/test/test_cve_patterns.py new file mode 100644 index 0000000..57e8cb4 --- /dev/null +++ b/test/test_cve_patterns.py @@ -0,0 +1,296 @@ +"""CVE-based attack pattern tests for fickling. + +These tests verify that fickling detects attack patterns identified in CVEs +for pickle scanning tools. The patterns are based on vulnerabilities found +in picklescan. + +CVE-2025-10157: Submodule Import Bypass +CVE-2025-10156: ZIP CRC Bypass +CVE-2025-10155: File Extension Bypass +""" + +from __future__ import annotations + +import io +import zipfile +from pathlib import Path + +import pytest + +import fickling.fickle as op +from fickling.analysis import Severity, check_safety +from fickling.fickle import Pickled +from test._helpers import make_malicious_pickle + +# ============================================================================= +# CVE-2025-10157: Submodule Import Bypass +# +# Attack: Using submodule paths like asyncio.unix_events bypasses scanners +# that only do exact string matching on module names. +# +# Fickling should detect parent modules when a submodule is imported. +# ============================================================================= + + +def test_submodule_bypass_os_path() -> None: + """os.path submodule should trigger os detection via hierarchical matching. + + Uses raw opcodes because make_malicious_pickle resolves os.path to + posixpath at pickle time, losing the dotted module path. + """ + pickled = Pickled( + [ + op.Proto.create(4), + op.Frame(50), + op.ShortBinUnicode("os.path"), + op.Memoize(), + op.ShortBinUnicode("join"), + op.Memoize(), + op.StackGlobal(), + op.Memoize(), + op.Stop(), + ] + ) + result = check_safety(pickled) + assert result.severity >= Severity.LIKELY_OVERTLY_MALICIOUS, ( + "Failed to detect os.path submodule" + ) + + +def test_submodule_bypass_subprocess_internal() -> None: + """Deep submodule in subprocess should trigger subprocess detection.""" + # Construct a pickle that references subprocess submodule + pickled = Pickled( + [ + op.Proto.create(4), + op.Frame(50), + op.ShortBinUnicode("subprocess"), + op.Memoize(), + op.ShortBinUnicode("Popen"), + op.Memoize(), + op.StackGlobal(), + op.Memoize(), + op.Stop(), + ] + ) + result = check_safety(pickled) + assert result.severity >= Severity.LIKELY_OVERTLY_MALICIOUS, "Failed to detect subprocess" + + +def test_submodule_bypass_multiprocessing_util() -> None: + """multiprocessing.util should trigger multiprocessing detection.""" + data = make_malicious_pickle("multiprocessing.util", "spawnv_passfds", (), protocol=4) + pickled = Pickled.load(data) + result = check_safety(pickled) + assert result.severity >= Severity.LIKELY_OVERTLY_MALICIOUS, ( + "Failed to detect multiprocessing.util submodule" + ) + + +def test_submodule_bypass_ctypes_util() -> None: + """ctypes.util should trigger ctypes detection.""" + data = make_malicious_pickle("ctypes.util", "find_library", ("c",), protocol=4) + pickled = Pickled.load(data) + result = check_safety(pickled) + assert result.severity >= Severity.LIKELY_OVERTLY_MALICIOUS, ( + "Failed to detect ctypes.util submodule" + ) + + +def test_submodule_bypass_importlib_util() -> None: + """importlib.util should trigger importlib detection.""" + data = make_malicious_pickle("importlib.util", "find_spec", ("os",), protocol=4) + pickled = Pickled.load(data) + result = check_safety(pickled) + assert result.severity >= Severity.LIKELY_OVERTLY_MALICIOUS, ( + "Failed to detect importlib.util submodule" + ) + + +def test_hierarchical_module_detection() -> None: + """Verify hierarchical module matching works correctly. + + When a pickle imports 'a.b.c', detection should trigger if any of: + - 'a' is in UNSAFE_IMPORTS/UNSAFE_MODULES + - 'a.b' is in UNSAFE_IMPORTS/UNSAFE_MODULES + - 'a.b.c' is in UNSAFE_IMPORTS/UNSAFE_MODULES + """ + # Test with multiprocessing.util which should match 'multiprocessing' + pickled = Pickled( + [ + op.Proto.create(4), + op.Frame(50), + op.ShortBinUnicode("multiprocessing.util"), + op.Memoize(), + op.ShortBinUnicode("spawnv_passfds"), + op.Memoize(), + op.StackGlobal(), + op.Memoize(), + op.Stop(), + ] + ) + result = check_safety(pickled) + assert result.severity >= Severity.LIKELY_OVERTLY_MALICIOUS, ( + "Hierarchical module matching failed for multiprocessing.util" + ) + + +# ============================================================================= +# CVE-2025-10156: ZIP CRC Bypass +# +# Attack: Corrupt CRC in ZIP Central Directory causes Python's zipfile to +# fail validation, but PyTorch still loads the file. Scanner fails to scan. +# +# Fickling should gracefully handle corrupted ZIPs and still scan content. +# ============================================================================= + + +def create_corrupted_zip_with_pickle(malicious_pickle: bytes) -> bytes: + """Create a ZIP file with corrupted CRC containing a malicious pickle. + + The CRC in the Central Directory is corrupted, but the local file + header CRC remains valid. Some parsers fail on this, but the pickle + data is still extractable. + """ + # Create a valid ZIP first + buffer = io.BytesIO() + with zipfile.ZipFile(buffer, "w", zipfile.ZIP_STORED) as zf: + zf.writestr("data.pkl", malicious_pickle) + + zip_bytes = bytearray(buffer.getvalue()) + + # Find the Central Directory and corrupt the CRC there + # The Central Directory File Header starts with signature 0x02014b50 + cd_sig = b"\x50\x4b\x01\x02" + cd_offset = zip_bytes.find(cd_sig) + + assert cd_offset != -1, "Central Directory signature not found in ZIP" + if cd_offset != -1: + # CRC-32 is at offset 16 from the start of the central directory entry + crc_offset = cd_offset + 16 + if crc_offset + 4 <= len(zip_bytes): + # Corrupt the CRC by XORing with 0xFF + for i in range(4): + zip_bytes[crc_offset + i] ^= 0xFF + + return bytes(zip_bytes) + + +def test_corrupted_zip_still_scanned() -> None: + """Malicious pickle in corrupted ZIP should still be detected. + + This tests the CVE-2025-10156 pattern where a corrupted CRC in the + ZIP Central Directory might cause scanners to skip the file. + """ + malicious_pickle = make_malicious_pickle("os", "system", ("id",)) + corrupted_zip = create_corrupted_zip_with_pickle(malicious_pickle) + + try: + with zipfile.ZipFile(io.BytesIO(corrupted_zip), "r") as zf: + pkl_data = zf.read("data.pkl") + except zipfile.BadZipFile: + pkl_offset = corrupted_zip.find(malicious_pickle) + assert pkl_offset != -1, "Pickle data not found in corrupted ZIP" + pkl_data = malicious_pickle + + pickled = Pickled.load(pkl_data) + result = check_safety(pickled) + assert result.severity >= Severity.LIKELY_OVERTLY_MALICIOUS + + +def test_valid_zip_with_malicious_pickle() -> None: + """Valid ZIP containing malicious pickle should be detected.""" + malicious_pickle = make_malicious_pickle("os", "system", ("id",)) + + buffer = io.BytesIO() + with zipfile.ZipFile(buffer, "w", zipfile.ZIP_STORED) as zf: + zf.writestr("model/data.pkl", malicious_pickle) + + buffer.seek(0) + with zipfile.ZipFile(buffer, "r") as zf: + pkl_data = zf.read("model/data.pkl") + pickled = Pickled.load(pkl_data) + result = check_safety(pickled) + assert result.severity >= Severity.LIKELY_OVERTLY_MALICIOUS, ( + "Failed to detect malicious pickle in valid ZIP" + ) + + +# ============================================================================= +# CVE-2025-10155: File Extension Bypass +# +# Attack: Rename .pkl to .bin or .pt causes parser confusion. +# Detection should work regardless of file extension. +# +# Fickling should detect based on content, not file extension. +# ============================================================================= + + +EXTENSION_CASES = [ + pytest.param(".bin", "os", "system", ("id",), id="bin"), + pytest.param(".pt", "subprocess", "call", (["id"],), id="pt"), + pytest.param(".pth", "builtins", "eval", ("1+1",), id="pth"), + pytest.param("", "socket", "socket", (), id="no_extension"), + pytest.param(".txt", "pty", "spawn", ("/bin/sh",), id="misleading_txt"), +] + + +@pytest.mark.parametrize("ext,module,func,args", EXTENSION_CASES) +def test_extension_agnostic_detection( + tmp_path: Path, ext: str, module: str, func: str, args: tuple +) -> None: + """Detection should work regardless of file extension.""" + malicious_pickle = make_malicious_pickle(module, func, args) + file_path = tmp_path / f"model{ext}" + file_path.write_bytes(malicious_pickle) + + pickled = Pickled.load(file_path.read_bytes()) + result = check_safety(pickled) + assert result.severity >= Severity.LIKELY_OVERTLY_MALICIOUS, ( + f"Failed to detect malicious pickle with {ext!r} extension" + ) + + +# ============================================================================= +# Additional CVE-related patterns +# ============================================================================= + + +def test_nested_module_in_unsafe_namespace() -> None: + """Deeply nested modules in unsafe namespaces should be detected.""" + # sys.modules allows accessing any loaded module + pickled = Pickled( + [ + op.Proto.create(4), + op.Frame(30), + op.ShortBinUnicode("sys"), + op.Memoize(), + op.ShortBinUnicode("modules"), + op.Memoize(), + op.StackGlobal(), + op.Memoize(), + op.Stop(), + ] + ) + result = check_safety(pickled) + assert result.severity >= Severity.LIKELY_OVERTLY_MALICIOUS, ( + "Failed to detect sys.modules access" + ) + + +def test_code_module_submodules() -> None: + """code module and submodules should be detected.""" + data = make_malicious_pickle("code", "InteractiveConsole", (), protocol=4) + pickled = Pickled.load(data) + result = check_safety(pickled) + assert result.severity >= Severity.LIKELY_OVERTLY_MALICIOUS, ( + "Failed to detect code.InteractiveConsole" + ) + + +def test_runpy_submodules() -> None: + """runpy module and any submodules should be detected.""" + data = make_malicious_pickle("runpy", "_run_code", (), protocol=4) + pickled = Pickled.load(data) + result = check_safety(pickled) + assert result.severity >= Severity.LIKELY_OVERTLY_MALICIOUS, "Failed to detect runpy._run_code"