diff --git a/frontend/lib/services/core/api-client.ts b/frontend/lib/services/core/api-client.ts index ff94a7c..8fa5ec2 100644 --- a/frontend/lib/services/core/api-client.ts +++ b/frontend/lib/services/core/api-client.ts @@ -23,6 +23,7 @@ const apiClient = axios.create({ withCredentials: apiConfig.withCredentials, headers: { 'Content-Type': 'application/json', + 'X-Requested-With': 'XMLHttpRequest', }, }); diff --git a/internal/router/middlewares.go b/internal/router/middlewares.go index 92c46eb..b10b426 100644 --- a/internal/router/middlewares.go +++ b/internal/router/middlewares.go @@ -17,6 +17,7 @@ limitations under the License. package router import ( + "net/http" "strconv" "time" @@ -24,10 +25,25 @@ import ( "github.com/linux-do/credit/internal/config" "github.com/linux-do/credit/internal/logger" "github.com/linux-do/credit/internal/otel_trace" + "github.com/linux-do/credit/internal/util" "go.opentelemetry.io/otel/codes" "go.opentelemetry.io/otel/trace" ) +func csrfMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + method := c.Request.Method + if method == http.MethodPost || method == http.MethodPut || method == http.MethodDelete || method == http.MethodPatch { + if c.GetHeader("X-Requested-With") != "XMLHttpRequest" { + c.AbortWithStatusJSON(http.StatusForbidden, util.Err("CSRF 验证失败")) + return + } + } + + c.Next() + } +} + func loggerMiddleware() gin.HandlerFunc { return func(c *gin.Context) { // 初始化 Trace diff --git a/internal/router/router.go b/internal/router/router.go index 4d60a9d..b4b1ab6 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -117,6 +117,7 @@ func Serve() { r.GET("/f/:id", upload.ServeFileByID) apiGroup := r.Group(config.Config.App.APIPrefix) + apiGroup.Use(csrfMiddleware()) { if !config.Config.App.IsProduction() { // Swagger diff --git a/internal/service/payment.go b/internal/service/payment.go index f3f47c2..85d8ff0 100644 --- a/internal/service/payment.go +++ b/internal/service/payment.go @@ -140,9 +140,9 @@ func GetTodayUsedAmount(db *gorm.DB, userID uint64) (decimal.Decimal, error) { var total decimal.Decimal err := db.Model(&model.Order{}). - Where("payer_user_id = ? AND status = ? AND type IN ? AND trade_time >= ? AND trade_time < ?", + Where("payer_user_id = ? AND status IN ? AND type IN ? AND trade_time >= ? AND trade_time < ?", userID, - model.OrderStatusSuccess, + []model.OrderStatus{model.OrderStatusSuccess, model.OrderStatusDisputing, model.OrderStatusRefused}, []model.OrderType{model.OrderTypePayment, model.OrderTypeOnline, model.OrderTypeDistribute, model.OrderTypeTransfer}, todayStart, todayEnd).