Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 70 additions & 25 deletions backend/src/auth/jwt_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ impl Default for JwtConfig {
pub struct SessionData {
pub user_id: Uuid,
pub session_id: String,
pub token_type: TokenType,
pub device_id: Option<String>,
pub created_at: i64,
pub last_activity: i64,
Expand Down Expand Up @@ -165,6 +166,7 @@ impl KeyRotation {
}

/// Main JWT Service
#[derive(Clone)]
pub struct JwtService {
config: JwtConfig,
redis: ConnectionManager,
Expand Down Expand Up @@ -209,8 +211,15 @@ impl JwtService {
let token = encode(&Header::new(self.config.algorithm), &claims, &encoding_key)
.map_err(|e| JwtError::TokenGeneration(e.to_string()))?;

// Store session in Redis
self.store_session(&session_id, user_id, device_id).await?;
// Store access session in Redis
self.store_session(
&session_id,
user_id,
device_id.clone(),
TokenType::Access,
self.config.access_token_expiry,
)
.await?;

info!(user_id = %user_id, session_id = %session_id, "Access token generated");

Expand Down Expand Up @@ -243,6 +252,15 @@ impl JwtService {
let token = encode(&Header::new(self.config.algorithm), &claims, &encoding_key)
.map_err(|e| JwtError::TokenGeneration(e.to_string()))?;

self.store_session(
&session_id,
user_id,
device_id.clone(),
TokenType::Refresh,
self.config.refresh_token_expiry,
)
.await?;

info!(user_id = %user_id, session_id = %session_id, "Refresh token generated");

Ok(token)
Expand Down Expand Up @@ -333,6 +351,28 @@ impl JwtService {
Ok(token_data.claims)
}

async fn decode_with_key_rotation(&self, token: &str) -> Result<Claims, JwtError> {
let key_rotation = self.key_rotation.read().await;
match self.decode_token(token, &key_rotation.current_key) {
Ok(claims) => Ok(claims),
Err(err) => {
if let Some(ref previous_key) = key_rotation.previous_key {
debug!("Trying previous key for token rotation");
self.decode_token(token, previous_key).or(Err(err))
} else {
Err(err)
}
}
}
}

async fn extract_jti(&self, token: &str) -> Option<String> {
self.decode_with_key_rotation(token)
.await
.ok()
.map(|claims| claims.jti)
}

/// Refresh access token using refresh token
pub async fn refresh_token(&self, refresh_token: &str) -> Result<TokenPair, JwtError> {
let claims = self.validate_token(refresh_token).await?;
Expand Down Expand Up @@ -361,12 +401,17 @@ impl JwtService {
/// Blacklist a token
pub async fn blacklist_token(&self, token: &str, reason: &str) -> Result<(), JwtError> {
// Decode token to get expiration
let key_rotation = self.key_rotation.read().await;
let claims = self.decode_token(token, &key_rotation.current_key)?;
let claims = match self.decode_with_key_rotation(token).await {
Ok(claims) => claims,
Err(JwtError::TokenExpired) => {
// Already expired, no need to store blacklist metadata.
return Ok(());
}
Err(err) => return Err(err),
};

let exp_duration = claims.exp - Utc::now().timestamp();
if exp_duration <= 0 {
// Token already expired, no need to blacklist
return Ok(());
}

Expand All @@ -390,13 +435,7 @@ impl JwtService {
/// that a concurrent `rotate_keys()` write is never blocked for the full
/// duration of the network call.
pub async fn is_token_blacklisted(&self, token: &str) -> Result<bool, JwtError> {
// Extract the JTI while holding the lock, then drop it immediately.
let jti_opt = {
let key_rotation = self.key_rotation.read().await;
self.decode_token(token, &key_rotation.current_key)
.ok()
.map(|c| c.jti)
}; // lock dropped here
let jti_opt = self.extract_jti(token).await;

match jti_opt {
Some(jti) => {
Expand All @@ -420,10 +459,13 @@ impl JwtService {
session_id: &str,
user_id: Uuid,
device_id: Option<String>,
token_type: TokenType,
ttl: Duration,
) -> Result<(), JwtError> {
let session_data = SessionData {
user_id,
session_id: session_id.to_string(),
token_type,
device_id,
created_at: Utc::now().timestamp(),
last_activity: Utc::now().timestamp(),
Expand All @@ -436,12 +478,8 @@ impl JwtService {
.map_err(|e| JwtError::RedisError(e.to_string()))?;

let mut conn = self.redis.clone();
conn.set_ex(
&session_key,
session_json,
self.config.access_token_expiry.num_seconds() as u64,
)
.await?;
conn.set_ex(&session_key, session_json, ttl.num_seconds() as u64)
.await?;

// Add to user's active sessions set; expire the set with the refresh TTL
// so it outlives individual access token sessions.
Expand Down Expand Up @@ -481,13 +519,12 @@ impl JwtService {
let updated_json =
serde_json::to_string(&session).map_err(|e| JwtError::RedisError(e.to_string()))?;

// Refresh the TTL using access_token_expiry (consistent with store_session).
conn.set_ex(
&session_key,
updated_json,
self.config.access_token_expiry.num_seconds() as u64,
)
.await?;
let ttl_seconds = match session.token_type {
TokenType::Access => self.config.access_token_expiry.num_seconds() as u64,
TokenType::Refresh => self.config.refresh_token_expiry.num_seconds() as u64,
};

conn.set_ex(&session_key, updated_json, ttl_seconds).await?;
}

Ok(())
Expand Down Expand Up @@ -537,6 +574,14 @@ impl JwtService {
pub async fn revoke_session(&self, session_id: &str) -> Result<(), JwtError> {
let session_key = format!("session:{}", session_id);
let mut conn = self.redis.clone();

if let Some(session_json) = conn.get::<_, Option<String>>(&session_key).await? {
if let Ok(session) = serde_json::from_str::<SessionData>(&session_json) {
let user_sessions_key = format!("user_sessions:{}", session.user_id);
conn.srem(&user_sessions_key, session_id).await?;
}
}

conn.del(&session_key).await?;

info!(session_id = %session_id, "Session revoked");
Expand Down
39 changes: 39 additions & 0 deletions backend/src/auth/jwt_service_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,45 @@ mod integration_tests {
assert!(stats.total_validated > 0);
}

#[tokio::test]
async fn test_refresh_token_creates_refresh_session() {
let service = create_test_service().await;
let user_id = Uuid::new_v4();
let roles = vec!["user".to_string()];

let token_pair = service
.generate_token_pair(user_id, roles.clone(), None)
.await
.unwrap();

let sessions = service.get_user_sessions(user_id).await.unwrap();
assert!(sessions.iter().any(|s| s.token_type == TokenType::Refresh));
assert!(sessions.iter().any(|s| s.token_type == TokenType::Access));
assert_eq!(sessions.len(), 2);

let refreshed = service.refresh_token(&token_pair.refresh_token).await;
assert!(refreshed.is_ok());
}

#[tokio::test]
async fn test_revoke_session_invalidates_token() {
let service = create_test_service().await;
let user_id = Uuid::new_v4();
let roles = vec!["user".to_string()];

let token = service
.generate_access_token(user_id, roles.clone(), Some("device-test".to_string()))
.await
.unwrap();

let claims = service.validate_token(&token).await.unwrap();
service.revoke_session(&claims.session_id).await.unwrap();

let result = service.validate_token(&token).await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), JwtError::SessionNotFound));
}

#[tokio::test]
async fn test_key_rotation() {
let service = create_test_service().await;
Expand Down
2 changes: 2 additions & 0 deletions backend/src/auth/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
pub mod device_service;
pub mod jwt_service;
#[cfg(test)]
mod jwt_service_test;
pub mod middleware;

pub use device_service::{
Expand Down
89 changes: 81 additions & 8 deletions backend/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,54 @@ use crate::middleware::idempotency_middleware::IdempotencyMiddleware;
use crate::middleware::rate_limit::RateLimitMiddleware;
use crate::middleware::security::{SecurityConfig, SecurityMiddleware};
use crate::service::ReaperService;
use crate::service::AuthService;
use crate::realtime::event_bus::EventBus;
use crate::realtime::session_registry::SessionRegistry;
use crate::realtime::ws_broadcaster::{WsAddressBook, WsBroadcaster};
use crate::service::matchmaker::{MatchmakerService, MatchmakingConfig, EloEngine};
use crate::service::soroban_service::{NetworkConfig, SorobanService};
use crate::service::tournament_service::TournamentService;
use crate::telemetry::init_telemetry;
use anyhow::anyhow;
use redis::AsyncCommands;

async fn create_redis_manager(redis_url: &str) -> io::Result<redis::aio::ConnectionManager> {
let redis_client = redis::Client::open(redis_url).map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("Invalid Redis URL '{}': {}", redis_url, e),
)
})?;

let manager = redis::aio::ConnectionManager::new(redis_client.clone())
.await
.map_err(|e| {
io::Error::new(
io::ErrorKind::ConnectionRefused,
format!("Failed to initialize Redis connection manager for '{}': {}", redis_url, e),
)
})?;

let mut probe = manager.clone();
let pong: String = probe
.ping()
.await
.map_err(|e| {
io::Error::new(
io::ErrorKind::ConnectionRefused,
format!("Failed to ping Redis server at '{}': {}", redis_url, e),
)
})?;

if pong != "PONG" {
return Err(io::Error::new(
io::ErrorKind::ConnectionRefused,
format!("Unexpected Redis ping response from '{}': {}", redis_url, pong),
));
}

Ok(manager)
}

#[tokio::main]
async fn main() -> io::Result<()> {
Expand All @@ -51,21 +92,16 @@ async fn main() -> io::Result<()> {
let reaper = Arc::new(ReaperService::new(db_pool.clone()));
reaper.run();

// Create Redis client (placeholder)
// let redis_client = redis::Client::open(config.redis.url.clone()).unwrap();
// Spawn tournament orchestrator polling worker
let _orchestrator_handle = crate::orchestrator::TournamentOrchestrator::spawn_polling_worker(
db_pool.clone(),
60,
);
tracing::info!("Tournament orchestrator polling worker started");

// Create Redis connection manager
let redis_client = redis::Client::open(config.redis.url.clone())
.expect("Failed to create Redis client");
let redis_conn = redis::aio::ConnectionManager::new(redis_client.clone())
.await
.expect("Failed to create Redis connection manager");
// Create and validate Redis client before startup.
let redis_conn = create_redis_manager(&config.redis.url).await?;
let redis_data = web::Data::new(redis_conn.clone());

// Initialize matchmaking service — pass the shared ConnectionManager so
// the service never opens a new connection per request.
Expand Down Expand Up @@ -120,6 +156,9 @@ async fn main() -> io::Result<()> {
let jwt_service = Arc::new(crate::auth::jwt_service::JwtService::new(jwt_config, redis_conn.clone()));
let auth_guard = Arc::new(crate::realtime::auth::RealtimeAuth::new(db_pool.clone()));

// Build the HTTP AuthService (wires JWT+Redis into register/login/logout/refresh)
let auth_service = Arc::new(AuthService::new(db_pool.clone(), jwt_service.clone()));

// Start Redis Pub/Sub subscriber (broadcasts to local WebSocket actors)
let broadcaster = WsBroadcaster::new(
config.redis.url.clone(),
Expand All @@ -140,6 +179,8 @@ async fn main() -> io::Result<()> {
let server = HttpServer::new(move || {
App::new()
.app_data(web::Data::new(db_pool.clone()))
.app_data(redis_data.clone())
.app_data(web::Data::new(auth_service.clone()))
.app_data(web::Data::new(event_bus.clone()))
.app_data(web::Data::new(session_registry.clone()))
.app_data(web::Data::new(address_book.clone()))
Expand Down Expand Up @@ -293,3 +334,35 @@ async fn main() -> io::Result<()> {

server.await
}

#[cfg(test)]
mod tests {
use super::*;

#[tokio::test]
async fn test_create_redis_manager_invalid_url_fails() {
let url = "not-a-redis-url";
let result = create_redis_manager(url).await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("Invalid Redis URL"));
}

#[tokio::test]
async fn test_init_redis_connection_invalid_host_fails() {
let url = "redis://127.0.0.1:1/";
let result = create_redis_manager(url).await;
assert!(result.is_err());
}

#[tokio::test]
async fn test_create_redis_manager_success_when_redis_available() {
let url = match std::env::var("REDIS_TEST_URL") {
Ok(url) => url,
Err(_) => return,
};

let result = create_redis_manager(&url).await;
assert!(result.is_ok(), "Redis startup should succeed with a running Redis server");
}
}
Loading