Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 76 additions & 28 deletions core/src/bin/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::env;
use std::io::Read;
use std::path::{Path, PathBuf};
use std::process;

Expand All @@ -8,21 +9,10 @@ fn main() {
let repo_root = find_repo_root_from(&env::current_dir().unwrap())
.unwrap_or_else(|| env::current_dir().unwrap());

let ph_event = match polyhook::read() {
Ok(e) => e,
let response = match run_app(std::io::stdin(), &repo_root) {
Ok(r) => r,
Err(e) => {
eprintln!("steplock: failed to read hook input: {e}");
process::exit(2);
}
};

let event = polyhook_to_hook_event(ph_event);

let response = match run(&event, &repo_root) {
Ok(HookResponse::Approve) => polyhook::HookResponse::approve(),
Ok(HookResponse::Block { message }) => polyhook::HookResponse::block(&message),
Err(e) => {
eprintln!("steplock: error: {e}");
eprintln!("{e}");
process::exit(2);
}
};
Expand All @@ -33,6 +23,26 @@ fn main() {
}
}

/// Parse the hook event from `reader`, run the gate, and return the polyhook response.
/// Returns `Err(message)` when input is unreadable or the gate engine fails.
fn run_app(mut reader: impl Read, repo_root: &Path) -> Result<polyhook::HookResponse, String> {
let mut bytes = Vec::new();
reader
.read_to_end(&mut bytes)
.map_err(|e| format!("steplock: failed to read hook input: {e}"))?;

let ph_event = polyhook::parse::parse_event(&bytes)
.map_err(|e| format!("steplock: failed to read hook input: {e}"))?;

let event = polyhook_to_hook_event(ph_event);

match run(&event, repo_root) {
Ok(HookResponse::Approve) => Ok(polyhook::HookResponse::approve()),
Ok(HookResponse::Block { message }) => Ok(polyhook::HookResponse::block(&message)),
Err(e) => Err(format!("steplock: error: {e}")),
}
}

fn polyhook_to_hook_event(e: polyhook::HookEvent) -> HookEvent {
HookEvent {
event: e.event.to_string(),
Expand Down Expand Up @@ -66,7 +76,7 @@ mod tests {
use std::fs;
use tempfile::TempDir;

fn make_claude_stdin(cmd: &str, session: &str) -> String {
fn claude_stdin(cmd: &str, session: &str) -> String {
serde_json::json!({
"hook_event_name": "PreToolUse",
"tool_name": "Bash",
Expand Down Expand Up @@ -102,7 +112,7 @@ reset = "session"

#[test]
fn polyhook_event_maps_correctly() {
let stdin = make_claude_stdin("git push origin main", "s1");
let stdin = claude_stdin("git push origin main", "s1");
let ph_event = polyhook::parse::parse_event(stdin.as_bytes()).unwrap();
let event = polyhook_to_hook_event(ph_event);
assert_eq!(event.event, "tool:before");
Expand All @@ -116,25 +126,56 @@ reset = "session"
}

#[test]
fn approves_non_matching_command() {
fn run_app_approves_non_matching_command() {
let tmp = TempDir::new().unwrap();
setup_checklist(tmp.path());
let stdin = make_claude_stdin("ls -la", "s1");
let ph_event = polyhook::parse::parse_event(stdin.as_bytes()).unwrap();
let event = polyhook_to_hook_event(ph_event);
let resp = run(&event, tmp.path()).unwrap();
assert!(matches!(resp, HookResponse::Approve));
let stdin = claude_stdin("ls -la", "s1");
let resp = run_app(stdin.as_bytes(), tmp.path()).unwrap();
assert!(matches!(resp, polyhook::HookResponse::ApproveResponse(_)));
}

#[test]
fn blocks_matching_command() {
fn run_app_blocks_matching_command() {
let tmp = TempDir::new().unwrap();
setup_checklist(tmp.path());
let stdin = make_claude_stdin("git push origin main", "s1");
let ph_event = polyhook::parse::parse_event(stdin.as_bytes()).unwrap();
let event = polyhook_to_hook_event(ph_event);
let resp = run(&event, tmp.path()).unwrap();
assert!(matches!(resp, HookResponse::Block { .. }));
let stdin = claude_stdin("git push origin main", "s1");
let resp = run_app(stdin.as_bytes(), tmp.path()).unwrap();
assert!(matches!(resp, polyhook::HookResponse::BlockResponse(_)));
}

#[test]
fn run_app_error_on_invalid_input() {
let tmp = TempDir::new().unwrap();
let err = run_app(b"not valid json".as_ref(), tmp.path());
assert!(err.is_err());
assert!(err
.unwrap_err()
.contains("steplock: failed to read hook input"));
}

#[test]
fn run_app_error_on_invalid_cel_expression() {
let tmp = TempDir::new().unwrap();
let cl_dir = tmp.path().join(".steplock/checklists/bad-gate");
fs::create_dir_all(&cl_dir).unwrap();
fs::write(
cl_dir.join("config.toml"),
r#"on_event = "tool:before"
on_tool = "bash"
match_input = "!!!invalid cel!!!"
reset = "session"
"#,
)
.unwrap();
fs::write(
cl_dir.join("flow.mmd"),
"stateDiagram-v2\n [*] --> check\n check --> [*]\n check: Check\n",
)
.unwrap();
let stdin = claude_stdin("anything", "s1");
let err = run_app(stdin.as_bytes(), tmp.path());
assert!(err.is_err());
assert!(err.unwrap_err().contains("steplock: error:"));
}

#[test]
Expand All @@ -154,4 +195,11 @@ reset = "session"
let root = find_repo_root_from(&subdir).unwrap();
assert_eq!(root, tmp.path());
}

#[test]
fn find_repo_root_returns_none_when_not_found() {
let tmp = TempDir::new().unwrap();
let result = find_repo_root_from(tmp.path());
assert!(result.is_none());
}
}
2 changes: 1 addition & 1 deletion core/src/cel_eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ mod tests {
#[test]
fn json_float_input_converts() {
let mut input = HashMap::new();
input.insert("val".to_owned(), json!(3.14f64));
input.insert("val".to_owned(), json!(1.23f64));
let ev = HookEvent {
event: "tool:before".to_owned(),
tool: "bash".to_owned(),
Expand Down
Loading