From 83816234d9c8cbe67e3cb44e916c0c7f8e8ad422 Mon Sep 17 00:00:00 2001 From: mohadon0 Date: Tue, 2 Jun 2026 10:54:31 +0000 Subject: [PATCH] Implement Redis initialization and DI for JwtService with integration tests --- backend/src/auth/jwt_service.rs | 95 +++++++--- backend/src/auth/jwt_service_test.rs | 39 ++++ backend/src/auth/mod.rs | 2 + backend/src/main.rs | 89 ++++++++- backend/src/service/auth_service.rs | 269 +++++++++++++++++++++++++-- backend/src/service/mod.rs | 2 + 6 files changed, 447 insertions(+), 49 deletions(-) diff --git a/backend/src/auth/jwt_service.rs b/backend/src/auth/jwt_service.rs index 376cf55b..13551b14 100644 --- a/backend/src/auth/jwt_service.rs +++ b/backend/src/auth/jwt_service.rs @@ -117,6 +117,7 @@ impl Default for JwtConfig { pub struct SessionData { pub user_id: Uuid, pub session_id: String, + pub token_type: TokenType, pub device_id: Option, pub created_at: i64, pub last_activity: i64, @@ -165,6 +166,7 @@ impl KeyRotation { } /// Main JWT Service +#[derive(Clone)] pub struct JwtService { config: JwtConfig, redis: ConnectionManager, @@ -209,8 +211,15 @@ impl JwtService { let token = encode(&Header::new(self.config.algorithm), &claims, &encoding_key) .map_err(|e| JwtError::TokenGeneration(e.to_string()))?; - // Store session in Redis - self.store_session(&session_id, user_id, device_id).await?; + // Store access session in Redis + self.store_session( + &session_id, + user_id, + device_id.clone(), + TokenType::Access, + self.config.access_token_expiry, + ) + .await?; info!(user_id = %user_id, session_id = %session_id, "Access token generated"); @@ -243,6 +252,15 @@ impl JwtService { let token = encode(&Header::new(self.config.algorithm), &claims, &encoding_key) .map_err(|e| JwtError::TokenGeneration(e.to_string()))?; + self.store_session( + &session_id, + user_id, + device_id.clone(), + TokenType::Refresh, + self.config.refresh_token_expiry, + ) + .await?; + info!(user_id = %user_id, session_id = %session_id, "Refresh token generated"); Ok(token) @@ -333,6 +351,28 @@ impl JwtService { Ok(token_data.claims) } + async fn decode_with_key_rotation(&self, token: &str) -> Result { + let key_rotation = self.key_rotation.read().await; + match self.decode_token(token, &key_rotation.current_key) { + Ok(claims) => Ok(claims), + Err(err) => { + if let Some(ref previous_key) = key_rotation.previous_key { + debug!("Trying previous key for token rotation"); + self.decode_token(token, previous_key).or(Err(err)) + } else { + Err(err) + } + } + } + } + + async fn extract_jti(&self, token: &str) -> Option { + self.decode_with_key_rotation(token) + .await + .ok() + .map(|claims| claims.jti) + } + /// Refresh access token using refresh token pub async fn refresh_token(&self, refresh_token: &str) -> Result { let claims = self.validate_token(refresh_token).await?; @@ -361,12 +401,17 @@ impl JwtService { /// Blacklist a token pub async fn blacklist_token(&self, token: &str, reason: &str) -> Result<(), JwtError> { // Decode token to get expiration - let key_rotation = self.key_rotation.read().await; - let claims = self.decode_token(token, &key_rotation.current_key)?; + let claims = match self.decode_with_key_rotation(token).await { + Ok(claims) => claims, + Err(JwtError::TokenExpired) => { + // Already expired, no need to store blacklist metadata. + return Ok(()); + } + Err(err) => return Err(err), + }; let exp_duration = claims.exp - Utc::now().timestamp(); if exp_duration <= 0 { - // Token already expired, no need to blacklist return Ok(()); } @@ -390,13 +435,7 @@ impl JwtService { /// that a concurrent `rotate_keys()` write is never blocked for the full /// duration of the network call. pub async fn is_token_blacklisted(&self, token: &str) -> Result { - // Extract the JTI while holding the lock, then drop it immediately. - let jti_opt = { - let key_rotation = self.key_rotation.read().await; - self.decode_token(token, &key_rotation.current_key) - .ok() - .map(|c| c.jti) - }; // lock dropped here + let jti_opt = self.extract_jti(token).await; match jti_opt { Some(jti) => { @@ -420,10 +459,13 @@ impl JwtService { session_id: &str, user_id: Uuid, device_id: Option, + token_type: TokenType, + ttl: Duration, ) -> Result<(), JwtError> { let session_data = SessionData { user_id, session_id: session_id.to_string(), + token_type, device_id, created_at: Utc::now().timestamp(), last_activity: Utc::now().timestamp(), @@ -436,12 +478,8 @@ impl JwtService { .map_err(|e| JwtError::RedisError(e.to_string()))?; let mut conn = self.redis.clone(); - conn.set_ex( - &session_key, - session_json, - self.config.access_token_expiry.num_seconds() as u64, - ) - .await?; + conn.set_ex(&session_key, session_json, ttl.num_seconds() as u64) + .await?; // Add to user's active sessions set; expire the set with the refresh TTL // so it outlives individual access token sessions. @@ -481,13 +519,12 @@ impl JwtService { let updated_json = serde_json::to_string(&session).map_err(|e| JwtError::RedisError(e.to_string()))?; - // Refresh the TTL using access_token_expiry (consistent with store_session). - conn.set_ex( - &session_key, - updated_json, - self.config.access_token_expiry.num_seconds() as u64, - ) - .await?; + let ttl_seconds = match session.token_type { + TokenType::Access => self.config.access_token_expiry.num_seconds() as u64, + TokenType::Refresh => self.config.refresh_token_expiry.num_seconds() as u64, + }; + + conn.set_ex(&session_key, updated_json, ttl_seconds).await?; } Ok(()) @@ -537,6 +574,14 @@ impl JwtService { pub async fn revoke_session(&self, session_id: &str) -> Result<(), JwtError> { let session_key = format!("session:{}", session_id); let mut conn = self.redis.clone(); + + if let Some(session_json) = conn.get::<_, Option>(&session_key).await? { + if let Ok(session) = serde_json::from_str::(&session_json) { + let user_sessions_key = format!("user_sessions:{}", session.user_id); + conn.srem(&user_sessions_key, session_id).await?; + } + } + conn.del(&session_key).await?; info!(session_id = %session_id, "Session revoked"); diff --git a/backend/src/auth/jwt_service_test.rs b/backend/src/auth/jwt_service_test.rs index 1645ae6e..0fad69f3 100644 --- a/backend/src/auth/jwt_service_test.rs +++ b/backend/src/auth/jwt_service_test.rs @@ -244,6 +244,45 @@ mod integration_tests { assert!(stats.total_validated > 0); } + #[tokio::test] + async fn test_refresh_token_creates_refresh_session() { + let service = create_test_service().await; + let user_id = Uuid::new_v4(); + let roles = vec!["user".to_string()]; + + let token_pair = service + .generate_token_pair(user_id, roles.clone(), None) + .await + .unwrap(); + + let sessions = service.get_user_sessions(user_id).await.unwrap(); + assert!(sessions.iter().any(|s| s.token_type == TokenType::Refresh)); + assert!(sessions.iter().any(|s| s.token_type == TokenType::Access)); + assert_eq!(sessions.len(), 2); + + let refreshed = service.refresh_token(&token_pair.refresh_token).await; + assert!(refreshed.is_ok()); + } + + #[tokio::test] + async fn test_revoke_session_invalidates_token() { + let service = create_test_service().await; + let user_id = Uuid::new_v4(); + let roles = vec!["user".to_string()]; + + let token = service + .generate_access_token(user_id, roles.clone(), Some("device-test".to_string())) + .await + .unwrap(); + + let claims = service.validate_token(&token).await.unwrap(); + service.revoke_session(&claims.session_id).await.unwrap(); + + let result = service.validate_token(&token).await; + assert!(result.is_err()); + assert!(matches!(result.unwrap_err(), JwtError::SessionNotFound)); + } + #[tokio::test] async fn test_key_rotation() { let service = create_test_service().await; diff --git a/backend/src/auth/mod.rs b/backend/src/auth/mod.rs index 58f525d7..d9cc1f25 100644 --- a/backend/src/auth/mod.rs +++ b/backend/src/auth/mod.rs @@ -1,5 +1,7 @@ pub mod device_service; pub mod jwt_service; +#[cfg(test)] +mod jwt_service_test; pub mod middleware; pub use device_service::{ diff --git a/backend/src/main.rs b/backend/src/main.rs index 512763de..ac0ee9b9 100644 --- a/backend/src/main.rs +++ b/backend/src/main.rs @@ -22,6 +22,7 @@ use crate::middleware::idempotency_middleware::IdempotencyMiddleware; use crate::middleware::rate_limit::RateLimitMiddleware; use crate::middleware::security::{SecurityConfig, SecurityMiddleware}; use crate::service::ReaperService; +use crate::service::AuthService; use crate::realtime::event_bus::EventBus; use crate::realtime::session_registry::SessionRegistry; use crate::realtime::ws_broadcaster::{WsAddressBook, WsBroadcaster}; @@ -29,6 +30,46 @@ use crate::service::matchmaker::{MatchmakerService, MatchmakingConfig, EloEngine use crate::service::soroban_service::{NetworkConfig, SorobanService}; use crate::service::tournament_service::TournamentService; use crate::telemetry::init_telemetry; +use anyhow::anyhow; +use redis::AsyncCommands; + +async fn create_redis_manager(redis_url: &str) -> io::Result { + let redis_client = redis::Client::open(redis_url).map_err(|e| { + io::Error::new( + io::ErrorKind::InvalidInput, + format!("Invalid Redis URL '{}': {}", redis_url, e), + ) + })?; + + let manager = redis::aio::ConnectionManager::new(redis_client.clone()) + .await + .map_err(|e| { + io::Error::new( + io::ErrorKind::ConnectionRefused, + format!("Failed to initialize Redis connection manager for '{}': {}", redis_url, e), + ) + })?; + + let mut probe = manager.clone(); + let pong: String = probe + .ping() + .await + .map_err(|e| { + io::Error::new( + io::ErrorKind::ConnectionRefused, + format!("Failed to ping Redis server at '{}': {}", redis_url, e), + ) + })?; + + if pong != "PONG" { + return Err(io::Error::new( + io::ErrorKind::ConnectionRefused, + format!("Unexpected Redis ping response from '{}': {}", redis_url, pong), + )); + } + + Ok(manager) +} #[tokio::main] async fn main() -> io::Result<()> { @@ -51,8 +92,6 @@ async fn main() -> io::Result<()> { let reaper = Arc::new(ReaperService::new(db_pool.clone())); reaper.run(); - // Create Redis client (placeholder) - // let redis_client = redis::Client::open(config.redis.url.clone()).unwrap(); // Spawn tournament orchestrator polling worker let _orchestrator_handle = crate::orchestrator::TournamentOrchestrator::spawn_polling_worker( db_pool.clone(), @@ -60,12 +99,9 @@ async fn main() -> io::Result<()> { ); tracing::info!("Tournament orchestrator polling worker started"); - // Create Redis connection manager - let redis_client = redis::Client::open(config.redis.url.clone()) - .expect("Failed to create Redis client"); - let redis_conn = redis::aio::ConnectionManager::new(redis_client.clone()) - .await - .expect("Failed to create Redis connection manager"); + // Create and validate Redis client before startup. + let redis_conn = create_redis_manager(&config.redis.url).await?; + let redis_data = web::Data::new(redis_conn.clone()); // Initialize matchmaking service — pass the shared ConnectionManager so // the service never opens a new connection per request. @@ -120,6 +156,9 @@ async fn main() -> io::Result<()> { let jwt_service = Arc::new(crate::auth::jwt_service::JwtService::new(jwt_config, redis_conn.clone())); let auth_guard = Arc::new(crate::realtime::auth::RealtimeAuth::new(db_pool.clone())); + // Build the HTTP AuthService (wires JWT+Redis into register/login/logout/refresh) + let auth_service = Arc::new(AuthService::new(db_pool.clone(), jwt_service.clone())); + // Start Redis Pub/Sub subscriber (broadcasts to local WebSocket actors) let broadcaster = WsBroadcaster::new( config.redis.url.clone(), @@ -140,6 +179,8 @@ async fn main() -> io::Result<()> { let server = HttpServer::new(move || { App::new() .app_data(web::Data::new(db_pool.clone())) + .app_data(redis_data.clone()) + .app_data(web::Data::new(auth_service.clone())) .app_data(web::Data::new(event_bus.clone())) .app_data(web::Data::new(session_registry.clone())) .app_data(web::Data::new(address_book.clone())) @@ -293,3 +334,35 @@ async fn main() -> io::Result<()> { server.await } + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_create_redis_manager_invalid_url_fails() { + let url = "not-a-redis-url"; + let result = create_redis_manager(url).await; + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("Invalid Redis URL")); + } + + #[tokio::test] + async fn test_init_redis_connection_invalid_host_fails() { + let url = "redis://127.0.0.1:1/"; + let result = create_redis_manager(url).await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_create_redis_manager_success_when_redis_available() { + let url = match std::env::var("REDIS_TEST_URL") { + Ok(url) => url, + Err(_) => return, + }; + + let result = create_redis_manager(&url).await; + assert!(result.is_ok(), "Redis startup should succeed with a running Redis server"); + } +} diff --git a/backend/src/service/auth_service.rs b/backend/src/service/auth_service.rs index c7115427..c1303629 100644 --- a/backend/src/service/auth_service.rs +++ b/backend/src/service/auth_service.rs @@ -1,33 +1,270 @@ -#![allow(dead_code)] - use crate::api_error::ApiError; +use crate::auth::jwt_service::{JwtService, TokenPair}; use crate::db::DbPool; -use crate::models::user::{User, CreateUserRequest, LoginRequest, AuthResponse}; +use crate::models::user::{AuthResponse, CreateUserRequest, LoginRequest, User, UserProfile}; +use bcrypt::{hash, verify, DEFAULT_COST}; +use chrono::Utc; +use tracing::info; use uuid::Uuid; +/// Authentication Service with JWT and Redis-backed session management #[derive(Clone)] pub struct AuthService { - #[allow(dead_code)] pool: DbPool, + jwt_service: Arc, } impl AuthService { - pub fn new(pool: DbPool) -> Self { - Self { pool } + pub fn new(pool: DbPool, jwt_service: Arc) -> Self { + Self { pool, jwt_service } + } + + /// Register a new user + pub async fn register(&self, request: CreateUserRequest) -> Result { + if request.username.is_empty() || request.email.is_empty() || request.password.is_empty() { + return Err(ApiError::bad_request("All fields are required")); + } + + if request.password.len() < 8 { + return Err(ApiError::bad_request( + "Password must be at least 8 characters", + )); + } + + let existing = sqlx::query!( + "SELECT id FROM users WHERE email = $1 OR username = $2", + request.email, + request.username + ) + .fetch_optional(&self.pool) + .await + .map_err(|e| ApiError::database_error(e))?; + + if existing.is_some() { + return Err(ApiError::bad_request( + "User with this email or username already exists", + )); + } + + let password_hash = hash(&request.password, DEFAULT_COST) + .map_err(|e| ApiError::internal_error(format!("Password hashing failed: {}", e)))?; + + let user_id = Uuid::new_v4(); + let now = Utc::now(); + + let user = sqlx::query_as!( + User, + r#" + INSERT INTO users ( + id, username, email, password_hash, is_active, is_verified, created_at, updated_at + ) + VALUES ($1, $2, $3, $4, true, false, $5, $6) + RETURNING id, username, email, password_hash, is_active, is_verified, created_at, updated_at + "#, + user_id, + request.username, + request.email, + password_hash, + now, + now + ) + .fetch_one(&self.pool) + .await + .map_err(|e| ApiError::database_error(e))?; + + let token_pair = self + .jwt_service + .generate_token_pair(user.id, vec!["user".to_string()], None) + .await + .map_err(|e| ApiError::internal_error(format!("Token generation failed: {}", e)))?; + + info!(user_id = %user.id, username = %user.username, "User registered successfully"); + + Ok(AuthResponse { + token: token_pair.access_token, + refresh_token: token_pair.refresh_token, + user: UserProfile { + id: user.id, + username: user.username, + email: user.email, + is_verified: user.is_verified, + created_at: user.created_at, + skill_score: None, + fair_play_score: None, + is_bad_actor: None, + }, + }) + } + + /// Login user and return JWT tokens + pub async fn login(&self, request: LoginRequest) -> Result { + let user = sqlx::query_as!( + User, + "SELECT id, username, email, password_hash, is_active, is_verified, created_at, updated_at, last_login_at, display_name, avatar_url, country_code, role, bio, phone_number, is_banned, device_fingerprint, stellar_account_id, stellar_public_key, total_earnings, banned_until, profile_image_url, reputation_score, skill_score, fair_play_score, is_bad_actor FROM users WHERE email = $1", + request.email + ) + .fetch_optional(&self.pool) + .await + .map_err(|e| ApiError::database_error(e))? + .ok_or_else(|| ApiError::unauthorized("Invalid email or password"))?; + + if !user.is_active { + return Err(ApiError::forbidden("Account is deactivated")); + } + + let valid = verify(&request.password, &user.password_hash) + .map_err(|e| ApiError::internal_error(format!("Password verification failed: {}", e)))?; + + if !valid { + return Err(ApiError::unauthorized("Invalid email or password")); + } + + sqlx::query!( + "UPDATE users SET last_login_at = $1 WHERE id = $2", + Utc::now(), + user.id + ) + .execute(&self.pool) + .await + .map_err(|e| ApiError::database_error(e))?; + + let token_pair = self + .jwt_service + .generate_token_pair(user.id, vec!["user".to_string()], None) + .await + .map_err(|e| ApiError::internal_error(format!("Token generation failed: {}", e)))?; + + info!(user_id = %user.id, username = %user.username, "User logged in successfully"); + + Ok(AuthResponse { + token: token_pair.access_token, + refresh_token: token_pair.refresh_token, + user: UserProfile { + id: user.id, + username: user.username, + email: user.email, + is_verified: user.is_verified, + created_at: user.created_at, + skill_score: None, + fair_play_score: None, + is_bad_actor: None, + }, + }) + } + + /// Verify JWT token and return user ID + pub async fn verify_token(&self, token: &str) -> Result { + let claims = self + .jwt_service + .validate_token(token) + .await + .map_err(|e| ApiError::unauthorized(format!("Token validation failed: {}", e)))?; + + Uuid::parse_str(&claims.sub) + .map_err(|e| ApiError::internal_error(format!("Invalid user ID in token: {}", e))) + } + + /// Refresh access token + pub async fn refresh_token(&self, refresh_token: &str) -> Result { + self.jwt_service + .refresh_token(refresh_token) + .await + .map_err(|e| ApiError::unauthorized(format!("Token refresh failed: {}", e))) + } + + /// Logout user (blacklist token, revoke session) + pub async fn logout(&self, token: &str) -> Result<(), ApiError> { + self.jwt_service + .blacklist_token(token, "User logout") + .await + .map_err(|e| ApiError::internal_error(format!("Logout failed: {}", e)))?; + + info!("User logged out successfully"); + Ok(()) + } + + /// Revoke all sessions for a user + pub async fn revoke_all_sessions(&self, user_id: Uuid) -> Result { + let count = self + .jwt_service + .revoke_user_sessions(user_id) + .await + .map_err(|e| ApiError::internal_error(format!("Session revocation failed: {}", e)))?; + + info!(user_id = %user_id, count = count, "All user sessions revoked"); + Ok(count) + } + + /// Get user by ID + pub async fn get_user(&self, user_id: Uuid) -> Result { + sqlx::query_as!( + User, + "SELECT id, username, email, password_hash, is_active, is_verified, created_at, updated_at FROM users WHERE id = $1", + user_id + ) + .fetch_optional(&self.pool) + .await + .map_err(|e| ApiError::database_error(e))? + .ok_or_else(|| ApiError::not_found("User not found")) } - pub async fn register(&self, _request: CreateUserRequest) -> Result { - // TODO: Implement user registration with database and JWT - Err(ApiError::internal_error("Auth service not yet implemented")) + /// Change user password and revoke all existing sessions + pub async fn change_password( + &self, + user_id: Uuid, + old_password: &str, + new_password: &str, + ) -> Result<(), ApiError> { + if new_password.len() < 8 { + return Err(ApiError::bad_request( + "Password must be at least 8 characters", + )); + } + + let user = self.get_user(user_id).await?; + + let valid = verify(old_password, &user.password_hash) + .map_err(|e| ApiError::internal_error(format!("Password verification failed: {}", e)))?; + + if !valid { + return Err(ApiError::unauthorized("Current password is incorrect")); + } + + let new_hash = hash(new_password, DEFAULT_COST) + .map_err(|e| ApiError::internal_error(format!("Password hashing failed: {}", e)))?; + + sqlx::query!( + "UPDATE users SET password_hash = $1, updated_at = $2 WHERE id = $3", + new_hash, + Utc::now(), + user_id + ) + .execute(&self.pool) + .await + .map_err(|e| ApiError::database_error(e))?; + + self.revoke_all_sessions(user_id).await?; + + info!(user_id = %user_id, "Password changed successfully"); + Ok(()) } +} - pub async fn login(&self, _request: LoginRequest) -> Result { - // TODO: Implement user login with password verification - Err(ApiError::internal_error("Auth service not yet implemented")) +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_bcrypt_hashing() { + let password = "test_password"; + let hashed = hash(password, DEFAULT_COST).unwrap(); + assert!(verify(password, &hashed).unwrap()); + assert!(!verify("wrong_password", &hashed).unwrap()); } - pub fn verify_token(&self, _token: &str) -> Result { - // TODO: Implement JWT token verification - Err(ApiError::internal_error("Token verification not yet implemented")) + #[test] + fn test_password_length_validation() { + assert!("short".len() < 8); + assert!("long_enough_password".len() >= 8); } -} \ No newline at end of file +} diff --git a/backend/src/service/mod.rs b/backend/src/service/mod.rs index 2613d981..ba0d7fec 100644 --- a/backend/src/service/mod.rs +++ b/backend/src/service/mod.rs @@ -1,6 +1,7 @@ // Service layer module for ArenaX pub mod achievement_service; pub mod analytics_service; +pub mod auth_service; pub mod governance_service; pub mod idempotency_service; pub mod leaderboard_service; @@ -24,6 +25,7 @@ pub use governance_service::{ ProposalStatus as GovProposalStatus, }; pub use achievement_service::AchievementService; +pub use auth_service::AuthService; pub use idempotency_service::IdempotencyService; pub use leaderboard_service::LeaderboardService; pub use match_authority_service::MatchAuthorityService;