Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 38 additions & 11 deletions backend/api/account/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""
用户管理部分的路由
"""
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, Depends, HTTPException, status, Response, Request
from fastapi.responses import JSONResponse
import hashlib
import jwt
import datetime
Expand All @@ -19,6 +20,7 @@
ADMIN_KEY = settings.ADMIN_KEY
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60
COOKIE_NAME = "access_token"
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")

Expand All @@ -42,15 +44,31 @@ def create_access_token(data: dict, expires_delta: datetime.timedelta = None):
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt

@router.post("/login", response_model=LoginResponse)
async def login(loginRequest: LoginRequest):
def set_token_cookie(response: Response, token: str, expires_minutes: int):
expires = datetime.datetime.utcnow() + datetime.timedelta(minutes=expires_minutes)
response.set_cookie(
key=COOKIE_NAME,
value=token,
httponly=True,
secure=False,
samesite="lax",
expires=expires,
path="/"
)

def clear_token_cookie(response: Response):
response.delete_cookie(
key=COOKIE_NAME,
path="/"
)

@router.post("/login")
async def login(loginRequest: LoginRequest, response: Response):
mysql_client = MysqlClient()
try:
username = loginRequest.username
password = loginRequest.password

# Workaround: Explicitly select needed columns to avoid 'status' column error
# Ideally, ensure the UserInfo model and database schema match.
user_data = mysql_client.db.query(
UserInfo.username,
UserInfo.password,
Expand All @@ -64,22 +82,22 @@ async def login(loginRequest: LoginRequest):

if not verify_password(password, hashed_password):
raise HTTPException(status_code=400, detail="Incorrect username or password")
# Use the fetched delete_sign value
if delete_sign == True:
raise HTTPException(status_code=400, detail="Account disabled")

access_token_expires = datetime.timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
# Use the fetched username for the token
access_token = create_access_token(
data={"sub": fetched_username}, expires_delta=access_token_expires
)

set_token_cookie(response, access_token, ACCESS_TOKEN_EXPIRE_MINUTES)

return LoginResponse(code=200, data=AccessToken(access_token=access_token,token_type="bearer"), message="Login Successful")
finally:
mysql_client.db.close()

@router.post("/signup", response_model=SignUpResponse)
def signup(signupRequest: SignUpRequest):
@router.post("/signup")
def signup(signupRequest: SignUpRequest, response: Response):
mysql_client = MysqlClient()
try:
username = signupRequest.username
Expand All @@ -102,12 +120,14 @@ def signup(signupRequest: SignUpRequest):
data={"sub": new_user.username}, expires_delta=access_token_expires
)

set_token_cookie(response, access_token, ACCESS_TOKEN_EXPIRE_MINUTES)

return SignUpResponse(code=200, data=AccessToken(access_token=access_token,token_type="bearer"), message="Sign Up Successful")
finally:
mysql_client.db.close()

@router.post("/signup_admin", response_model=SignUpResponse)
def signup(signupRequest: SignUpAdminRequest):
@router.post("/signup_admin")
def signup(signupRequest: SignUpAdminRequest, response: Response):
mysql_client = MysqlClient()
try:
username = signupRequest.username
Expand All @@ -133,10 +153,17 @@ def signup(signupRequest: SignUpAdminRequest):
data={"sub": new_user.username}, expires_delta=access_token_expires
)

set_token_cookie(response, access_token, ACCESS_TOKEN_EXPIRE_MINUTES)

return SignUpResponse(code=200, data=AccessToken(access_token=access_token,token_type="bearer"), message="Sign Up Successful")
finally:
mysql_client.db.close()

@router.post("/logout")
def logout(response: Response):
clear_token_cookie(response)
return {"code": 200, "message": "Logout successful"}

# 获取当前用户信息
@router.get("/me")
def read_users_me(token: str = Depends(get_current_user)):
Expand Down
25 changes: 19 additions & 6 deletions backend/api/account/user.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,22 @@
from pydantic import BaseModel
from pydantic import BaseModel, field_validator
from typing import Any,List,Dict
class LoginRequest(BaseModel):
import re

class UsernameValidator(BaseModel):
username: str

@field_validator('username')
@classmethod
def validate_username(cls, v):
if not v:
raise ValueError('用户名不能为空')
if len(v) < 3 or len(v) > 20:
raise ValueError('用户名长度必须在3-20个字符之间')
if not re.match(r'^[a-zA-Z0-9_]+$', v):
raise ValueError('用户名只能包含字母、数字和下划线')
return v

class LoginRequest(UsernameValidator):
password: str

class AccessToken(BaseModel):
Expand All @@ -15,12 +30,10 @@ class LoginResponse(BaseModel):



class SignUpRequest(BaseModel):
username: str
class SignUpRequest(UsernameValidator):
password: str

class SignUpAdminRequest(BaseModel):
username: str
class SignUpAdminRequest(UsernameValidator):
password: str
admin_key: str

Expand Down
33 changes: 26 additions & 7 deletions backend/core/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import uuid
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi import FastAPI, Depends, HTTPException, status, Request
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from config.config_info import settings
import jwt
Expand All @@ -18,22 +18,36 @@
SECRET_KEY = settings.SECRET_KEY
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
COOKIE_NAME = "access_token"

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")

def generate_unique_filename(original_filename):
# 获取文件扩展名
extension = os.path.splitext(original_filename)[1]
# 生成唯一的UUID
unique_id = uuid.uuid4()
# 组合唯一标识符和扩展名
unique_filename = f"{unique_id}{extension}"
return unique_filename

async def get_current_user(token: str = Depends(oauth2_scheme)):
async def get_token_from_request(request: Request):
token = request.cookies.get(COOKIE_NAME)
if token:
return token

authorization = request.headers.get("Authorization")
if authorization and authorization.startswith("Bearer "):
return authorization[7:]

return None

async def get_current_user(request: Request):
credentials_exception = HTTPException(
status_code=401, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"},
)

token = await get_token_from_request(request)
if not token:
raise credentials_exception

try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
Expand All @@ -47,20 +61,25 @@ async def get_current_user(token: str = Depends(oauth2_scheme)):
if user is None:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not Login")


return username

async def get_is_admin(token: str = Depends(oauth2_scheme)):
async def get_is_admin(request: Request):
credentials_exception = HTTPException(
status_code=401, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"},
)

token = await get_token_from_request(request)
if not token:
raise credentials_exception

try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
except jwt.PyJWTError:
raise credentials_exception

mysql_client = MysqlClient()
user = mysql_client.db.query(UserInfo).filter(UserInfo.username == username).first()

Expand Down
Loading