From f9498a6268ef6758b9d4f72a7f1b880a43bd4012 Mon Sep 17 00:00:00 2001 From: ChengWei Date: Tue, 26 May 2026 23:48:04 +0800 Subject: [PATCH 1/2] feat(auth): add CSRF protection to authentication flows - Add CSRF token generation and validation - Integrate CSRF handling into auth and session flows - Add endpoint for retrieving CSRF tokens - Update related schemas, services, and tests --- backend/api/auth/controller.py | 138 +++++++++++++++------- backend/api/auth/schema.py | 8 +- backend/api/auth/services.py | 86 +++++++++++++- backend/core/config.py | 1 + backend/core/security.py | 23 +++- backend/tests/api/auth/test_controller.py | 107 ++++++++++++++--- backend/tests/api/auth/test_service.py | 5 +- 7 files changed, 300 insertions(+), 68 deletions(-) diff --git a/backend/api/auth/controller.py b/backend/api/auth/controller.py index 5a255d4..44f036c 100644 --- a/backend/api/auth/controller.py +++ b/backend/api/auth/controller.py @@ -16,6 +16,7 @@ UserLogin, UserLoginResponse, TokenResponse, + CsrfTokenResponse, ResetPasswordRequest, TokenValidationResponse, LogoutRequest, @@ -37,6 +38,8 @@ get_password_reset_cooldown, verify_email, resend_verification_email, + get_or_create_csrf_token, + verify_csrf_token, ) from utils.custom_exception import ( ConflictException, @@ -80,15 +83,10 @@ async def register_api( user = result["user"] session_id = result["session_id"] access_token = result["access_token"] - - response.set_cookie( - key="session_id", - value=session_id, - httponly=settings.COOKIE_HTTPONLY, - secure=settings.COOKIE_SECURE, - samesite=settings.COOKIE_SAMESITE, - max_age=settings.SESSION_EXPIRE_MINUTES * 60 - ) + csrf_token = result["csrf_token"] + + _set_session_cookie(response, session_id) + _set_csrf_cookie(response, csrf_token) user_response = UserResponse( id=user.id, @@ -142,7 +140,8 @@ async def login_api( user = result["user"] session_id = result["session_id"] access_token = result["access_token"] - + csrf_token = result["csrf_token"] + user_response = UserResponse( id=user.id, first_name=user.first_name, @@ -150,15 +149,9 @@ async def login_api( email=user.email, phone=user.phone ) - - response.set_cookie( - key="session_id", - value=session_id, - httponly=settings.COOKIE_HTTPONLY, - secure=settings.COOKIE_SECURE, - samesite=settings.COOKIE_SAMESITE, - max_age=settings.SESSION_EXPIRE_MINUTES * 60 - ) + + _set_session_cookie(response, session_id) + _set_csrf_cookie(response, csrf_token) response_data = UserLoginResponse( access_token=access_token, @@ -208,7 +201,7 @@ async def logout_api( # Logout from all devices if await logout_all_devices(db, redis_client, user_id): if response: - response.delete_cookie("session_id") + _clear_auth_cookies(response) return APIResponse(code=200, message="User logged out successfully") else: # Logout from current device only @@ -217,7 +210,7 @@ async def logout_api( if await logout(db, redis_client, user_id, session_id): if response: - response.delete_cookie("session_id") + _clear_auth_cookies(response) return APIResponse(code=200, message="User logged out successfully") except AuthenticationException: raise HTTPException(status_code=401, detail="Invalid or expired session") @@ -231,7 +224,7 @@ async def logout_api( summary="Refresh token", responses=parse_responses({ 200: ("Token refreshed successfully", TokenResponse), - 401: ("Invalid or expired session", None) + 401: ("Invalid or expired session / Invalid or expired CSRF token", None) }, common_responses) ) async def token_api( @@ -239,18 +232,59 @@ async def token_api( db: AsyncSession = Depends(get_db), redis_client = Depends(get_redis) ): - """Use session_id Cookie to get new access_token""" + """Validate CSRF token then use session_id cookie to issue new access_token""" try: + csrf_token = request.headers.get("X-CSRF-Token") + sid_from_csrf = await verify_csrf_token(redis_client, csrf_token) + session_id = request.cookies.get("session_id") - if not session_id: + if not session_id or session_id != sid_from_csrf: raise AuthenticationException("Invalid or expired session") + new_access_token = await token(db, redis_client, session_id) response_data = TokenResponse( access_token=new_access_token, expires_at=datetime.now().astimezone() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) ) return APIResponse(code=200, message="Token refreshed successfully", data=response_data) - except (AuthenticationException, NotFoundException): + except AuthenticationException as e: + raise HTTPException(status_code=401, detail=e.message) + except NotFoundException: + raise HTTPException(status_code=401, detail="Invalid or expired session") + except Exception: + raise HTTPException(status_code=500) + +@router.post( + "/csrf-token", + response_model=APIResponse[CsrfTokenResponse], + response_model_exclude_none=True, + summary="Get CSRF token", + responses=parse_responses({ + 200: ("CSRF token retrieved successfully", CsrfTokenResponse), + 401: ("Invalid or expired session", None), + }, common_responses), +) +async def csrf_token_api( + request: Request, + response: Response, + redis_client=Depends(get_redis), +): + """Issue CSRF token for the current session (cookie-based).""" + try: + session_id = request.cookies.get("session_id") + if not session_id: + raise AuthenticationException("Invalid or expired session") + + csrf_token = await get_or_create_csrf_token(redis_client, session_id) + _set_csrf_cookie(response, csrf_token) + + response_data = CsrfTokenResponse( + csrf_token=csrf_token, + expires_at=datetime.now().astimezone() + + timedelta(minutes=settings.CSRF_TOKEN_EXPIRE_MINUTES), + ) + return APIResponse(code=200, message="CSRF token retrieved successfully", data=response_data) + except AuthenticationException: raise HTTPException(status_code=401, detail="Invalid or expired session") except Exception: raise HTTPException(status_code=500) @@ -284,6 +318,7 @@ async def reset_password_api( user = result["user"] session_id = result["session_id"] access_token = result["access_token"] + csrf_token = result["csrf_token"] user_response = UserResponse( id=user.id, @@ -298,15 +333,9 @@ async def reset_password_api( expires_at=datetime.now().astimezone() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES), user=user_response ) - - response.set_cookie( - key="session_id", - value=session_id, - httponly=settings.COOKIE_HTTPONLY, - secure=settings.COOKIE_SECURE, - samesite=settings.COOKIE_SAMESITE, - max_age=settings.SESSION_EXPIRE_MINUTES * 60 - ) + + _set_session_cookie(response, session_id) + _set_csrf_cookie(response, csrf_token) return APIResponse(code=200, message="Password reset successfully", data=response_data) except AuthenticationException: @@ -431,6 +460,7 @@ async def verify_email_api( user = result["user"] session_id = result["session_id"] access_token = result["access_token"] + csrf_token = result["csrf_token"] user_response = UserResponse( id=user.id, @@ -445,15 +475,9 @@ async def verify_email_api( expires_at=datetime.now().astimezone() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES), user=user_response ) - - response.set_cookie( - key="session_id", - value=session_id, - httponly=settings.COOKIE_HTTPONLY, - secure=settings.COOKIE_SECURE, - samesite=settings.COOKIE_SAMESITE, - max_age=settings.SESSION_EXPIRE_MINUTES * 60 - ) + + _set_session_cookie(response, session_id) + _set_csrf_cookie(response, csrf_token) return APIResponse(code=200, message="Email verified successfully", data=response_data) except AuthenticationException: @@ -530,4 +554,30 @@ async def get_email_verification_cooldown_api( ) return APIResponse(code=200, message="Cooldown status retrieved", data=response_data) except Exception: - raise HTTPException(status_code=500) \ No newline at end of file + raise HTTPException(status_code=500) + +def _set_session_cookie(response: Response, session_id: str) -> None: + response.set_cookie( + key="session_id", + value=session_id, + httponly=settings.COOKIE_HTTPONLY, + secure=settings.COOKIE_SECURE, + samesite=settings.COOKIE_SAMESITE, + max_age=settings.SESSION_EXPIRE_MINUTES * 60, + ) + + +def _set_csrf_cookie(response: Response, csrf_token: str) -> None: + response.set_cookie( + key="csrf_token", + value=csrf_token, + httponly=False, + secure=settings.COOKIE_SECURE, + samesite=settings.COOKIE_SAMESITE, + max_age=settings.CSRF_TOKEN_EXPIRE_MINUTES * 60, + ) + + +def _clear_auth_cookies(response: Response) -> None: + response.delete_cookie("session_id") + response.delete_cookie("csrf_token") \ No newline at end of file diff --git a/backend/api/auth/schema.py b/backend/api/auth/schema.py index 4b83b5a..bb29c01 100644 --- a/backend/api/auth/schema.py +++ b/backend/api/auth/schema.py @@ -5,12 +5,14 @@ class LoginResult(TypedDict): user: "UserResponse" - session_id: str = Field(..., description="Session ID") + session_id: str = Field(..., description="Session ID") access_token: str = Field(..., description="JWT access token") + csrf_token: str = Field(..., description="CSRF token") class SessionResult(TypedDict): session_id: str = Field(..., description="Session ID") access_token: str = Field(..., description="JWT access token") + csrf_token: str = Field(..., description="CSRF token") class UserRegister(BaseModel): first_name: str = Field(..., min_length=1, max_length=50, description="First name") @@ -39,6 +41,10 @@ class TokenResponse(BaseModel): access_token: str = Field(..., description="JWT access token") expires_at: datetime = Field(..., description="Token expiration time") +class CsrfTokenResponse(BaseModel): + csrf_token: str = Field(..., description="CSRF token") + expires_at: datetime = Field(..., description="CSRF token expiration time") + class ActionRequiredResponse(BaseModel): action_type: str = Field(..., description="Action type for frontend routing: 'password_reset' or 'email_verification'") token: Optional[str] = Field(default=None, description="Token for the password reset") diff --git a/backend/api/auth/services.py b/backend/api/auth/services.py index c83c583..41726a1 100644 --- a/backend/api/auth/services.py +++ b/backend/api/auth/services.py @@ -1,5 +1,6 @@ import ast import redis +from jose import jwt, JWTError from typing import Optional from sqlalchemy import select, update, or_ from urllib.parse import quote @@ -28,6 +29,7 @@ from core.security import ( verify_password, create_access_token, + create_csrf_token, hash_password, extend_session_ttl, clear_user_all_sessions, @@ -116,7 +118,8 @@ async def register( return { "user": user, "session_id": session_result["session_id"], - "access_token": session_result["access_token"] + "access_token": session_result["access_token"], + "csrf_token": session_result["csrf_token"], } async def login( @@ -261,7 +264,8 @@ async def login( return { "user": user, "session_id": session_result["session_id"], - "access_token": session_result["access_token"] + "access_token": session_result["access_token"], + "csrf_token": session_result["csrf_token"], } async def logout( @@ -273,7 +277,7 @@ async def logout( """User logout""" try: redis_key = f"session:{session_id}" - await redis_client.delete(redis_key) + await redis_client.delete(redis_key, f"csrf:{session_id}") result = await db.execute( select(UserSessions).where( @@ -301,6 +305,57 @@ async def logout_all_devices( except Exception: raise ServerException("Failed to logout all devices") +async def get_or_create_csrf_token( + redis_client: redis.Redis, + session_id: str, +) -> str: + """Return existing CSRF token or create a new one (does not extend TTL).""" + session_raw = await redis_client.get(f"session:{session_id}") + if not session_raw: + raise AuthenticationException("Invalid or expired session") + + existing = await redis_client.get(_csrf_redis_key(session_id)) + if existing: + return existing.decode() if isinstance(existing, bytes) else existing + + return await _create_csrf_token_for_session(redis_client, session_id) + + +async def verify_csrf_token( + redis_client: redis.Redis, + csrf_token: Optional[str], +) -> str: + """Validate CSRF token and return session_id.""" + if not csrf_token: + raise AuthenticationException("Invalid or expired CSRF token") + + try: + payload = jwt.decode( + csrf_token, + settings.SECRET_KEY, + algorithms=[settings.ALGORITHM], + ) + except JWTError: + raise AuthenticationException("Invalid or expired CSRF token") + + if payload.get("token_type") != "csrf": + raise AuthenticationException("Invalid or expired CSRF token") + + session_id = payload.get("sid") + if not session_id: + raise AuthenticationException("Invalid or expired CSRF token") + + stored = await redis_client.get(_csrf_redis_key(session_id)) + if not stored: + raise AuthenticationException("Invalid or expired CSRF token") + + stored_token = stored.decode() if isinstance(stored, bytes) else stored + if stored_token != csrf_token: + raise AuthenticationException("Invalid or expired CSRF token") + + return session_id + + async def token( db: AsyncSession, redis_client: redis.Redis, @@ -391,7 +446,8 @@ async def reset_password( return { "user": user, "session_id": session_result["session_id"], - "access_token": session_result["access_token"] + "access_token": session_result["access_token"], + "csrf_token": session_result["csrf_token"], } except (AuthenticationException, NotFoundException): @@ -615,10 +671,13 @@ async def _create_user_session( settings.SESSION_EXPIRE_MINUTES * 60, str(session_data) ) + + csrf_token = await _create_csrf_token_for_session(redis_client, session_id) return { "session_id": session_id, - "access_token": access_token + "access_token": access_token, + "csrf_token": csrf_token, } except Exception as e: raise ServerException(f"Failed to create user session: {str(e)}") @@ -762,7 +821,8 @@ async def verify_email( return { "user": user, "session_id": session_result["session_id"], - "access_token": session_result["access_token"] + "access_token": session_result["access_token"], + "csrf_token": session_result["csrf_token"], } except (AuthenticationException, NotFoundException, ConflictException): @@ -862,6 +922,20 @@ async def resend_verification_email( except Exception as e: raise ServerException(f"Failed to send verification email: {str(e)}") + +def _csrf_redis_key(session_id: str) -> str: + return f"csrf:{session_id}" + + +async def _create_csrf_token_for_session( + redis_client: redis.Redis, + session_id: str, +) -> str: + csrf_token = await create_csrf_token(session_id) + ttl = settings.CSRF_TOKEN_EXPIRE_MINUTES * 60 + await redis_client.setex(_csrf_redis_key(session_id), ttl, csrf_token) + return csrf_token + async def _send_registration_verification_email( db: AsyncSession, mailer: SMTPMailer, diff --git a/backend/core/config.py b/backend/core/config.py index 40cdf16..e6830fd 100644 --- a/backend/core/config.py +++ b/backend/core/config.py @@ -50,6 +50,7 @@ class Settings(BaseSettings): # Session settings SESSION_EXPIRE_MINUTES: int = 10080 # 7 days + CSRF_TOKEN_EXPIRE_MINUTES: int = 30 # Cookie settings COOKIE_SECURE: bool = SSL_ENABLE diff --git a/backend/core/security.py b/backend/core/security.py index fb7d573..c7d275d 100644 --- a/backend/core/security.py +++ b/backend/core/security.py @@ -55,6 +55,24 @@ async def create_password_reset_token(user_id: str, email: str) -> str: except Exception as e: raise ServerException(f"Failed to create password reset token: {str(e)}") +async def create_csrf_token(session_id: str) -> str: + """Create CSRF token bound to a session""" + try: + now = datetime.now().astimezone() + expire = now + timedelta(minutes=settings.CSRF_TOKEN_EXPIRE_MINUTES) + + payload = { + "sid": session_id, + "token_type": "csrf", + "iat": now, + "exp": expire, + } + + return jwt.encode(payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM) + + except Exception as e: + raise ServerException(f"Failed to create CSRF token: {str(e)}") + async def create_email_verification_token(user_id: str, email: str, token_type: str) -> str: """Create email verification token""" try: @@ -275,7 +293,10 @@ async def clear_user_all_sessions(db: AsyncSession, redis_client: redis.Redis, u session_ids = result.scalars().all() if session_ids: - redis_keys = [f"session:{sid}" for sid in session_ids] + redis_keys = [] + for sid in session_ids: + redis_keys.append(f"session:{sid}") + redis_keys.append(f"csrf:{sid}") await redis_client.delete(*redis_keys) return True diff --git a/backend/tests/api/auth/test_controller.py b/backend/tests/api/auth/test_controller.py index a6fceb8..8a5d10c 100644 --- a/backend/tests/api/auth/test_controller.py +++ b/backend/tests/api/auth/test_controller.py @@ -52,6 +52,7 @@ async def test_register_success(self, client: AsyncClient): )(), "session_id": "test-session-id", "access_token": "test-access-token", + "csrf_token": "test-csrf-token", } response = await client.post("/api/auth/register", json=register_data) @@ -61,6 +62,7 @@ async def test_register_success(self, client: AsyncClient): assert data["code"] == 200 assert data["message"] == "User registered successfully" assert "session_id" in response.cookies + assert "csrf_token" in response.cookies @pytest.mark.asyncio async def test_register_email_already_exists(self, client: AsyncClient): @@ -165,6 +167,7 @@ async def test_login_success(self, client: AsyncClient): )(), "session_id": "test-session-id", "access_token": "test-access-token", + "csrf_token": "test-csrf-token", } response = await client.post("/api/auth/login", json=login_data) @@ -174,6 +177,7 @@ async def test_login_success(self, client: AsyncClient): assert data["code"] == 200 assert data["message"] == "User logged in successfully" assert "session_id" in response.cookies + assert "csrf_token" in response.cookies @pytest.mark.asyncio async def test_login_invalid_credentials(self, client: AsyncClient): @@ -290,11 +294,16 @@ async def mock_verify_token(): @pytest.mark.asyncio async def test_token_refresh_success(self, client: AsyncClient): """Test successful token refresh""" - with patch("api.auth.controller.token") as mock_token: + with patch("api.auth.controller.verify_csrf_token") as mock_verify_csrf, patch( + "api.auth.controller.token" + ) as mock_token: + mock_verify_csrf.return_value = "test-session-id" mock_token.return_value = "new-access-token" response = await client.post( - "/api/auth/token", cookies={"session_id": "test-session-id"} + "/api/auth/token", + cookies={"session_id": "test-session-id"}, + headers={"X-CSRF-Token": "test-csrf-token"}, ) assert response.status_code == 200 @@ -304,40 +313,70 @@ async def test_token_refresh_success(self, client: AsyncClient): assert data["data"]["access_token"] == "new-access-token" @pytest.mark.asyncio - async def test_token_refresh_no_session(self, client: AsyncClient): - """Test token refresh without session cookie""" - response = await client.post("/api/auth/token") + async def test_token_refresh_no_csrf_header(self, client: AsyncClient): + """Test token refresh without CSRF header""" + response = await client.post( + "/api/auth/token", + cookies={"session_id": "test-session-id"}, + ) assert response.status_code == 401 data = response.json() assert data["code"] == 401 - assert data["message"] == "Invalid or expired session" + assert data["message"] == "Invalid or expired CSRF token" + + @pytest.mark.asyncio + async def test_token_refresh_invalid_csrf(self, client: AsyncClient): + """Test token refresh with invalid CSRF token""" + with patch("api.auth.controller.verify_csrf_token") as mock_verify_csrf: + mock_verify_csrf.side_effect = AuthenticationException( + "Invalid or expired CSRF token" + ) + + response = await client.post( + "/api/auth/token", + cookies={"session_id": "test-session-id"}, + headers={"X-CSRF-Token": "invalid-csrf-token"}, + ) + + assert response.status_code == 401 + data = response.json() + assert data["code"] == 401 + assert data["message"] == "Invalid or expired CSRF token" @pytest.mark.asyncio async def test_token_refresh_invalid_session(self, client: AsyncClient): """Test token refresh with invalid session""" - with patch("api.auth.controller.token") as mock_token: - mock_token.side_effect = AuthenticationException( - "Invalid or expired session" - ) + with patch("api.auth.controller.verify_csrf_token") as mock_verify_csrf, patch( + "api.auth.controller.token" + ) as mock_token: + mock_verify_csrf.return_value = "other-session-id" response = await client.post( - "/api/auth/token", cookies={"session_id": "invalid-session-id"} + "/api/auth/token", + cookies={"session_id": "invalid-session-id"}, + headers={"X-CSRF-Token": "test-csrf-token"}, ) assert response.status_code == 401 data = response.json() assert data["code"] == 401 assert data["message"] == "Invalid or expired session" + mock_token.assert_not_called() @pytest.mark.asyncio async def test_token_refresh_user_not_found(self, client: AsyncClient): """Test token refresh with user not found""" - with patch("api.auth.controller.token") as mock_token: + with patch("api.auth.controller.verify_csrf_token") as mock_verify_csrf, patch( + "api.auth.controller.token" + ) as mock_token: + mock_verify_csrf.return_value = "test-session-id" mock_token.side_effect = NotFoundException("User not found") response = await client.post( - "/api/auth/token", cookies={"session_id": "test-session-id"} + "/api/auth/token", + cookies={"session_id": "test-session-id"}, + headers={"X-CSRF-Token": "test-csrf-token"}, ) assert response.status_code == 401 @@ -348,15 +387,48 @@ async def test_token_refresh_user_not_found(self, client: AsyncClient): @pytest.mark.asyncio async def test_token_refresh_server_error(self, client: AsyncClient): """Test token refresh with server error""" - with patch("api.auth.controller.token") as mock_token: + with patch("api.auth.controller.verify_csrf_token") as mock_verify_csrf, patch( + "api.auth.controller.token" + ) as mock_token: + mock_verify_csrf.return_value = "test-session-id" mock_token.side_effect = Exception("Database error") response = await client.post( - "/api/auth/token", cookies={"session_id": "test-session-id"} + "/api/auth/token", + cookies={"session_id": "test-session-id"}, + headers={"X-CSRF-Token": "test-csrf-token"}, ) assert response.status_code == 500 + @pytest.mark.asyncio + async def test_csrf_token_success(self, client: AsyncClient): + """Test successful CSRF token retrieval""" + with patch("api.auth.controller.get_or_create_csrf_token") as mock_get_csrf: + mock_get_csrf.return_value = "test-csrf-token" + + response = await client.post( + "/api/auth/csrf-token", + cookies={"session_id": "test-session-id"}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["code"] == 200 + assert data["message"] == "CSRF token retrieved successfully" + assert data["data"]["csrf_token"] == "test-csrf-token" + assert "csrf_token" in response.cookies + + @pytest.mark.asyncio + async def test_csrf_token_no_session(self, client: AsyncClient): + """Test CSRF token retrieval without session cookie""" + response = await client.post("/api/auth/csrf-token") + + assert response.status_code == 401 + data = response.json() + assert data["code"] == 401 + assert data["message"] == "Invalid or expired session" + @pytest.mark.asyncio async def test_login_password_reset_required(self, client: AsyncClient): """Test login with password reset required""" @@ -426,6 +498,7 @@ async def mock_verify_password_reset_token(): )(), "session_id": "test-session-id", "access_token": "test-access-token", + "csrf_token": "test-csrf-token", } app.dependency_overrides[verify_password_reset_token] = ( @@ -670,6 +743,7 @@ async def mock_verify_password_reset_token(): )(), "session_id": "test-session-id", "access_token": "test-access-token", + "csrf_token": "test-csrf-token", } app.dependency_overrides[verify_password_reset_token] = ( @@ -684,6 +758,7 @@ async def mock_verify_password_reset_token(): ) assert response.status_code == 200 assert "session_id" in response.cookies + assert "csrf_token" in response.cookies finally: app.dependency_overrides.pop(verify_password_reset_token, None) @@ -816,6 +891,7 @@ async def mock_verify_email_token(): )(), "session_id": "test-session-id", "access_token": "test-access-token", + "csrf_token": "test-csrf-token", } app.dependency_overrides[verify_email_verification_token] = ( mock_verify_email_token @@ -831,6 +907,7 @@ async def mock_verify_email_token(): assert data["code"] == 200 assert data["message"] == "Email verified successfully" assert "session_id" in response.cookies + assert "csrf_token" in response.cookies finally: app.dependency_overrides.pop(verify_email_verification_token, None) diff --git a/backend/tests/api/auth/test_service.py b/backend/tests/api/auth/test_service.py index 3845c3c..a2917ee 100644 --- a/backend/tests/api/auth/test_service.py +++ b/backend/tests/api/auth/test_service.py @@ -237,7 +237,10 @@ async def test_logout_success( ) assert result is True - mock_redis.delete.assert_called_once_with(f"session:{test_user_session.id}") + mock_redis.delete.assert_called_once_with( + f"session:{test_user_session.id}", + f"csrf:{test_user_session.id}", + ) @pytest.mark.asyncio async def test_logout_all_devices_success( From 940e6dd59bb72bc01fd80de9bd6bccb57267eaeb Mon Sep 17 00:00:00 2001 From: ChengWei Date: Tue, 26 May 2026 23:50:02 +0800 Subject: [PATCH 2/2] feat(frontend): add CSRF validation and improve auth error handling - Add CSRF token retrieval and validation in authentication flow - Improve handling of CSRF and session invalidation errors - Introduce utilities for error parsing and CSRF token management - Update services and hooks to support enhanced auth flow --- frontend/src/hooks/useAuth.jsx | 154 +++++++++++++++----------- frontend/src/lib/authErrors.js | 14 +++ frontend/src/lib/cookies.js | 16 +++ frontend/src/services/auth.service.js | 31 +++++- 4 files changed, 146 insertions(+), 69 deletions(-) create mode 100644 frontend/src/lib/authErrors.js create mode 100644 frontend/src/lib/cookies.js diff --git a/frontend/src/hooks/useAuth.jsx b/frontend/src/hooks/useAuth.jsx index 6cff73c..bec88bd 100644 --- a/frontend/src/hooks/useAuth.jsx +++ b/frontend/src/hooks/useAuth.jsx @@ -4,6 +4,31 @@ import authService from '@/services/auth.service'; import accountService from '@/services/account.service'; import rolesService from '@/services/roles.service'; import { debugError } from '@/lib/utils'; +import { getCsrfTokenFromCookie } from '@/lib/cookies'; +import { + getApiErrorMessage, + isCsrfInvalidError, + isSessionInvalidError, +} from '@/lib/authErrors'; + +function parseAccessToken(result) { + if (!result) { + return null; + } + if (typeof result === 'string') { + return result; + } + if (result.access_token) { + return result.access_token; + } + if (result.data?.access_token) { + return result.data.access_token; + } + if (result.data && typeof result.data === 'string') { + return result.data; + } + return null; +} export const useAuth = () => { const context = useAuthContext(); @@ -262,52 +287,81 @@ export const useAuth = () => { } }, [clearAuth]); - // Get authentication token from server - const getToken = useCallback(async (isInit = false, skipProfile = false) => { + // Get authentication token from server (with CSRF validation and retry) + const getToken = useCallback(async (isInit = false, _skipProfile = false) => { + const applyAccessToken = (accessToken) => { + if (state.user) { + loginSuccess(state.user, accessToken); + } else { + setToken(accessToken); + } + }; + + const attemptRefresh = async (csrfToken) => { + const result = await authService.getToken({ + showErrorToast: false, + csrfToken: csrfToken ?? getCsrfTokenFromCookie() ?? '', + }); + const accessToken = parseAccessToken(result); + if (!accessToken) { + throw new Error('Unable to get token'); + } + applyAccessToken(accessToken); + return { success: true, token: accessToken }; + }; + + const invalidateSession = async () => { + await logout(true); + }; + try { setLoading(true); - - const result = await authService.getToken({ showErrorToast: false }); - - let access_token; - if (typeof result === 'string') { - access_token = result; - } else if (result?.data) { - access_token = result.data.access_token || result.data; - } else if (result?.access_token) { - access_token = result.access_token; - } - - if (access_token) { - // Preserve user if exists, otherwise just set token - if (state.user) { - loginSuccess(state.user, access_token); - } else { - setToken(access_token); + + try { + const success = await attemptRefresh(); + setLoading(false); + return success; + } catch (error) { + if (isCsrfInvalidError(error)) { + try { + const csrfResult = await authService.getCsrfToken({ showErrorToast: false }); + const newCsrf = csrfResult?.csrf_token ?? getCsrfTokenFromCookie(); + if (!newCsrf) { + throw error; + } + const success = await attemptRefresh(newCsrf); + setLoading(false); + return success; + } catch (retryError) { + if (isCsrfInvalidError(retryError) || isSessionInvalidError(retryError)) { + await invalidateSession(); + setLoading(false); + return { success: false, error: getApiErrorMessage(retryError) }; + } + throw retryError; + } } + + if (isSessionInvalidError(error)) { + await invalidateSession(); + setLoading(false); + return { success: false, error: getApiErrorMessage(error) }; + } + setLoading(false); - return { success: true, token: access_token }; - } - - setLoading(false); - - // Only clear auth if not initializing and no token exists - if (!isInit && !state.token) { - clearAuth(); + if (!isInit && !state.token) { + clearAuth(); + } + return { success: false, error: getApiErrorMessage(error) }; } - - return { success: false, error: 'Unable to get token' }; } catch (error) { setLoading(false); - - // Only clear auth if not initializing and no token exists if (!isInit && !state.token) { clearAuth(); } - - return { success: false, error: error.message }; + return { success: false, error: getApiErrorMessage(error) }; } - }, [setLoading, setToken, clearAuth, state.token, state.user, loginSuccess]); + }, [setLoading, setToken, clearAuth, state.token, state.user, loginSuccess, logout]); // Reset password and auto-login user const resetPassword = useCallback(async (newPassword, resetToken) => { @@ -468,40 +522,14 @@ export const useAuth = () => { const init = async () => { try { - setLoading(true); - - let result; - try { - result = await authService.getToken(); - } catch (tokenError) { - result = null; - } - - let access_token; - if (result) { - if (typeof result === 'string') { - access_token = result; - } else if (result?.data) { - access_token = result.data.access_token || result.data; - } else if (result?.access_token) { - access_token = result.access_token; - } - } - - if (access_token) { - setToken(access_token); - } - - setLoading(false); - } catch (error) { - setLoading(false); + await getToken(true); } finally { isInitializingRef.current = false; } }; init(); - }, [setLoading, setToken]); + }, [getToken]); // Auto-load profile and permissions when authenticated useEffect(() => { diff --git a/frontend/src/lib/authErrors.js b/frontend/src/lib/authErrors.js new file mode 100644 index 0000000..a458ae0 --- /dev/null +++ b/frontend/src/lib/authErrors.js @@ -0,0 +1,14 @@ +export const CSRF_INVALID_MSG = 'Invalid or expired CSRF token'; +export const SESSION_INVALID_MSG = 'Invalid or expired session'; + +export function getApiErrorMessage(error) { + return error?.response?.data?.message ?? error?.message ?? ''; +} + +export function isCsrfInvalidError(error) { + return getApiErrorMessage(error) === CSRF_INVALID_MSG; +} + +export function isSessionInvalidError(error) { + return getApiErrorMessage(error) === SESSION_INVALID_MSG; +} \ No newline at end of file diff --git a/frontend/src/lib/cookies.js b/frontend/src/lib/cookies.js new file mode 100644 index 0000000..12932c9 --- /dev/null +++ b/frontend/src/lib/cookies.js @@ -0,0 +1,16 @@ +export const CSRF_COOKIE_NAME = 'csrf_token'; + +export function getCookie(name) { + if (typeof document === 'undefined' || !document.cookie) { + return null; + } + + const escapedName = name.replace(/[.*+?^${}()|[\]\\]/g, '\\$&'); + const match = document.cookie.match(new RegExp(`(?:^|;\\s*)${escapedName}=([^;]*)`)); + + return match ? decodeURIComponent(match[1]) : null; +} + +export function getCsrfTokenFromCookie() { + return getCookie(CSRF_COOKIE_NAME); +} \ No newline at end of file diff --git a/frontend/src/services/auth.service.js b/frontend/src/services/auth.service.js index 2b0d552..95c2cb9 100644 --- a/frontend/src/services/auth.service.js +++ b/frontend/src/services/auth.service.js @@ -1,5 +1,6 @@ import apiService from './api.service'; import i18n from '@/i18n'; +import { getCsrfTokenFromCookie } from '@/lib/cookies'; const BASE_AUTH = '/auth'; @@ -46,19 +47,37 @@ export const authService = { ...config, }), - // Get token - getToken: (config = {}) => - apiService.post(`${BASE_AUTH}/token`, {}, { + // Get CSRF token (session cookie required) + getCsrfToken: (config = {}) => + apiService.post(`${BASE_AUTH}/csrf-token`, {}, { + noToken: true, + retryOn401: false, + showErrorToast: false, + showSuccessToast: false, + ...config, + }), + + // Get token (requires session_id cookie and X-CSRF-Token header) + getToken: (config = {}) => { + const { csrfToken, ...restConfig } = config; + const headers = { + ...(restConfig.headers || {}), + 'X-CSRF-Token': csrfToken ?? getCsrfTokenFromCookie() ?? '', + }; + + return apiService.post(`${BASE_AUTH}/token`, {}, { noToken: true, retryOn401: false, showErrorToast: false, showSuccessToast: false, messageMap: { 401: i18n.t('pages.auth.login.messages.invalidCredentials', 'Invalid email or password'), - ...config.messageMap, + ...restConfig.messageMap, }, - ...config - }), + headers, + ...restConfig, + }); + }, // Reset password resetPassword: (newPassword, resetToken, config = {}) =>