diff --git a/Cargo.lock b/Cargo.lock index bb69e64..51ed786 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -944,11 +944,12 @@ dependencies = [ [[package]] name = "ghpool" -version = "0.1.0" +version = "0.2.0" dependencies = [ "aws-config", "aws-sdk-secretsmanager", "axum", + "hyper 1.10.0", "moka", "reqwest", "serde", @@ -956,6 +957,7 @@ dependencies = [ "time", "tokio", "toml", + "tower", "tower-http", "tracing", "tracing-subscriber", diff --git a/Cargo.toml b/Cargo.toml index 0d8826b..f8cf122 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,3 +19,7 @@ tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter", "local-time"] } time = { version = "0.3", features = ["formatting", "parsing", "macros"] } tower-http = { version = "0.6", features = ["trace", "cors"] } + +[dev-dependencies] +tower = { version = "0.5", features = ["util"] } +hyper = "1" diff --git a/ghp/src/main.rs b/ghp/src/main.rs index aa9b943..906d8a7 100644 --- a/ghp/src/main.rs +++ b/ghp/src/main.rs @@ -28,6 +28,8 @@ fn try_pooled(args: &[String], base: &str) -> Option { "issue" if args.get(1).map(|s| s.as_str()) == Some("view") => try_issue_view(args, base), "pr" if args.get(1).map(|s| s.as_str()) == Some("list") => try_pr_list(args, base), "pr" if args.get(1).map(|s| s.as_str()) == Some("view") => try_pr_view(args, base), + "pr" if args.get(1).map(|s| s.as_str()) == Some("diff") => try_pr_diff(args, base), + "pr" if args.get(1).map(|s| s.as_str()) == Some("checks") => try_pr_checks(args, base), "run" if args.get(1).map(|s| s.as_str()) == Some("list") => try_run_list(args, base), _ => None, } @@ -131,6 +133,48 @@ fn try_pr_view(args: &[String], base: &str) -> Option { Some(0) } +// gh pr diff -R owner/repo +fn try_pr_diff(args: &[String], base: &str) -> Option { + let repo = repo_flag(args)?; + let number = args.get(2).and_then(|s| s.parse::().ok())?; + + let url = format!("{}/raw/repos/{}/pulls/{}", base, repo, number); + let client = reqwest::blocking::Client::new(); + let resp = client.get(&url) + .header("Accept", "application/vnd.github.v3.diff") + .send().ok()?; + if !resp.status().is_success() { return None; } + print!("{}", resp.text().ok()?); + Some(0) +} + +// gh pr checks -R owner/repo +fn try_pr_checks(args: &[String], base: &str) -> Option { + let repo = repo_flag(args)?; + let number = args.get(2).and_then(|s| s.parse::().ok())?; + + // Get the PR head SHA + let pr_url = format!("{}/repos/{}/pulls/{}", base, repo, number); + let pr_body = http_get(&pr_url)?; + let pr: serde_json::Value = serde_json::from_str(&pr_body).ok()?; + let sha = pr["head"]["sha"].as_str()?; + + // Get check runs for that SHA + let url = format!("{}/repos/{}/commits/{}/check-runs", base, repo, sha); + let body = http_get(&url)?; + let v: serde_json::Value = serde_json::from_str(&body).ok()?; + let runs = v["check_runs"].as_array()?; + + for run in runs { + let name = run["name"].as_str().unwrap_or(""); + let status = run["status"].as_str().unwrap_or(""); + let conclusion = run["conclusion"].as_str().unwrap_or(""); + let display = if status == "completed" { conclusion } else { status }; + println!("{}\t{}", display, name); + } + Some(0) +} + // gh run list -R owner/repo fn try_run_list(args: &[String], base: &str) -> Option { let repo = repo_flag(args)?; @@ -205,3 +249,101 @@ fn find_real_gh() -> String { } "/usr/bin/gh".to_string() } + +#[cfg(test)] +mod tests { + use super::*; + + fn args(s: &[&str]) -> Vec { + s.iter().map(|x| x.to_string()).collect() + } + + #[test] + fn test_flag_val() { + let a = args(&["pr", "view", "123", "-R", "owner/repo"]); + assert_eq!(flag_val(&a, "-R"), Some("owner/repo".to_string())); + assert_eq!(flag_val(&a, "--repo"), None); + } + + #[test] + fn test_repo_flag() { + let a = args(&["pr", "diff", "42", "--repo", "foo/bar"]); + assert_eq!(repo_flag(&a), Some("foo/bar".to_string())); + + let a = args(&["pr", "diff", "42", "-R", "baz/qux"]); + assert_eq!(repo_flag(&a), Some("baz/qux".to_string())); + } + + #[test] + fn test_repo_flag_missing() { + let a = args(&["pr", "diff", "42"]); + assert_eq!(repo_flag(&a), None); + } + + #[test] + fn test_jq_extract_simple() { + let val: serde_json::Value = serde_json::json!({"name": "hello", "count": 42}); + assert_eq!(jq_extract(&val, ".name"), "hello"); + assert_eq!(jq_extract(&val, ".count"), "42"); + assert_eq!(jq_extract(&val, ".missing"), "null"); + } + + #[test] + fn test_jq_extract_nested() { + let val: serde_json::Value = serde_json::json!({"head": {"sha": "abc123"}}); + assert_eq!(jq_extract(&val, ".head.sha"), "abc123"); + } + + #[test] + fn test_try_pooled_returns_none_for_writes() { + // pr create, pr merge, etc. should not be handled + let a = args(&["pr", "create", "--title", "test"]); + assert_eq!(try_pooled(&a, "http://fake:8080"), None); + + let a = args(&["pr", "merge", "123", "-R", "o/r"]); + assert_eq!(try_pooled(&a, "http://fake:8080"), None); + } + + #[test] + fn test_try_pooled_returns_none_for_empty() { + let a: Vec = vec![]; + assert_eq!(try_pooled(&a, "http://fake:8080"), None); + } + + #[test] + fn test_try_pr_diff_missing_repo() { + // No -R flag → returns None (falls through) + let a = args(&["pr", "diff", "123"]); + assert_eq!(try_pr_diff(&a, "http://unreachable:9999"), None); + } + + #[test] + fn test_try_pr_diff_bad_number() { + let a = args(&["pr", "diff", "notanumber", "-R", "o/r"]); + assert_eq!(try_pr_diff(&a, "http://unreachable:9999"), None); + } + + #[test] + fn test_try_pr_checks_missing_repo() { + let a = args(&["pr", "checks", "123"]); + assert_eq!(try_pr_checks(&a, "http://unreachable:9999"), None); + } + + #[test] + fn test_try_pr_checks_bad_number() { + let a = args(&["pr", "checks", "abc", "-R", "o/r"]); + assert_eq!(try_pr_checks(&a, "http://unreachable:9999"), None); + } + + #[test] + fn test_try_api_write_indicators_return_none() { + let a = args(&["api", "/repos/o/r/issues", "-X", "POST"]); + assert_eq!(try_api(&a, "http://fake:8080"), None); + + let a = args(&["api", "/repos/o/r/issues", "-f", "title=x"]); + assert_eq!(try_api(&a, "http://fake:8080"), None); + + let a = args(&["api", "graphql"]); + assert_eq!(try_api(&a, "http://fake:8080"), None); + } +} diff --git a/src/main.rs b/src/main.rs index 5cdd672..b7c61f0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -19,6 +19,7 @@ struct AppState { cache: cache::Cache, config: config::Config, token_users: moka::future::Cache, + http: reqwest::Client, } #[tokio::main] @@ -39,12 +40,14 @@ async fn main() { cache, config: config.clone(), token_users: moka::future::Cache::builder().max_capacity(100).build(), + http: reqwest::Client::new(), }); let app = Router::new() .route("/healthz", get(healthz)) .route("/stats", get(stats)) .route("/graphql", post(graphql_proxy)) + .route("/raw/{*path}", get(proxy_raw)) .route("/{*path}", get(proxy)) .with_state(state); @@ -100,8 +103,7 @@ async fn proxy( } // Forward request - let client = reqwest::Client::new(); - let mut req = client.get(&url) + let mut req = state.http.get(&url) .header("Authorization", format!("Bearer {}", identity.token)) .header("User-Agent", "ghpool/0.1.0") .header("Accept", "application/vnd.github+json"); @@ -146,6 +148,65 @@ async fn proxy( Ok(Json(body)) } +async fn proxy_raw( + State(state): State>, + Path(path): Path, + Query(query): Query>, + headers: HeaderMap, +) -> Result<(StatusCode, HeaderMap, String), StatusCode> { + let api_path = format!("/{}", path); + + if !is_allowed_path(&api_path, &state.config.allowed_owners) { + return Err(StatusCode::FORBIDDEN); + } + + let identity = state.pool.select().map_err(|_| StatusCode::SERVICE_UNAVAILABLE)?; + + let mut url = format!("https://api.github.com{}", api_path); + if !query.is_empty() { + let qs: Vec = query.iter().map(|(k, v)| format!("{}={}", k, v)).collect(); + url = format!("{}?{}", url, qs.join("&")); + } + + let accept = headers.get("accept") + .and_then(|v| v.to_str().ok()) + .unwrap_or("application/vnd.github.v3.diff"); + + let resp = state.http.get(&url) + .header("Authorization", format!("Bearer {}", identity.token)) + .header("User-Agent", "ghpool/0.1.0") + .header("Accept", accept) + .send() + .await + .map_err(|e| { + tracing::error!("github request failed: {}", e); + StatusCode::BAD_GATEWAY + })?; + + let rate_remaining = resp.headers() + .get("x-ratelimit-remaining") + .and_then(|v| v.to_str().ok()) + .and_then(|v| v.parse::().ok()); + let rate_reset = resp.headers() + .get("x-ratelimit-reset") + .and_then(|v| v.to_str().ok()) + .and_then(|v| v.parse::().ok()); + state.pool.update_rate(&identity.id, rate_remaining, rate_reset); + + let status = resp.status(); + let body = resp.text().await.map_err(|_| StatusCode::BAD_GATEWAY)?; + + if !status.is_success() { + tracing::warn!("github returned {}: {}", status, api_path); + return Err(StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::BAD_GATEWAY)); + } + + tracing::info!("200 OK {} [raw, via {}]", api_path, identity.id); + let mut resp_headers = HeaderMap::new(); + resp_headers.insert("content-type", "text/plain".parse().unwrap()); + Ok((StatusCode::OK, resp_headers, body)) +} + fn is_allowed_path(path: &str, allowed_owners: &[String]) -> bool { let parts: Vec<&str> = path.split('/').collect(); if parts.len() >= 3 && parts[1] == "repos" { @@ -191,8 +252,7 @@ async fn graphql_proxy( (format!("Bearer {}", identity.token), identity.id.clone()) }; - let client = reqwest::Client::new(); - let resp = client.post("https://api.github.com/graphql") + let resp = state.http.post("https://api.github.com/graphql") .header("Authorization", &auth_header) .header("User-Agent", "ghpool/0.1.0") .header("Content-Type", "application/json") @@ -240,8 +300,7 @@ async fn resolve_token_user(state: &AppState, auth_header: &str) -> String { if let Some(user) = state.token_users.get(&key).await { return user; } - let client = reqwest::Client::new(); - let user = match client.get("https://api.github.com/user") + let user = match state.http.get("https://api.github.com/user") .header("Authorization", auth_header) .header("User-Agent", "ghpool/0.1.0") .send() @@ -257,3 +316,95 @@ async fn resolve_token_user(state: &AppState, auth_header: &str) -> String { state.token_users.insert(key, user.clone()).await; user } + +#[cfg(test)] +mod tests { + use super::*; + use axum::body::Body; + use axum::http::Request; + use tower::ServiceExt; + + fn test_state(allowed_owners: Vec<&str>) -> Arc { + let identities = vec![config::IdentityConfig { + id: "test".to_string(), + token: "fake-token".to_string(), + }]; + let pool = pool::PatPool::new(&identities); + let cache_config = config::CacheConfig::default(); + let cache = cache::Cache::new(&cache_config); + Arc::new(AppState { + pool, + cache, + config: config::Config { + port: 8080, + identities, + allowed_owners: allowed_owners.iter().map(|s| s.to_string()).collect(), + cache: cache_config, + }, + token_users: moka::future::Cache::builder().max_capacity(10).build(), + http: reqwest::Client::new(), + }) + } + + fn app(state: Arc) -> axum::Router { + axum::Router::new() + .route("/healthz", axum::routing::get(healthz)) + .route("/stats", axum::routing::get(stats)) + .route("/raw/{*path}", axum::routing::get(proxy_raw)) + .route("/{*path}", axum::routing::get(proxy)) + .with_state(state) + } + + #[tokio::test] + async fn test_healthz() { + let state = test_state(vec!["openabdev"]); + let resp = app(state) + .oneshot(Request::builder().uri("/healthz").body(Body::empty()).unwrap()) + .await + .unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + } + + #[tokio::test] + async fn test_forbidden_owner() { + let state = test_state(vec!["openabdev"]); + let resp = app(state) + .oneshot(Request::builder().uri("/repos/evil-org/repo/pulls/1").body(Body::empty()).unwrap()) + .await + .unwrap(); + assert_eq!(resp.status(), StatusCode::FORBIDDEN); + } + + #[tokio::test] + async fn test_raw_forbidden_owner() { + let state = test_state(vec!["openabdev"]); + let resp = app(state) + .oneshot(Request::builder().uri("/raw/repos/evil-org/repo/pulls/1").body(Body::empty()).unwrap()) + .await + .unwrap(); + assert_eq!(resp.status(), StatusCode::FORBIDDEN); + } + + #[tokio::test] + async fn test_allowed_non_repo_path() { + // Non-repo paths like /rate_limit are allowed (will fail at GitHub but not 403) + let state = test_state(vec!["openabdev"]); + let resp = app(state) + .oneshot(Request::builder().uri("/rate_limit").body(Body::empty()).unwrap()) + .await + .unwrap(); + // Will be BAD_GATEWAY since fake token can't reach GitHub, but NOT FORBIDDEN + assert_ne!(resp.status(), StatusCode::FORBIDDEN); + } + + #[test] + fn test_is_allowed_path() { + let owners = vec!["openabdev".to_string(), "oablab".to_string()]; + assert!(is_allowed_path("/repos/openabdev/ghpool/pulls/1", &owners)); + assert!(is_allowed_path("/repos/oablab/chi/issues", &owners)); + assert!(!is_allowed_path("/repos/evil/repo/pulls/1", &owners)); + // Non-repo paths are allowed + assert!(is_allowed_path("/rate_limit", &owners)); + assert!(is_allowed_path("/user", &owners)); + } +}