diff --git a/src/adapter/mcp.rs b/src/adapter/mcp.rs index 87fbd4a..adc4638 100644 --- a/src/adapter/mcp.rs +++ b/src/adapter/mcp.rs @@ -22,54 +22,7 @@ impl super::Adapter for McpAdapter { } fn detect(&self, root: &Path) -> bool { - // Check package.json for MCP SDK - let pkg_json = root.join("package.json"); - if pkg_json.exists() { - if let Ok(content) = std::fs::read_to_string(&pkg_json) { - if content.contains("@modelcontextprotocol/sdk") || content.contains("mcp-server") { - return true; - } - } - } - - // Check pyproject.toml for mcp dependency - let pyproject = root.join("pyproject.toml"); - if pyproject.exists() { - if let Ok(content) = std::fs::read_to_string(&pyproject) { - if content.contains("mcp") { - return true; - } - } - } - - // Check for Python files importing mcp - if let Ok(entries) = std::fs::read_dir(root) { - for entry in entries.flatten() { - let path = entry.path(); - if path.extension().is_some_and(|e| e == "py") { - if let Ok(content) = std::fs::read_to_string(&path) { - if content.contains("from mcp") - || content.contains("import mcp") - || content.contains("@server.tool") - { - return true; - } - } - } - } - } - - // Check requirements.txt - let requirements = root.join("requirements.txt"); - if requirements.exists() { - if let Ok(content) = std::fs::read_to_string(&requirements) { - if content.lines().any(|l| l.trim().starts_with("mcp")) { - return true; - } - } - } - - false + super::mcp_metadata::metadata_root_for_scan(root).is_some() } fn load(&self, root: &Path, ignore_tests: bool) -> Result> { @@ -78,6 +31,8 @@ impl super::Adapter for McpAdapter { } fn load_with_filter(&self, root: &Path, filter: &ScanPathFilter) -> Result> { + let metadata_root = + super::mcp_metadata::metadata_root_for_scan(root).unwrap_or_else(|| root.to_path_buf()); let name = root .file_name() .map(|n| n.to_string_lossy().to_string()) @@ -137,18 +92,24 @@ impl super::Adapter for McpAdapter { } } - // Parse dependencies (metadata files honor the path filter) - let dependencies = parse_dependencies(root, filter); - - // Parse provenance from package.json or pyproject.toml (filtered) - let provenance = parse_provenance(root, filter); + let (dependencies, provenance) = if super::mcp_metadata::same_path(root, &metadata_root) { + ( + parse_dependencies(root, filter), + parse_provenance(root, filter), + ) + } else { + ( + parse_dependencies(&metadata_root, filter), + parse_provenance(&metadata_root, filter), + ) + }; let data = build_data_surface(&tools, &execution); Ok(vec![ScanTarget { name, framework: Framework::Mcp, - root_path: root.to_path_buf(), + root_path: metadata_root, tools, execution, data, diff --git a/src/adapter/mcp_metadata.rs b/src/adapter/mcp_metadata.rs new file mode 100644 index 0000000..ab58067 --- /dev/null +++ b/src/adapter/mcp_metadata.rs @@ -0,0 +1,122 @@ +use std::path::{Path, PathBuf}; + +pub(super) fn metadata_root_for_scan(scan_root: &Path) -> Option { + if has_mcp_metadata(scan_root) { + return Some(scan_root.to_path_buf()); + } + + if let Some(metadata_root) = ancestor_metadata_root(scan_root) { + if contains_mcp_tool_source(scan_root) { + return Some(metadata_root); + } + } + + contains_mcp_sdk_source(scan_root).then(|| scan_root.to_path_buf()) +} + +pub(super) fn same_path(left: &Path, right: &Path) -> bool { + let normalized_left = left.canonicalize().unwrap_or_else(|_| left.to_path_buf()); + let normalized_right = right.canonicalize().unwrap_or_else(|_| right.to_path_buf()); + normalized_left == normalized_right +} + +fn ancestor_metadata_root(scan_root: &Path) -> Option { + scan_root + .ancestors() + .skip(1) + .find(|ancestor| has_mcp_metadata(ancestor)) + .map(Path::to_path_buf) +} + +fn has_mcp_metadata(root: &Path) -> bool { + package_json_declares_mcp(root) + || pyproject_declares_mcp(root) + || requirements_declare_mcp(root) + || root.join("mcp.json").exists() + || root.join("mcp-config.json").exists() +} + +fn package_json_declares_mcp(root: &Path) -> bool { + let path = root.join("package.json"); + std::fs::read_to_string(path).is_ok_and(|content| { + content.contains("@modelcontextprotocol/sdk") || content.contains("mcp-server") + }) +} + +fn pyproject_declares_mcp(root: &Path) -> bool { + std::fs::read_to_string(root.join("pyproject.toml")) + .is_ok_and(|content| content.contains("mcp")) +} + +fn requirements_declare_mcp(root: &Path) -> bool { + std::fs::read_to_string(root.join("requirements.txt")).is_ok_and(|content| { + content + .lines() + .map(str::trim) + .any(|line| line.starts_with("mcp")) + }) +} + +fn contains_mcp_tool_source(root: &Path) -> bool { + contains_mcp_source(root, SourceDetectionMode::ToolSurface) +} + +fn contains_mcp_sdk_source(root: &Path) -> bool { + contains_mcp_source(root, SourceDetectionMode::SdkUsage) +} + +#[derive(Debug, Clone, Copy)] +enum SourceDetectionMode { + SdkUsage, + ToolSurface, +} + +fn contains_mcp_source(root: &Path, mode: SourceDetectionMode) -> bool { + let walker = ignore::WalkBuilder::new(root) + .hidden(true) + .git_ignore(true) + .max_depth(Some(5)) + .build(); + + for entry in walker.flatten() { + let path = entry.path(); + if !path.is_file() { + continue; + } + if !is_mcp_source_candidate(path) { + continue; + } + if std::fs::read_to_string(path) + .is_ok_and(|content| source_mentions_mcp(path, &content, mode)) + { + return true; + } + } + + false +} + +fn is_mcp_source_candidate(path: &Path) -> bool { + matches!( + path.extension().and_then(|extension| extension.to_str()), + Some("py" | "ts" | "tsx" | "js" | "jsx" | "mjs" | "cjs") + ) +} + +fn source_mentions_mcp(path: &Path, content: &str, mode: SourceDetectionMode) -> bool { + match path.extension().and_then(|extension| extension.to_str()) { + Some("py") => { + content.contains("from mcp") + || content.contains("import mcp") + || matches!(mode, SourceDetectionMode::ToolSurface) + && content.contains("@server.tool") + } + Some("ts" | "tsx" | "js" | "jsx" | "mjs" | "cjs") => { + content.contains("@modelcontextprotocol/sdk") + || content.contains("McpServer") + || matches!(mode, SourceDetectionMode::ToolSurface) + && (content.contains(".registerTool(") || content.contains(".tool(")) + } + Some(_) | None => false, + } +} diff --git a/src/adapter/mod.rs b/src/adapter/mod.rs index 5446e26..6e0b8a5 100644 --- a/src/adapter/mod.rs +++ b/src/adapter/mod.rs @@ -4,6 +4,7 @@ pub mod gpt_actions; pub mod hermes; pub mod langchain; pub mod mcp; +pub(super) mod mcp_metadata; pub mod openclaw; use std::path::Path; diff --git a/src/ux.rs b/src/ux.rs index 35bdf2f..a462e88 100644 --- a/src/ux.rs +++ b/src/ux.rs @@ -1,3 +1,6 @@ +mod hotspots; +mod roots; + use std::collections::{BTreeMap, BTreeSet}; use std::path::Path; @@ -134,6 +137,7 @@ pub fn render_explain(report: &ScanReport, options: &ExplainOptions) -> String { "- Adapters: {}\n", display_list(&coverage.frameworks, "none") )); + output.push_str(&roots::render(report)); output.push_str(&format!("- Targets: {}\n", coverage.targets)); output.push_str(&format!( "- Source files parsed: {} ({})\n", @@ -173,6 +177,8 @@ pub fn render_explain(report: &ScanReport, options: &ExplainOptions) -> String { severity_counts(&report.findings) )); + output.push_str(&hotspots::render(report)); + output.push_str("Next actions:\n"); for action in next_actions(report) { output.push_str(&format!("- {action}\n")); diff --git a/src/ux/hotspots.rs b/src/ux/hotspots.rs new file mode 100644 index 0000000..fcf53e8 --- /dev/null +++ b/src/ux/hotspots.rs @@ -0,0 +1,252 @@ +use std::collections::{BTreeMap, BTreeSet}; +use std::path::{Path, PathBuf}; + +use crate::rules::{AttackCategory, Finding, Severity}; +use crate::ScanReport; + +const MAX_ITEMS: usize = 3; + +pub(super) fn render(report: &ScanReport) -> String { + let blocking_findings: Vec<&Finding> = report + .findings + .iter() + .filter(|finding| finding.severity >= report.verdict.fail_threshold) + .collect(); + + let mut output = String::new(); + output.push_str("Hotspots:\n"); + + if blocking_findings.is_empty() { + output.push_str("- Blocking findings: none\n\n"); + return output; + } + + let runtime_findings: Vec<&Finding> = blocking_findings + .iter() + .copied() + .filter(|finding| finding.attack_category != AttackCategory::SupplyChain) + .collect(); + let supply_chain_findings: Vec<&Finding> = blocking_findings + .iter() + .copied() + .filter(|finding| finding.attack_category == AttackCategory::SupplyChain) + .collect(); + + output.push_str(&format!( + "- Runtime-risk concentration: {}\n", + concentration_summary( + &runtime_findings, + &report.scan_root, + GroupingMode::RuntimeDirectory, + ) + )); + output.push_str(&format!( + "- Supply-chain concentration: {}\n", + concentration_summary( + &supply_chain_findings, + &report.scan_root, + GroupingMode::SupplyChainFile, + ) + )); + output.push_str(&format!( + "- Rule concentration: {}\n", + rule_summary(&blocking_findings) + )); + + if mostly_outside_tool_sources(report, &blocking_findings) { + output.push_str( + "- Most blocking findings are outside discovered tool source files; consider `[scan] include/exclude` or a baseline.\n", + ); + } + + output.push('\n'); + output +} + +#[derive(Debug, Clone, Copy)] +enum GroupingMode { + RuntimeDirectory, + SupplyChainFile, +} + +#[derive(Debug)] +struct GroupCount { + label: String, + total: usize, + severity_counts: BTreeMap, +} + +fn concentration_summary( + findings: &[&Finding], + scan_root: &Path, + grouping_mode: GroupingMode, +) -> String { + if findings.is_empty() { + return "none".into(); + } + + let mut groups: BTreeMap = BTreeMap::new(); + for finding in findings { + let label = group_label(finding, scan_root, grouping_mode); + let entry = groups.entry(label.clone()).or_insert_with(|| GroupCount { + label, + total: 0, + severity_counts: BTreeMap::new(), + }); + entry.total += 1; + *entry.severity_counts.entry(finding.severity).or_default() += 1; + } + + let mut ranked: Vec = groups.into_values().collect(); + ranked.sort_by(|left, right| { + right + .total + .cmp(&left.total) + .then_with(|| left.label.cmp(&right.label)) + }); + + ranked + .into_iter() + .take(MAX_ITEMS) + .map(|group| { + format!( + "{} ({})", + group.label, + severity_summary(&group.severity_counts) + ) + }) + .collect::>() + .join(", ") +} + +fn rule_summary(findings: &[&Finding]) -> String { + let mut counts: BTreeMap<&str, usize> = BTreeMap::new(); + for finding in findings { + *counts.entry(&finding.rule_id).or_default() += 1; + } + + let mut ranked: Vec<(&str, usize)> = counts.into_iter().collect(); + ranked.sort_by(|(left_rule, left_count), (right_rule, right_count)| { + right_count + .cmp(left_count) + .then_with(|| left_rule.cmp(right_rule)) + }); + + ranked + .into_iter() + .take(MAX_ITEMS) + .map(|(rule_id, count)| format!("{rule_id} ({count})")) + .collect::>() + .join(", ") +} + +fn group_label(finding: &Finding, scan_root: &Path, grouping_mode: GroupingMode) -> String { + let Some(location) = &finding.location else { + return "unknown location".into(); + }; + let relative = relative_path(scan_root, &location.file); + + match grouping_mode { + GroupingMode::RuntimeDirectory => directory_label(&relative), + GroupingMode::SupplyChainFile => relative, + } +} + +fn directory_label(relative_path: &str) -> String { + let path = Path::new(relative_path); + if let Some(first_component) = path.components().find_map(|component| match component { + std::path::Component::Normal(part) => Some(part.to_string_lossy().into_owned()), + std::path::Component::CurDir + | std::path::Component::ParentDir + | std::path::Component::RootDir + | std::path::Component::Prefix(_) => None, + }) { + if path + .components() + .filter(|component| matches!(component, std::path::Component::Normal(_))) + .count() + > 1 + { + return format!("{first_component}/"); + } + } + + let Some(parent) = path.parent() else { + return relative_path.into(); + }; + if parent.as_os_str().is_empty() { + relative_path.into() + } else { + format!("{}/", parent.to_string_lossy().replace('\\', "/")) + } +} + +fn severity_summary(counts: &BTreeMap) -> String { + [ + Severity::Critical, + Severity::High, + Severity::Medium, + Severity::Low, + Severity::Info, + ] + .into_iter() + .filter_map(|severity| { + counts + .get(&severity) + .map(|count| format!("{count} {severity}")) + }) + .collect::>() + .join(", ") +} + +fn mostly_outside_tool_sources(report: &ScanReport, blocking_findings: &[&Finding]) -> bool { + let tool_files = tool_source_files(report); + if tool_files.is_empty() { + return false; + } + + let outside_count = blocking_findings + .iter() + .filter(|finding| { + finding + .location + .as_ref() + .is_none_or(|location| !tool_files.contains(&normalized_path(&location.file))) + }) + .count(); + + outside_count > blocking_findings.len().saturating_sub(outside_count) +} + +fn tool_source_files(report: &ScanReport) -> BTreeSet { + report + .targets + .iter() + .flat_map(|target| target.tools.iter()) + .filter_map(|tool| tool.defined_at.as_ref()) + .map(|location| normalized_path(&location.file)) + .collect() +} + +fn relative_path(root: &Path, path: &Path) -> String { + let normalized_root = normalized_path(root); + let normalized_path = normalized_path(path); + let relative = normalized_path + .strip_prefix(&normalized_root) + .unwrap_or(&normalized_path); + + relative + .components() + .filter_map(|component| match component { + std::path::Component::Normal(part) => Some(part.to_string_lossy().into_owned()), + std::path::Component::CurDir => None, + std::path::Component::ParentDir => Some("..".to_string()), + std::path::Component::RootDir | std::path::Component::Prefix(_) => None, + }) + .collect::>() + .join("/") +} + +fn normalized_path(path: &Path) -> PathBuf { + path.canonicalize().unwrap_or_else(|_| path.to_path_buf()) +} diff --git a/src/ux/roots.rs b/src/ux/roots.rs new file mode 100644 index 0000000..2ce11bc --- /dev/null +++ b/src/ux/roots.rs @@ -0,0 +1,30 @@ +use std::collections::BTreeSet; +use std::path::{Path, PathBuf}; + +use crate::ScanReport; + +pub(super) fn render(report: &ScanReport) -> String { + let mut output = format!("- Scan root: {}\n", report.scan_root.display()); + let roots = metadata_roots(report); + if !roots.is_empty() { + output.push_str(&format!("- Metadata root: {}\n", roots.join(", "))); + } + output +} + +fn metadata_roots(report: &ScanReport) -> Vec { + let scan_root = normalized_path(&report.scan_root); + let roots: BTreeSet = report + .targets + .iter() + .map(|target| normalized_path(&target.root_path)) + .filter(|root| *root != scan_root) + .map(|root| root.display().to_string()) + .collect(); + + roots.into_iter().collect() +} + +fn normalized_path(path: &Path) -> PathBuf { + path.canonicalize().unwrap_or_else(|_| path.to_path_buf()) +} diff --git a/tests/explain_hotspots.rs b/tests/explain_hotspots.rs new file mode 100644 index 0000000..01be20a --- /dev/null +++ b/tests/explain_hotspots.rs @@ -0,0 +1,143 @@ +use std::path::PathBuf; + +use agentshield::config::ScanPathFilterSummary; +use agentshield::ir::{Framework, Language, ScanTarget, SourceFile, SourceLocation, ToolSurface}; +use agentshield::rules::policy::PolicyVerdict; +use agentshield::rules::{AttackCategory, Confidence, Evidence, Finding, Severity}; +use agentshield::ux::{render_explain, ExplainOptions}; +use agentshield::ScanReport; + +#[test] +fn explain_shows_concentrated_blocking_hotspots() { + let report = report(vec![ + finding( + "SHIELD-001", + Severity::High, + AttackCategory::CommandInjection, + "scripts/setup.py", + ), + finding( + "SHIELD-001", + Severity::High, + AttackCategory::CommandInjection, + "scripts/raw-lake/import.py", + ), + finding( + "SHIELD-009", + Severity::High, + AttackCategory::SupplyChain, + "package.json", + ), + finding( + "SHIELD-003", + Severity::Medium, + AttackCategory::Ssrf, + "src/mcp/server.ts", + ), + ]); + + let output = render_explain( + &report, + &ExplainOptions { + ignore_tests: false, + }, + ); + + assert!(output.contains("Hotspots:")); + assert!(output.contains("- Runtime-risk concentration: scripts/ (2 high)")); + assert!(output.contains("- Supply-chain concentration: package.json (1 high)")); + assert!(output.contains("- Rule concentration: SHIELD-001 (2), SHIELD-009 (1)")); + assert!(output.contains("consider `[scan] include/exclude` or a baseline")); +} + +#[test] +fn explain_reports_no_hotspots_when_there_are_no_findings() { + let output = render_explain( + &report(Vec::new()), + &ExplainOptions { + ignore_tests: false, + }, + ); + + assert!(output.contains("Hotspots:\n- Blocking findings: none")); +} + +fn report(findings: Vec) -> ScanReport { + let pass = findings + .iter() + .all(|finding| finding.severity < Severity::High); + let highest_severity = findings.iter().map(|finding| finding.severity).max(); + + ScanReport { + target_name: "fixture".into(), + findings, + verdict: PolicyVerdict { + pass, + total_findings: 0, + effective_findings: 0, + highest_severity, + fail_threshold: Severity::High, + }, + scan_root: PathBuf::from("/repo"), + targets: vec![ScanTarget { + name: "fixture".into(), + framework: Framework::Mcp, + root_path: PathBuf::from("/repo"), + tools: vec![ToolSurface { + name: "server_tool".into(), + description: None, + input_schema: None, + output_schema: None, + declared_permissions: Vec::new(), + defined_at: Some(location("src/mcp/server.ts")), + }], + execution: Default::default(), + data: Default::default(), + dependencies: Default::default(), + provenance: Default::default(), + source_files: vec![SourceFile { + path: PathBuf::from("/repo/src/mcp/server.ts"), + language: Language::TypeScript, + content: String::new(), + size_bytes: 0, + content_hash: String::new(), + }], + }], + path_filter_summary: ScanPathFilterSummary::default(), + } +} + +fn finding( + rule_id: &str, + severity: Severity, + attack_category: AttackCategory, + relative_file: &str, +) -> Finding { + Finding { + rule_id: rule_id.into(), + rule_name: "Rule".into(), + severity, + confidence: Confidence::High, + attack_category, + message: "finding".into(), + location: Some(location(relative_file)), + evidence: vec![Evidence { + description: "evidence".into(), + location: None, + snippet: None, + }], + taint_path: None, + remediation: None, + cwe_id: None, + } +} + +fn location(relative_file: &str) -> SourceLocation { + SourceLocation { + file: PathBuf::from("/repo").join(relative_file), + line: 1, + column: 0, + end_line: None, + end_column: None, + } +} diff --git a/tests/mcp_subdirectory.rs b/tests/mcp_subdirectory.rs new file mode 100644 index 0000000..fb48afd --- /dev/null +++ b/tests/mcp_subdirectory.rs @@ -0,0 +1,224 @@ +use std::path::Path; + +use agentshield::ir::Framework; +use agentshield::{scan, ScanOptions}; +use tempfile::TempDir; + +const ROOT_PACKAGE_JSON: &str = r#"{"dependencies":{"@modelcontextprotocol/sdk":"1.0.0"}}"#; +const VULNERABLE_PACKAGE_JSON: &str = + r#"{"dependencies":{"@modelcontextprotocol/sdk":"1.0.0","event-stream":"3.3.6"}}"#; + +const MCP_SERVER_TS: &str = r#" +import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; + +const server = new McpServer({ name: "demo" }); + +server.tool("echo", "Echo input", {}, async () => ({ content: [] })); +"#; + +const MCP_SERVER_PY: &str = r#" +from mcp import Server + +server = Server("demo") + +@server.tool("echo") +def echo(value: str) -> str: + return "ok" +"#; + +#[test] +fn subdirectory_scan_uses_ancestor_mcp_metadata_without_expanding_source_boundary() { + let fixture = Fixture::new(); + fixture.write("package.json", ROOT_PACKAGE_JSON); + fixture.write("src/mcp/server.ts", MCP_SERVER_TS); + fixture.write("src/outside.ts", MCP_SERVER_TS); + + let scan_root = fixture.path().join("src/mcp"); + let report = scan(&scan_root, &ScanOptions::default()).unwrap(); + + assert_eq!(report.targets.len(), 1); + let target = &report.targets[0]; + assert_eq!(target.framework, Framework::Mcp); + assert_eq!( + target.root_path.canonicalize().unwrap(), + fixture.canonical_root() + ); + assert!(target + .dependencies + .dependencies + .iter() + .any(|dep| dep.name == "@modelcontextprotocol/sdk")); + assert_eq!(source_paths(&report), vec!["server.ts"]); + + let rendered = agentshield::ux::render_explain( + &report, + &agentshield::ux::ExplainOptions { + ignore_tests: false, + }, + ); + + assert!(rendered.contains("- Scan root:")); + assert!(rendered.contains("- Metadata root:")); +} + +#[test] +fn subdirectory_scan_detects_typescript_mcp_source_without_root_package() { + let fixture = Fixture::new(); + fixture.write("src/mcp/server.ts", MCP_SERVER_TS); + + let scan_root = fixture.path().join("src/mcp"); + let report = scan(&scan_root, &ScanOptions::default()).unwrap(); + + assert_eq!(report.targets.len(), 1); + assert_eq!(report.targets[0].framework, Framework::Mcp); + assert_eq!( + report.targets[0].root_path.canonicalize().unwrap(), + scan_root.canonicalize().unwrap() + ); + assert_eq!(source_paths(&report), vec!["server.ts"]); +} + +#[test] +fn subdirectory_scan_honors_metadata_root_exclude_for_package_json() { + let fixture = Fixture::new(); + fixture.write("package.json", VULNERABLE_PACKAGE_JSON); + fixture.write( + ".agentshield.toml", + "[scan]\nexclude = [\"package.json\"]\n", + ); + fixture.write("src/mcp/server.ts", MCP_SERVER_TS); + + let report = scan( + &fixture.path().join("src/mcp"), + &ScanOptions { + config_path: Some(fixture.path().join(".agentshield.toml")), + ..ScanOptions::default() + }, + ) + .unwrap(); + + assert_eq!(source_paths(&report), vec!["server.ts"]); + assert_no_finding_from(&report, "package.json"); + assert!(report.targets[0].dependencies.dependencies.is_empty()); +} + +#[test] +fn subdirectory_scan_honors_metadata_root_exclude_for_requirements_txt() { + let fixture = Fixture::new(); + fixture.write("requirements.txt", "mcp==1.0.0\nevent-stream==3.3.6\n"); + fixture.write( + ".agentshield.toml", + "[scan]\nexclude = [\"requirements.txt\"]\n", + ); + fixture.write("src/mcp/server.py", MCP_SERVER_PY); + + let report = scan( + &fixture.path().join("src/mcp"), + &ScanOptions { + config_path: Some(fixture.path().join(".agentshield.toml")), + ..ScanOptions::default() + }, + ) + .unwrap(); + + assert_eq!(source_paths(&report), vec!["server.py"]); + assert_no_finding_from(&report, "requirements.txt"); + assert!(report.targets[0].dependencies.dependencies.is_empty()); +} + +#[test] +fn subdirectory_scan_honors_metadata_root_include_for_package_json() { + let fixture = Fixture::new(); + fixture.write("package.json", VULNERABLE_PACKAGE_JSON); + fixture.write(".agentshield.toml", "[scan]\ninclude = [\"server.ts\"]\n"); + fixture.write("src/mcp/server.ts", MCP_SERVER_TS); + + let report = scan( + &fixture.path().join("src/mcp"), + &ScanOptions { + config_path: Some(fixture.path().join(".agentshield.toml")), + ..ScanOptions::default() + }, + ) + .unwrap(); + + assert_eq!(source_paths(&report), vec!["server.ts"]); + assert_no_finding_from(&report, "package.json"); + assert!(report.targets[0].dependencies.dependencies.is_empty()); +} + +struct Fixture { + dir: TempDir, +} + +impl Fixture { + fn new() -> Self { + Self { + dir: TempDir::new().unwrap(), + } + } + + fn path(&self) -> &Path { + self.dir.path() + } + + fn canonical_root(&self) -> std::path::PathBuf { + self.dir.path().canonicalize().unwrap() + } + + fn write(&self, relative_path: &str, content: &str) { + let path = self.dir.path().join(relative_path); + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent).unwrap(); + } + std::fs::write(path, content).unwrap(); + } +} + +fn source_paths(report: &agentshield::ScanReport) -> Vec { + report + .targets + .iter() + .flat_map(|target| target.source_files.iter()) + .map(|source| relative_path(&report.scan_root, &source.path)) + .collect() +} + +fn assert_no_finding_from(report: &agentshield::ScanReport, file_name: &str) { + assert!( + !report.findings.iter().any(|finding| finding + .location + .as_ref() + .is_some_and(|location| location.file.ends_with(file_name))), + "expected no findings from {file_name}, got: {:?}", + report + .findings + .iter() + .map(|finding| ( + finding.rule_id.as_str(), + finding + .location + .as_ref() + .map(|location| location.file.clone()) + )) + .collect::>() + ); +} + +fn relative_path(root: &Path, path: &Path) -> String { + let canonical_root = root.canonicalize().unwrap_or_else(|_| root.to_path_buf()); + let canonical_path = path.canonicalize().unwrap_or_else(|_| path.to_path_buf()); + + canonical_path + .strip_prefix(&canonical_root) + .unwrap_or(path) + .components() + .filter_map(|component| match component { + std::path::Component::Normal(part) => Some(part.to_string_lossy().into_owned()), + std::path::Component::CurDir => None, + std::path::Component::ParentDir => Some("..".to_string()), + std::path::Component::RootDir | std::path::Component::Prefix(_) => None, + }) + .collect::>() + .join("/") +}