Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 94 additions & 44 deletions backend/api/auth/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
UserLogin,
UserLoginResponse,
TokenResponse,
CsrfTokenResponse,
ResetPasswordRequest,
TokenValidationResponse,
LogoutRequest,
Expand All @@ -37,6 +38,8 @@
get_password_reset_cooldown,
verify_email,
resend_verification_email,
get_or_create_csrf_token,
verify_csrf_token,
)
from utils.custom_exception import (
ConflictException,
Expand Down Expand Up @@ -80,15 +83,10 @@ async def register_api(
user = result["user"]
session_id = result["session_id"]
access_token = result["access_token"]

response.set_cookie(
key="session_id",
value=session_id,
httponly=settings.COOKIE_HTTPONLY,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
max_age=settings.SESSION_EXPIRE_MINUTES * 60
)
csrf_token = result["csrf_token"]

_set_session_cookie(response, session_id)
_set_csrf_cookie(response, csrf_token)

user_response = UserResponse(
id=user.id,
Expand Down Expand Up @@ -142,23 +140,18 @@ async def login_api(
user = result["user"]
session_id = result["session_id"]
access_token = result["access_token"]

csrf_token = result["csrf_token"]

user_response = UserResponse(
id=user.id,
first_name=user.first_name,
last_name=user.last_name,
email=user.email,
phone=user.phone
)

response.set_cookie(
key="session_id",
value=session_id,
httponly=settings.COOKIE_HTTPONLY,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
max_age=settings.SESSION_EXPIRE_MINUTES * 60
)

_set_session_cookie(response, session_id)
_set_csrf_cookie(response, csrf_token)

response_data = UserLoginResponse(
access_token=access_token,
Expand Down Expand Up @@ -208,7 +201,7 @@ async def logout_api(
# Logout from all devices
if await logout_all_devices(db, redis_client, user_id):
if response:
response.delete_cookie("session_id")
_clear_auth_cookies(response)
return APIResponse(code=200, message="User logged out successfully")
else:
# Logout from current device only
Expand All @@ -217,7 +210,7 @@ async def logout_api(

if await logout(db, redis_client, user_id, session_id):
if response:
response.delete_cookie("session_id")
_clear_auth_cookies(response)
return APIResponse(code=200, message="User logged out successfully")
except AuthenticationException:
raise HTTPException(status_code=401, detail="Invalid or expired session")
Expand All @@ -231,26 +224,67 @@ async def logout_api(
summary="Refresh token",
responses=parse_responses({
200: ("Token refreshed successfully", TokenResponse),
401: ("Invalid or expired session", None)
401: ("Invalid or expired session / Invalid or expired CSRF token", None)
}, common_responses)
)
async def token_api(
request: Request,
db: AsyncSession = Depends(get_db),
redis_client = Depends(get_redis)
):
"""Use session_id Cookie to get new access_token"""
"""Validate CSRF token then use session_id cookie to issue new access_token"""
try:
csrf_token = request.headers.get("X-CSRF-Token")
sid_from_csrf = await verify_csrf_token(redis_client, csrf_token)

session_id = request.cookies.get("session_id")
if not session_id:
if not session_id or session_id != sid_from_csrf:
raise AuthenticationException("Invalid or expired session")

new_access_token = await token(db, redis_client, session_id)
response_data = TokenResponse(
access_token=new_access_token,
expires_at=datetime.now().astimezone() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
)
return APIResponse(code=200, message="Token refreshed successfully", data=response_data)
except (AuthenticationException, NotFoundException):
except AuthenticationException as e:
raise HTTPException(status_code=401, detail=e.message)
except NotFoundException:
raise HTTPException(status_code=401, detail="Invalid or expired session")
except Exception:
raise HTTPException(status_code=500)

@router.post(
"/csrf-token",
response_model=APIResponse[CsrfTokenResponse],
response_model_exclude_none=True,
summary="Get CSRF token",
responses=parse_responses({
200: ("CSRF token retrieved successfully", CsrfTokenResponse),
401: ("Invalid or expired session", None),
}, common_responses),
)
async def csrf_token_api(
request: Request,
response: Response,
redis_client=Depends(get_redis),
):
"""Issue CSRF token for the current session (cookie-based)."""
try:
session_id = request.cookies.get("session_id")
if not session_id:
raise AuthenticationException("Invalid or expired session")

csrf_token = await get_or_create_csrf_token(redis_client, session_id)
_set_csrf_cookie(response, csrf_token)

response_data = CsrfTokenResponse(
csrf_token=csrf_token,
expires_at=datetime.now().astimezone()
+ timedelta(minutes=settings.CSRF_TOKEN_EXPIRE_MINUTES),
)
return APIResponse(code=200, message="CSRF token retrieved successfully", data=response_data)
except AuthenticationException:
raise HTTPException(status_code=401, detail="Invalid or expired session")
except Exception:
raise HTTPException(status_code=500)
Expand Down Expand Up @@ -284,6 +318,7 @@ async def reset_password_api(
user = result["user"]
session_id = result["session_id"]
access_token = result["access_token"]
csrf_token = result["csrf_token"]

user_response = UserResponse(
id=user.id,
Expand All @@ -298,15 +333,9 @@ async def reset_password_api(
expires_at=datetime.now().astimezone() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES),
user=user_response
)

response.set_cookie(
key="session_id",
value=session_id,
httponly=settings.COOKIE_HTTPONLY,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
max_age=settings.SESSION_EXPIRE_MINUTES * 60
)

_set_session_cookie(response, session_id)
_set_csrf_cookie(response, csrf_token)

return APIResponse(code=200, message="Password reset successfully", data=response_data)
except AuthenticationException:
Expand Down Expand Up @@ -431,6 +460,7 @@ async def verify_email_api(
user = result["user"]
session_id = result["session_id"]
access_token = result["access_token"]
csrf_token = result["csrf_token"]

user_response = UserResponse(
id=user.id,
Expand All @@ -445,15 +475,9 @@ async def verify_email_api(
expires_at=datetime.now().astimezone() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES),
user=user_response
)

response.set_cookie(
key="session_id",
value=session_id,
httponly=settings.COOKIE_HTTPONLY,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
max_age=settings.SESSION_EXPIRE_MINUTES * 60
)

_set_session_cookie(response, session_id)
_set_csrf_cookie(response, csrf_token)

return APIResponse(code=200, message="Email verified successfully", data=response_data)
except AuthenticationException:
Expand Down Expand Up @@ -530,4 +554,30 @@ async def get_email_verification_cooldown_api(
)
return APIResponse(code=200, message="Cooldown status retrieved", data=response_data)
except Exception:
raise HTTPException(status_code=500)
raise HTTPException(status_code=500)

def _set_session_cookie(response: Response, session_id: str) -> None:
response.set_cookie(
key="session_id",
value=session_id,
httponly=settings.COOKIE_HTTPONLY,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
max_age=settings.SESSION_EXPIRE_MINUTES * 60,
)


def _set_csrf_cookie(response: Response, csrf_token: str) -> None:
response.set_cookie(
key="csrf_token",
value=csrf_token,
httponly=False,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
max_age=settings.CSRF_TOKEN_EXPIRE_MINUTES * 60,
)


def _clear_auth_cookies(response: Response) -> None:
response.delete_cookie("session_id")
response.delete_cookie("csrf_token")
8 changes: 7 additions & 1 deletion backend/api/auth/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@

class LoginResult(TypedDict):
user: "UserResponse"
session_id: str = Field(..., description="Session ID")
session_id: str = Field(..., description="Session ID")
access_token: str = Field(..., description="JWT access token")
csrf_token: str = Field(..., description="CSRF token")

class SessionResult(TypedDict):
session_id: str = Field(..., description="Session ID")
access_token: str = Field(..., description="JWT access token")
csrf_token: str = Field(..., description="CSRF token")

class UserRegister(BaseModel):
first_name: str = Field(..., min_length=1, max_length=50, description="First name")
Expand Down Expand Up @@ -39,6 +41,10 @@ class TokenResponse(BaseModel):
access_token: str = Field(..., description="JWT access token")
expires_at: datetime = Field(..., description="Token expiration time")

class CsrfTokenResponse(BaseModel):
csrf_token: str = Field(..., description="CSRF token")
expires_at: datetime = Field(..., description="CSRF token expiration time")

class ActionRequiredResponse(BaseModel):
action_type: str = Field(..., description="Action type for frontend routing: 'password_reset' or 'email_verification'")
token: Optional[str] = Field(default=None, description="Token for the password reset")
Expand Down
Loading
Loading