diff --git a/core/relay/adaptor/ali/chat.go b/core/relay/adaptor/ali/chat.go index a28d2c47..5576afa5 100644 --- a/core/relay/adaptor/ali/chat.go +++ b/core/relay/adaptor/ali/chat.go @@ -21,7 +21,7 @@ func aliModelMatches(meta *meta.Meta, match func(string) bool) bool { if meta == nil { return false } - return utils.FirstMatchingModelName(meta.OriginModel, meta.ActualModel, match) != "" + return utils.FirstMatchingModelName(match, meta.OriginModel, meta.ActualModel) != "" } func isAliQwen3Model(meta *meta.Meta) bool { diff --git a/core/relay/adaptor/anthropic/openai.go b/core/relay/adaptor/anthropic/openai.go index 8eadf202..3a3ab477 100644 --- a/core/relay/adaptor/anthropic/openai.go +++ b/core/relay/adaptor/anthropic/openai.go @@ -141,11 +141,11 @@ func openAIConvertRequest( claudeRequest.OutputConfig = nil } } else if utils.FirstMatchingModelName( - meta.OriginModel, - meta.ActualModel, func(modelName string) bool { return strings.Contains(strings.ToLower(modelName), "think") }, + meta.OriginModel, + meta.ActualModel, ) != "" { thinkingType := relaymodel.ClaudeThinkingTypeEnabled if shouldAutoUseAdaptiveThinking(resolvedModel) { diff --git a/core/relay/adaptor/anthropic/utils.go b/core/relay/adaptor/anthropic/utils.go index f9c58506..27b0de37 100644 --- a/core/relay/adaptor/anthropic/utils.go +++ b/core/relay/adaptor/anthropic/utils.go @@ -24,12 +24,12 @@ func shouldAutoUseAdaptiveThinking(model string) bool { func ResolveModelName(originModel, actualModel string) string { if modelName := relayutils.FirstMatchingModelName( - originModel, - actualModel, func(modelName string) bool { modelName = strings.ToLower(modelName) return strings.Contains(modelName, "claude") || strings.Contains(modelName, "mythos") }, + originModel, + actualModel, ); modelName != "" { return modelName } diff --git a/core/relay/adaptor/baidu/adaptor.go b/core/relay/adaptor/baidu/adaptor.go index 3ab3c904..c553336b 100644 --- a/core/relay/adaptor/baidu/adaptor.go +++ b/core/relay/adaptor/baidu/adaptor.go @@ -87,12 +87,12 @@ func (a *Adaptor) GetRequestURL( } if endpointModel := utils.FirstMatchingModelName( - meta.OriginModel, - meta.ActualModel, func(modelName string) bool { _, ok := modelEndpointMap[modelName] return ok }, + meta.OriginModel, + meta.ActualModel, ); endpointModel != "" { modelName = endpointModel } diff --git a/core/relay/adaptor/doubao/chat.go b/core/relay/adaptor/doubao/chat.go index 13c6cc3d..9ddfdb04 100644 --- a/core/relay/adaptor/doubao/chat.go +++ b/core/relay/adaptor/doubao/chat.go @@ -27,11 +27,11 @@ func ConvertChatCompletionsRequest( } if utils.FirstMatchingModelName( - meta.OriginModel, - meta.ActualModel, func(modelName string) bool { return strings.HasPrefix(strings.ToLower(modelName), "deepseek-reasoner") }, + meta.OriginModel, + meta.ActualModel, ) != "" { callbacks = append(callbacks, patchDeepseekReasonerSystemPrompt) } diff --git a/core/relay/adaptor/doubao/main.go b/core/relay/adaptor/doubao/main.go index f7bbc63f..f9df727d 100644 --- a/core/relay/adaptor/doubao/main.go +++ b/core/relay/adaptor/doubao/main.go @@ -27,12 +27,12 @@ func featureModel(meta *meta.Meta) string { } if modelName := utils.FirstMatchingModelName( - meta.OriginModel, - meta.ActualModel, func(modelName string) bool { modelName = strings.ToLower(modelName) return strings.HasPrefix(modelName, "bot-") || strings.Contains(modelName, "vision") }, + meta.OriginModel, + meta.ActualModel, ); modelName != "" { return modelName } diff --git a/core/relay/adaptor/gemini/adaptor.go b/core/relay/adaptor/gemini/adaptor.go index db241500..ed695594 100644 --- a/core/relay/adaptor/gemini/adaptor.go +++ b/core/relay/adaptor/gemini/adaptor.go @@ -68,12 +68,12 @@ func requestVersionModel(meta *meta.Meta) string { } if modelName := utils.FirstMatchingModelName( - meta.OriginModel, - meta.ActualModel, func(modelName string) bool { _, ok := v1ModelMap[modelName] return ok }, + meta.OriginModel, + meta.ActualModel, ); modelName != "" { return modelName } diff --git a/core/relay/adaptor/gemini/openai.go b/core/relay/adaptor/gemini/openai.go index e4a34c06..2973e55c 100644 --- a/core/relay/adaptor/gemini/openai.go +++ b/core/relay/adaptor/gemini/openai.go @@ -99,11 +99,11 @@ func resolveGeminiFeatureModel(meta *meta.Meta) string { } if modelName := utils.FirstMatchingModelName( - meta.OriginModel, - meta.ActualModel, func(modelName string) bool { return strings.Contains(strings.ToLower(modelName), "gemini") }, + meta.OriginModel, + meta.ActualModel, ); modelName != "" { return modelName } diff --git a/core/relay/adaptor/openai/chat.go b/core/relay/adaptor/openai/chat.go index 7eae6bdb..37a73bd9 100644 --- a/core/relay/adaptor/openai/chat.go +++ b/core/relay/adaptor/openai/chat.go @@ -30,6 +30,18 @@ type chatCompletionStreamState struct { toolCallArgs string } +func responseModelName(meta *meta.Meta) string { + if meta == nil { + return "" + } + + if meta.OriginModel != "" { + return meta.OriginModel + } + + return meta.ActualModel +} + // handleResponseCreated handles response.created event for ChatCompletion func (s *chatCompletionStreamState) handleResponseCreated( event *relaymodel.ResponseStreamEvent, @@ -44,7 +56,7 @@ func (s *chatCompletionStreamState) handleResponseCreated( ID: s.messageID, Object: relaymodel.ChatCompletionChunkObject, Created: event.Response.CreatedAt, - Model: event.Response.Model, + Model: responseModelName(s.meta), Choices: []*relaymodel.ChatCompletionsStreamResponseChoice{ { Index: 0, @@ -68,7 +80,7 @@ func (s *chatCompletionStreamState) handleOutputTextDelta( ID: s.messageID, Object: relaymodel.ChatCompletionChunkObject, Created: time.Now().Unix(), - Model: s.meta.ActualModel, + Model: responseModelName(s.meta), Choices: []*relaymodel.ChatCompletionsStreamResponseChoice{ { Index: 0, @@ -106,7 +118,7 @@ func (s *chatCompletionStreamState) handleOutputItemAdded( ID: s.messageID, Object: relaymodel.ChatCompletionChunkObject, Created: time.Now().Unix(), - Model: s.meta.ActualModel, + Model: responseModelName(s.meta), Choices: []*relaymodel.ChatCompletionsStreamResponseChoice{ { Index: 0, @@ -133,7 +145,7 @@ func (s *chatCompletionStreamState) handleOutputItemAdded( ID: s.messageID, Object: relaymodel.ChatCompletionChunkObject, Created: time.Now().Unix(), - Model: s.meta.ActualModel, + Model: responseModelName(s.meta), Choices: []*relaymodel.ChatCompletionsStreamResponseChoice{ { Index: 0, @@ -164,7 +176,7 @@ func (s *chatCompletionStreamState) handleFunctionCallArgumentsDelta( ID: s.messageID, Object: relaymodel.ChatCompletionChunkObject, Created: time.Now().Unix(), - Model: s.meta.ActualModel, + Model: responseModelName(s.meta), Choices: []*relaymodel.ChatCompletionsStreamResponseChoice{ { Index: 0, @@ -223,7 +235,7 @@ func (s *chatCompletionStreamState) handleResponseCompleted( ID: s.messageID, Object: relaymodel.ChatCompletionChunkObject, Created: time.Now().Unix(), - Model: s.meta.ActualModel, + Model: responseModelName(s.meta), Choices: []*relaymodel.ChatCompletionsStreamResponseChoice{ { Index: 0, @@ -417,7 +429,7 @@ func StreamHandler( responseText := strings.Builder{} - scanner, cleanup := utils.NewStreamScanner(resp.Body, meta.ActualModel) + scanner, cleanup := utils.NewStreamScanner(resp.Body, meta.OriginModel, meta.ActualModel) defer cleanup() var ( @@ -857,8 +869,11 @@ func ConvertMessagesToInputItems(messages []relaymodel.Message) []relaymodel.Inp // Handle regular messages role := msg.Role // Tool role without ToolCallID is treated as user role - if role == relaymodel.RoleTool { + switch role { + case relaymodel.RoleTool: role = relaymodel.RoleUser + case relaymodel.RoleSystem: + role = relaymodel.RoleDeveloper } inputItem := relaymodel.InputItem{ @@ -1052,13 +1067,17 @@ func ConvertResponsesToChatCompletionResponse( ID: responsesResp.ID, Object: relaymodel.ChatCompletionObject, Created: responsesResp.CreatedAt, - Model: responsesResp.Model, + Model: responseModelName(meta), Choices: []*relaymodel.TextResponseChoice{}, Usage: relaymodel.ChatUsage{}, } // Convert output items to choices for _, outputItem := range responsesResp.Output { + if outputItem.Type != "" && outputItem.Type != relaymodel.InputItemTypeMessage { + continue + } + choice := relaymodel.TextResponseChoice{ Index: 0, // Responses API doesn't have index, default to 0 Message: relaymodel.Message{ @@ -1127,6 +1146,85 @@ func ConvertResponsesToChatCompletionResponse( }, nil } +type responsesStreamErrorState struct { + usage model.Usage + responseID string + lastResponse *relaymodel.Response + pendingFailure *relaymodel.ResponseStreamEvent +} + +func (s *responsesStreamErrorState) update(event *relaymodel.ResponseStreamEvent) { + if event.Response == nil { + return + } + + if s.responseID == "" { + s.responseID = event.Response.ID + } + + s.lastResponse = event.Response + s.usage = event.Response.ToModelUsage() +} + +func (s *responsesStreamErrorState) result() adaptor.DoResponseResult { + asyncUsage := responseNeedsAsyncUsage(s.lastResponse) + if s.pendingFailure != nil { + asyncUsage = false + } + + return adaptor.DoResponseResult{ + Usage: s.usage, + UpstreamID: s.responseID, + AsyncUsage: asyncUsage, + } +} + +func (s *responsesStreamErrorState) errorBeforeEvent( + event *relaymodel.ResponseStreamEvent, +) adaptor.Error { + if s.pendingFailure == nil { + return nil + } + + if event.Type == relaymodel.EventError { + return responseStreamError(event) + } + + return responseStreamError(s.pendingFailure) +} + +func (s *responsesStreamErrorState) handleFailure( + event *relaymodel.ResponseStreamEvent, +) (adaptor.Error, bool) { + if event.Type != relaymodel.EventResponseFailed && event.Type != relaymodel.EventError { + return nil, false + } + + if event.Type == relaymodel.EventResponseFailed && event.Response != nil && + event.Response.Error == nil { + s.update(event) + + pendingEvent := *event + s.pendingFailure = &pendingEvent + + return nil, true + } + + return responseStreamError(event), true +} + +func responseStreamEventCanDelay(eventType string) bool { + switch eventType { + case relaymodel.EventResponseCreated, + relaymodel.EventResponseInProgress, + relaymodel.EventResponseQueued, + relaymodel.EventKeepAlive: + return true + default: + return false + } +} + // ConvertResponsesToChatCompletionStreamResponse converts Responses API stream to ChatCompletion stream func ConvertResponsesToChatCompletionStreamResponse( meta *meta.Meta, @@ -1141,30 +1239,57 @@ func ConvertResponsesToChatCompletionStreamResponse( log := common.GetLogger(c) - scanner, cleanup := utils.NewStreamScanner(resp.Body, meta.ActualModel) + scanner, cleanup := utils.NewStreamScanner(resp.Body, meta.OriginModel, meta.ActualModel) defer cleanup() var ( - usage model.Usage - responseID string - lastResponse *relaymodel.Response + usage model.Usage + responseID string + lastResponse *relaymodel.Response + pendingInitialChunk *relaymodel.ChatCompletionsStreamResponse + wroteStream bool ) + errorState := responsesStreamErrorState{} + state := &chatCompletionStreamState{ meta: meta, c: c, } + stopStream := false - for scanner.Scan() { + var writeChatStreamResp func(*relaymodel.ChatCompletionsStreamResponse) + + writeChatStreamResp = func(chatStreamResp *relaymodel.ChatCompletionsStreamResponse) { + if chatStreamResp == nil { + return + } + + if pendingInitialChunk != nil { + initialChunk := pendingInitialChunk + pendingInitialChunk = nil + + writeChatStreamResp(initialChunk) + } + + chunkData, err := sonic.Marshal(chatStreamResp) + if err != nil { + log.Error("error marshalling chat stream response: " + err.Error()) + return + } + + render.OpenaiBytesData(c, chunkData) + + wroteStream = true + } + + for scanner.Scan() && !stopStream { data := scanner.Bytes() if !render.IsValidSSEData(data) { continue } data = render.ExtractSSEData(data) - if render.IsSSEDone(data) { - break - } // Parse the stream event var event relaymodel.ResponseStreamEvent @@ -1184,12 +1309,20 @@ func ConvertResponsesToChatCompletionStreamResponse( usage = event.Response.ToModelUsage() } + errorState.usage = usage + errorState.responseID = responseID + errorState.lastResponse = lastResponse + + if err := errorState.errorBeforeEvent(&event); err != nil { + return errorState.result(), err + } + // Handle event and get response var chatStreamResp *relaymodel.ChatCompletionsStreamResponse switch event.Type { case relaymodel.EventResponseCreated: - chatStreamResp = state.handleResponseCreated(&event) + pendingInitialChunk = state.handleResponseCreated(&event) case relaymodel.EventOutputTextDelta: chatStreamResp = state.handleOutputTextDelta(&event) case relaymodel.EventOutputItemAdded: @@ -1200,27 +1333,143 @@ func ConvertResponsesToChatCompletionStreamResponse( state.handleOutputItemDone(&event) case relaymodel.EventResponseCompleted, relaymodel.EventResponseDone: chatStreamResp = state.handleResponseCompleted(&event) - } + case relaymodel.EventResponseFailed, relaymodel.EventError: + if wroteStream { + log.Error( + "response stream failed after data was sent: " + responseStreamErrorMessage( + &event, + ), + ) - // Send the converted chunk - if chatStreamResp != nil { - chunkData, err := sonic.Marshal(chatStreamResp) - if err != nil { - log.Error("error marshalling chat stream response: " + err.Error()) + stopStream = true + + break + } + + err, handled := errorState.handleFailure(&event) + if handled && err == nil { continue } - render.OpenaiBytesData(c, chunkData) + if handled { + return errorState.result(), err + } } + + writeChatStreamResp(chatStreamResp) } if err := scanner.Err(); err != nil { log.Error("error reading response stream: " + err.Error()) } + if errorState.pendingFailure != nil && !wroteStream { + return errorState.result(), responseStreamError(errorState.pendingFailure) + } + + if wroteStream { + render.OpenaiDone(c) + } + return adaptor.DoResponseResult{ Usage: usage, UpstreamID: responseID, AsyncUsage: responseNeedsAsyncUsage(lastResponse), }, nil } + +func responseStreamError(event *relaymodel.ResponseStreamEvent) adaptor.Error { + openAIError := relaymodel.OpenAIError{ + Message: responseStreamErrorMessage(event), + Type: relaymodel.ErrorTypeUpstream, + Code: relaymodel.ErrorCodeBadResponse, + } + statusCode := http.StatusBadGateway + + if event.Error != nil { + openAIError = *event.Error + if openAIError.Type == "" { + openAIError.Type = relaymodel.ErrorTypeUpstream + } + + if openAIError.Message == "" { + openAIError.Message = responseStreamErrorMessage(event) + } + + if openAIError.Code == nil { + openAIError.Code = relaymodel.ErrorCodeBadResponse + } + } + + if event.Response != nil && event.Response.Error != nil { + openAIError.Message = event.Response.Error.Message + if event.Response.Error.Code != "" { + openAIError.Code = event.Response.Error.Code + } + } + + if status, ok := streamErrorStatusCode(openAIError.Code); ok { + statusCode = status + } else if status, ok := streamErrorStatusCode(openAIError.Type); ok { + statusCode = status + } else if status, ok := streamErrorStatusCode(openAIError.Message); ok { + statusCode = status + } + + return relaymodel.NewOpenAIError(statusCode, openAIError) +} + +func streamErrorStatusCode(code any) (int, bool) { + switch value := code.(type) { + case int: + return statusCodeFromNumericCode(value) + case int64: + return statusCodeFromNumericCode(int(value)) + case float64: + if value == float64(int(value)) { + return statusCodeFromNumericCode(int(value)) + } + case string: + if status, err := strconv.Atoi(value); err == nil { + return statusCodeFromNumericCode(status) + } + + switch value { + case "too_many_requests", "rate_limit_exceeded": + return http.StatusTooManyRequests, true + case "invalid_request_error", "bad_request", "bad_request_error", "invalid_request": + return http.StatusBadRequest, true + } + + lowerValue := strings.ToLower(value) + if strings.Contains(lowerValue, "system messages are not allowed") { + return http.StatusBadRequest, true + } + } + + return 0, false +} + +func statusCodeFromNumericCode(code int) (int, bool) { + if code >= http.StatusBadRequest && code < 600 { + return code, true + } + + return 0, false +} + +func responseStreamErrorMessage(event *relaymodel.ResponseStreamEvent) string { + if event.Error != nil && event.Error.Message != "" { + return event.Error.Message + } + + if event.Response != nil && event.Response.Error != nil && event.Response.Error.Message != "" { + return event.Response.Error.Message + } + + if event.Type != "" { + return "response stream failed: " + event.Type + } + + return "response stream failed" +} diff --git a/core/relay/adaptor/openai/chat_test.go b/core/relay/adaptor/openai/chat_test.go index 8cd68133..ee80f330 100644 --- a/core/relay/adaptor/openai/chat_test.go +++ b/core/relay/adaptor/openai/chat_test.go @@ -251,6 +251,27 @@ func TestConvertChatCompletionToResponsesRequest(t *testing.T) { assert.False(t, *responsesReq.Store) }, }, + { + name: "system messages become developer messages", + inputRequest: relaymodel.GeneralOpenAIRequest{ + Model: "gpt-5.5", + Messages: []relaymodel.Message{ + {Role: "system", Content: "You are concise."}, + {Role: "user", Content: "Hello"}, + }, + }, + checkFunc: func(t *testing.T, responsesReq relaymodel.CreateResponseRequest) { + t.Helper() + + inputItems, ok := responsesReq.Input.([]any) + require.True(t, ok) + require.Len(t, inputItems, 2) + + systemItem, ok := inputItems[0].(map[string]any) + require.True(t, ok) + assert.Equal(t, "developer", systemItem["role"]) + }, + }, { name: "request with temperature and max_tokens", inputRequest: relaymodel.GeneralOpenAIRequest{ @@ -428,7 +449,6 @@ func TestConvertResponsesToChatCompletionResponse(t *testing.T) { checkFunc: func(t *testing.T, chatResp relaymodel.TextResponse) { t.Helper() assert.Equal(t, "resp_123", chatResp.ID) - assert.Equal(t, "gpt-5-codex", chatResp.Model) assert.Equal(t, "chat.completion", chatResp.Object) require.Len(t, chatResp.Choices, 1) assert.Contains(t, chatResp.Choices[0].Message.Content, "Hello, world!") @@ -468,16 +488,8 @@ func TestConvertResponsesToChatCompletionResponse(t *testing.T) { }, checkFunc: func(t *testing.T, chatResp relaymodel.TextResponse) { t.Helper() - // Current implementation creates one choice per output item - require.Len(t, chatResp.Choices, 2) - // First choice is reasoning - assert.Contains( - t, - chatResp.Choices[0].Message.Content, - "Let me think about this...", - ) - // Second choice is the message - assert.Contains(t, chatResp.Choices[1].Message.Content, "The answer is 42.") + require.Len(t, chatResp.Choices, 1) + assert.Contains(t, chatResp.Choices[0].Message.Content, "The answer is 42.") }, expectedStatus: http.StatusOK, }, @@ -500,6 +512,7 @@ func TestConvertResponsesToChatCompletionResponse(t *testing.T) { c, _ := gin.CreateTestContext(w) m := &meta.Meta{ + OriginModel: "client-gpt-5", ActualModel: tt.responsesResp.Model, } @@ -511,6 +524,7 @@ func TestConvertResponsesToChatCompletionResponse(t *testing.T) { err = json.Unmarshal(w.Body.Bytes(), &chatResp) require.NoError(t, err) + assert.Equal(t, "client-gpt-5", chatResp.Model) tt.checkFunc(t, chatResp) }) } @@ -530,8 +544,6 @@ func TestConvertResponsesToChatCompletionStreamResponseSkipsOutputItemDoneConten "", `data: {"type":"response.completed","response":{"id":"resp_123","object":"response","created_at":1780731105,"status":"completed","model":"gpt-5.1","output":[{"id":"msg_123","type":"message","role":"assistant","content":[{"type":"output_text","text":"Hello! What would you like to discuss or work on?"}]}],"parallel_tool_calls":true,"store":false,"usage":{"input_tokens":7,"output_tokens":22,"total_tokens":29}}}`, "", - `data: [DONE]`, - "", }, "\n") httpResp := &http.Response{ @@ -563,6 +575,271 @@ func TestConvertResponsesToChatCompletionStreamResponseSkipsOutputItemDoneConten 1, strings.Count(w.Body.String(), "Hello! What would you like to discuss or work on?"), ) + assert.Equal(t, 1, strings.Count(w.Body.String(), "data: [DONE]")) +} + +func TestConvertResponsesToChatCompletionStreamResponseReturnsErrorBeforeDownstreamWrite( + t *testing.T, +) { + gin.SetMode(gin.TestMode) + + stream := strings.Join([]string{ + `event: response.created`, + `data: {"type":"response.created","response":{"id":"resp_123","object":"response","created_at":1781332973,"status":"in_progress","model":"gpt-5-mini","output":[],"parallel_tool_calls":true,"store":false}}`, + "", + `event: response.failed`, + `data: {"type":"response.failed","response":{"id":"resp_123","object":"response","created_at":1781332973,"status":"failed","model":"gpt-5-mini","output":[],"parallel_tool_calls":true,"store":false},"sequence_number":1}`, + "", + `event: error`, + `data: {"type":"error","error":{"type":"too_many_requests","code":"too_many_requests","message":"Too Many Requests","param":null},"sequence_number":2}`, + "", + }, "\n") + + httpResp := &http.Response{ + StatusCode: http.StatusOK, + Body: &mockReadCloser{Reader: bytes.NewReader([]byte(stream))}, + Header: make(http.Header), + } + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequestWithContext( + t.Context(), + http.MethodPost, + "/v1/chat/completions", + nil, + ) + + m := &meta.Meta{ + ActualModel: "gpt-5-mini", + } + + result, err := openai.ConvertResponsesToChatCompletionStreamResponse(m, c, httpResp) + require.NotNil(t, err) + assert.Equal(t, http.StatusTooManyRequests, err.StatusCode()) + assert.Equal(t, "resp_123", result.UpstreamID) + assert.Empty(t, w.Body.String()) +} + +func TestConvertResponsesToChatCompletionStreamResponsePreservesNumericStreamErrorStatus( + t *testing.T, +) { + gin.SetMode(gin.TestMode) + + stream := strings.Join([]string{ + `event: response.created`, + `data: {"type":"response.created","response":{"id":"resp_429","object":"response","created_at":1781332973,"status":"in_progress","model":"gpt-5-mini","output":[],"parallel_tool_calls":true,"store":false}}`, + "", + `event: error`, + `data: {"type":"error","error":{"type":"too_many_requests","code":429,"message":"Too Many Requests","param":null},"sequence_number":1}`, + "", + }, "\n") + + httpResp := &http.Response{ + StatusCode: http.StatusOK, + Body: &mockReadCloser{Reader: bytes.NewReader([]byte(stream))}, + Header: make(http.Header), + } + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequestWithContext( + t.Context(), + http.MethodPost, + "/v1/chat/completions", + nil, + ) + + m := &meta.Meta{ + ActualModel: "gpt-5-mini", + } + + result, err := openai.ConvertResponsesToChatCompletionStreamResponse(m, c, httpResp) + require.NotNil(t, err) + assert.Equal(t, http.StatusTooManyRequests, err.StatusCode()) + assert.Equal(t, "resp_429", result.UpstreamID) + assert.Empty(t, w.Body.String()) +} + +func TestConvertResponsesToChatCompletionStreamResponseFailedWithoutErrorDoesNotMarkAsyncUsage( + t *testing.T, +) { + gin.SetMode(gin.TestMode) + + stream := strings.Join([]string{ + `event: response.created`, + `data: {"type":"response.created","response":{"id":"resp_failed","object":"response","created_at":1781332973,"status":"in_progress","model":"gpt-5-mini","output":[],"parallel_tool_calls":true,"store":true}}`, + "", + `event: response.in_progress`, + `data: {"type":"response.in_progress","response":{"id":"resp_failed","object":"response","created_at":1781332973,"status":"in_progress","model":"gpt-5-mini","output":[],"parallel_tool_calls":true,"store":true}}`, + "", + `event: response.failed`, + `data: {"type":"response.failed","response":{"id":"resp_failed","object":"response","created_at":1781332973,"status":"failed","model":"gpt-5-mini","output":[],"parallel_tool_calls":true,"store":true},"sequence_number":2}`, + "", + }, "\n") + + httpResp := &http.Response{ + StatusCode: http.StatusOK, + Body: &mockReadCloser{Reader: bytes.NewReader([]byte(stream))}, + Header: make(http.Header), + } + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequestWithContext( + t.Context(), + http.MethodPost, + "/v1/chat/completions", + nil, + ) + + m := &meta.Meta{ + ActualModel: "gpt-5-mini", + } + + result, err := openai.ConvertResponsesToChatCompletionStreamResponse(m, c, httpResp) + require.NotNil(t, err) + assert.Equal(t, http.StatusBadGateway, err.StatusCode()) + assert.Equal(t, "resp_failed", result.UpstreamID) + assert.False(t, result.AsyncUsage) + assert.Empty(t, w.Body.String()) +} + +func TestConvertResponsesToChatCompletionStreamResponseMapsInvalidRequestErrorToBadRequest( + t *testing.T, +) { + gin.SetMode(gin.TestMode) + + stream := strings.Join([]string{ + `event: error`, + `data: {"type":"error","error":{"type":"invalid_request_error","code":"bad_response","message":"System messages are not allowed","param":null},"sequence_number":1}`, + "", + }, "\n") + + httpResp := &http.Response{ + StatusCode: http.StatusOK, + Body: &mockReadCloser{Reader: bytes.NewReader([]byte(stream))}, + Header: make(http.Header), + } + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequestWithContext( + t.Context(), + http.MethodPost, + "/v1/chat/completions", + nil, + ) + + m := &meta.Meta{ + ActualModel: "gpt-5.5", + } + + result, err := openai.ConvertResponsesToChatCompletionStreamResponse(m, c, httpResp) + require.NotNil(t, err) + assert.Equal(t, http.StatusBadRequest, err.StatusCode()) + assert.Empty(t, result.UpstreamID) + assert.Empty(t, w.Body.String()) +} + +func TestConvertResponsesToChatCompletionStreamResponseHandlesErrorAfterDownstreamWrite( + t *testing.T, +) { + gin.SetMode(gin.TestMode) + + stream := strings.Join([]string{ + `data: {"type":"response.created","response":{"id":"resp_123","object":"response","created_at":1781332973,"status":"in_progress","model":"gpt-5-mini","output":[],"parallel_tool_calls":true,"store":false}}`, + "", + `data: {"type":"response.output_text.delta","delta":"partial"}`, + "", + `event: error`, + `data: {"type":"error","error":{"type":"server_error","code":"server_error","message":"stream failed","param":null},"sequence_number":2}`, + "", + `data: {"type":"response.output_text.delta","delta":"late"}`, + "", + `data: {"type":"response.completed","response":{"id":"resp_123","object":"response","created_at":1781332973,"status":"completed","model":"gpt-5-mini","output":[],"parallel_tool_calls":true,"store":false}}`, + "", + }, "\n") + + httpResp := &http.Response{ + StatusCode: http.StatusOK, + Body: &mockReadCloser{Reader: bytes.NewReader([]byte(stream))}, + Header: make(http.Header), + } + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequestWithContext( + t.Context(), + http.MethodPost, + "/v1/chat/completions", + nil, + ) + + m := &meta.Meta{ + ActualModel: "gpt-5-mini", + } + + _, err := openai.ConvertResponsesToChatCompletionStreamResponse(m, c, httpResp) + require.Nil(t, err) + assert.Equal(t, "partial", collectChatCompletionStreamContent(t, w.Body.String())) + assert.Equal(t, 1, strings.Count(w.Body.String(), "data: [DONE]")) + assert.NotContains(t, w.Body.String(), "late") +} + +func TestConvertResponsesToChatCompletionStreamResponseUsesOriginModelForEveryChunk( + t *testing.T, +) { + gin.SetMode(gin.TestMode) + + stream := strings.Join([]string{ + `data: {"type":"response.created","response":{"id":"resp_123","object":"response","created_at":1781332973,"status":"in_progress","model":"mapped-gpt-5","output":[],"parallel_tool_calls":true,"store":false}}`, + "", + `data: {"type":"response.output_text.delta","delta":"partial"}`, + "", + `data: {"type":"response.completed","response":{"id":"resp_123","object":"response","created_at":1781332973,"status":"completed","model":"mapped-gpt-5","output":[],"parallel_tool_calls":true,"store":false,"usage":{"input_tokens":1,"output_tokens":1,"total_tokens":2}}}`, + "", + }, "\n") + + httpResp := &http.Response{ + StatusCode: http.StatusOK, + Body: &mockReadCloser{Reader: bytes.NewReader([]byte(stream))}, + Header: make(http.Header), + } + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequestWithContext( + t.Context(), + http.MethodPost, + "/v1/chat/completions", + nil, + ) + + m := &meta.Meta{ + OriginModel: "gpt-5", + ActualModel: "mapped-gpt-5", + } + + _, err := openai.ConvertResponsesToChatCompletionStreamResponse(m, c, httpResp) + require.Nil(t, err) + assert.NotContains(t, w.Body.String(), "mapped-gpt-5") + + chunkCount := 0 + for line := range strings.SplitSeq(w.Body.String(), "\n") { + line = strings.TrimSpace(line) + if !strings.HasPrefix(line, "data: ") || line == "data: [DONE]" { + continue + } + + var chunk relaymodel.ChatCompletionsStreamResponse + require.NoError(t, json.Unmarshal([]byte(strings.TrimPrefix(line, "data: ")), &chunk)) + assert.Equal(t, "gpt-5", chunk.Model) + + chunkCount++ + } + + assert.GreaterOrEqual(t, chunkCount, 2) } func collectChatCompletionStreamContent(t *testing.T, body string) string { diff --git a/core/relay/adaptor/openai/claude.go b/core/relay/adaptor/openai/claude.go index 2c46248f..4dc2207c 100644 --- a/core/relay/adaptor/openai/claude.go +++ b/core/relay/adaptor/openai/claude.go @@ -10,7 +10,6 @@ import ( "github.com/bytedance/sonic" "github.com/gin-gonic/gin" "github.com/labring/aiproxy/core/common" - "github.com/labring/aiproxy/core/model" "github.com/labring/aiproxy/core/relay/adaptor" "github.com/labring/aiproxy/core/relay/meta" relaymodel "github.com/labring/aiproxy/core/relay/model" @@ -354,7 +353,7 @@ func ClaudeStreamHandler( log := common.GetLogger(c) - scanner, cleanup := utils.NewStreamScanner(resp.Body, meta.ActualModel) + scanner, cleanup := utils.NewStreamScanner(resp.Body, meta.OriginModel, meta.ActualModel) defer cleanup() // Initialize Claude response tracking @@ -416,7 +415,7 @@ func ClaudeStreamHandler( ID: messageID, Type: relaymodel.ClaudeTypeMessage, Role: relaymodel.RoleAssistant, - Model: meta.ActualModel, + Model: responseModelName(meta), Content: []relaymodel.ClaudeContent{}, }, } @@ -632,7 +631,7 @@ func ClaudeHandler( ID: "msg_" + common.ShortUUID(), Type: relaymodel.ClaudeTypeMessage, Role: relaymodel.RoleAssistant, - Model: meta.ActualModel, + Model: responseModelName(meta), Content: []relaymodel.ClaudeContent{}, StopReason: "", StopSequence: nil, @@ -887,7 +886,7 @@ func ConvertResponsesToClaudeResponse( ID: responsesResp.ID, Type: relaymodel.ClaudeTypeMessage, Role: relaymodel.RoleAssistant, - Model: responsesResp.Model, + Model: responseModelName(meta), Content: []relaymodel.ClaudeContent{}, } @@ -970,13 +969,13 @@ func ConvertResponsesToClaudeStreamResponse( log := common.GetLogger(c) - scanner, cleanup := utils.NewStreamScanner(resp.Body, meta.ActualModel) + scanner, cleanup := utils.NewStreamScanner(resp.Body, meta.OriginModel, meta.ActualModel) defer cleanup() var ( - usage model.Usage - responseID string - lastResponse *relaymodel.Response + errorState responsesStreamErrorState + pendingCreated *relaymodel.ResponseStreamEvent + wroteStream bool ) state := &claudeStreamState{ @@ -991,9 +990,6 @@ func ConvertResponsesToClaudeStreamResponse( } data = render.ExtractSSEData(data) - if render.IsSSEDone(data) { - break - } // Parse the stream event var event relaymodel.ResponseStreamEvent @@ -1004,19 +1000,56 @@ func ConvertResponsesToClaudeStreamResponse( continue } - if event.Response != nil { - if responseID == "" { - responseID = event.Response.ID + errorState.update(&event) + + if err := errorState.errorBeforeEvent(&event); err != nil { + return errorState.result(), err + } + + if event.Type == relaymodel.EventResponseFailed || event.Type == relaymodel.EventError { + if wroteStream { + log.Error( + "response stream failed after data was sent: " + responseStreamErrorMessage( + &event, + ), + ) + + break } - lastResponse = event.Response - usage = event.Response.ToModelUsage() + err, handled := errorState.handleFailure(&event) + if handled && err == nil { + continue + } + + if handled { + return errorState.result(), err + } + } + + if event.Type == relaymodel.EventResponseCreated && !wroteStream { + pendingEvent := event + pendingCreated = &pendingEvent + continue + } + + if pendingCreated != nil && !claudeResponseStreamEventWrites(event.Type) { + continue + } + + if pendingCreated != nil { + state.handleResponseCreated(pendingCreated) + + wroteStream = true + pendingCreated = nil } // Handle events switch event.Type { case relaymodel.EventResponseCreated: state.handleResponseCreated(&event) + + wroteStream = true case relaymodel.EventOutputItemAdded: state.handleOutputItemAdded(&event) case relaymodel.EventContentPartAdded: @@ -1038,11 +1071,27 @@ func ConvertResponsesToClaudeStreamResponse( log.Error("error reading response stream: " + err.Error()) } - return adaptor.DoResponseResult{ - Usage: usage, - UpstreamID: responseID, - AsyncUsage: responseNeedsAsyncUsage(lastResponse), - }, nil + if errorState.pendingFailure != nil && !wroteStream { + return errorState.result(), responseStreamError(errorState.pendingFailure) + } + + return errorState.result(), nil +} + +func claudeResponseStreamEventWrites(eventType string) bool { + switch eventType { + case relaymodel.EventOutputItemAdded, + relaymodel.EventContentPartAdded, + relaymodel.EventReasoningTextDelta, + relaymodel.EventOutputTextDelta, + relaymodel.EventFunctionCallArgumentsDelta, + relaymodel.EventOutputItemDone, + relaymodel.EventResponseCompleted, + relaymodel.EventResponseDone: + return true + default: + return false + } } // claudeStreamState manages state for Claude stream conversion @@ -1075,7 +1124,7 @@ func (s *claudeStreamState) handleResponseCreated(event *relaymodel.ResponseStre ID: s.messageID, Type: relaymodel.ClaudeTypeMessage, Role: relaymodel.RoleAssistant, - Model: event.Response.Model, + Model: responseModelName(s.meta), Content: []relaymodel.ClaudeContent{}, }, }) diff --git a/core/relay/adaptor/openai/claude_test.go b/core/relay/adaptor/openai/claude_test.go index 87069413..fa83d5f6 100644 --- a/core/relay/adaptor/openai/claude_test.go +++ b/core/relay/adaptor/openai/claude_test.go @@ -4,8 +4,10 @@ import ( "bytes" "context" "encoding/json" + "io" "net/http" "net/http/httptest" + "strings" "testing" "github.com/gin-gonic/gin" @@ -459,6 +461,7 @@ func TestConvertResponsesToClaudeResponse(t *testing.T) { c, _ := gin.CreateTestContext(w) m := &meta.Meta{ + OriginModel: "client-claude", ActualModel: tt.responsesResp.Model, } @@ -475,6 +478,7 @@ func TestConvertResponsesToClaudeResponse(t *testing.T) { // Verify assert.Equal(t, tt.expectedType, claudeResp.Type) assert.Equal(t, tt.expectedRole, claudeResp.Role) + assert.Equal(t, "client-claude", claudeResp.Model) assert.NotEmpty(t, claudeResp.Content) if tt.hasReasoning { @@ -606,3 +610,46 @@ func TestConvertClaudeToolsToOpenAI_WithRequiredField(t *testing.T) { }) } } + +func TestConvertResponsesToClaudeStreamResponseReturnsErrorBeforeRealOutputAfterLifecycleEvents( + t *testing.T, +) { + gin.SetMode(gin.TestMode) + + stream := strings.Join([]string{ + `event: response.created`, + `data: {"type":"response.created","response":{"id":"resp_123","object":"response","created_at":1,"status":"in_progress","model":"gpt-5","output":[],"parallel_tool_calls":true,"store":false}}`, + "", + `event: response.in_progress`, + `data: {"type":"response.in_progress","response":{"id":"resp_123","object":"response","created_at":1,"status":"in_progress","model":"gpt-5","output":[],"parallel_tool_calls":true,"store":false}}`, + "", + `event: error`, + `data: {"type":"error","error":{"type":"server_error","code":"server_error","message":"stream failed"}}`, + "", + }, "\n") + + httpResp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader([]byte(stream))), + Header: make(http.Header), + } + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequestWithContext( + t.Context(), + http.MethodPost, + "/v1/messages", + nil, + ) + + m := &meta.Meta{ + ActualModel: "gpt-5", + } + + result, err := openai.ConvertResponsesToClaudeStreamResponse(m, c, httpResp) + require.NotNil(t, err) + assert.Equal(t, http.StatusBadGateway, err.StatusCode()) + assert.Equal(t, "resp_123", result.UpstreamID) + assert.Empty(t, w.Body.String()) +} diff --git a/core/relay/adaptor/openai/gemini.go b/core/relay/adaptor/openai/gemini.go index 70dd8f38..32c091d7 100644 --- a/core/relay/adaptor/openai/gemini.go +++ b/core/relay/adaptor/openai/gemini.go @@ -114,7 +114,7 @@ func ConvertOpenAIToGeminiResponse( openaiResp *relaymodel.TextResponse, ) *relaymodel.GeminiChatResponse { geminiResp := &relaymodel.GeminiChatResponse{ - ModelVersion: meta.ActualModel, + ModelVersion: responseModelName(meta), } if openaiResp.Usage.TotalTokens > 0 { @@ -201,7 +201,7 @@ func GeminiStreamHandler( defer resp.Body.Close() - scanner, cleanup := utils.NewStreamScanner(resp.Body, meta.ActualModel) + scanner, cleanup := utils.NewStreamScanner(resp.Body, meta.OriginModel, meta.ActualModel) defer cleanup() usage := model.Usage{} @@ -257,7 +257,7 @@ func (s *GeminiStreamState) ConvertOpenAIStreamToGemini( openaiResp *relaymodel.ChatCompletionsStreamResponse, ) *relaymodel.GeminiChatResponse { geminiResp := &relaymodel.GeminiChatResponse{ - ModelVersion: meta.ActualModel, + ModelVersion: responseModelName(meta), Candidates: []*relaymodel.GeminiChatCandidate{}, } @@ -975,7 +975,7 @@ func ConvertResponsesToGeminiResponse( // Convert to Gemini format geminiResp := relaymodel.GeminiChatResponse{ - ModelVersion: responsesResp.Model, + ModelVersion: responseModelName(meta), Candidates: []*relaymodel.GeminiChatCandidate{}, } @@ -1096,13 +1096,12 @@ func ConvertResponsesToGeminiStreamResponse( log := common.GetLogger(c) - scanner, cleanup := utils.NewStreamScanner(resp.Body, meta.ActualModel) + scanner, cleanup := utils.NewStreamScanner(resp.Body, meta.OriginModel, meta.ActualModel) defer cleanup() var ( - usage model.Usage - responseID string - lastResponse *relaymodel.Response + errorState responsesStreamErrorState + wroteStream bool ) state := &geminiStreamState{ @@ -1110,16 +1109,15 @@ func ConvertResponsesToGeminiStreamResponse( c: c, } - for scanner.Scan() { + stopStream := false + + for scanner.Scan() && !stopStream { data := scanner.Bytes() if !render.IsValidSSEData(data) { continue } data = render.ExtractSSEData(data) - if render.IsSSEDone(data) { - break - } // Parse the stream event var event relaymodel.ResponseStreamEvent @@ -1130,13 +1128,33 @@ func ConvertResponsesToGeminiStreamResponse( continue } - if event.Response != nil { - if responseID == "" { - responseID = event.Response.ID + errorState.update(&event) + + if err := errorState.errorBeforeEvent(&event); err != nil { + return errorState.result(), err + } + + if event.Type == relaymodel.EventResponseFailed || event.Type == relaymodel.EventError { + if wroteStream { + log.Error( + "response stream failed after data was sent: " + responseStreamErrorMessage( + &event, + ), + ) + + stopStream = true + + continue } - lastResponse = event.Response - usage = event.Response.ToModelUsage() + err, handled := errorState.handleFailure(&event) + if handled && err == nil { + continue + } + + if handled { + return errorState.result(), err + } } // Handle events @@ -1146,11 +1164,17 @@ func ConvertResponsesToGeminiStreamResponse( case relaymodel.EventOutputItemAdded: state.handleOutputItemAdded(&event) case relaymodel.EventOutputTextDelta: - state.handleOutputTextDelta(&event) + if state.handleOutputTextDelta(&event) { + wroteStream = true + } case relaymodel.EventFunctionCallArgumentsDone: - state.handleFunctionCallArgumentsDone(&event) + if state.handleFunctionCallArgumentsDone(&event) { + wroteStream = true + } case relaymodel.EventResponseCompleted, relaymodel.EventResponseDone: - state.handleResponseCompleted(&event) + if state.handleResponseCompleted(&event) { + wroteStream = true + } } } @@ -1158,11 +1182,11 @@ func ConvertResponsesToGeminiStreamResponse( log.Error("error reading response stream: " + err.Error()) } - return adaptor.DoResponseResult{ - Usage: usage, - UpstreamID: responseID, - AsyncUsage: responseNeedsAsyncUsage(lastResponse), - }, nil + if errorState.pendingFailure != nil && !wroteStream { + return errorState.result(), responseStreamError(errorState.pendingFailure) + } + + return errorState.result(), nil } // geminiStreamState manages state for Gemini stream conversion @@ -1189,14 +1213,14 @@ func (s *geminiStreamState) handleOutputItemAdded(event *relaymodel.ResponseStre } // handleOutputTextDelta handles response.output_text.delta event for Gemini -func (s *geminiStreamState) handleOutputTextDelta(event *relaymodel.ResponseStreamEvent) { +func (s *geminiStreamState) handleOutputTextDelta(event *relaymodel.ResponseStreamEvent) bool { if event.Delta == "" { - return + return false } // Send text delta geminiResp := relaymodel.GeminiChatResponse{ - ModelVersion: s.meta.ActualModel, + ModelVersion: responseModelName(s.meta), Candidates: []*relaymodel.GeminiChatCandidate{ { Index: 0, @@ -1213,29 +1237,33 @@ func (s *geminiStreamState) handleOutputTextDelta(event *relaymodel.ResponseStre } _ = render.GeminiObjectData(s.c, geminiResp) + + return true } // handleFunctionCallArgumentsDone handles response.function_call_arguments.done event for Gemini -func (s *geminiStreamState) handleFunctionCallArgumentsDone(event *relaymodel.ResponseStreamEvent) { +func (s *geminiStreamState) handleFunctionCallArgumentsDone( + event *relaymodel.ResponseStreamEvent, +) bool { if event.Arguments == "" || event.ItemID == "" { - return + return false } // Get function name from tracked state functionName := s.functionCallNames[event.ItemID] if functionName == "" { - return + return false } // Parse arguments var args map[string]any if err := sonic.UnmarshalString(event.Arguments, &args); err != nil { - return + return false } // Send complete function call geminiResp := relaymodel.GeminiChatResponse{ - ModelVersion: s.meta.ActualModel, + ModelVersion: responseModelName(s.meta), Candidates: []*relaymodel.GeminiChatCandidate{ { Index: 0, @@ -1255,18 +1283,20 @@ func (s *geminiStreamState) handleFunctionCallArgumentsDone(event *relaymodel.Re } _ = render.GeminiObjectData(s.c, geminiResp) + + return true } // handleResponseCompleted handles response.completed/done event for Gemini -func (s *geminiStreamState) handleResponseCompleted(event *relaymodel.ResponseStreamEvent) { +func (s *geminiStreamState) handleResponseCompleted(event *relaymodel.ResponseStreamEvent) bool { if event.Response == nil || event.Response.Usage == nil { - return + return false } // Send final response with usage geminiUsage := event.Response.Usage.ToGeminiUsage() geminiResp := relaymodel.GeminiChatResponse{ - ModelVersion: s.meta.ActualModel, + ModelVersion: responseModelName(s.meta), UsageMetadata: &geminiUsage, Candidates: []*relaymodel.GeminiChatCandidate{ { @@ -1281,4 +1311,6 @@ func (s *geminiStreamState) handleResponseCompleted(event *relaymodel.ResponseSt } _ = render.GeminiObjectData(s.c, geminiResp) + + return true } diff --git a/core/relay/adaptor/openai/gemini_test.go b/core/relay/adaptor/openai/gemini_test.go index dbbecbef..da1c06bb 100644 --- a/core/relay/adaptor/openai/gemini_test.go +++ b/core/relay/adaptor/openai/gemini_test.go @@ -1,16 +1,21 @@ package openai_test import ( + "bytes" "context" "encoding/json" "io" "net/http" + "net/http/httptest" "strings" "testing" + "github.com/gin-gonic/gin" "github.com/labring/aiproxy/core/relay/adaptor/openai" "github.com/labring/aiproxy/core/relay/meta" relaymodel "github.com/labring/aiproxy/core/relay/model" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestConvertGeminiRequest_MapsThinkingConfigToReasoningEffort(t *testing.T) { @@ -132,6 +137,45 @@ func TestConvertGeminiRequest_MapsThinkingConfigToReasoningEffort(t *testing.T) } } +func TestConvertResponsesToGeminiStreamResponseReturnsErrorAfterUnwrittenFunctionCall( + t *testing.T, +) { + gin.SetMode(gin.TestMode) + + stream := strings.Join([]string{ + `data: {"type":"response.function_call_arguments.done","item_id":"fc_missing","arguments":"{\"query\":\"hello\"}"}`, + "", + `event: error`, + `data: {"type":"error","error":{"type":"server_error","code":"server_error","message":"stream failed"}}`, + "", + }, "\n") + + httpResp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader([]byte(stream))), + Header: make(http.Header), + } + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequestWithContext( + t.Context(), + http.MethodPost, + "/v1beta/models/gemini-pro:streamGenerateContent", + nil, + ) + + m := &meta.Meta{ + ActualModel: "gpt-5", + } + + result, err := openai.ConvertResponsesToGeminiStreamResponse(m, c, httpResp) + require.NotNil(t, err) + assert.Equal(t, http.StatusBadGateway, err.StatusCode()) + assert.Empty(t, result.UpstreamID) + assert.Empty(t, w.Body.String()) +} + func TestConvertGeminiRequest_ToolResponse(t *testing.T) { // Reproduce the user's scenario: // 1. Setup message (User) diff --git a/core/relay/adaptor/openai/gemini_to_responses_test.go b/core/relay/adaptor/openai/gemini_to_responses_test.go index 286689dc..d6283170 100644 --- a/core/relay/adaptor/openai/gemini_to_responses_test.go +++ b/core/relay/adaptor/openai/gemini_to_responses_test.go @@ -98,14 +98,14 @@ func TestConvertGeminiToResponsesRequest_WithFunctionCalls(t *testing.T) { t, 4, len(inputArray), - "Should have 4 items: system message, user message, function call, function result", + "Should have 4 items: developer message, user message, function call, function result", ) - // Verify system message + // Verify developer message systemMsg, ok := inputArray[0].(map[string]any) require.True(t, ok) assert.Equal(t, "message", systemMsg["type"]) - assert.Equal(t, "system", systemMsg["role"]) + assert.Equal(t, "developer", systemMsg["role"]) // Verify user message userMsg, ok := inputArray[1].(map[string]any) @@ -470,6 +470,7 @@ func TestConvertResponsesToGeminiResponse(t *testing.T) { c, _ := gin.CreateTestContext(w) m := &meta.Meta{ + OriginModel: "client-gemini", ActualModel: tt.responsesResp.Model, } @@ -484,7 +485,7 @@ func TestConvertResponsesToGeminiResponse(t *testing.T) { require.NoError(t, err) // Verify - assert.Equal(t, tt.responsesResp.Model, geminiResp.ModelVersion) + assert.Equal(t, "client-gemini", geminiResp.ModelVersion) assert.NotEmpty(t, geminiResp.Candidates) if tt.hasReasoning { diff --git a/core/relay/adaptor/openai/image.go b/core/relay/adaptor/openai/image.go index 65d344fd..5eed2472 100644 --- a/core/relay/adaptor/openai/image.go +++ b/core/relay/adaptor/openai/image.go @@ -253,7 +253,7 @@ func ImagesStreamHandler( log := common.GetLogger(c) - scanner, cleanup := utils.NewStreamScanner(resp.Body, meta.ActualModel) + scanner, cleanup := utils.NewStreamScanner(resp.Body, meta.OriginModel, meta.ActualModel) defer cleanup() usage := model.Usage{} diff --git a/core/relay/adaptor/openai/reasoning.go b/core/relay/adaptor/openai/reasoning.go index a77d1eb8..5eb1b065 100644 --- a/core/relay/adaptor/openai/reasoning.go +++ b/core/relay/adaptor/openai/reasoning.go @@ -239,12 +239,12 @@ func openAIReasoningEffortsForModel( actualModel string, ) ([]relaymodel.ReasoningEffort, bool) { return openAIReasoningEffortsForName(utils.FirstMatchingModelName( - originModel, - actualModel, func(modelName string) bool { _, ok := openAIReasoningEffortsForName(modelName) return ok }, + originModel, + actualModel, )) } diff --git a/core/relay/adaptor/openai/response.go b/core/relay/adaptor/openai/response.go index 1a9f26a6..07a2b94c 100644 --- a/core/relay/adaptor/openai/response.go +++ b/core/relay/adaptor/openai/response.go @@ -18,8 +18,11 @@ import ( relaymodel "github.com/labring/aiproxy/core/relay/model" "github.com/labring/aiproxy/core/relay/render" "github.com/labring/aiproxy/core/relay/utils" + "github.com/sirupsen/logrus" ) +var responseStreamInitialBufferTimeout = 2 * time.Second + // ConvertResponseRequest converts a response creation request func ConvertResponseRequest( meta *meta.Meta, @@ -41,6 +44,10 @@ func ConvertResponseRequest( } } + if err := normalizeResponsesInputSystemRole(&node); err != nil { + return adaptor.ConvertResult{}, err + } + // Set the model _, err = node.Set("model", ast.NewString(meta.ActualModel)) if err != nil { @@ -61,6 +68,50 @@ func ConvertResponseRequest( }, nil } +func normalizeResponsesInputSystemRole(node *ast.Node) error { + inputNode := node.Get("input") + if !inputNode.Exists() || inputNode.TypeSafe() != ast.V_ARRAY { + return nil + } + + inputItems, err := inputNode.ArrayUseNode() + if err != nil { + return err + } + + for index, inputItem := range inputItems { + if inputItem.TypeSafe() != ast.V_OBJECT { + continue + } + + roleNode := inputItem.Get("role") + if !roleNode.Exists() || roleNode.TypeSafe() != ast.V_STRING { + continue + } + + role, err := roleNode.String() + if err != nil { + return err + } + + if role != relaymodel.RoleSystem { + continue + } + + _, err = inputItem.Set("role", ast.NewString(relaymodel.RoleDeveloper)) + if err != nil { + return err + } + + _, err = inputNode.SetByIndex(index, inputItem) + if err != nil { + return err + } + } + + return nil +} + // ResponseHandler handles non-streaming response func ResponseHandler( meta *meta.Meta, @@ -111,6 +162,15 @@ func ResponseHandler( } } + responseBody, err = rewriteTopLevelModel(responseBody, responseModelName(meta)) + if err != nil { + return adaptor.DoResponseResult{}, relaymodel.WrapperOpenAIError( + err, + "rewrite_response_model_failed", + http.StatusInternalServerError, + ) + } + // Write response c.Writer.Header().Set("Content-Type", "application/json") c.Writer.Header().Set("Content-Length", strconv.Itoa(len(responseBody))) @@ -141,69 +201,355 @@ func ResponseStreamHandler( log := common.GetLogger(c) - scanner, cleanup := utils.NewStreamScanner(resp.Body, meta.ActualModel) - defer cleanup() + done := make(chan struct{}) + defer close(done) + + events := scanResponseStreamEvents(resp.Body, done, meta.OriginModel, meta.ActualModel) var ( - usage model.Usage - responseID string - lastResponse *relaymodel.Response + errorState responsesStreamErrorState + pendingEvents [][]byte + wroteStream bool + bufferTimer *time.Timer ) + defer func() { + stopResponseStreamBufferTimer(bufferTimer) + }() + +readLoop: + for { + var item responseStreamEventItem + + if bufferTimer == nil { + next, ok := <-events + if !ok { + break + } + + item = next + } else { + select { + case next, ok := <-events: + if !ok { + break readLoop + } + + item = next + case <-bufferTimer.C: + bufferTimer = nil + + flushDelayedResponseStreamEvents(c, &pendingEvents, &wroteStream) + continue + } + } - for scanner.Scan() { - data := scanner.Bytes() - if !render.IsValidSSEData(data) { + if item.scanErr != nil { + log.Error("error reading response stream: " + item.scanErr.Error()) continue } - data = render.ExtractSSEData(data) + if item.parseErr != nil { + log.Error("error unmarshalling response stream: " + item.parseErr.Error()) + continue + } - // Parse the stream event - var event relaymodel.ResponseStreamEvent + event := item.event + data := item.data + data = rewriteResponseStreamEventModel(data, &event, responseModelName(meta), log) - err := sonic.Unmarshal(data, &event) - if err != nil { - log.Error("error unmarshalling response stream: " + err.Error()) - continue + if err := errorState.errorBeforeEvent(&event); err != nil { + return errorState.result(), err } // Store response ID if this is the first event with a response - if event.Response != nil && responseID == "" { - responseID = event.Response.ID - if event.Response.Store && responseID != "" { - err = store.SaveStore(adaptor.StoreCache{ - ID: model.ResponseStoreID(responseID), + if event.Response != nil && errorState.responseID == "" { + errorState.responseID = event.Response.ID + if event.Response.Store && errorState.responseID != "" { + saveErr := store.SaveStore(adaptor.StoreCache{ + ID: model.ResponseStoreID(errorState.responseID), GroupID: meta.Group.ID, TokenID: meta.Token.ID, ChannelID: meta.Channel.ID, Model: meta.OriginModel, ExpiresAt: time.Now().Add(time.Hour * 24 * 7), }) - if err != nil { - log.Errorf("save response store failed: %v", err) + if saveErr != nil { + log.Errorf("save response store failed: %v", saveErr) } } } // Update usage if available - if event.Response != nil { - lastResponse = event.Response - usage = event.Response.ToModelUsage() + errorState.update(&event) + + if event.Type == relaymodel.EventResponseFailed || event.Type == relaymodel.EventError { + if wroteStream { + log.Error( + "response stream failed after data was sent: " + responseStreamErrorMessage( + &event, + ), + ) + } else { + err, handled := errorState.handleFailure(&event) + if handled && err == nil { + continue + } + + if handled { + return errorState.result(), err + } + } } + if responseStreamEventCanDelay(event.Type) && !wroteStream { + pendingEvents = append(pendingEvents, append([]byte(nil), data...)) + + if bufferTimer == nil { + bufferTimer = time.NewTimer(responseStreamInitialBufferTimeout) + } + + continue + } + + stopResponseStreamBufferTimer(bufferTimer) + bufferTimer = nil + + flushDelayedResponseStreamEvents(c, &pendingEvents, &wroteStream) + // Forward the event render.ResponsesData(c, data) + + wroteStream = true } - if err := scanner.Err(); err != nil { - log.Error("error reading response stream: " + err.Error()) + if errorState.pendingFailure != nil && !wroteStream { + return errorState.result(), responseStreamError(errorState.pendingFailure) } - return adaptor.DoResponseResult{ - Usage: usage, - UpstreamID: responseID, - AsyncUsage: responseNeedsAsyncUsage(lastResponse), - }, nil + flushDelayedResponseStreamEvents(c, &pendingEvents, &wroteStream) + + return errorState.result(), nil +} + +type responseStreamEventItem struct { + data []byte + event relaymodel.ResponseStreamEvent + parseErr error + scanErr error +} + +func scanResponseStreamEvents( + body io.Reader, + done <-chan struct{}, + modelNames ...string, +) <-chan responseStreamEventItem { + events := make(chan responseStreamEventItem, 16) + + go func() { + defer close(events) + + scanner, cleanup := utils.NewStreamScanner(body, modelNames...) + defer cleanup() + + for scanner.Scan() { + data := scanner.Bytes() + if !render.IsValidSSEData(data) { + continue + } + + data = append([]byte(nil), render.ExtractSSEData(data)...) + + var event relaymodel.ResponseStreamEvent + + parseErr := sonic.Unmarshal(data, &event) + + item := responseStreamEventItem{ + data: data, + event: event, + parseErr: parseErr, + } + + select { + case events <- item: + case <-done: + return + } + } + + if err := scanner.Err(); err != nil { + select { + case events <- responseStreamEventItem{scanErr: err}: + case <-done: + } + } + }() + + return events +} + +func stopResponseStreamBufferTimer(timer *time.Timer) { + if timer == nil { + return + } + + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } +} + +func flushDelayedResponseStreamEvents( + c *gin.Context, + pendingEvents *[][]byte, + wroteStream *bool, +) { + if len(*pendingEvents) == 0 { + return + } + + for _, pendingEvent := range *pendingEvents { + render.ResponsesData(c, pendingEvent) + } + + *pendingEvents = nil + *wroteStream = true +} + +func rewriteResponseStreamEventModel( + data []byte, + event *relaymodel.ResponseStreamEvent, + originModel string, + log *logrus.Entry, +) []byte { + if originModel == "" || event == nil || event.Response == nil || + event.Response.Model == originModel { + return data + } + + rewrittenData, err := rewriteNestedModel(data, originModel, "response", "model") + if err != nil { + log.Error("error rewriting response stream event model: " + err.Error()) + return data + } + + return rewrittenData +} + +func writeResponseObjectWithOriginModel( + meta *meta.Meta, + c *gin.Context, + resp *http.Response, +) (relaymodel.Response, adaptor.Error) { + responseBody, err := common.GetResponseBody(resp) + if err != nil { + return relaymodel.Response{}, relaymodel.WrapperOpenAIError( + err, + "read_response_body_failed", + http.StatusInternalServerError, + ) + } + + var response relaymodel.Response + + err = sonic.Unmarshal(responseBody, &response) + if err != nil { + return relaymodel.Response{}, relaymodel.WrapperOpenAIError( + err, + "unmarshal_response_body_failed", + http.StatusInternalServerError, + ) + } + + responseBody, err = rewriteTopLevelModel(responseBody, responseModelName(meta)) + if err != nil { + return response, relaymodel.WrapperOpenAIError( + err, + "rewrite_response_model_failed", + http.StatusInternalServerError, + ) + } + + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.Header().Set("Content-Length", strconv.Itoa(len(responseBody))) + _, _ = c.Writer.Write(responseBody) + + return response, nil +} + +func rewriteTopLevelModel(data []byte, modelName string) ([]byte, error) { + node, err := common.GetJSONNodeNoCopy(data) + if err != nil { + return nil, err + } + + return rewriteTopLevelModelNode(data, &node, modelName) +} + +func rewriteNestedModel(data []byte, modelName string, path ...string) ([]byte, error) { + if modelName == "" { + return data, nil + } + + node, err := common.GetJSONNodeNoCopy(data) + if err != nil { + return nil, err + } + + if len(path) == 0 { + return rewriteTopLevelModelNode(data, &node, modelName) + } + + parent := &node + for _, key := range path[:len(path)-1] { + next := parent.Get(key) + if !next.Exists() { + return data, nil + } + + parent = next + } + + modelKey := path[len(path)-1] + + modelNode := parent.Get(modelKey) + if !modelNode.Exists() { + return data, nil + } + + if currentModel, err := modelNode.String(); err == nil && currentModel == modelName { + return data, nil + } + + _, err = parent.Set(modelKey, ast.NewString(modelName)) + if err != nil { + return nil, err + } + + return node.MarshalJSON() +} + +func rewriteTopLevelModelNode(data []byte, node *ast.Node, modelName string) ([]byte, error) { + if modelName == "" || node == nil { + return data, nil + } + + modelNode := node.Get("model") + if !modelNode.Exists() { + return data, nil + } + + if currentModel, err := modelNode.String(); err == nil && currentModel == modelName { + return data, nil + } + + _, err := node.Set("model", ast.NewString(modelName)) + if err != nil { + return nil, err + } + + return node.MarshalJSON() } func responseNeedsAsyncUsage(response *relaymodel.Response) bool { @@ -225,7 +571,7 @@ func responseNeedsAsyncUsage(response *relaymodel.Response) bool { // GetResponseHandler handles GET /v1/responses/{response_id} func GetResponseHandler( - _ *meta.Meta, + meta *meta.Meta, c *gin.Context, resp *http.Response, ) (adaptor.DoResponseResult, adaptor.Error) { @@ -235,11 +581,14 @@ func GetResponseHandler( defer resp.Body.Close() - c.Writer.Header().Set("Content-Type", resp.Header.Get("Content-Type")) - c.Writer.Header().Set("Content-Length", resp.Header.Get("Content-Length")) - _, _ = io.Copy(c.Writer, resp.Body) + response, err := writeResponseObjectWithOriginModel(meta, c, resp) + if err != nil { + return adaptor.DoResponseResult{}, err + } - return adaptor.DoResponseResult{}, nil + return adaptor.DoResponseResult{ + UpstreamID: response.ID, + }, nil } // DeleteResponseHandler handles DELETE /v1/responses/{response_id} @@ -264,7 +613,7 @@ func DeleteResponseHandler( // CancelResponseHandler handles POST /v1/responses/{response_id}/cancel func CancelResponseHandler( - _ *meta.Meta, + meta *meta.Meta, c *gin.Context, resp *http.Response, ) (adaptor.DoResponseResult, adaptor.Error) { @@ -274,11 +623,16 @@ func CancelResponseHandler( defer resp.Body.Close() - c.Writer.Header().Set("Content-Type", resp.Header.Get("Content-Type")) - c.Writer.Header().Set("Content-Length", resp.Header.Get("Content-Length")) - _, _ = io.Copy(c.Writer, resp.Body) + response, err := writeResponseObjectWithOriginModel(meta, c, resp) + if err != nil { + return adaptor.DoResponseResult{}, err + } - return adaptor.DoResponseResult{}, nil + return adaptor.DoResponseResult{ + Usage: response.ToModelUsage(), + UpstreamID: response.ID, + AsyncUsage: responseNeedsAsyncUsage(&response), + }, nil } // GetInputItemsHandler handles GET /v1/responses/{response_id}/input_items diff --git a/core/relay/adaptor/openai/response_test.go b/core/relay/adaptor/openai/response_test.go index 46f6e8aa..82bc9f43 100644 --- a/core/relay/adaptor/openai/response_test.go +++ b/core/relay/adaptor/openai/response_test.go @@ -3,15 +3,20 @@ package openai import ( "bytes" + "encoding/json" "io" "net/http" "net/http/httptest" + "strings" + "sync" "testing" + "time" "github.com/gin-gonic/gin" "github.com/labring/aiproxy/core/model" "github.com/labring/aiproxy/core/relay/adaptor" "github.com/labring/aiproxy/core/relay/meta" + relaymodel "github.com/labring/aiproxy/core/relay/model" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -21,6 +26,8 @@ type responseTestStore struct { savedIfNotExist []adaptor.StoreCache } +var responseStreamInitialBufferTimeoutTestMu sync.Mutex + func (s *responseTestStore) GetStore(string, int, string) (adaptor.StoreCache, error) { return adaptor.StoreCache{}, nil } @@ -143,6 +150,219 @@ func TestResponseStreamHandlerPromptCacheRetention(t *testing.T) { assert.Equal(t, model.ZeroNullInt64(20), result.Usage.TotalTokens) } +func TestResponseStreamHandlerForwardsErrorAfterDownstreamWrite(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequestWithContext( + t.Context(), + http.MethodPost, + "/v1/responses", + nil, + ) + + body := "event: response.created\n" + + "data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_123\",\"object\":\"response\",\"created_at\":1,\"status\":\"in_progress\",\"model\":\"gpt-5\",\"output\":[],\"parallel_tool_calls\":true,\"store\":false}}\n\n" + + "event: response.output_text.delta\n" + + "data: {\"type\":\"response.output_text.delta\",\"delta\":\"partial\"}\n\n" + + "event: error\n" + + "data: {\"type\":\"error\",\"error\":{\"type\":\"server_error\",\"code\":\"server_error\",\"message\":\"stream failed\"}}\n\n" + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(body)), + Header: make(http.Header), + } + + result, err := ResponseStreamHandler(&meta.Meta{}, &responseTestStore{}, c, resp) + require.Nil(t, err) + assert.Equal(t, "resp_123", result.UpstreamID) + assert.Contains(t, recorder.Body.String(), `"type":"error"`) + assert.Contains(t, recorder.Body.String(), `"message":"stream failed"`) +} + +func TestResponseStreamHandlerReturnsErrorBeforeRealOutputAfterLifecycleEvents(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequestWithContext( + t.Context(), + http.MethodPost, + "/v1/responses", + nil, + ) + + body := "event: response.created\n" + + "data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_123\",\"object\":\"response\",\"created_at\":1,\"status\":\"in_progress\",\"model\":\"gpt-5\",\"output\":[],\"parallel_tool_calls\":true,\"store\":false}}\n\n" + + "event: response.in_progress\n" + + "data: {\"type\":\"response.in_progress\",\"response\":{\"id\":\"resp_123\",\"object\":\"response\",\"created_at\":1,\"status\":\"in_progress\",\"model\":\"gpt-5\",\"output\":[],\"parallel_tool_calls\":true,\"store\":false}}\n\n" + + "event: error\n" + + "data: {\"type\":\"error\",\"error\":{\"type\":\"server_error\",\"code\":\"server_error\",\"message\":\"stream failed\"}}\n\n" + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(body)), + Header: make(http.Header), + } + + result, err := ResponseStreamHandler(&meta.Meta{}, &responseTestStore{}, c, resp) + require.NotNil(t, err) + assert.Equal(t, http.StatusBadGateway, err.StatusCode()) + assert.Equal(t, "resp_123", result.UpstreamID) + assert.Empty(t, recorder.Body.String()) +} + +func TestResponseStreamHandlerFailedWithoutErrorDoesNotMarkAsyncUsage(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequestWithContext( + t.Context(), + http.MethodPost, + "/v1/responses", + nil, + ) + + body := "event: response.created\n" + + "data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_failed\",\"object\":\"response\",\"created_at\":1,\"status\":\"in_progress\",\"model\":\"gpt-5\",\"output\":[],\"parallel_tool_calls\":true,\"store\":true}}\n\n" + + "event: response.in_progress\n" + + "data: {\"type\":\"response.in_progress\",\"response\":{\"id\":\"resp_failed\",\"object\":\"response\",\"created_at\":1,\"status\":\"in_progress\",\"model\":\"gpt-5\",\"output\":[],\"parallel_tool_calls\":true,\"store\":true}}\n\n" + + "event: response.failed\n" + + "data: {\"type\":\"response.failed\",\"response\":{\"id\":\"resp_failed\",\"object\":\"response\",\"created_at\":1,\"status\":\"failed\",\"model\":\"gpt-5\",\"output\":[],\"parallel_tool_calls\":true,\"store\":true}}\n\n" + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(body)), + Header: make(http.Header), + } + + result, err := ResponseStreamHandler(&meta.Meta{}, &responseTestStore{}, c, resp) + require.NotNil(t, err) + assert.Equal(t, http.StatusBadGateway, err.StatusCode()) + assert.Equal(t, "resp_failed", result.UpstreamID) + assert.False(t, result.AsyncUsage) + assert.Empty(t, recorder.Body.String()) +} + +func TestResponseStreamHandlerFlushesLifecycleEventsOnOfficialTextStreamOrder(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequestWithContext( + t.Context(), + http.MethodPost, + "/v1/responses", + nil, + ) + + body := strings.Join([]string{ + "event: response.created", + "data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_text\",\"object\":\"response\",\"created_at\":1,\"status\":\"in_progress\",\"model\":\"gpt-5.4\",\"output\":[],\"parallel_tool_calls\":true,\"store\":false,\"usage\":null}}", + "", + "event: response.in_progress", + "data: {\"type\":\"response.in_progress\",\"response\":{\"id\":\"resp_text\",\"object\":\"response\",\"created_at\":1,\"status\":\"in_progress\",\"model\":\"gpt-5.4\",\"output\":[],\"parallel_tool_calls\":true,\"store\":false,\"usage\":null}}", + "", + "event: response.output_item.added", + "data: {\"type\":\"response.output_item.added\",\"output_index\":0,\"item\":{\"id\":\"msg_text\",\"type\":\"message\",\"status\":\"in_progress\",\"role\":\"assistant\",\"content\":[]}}", + "", + "event: response.content_part.added", + "data: {\"type\":\"response.content_part.added\",\"item_id\":\"msg_text\",\"output_index\":0,\"content_index\":0,\"part\":{\"type\":\"output_text\",\"text\":\"\",\"annotations\":[]}}", + "", + "event: response.output_text.delta", + "data: {\"type\":\"response.output_text.delta\",\"item_id\":\"msg_text\",\"output_index\":0,\"content_index\":0,\"delta\":\"Hi\"}", + "", + "event: response.output_text.done", + "data: {\"type\":\"response.output_text.done\",\"item_id\":\"msg_text\",\"output_index\":0,\"content_index\":0,\"text\":\"Hi\"}", + "", + "event: response.content_part.done", + "data: {\"type\":\"response.content_part.done\",\"item_id\":\"msg_text\",\"output_index\":0,\"content_index\":0,\"part\":{\"type\":\"output_text\",\"text\":\"Hi\",\"annotations\":[]}}", + "", + "event: response.output_item.done", + "data: {\"type\":\"response.output_item.done\",\"output_index\":0,\"item\":{\"id\":\"msg_text\",\"type\":\"message\",\"status\":\"completed\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"Hi\",\"annotations\":[]}]}}", + "", + "event: response.completed", + "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_text\",\"object\":\"response\",\"created_at\":1,\"status\":\"completed\",\"model\":\"gpt-5.4\",\"output\":[{\"id\":\"msg_text\",\"type\":\"message\",\"status\":\"completed\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"Hi\",\"annotations\":[]}]}],\"parallel_tool_calls\":true,\"store\":false,\"usage\":{\"input_tokens\":1,\"output_tokens\":1,\"total_tokens\":2}}}", + "", + }, "\n") + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(body)), + Header: make(http.Header), + } + + result, err := ResponseStreamHandler(&meta.Meta{}, &responseTestStore{}, c, resp) + require.Nil(t, err) + assert.Equal(t, "resp_text", result.UpstreamID) + assert.Equal(t, model.ZeroNullInt64(2), result.Usage.TotalTokens) + + output := recorder.Body.String() + assert.Contains(t, output, "response.created") + assert.Contains(t, output, "response.in_progress") + assert.Contains(t, output, "response.output_item.added") + assert.Contains(t, output, "response.content_part.added") + assert.Contains(t, output, "response.output_text.delta") + assert.Contains(t, output, "response.completed") +} + +func TestResponseStreamHandlerStartsBufferTimeoutFromFirstDelayedEvent(t *testing.T) { + responseStreamInitialBufferTimeoutTestMu.Lock() + defer responseStreamInitialBufferTimeoutTestMu.Unlock() + + gin.SetMode(gin.TestMode) + + oldTimeout := responseStreamInitialBufferTimeout + responseStreamInitialBufferTimeout = time.Millisecond + t.Cleanup(func() { + responseStreamInitialBufferTimeout = oldTimeout + }) + + reader, writer := io.Pipe() + defer writer.Close() + + go func() { + _, _ = writer.Write([]byte(strings.Join([]string{ + "event: response.in_progress", + "data: {\"type\":\"response.in_progress\",\"response\":{\"id\":\"resp_timeout\",\"object\":\"response\",\"created_at\":1,\"status\":\"in_progress\",\"model\":\"gpt-5.4\",\"output\":[],\"parallel_tool_calls\":true,\"store\":false,\"usage\":null}}", + "", + }, "\n"))) + + time.Sleep(20 * time.Millisecond) + + _, _ = writer.Write([]byte(strings.Join([]string{ + "event: response.completed", + "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_timeout\",\"object\":\"response\",\"created_at\":1,\"status\":\"completed\",\"model\":\"gpt-5.4\",\"output\":[],\"parallel_tool_calls\":true,\"store\":false,\"usage\":{\"input_tokens\":1,\"output_tokens\":1,\"total_tokens\":2}}}", + "", + }, "\n"))) + _ = writer.Close() + }() + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequestWithContext( + t.Context(), + http.MethodPost, + "/v1/responses", + nil, + ) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: reader, + Header: make(http.Header), + } + + result, err := ResponseStreamHandler(&meta.Meta{}, &responseTestStore{}, c, resp) + require.Nil(t, err) + assert.Equal(t, "resp_timeout", result.UpstreamID) + assert.Contains(t, recorder.Body.String(), "response.in_progress") + assert.Contains(t, recorder.Body.String(), "response.completed") +} + func TestResponseHandlerWebSearchCountFromToolUsage(t *testing.T) { t.Parallel() gin.SetMode(gin.TestMode) @@ -238,6 +458,212 @@ func TestResponseHandlerStoreUsesOriginModel(t *testing.T) { assert.Equal(t, "resp_store_origin", result.UpstreamID) require.Len(t, store.saved, 1) assert.Equal(t, "gpt-5", store.saved[0].Model) + + var payload relaymodel.Response + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload)) + assert.Equal(t, "gpt-5", payload.Model) + assert.NotContains(t, recorder.Body.String(), "mapped-gpt-5") +} + +func TestResponseHandlerRewritesOnlyModelField(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequestWithContext(t.Context(), http.MethodPost, "/v1/responses", nil) + m := &meta.Meta{ + OriginModel: "gpt-5", + ActualModel: "mapped-gpt-5", + } + + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(`{ + "id":"resp_extra", + "object":"response", + "created_at":1, + "status":"completed", + "model":"mapped-gpt-5", + "output":[], + "parallel_tool_calls":true, + "store":false, + "provider_extra":{"future_field":"kept"}, + "future_top_level":"kept" + }`)), + Header: make(http.Header), + } + + _, err := ResponseHandler(m, &responseTestStore{}, c, resp) + require.Nil(t, err) + assert.Contains(t, recorder.Body.String(), `"model":"gpt-5"`) + assert.Contains(t, recorder.Body.String(), `"provider_extra":{"future_field":"kept"}`) + assert.Contains(t, recorder.Body.String(), `"future_top_level":"kept"`) + assert.NotContains(t, recorder.Body.String(), "mapped-gpt-5") +} + +func TestConvertResponseRequestMapsSystemInputRoleToDeveloper(t *testing.T) { + t.Parallel() + + req := httptest.NewRequestWithContext( + t.Context(), + http.MethodPost, + "/v1/responses", + strings.NewReader(`{ + "model":"gpt-5.5", + "input":[ + {"type":"message","role":"system","content":[{"type":"input_text","text":"Be concise"}]}, + {"type":"message","role":"user","content":[{"type":"input_text","text":"Hello"}]} + ] + }`), + ) + req.Header.Set("Content-Type", "application/json") + + m := &meta.Meta{ + ActualModel: "mapped-gpt-5.5", + } + + result, err := ConvertResponseRequest(m, req) + require.NoError(t, err) + + var body map[string]any + + err = json.NewDecoder(result.Body).Decode(&body) + require.NoError(t, err) + + input, ok := body["input"].([]any) + require.True(t, ok) + require.Len(t, input, 2) + + first, ok := input[0].(map[string]any) + require.True(t, ok) + assert.Equal(t, "developer", first["role"]) + + second, ok := input[1].(map[string]any) + require.True(t, ok) + assert.Equal(t, "user", second["role"]) + assert.Equal(t, "mapped-gpt-5.5", body["model"]) +} + +func TestGetResponseHandlerRewritesModelToOriginModel(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequestWithContext( + t.Context(), + http.MethodGet, + "/v1/responses/resp_1", + nil, + ) + m := &meta.Meta{ + OriginModel: "gpt-5", + ActualModel: "mapped-gpt-5", + } + + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(`{ + "id":"resp_1", + "object":"response", + "created_at":1, + "status":"completed", + "model":"mapped-gpt-5", + "output":[], + "parallel_tool_calls":true, + "store":false + }`)), + Header: make(http.Header), + } + + result, err := GetResponseHandler(m, c, resp) + require.Nil(t, err) + assert.Equal(t, "resp_1", result.UpstreamID) + assert.Zero(t, result.Usage.TotalTokens) + assert.False(t, result.AsyncUsage) + assert.Contains(t, recorder.Body.String(), `"model":"gpt-5"`) + assert.NotContains(t, recorder.Body.String(), "mapped-gpt-5") +} + +func TestGetResponseHandlerDoesNotReportUsageForCompletedResponse(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequestWithContext( + t.Context(), + http.MethodGet, + "/v1/responses/resp_1", + nil, + ) + m := &meta.Meta{ + OriginModel: "gpt-5", + ActualModel: "mapped-gpt-5", + } + + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(`{ + "id":"resp_1", + "object":"response", + "created_at":1, + "status":"completed", + "model":"mapped-gpt-5", + "output":[], + "parallel_tool_calls":true, + "store":true, + "usage":{"input_tokens":7,"output_tokens":13,"total_tokens":20} + }`)), + Header: make(http.Header), + } + + result, err := GetResponseHandler(m, c, resp) + require.Nil(t, err) + assert.Equal(t, "resp_1", result.UpstreamID) + assert.Zero(t, result.Usage.TotalTokens) + assert.False(t, result.AsyncUsage) + assert.Contains(t, recorder.Body.String(), `"usage"`) +} + +func TestCancelResponseHandlerRewritesModelToOriginModel(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequestWithContext( + t.Context(), + http.MethodPost, + "/v1/responses/resp_1/cancel", + nil, + ) + m := &meta.Meta{ + OriginModel: "gpt-5", + ActualModel: "mapped-gpt-5", + } + + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(`{ + "id":"resp_1", + "object":"response", + "created_at":1, + "status":"in_progress", + "model":"mapped-gpt-5", + "output":[], + "parallel_tool_calls":true, + "store":false + }`)), + Header: make(http.Header), + } + + result, err := CancelResponseHandler(m, c, resp) + require.Nil(t, err) + assert.Equal(t, "resp_1", result.UpstreamID) + assert.Contains(t, recorder.Body.String(), `"model":"gpt-5"`) + assert.NotContains(t, recorder.Body.String(), "mapped-gpt-5") } func TestResponseHandlerAsyncUsageInProgress(t *testing.T) { @@ -315,18 +741,26 @@ func TestResponseStreamHandlerForegroundImageGenerationContinuesToCompleted(t *t nil, ) - body := "event: response.created\n" + - "data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_generating_async\",\"object\":\"response\",\"created_at\":1,\"status\":\"in_progress\",\"background\":false,\"model\":\"gpt-5.4\",\"output\":[],\"parallel_tool_calls\":true,\"store\":false,\"tool_usage\":{\"image_gen\":{\"input_tokens\":0,\"output_tokens\":0,\"total_tokens\":0},\"web_search\":{\"num_requests\":0}},\"tools\":[{\"type\":\"image_generation\",\"background\":\"auto\",\"model\":\"gpt-image-2\"}],\"usage\":null}}\n\n" + - "event: response.in_progress\n" + - "data: {\"type\":\"response.in_progress\",\"response\":{\"id\":\"resp_generating_async\",\"object\":\"response\",\"created_at\":1,\"status\":\"in_progress\",\"background\":false,\"model\":\"gpt-5.4\",\"output\":[],\"parallel_tool_calls\":true,\"store\":false,\"tool_usage\":{\"image_gen\":{\"input_tokens\":0,\"output_tokens\":0,\"total_tokens\":0},\"web_search\":{\"num_requests\":0}},\"tools\":[{\"type\":\"image_generation\",\"background\":\"auto\",\"model\":\"gpt-image-2\"}],\"usage\":null}}\n\n" + - "event: response.output_item.added\n" + - "data: {\"type\":\"response.output_item.added\",\"item\":{\"id\":\"ig_generating_async\",\"type\":\"image_generation_call\",\"status\":\"in_progress\"},\"output_index\":0,\"sequence_number\":2}\n\n" + - "event: response.image_generation_call.generating\n" + - "data: {\"type\":\"response.image_generation_call.generating\",\"item_id\":\"ig_generating_async\",\"output_index\":0,\"sequence_number\":3}\n\n" + - "event: keepalive\n" + - "data: {\"type\":\"keepalive\",\"sequence_number\":4}\n\n" + - "event: response.completed\n" + - "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_generating_async\",\"object\":\"response\",\"created_at\":2,\"status\":\"completed\",\"model\":\"gpt-5.4\",\"output\":[],\"parallel_tool_calls\":true,\"store\":false,\"usage\":{\"input_tokens\":1,\"output_tokens\":2,\"total_tokens\":3}}}\n\n" + body := strings.Join([]string{ + "event: response.created", + "data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_generating_async\",\"object\":\"response\",\"created_at\":1,\"status\":\"in_progress\",\"background\":false,\"model\":\"gpt-5.4\",\"output\":[],\"parallel_tool_calls\":true,\"store\":false,\"tool_usage\":{\"image_gen\":{\"input_tokens\":0,\"output_tokens\":0,\"total_tokens\":0},\"web_search\":{\"num_requests\":0}},\"tools\":[{\"type\":\"image_generation\",\"background\":\"auto\",\"model\":\"gpt-image-2\"}],\"usage\":null}}", + "", + "event: response.in_progress", + "data: {\"type\":\"response.in_progress\",\"response\":{\"id\":\"resp_generating_async\",\"object\":\"response\",\"created_at\":1,\"status\":\"in_progress\",\"background\":false,\"model\":\"gpt-5.4\",\"output\":[],\"parallel_tool_calls\":true,\"store\":false,\"tool_usage\":{\"image_gen\":{\"input_tokens\":0,\"output_tokens\":0,\"total_tokens\":0},\"web_search\":{\"num_requests\":0}},\"tools\":[{\"type\":\"image_generation\",\"background\":\"auto\",\"model\":\"gpt-image-2\"}],\"usage\":null}}", + "", + "event: response.output_item.added", + "data: {\"type\":\"response.output_item.added\",\"item\":{\"id\":\"ig_generating_async\",\"type\":\"image_generation_call\",\"status\":\"in_progress\"},\"output_index\":0,\"sequence_number\":2}", + "", + "event: response.image_generation_call.generating", + "data: {\"type\":\"response.image_generation_call.generating\",\"item_id\":\"ig_generating_async\",\"output_index\":0,\"sequence_number\":3}", + "", + "event: " + relaymodel.EventKeepAlive, + "data: {\"type\":\"" + relaymodel.EventKeepAlive + "\",\"sequence_number\":4}", + "", + "event: response.completed", + "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_generating_async\",\"object\":\"response\",\"created_at\":2,\"status\":\"completed\",\"model\":\"gpt-5.4\",\"output\":[],\"parallel_tool_calls\":true,\"store\":false,\"usage\":{\"input_tokens\":1,\"output_tokens\":2,\"total_tokens\":3}}}", + "", + }, "\n") resp := &http.Response{ StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewBufferString(body)), @@ -450,6 +884,43 @@ func TestResponseStreamHandlerStoreUsesOriginModel(t *testing.T) { assert.Equal(t, "resp_stream_store_origin", result.UpstreamID) require.Len(t, store.saved, 1) assert.Equal(t, "gpt-5", store.saved[0].Model) + assert.Contains(t, recorder.Body.String(), `"model":"gpt-5"`) + assert.NotContains(t, recorder.Body.String(), "mapped-gpt-5") +} + +func TestResponseStreamHandlerRewritesOnlyResponseModelField(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequestWithContext(t.Context(), http.MethodPost, "/v1/responses", nil) + m := &meta.Meta{ + OriginModel: "gpt-5", + ActualModel: "mapped-gpt-5", + } + + body := strings.Join([]string{ + "event: response.created", + `data: {"type":"response.created","response":{"id":"resp_stream_extra","object":"response","created_at":1,"status":"in_progress","model":"mapped-gpt-5","output":[],"parallel_tool_calls":true,"store":false,"provider_extra":{"future_field":"kept"},"future_response_field":"kept"},"future_event_field":"kept"}`, + "", + "event: response.output_text.delta", + `data: {"type":"response.output_text.delta","delta":"hi","future_event_field":"kept"}`, + "", + }, "\n") + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(body)), + Header: make(http.Header), + } + + _, err := ResponseStreamHandler(m, &responseTestStore{}, c, resp) + require.Nil(t, err) + assert.Contains(t, recorder.Body.String(), `"model":"gpt-5"`) + assert.Contains(t, recorder.Body.String(), `"provider_extra":{"future_field":"kept"}`) + assert.Contains(t, recorder.Body.String(), `"future_response_field":"kept"`) + assert.Contains(t, recorder.Body.String(), `"future_event_field":"kept"`) + assert.NotContains(t, recorder.Body.String(), "mapped-gpt-5") } func TestVideoHandlerMarksAsyncUsage(t *testing.T) { @@ -467,7 +938,7 @@ func TestVideoHandlerMarksAsyncUsage(t *testing.T) { m := &meta.Meta{ OriginModel: "sora-2", - ActualModel: "sora-2", + ActualModel: "mapped-sora-2", Group: model.GroupCache{ID: "group-1"}, Token: model.TokenCache{ID: 7}, Channel: meta.ChannelMeta{ID: 9}, @@ -478,7 +949,7 @@ func TestVideoHandlerMarksAsyncUsage(t *testing.T) { "id":"video_job_async", "object":"video.generation.job", "status":"queued", - "model":"sora-2" + "model":"mapped-sora-2" }`)), Header: make(http.Header), } @@ -487,6 +958,8 @@ func TestVideoHandlerMarksAsyncUsage(t *testing.T) { require.Nil(t, err) assert.Equal(t, "video_job_async", result.UpstreamID) assert.True(t, result.AsyncUsage) + assert.Contains(t, recorder.Body.String(), `"model":"sora-2"`) + assert.NotContains(t, recorder.Body.String(), "mapped-sora-2") } func TestVideosHandlerStoresVideoAndMarksAsyncUsage(t *testing.T) { @@ -504,7 +977,7 @@ func TestVideosHandlerStoresVideoAndMarksAsyncUsage(t *testing.T) { m := &meta.Meta{ OriginModel: "sora-2", - ActualModel: "sora-2", + ActualModel: "mapped-sora-2", Group: model.GroupCache{ID: "group-1"}, Token: model.TokenCache{ID: 7}, Channel: meta.ChannelMeta{ID: 9}, @@ -516,7 +989,7 @@ func TestVideosHandlerStoresVideoAndMarksAsyncUsage(t *testing.T) { "id":"video_async", "object":"video", "status":"queued", - "model":"sora-2" + "model":"mapped-sora-2" }`)), Header: make(http.Header), } @@ -527,4 +1000,79 @@ func TestVideosHandlerStoresVideoAndMarksAsyncUsage(t *testing.T) { assert.True(t, result.AsyncUsage) require.Len(t, store.saved, 1) assert.Equal(t, model.VideoGenerationStoreID("video_async"), store.saved[0].ID) + assert.Contains(t, recorder.Body.String(), `"model":"sora-2"`) + assert.NotContains(t, recorder.Body.String(), "mapped-sora-2") +} + +func TestVideoGetJobsHandlerRewritesModelToOriginModel(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequestWithContext( + t.Context(), + http.MethodGet, + "/v1/video/generations/jobs/job_1", + nil, + ) + + m := &meta.Meta{ + OriginModel: "sora-2", + ActualModel: "mapped-sora-2", + Group: model.GroupCache{ID: "group-1"}, + Token: model.TokenCache{ID: 7}, + Channel: meta.ChannelMeta{ID: 9}, + } + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(`{ + "id":"job_1", + "object":"video.generation.job", + "status":"succeeded", + "model":"mapped-sora-2", + "expires_at":1780000000, + "generations":[{"id":"gen_1","object":"video.generation","job_id":"job_1"}] + }`)), + Header: make(http.Header), + } + + _, err := VideoGetJobsHandler(m, &responseTestStore{}, c, resp) + require.Nil(t, err) + assert.Contains(t, recorder.Body.String(), `"model":"sora-2"`) + assert.NotContains(t, recorder.Body.String(), "mapped-sora-2") +} + +func TestVideosGetHandlerRewritesModelToOriginModel(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequestWithContext( + t.Context(), + http.MethodGet, + "/v1/videos/video_1", + nil, + ) + + m := &meta.Meta{ + OriginModel: "sora-2", + ActualModel: "mapped-sora-2", + } + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(`{ + "id":"video_1", + "object":"video", + "status":"completed", + "model":"mapped-sora-2" + }`)), + Header: make(http.Header), + } + + _, err := VideosGetHandler(m, c, resp) + require.Nil(t, err) + assert.Contains(t, recorder.Body.String(), `"model":"sora-2"`) + assert.NotContains(t, recorder.Body.String(), "mapped-sora-2") } diff --git a/core/relay/adaptor/openai/video.go b/core/relay/adaptor/openai/video.go index 7982fa12..19269f3a 100644 --- a/core/relay/adaptor/openai/video.go +++ b/core/relay/adaptor/openai/video.go @@ -174,7 +174,7 @@ func VideoHandler( ) } - idNode, err := common.GetJSONNodeNoCopy(responseBody, "id") + node, err := common.GetJSONNodeNoCopy(responseBody) if err != nil { return adaptor.DoResponseResult{}, relaymodel.WrapperOpenAIVideoError( err, @@ -182,6 +182,8 @@ func VideoHandler( ) } + idNode := node.Get("id") + id, err := idNode.String() if err != nil { return adaptor.DoResponseResult{}, relaymodel.WrapperOpenAIVideoError( @@ -203,6 +205,14 @@ func VideoHandler( log.Errorf("save store failed: %v", err) } + responseBody, err = rewriteTopLevelModelNode(responseBody, &node, responseModelName(meta)) + if err != nil { + return adaptor.DoResponseResult{}, relaymodel.WrapperOpenAIVideoError( + err, + http.StatusInternalServerError, + ) + } + c.Writer.Header().Set("Content-Type", "application/json") c.Writer.Header().Set("Content-Length", strconv.Itoa(len(responseBody))) _, _ = c.Writer.Write(responseBody) @@ -269,7 +279,7 @@ func videosHandler( ) } - idNode, err := common.GetJSONNodeNoCopy(responseBody, "id") + node, err := common.GetJSONNodeNoCopy(responseBody) if err != nil { return adaptor.DoResponseResult{}, relaymodel.WrapperOpenAIVideoError( err, @@ -277,6 +287,8 @@ func videosHandler( ) } + idNode := node.Get("id") + id, err := idNode.String() if err != nil { return adaptor.DoResponseResult{}, relaymodel.WrapperOpenAIVideoError( @@ -290,6 +302,14 @@ func videosHandler( log.Errorf("save video store failed: %v", err) } + responseBody, err = rewriteTopLevelModelNode(responseBody, &node, responseModelName(meta)) + if err != nil { + return adaptor.DoResponseResult{}, relaymodel.WrapperOpenAIVideoError( + err, + http.StatusInternalServerError, + ) + } + c.Writer.Header().Set("Content-Type", "application/json") c.Writer.Header().Set("Content-Length", strconv.Itoa(len(responseBody))) _, _ = c.Writer.Write(responseBody) @@ -390,6 +410,14 @@ func VideoGetJobsHandler( ) } + responseBody, err = rewriteTopLevelModelNode(responseBody, &node, responseModelName(meta)) + if err != nil { + return adaptor.DoResponseResult{}, relaymodel.WrapperOpenAIVideoError( + err, + http.StatusInternalServerError, + ) + } + c.Writer.Header().Set("Content-Type", "application/json") c.Writer.Header().Set("Content-Length", strconv.Itoa(len(responseBody))) _, _ = c.Writer.Write(responseBody) @@ -421,7 +449,7 @@ func VideosGetHandler( c *gin.Context, resp *http.Response, ) (adaptor.DoResponseResult, adaptor.Error) { - return videoObjectHandler(c, resp) + return videoObjectHandler(meta, c, resp) } func VideosContentHandler( @@ -433,6 +461,7 @@ func VideosContentHandler( } func videoObjectHandler( + meta *meta.Meta, c *gin.Context, resp *http.Response, ) (adaptor.DoResponseResult, adaptor.Error) { @@ -450,6 +479,14 @@ func videoObjectHandler( ) } + responseBody, err = rewriteTopLevelModel(responseBody, responseModelName(meta)) + if err != nil { + return adaptor.DoResponseResult{}, relaymodel.WrapperOpenAIVideoError( + err, + http.StatusInternalServerError, + ) + } + c.Writer.Header().Set("Content-Type", firstNonEmptyString( resp.Header.Get("Content-Type"), "application/json", diff --git a/core/relay/adaptor/qianfan/adaptor.go b/core/relay/adaptor/qianfan/adaptor.go index 461ba22c..d009d493 100644 --- a/core/relay/adaptor/qianfan/adaptor.go +++ b/core/relay/adaptor/qianfan/adaptor.go @@ -102,7 +102,7 @@ func qianfanModelMatches(mt *meta.Meta, match func(string) bool) bool { return false } - return utils.FirstMatchingModelName(mt.OriginModel, mt.ActualModel, match) != "" + return utils.FirstMatchingModelName(match, mt.OriginModel, mt.ActualModel) != "" } func isBuiltinResponsesModel(modelName string) bool { diff --git a/core/relay/adaptor/vertexai/adaptor.go b/core/relay/adaptor/vertexai/adaptor.go index d7cf9350..1578ef3a 100644 --- a/core/relay/adaptor/vertexai/adaptor.go +++ b/core/relay/adaptor/vertexai/adaptor.go @@ -71,12 +71,12 @@ func resolveFeatureModel(meta *meta.Meta) string { } if modelName := utils.FirstMatchingModelName( - meta.OriginModel, - meta.ActualModel, func(modelName string) bool { modelName = strings.ToLower(modelName) return strings.Contains(modelName, "gemini") || strings.Contains(modelName, "claude") }, + meta.OriginModel, + meta.ActualModel, ); modelName != "" { return modelName } diff --git a/core/relay/model/constant.go b/core/relay/model/constant.go index 5c8091ce..4d79e23f 100644 --- a/core/relay/model/constant.go +++ b/core/relay/model/constant.go @@ -3,6 +3,7 @@ package model // Common Role constants (used across different API formats) const ( RoleSystem = "system" + RoleDeveloper = "developer" RoleUser = "user" RoleAssistant = "assistant" RoleTool = "tool" diff --git a/core/relay/model/response.go b/core/relay/model/response.go index c89c5d76..3f272596 100644 --- a/core/relay/model/response.go +++ b/core/relay/model/response.go @@ -119,6 +119,9 @@ const ( // Error event EventError ResponseStreamEventType = "error" + + // Compatibility event + EventKeepAlive ResponseStreamEventType = "keepalive" ) // ResponseError represents an error in a response @@ -323,6 +326,7 @@ type InputItemList struct { type ResponseStreamEvent struct { Type string `json:"type"` Response *Response `json:"response,omitempty"` + Error *OpenAIError `json:"error,omitempty"` OutputIndex *int `json:"output_index,omitempty"` Item *OutputItem `json:"item,omitempty"` ItemID string `json:"item_id,omitempty"` diff --git a/core/relay/utils/reasoning.go b/core/relay/utils/reasoning.go index 8d82fa07..8a82cca6 100644 --- a/core/relay/utils/reasoning.go +++ b/core/relay/utils/reasoning.go @@ -677,20 +677,37 @@ func PreferredModelName(originModel, actualModel string) string { } func FirstMatchingModelName( - originModel string, - actualModel string, match func(string) bool, + modelNames ...string, ) string { if match == nil { - return PreferredModelName(originModel, actualModel) - } + if len(modelNames) == 0 { + return "" + } - if originModel != "" && match(originModel) { - return originModel + for _, modelName := range modelNames { + if modelName != "" { + return modelName + } + } + + return "" } - if actualModel != "" && actualModel != originModel && match(actualModel) { - return actualModel + seen := map[string]struct{}{} + for _, modelName := range modelNames { + if modelName == "" { + continue + } + + if _, ok := seen[modelName]; ok { + continue + } + + seen[modelName] = struct{}{} + if match(modelName) { + return modelName + } } return "" @@ -739,9 +756,9 @@ func ClampReasoningBudget(maxTokens *int, budget int) int { } func resolveGeminiModelName(originModel, actualModel string) string { - if modelName := FirstMatchingModelName(originModel, actualModel, func(modelName string) bool { + if modelName := FirstMatchingModelName(func(modelName string) bool { return strings.Contains(strings.ToLower(modelName), "gemini") - }); modelName != "" { + }, originModel, actualModel); modelName != "" { return modelName } @@ -749,14 +766,14 @@ func resolveGeminiModelName(originModel, actualModel string) string { } func resolveAliModelName(originModel, actualModel string) string { - if modelName := FirstMatchingModelName(originModel, actualModel, func(modelName string) bool { + if modelName := FirstMatchingModelName(func(modelName string) bool { modelName = strings.ToLower(modelName) return strings.HasPrefix(modelName, "qwen") || strings.HasPrefix(modelName, "qwq-") || strings.Contains(modelName, "glm") || strings.Contains(modelName, "kimi") - }); modelName != "" { + }, originModel, actualModel); modelName != "" { return modelName } diff --git a/core/relay/utils/utils.go b/core/relay/utils/utils.go index dcee836e..614e0efd 100644 --- a/core/relay/utils/utils.go +++ b/core/relay/utils/utils.go @@ -277,10 +277,10 @@ func IsImageModel(modelName string) bool { // NewStreamScanner creates a bufio.Scanner with appropriate buffer size based on model type. // Returns the scanner and a cleanup function that must be called when done. -func NewStreamScanner(r io.Reader, modelName string) (*bufio.Scanner, func()) { +func NewStreamScanner(r io.Reader, modelNames ...string) (*bufio.Scanner, func()) { scanner := bufio.NewScanner(r) - if IsImageModel(modelName) { + if FirstMatchingModelName(IsImageModel, modelNames...) != "" { buf := GetImageScannerBuffer() scanner.Buffer(*buf, cap(*buf)) diff --git a/core/relay/utils/utils_test.go b/core/relay/utils/utils_test.go index 308334db..cc0c6273 100644 --- a/core/relay/utils/utils_test.go +++ b/core/relay/utils/utils_test.go @@ -54,6 +54,29 @@ func TestScannerBuffer(t *testing.T) { }) } +func TestNewStreamScannerUsesImageBufferForAnyMappedModel(t *testing.T) { + convey.Convey( + "NewStreamScanner should use image buffer when origin or actual model is image", + t, + func() { + largeLine := bytes.Repeat([]byte("x"), utils.ScannerBufferSize+1) + lineLength := len(largeLine) + largeLine = append(largeLine, '\n') + + scanner, cleanup := utils.NewStreamScanner( + bytes.NewReader(largeLine), + "gpt-image-1", + "mapped-chat-model", + ) + defer cleanup() + + convey.So(scanner.Scan(), convey.ShouldBeTrue) + convey.So(len(scanner.Bytes()), convey.ShouldEqual, lineLength) + convey.So(scanner.Err(), convey.ShouldBeNil) + }, + ) +} + func TestDoRequest(t *testing.T) { convey.Convey("DoRequest", t, func() { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {