diff --git a/core/relay/adaptor/openai/chat.go b/core/relay/adaptor/openai/chat.go index 37a73bd9..874110c8 100644 --- a/core/relay/adaptor/openai/chat.go +++ b/core/relay/adaptor/openai/chat.go @@ -4,6 +4,7 @@ import ( "bytes" "errors" "net/http" + "slices" "strconv" "strings" "time" @@ -804,6 +805,148 @@ func ConvertToolsToResponseTools(tools []relaymodel.Tool) []relaymodel.ResponseT return responseTools } +func convertLegacyFunctionsToResponseTools(functions any) []relaymodel.ResponseTool { + functionList, ok := functions.([]any) + if !ok || len(functionList) == 0 { + return nil + } + + responseTools := make([]relaymodel.ResponseTool, 0, len(functionList)) + for _, function := range functionList { + functionMap, ok := function.(map[string]any) + if !ok { + continue + } + + name, _ := functionMap["name"].(string) + if name == "" { + continue + } + + description, _ := functionMap["description"].(string) + + responseTools = append(responseTools, relaymodel.ResponseTool{ + Type: relaymodel.ToolChoiceTypeFunction, + Name: name, + Description: description, + Parameters: CleanToolParameters(functionMap["parameters"]), + }) + } + + return responseTools +} + +func convertLegacyFunctionCallToResponseToolChoice(functionCall any) any { + switch value := functionCall.(type) { + case string: + switch value { + case "auto", "none": + return value + default: + return nil + } + case map[string]any: + name, _ := value["name"].(string) + if name == "" { + return nil + } + + return map[string]any{ + "type": "function", + "name": name, + } + default: + return nil + } +} + +func convertChatToolChoiceToResponseToolChoice(toolChoice any) any { + toolChoiceMap, ok := toolChoice.(map[string]any) + if !ok { + return toolChoice + } + + toolType, _ := toolChoiceMap["type"].(string) + + functionMap, ok := toolChoiceMap["function"].(map[string]any) + if !ok || toolType != relaymodel.ToolChoiceTypeFunction { + return toolChoice + } + + name, _ := functionMap["name"].(string) + if name == "" { + return toolChoice + } + + return map[string]any{ + "type": relaymodel.ToolChoiceTypeFunction, + "name": name, + } +} + +func convertChatResponseFormatToResponseText( + responseFormat *relaymodel.ResponseFormat, +) *relaymodel.ResponseText { + if responseFormat == nil || responseFormat.Type == "" { + return nil + } + + format := relaymodel.ResponseTextFormat{ + Type: responseFormat.Type, + } + + if responseFormat.JSONSchema != nil { + format.Name = responseFormat.JSONSchema.Name + format.Schema = responseFormat.JSONSchema.Schema + format.Strict = responseFormat.JSONSchema.Strict + format.Description = responseFormat.JSONSchema.Description + } + + return &relaymodel.ResponseText{Format: format} +} + +func appendUniqueString(values []string, value string) []string { + if slices.Contains(values, value) { + return values + } + + return append(values, value) +} + +func appendChatContentPartToResponseInput( + inputItem *relaymodel.InputItem, + contentType relaymodel.InputContentType, + part map[string]any, +) { + partType, _ := part["type"].(string) + switch partType { + case relaymodel.ContentTypeText: + text, _ := part["text"].(string) + if text == "" { + return + } + + inputItem.Content = append(inputItem.Content, relaymodel.InputContent{ + Type: contentType, + Text: text, + }) + case relaymodel.ContentTypeImageURL: + imageURL, _ := part["image_url"].(map[string]any) + + url, _ := imageURL["url"].(string) + if url == "" { + return + } + + detail, _ := imageURL["detail"].(string) + inputItem.Content = append(inputItem.Content, relaymodel.InputContent{ + Type: "input_image", + ImageURL: url, + Detail: detail, + }) + } +} + // ConvertMessagesToInputItems converts Message array to InputItem array for Responses API func ConvertMessagesToInputItems(messages []relaymodel.Message) []relaymodel.InputItem { inputItems := make([]relaymodel.InputItem, 0, len(messages)) @@ -913,14 +1056,7 @@ func ConvertMessagesToInputItems(messages []relaymodel.Message) []relaymodel.Inp // Array of content parts (multimodal) for _, part := range content { if partMap, ok := part.(map[string]any); ok { - if partType, ok := partMap["type"].(string); ok && partType == "text" { - if text, ok := partMap["text"].(string); ok { - inputItem.Content = append(inputItem.Content, relaymodel.InputContent{ - Type: contentType, - Text: text, - }) - } - } + appendChatContentPartToResponseInput(&inputItem, contentType, partMap) } } } @@ -963,6 +1099,21 @@ func ConvertChatCompletionToResponsesRequest( responsesReq.TopP = chatReq.TopP } + if chatReq.ResponseFormat != nil { + responsesReq.Text = convertChatResponseFormatToResponseText(chatReq.ResponseFormat) + } + + if chatReq.TopLogprobs != nil { + responsesReq.TopLogprobs = chatReq.TopLogprobs + } + + if chatReq.Logprobs != nil && *chatReq.Logprobs { + responsesReq.Include = appendUniqueString( + responsesReq.Include, + "message.output_text.logprobs", + ) + } + if chatReq.MaxTokens > 0 { responsesReq.MaxOutputTokens = &chatReq.MaxTokens } else if chatReq.MaxCompletionTokens > 0 { @@ -972,10 +1123,20 @@ func ConvertChatCompletionToResponsesRequest( // Map tools if len(chatReq.Tools) > 0 { responsesReq.Tools = ConvertToolsToResponseTools(chatReq.Tools) + } else if chatReq.Functions != nil { + responsesReq.Tools = convertLegacyFunctionsToResponseTools(chatReq.Functions) } if chatReq.ToolChoice != nil { - responsesReq.ToolChoice = chatReq.ToolChoice + responsesReq.ToolChoice = convertChatToolChoiceToResponseToolChoice(chatReq.ToolChoice) + } else if chatReq.FunctionCall != nil { + responsesReq.ToolChoice = convertLegacyFunctionCallToResponseToolChoice( + chatReq.FunctionCall, + ) + } + + if chatReq.ParallelToolCalls != nil { + responsesReq.ParallelToolCalls = chatReq.ParallelToolCalls } // Map service tier diff --git a/core/relay/adaptor/openai/chat_test.go b/core/relay/adaptor/openai/chat_test.go index ee80f330..5f5131e7 100644 --- a/core/relay/adaptor/openai/chat_test.go +++ b/core/relay/adaptor/openai/chat_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "io" "net/http" "net/http/httptest" "strings" @@ -338,6 +339,184 @@ func TestConvertChatCompletionToResponsesRequest(t *testing.T) { assert.Equal(t, "auto", responsesReq.ToolChoice) }, }, + { + name: "request with named tool choice", + inputRequest: relaymodel.GeneralOpenAIRequest{ + Model: "gpt-5-codex", + Messages: []relaymodel.Message{ + {Role: "user", Content: "What's the weather?"}, + }, + Tools: []relaymodel.Tool{ + { + Type: "function", + Function: relaymodel.Function{ + Name: "get_weather", + }, + }, + }, + ToolChoice: map[string]any{ + "type": "function", + "function": map[string]any{"name": "get_weather"}, + }, + }, + checkFunc: func(t *testing.T, responsesReq relaymodel.CreateResponseRequest) { + t.Helper() + assert.Equal( + t, + map[string]any{"type": "function", "name": "get_weather"}, + responsesReq.ToolChoice, + ) + }, + }, + { + name: "request with legacy functions", + inputRequest: relaymodel.GeneralOpenAIRequest{ + Model: "gpt-5-codex", + Messages: []relaymodel.Message{ + {Role: "user", Content: "What's the weather?"}, + }, + Functions: []any{ + map[string]any{ + "name": "get_weather", + "description": "Get weather information", + "parameters": map[string]any{ + "type": "object", + "properties": map[string]any{ + "location": map[string]any{ + "type": "string", + }, + }, + }, + }, + }, + FunctionCall: map[string]any{"name": "get_weather"}, + }, + checkFunc: func(t *testing.T, responsesReq relaymodel.CreateResponseRequest) { + t.Helper() + require.Len(t, responsesReq.Tools, 1) + assert.Equal(t, "function", responsesReq.Tools[0].Type) + assert.Equal(t, "get_weather", responsesReq.Tools[0].Name) + assert.Equal( + t, + map[string]any{"type": "function", "name": "get_weather"}, + responsesReq.ToolChoice, + ) + }, + }, + { + name: "request with image content", + inputRequest: relaymodel.GeneralOpenAIRequest{ + Model: "gpt-5.5", + Messages: []relaymodel.Message{ + { + Role: "user", + Content: []any{ + map[string]any{ + "type": "text", + "text": "What's in this image?", + }, + map[string]any{ + "type": "image_url", + "image_url": map[string]any{ + "url": "https://example.com/image.png", + "detail": "high", + }, + }, + }, + }, + }, + }, + checkFunc: func(t *testing.T, responsesReq relaymodel.CreateResponseRequest) { + t.Helper() + + inputItems, ok := responsesReq.Input.([]any) + require.True(t, ok) + require.Len(t, inputItems, 1) + + userItem, ok := inputItems[0].(map[string]any) + require.True(t, ok) + content, ok := userItem["content"].([]any) + require.True(t, ok) + require.Len(t, content, 2) + + imageContent, ok := content[1].(map[string]any) + require.True(t, ok) + assert.Equal(t, "input_image", imageContent["type"]) + assert.Equal(t, "https://example.com/image.png", imageContent["image_url"]) + assert.Equal(t, "high", imageContent["detail"]) + }, + }, + { + name: "request with response format", + inputRequest: relaymodel.GeneralOpenAIRequest{ + Model: "gpt-5-codex", + Messages: []relaymodel.Message{ + {Role: "user", Content: "Return JSON"}, + }, + ResponseFormat: &relaymodel.ResponseFormat{ + Type: "json_schema", + JSONSchema: &relaymodel.JSONSchema{ + Name: "answer", + Schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "answer": map[string]any{"type": "string"}, + }, + }, + }, + }, + }, + checkFunc: func(t *testing.T, responsesReq relaymodel.CreateResponseRequest) { + t.Helper() + require.NotNil(t, responsesReq.Text) + assert.Equal(t, "json_schema", responsesReq.Text.Format.Type) + assert.Equal(t, "answer", responsesReq.Text.Format.Name) + assert.Equal(t, map[string]any{ + "type": "object", + "properties": map[string]any{ + "answer": map[string]any{"type": "string"}, + }, + }, responsesReq.Text.Format.Schema) + }, + }, + { + name: "request with logprobs", + inputRequest: relaymodel.GeneralOpenAIRequest{ + Model: "gpt-5-codex", + Messages: []relaymodel.Message{ + {Role: "user", Content: "Hello"}, + }, + Logprobs: func() *bool { + value := true + return &value + }(), + TopLogprobs: func() *int { + value := 3 + return &value + }(), + }, + checkFunc: func(t *testing.T, responsesReq relaymodel.CreateResponseRequest) { + t.Helper() + require.NotNil(t, responsesReq.TopLogprobs) + assert.Equal(t, 3, *responsesReq.TopLogprobs) + assert.Contains(t, responsesReq.Include, "message.output_text.logprobs") + }, + }, + { + name: "request with parallel tool calls", + inputRequest: relaymodel.GeneralOpenAIRequest{ + Model: "gpt-5-codex", + Messages: []relaymodel.Message{ + {Role: "user", Content: "Hello"}, + }, + ParallelToolCalls: new(bool), + }, + checkFunc: func(t *testing.T, responsesReq relaymodel.CreateResponseRequest) { + t.Helper() + require.NotNil(t, responsesReq.ParallelToolCalls) + assert.False(t, *responsesReq.ParallelToolCalls) + }, + }, { name: "request with service tier", inputRequest: relaymodel.GeneralOpenAIRequest{ @@ -417,6 +596,105 @@ func TestConvertChatCompletionToResponsesRequest(t *testing.T) { } } +func TestConvertChatCompletionToResponsesRequestAcceptsMultipleChoices(t *testing.T) { + inputRequest := relaymodel.GeneralOpenAIRequest{ + Model: "gpt-5-codex", + Messages: []relaymodel.Message{ + {Role: "user", Content: "Hello"}, + }, + N: 2, + } + reqBody, err := json.Marshal(inputRequest) + require.NoError(t, err) + + req, _ := http.NewRequestWithContext(context.Background(), + http.MethodPost, + "/v1/chat/completions", + bytes.NewReader(reqBody), + ) + req.Header.Set("Content-Type", "application/json") + + m := &meta.Meta{ + ActualModel: inputRequest.Model, + } + + result, err := openai.ConvertChatCompletionToResponsesRequest(m, req) + require.NoError(t, err) + + body, err := io.ReadAll(result.Body) + require.NoError(t, err) + + assert.NotContains(t, string(body), `"n"`) + + var responsesReq relaymodel.CreateResponseRequest + + err = json.Unmarshal(body, &responsesReq) + require.NoError(t, err) + + assert.Equal(t, inputRequest.Model, responsesReq.Model) + assert.Equal(t, false, *responsesReq.Store) +} + +func TestConvertChatCompletionToResponsesRequestFlattensJSONSchemaTextFormat(t *testing.T) { + strict := true + inputRequest := relaymodel.GeneralOpenAIRequest{ + Model: "gpt-5-codex", + Messages: []relaymodel.Message{ + {Role: "user", Content: "Return JSON"}, + }, + ResponseFormat: &relaymodel.ResponseFormat{ + Type: "json_schema", + JSONSchema: &relaymodel.JSONSchema{ + Name: "answer", + Description: "Answer payload", + Strict: &strict, + Schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "answer": map[string]any{"type": "string"}, + }, + "required": []any{"answer"}, + }, + }, + }, + } + reqBody, err := json.Marshal(inputRequest) + require.NoError(t, err) + + req, _ := http.NewRequestWithContext(context.Background(), + http.MethodPost, + "/v1/chat/completions", + bytes.NewReader(reqBody), + ) + req.Header.Set("Content-Type", "application/json") + + result, err := openai.ConvertChatCompletionToResponsesRequest(&meta.Meta{ + ActualModel: inputRequest.Model, + }, req) + require.NoError(t, err) + + body, err := io.ReadAll(result.Body) + require.NoError(t, err) + + assert.NotContains(t, string(body), `"json_schema":`) + + var raw map[string]any + + err = json.Unmarshal(body, &raw) + require.NoError(t, err) + + text, ok := raw["text"].(map[string]any) + require.True(t, ok) + format, ok := text["format"].(map[string]any) + require.True(t, ok) + + assert.Equal(t, "json_schema", format["type"]) + assert.Equal(t, "answer", format["name"]) + assert.Equal(t, "Answer payload", format["description"]) + assert.Equal(t, true, format["strict"]) + assert.NotNil(t, format["schema"]) +} + func TestConvertResponsesToChatCompletionResponse(t *testing.T) { tests := []struct { name string diff --git a/core/relay/model/completions.go b/core/relay/model/completions.go index 031063e7..67bcf114 100644 --- a/core/relay/model/completions.go +++ b/core/relay/model/completions.go @@ -58,12 +58,15 @@ type GeneralOpenAIRequest struct { Size string `json:"size,omitempty"` Messages []Message `json:"messages,omitempty"` Tools []Tool `json:"tools,omitempty"` + Modalities []string `json:"modalities,omitempty"` Seed float64 `json:"seed,omitempty"` + N int `json:"n,omitempty"` MaxTokens int `json:"max_tokens,omitempty"` MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` TopK int `json:"top_k,omitempty"` NumCtx int `json:"num_ctx,omitempty"` Stream bool `json:"stream,omitempty"` + ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` ReasoningEffort *string `json:"reasoning_effort,omitempty"` EnableThinking *bool `json:"enable_thinking,omitempty"` ThinkingBudget *int `json:"thinking_budget,omitempty"` diff --git a/core/relay/model/response.go b/core/relay/model/response.go index 3f272596..608d8d65 100644 --- a/core/relay/model/response.go +++ b/core/relay/model/response.go @@ -157,7 +157,11 @@ type ResponseReasoning struct { // ResponseTextFormat represents text format configuration type ResponseTextFormat struct { - Type string `json:"type"` + Type string `json:"type"` + Name string `json:"name,omitempty"` + Schema map[string]any `json:"schema,omitempty"` + Strict *bool `json:"strict,omitempty"` + Description string `json:"description,omitempty"` } // ResponseText represents text configuration @@ -189,6 +193,10 @@ type OutputItem struct { type InputContent struct { Type string `json:"type"` Text string `json:"text,omitempty"` + // Fields for input_image type + ImageURL string `json:"image_url,omitempty"` + FileID string `json:"file_id,omitempty"` + Detail string `json:"detail,omitempty"` // Fields for function_call type ID string `json:"id,omitempty"` Name string `json:"name,omitempty"` diff --git a/core/relay/plugin/streamfake/fake.go b/core/relay/plugin/streamfake/fake.go index a7fdf6cf..9f1b7ab8 100644 --- a/core/relay/plugin/streamfake/fake.go +++ b/core/relay/plugin/streamfake/fake.go @@ -182,8 +182,15 @@ func (p *StreamFake) handleFakeStreamResponse( type fakeStreamResponseWriter struct { gin.ResponseWriter - lastChunk *ast.Node - usageNode *ast.Node + lastChunk *ast.Node + usageNode *ast.Node + choices map[int]*fakeStreamChoiceState + + // Azure OpenAI prompt-level content filtering fields + promptFilterResults *ast.Node +} + +type fakeStreamChoiceState struct { contentBuilder bytes.Buffer reasoningContent bytes.Buffer finishReason relaymodel.FinishReason @@ -194,12 +201,25 @@ type fakeStreamResponseWriter struct { audio map[string]*bytes.Buffer audioFields map[string]any - // Azure OpenAI content filtering fields - promptFilterResults *ast.Node // prompt-level filter results (from first chunk) + // Azure OpenAI choice-level content filtering fields contentFilterResults *ast.Node // choice-level filter results contentFilterResult *ast.Node // choice-level filter result (alternative field name) } +func (rw *fakeStreamResponseWriter) choiceState(index int) *fakeStreamChoiceState { + if rw.choices == nil { + rw.choices = make(map[int]*fakeStreamChoiceState) + } + + state := rw.choices[index] + if state == nil { + state = &fakeStreamChoiceState{} + rw.choices[index] = state + } + + return state +} + // ignore flush func (rw *fakeStreamResponseWriter) Flush() {} @@ -256,16 +276,23 @@ func (rw *fakeStreamResponseWriter) parseStreamingData(data []byte) error { } return choicesNode.ForEach(func(_ ast.Sequence, choiceNode *ast.Node) bool { + choiceIndex, err := choiceNode.Get("index").Int64() + if err != nil { + choiceIndex = 0 + } + + state := rw.choiceState(int(choiceIndex)) + // Extract content_filter_results from choice (keep last non-empty value) contentFilterResultsNode := choiceNode.Get("content_filter_results") if err := contentFilterResultsNode.Check(); err == nil { - rw.contentFilterResults = contentFilterResultsNode + state.contentFilterResults = contentFilterResultsNode } // Extract content_filter_result from choice (alternative field name, keep last non-empty value) contentFilterResultNode := choiceNode.Get("content_filter_result") if err := contentFilterResultNode.Check(); err == nil { - rw.contentFilterResult = contentFilterResultNode + state.contentFilterResult = contentFilterResultNode } deltaNode := choiceNode.Get("delta") @@ -277,7 +304,7 @@ func (rw *fakeStreamResponseWriter) parseStreamingData(data []byte) error { if err := contentNode.Check(); err == nil { // Try as string first (common case) if content, err := contentNode.String(); err == nil { - rw.contentBuilder.WriteString(content) + state.contentBuilder.WriteString(content) } else { // Try as array (for image/multimodal content) _ = contentNode.ForEach(func(_ ast.Sequence, partNode *ast.Node) bool { @@ -292,7 +319,7 @@ func (rw *fakeStreamResponseWriter) parseStreamingData(data []byte) error { } // Keep all parts in contentParts for multimodal content - rw.contentParts = append(rw.contentParts, part) + state.contentParts = append(state.contentParts, part) return true }) @@ -301,15 +328,15 @@ func (rw *fakeStreamResponseWriter) parseStreamingData(data []byte) error { reasoningContent, err := deltaNode.Get("reasoning_content").String() if err == nil { - rw.reasoningContent.WriteString(reasoningContent) + state.reasoningContent.WriteString(reasoningContent) } // Handle signature for thought if signature, err := deltaNode.Get("signature").String(); err == nil && signature != "" { - rw.signature = signature + state.signature = signature } - rw.processAudioDelta(deltaNode.Get("audio")) + state.processAudioDelta(deltaNode.Get("audio")) _ = deltaNode.Get("tool_calls"). ForEach(func(_ ast.Sequence, toolCallNode *ast.Node) bool { @@ -327,14 +354,14 @@ func (rw *fakeStreamResponseWriter) parseStreamingData(data []byte) error { return true } - rw.toolCalls = mergeToolCalls(rw.toolCalls, &toolCall) + state.toolCalls = mergeToolCalls(state.toolCalls, &toolCall) return true }) finishReason, err := choiceNode.Get("finish_reason").String() if err == nil && finishReason != "" { - rw.finishReason = finishReason + state.finishReason = finishReason } logprobsContentNode := choiceNode.GetByPath("logprobs", "content") @@ -344,10 +371,14 @@ func (rw *fakeStreamResponseWriter) parseStreamingData(data []byte) error { return true } - rw.logprobsContent = slices.Grow(rw.logprobsContent, l) + state.logprobsContent = slices.Grow(state.logprobsContent, l) _ = logprobsContentNode.ForEach( func(_ ast.Sequence, logprobsContentNode *ast.Node) bool { - rw.logprobsContent = append(rw.logprobsContent, *logprobsContentNode) + state.logprobsContent = append( + state.logprobsContent, + *logprobsContentNode, + ) + return true }, ) @@ -375,80 +406,96 @@ func (rw *fakeStreamResponseWriter) convertToNonStream() ([]byte, error) { } } + indexes := make([]int, 0, len(rw.choices)) + for index := range rw.choices { + indexes = append(indexes, index) + } + + slices.Sort(indexes) + + choices := make([]any, 0, len(indexes)) + for _, index := range indexes { + choices = append(choices, rw.choices[index].buildChoice(index)) + } + + _, err = lastChunk.SetAny("choices", choices) + if err != nil { + return nil, err + } + + // Add prompt_filter_results to response if present + if rw.promptFilterResults != nil { + _, err = lastChunk.Set("prompt_filter_results", *rw.promptFilterResults) + if err != nil { + return nil, err + } + } + + return lastChunk.MarshalJSON() +} + +func (state *fakeStreamChoiceState) buildChoice(index int) map[string]any { message := map[string]any{ "role": "assistant", } // Use contentParts if available (for image/multimodal content), otherwise use string content - if len(rw.contentParts) > 0 { - message["content"] = rw.contentParts + if len(state.contentParts) > 0 { + message["content"] = state.contentParts } else { - message["content"] = rw.contentBuilder.String() + message["content"] = state.contentBuilder.String() } - reasoningContent := rw.reasoningContent.String() + reasoningContent := state.reasoningContent.String() if reasoningContent != "" { message["reasoning_content"] = reasoningContent } - if rw.signature != "" { - message["signature"] = rw.signature + if state.signature != "" { + message["signature"] = state.signature } - if audio := rw.buildAudio(); len(audio) > 0 { + if audio := state.buildAudio(); len(audio) > 0 { message["audio"] = audio } - if len(rw.toolCalls) > 0 { - message["tool_calls"] = rw.buildToolCalls() + if len(state.toolCalls) > 0 { + message["tool_calls"] = state.buildToolCalls() } - if len(rw.logprobsContent) > 0 { + if len(state.logprobsContent) > 0 { message["logprobs"] = map[string]any{ - "content": rw.logprobsContent, + "content": state.logprobsContent, } } // Build choice with content filter fields choice := map[string]any{ - "index": 0, + "index": index, "message": message, - "finish_reason": rw.finishReason, + "finish_reason": state.finishReason, } // Add content_filter_results to choice if present - if rw.contentFilterResults != nil { - contentFilterResultsRaw, err := rw.contentFilterResults.Interface() + if state.contentFilterResults != nil { + contentFilterResultsRaw, err := state.contentFilterResults.Interface() if err == nil { choice["content_filter_results"] = contentFilterResultsRaw } } // Add content_filter_result to choice if present (alternative field name) - if rw.contentFilterResult != nil { - contentFilterResultRaw, err := rw.contentFilterResult.Interface() + if state.contentFilterResult != nil { + contentFilterResultRaw, err := state.contentFilterResult.Interface() if err == nil { choice["content_filter_result"] = contentFilterResultRaw } } - _, err = lastChunk.SetAny("choices", []any{choice}) - if err != nil { - return nil, err - } - - // Add prompt_filter_results to response if present - if rw.promptFilterResults != nil { - _, err = lastChunk.Set("prompt_filter_results", *rw.promptFilterResults) - if err != nil { - return nil, err - } - } - - return lastChunk.MarshalJSON() + return choice } -func (rw *fakeStreamResponseWriter) processAudioDelta(audioNode *ast.Node) { +func (state *fakeStreamChoiceState) processAudioDelta(audioNode *ast.Node) { if audioNode == nil || audioNode.TypeSafe() != ast.V_OBJECT { return } @@ -470,14 +517,14 @@ func (rw *fakeStreamResponseWriter) processAudioDelta(audioNode *ast.Node) { return true } - if rw.audio == nil { - rw.audio = make(map[string]*bytes.Buffer) + if state.audio == nil { + state.audio = make(map[string]*bytes.Buffer) } - builder := rw.audio[key] + builder := state.audio[key] if builder == nil { builder = &bytes.Buffer{} - rw.audio[key] = builder + state.audio[key] = builder } builder.WriteString(value) @@ -490,11 +537,11 @@ func (rw *fakeStreamResponseWriter) processAudioDelta(audioNode *ast.Node) { return true } - if rw.audioFields == nil { - rw.audioFields = make(map[string]any) + if state.audioFields == nil { + state.audioFields = make(map[string]any) } - rw.audioFields[key] = value + state.audioFields[key] = value return true }) @@ -504,36 +551,36 @@ func shouldAppendAudioField(key string) bool { return key == "data" || key == "transcript" } -func (rw *fakeStreamResponseWriter) buildAudio() map[string]any { - audio := make(map[string]any, len(rw.audio)+len(rw.audioFields)) +func (state *fakeStreamChoiceState) buildAudio() map[string]any { + audio := make(map[string]any, len(state.audio)+len(state.audioFields)) - maps.Copy(audio, rw.audioFields) + maps.Copy(audio, state.audioFields) - for key, builder := range rw.audio { + for key, builder := range state.audio { audio[key] = builder.String() } return audio } -func (rw *fakeStreamResponseWriter) buildToolCalls() []*relaymodel.ToolCall { - if len(rw.toolCalls) == 0 { +func (state *fakeStreamChoiceState) buildToolCalls() []*relaymodel.ToolCall { + if len(state.toolCalls) == 0 { return nil } - slices.SortFunc(rw.toolCalls, func(a, b *relaymodel.ToolCall) int { + slices.SortFunc(state.toolCalls, func(a, b *relaymodel.ToolCall) int { return a.Index - b.Index }) - if rw.toolCalls[0].Index == 0 { - return rw.toolCalls + if state.toolCalls[0].Index == 0 { + return state.toolCalls } // fix tool call index start with 0 - for i, v := range rw.toolCalls { + for i, v := range state.toolCalls { v.Index = i } - return rw.toolCalls + return state.toolCalls } func mergeToolCalls( diff --git a/core/relay/plugin/streamfake/fake_test.go b/core/relay/plugin/streamfake/fake_test.go index 67e97614..f798c38a 100644 --- a/core/relay/plugin/streamfake/fake_test.go +++ b/core/relay/plugin/streamfake/fake_test.go @@ -2,13 +2,135 @@ package streamfake import ( + "bytes" + "io" + "net/http" "testing" "github.com/bytedance/sonic" + coremodel "github.com/labring/aiproxy/core/model" + "github.com/labring/aiproxy/core/relay/adaptor" + "github.com/labring/aiproxy/core/relay/meta" + "github.com/labring/aiproxy/core/relay/mode" + "github.com/labring/aiproxy/core/relay/plugin/patch" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +type streamFakeConvertRequestStub struct { + body []byte +} + +func (s *streamFakeConvertRequestStub) ConvertRequest( + _ *meta.Meta, + _ adaptor.Store, + req *http.Request, +) (adaptor.ConvertResult, error) { + body, err := io.ReadAll(req.Body) + if err != nil { + return adaptor.ConvertResult{}, err + } + + s.body = body + + return adaptor.ConvertResult{Body: bytes.NewReader(body)}, nil +} + +func newStreamFakeTestMeta() *meta.Meta { + return meta.NewMeta(nil, mode.ChatCompletions, "gpt-4.1", coremodel.ModelConfig{ + Model: "gpt-4.1", + Plugin: map[string]map[string]any{ + "stream-fake": {"enable": true}, + }, + }) +} + +func TestConvertRequestEnablesFakeStreamForMultipleChoices(t *testing.T) { + m := newStreamFakeTestMeta() + req := httptestNewJSONRequest(t, `{"model":"gpt-4.1","messages":[],"n":2}`) + stub := &streamFakeConvertRequestStub{} + + _, err := (&StreamFake{}).ConvertRequest(m, nil, req, stub) + require.NoError(t, err) + + value, ok := m.Get(fakeStreamKey) + require.True(t, ok) + assert.Equal(t, true, value) + assert.Len(t, patch.GetLazyPatches(m), 1) +} + +func TestConvertRequestEnablesFakeStreamForSingleChoice(t *testing.T) { + m := newStreamFakeTestMeta() + req := httptestNewJSONRequest(t, `{"model":"gpt-4.1","messages":[]}`) + stub := &streamFakeConvertRequestStub{} + + _, err := (&StreamFake{}).ConvertRequest(m, nil, req, stub) + require.NoError(t, err) + + value, ok := m.Get(fakeStreamKey) + require.True(t, ok) + assert.Equal(t, true, value) + assert.Len(t, patch.GetLazyPatches(m), 1) +} + +func httptestNewJSONRequest(t *testing.T, body string) *http.Request { + t.Helper() + + req, err := http.NewRequestWithContext( + t.Context(), + http.MethodPost, + "/v1/chat/completions", + bytes.NewReader([]byte(body)), + ) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + + return req +} + +func TestConvertToNonStreamPreservesMultipleChoices(t *testing.T) { + rw := &fakeStreamResponseWriter{} + + chunks := []string{ + `{"choices":[{"delta":{"role":"assistant","content":"A"},"finish_reason":null,"index":0},{"delta":{"role":"assistant","content":"B"},"finish_reason":null,"index":1}],"created":1767597874,"id":"chatcmpl-test","model":"gpt-4.1","object":"chat.completion.chunk"}`, + `{"choices":[{"delta":{"content":" one"},"finish_reason":null,"index":0},{"delta":{"content":" two"},"finish_reason":null,"index":1}],"created":1767597874,"id":"chatcmpl-test","model":"gpt-4.1","object":"chat.completion.chunk"}`, + `{"choices":[{"delta":{},"finish_reason":"stop","index":0},{"delta":{},"finish_reason":"length","index":1}],"created":1767597874,"id":"chatcmpl-test","model":"gpt-4.1","object":"chat.completion.chunk","usage":{"completion_tokens":4,"prompt_tokens":10,"total_tokens":14}}`, + } + + for _, chunk := range chunks { + err := rw.parseStreamingData([]byte(chunk)) + require.NoError(t, err) + } + + result, err := rw.convertToNonStream() + require.NoError(t, err) + + var response map[string]any + + err = sonic.Unmarshal(result, &response) + require.NoError(t, err) + + choices, ok := response["choices"].([]any) + require.True(t, ok) + require.Len(t, choices, 2) + + choice0, ok := choices[0].(map[string]any) + require.True(t, ok) + assert.Equal(t, float64(0), choice0["index"]) + assert.Equal(t, "stop", choice0["finish_reason"]) + message0, ok := choice0["message"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "A one", message0["content"]) + + choice1, ok := choices[1].(map[string]any) + require.True(t, ok) + assert.Equal(t, float64(1), choice1["index"]) + assert.Equal(t, "length", choice1["finish_reason"]) + message1, ok := choice1["message"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "B two", message1["content"]) +} + func TestParseStreamingDataWithContentFilterFields(t *testing.T) { tests := []struct { name string @@ -61,11 +183,14 @@ func TestParseStreamingDataWithContentFilterFields(t *testing.T) { require.NoError(t, err) } + state := rw.choices[0] + require.NotNil(t, state) + // Check content - assert.Equal(t, tt.expectedContent, rw.contentBuilder.String()) + assert.Equal(t, tt.expectedContent, state.contentBuilder.String()) // Check finish reason - assert.Equal(t, tt.expectedFinishReason, rw.finishReason) + assert.Equal(t, tt.expectedFinishReason, state.finishReason) // Check prompt_filter_results if tt.hasPromptFilterResults { @@ -76,16 +201,16 @@ func TestParseStreamingDataWithContentFilterFields(t *testing.T) { // Check content_filter_results if tt.hasContentFilterResults { - assert.NotNil(t, rw.contentFilterResults) + assert.NotNil(t, state.contentFilterResults) } else { - assert.Nil(t, rw.contentFilterResults) + assert.Nil(t, state.contentFilterResults) } // Check content_filter_result if tt.hasContentFilterResult { - assert.NotNil(t, rw.contentFilterResult) + assert.NotNil(t, state.contentFilterResult) } else { - assert.Nil(t, rw.contentFilterResult) + assert.Nil(t, state.contentFilterResult) } // Convert to non-stream and verify @@ -283,7 +408,7 @@ func TestParseStreamingDataWithEmptyChoices(t *testing.T) { require.NoError(t, err) assert.NotNil(t, rw.promptFilterResults) - assert.Equal(t, "", rw.contentBuilder.String()) // No content yet + assert.Empty(t, rw.choices) // No choice content yet } func TestContentFilterResultPreservesErrorDetails(t *testing.T) {