diff --git a/backend/Cargo.lock b/backend/Cargo.lock index d70b19302..7c8f8d4d7 100644 --- a/backend/Cargo.lock +++ b/backend/Cargo.lock @@ -2514,6 +2514,7 @@ dependencies = [ "jsonwebtoken", "metrics", "metrics-exporter-prometheus", + "once_cell", "postgres-types", "rand 0.8.6", "redis", diff --git a/backend/src/api_error.rs b/backend/src/api_error.rs index e277db83f..86e95049b 100644 --- a/backend/src/api_error.rs +++ b/backend/src/api_error.rs @@ -79,7 +79,7 @@ pub enum ApiError { PayloadTooLarge(String), /// Client has exceeded the configured rate limit for the endpoint. - #[error("Rate limit exceeded. Please slow down and retry after the indicated period.")] + #[error("Rate limit exceeded: {0}")] TooManyRequests(String), } @@ -194,10 +194,7 @@ impl IntoResponse for ApiError { } Self::Timeout => { tracing::warn!(error_code = "TIMEOUT", "Request timeout"); - crate::error_tracking::capture_message( - "Request timed out", - sentry::Level::Warning, - ); + crate::error_tracking::capture_message("Request timed out", sentry::Level::Warning); (StatusCode::GATEWAY_TIMEOUT, self.to_string()) } Self::ServiceUnavailable(msg) => { diff --git a/backend/src/app.rs b/backend/src/app.rs index 3044dfcca..679aafa8b 100644 --- a/backend/src/app.rs +++ b/backend/src/app.rs @@ -1,3 +1,4 @@ +use crate::validation::Path; use axum::{ extract::{Query, State}, http::HeaderMap, @@ -6,7 +7,6 @@ use axum::{ routing::{delete, get, post, put}, Json, Router, }; -use crate::validation::Path; use metrics_exporter_prometheus::PrometheusHandle; use serde_json::{json, Value}; use sqlx::PgPool; @@ -2269,12 +2269,9 @@ async fn get_plan_events( Path(plan_id): Path, AuthenticatedUser(user): AuthenticatedUser, ) -> Result, ApiError> { - let events = crate::will_events::WillEventService::get_plan_events( - &state.db, - plan_id, - user.user_id, - ) - .await?; + let events = + crate::will_events::WillEventService::get_plan_events(&state.db, plan_id, user.user_id) + .await?; Ok(Json( json!({ "status": "success", "data": events, "count": events.len() }), )) @@ -2988,13 +2985,9 @@ async fn get_collateral_value( state.db.clone(), 3600, )); - let info = CollateralManagementService::get_collateral_value( - &state.db, - price_feed, - id, - user.user_id, - ) - .await?; + let info = + CollateralManagementService::get_collateral_value(&state.db, price_feed, id, user.user_id) + .await?; Ok(Json(json!({ "status": "success", "data": info }))) } diff --git a/backend/src/auth.rs b/backend/src/auth.rs index ec8e5321a..6e6f87f6b 100644 --- a/backend/src/auth.rs +++ b/backend/src/auth.rs @@ -6,7 +6,7 @@ use axum::{extract::State, Json}; use bcrypt::verify; use chrono::{DateTime, Duration, Utc}; use hex; -use jsonwebtoken::{encode, decode, EncodingKey, Header, Validation}; +use jsonwebtoken::{decode, encode, EncodingKey, Header, Validation}; use ring::signature; use serde::{Deserialize, Serialize}; use sqlx::FromRow; @@ -98,9 +98,11 @@ pub async fn web3_login( let public_key_bytes = { // Enforce strict Stellar address validation if !payload.wallet_address.starts_with('G') || payload.wallet_address.len() != 56 { - return Err(ApiError::BadRequest("Invalid Stellar address format".to_string())); + return Err(ApiError::BadRequest( + "Invalid Stellar address format".to_string(), + )); } - + let strkey = Strkey::from_string(&payload.wallet_address) .map_err(|_| ApiError::BadRequest("Invalid Stellar address".to_string()))?; @@ -513,17 +515,17 @@ where } let token = auth_header.strip_prefix("Bearer ").unwrap(); let mut validation = Validation::default(); - // Ensure expiration is always validated - validation.validate_exp = true; - validation.required_spec_claims.insert("exp".to_string()); - - let claims: UserClaims = decode( - token, - &jsonwebtoken::DecodingKey::from_secret(config.jwt_secret.as_bytes()), - &validation, - ) - .map_err(|_| ApiError::Unauthorized)? - .claims; + // Ensure expiration is always validated + validation.validate_exp = true; + validation.required_spec_claims.insert("exp".to_string()); + + let claims: UserClaims = decode( + token, + &jsonwebtoken::DecodingKey::from_secret(config.jwt_secret.as_bytes()), + &validation, + ) + .map_err(|_| ApiError::Unauthorized)? + .claims; return Ok(AuthenticatedUser(claims)); } @@ -570,17 +572,17 @@ where } let token = auth_header.strip_prefix("Bearer ").unwrap(); let mut validation = Validation::default(); - // Ensure expiration is always validated - validation.validate_exp = true; - validation.required_spec_claims.insert("exp".to_string()); - - let claims: AdminClaims = decode( - token, - &jsonwebtoken::DecodingKey::from_secret(config.jwt_secret.as_bytes()), - &validation, - ) - .map_err(|_| ApiError::Unauthorized)? - .claims; + // Ensure expiration is always validated + validation.validate_exp = true; + validation.required_spec_claims.insert("exp".to_string()); + + let claims: AdminClaims = decode( + token, + &jsonwebtoken::DecodingKey::from_secret(config.jwt_secret.as_bytes()), + &validation, + ) + .map_err(|_| ApiError::Unauthorized)? + .claims; return Ok(AuthenticatedAdmin(claims)); } diff --git a/backend/src/bin/test_stmt_timeout.rs b/backend/src/bin/test_stmt_timeout.rs index 2786e4077..ef4c35c46 100644 --- a/backend/src/bin/test_stmt_timeout.rs +++ b/backend/src/bin/test_stmt_timeout.rs @@ -3,8 +3,8 @@ use std::time::Instant; #[tokio::main] async fn main() -> Result<(), Box> { - let database_url = std::env::var("DATABASE_URL") - .expect("DATABASE_URL must be set for this test"); + let database_url = + std::env::var("DATABASE_URL").expect("DATABASE_URL must be set for this test"); // Create pool using env-configured settings (including DB_POOL_QUERY_TIMEOUT_SECS). let pool = db::create_pool(&database_url).await?; diff --git a/backend/src/cache.rs b/backend/src/cache.rs index 58c57bb48..18593014a 100644 --- a/backend/src/cache.rs +++ b/backend/src/cache.rs @@ -373,7 +373,8 @@ impl CacheService { /// Invalidate all notification caches for a user. pub async fn invalidate_notification_caches(&self, user_id: &str) -> Result { - self.invalidate_prefix(&format!("notifications:{user_id}:")).await + self.invalidate_prefix(&format!("notifications:{user_id}:")) + .await } /// Invalidate all audit log caches. @@ -383,7 +384,8 @@ impl CacheService { /// Invalidate all collateral-related caches. pub async fn invalidate_collateral_caches(&self, user_id: &str) -> Result { - self.invalidate_prefix(&format!("collateral:{user_id}:")).await + self.invalidate_prefix(&format!("collateral:{user_id}:")) + .await } /// Invalidate price feed caches. diff --git a/backend/src/compliance.rs b/backend/src/compliance.rs index b7c6290ae..289434de0 100644 --- a/backend/src/compliance.rs +++ b/backend/src/compliance.rs @@ -1,11 +1,11 @@ use crate::api_error::ApiError; +use crate::events::{EventType, LendingEvent}; use crate::external_integrations::{ AnchorIntegrationClient, ComplianceApiClient, SanctionsApiClient, }; use crate::notifications::{ audit_action, entity_type, notif_type, AuditLogService, NotificationService, }; -use crate::events::{EventType, LendingEvent}; use async_trait::async_trait; use chrono::{DateTime, Duration as ChronoDuration, Utc}; use once_cell::sync::OnceCell; @@ -23,7 +23,8 @@ const ALERT_COOLDOWN_WINDOW: Duration = Duration::from_secs(300); const EVENT_COMMIT_POLL_INTERVAL: Duration = Duration::from_millis(200); const EVENT_COMMIT_MAX_RETRIES: usize = 5; -static REALTIME_COMPLIANCE_LISTENER: OnceCell> = OnceCell::new(); +static REALTIME_COMPLIANCE_LISTENER: OnceCell> = + OnceCell::new(); #[async_trait] pub trait RealtimeComplianceListener: Send + Sync { @@ -158,12 +159,11 @@ impl ComplianceEngine { async fn wait_for_event_commit(&self, event_id: Uuid) -> Result { for _ in 0..EVENT_COMMIT_MAX_RETRIES { - let exists: Option = sqlx::query_scalar( - "SELECT true FROM lending_events WHERE id = $1", - ) - .bind(event_id) - .fetch_optional(&self.db) - .await?; + let exists: Option = + sqlx::query_scalar("SELECT true FROM lending_events WHERE id = $1") + .bind(event_id) + .fetch_optional(&self.db) + .await?; if exists.is_some() { return Ok(true); @@ -263,12 +263,11 @@ impl ComplianceEngine { None => return Ok(()), }; - let plan_created_at: Option> = sqlx::query_scalar( - "SELECT created_at FROM plans WHERE id = $1", - ) - .bind(plan_id) - .fetch_optional(&self.db) - .await?; + let plan_created_at: Option> = + sqlx::query_scalar("SELECT created_at FROM plans WHERE id = $1") + .bind(plan_id) + .fetch_optional(&self.db) + .await?; let plan_created_at = match plan_created_at { Some(created_at) => created_at, @@ -719,11 +718,11 @@ impl RealtimeComplianceListener for ComplianceEngine { #[cfg(test)] mod tests { use super::*; - use rust_decimal_macros::dec; - use sqlx::PgPool; use anyhow::anyhow; use chrono::Utc; + use rust_decimal_macros::dec; use serde_json::json; + use sqlx::PgPool; use std::sync::Arc; use std::time::Duration; use tokio::sync::{oneshot, Mutex}; diff --git a/backend/src/csrf.rs b/backend/src/csrf.rs index a420c8397..01e502f19 100644 --- a/backend/src/csrf.rs +++ b/backend/src/csrf.rs @@ -54,14 +54,13 @@ async fn rotate_csrf_token( db: &sqlx::PgPool, old_token: &str, ) -> Result<(String, chrono::DateTime), ()> { - let user_id: Option = sqlx::query_scalar( - "SELECT user_id FROM csrf_tokens WHERE token = $1 AND used = FALSE", - ) - .bind(old_token) - .fetch_optional(db) - .await - .map_err(|_| ())? - .ok_or(())?; + let user_id: Option = + sqlx::query_scalar("SELECT user_id FROM csrf_tokens WHERE token = $1 AND used = FALSE") + .bind(old_token) + .fetch_optional(db) + .await + .map_err(|_| ())? + .ok_or(())?; let new_token = generate_csrf_token(); let expires_at = Utc::now() + Duration::minutes(60); diff --git a/backend/src/db.rs b/backend/src/db.rs index 3c6c4d0dc..c4fa5d97d 100644 --- a/backend/src/db.rs +++ b/backend/src/db.rs @@ -119,13 +119,13 @@ pub async fn create_pool_with_config( // Enforce a per-query timeout at the server using `statement_timeout`. // This cancels queries that exceed the configured duration and prevents // client-side tasks from hanging indefinitely while the DB is busy. - .after_connect(move |mut conn| { + .after_connect(move |mut conn, _pool| { let timeout_ms = query_timeout_secs * 1000; Box::pin(async move { // Use an explicit SET on the connection. This returns a // Result which we map to (). let set_stmt = format!("SET statement_timeout = {}", timeout_ms); - sqlx::query(&set_stmt).execute(&mut conn).await.map(|_| ()) + sqlx::query(&set_stmt).execute(conn).await.map(|_| ()) }) }) // Test each connection with a lightweight ping before handing it @@ -350,13 +350,14 @@ pub async fn rollback_migration( ensure_version_table(pool).await?; // Verify the target exists and is not already rolled back. - let row: Option<(bool,)> = sqlx::query_as( - "SELECT rolled_back FROM _migration_versions WHERE version = $1", - ) - .bind(target_version) - .fetch_optional(pool) - .await - .map_err(|e| ApiError::Internal(anyhow::anyhow!("Failed to look up migration: {}", e)))?; + let row: Option<(bool,)> = + sqlx::query_as("SELECT rolled_back FROM _migration_versions WHERE version = $1") + .bind(target_version) + .fetch_optional(pool) + .await + .map_err(|e| { + ApiError::Internal(anyhow::anyhow!("Failed to look up migration: {}", e)) + })?; match row { None => { @@ -381,32 +382,27 @@ pub async fn rollback_migration( .map_err(|e| ApiError::Internal(anyhow::anyhow!("Failed to begin transaction: {}", e)))?; // Run the caller-supplied down SQL. - sqlx::query(down_sql) + sqlx::query(down_sql).execute(&mut *tx).await.map_err(|e| { + ApiError::Internal(anyhow::anyhow!( + "Rollback SQL for version {} failed: {}", + target_version, + e + )) + })?; + + // Mark as rolled back in our version table. + sqlx::query("UPDATE _migration_versions SET rolled_back = TRUE WHERE version = $1") + .bind(target_version) .execute(&mut *tx) .await .map_err(|e| { ApiError::Internal(anyhow::anyhow!( - "Rollback SQL for version {} failed: {}", + "Failed to mark migration {} as rolled back: {}", target_version, e )) })?; - // Mark as rolled back in our version table. - sqlx::query( - "UPDATE _migration_versions SET rolled_back = TRUE WHERE version = $1", - ) - .bind(target_version) - .execute(&mut *tx) - .await - .map_err(|e| { - ApiError::Internal(anyhow::anyhow!( - "Failed to mark migration {} as rolled back: {}", - target_version, - e - )) - })?; - // Remove from SQLx's tracking table so it can be re-applied later. sqlx::query("DELETE FROM _sqlx_migrations WHERE version = $1") .bind(target_version) @@ -427,7 +423,10 @@ pub async fn rollback_migration( )) })?; - info!(version = target_version, "Migration rolled back successfully"); + info!( + version = target_version, + "Migration rolled back successfully" + ); Ok(()) } @@ -443,6 +442,8 @@ pub struct PoolMetrics { pub idle: u32, /// Connections currently checked out by active queries. pub active: u32, + /// Connections currently waiting for a connection from the pool. + pub pending: u32, /// Configured upper bound on pool size. pub max_connections: u32, /// Pool utilisation as a fraction in [0.0, 1.0]. @@ -454,6 +455,9 @@ pub fn pool_metrics(pool: &PgPool) -> PoolMetrics { let size = pool.size(); let idle = pool.num_idle() as u32; let active = size.saturating_sub(idle); + // sqlx 0.7 doesn't expose a `num_waiters()` API on `Pool`. + // Use a conservative placeholder until a more accurate metric is available. + let pending = 0u32; let max_connections = pool.options().get_max_connections(); let utilisation = if max_connections > 0 { active as f64 / max_connections as f64 @@ -465,6 +469,7 @@ pub fn pool_metrics(pool: &PgPool) -> PoolMetrics { size, idle, active, + pending, max_connections, utilisation, } diff --git a/backend/src/emergency_access_jobs.rs b/backend/src/emergency_access_jobs.rs index 41eba77ef..184201ff7 100644 --- a/backend/src/emergency_access_jobs.rs +++ b/backend/src/emergency_access_jobs.rs @@ -31,7 +31,9 @@ impl EmergencyAccessJobService { Err(e) => { error!("Error checking expiring emergency access: {}", e); crate::error_tracking::capture_message( - &format!("EmergencyAccessJobService: check_expiring_access failed: {e}"), + &format!( + "EmergencyAccessJobService: check_expiring_access failed: {e}" + ), sentry::Level::Error, ); } diff --git a/backend/src/error_tracking.rs b/backend/src/error_tracking.rs index d58d8e351..5fbf540e5 100644 --- a/backend/src/error_tracking.rs +++ b/backend/src/error_tracking.rs @@ -278,8 +278,8 @@ mod tests { // The payload is base64url({"user_id":"","exp":9999999999}). use base64::Engine as _; let payload = serde_json::json!({ "user_id": user_id, "exp": 9_999_999_999u64 }); - let payload_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD - .encode(payload.to_string().as_bytes()); + let payload_b64 = + base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(payload.to_string().as_bytes()); // header and signature can be arbitrary for this test format!("eyJhbGciOiJIUzI1NiJ9.{payload_b64}.fakesig") } diff --git a/backend/src/event_handlers.rs b/backend/src/event_handlers.rs index 797f83ebd..fd9227244 100644 --- a/backend/src/event_handlers.rs +++ b/backend/src/event_handlers.rs @@ -2,13 +2,13 @@ use crate::api_error::ApiError; use crate::app::AppState; use crate::auth::AuthenticatedUser; use crate::events::{EventService, EventType, LendingEvent}; +use crate::validation::Path; use axum::{ extract::{Query, State}, http::StatusCode, response::IntoResponse, Json, }; -use crate::validation::Path; use serde::{Deserialize, Serialize}; use std::sync::Arc; use uuid::Uuid; diff --git a/backend/src/governance.rs b/backend/src/governance.rs index a7039c0a8..67524cb50 100644 --- a/backend/src/governance.rs +++ b/backend/src/governance.rs @@ -708,7 +708,11 @@ impl GovernanceService { .await .map_err(|e| ApiError::Internal(anyhow::anyhow!("DB error checking delegation: {}", e)))?; - let action = if existing.is_some() { "redelegated" } else { "delegated" }; + let action = if existing.is_some() { + "redelegated" + } else { + "delegated" + }; // Upsert the delegation (insert or update on conflict). sqlx::query( @@ -780,7 +784,9 @@ impl GovernanceService { .bind(delegator_id) .execute(&mut *tx) .await - .map_err(|e| ApiError::Internal(anyhow::anyhow!("DB error removing delegation: {}", e)))?; + .map_err(|e| { + ApiError::Internal(anyhow::anyhow!("DB error removing delegation: {}", e)) + })?; sqlx::query( "INSERT INTO governance_delegation_history (delegator_id, delegate_id, action) VALUES ($1, $2, 'undelegated')", diff --git a/backend/src/graphql.rs b/backend/src/graphql.rs index e447534f8..8ebd07f54 100644 --- a/backend/src/graphql.rs +++ b/backend/src/graphql.rs @@ -37,13 +37,12 @@ impl QueryRoot { let parsed_id = Uuid::parse_str(&id).map_err(|_| async_graphql::Error::new("Invalid ID"))?; - let record = sqlx::query( - "SELECT id, user_id, status FROM plans WHERE id = $1 AND user_id = $2", - ) - .bind(parsed_id) - .bind(user.user_id) - .fetch_optional(db) - .await?; + let record = + sqlx::query("SELECT id, user_id, status FROM plans WHERE id = $1 AND user_id = $2") + .bind(parsed_id) + .bind(user.user_id) + .fetch_optional(db) + .await?; Ok(record.map(|r| Plan { id: r.get::("id").to_string(), @@ -128,18 +127,20 @@ pub async fn graphql_handler( Some(config) => config.clone(), None => { return GraphQLResponse(async_graphql::BatchResponse::Single( - async_graphql::Response::from_errors(vec![ - async_graphql::ServerError::new("Server configuration unavailable", None), - ]), + async_graphql::Response::from_errors(vec![async_graphql::ServerError::new( + "Server configuration unavailable", + None, + )]), )) } }; let Some(user) = extract_user_claims(&headers, &config) else { return GraphQLResponse(async_graphql::BatchResponse::Single( - async_graphql::Response::from_errors(vec![ - async_graphql::ServerError::new("Authentication required", None), - ]), + async_graphql::Response::from_errors(vec![async_graphql::ServerError::new( + "Authentication required", + None, + )]), )); }; diff --git a/backend/src/insurance_fund.rs b/backend/src/insurance_fund.rs index d4b1de883..cc7d30500 100644 --- a/backend/src/insurance_fund.rs +++ b/backend/src/insurance_fund.rs @@ -595,14 +595,13 @@ impl InsuranceFundService { } if let Some(plan_id) = req.plan_id { - let owner: Option = - sqlx::query_scalar("SELECT user_id FROM plans WHERE id = $1") - .bind(plan_id) - .fetch_optional(&self.db) - .await - .map_err(|e| { - ApiError::Internal(anyhow::anyhow!("DB error resolving plan owner: {}", e)) - })?; + let owner: Option = sqlx::query_scalar("SELECT user_id FROM plans WHERE id = $1") + .bind(plan_id) + .fetch_optional(&self.db) + .await + .map_err(|e| { + ApiError::Internal(anyhow::anyhow!("DB error resolving plan owner: {}", e)) + })?; if let Some(user_id) = owner { return Ok(user_id); } @@ -642,14 +641,12 @@ impl InsuranceFundService { .await .map_err(|e| ApiError::Internal(anyhow::anyhow!("Tx start error: {}", e)))?; - let fund = sqlx::query_as::<_, InsuranceFund>( - "SELECT * FROM insurance_fund WHERE id = $1", - ) - .bind(fund_id) - .fetch_optional(&mut *tx) - .await - .map_err(|e| ApiError::Internal(anyhow::anyhow!("DB error fetching fund: {}", e)))? - .ok_or_else(|| ApiError::NotFound("Fund not found".to_string()))?; + let fund = sqlx::query_as::<_, InsuranceFund>("SELECT * FROM insurance_fund WHERE id = $1") + .bind(fund_id) + .fetch_optional(&mut *tx) + .await + .map_err(|e| ApiError::Internal(anyhow::anyhow!("DB error fetching fund: {}", e)))? + .ok_or_else(|| ApiError::NotFound("Fund not found".to_string()))?; if fund.status == FundStatus::Insolvent.as_str() { return Err(ApiError::BadRequest( @@ -752,7 +749,9 @@ impl InsuranceFundService { .fetch_optional(&mut *tx) .await .map_err(|e| ApiError::Internal(anyhow::anyhow!("DB error fetching fund: {}", e)))? - .ok_or_else(|| ApiError::NotFound(format!("Insurance fund {} not found", claim.fund_id)))?; + .ok_or_else(|| { + ApiError::NotFound(format!("Insurance fund {} not found", claim.fund_id)) + })?; if fund.status == FundStatus::Insolvent.as_str() { return Err(ApiError::BadRequest( @@ -893,9 +892,9 @@ impl InsuranceFundService { )); } - let payout_amount = claim.payout_amount.ok_or_else(|| { - ApiError::BadRequest("Claim has no payout amount set".to_string()) - })?; + let payout_amount = claim + .payout_amount + .ok_or_else(|| ApiError::BadRequest("Claim has no payout amount set".to_string()))?; let fund = sqlx::query_as::<_, InsuranceFund>( "SELECT * FROM insurance_fund WHERE id = $1 FOR UPDATE", diff --git a/backend/src/lending_data_warehouse.rs b/backend/src/lending_data_warehouse.rs index d3807fe4e..0f68d867a 100644 --- a/backend/src/lending_data_warehouse.rs +++ b/backend/src/lending_data_warehouse.rs @@ -32,7 +32,9 @@ impl LendingDataWarehouseService { Err(e) => { error!("Lending data warehouse snapshot failed: {}", e); crate::error_tracking::capture_message( - &format!("LendingDataWarehouseService::snapshot_current_metrics failed: {e}"), + &format!( + "LendingDataWarehouseService::snapshot_current_metrics failed: {e}" + ), sentry::Level::Error, ); } diff --git a/backend/src/loan_lifecycle.rs b/backend/src/loan_lifecycle.rs index b00ff2eb6..a019e23da 100644 --- a/backend/src/loan_lifecycle.rs +++ b/backend/src/loan_lifecycle.rs @@ -151,6 +151,19 @@ pub struct LoanLifecycleRecord { pub updated_at: DateTime, pub repaid_at: Option>, pub liquidated_at: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub plan: Option, +} + +/// A compact plan summary attached to a loan for eager-loaded loan lifecycle +/// queries. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct LoanPlanSummary { + pub id: Uuid, + pub title: Option, + pub status: Option, + pub is_paused: Option, } /// Raw sqlx row helper – mirrors the table schema exactly. @@ -174,6 +187,29 @@ pub(crate) struct LoanLifecycleRow { pub liquidated_at: Option>, } +#[derive(sqlx::FromRow)] +pub(crate) struct LoanLifecycleRowWithPlan { + pub id: Uuid, + pub user_id: Uuid, + pub plan_id: Option, + pub borrow_asset: String, + pub collateral_asset: String, + pub principal: Decimal, + pub interest_rate_bps: i32, + pub collateral_amount: Decimal, + pub amount_repaid: Decimal, + pub status: String, + pub due_date: DateTime, + pub transaction_hash: Option, + pub created_at: DateTime, + pub updated_at: DateTime, + pub repaid_at: Option>, + pub liquidated_at: Option>, + pub plan_title: Option, + pub plan_status: Option, + pub plan_is_paused: Option, +} + impl From for LoanLifecycleRecord { fn from(r: LoanLifecycleRow) -> Self { LoanLifecycleRecord { @@ -193,6 +229,38 @@ impl From for LoanLifecycleRecord { updated_at: r.updated_at, repaid_at: r.repaid_at, liquidated_at: r.liquidated_at, + plan: None, + } + } +} + +impl From for LoanLifecycleRecord { + fn from(r: LoanLifecycleRowWithPlan) -> Self { + let plan = r.plan_id.map(|id| LoanPlanSummary { + id, + title: r.plan_title, + status: r.plan_status, + is_paused: r.plan_is_paused, + }); + + LoanLifecycleRecord { + id: r.id, + user_id: r.user_id, + plan_id: r.plan_id, + borrow_asset: r.borrow_asset, + collateral_asset: r.collateral_asset, + principal: r.principal, + interest_rate_bps: r.interest_rate_bps, + collateral_amount: r.collateral_amount, + amount_repaid: r.amount_repaid, + status: r.status, + due_date: r.due_date, + transaction_hash: r.transaction_hash, + created_at: r.created_at, + updated_at: r.updated_at, + repaid_at: r.repaid_at, + liquidated_at: r.liquidated_at, + plan, } } } @@ -262,14 +330,16 @@ impl LoanLifecycleService { id: Uuid, user_id: Uuid, ) -> Result { - let row = sqlx::query_as::<_, LoanLifecycleRow>( + let row = sqlx::query_as::<_, LoanLifecycleRowWithPlan>( r#" - SELECT id, user_id, plan_id, borrow_asset, collateral_asset, - principal, interest_rate_bps, collateral_amount, amount_repaid, - status, due_date, transaction_hash, - created_at, updated_at, repaid_at, liquidated_at - FROM loan_lifecycle - WHERE id = $1 AND user_id = $2 + SELECT ll.id, ll.user_id, ll.plan_id, ll.borrow_asset, ll.collateral_asset, + ll.principal, ll.interest_rate_bps, ll.collateral_amount, ll.amount_repaid, + ll.status, ll.due_date, ll.transaction_hash, + ll.created_at, ll.updated_at, ll.repaid_at, ll.liquidated_at, + p.title AS plan_title, p.status AS plan_status, p.is_paused AS plan_is_paused + FROM loan_lifecycle ll + LEFT JOIN plans p ON p.id = ll.plan_id + WHERE ll.id = $1 AND ll.user_id = $2 "#, ) .bind(id) @@ -283,14 +353,16 @@ impl LoanLifecycleService { /// Fetch a single loan by its `id`. Returns `NotFound` when absent. pub async fn get_loan(db: &PgPool, id: Uuid) -> Result { - let row = sqlx::query_as::<_, LoanLifecycleRow>( + let row = sqlx::query_as::<_, LoanLifecycleRowWithPlan>( r#" - SELECT id, user_id, plan_id, borrow_asset, collateral_asset, - principal, interest_rate_bps, collateral_amount, amount_repaid, - status, due_date, transaction_hash, - created_at, updated_at, repaid_at, liquidated_at - FROM loan_lifecycle - WHERE id = $1 + SELECT ll.id, ll.user_id, ll.plan_id, ll.borrow_asset, ll.collateral_asset, + ll.principal, ll.interest_rate_bps, ll.collateral_amount, ll.amount_repaid, + ll.status, ll.due_date, ll.transaction_hash, + ll.created_at, ll.updated_at, ll.repaid_at, ll.liquidated_at, + p.title AS plan_title, p.status AS plan_status, p.is_paused AS plan_is_paused + FROM loan_lifecycle ll + LEFT JOIN plans p ON p.id = ll.plan_id + WHERE ll.id = $1 "#, ) .bind(id) @@ -306,24 +378,31 @@ impl LoanLifecycleService { db: &PgPool, filters: &LoanListFilters, ) -> Result, ApiError> { - let rows = sqlx::query_as::<_, LoanLifecycleRow>( + let rows = sqlx::query_as::<_, LoanLifecycleRowWithPlan>( r#" - SELECT id, user_id, plan_id, borrow_asset, collateral_asset, - principal, interest_rate_bps, collateral_amount, amount_repaid, - status, due_date, transaction_hash, - created_at, updated_at, repaid_at, liquidated_at - FROM loan_lifecycle - WHERE ($1::uuid IS NULL OR user_id = $1) - AND ($2::uuid IS NULL OR plan_id = $2) - AND ($3::text IS NULL OR status::text = $3) - ORDER BY created_at DESC - "# + SELECT ll.id, ll.user_id, ll.plan_id, ll.borrow_asset, ll.collateral_asset, + ll.principal, ll.interest_rate_bps, ll.collateral_amount, ll.amount_repaid, + ll.status, ll.due_date, ll.transaction_hash, + ll.created_at, ll.updated_at, ll.repaid_at, ll.liquidated_at, + p.title AS plan_title, p.status AS plan_status, p.is_paused AS plan_is_paused + FROM loan_lifecycle ll + LEFT JOIN plans p ON p.id = ll.plan_id + WHERE ($1::uuid IS NULL OR ll.user_id = $1) + AND ($2::uuid IS NULL OR ll.plan_id = $2) + AND ($3::text IS NULL OR ll.status::text = $3) + ORDER BY ll.created_at DESC + "#, ); let rows = rows .bind(filters.user_id) .bind(filters.plan_id) - .bind(filters.status.as_ref().map(|status| status.as_str().to_string())) + .bind( + filters + .status + .as_ref() + .map(|status| status.as_str().to_string()), + ) .fetch_all(db) .await?; Ok(rows.into_iter().map(Into::into).collect()) @@ -336,17 +415,19 @@ impl LoanLifecycleService { limit: i64, offset: i64, ) -> Result, ApiError> { - let rows = sqlx::query_as::<_, LoanLifecycleRow>( + let rows = sqlx::query_as::<_, LoanLifecycleRowWithPlan>( r#" - SELECT id, user_id, plan_id, borrow_asset, collateral_asset, - principal, interest_rate_bps, collateral_amount, amount_repaid, - status, due_date, transaction_hash, - created_at, updated_at, repaid_at, liquidated_at - FROM loan_lifecycle - WHERE ($1::uuid IS NULL OR user_id = $1) - AND ($2::uuid IS NULL OR plan_id = $2) - AND ($3::text IS NULL OR status::text = $3) - ORDER BY created_at DESC + SELECT ll.id, ll.user_id, ll.plan_id, ll.borrow_asset, ll.collateral_asset, + ll.principal, ll.interest_rate_bps, ll.collateral_amount, ll.amount_repaid, + ll.status, ll.due_date, ll.transaction_hash, + ll.created_at, ll.updated_at, ll.repaid_at, ll.liquidated_at, + p.title AS plan_title, p.status AS plan_status, p.is_paused AS plan_is_paused + FROM loan_lifecycle ll + LEFT JOIN plans p ON p.id = ll.plan_id + WHERE ($1::uuid IS NULL OR ll.user_id = $1) + AND ($2::uuid IS NULL OR ll.plan_id = $2) + AND ($3::text IS NULL OR ll.status::text = $3) + ORDER BY ll.created_at DESC LIMIT $4 OFFSET $5 "#, ); @@ -354,7 +435,12 @@ impl LoanLifecycleService { let rows = rows .bind(filters.user_id) .bind(filters.plan_id) - .bind(filters.status.as_ref().map(|status| status.as_str().to_string())) + .bind( + filters + .status + .as_ref() + .map(|status| status.as_str().to_string()), + ) .bind(limit) .bind(offset) .fetch_all(db) @@ -371,13 +457,18 @@ impl LoanLifecycleService { WHERE ($1::uuid IS NULL OR user_id = $1) AND ($2::uuid IS NULL OR plan_id = $2) AND ($3::text IS NULL OR status::text = $3) - "# + "#, ); let count = sql .bind(filters.user_id) .bind(filters.plan_id) - .bind(filters.status.as_ref().map(|status| status.as_str().to_string())) + .bind( + filters + .status + .as_ref() + .map(|status| status.as_str().to_string()), + ) .fetch_one(db) .await?; Ok(count) @@ -882,10 +973,10 @@ impl LoanLifecycleService { })?; let current_status = LoanStatus::from_str(&row.status)?; - + let new_amount_repaid = row.amount_repaid + amount; let fully_repaid = new_amount_repaid >= row.principal; - + if fully_repaid { let next_status = LoanStatus::PaidOff; current_status.validate_transition(next_status)?; @@ -1001,6 +1092,47 @@ impl LoanLifecycleService { crate::metrics::inc_loans_liquidated(); Ok(record) } + + /// Convenience wrapper used by older call-sites to create and immediately + /// activate a loan. This delegates to `create_draft_loan` and then + /// transitions the loan to `active` so higher-level handlers can call a + /// single method. + pub async fn create_loan( + pool: &PgPool, + req: &CreateLoanRequest, + ) -> Result { + let draft = Self::create_draft_loan(pool, req).await?; + // Activate the draft loan on behalf of the requesting user. + Self::activate_loan(pool, draft.id, req.user_id).await + } + + /// Admin-facing wrapper to liquidate a loan. Reuses `default_loan` which + /// marks the loan as defaulted/liquidated and logs metrics. + pub async fn liquidate_loan( + pool: &PgPool, + loan_id: Uuid, + admin_id: Uuid, + ) -> Result { + Self::default_loan(pool, loan_id, admin_id).await + } + + /// Sweep active loans past their due date and mark them as `defaulted`. + /// Returns the list of loan ids that were marked. This is a best-effort + /// implementation useful for tests and cron jobs. + pub async fn mark_overdue_loans(db: &PgPool) -> Result, ApiError> { + let rows = sqlx::query_as::<_, (Uuid,)>( + r#" + UPDATE loan_lifecycle + SET status = 'defaulted', liquidated_at = NOW() + WHERE status = 'active' AND due_date < NOW() + RETURNING id + "#, + ) + .fetch_all(db) + .await?; + + Ok(rows.into_iter().map(|(id,)| id).collect()) + } } // ── Tests ───────────────────────────────────────────────────────────────────── @@ -1042,23 +1174,47 @@ mod tests { #[test] fn valid_state_transitions_pass() { // All valid transitions - assert!(LoanStatus::Draft.validate_transition(LoanStatus::Applied).is_ok()); - assert!(LoanStatus::Applied.validate_transition(LoanStatus::UnderReview).is_ok()); - assert!(LoanStatus::UnderReview.validate_transition(LoanStatus::Approved).is_ok()); - assert!(LoanStatus::UnderReview.validate_transition(LoanStatus::Rejected).is_ok()); - assert!(LoanStatus::Approved.validate_transition(LoanStatus::Active).is_ok()); - assert!(LoanStatus::Active.validate_transition(LoanStatus::PaidOff).is_ok()); - assert!(LoanStatus::Active.validate_transition(LoanStatus::Defaulted).is_ok()); + assert!(LoanStatus::Draft + .validate_transition(LoanStatus::Applied) + .is_ok()); + assert!(LoanStatus::Applied + .validate_transition(LoanStatus::UnderReview) + .is_ok()); + assert!(LoanStatus::UnderReview + .validate_transition(LoanStatus::Approved) + .is_ok()); + assert!(LoanStatus::UnderReview + .validate_transition(LoanStatus::Rejected) + .is_ok()); + assert!(LoanStatus::Approved + .validate_transition(LoanStatus::Active) + .is_ok()); + assert!(LoanStatus::Active + .validate_transition(LoanStatus::PaidOff) + .is_ok()); + assert!(LoanStatus::Active + .validate_transition(LoanStatus::Defaulted) + .is_ok()); } #[test] fn invalid_state_transitions_fail() { // A few invalid transitions - assert!(LoanStatus::Draft.validate_transition(LoanStatus::Active).is_err()); - assert!(LoanStatus::Applied.validate_transition(LoanStatus::Approved).is_err()); - assert!(LoanStatus::Approved.validate_transition(LoanStatus::PaidOff).is_err()); - assert!(LoanStatus::PaidOff.validate_transition(LoanStatus::Active).is_err()); - assert!(LoanStatus::Rejected.validate_transition(LoanStatus::UnderReview).is_err()); + assert!(LoanStatus::Draft + .validate_transition(LoanStatus::Active) + .is_err()); + assert!(LoanStatus::Applied + .validate_transition(LoanStatus::Approved) + .is_err()); + assert!(LoanStatus::Approved + .validate_transition(LoanStatus::PaidOff) + .is_err()); + assert!(LoanStatus::PaidOff + .validate_transition(LoanStatus::Active) + .is_err()); + assert!(LoanStatus::Rejected + .validate_transition(LoanStatus::UnderReview) + .is_err()); } // ── Partial repayment business logic ───────────────────────────────────── diff --git a/backend/src/main.rs b/backend/src/main.rs index e081ef92f..294d6ed0a 100644 --- a/backend/src/main.rs +++ b/backend/src/main.rs @@ -19,10 +19,7 @@ async fn main() -> Result<(), Box> { // Run the rest of startup inside a helper so any error can be captured // by Sentry before the guard is dropped. if let Err(e) = run().await { - error_tracking::capture_message( - &format!("Fatal startup error: {e}"), - sentry::Level::Fatal, - ); + error_tracking::capture_message(&format!("Fatal startup error: {e}"), sentry::Level::Fatal); // Give Sentry a moment to flush the event before the process exits. std::thread::sleep(std::time::Duration::from_secs(2)); return Err(e); @@ -32,7 +29,6 @@ async fn main() -> Result<(), Box> { } async fn run() -> Result<(), Box> { - // Install Prometheus metrics recorder (Issue #423). let prometheus_handle = metrics::get_or_install_recorder(); diff --git a/backend/src/metrics.rs b/backend/src/metrics.rs index 8313bd657..7059cc321 100644 --- a/backend/src/metrics.rs +++ b/backend/src/metrics.rs @@ -23,6 +23,8 @@ //! | `db_pool_size` | Gauge | — | Open connections (idle + active) | //! | `db_pool_idle` | Gauge | — | Idle connections | //! | `db_pool_active` | Gauge | — | Checked-out connections | +//! | `db_pool_pending` | Gauge | — | Connections waiting for checkout | +//! | `db_pool_max_connections` | Gauge | — | Configured pool size limit | //! | `db_pool_utilisation` | Gauge | — | active / max_connections | //! | `db_query_duration_seconds` | Histogram | operation | Query round-trip latency | //! @@ -194,6 +196,8 @@ pub fn record_pool_metrics(m: &crate::db::PoolMetrics) { metrics::gauge!("db_pool_size").set(m.size as f64); metrics::gauge!("db_pool_idle").set(m.idle as f64); metrics::gauge!("db_pool_active").set(m.active as f64); + metrics::gauge!("db_pool_pending").set(m.pending as f64); + metrics::gauge!("db_pool_max_connections").set(m.max_connections as f64); metrics::gauge!("db_pool_utilisation").set(m.utilisation); } @@ -351,6 +355,7 @@ mod tests { size: 3, idle: 1, active: 2, + pending: 0, max_connections: 10, utilisation: 0.2, }; diff --git a/backend/src/middleware.rs b/backend/src/middleware.rs index 909953571..8db887053 100644 --- a/backend/src/middleware.rs +++ b/backend/src/middleware.rs @@ -7,9 +7,10 @@ use axum::{ }; use serde_json::json; use serde_json::Value as JsonValue; -use std::collections::HashSet; -use std::sync::Arc; -use std::time::Duration; +use std::collections::{hash_map::DefaultHasher, HashSet}; +use std::hash::{Hash, Hasher}; +use std::sync::{Arc, OnceLock}; +use std::time::{Duration, Instant}; use tokio::time::timeout; use tower_governor::{errors::GovernorError, key_extractor::KeyExtractor}; use uuid::Uuid; @@ -51,14 +52,15 @@ pub async fn enforce_max_request_size(req: Request, next: Next) -> Respons // Read the body up to the configured cap so we can inspect JSON payloads. let (parts, body) = req.into_parts(); - let bytes: axum::body::Bytes = match axum::body::to_bytes(body, crate::validation::DEFAULT_MAX_BODY_BYTES + 1).await { - Ok(b) => b, - Err(_) => { - // If body couldn't be read, let the inner handler observe the failure. - let req = Request::from_parts(parts, Body::empty()); - return next.run(req).await; - } - }; + let bytes: axum::body::Bytes = + match axum::body::to_bytes(body, crate::validation::DEFAULT_MAX_BODY_BYTES + 1).await { + Ok(b) => b, + Err(_) => { + // If body couldn't be read, let the inner handler observe the failure. + let req = Request::from_parts(parts, Body::empty()); + return next.run(req).await; + } + }; if bytes.len() > crate::validation::DEFAULT_MAX_BODY_BYTES { return ( @@ -258,9 +260,10 @@ pub async fn attach_correlation_id(req: Request, next: Next) -> impl IntoR } /// Logs each incoming request with its method, URI, and assigned request ID. -pub async fn request_logging_middleware(req: Request, next: Next) -> Response { +pub async fn request_logging_middleware(mut req: Request, next: Next) -> Response { let method = req.method().clone(); let uri = req.uri().clone(); + let path = uri.path().to_string(); let request_id = req .headers() .get(&X_REQUEST_ID) @@ -268,19 +271,145 @@ pub async fn request_logging_middleware(req: Request, next: Next) -> Response { .unwrap_or("unknown") .to_owned(); - tracing::info!(request_id = %request_id, method = %method, uri = %uri, "incoming request"); + let context = req + .extensions() + .get::() + .cloned() + .unwrap_or_default(); + + let headers = sanitize_headers(req.headers()); + let start = Instant::now(); let response = next.run(req).await; + let duration_ms = start.elapsed().as_secs_f64() * 1000.0; + let status = response.status(); - tracing::info!( - request_id = %request_id, - status = %response.status(), - "request completed" - ); + let sampling_ratio = log_sampling_ratio(); + let is_high_traffic = is_high_traffic_path(&path); + let should_emit_full_details = + !is_high_traffic || should_sample_request(&request_id, sampling_ratio); + + if should_emit_full_details { + tracing::info!( + request_id = %request_id, + http.method = %method, + http.path = %path, + http.status_code = %status, + http.duration_ms = duration_ms, + user_id = ?context.user_id, + plan_id = ?context.plan_id, + http.request_headers = ?headers, + "request completed" + ); + } else { + tracing::info!( + request_id = %request_id, + http.method = %method, + http.path = %path, + http.status_code = %status, + http.duration_ms = duration_ms, + user_id = ?context.user_id, + plan_id = ?context.plan_id, + "request completed" + ); + } response } +fn sanitize_headers(headers: &HeaderMap) -> Vec<(String, String)> { + headers + .iter() + .map(|(name, value)| { + ( + name.as_str().to_string(), + sanitize_header_value(name, value), + ) + }) + .collect() +} + +fn sanitize_header_value(name: &HeaderName, value: &HeaderValue) -> String { + let sensitive_headers = [ + "authorization", + "cookie", + "set-cookie", + "x-api-key", + "x-csrf-token", + "x-csrf", + ]; + let header_name = name.as_str().to_ascii_lowercase(); + + if sensitive_headers.contains(&header_name.as_str()) { + "***".to_string() + } else { + value.to_str().unwrap_or("").to_string() + } +} + +fn is_high_traffic_path(path: &str) -> bool { + matches!( + path, + "/api/metrics" | "/api/health" | "/api/ping" | "/api/status" | "/api/loans/lifecycle" + ) +} + +fn log_sampling_ratio() -> f64 { + static RATIO: OnceLock = OnceLock::new(); + *RATIO.get_or_init(|| { + std::env::var("LOG_SAMPLING_PERCENT") + .ok() + .and_then(|value| value.parse::().ok()) + .map(|percent| percent.clamp(0.0, 100.0) / 100.0) + .unwrap_or(0.1) + }) +} + +fn should_sample_request(request_id: &str, ratio: f64) -> bool { + if ratio <= 0.0 { + return false; + } + if ratio >= 1.0 { + return true; + } + + let mut hasher = DefaultHasher::new(); + request_id.hash(&mut hasher); + let bucket = hasher.finish() % 1000; + bucket < (ratio * 1000.0).round() as u64 +} + +#[cfg(test)] +mod tests { + use super::*; + use axum::http::HeaderMap; + use axum::http::HeaderValue; + + #[test] + fn sanitize_headers_masks_sensitive_values() { + let mut headers = HeaderMap::new(); + headers.insert("authorization", HeaderValue::from_static("Bearer abc123")); + headers.insert("cookie", HeaderValue::from_static("session=secret")); + headers.insert("x-api-key", HeaderValue::from_static("key")); + headers.insert("content-type", HeaderValue::from_static("application/json")); + + let sanitized = sanitize_headers(&headers); + assert!(sanitized.contains(&("authorization".to_string(), "***".to_string()))); + assert!(sanitized.contains(&("cookie".to_string(), "***".to_string()))); + assert!(sanitized.contains(&("x-api-key".to_string(), "***".to_string()))); + assert!(sanitized.contains(&("content-type".to_string(), "application/json".to_string()))); + } + + #[test] + fn should_sample_request_is_deterministic() { + let request_id = "abc-123"; + let ratio = 0.5; + let first = should_sample_request(request_id, ratio); + let second = should_sample_request(request_id, ratio); + assert_eq!(first, second); + } +} + pub async fn log_rate_limit_violations(req: Request, next: Next) -> impl IntoResponse { let path = req.uri().path().to_string(); let method = req.method().clone(); diff --git a/backend/src/notifications.rs b/backend/src/notifications.rs index eb7d68e42..845f4a087 100644 --- a/backend/src/notifications.rs +++ b/backend/src/notifications.rs @@ -167,10 +167,7 @@ impl NotificationService { } /// Mark notification as delivered. - pub async fn mark_delivered( - db: &PgPool, - notif_id: Uuid, - ) -> Result<(), ApiError> { + pub async fn mark_delivered(db: &PgPool, notif_id: Uuid) -> Result<(), ApiError> { sqlx::query( r#" UPDATE notifications @@ -186,10 +183,7 @@ impl NotificationService { } /// Increment delivery attempts for a notification. - pub async fn increment_delivery_attempts( - db: &PgPool, - notif_id: Uuid, - ) -> Result<(), ApiError> { + pub async fn increment_delivery_attempts(db: &PgPool, notif_id: Uuid) -> Result<(), ApiError> { sqlx::query( r#" UPDATE notifications @@ -209,10 +203,7 @@ impl NotificationService { } /// Fetch undelivered notifications for retry. - pub async fn list_undelivered( - db: &PgPool, - limit: i64, - ) -> Result, ApiError> { + pub async fn list_undelivered(db: &PgPool, limit: i64) -> Result, ApiError> { let rows = sqlx::query_as::<_, Notification>( r#" SELECT id, user_id, type, message, is_read, delivery_status, delivery_attempts, created_at @@ -619,6 +610,8 @@ mod tests { notif_type: notif_type::KYC_APPROVED.to_string(), message: "Approved!".to_string(), is_read: false, + delivery_status: None, + delivery_attempts: None, created_at: Utc::now(), }; let json = serde_json::to_value(&n).unwrap(); @@ -640,6 +633,8 @@ mod tests { notif_type: notif_type::PLAN_CREATED.to_string(), message: "Plan created".to_string(), is_read: false, + delivery_status: None, + delivery_attempts: None, created_at: Utc::now(), }; assert!(!n.is_read); @@ -657,6 +652,7 @@ mod tests { old_value: None, new_value: None, metadata: None, + sequence_number: None, timestamp: Utc::now(), }; let json = serde_json::to_value(&log).unwrap(); @@ -676,6 +672,7 @@ mod tests { old_value: None, new_value: None, metadata: None, + sequence_number: None, timestamp: Utc::now(), }; let json = serde_json::to_value(&log).unwrap(); diff --git a/backend/src/price_feed.rs b/backend/src/price_feed.rs index 176154c91..542083ade 100644 --- a/backend/src/price_feed.rs +++ b/backend/src/price_feed.rs @@ -46,7 +46,10 @@ pub struct AssetPrice { impl AssetPrice { /// Returns true if the price is younger than MAX_PRICE_AGE_SECS. pub fn is_fresh(&self) -> bool { - Utc::now().signed_duration_since(self.timestamp).num_seconds() < MAX_PRICE_AGE_SECS + Utc::now() + .signed_duration_since(self.timestamp) + .num_seconds() + < MAX_PRICE_AGE_SECS } } diff --git a/backend/src/price_feed_handlers.rs b/backend/src/price_feed_handlers.rs index 1002eea10..775d947c9 100644 --- a/backend/src/price_feed_handlers.rs +++ b/backend/src/price_feed_handlers.rs @@ -2,9 +2,9 @@ use crate::api_error::ApiError; use crate::auth::{AuthenticatedAdmin, AuthenticatedUser}; use crate::notifications::AuditLogService; use crate::price_feed::{PriceFeedService, PriceFeedSource}; +use crate::validation::Path; use axum::extract::State; use axum::Json; -use crate::validation::Path; use rust_decimal::Decimal; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; @@ -285,7 +285,6 @@ pub async fn get_plan_valuation( }))) } - /// Get all active price feeds (admin only) pub async fn get_active_feeds( State((_db, price_service)): State<(PgPool, Arc)>, diff --git a/backend/src/risk_engine.rs b/backend/src/risk_engine.rs index e10889860..a6f91cc11 100644 --- a/backend/src/risk_engine.rs +++ b/backend/src/risk_engine.rs @@ -3,11 +3,11 @@ use crate::notifications::{ audit_action, entity_type, notif_type, AuditLogService, NotificationService, }; use crate::price_feed::PriceFeedService; +use chrono::Utc; use rust_decimal::Decimal; use sqlx::PgPool; use std::sync::Arc; use std::time::Duration; -use chrono::Utc; use tracing::{error, info, warn}; pub struct RiskEngine { @@ -66,7 +66,10 @@ impl RiskEngine { last_seen = ts_opt; if last_seen.is_some() { if let Err(e) = watcher.check_all_loans().await { - error!("Risk Engine error recalculating after price update: {}", e); + error!( + "Risk Engine error recalculating after price update: {}", + e + ); } else { info!("Risk Engine recalculated health factors after price update."); } @@ -107,7 +110,7 @@ impl RiskEngine { WHERE (ll.principal - ll.amount_repaid) > 0 AND ll.status = 'active' AND (p.is_paused IS NULL OR p.is_paused = false) - "# + "#, ) .fetch_all(&self.db) .await @@ -151,9 +154,9 @@ impl RiskEngine { // Determine liquidation threshold based on collateral asset when possible let asset_upper = collat_asset.to_uppercase(); let liquidation_threshold_for_asset = match asset_upper.as_str() { - "USDC" => Decimal::new(95, 2), // 0.95 - "ETH" | "WETH" => Decimal::new(85, 2), // 0.85 - "BTC" | "WBTC" => Decimal::new(85, 2), // 0.85 + "USDC" => Decimal::new(95, 2), // 0.95 + "ETH" | "WETH" => Decimal::new(85, 2), // 0.85 + "BTC" | "WBTC" => Decimal::new(85, 2), // 0.85 "XLM" | "STELLAR_XLM" => Decimal::new(80, 2), // 0.80 // Fallback to engine-wide threshold if unknown _ => self.liquidation_threshold, diff --git a/backend/src/service.rs b/backend/src/service.rs index d005cd77b..7dc80c213 100644 --- a/backend/src/service.rs +++ b/backend/src/service.rs @@ -459,13 +459,12 @@ impl PlanService { where E: sqlx::Executor<'a, Database = sqlx::Postgres>, { - let exists: bool = sqlx::query_scalar( - "SELECT EXISTS(SELECT 1 FROM plans WHERE id = $1 AND user_id = $2)", - ) - .bind(plan_id) - .bind(user_id) - .fetch_one(executor) - .await?; + let exists: bool = + sqlx::query_scalar("SELECT EXISTS(SELECT 1 FROM plans WHERE id = $1 AND user_id = $2)") + .bind(plan_id) + .bind(user_id) + .fetch_one(executor) + .await?; if !exists { return Err(ApiError::NotFound(format!("Plan {plan_id} not found"))); @@ -1600,11 +1599,7 @@ impl RevenueMetricsService { "#, ); - let rows = rows - .bind(trunc) - .bind(interval) - .fetch_all(pool) - .await?; + let rows = rows.bind(trunc).bind(interval).fetch_all(pool).await?; let data = rows .into_iter() @@ -4172,7 +4167,7 @@ impl EmergencyAccessMetricsService { _ => ("30 days", "day"), // default to daily }; - let trend_rows: Vec<(String, i64)> = sqlx::query_as( + let trend_rows: Vec<(String, i64)> = sqlx::query_as::<_, (String, i64)>( r#" SELECT DATE_TRUNC($1, created_at)::DATE::TEXT as date, @@ -4182,13 +4177,11 @@ impl EmergencyAccessMetricsService { GROUP BY 1 ORDER BY 1 "#, - ); - - let trend_rows: Vec<(String, i64)> = trend_rows - .bind(trunc) - .bind(interval) - .fetch_all(db) - .await?; + ) + .bind(trunc) + .bind(interval) + .fetch_all(db) + .await?; let grant_trend = trend_rows .into_iter() diff --git a/backend/src/session.rs b/backend/src/session.rs index 14e2e553f..67bfc3688 100644 --- a/backend/src/session.rs +++ b/backend/src/session.rs @@ -269,6 +269,7 @@ pub async fn list_sessions( pub async fn revoke_session( State(state): State>, crate::validation::Path(session_id): crate::validation::Path, + AuthenticatedUser(user): AuthenticatedUser, req: Request, ) -> Result, ApiError> { // Ensure the session belongs to the authenticated user @@ -294,7 +295,6 @@ pub async fn revoke_session( Ok(Json(json!({ "message": "Session revoked" }))) } - // ── Middleware ──────────────────────────────────────────────────────────────── /// Rejects requests whose JWT has been explicitly revoked OR expired. diff --git a/backend/src/validation.rs b/backend/src/validation.rs index 5224aa958..964ba1094 100644 --- a/backend/src/validation.rs +++ b/backend/src/validation.rs @@ -5,10 +5,9 @@ /// before processing the request body. use crate::api_error::ApiError; use regex::Regex; +use serde_json::Value as JsonValue; use std::collections::HashMap; use std::sync::OnceLock; -use serde_json::Value as JsonValue; -use once_cell::sync::Lazy; /// Collects field-level validation errors. #[derive(Debug, Default)] @@ -44,7 +43,9 @@ impl ValidationErrors { /// Strips leading/trailing whitespace and removes common SQL injection patterns. pub fn sanitize_string(input: &str) -> String { - sql_injection_pattern().replace_all(input.trim(), "").to_string() + sql_injection_pattern() + .replace_all(input.trim(), "") + .to_string() } fn sql_injection_pattern() -> &'static Regex { @@ -79,22 +80,160 @@ pub fn validate_min_length(errors: &mut ValidationErrors, field: &str, value: &s } pub fn validate_email(errors: &mut ValidationErrors, field: &str, value: &str) { - static EMAIL_RE: Lazy = Lazy::new(|| { - // This pattern is a pragmatic RFC 5322 compatible (local@domain) validator. - // It supports quoted local-parts, dots, and IPv4/IPv6 literals in domains. - Regex::new(r"(?xi)^(?:[a-z0-9!#$%&'*+/=?^_`{|}~-]+(?:\.[a-z0-9!#$%&'*+/=?^_`{|}~-]+)*|\"(?:[\x01-\x08\x0b\x0c\x0e-\x1f\x21\x23-\x5b\x5d-\x7f]|\\[\x00-\x7f])*\")@(?:(?:[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?|\[(?:(?:25[0-5]|2[0-4]\d|[01]?\d?\d)(?:\.(?:25[0-5]|2[0-4]\d|[01]?\d?\d)){3}|[a-f0-9:\.]+)\])$").expect("email regex") - }); - let s = value.trim(); - // Per RFCs, the maximum total length for an email address is 254 characters. - if s.len() == 0 || s.len() > 254 { + if s.is_empty() || s.len() > 254 { + errors.add(field, "must be a valid email address"); + return; + } + + let parts: Vec<&str> = s.rsplitn(2, '@').collect(); + if parts.len() != 2 { + errors.add(field, "must be a valid email address"); + return; + } + + let domain = parts[0]; + let local = parts[1]; + + if local.is_empty() || domain.is_empty() { + errors.add(field, "must be a valid email address"); + return; + } + + if local.starts_with('.') || local.ends_with('.') || local.contains("..") { + errors.add(field, "must be a valid email address"); + return; + } + + if domain.starts_with('.') || domain.ends_with('.') || domain.contains("..") { errors.add(field, "must be a valid email address"); return; } - if !EMAIL_RE.is_match(s) { + let valid_local = if local.starts_with('"') && local.ends_with('"') { + validate_quoted_local_part(local) + } else { + is_valid_unquoted_local_part(local) + }; + + if !valid_local { errors.add(field, "must be a valid email address"); + return; } + + if !is_valid_domain(domain) { + errors.add(field, "must be a valid email address"); + } +} + +fn is_valid_unquoted_local_part(local: &str) -> bool { + if local.is_empty() { + return false; + } + + local.as_bytes().iter().all(|&b| match b { + b'a'..=b'z' + | b'A'..=b'Z' + | b'0'..=b'9' + | b'!' + | b'#' + | b'$' + | b'%' + | b'&' + | b'\'' + | b'*' + | b'+' + | b'-' + | b'/' + | b'=' + | b'?' + | b'^' + | b'_' + | b'`' + | b'{' + | b'|' + | b'}' + | b'~' + | b'.' => true, + _ => false, + }) +} + +fn is_valid_domain(domain: &str) -> bool { + if domain.starts_with('[') && domain.ends_with(']') { + let literal = &domain[1..domain.len() - 1]; + return is_valid_ipv4(literal) || is_valid_ipv6_literal(literal); + } + + if domain.len() > 255 { + return false; + } + + for label in domain.split('.') { + if label.is_empty() || label.len() > 63 { + return false; + } + if label.starts_with('-') || label.ends_with('-') { + return false; + } + if !label.as_bytes().iter().all(|&b| match b { + b'a'..=b'z' | b'A'..=b'Z' | b'0'..=b'9' | b'-' => true, + _ => false, + }) { + return false; + } + } + + true +} + +fn is_valid_ipv4(literal: &str) -> bool { + let octets: Vec<&str> = literal.split('.').collect(); + if octets.len() != 4 { + return false; + } + for octet in octets { + if octet.is_empty() || octet.len() > 3 { + return false; + } + if octet.starts_with('0') && octet.len() > 1 { + return false; + } + let value = octet.parse::(); + if value.is_err() { + return false; + } + } + true +} + +fn is_valid_ipv6_literal(literal: &str) -> bool { + literal + .strip_prefix("IPv6:") + .map(|v| { + v.chars() + .all(|ch| ch.is_ascii_hexdigit() || ch == ':' || ch == '.') + }) + .unwrap_or(false) +} + +fn validate_quoted_local_part(local: &str) -> bool { + let inner = &local[1..local.len() - 1]; + let mut chars = inner.chars(); + while let Some(ch) = chars.next() { + if ch == '\\' { + if chars.next().is_none() { + return false; + } + continue; + } + + if ch == '"' || ch == '\r' || ch == '\n' || ch.is_control() { + return false; + } + } + + true } pub fn validate_uuid(errors: &mut ValidationErrors, field: &str, value: &str) { @@ -193,7 +332,6 @@ macro_rules! bail_if_invalid { #[derive(Debug)] pub struct Path(pub T); -#[axum::async_trait] impl axum::extract::FromRequestParts for Path where T: serde::de::DeserializeOwned + Send, @@ -201,14 +339,17 @@ where { type Rejection = ApiError; - async fn from_request_parts( + fn from_request_parts( parts: &mut axum::http::request::Parts, state: &S, - ) -> Result { - match axum::extract::Path::::from_request_parts(parts, state).await { - Ok(axum::extract::Path(value)) => Ok(Path(value)), - Err(err) => { - Err(ApiError::BadRequest(format!("Invalid path parameter: {}", err))) + ) -> impl std::future::Future> + Send { + async move { + match axum::extract::Path::::from_request_parts(parts, state).await { + Ok(axum::extract::Path(value)) => Ok(Path(value)), + Err(err) => Err(ApiError::BadRequest(format!( + "Invalid path parameter: {}", + err + ))), } } } diff --git a/backend/src/webhook.rs b/backend/src/webhook.rs index 96bf34e86..39c86b865 100644 --- a/backend/src/webhook.rs +++ b/backend/src/webhook.rs @@ -1,9 +1,5 @@ -use axum::{ - extract::State, - http::StatusCode, - Json, -}; use crate::validation::Path; +use axum::{extract::State, http::StatusCode, Json}; use base64::engine::general_purpose; use base64::Engine as _; use rand::Rng;