diff --git a/fickling/analysis.py b/fickling/analysis.py index d24afc8..9e213e2 100644 --- a/fickling/analysis.py +++ b/fickling/analysis.py @@ -78,19 +78,25 @@ def analyze(self, analysis: Analysis) -> list[AnalysisResult]: def results(self) -> AnalysisResults: return AnalysisResults(pickled=self.pickled, results=self.previous_results) - def shorten_code(self, ast_node) -> tuple[str, bool]: + def shorten_code(self, ast_node) -> str: + """Return a short, human-readable form of an AST node for use in + analysis messages. Pure formatter — does not touch dedup state. + """ code = unparse(ast_node).strip() if len(code) > 32: cutoff = code.find("(") - if code[cutoff] == "(": - shortened_code = f"{code[: code.find('(')].strip()}(...)" - else: - shortened_code = code - else: - shortened_code = code - was_already_reported = shortened_code in self.reported_shortened_code - self.reported_shortened_code.add(shortened_code) - return shortened_code, was_already_reported + if cutoff >= 0: + return f"{code[:cutoff].strip()}(...)" + return code + + def mark_reported(self, shortened: str) -> bool: + """Mark a shortened code fragment as reported. Returns True if + this was the first mark, False if a prior call already marked it. + """ + if shortened in self.reported_shortened_code: + return False + self.reported_shortened_code.add(shortened) + return True class Analyzer(metaclass=AnalyzerMeta): @@ -164,8 +170,10 @@ def __str__(self): class Analysis(ABC): ALL: list[Analysis] = [] - def __init_subclass__(cls, **kwargs): - Analysis.ALL.append(cls()) + def __init_subclass__(cls, *, register: bool = True, **kwargs): + super().__init_subclass__(**kwargs) + if register: + Analysis.ALL.append(cls()) @abstractmethod def analyze(self, context: AnalysisContext) -> Iterator[AnalysisResult]: @@ -255,8 +263,8 @@ def analyze(self, context: AnalysisContext) -> Iterator[AnalysisResult]: class NonStandardImports(Analysis): def analyze(self, context: AnalysisContext) -> Iterator[AnalysisResult]: for node in context.pickled.non_standard_imports(): - shortened, already_reported = context.shorten_code(node) - if not already_reported: + shortened = context.shorten_code(node) + if context.mark_reported(shortened): yield AnalysisResult( Severity.LIKELY_UNSAFE, f"`{shortened}` imports a Python module that is not a part of " @@ -324,7 +332,7 @@ class UnsafeImportsML(Analysis): def analyze(self, context: AnalysisContext) -> Iterator[AnalysisResult]: for node in context.pickled.properties.imports: - shortened, _ = context.shorten_code(node) + shortened = context.shorten_code(node) all_modules = [ node.module.rsplit(".", i)[0] for i in range(0, node.module.count(".") + 1) ] @@ -377,7 +385,7 @@ class BadCalls(Analysis): def analyze(self, context: AnalysisContext) -> Iterator[AnalysisResult]: for node in context.pickled.properties.calls: - shortened, _already_reported = context.shorten_code(node) + shortened = context.shorten_code(node) if any(shortened.startswith(f"{c}(") for c in self.BAD_CALLS): yield AnalysisResult( Severity.OVERTLY_MALICIOUS, @@ -397,7 +405,7 @@ def analyze(self, context: AnalysisContext) -> Iterator[AnalysisResult]: # if the call is to a constructor of an object imported from the Python # standard library, it's probably okay continue - shortened, already_reported = context.shorten_code(node) + shortened = context.shorten_code(node) if ( shortened.startswith("eval(") or shortened.startswith("exec(") @@ -413,7 +421,7 @@ def analyze(self, context: AnalysisContext) -> Iterator[AnalysisResult]: "OvertlyBadEval", trigger=shortened, ) - elif not already_reported: + elif context.mark_reported(shortened): yield AnalysisResult( Severity.LIKELY_UNSAFE, f"Call to `{shortened}` can execute arbitrary code and is inherently unsafe", @@ -429,7 +437,7 @@ def analyze(self, context: AnalysisContext) -> Iterator[AnalysisResult]: n.name in SAFE_BUILTINS for n in node.names ): continue - shortened, _ = context.shorten_code(node) + shortened = context.shorten_code(node) yield AnalysisResult( Severity.LIKELY_OVERTLY_MALICIOUS, f"`{shortened}` is suspicious and indicative of an overtly malicious pickle file", @@ -447,7 +455,7 @@ def analyze(self, context: AnalysisContext) -> Iterator[AnalysisResult]: # Malformed pickle or resource exhaustion - dedicated analyses will report this return for varname, asmt in unused.items(): - shortened, _ = context.shorten_code(asmt.value) + shortened = context.shorten_code(asmt.value) yield AnalysisResult( Severity.SUSPICIOUS, f"Variable `{varname}` is assigned value `{shortened}` but unused afterward; " diff --git a/fickling/ml.py b/fickling/ml.py index 12b415e..fda18b8 100644 --- a/fickling/ml.py +++ b/fickling/ml.py @@ -288,16 +288,14 @@ } -class MLAllowlist(Analysis): +class MLAllowlist(Analysis, register=False): def __init__(self): super().__init__() self.allowlist = ML_ALLOWLIST def analyze(self, context: AnalysisContext) -> Iterator[AnalysisResult]: for node in context.pickled.properties.imports: - shortened, already_reported = context.shorten_code(node) - if already_reported: - continue + shortened = context.shorten_code(node) if isinstance(node, ast.ImportFrom): # from module import x diff --git a/test/test_analysis.py b/test/test_analysis.py new file mode 100644 index 0000000..3b5ab2c --- /dev/null +++ b/test/test_analysis.py @@ -0,0 +1,19 @@ +from unittest import TestCase + +import fickling.fickle as op +from fickling.analysis import Severity, check_safety +from fickling.fickle import Pickled + + +class TestAnalysis(TestCase): + def test_benign_pickle(self): + pickled = Pickled( + [ + op.Proto.create(4), + op.ShortBinUnicode("collections"), + op.ShortBinUnicode("deque"), + op.StackGlobal(), + op.Stop(), + ] + ) + self.assertEqual(check_safety(pickled).severity, Severity.LIKELY_SAFE) diff --git a/test/test_bypasses.py b/test/test_bypasses.py index 745e793..e24e2d4 100644 --- a/test/test_bypasses.py +++ b/test/test_bypasses.py @@ -2,8 +2,9 @@ from unittest import TestCase import fickling.fickle as op -from fickling.analysis import Severity, UnsafeImportsML, check_safety +from fickling.analysis import Analyzer, Severity, UnsafeImportsML, check_safety from fickling.fickle import Pickled +from fickling.ml import MLAllowlist class TestBypasses(TestCase): @@ -683,6 +684,29 @@ def test_missing_osx_support(self): "from _osx_support import _find_build_tool", ) + # https://github.com/trailofbits/fickling/security/advisories/GHSA-cffv-grgg-g429 + def test_ml_allowlist_not_shadowed_by_unsafe_imports_ml(self): + """MLAllowlist must flag imports outside ML_ALLOWLIST even when another + analysis (UnsafeImportsML) has already iterated the same import. + """ + pickled = Pickled( + [ + op.Proto.create(4), + op.ShortBinUnicode("ast"), + op.ShortBinUnicode("parse"), + op.StackGlobal(), + op.Stop(), + ] + ) + analyzer = Analyzer([UnsafeImportsML(), MLAllowlist()]) + res = check_safety(pickled, analyzer=analyzer) + self.assertGreater(res.severity, Severity.LIKELY_SAFE) + detailed = res.detailed_results().get("AnalysisResult", {}) + self.assertIsNotNone( + detailed.get("MLAllowlist"), + "MLAllowlist did not produce a finding for `from ast import parse`", + ) + class TestUnsafeModuleCoverage(TestCase): """Verify every entry in UNSAFE_MODULES and UNSAFE_IMPORTS triggers detection."""