diff --git a/fixtures/fixtures.go b/fixtures/fixtures.go index c731e0fb..c6b3a78c 100644 --- a/fixtures/fixtures.go +++ b/fixtures/fixtures.go @@ -108,6 +108,9 @@ var ( //go:embed openai/responses/blocking/wrong_response_format.txtar OaiResponsesBlockingWrongResponseFormat []byte + + //go:embed openai/responses/blocking/web_search.txtar + OaiResponsesBlockingWebSearch []byte ) var ( diff --git a/fixtures/openai/responses/blocking/web_search.txtar b/fixtures/openai/responses/blocking/web_search.txtar new file mode 100644 index 00000000..ff139219 --- /dev/null +++ b/fixtures/openai/responses/blocking/web_search.txtar @@ -0,0 +1,103 @@ +-- request -- +{ + "input": [ + { + "role": "user", + "content": "What is the current weather in Cape Town?" + } + ], + "model": "gpt-5", + "stream": false, + "tools": [ + { + "type": "web_search" + } + ] +} + +-- non-streaming -- +{ + "id": "resp_a1b2c3d4e5f6a7b8c9d0e1f2a3b4c5d6a7b8c9d0e1", + "object": "response", + "created_at": 1767875200, + "status": "completed", + "background": false, + "billing": { + "payer": "developer" + }, + "completed_at": 1767875205, + "error": null, + "incomplete_details": null, + "instructions": null, + "max_output_tokens": null, + "max_tool_calls": null, + "model": "gpt-5-0806", + "output": [ + { + "id": "ws_a1b2c3d4e5f6a7b8c9d0e1f2a3b4c5d6a7b8c9d0e1", + "type": "web_search_call", + "status": "completed" + }, + { + "id": "msg_a1b2c3d4e5f6a7b8c9d0e1f2a3b4c5d6a7b8c9d0e1", + "type": "message", + "status": "completed", + "role": "assistant", + "content": [ + { + "type": "output_text", + "text": "Based on my search, the current weather in Cape Town is partly cloudy with a temperature of around 22°C (72°F).", + "annotations": [ + { + "type": "url_citation", + "start_index": 28, + "end_index": 47, + "url": "https://weather.example.com/cape-town", + "title": "Cape Town Weather" + } + ] + } + ] + } + ], + "parallel_tool_calls": true, + "previous_response_id": null, + "prompt_cache_key": null, + "prompt_cache_retention": null, + "reasoning": { + "effort": null, + "summary": null + }, + "safety_identifier": null, + "service_tier": "default", + "store": true, + "temperature": 1.0, + "text": { + "format": { + "type": "text" + }, + "verbosity": "medium" + }, + "tool_choice": "auto", + "tools": [ + { + "type": "web_search" + } + ], + "top_logprobs": 0, + "top_p": 1.0, + "truncation": "disabled", + "usage": { + "input_tokens": 42, + "input_tokens_details": { + "cached_tokens": 0 + }, + "output_tokens": 150, + "output_tokens_details": { + "reasoning_tokens": 0 + }, + "total_tokens": 192 + }, + "user": null, + "metadata": {} +} diff --git a/intercept/responses/base.go b/intercept/responses/base.go index daccc300..76f142dd 100644 --- a/intercept/responses/base.go +++ b/intercept/responses/base.go @@ -204,14 +204,54 @@ func (i *responsesInterceptionBase) recordNonInjectedToolUsage(ctx context.Conte } for _, item := range response.Output { - var args recorder.ToolArgs + var ( + args recorder.ToolArgs + toolName string + itemID string + callID string + ) - // recording other function types to be considered: https://github.com/coder/aibridge/issues/121 switch item.Type { case string(constant.ValueOf[constant.FunctionCall]()): args = i.parseFunctionCallJSONArgs(ctx, item.Arguments) + toolName = item.Name + itemID = item.ID + callID = item.CallID + case string(constant.ValueOf[constant.CustomToolCall]()): args = item.Input + toolName = item.Name + itemID = item.ID + callID = item.CallID + + // Agentic tools: the client sends a corresponding *_output + // item correlated by call_id. + case "computer_call", + string(constant.ValueOf[constant.LocalShellCall]()), + string(constant.ValueOf[constant.ShellCall]()), + string(constant.ValueOf[constant.ApplyPatchCall]()): + toolName = item.Name + if toolName == "" { + toolName = item.Type + } + itemID = item.ID + callID = item.CallID + + // Hosted tools: executed server-side, these output items + // carry only an id field — not call_id. The client never + // submits output for them. + // https://platform.openai.com/docs/api-reference/responses/create + case string(constant.ValueOf[constant.WebSearchCall]()), + string(constant.ValueOf[constant.FileSearchCall]()), + string(constant.ValueOf[constant.CodeInterpreterCall]()), + string(constant.ValueOf[constant.ImageGenerationCall]()), + string(constant.ValueOf[constant.McpCall]()): + toolName = item.Name + if toolName == "" { + toolName = item.Type + } + itemID = item.ID + default: continue } @@ -219,12 +259,13 @@ func (i *responsesInterceptionBase) recordNonInjectedToolUsage(ctx context.Conte if err := i.recorder.RecordToolUsage(ctx, &recorder.ToolUsageRecord{ InterceptionID: i.ID().String(), MsgID: response.ID, - ToolCallID: item.CallID, - Tool: item.Name, + ItemID: itemID, + ToolCallID: callID, + Tool: toolName, Args: args, Injected: false, }); err != nil { - i.logger.Warn(ctx, "failed to record tool usage", slog.Error(err), slog.F("tool", item.Name)) + i.logger.Warn(ctx, "failed to record tool usage", slog.Error(err), slog.F("tool", toolName)) } } } diff --git a/intercept/responses/base_test.go b/intercept/responses/base_test.go index ea5c87b5..d787813c 100644 --- a/intercept/responses/base_test.go +++ b/intercept/responses/base_test.go @@ -113,6 +113,7 @@ func TestRecordToolUsage(t *testing.T) { { InterceptionID: id.String(), MsgID: "resp_456", + ItemID: "", ToolCallID: "call_abc", Tool: "get_weather", Args: "", @@ -160,6 +161,7 @@ func TestRecordToolUsage(t *testing.T) { { InterceptionID: id.String(), MsgID: "resp_789", + ItemID: "", ToolCallID: "call_1", Tool: "get_weather", Args: map[string]any{"location": "NYC"}, @@ -168,6 +170,7 @@ func TestRecordToolUsage(t *testing.T) { { InterceptionID: id.String(), MsgID: "resp_789", + ItemID: "", ToolCallID: "call_2", Tool: "bad_json_args", Args: `{"bad": args`, @@ -176,6 +179,7 @@ func TestRecordToolUsage(t *testing.T) { { InterceptionID: id.String(), MsgID: "resp_789", + ItemID: "", ToolCallID: "call_3", Tool: "search", Args: `{\"query\": \"test\"}`, @@ -184,6 +188,7 @@ func TestRecordToolUsage(t *testing.T) { { InterceptionID: id.String(), MsgID: "resp_789", + ItemID: "", ToolCallID: "call_4", Tool: "calculate", Args: map[string]any{"a": float64(1), "b": float64(2)}, @@ -191,6 +196,146 @@ func TestRecordToolUsage(t *testing.T) { }, }, }, + { + name: "web_search_call_with_no_name", + response: &oairesponses.Response{ + ID: "resp_ws", + Output: []oairesponses.ResponseOutputItemUnion{ + { + Type: "web_search_call", + ID: "ws_abc", + }, + }, + }, + expected: []*recorder.ToolUsageRecord{ + { + InterceptionID: id.String(), + MsgID: "resp_ws", + ItemID: "ws_abc", + ToolCallID: "", + Tool: "web_search_call", + Injected: false, + }, + }, + }, + { + name: "all_additional_tool_types", + response: &oairesponses.Response{ + ID: "resp_all", + Output: []oairesponses.ResponseOutputItemUnion{ + { + Type: "web_search_call", + ID: "ws_1", + }, + { + Type: "computer_call", + CallID: "call_comp", + }, + { + Type: "local_shell_call", + CallID: "call_lsh", + }, + { + Type: "shell_call", + CallID: "call_sh", + }, + { + Type: "apply_patch_call", + CallID: "call_ap", + }, + { + Type: "code_interpreter_call", + ID: "ci_1", + }, + { + Type: "mcp_call", + ID: "mcp_1", + Name: "my_mcp_tool", + }, + { + Type: "file_search_call", + ID: "fs_1", + }, + { + Type: "image_generation_call", + ID: "ig_1", + }, + { + Type: "message", + ID: "msg_skip", + }, + { + Type: "reasoning", + ID: "rs_skip", + }, + }, + }, + expected: []*recorder.ToolUsageRecord{ + { + InterceptionID: id.String(), + MsgID: "resp_all", + ItemID: "ws_1", + Tool: "web_search_call", + Injected: false, + }, + { + InterceptionID: id.String(), + MsgID: "resp_all", + ToolCallID: "call_comp", + Tool: "computer_call", + Injected: false, + }, + { + InterceptionID: id.String(), + MsgID: "resp_all", + ToolCallID: "call_lsh", + Tool: "local_shell_call", + Injected: false, + }, + { + InterceptionID: id.String(), + MsgID: "resp_all", + ToolCallID: "call_sh", + Tool: "shell_call", + Injected: false, + }, + { + InterceptionID: id.String(), + MsgID: "resp_all", + ToolCallID: "call_ap", + Tool: "apply_patch_call", + Injected: false, + }, + { + InterceptionID: id.String(), + MsgID: "resp_all", + ItemID: "ci_1", + Tool: "code_interpreter_call", + Injected: false, + }, + { + InterceptionID: id.String(), + MsgID: "resp_all", + ItemID: "mcp_1", + Tool: "my_mcp_tool", + Injected: false, + }, + { + InterceptionID: id.String(), + MsgID: "resp_all", + ItemID: "fs_1", + Tool: "file_search_call", + Injected: false, + }, + { + InterceptionID: id.String(), + MsgID: "resp_all", + ItemID: "ig_1", + Tool: "image_generation_call", + Injected: false, + }, + }, + }, } for _, tc := range tests { diff --git a/intercept/responses/injected_tools.go b/intercept/responses/injected_tools.go index dd44014b..b5872467 100644 --- a/intercept/responses/injected_tools.go +++ b/intercept/responses/injected_tools.go @@ -178,6 +178,7 @@ func (i *responsesInterceptionBase) invokeInjectedTool(ctx context.Context, resp _ = i.recorder.RecordToolUsage(ctx, &recorder.ToolUsageRecord{ InterceptionID: i.ID().String(), MsgID: responseID, + ItemID: fc.ID, ToolCallID: fc.CallID, ServerURL: &tool.ServerURL, Tool: tool.Name, diff --git a/internal/integrationtest/responses_test.go b/internal/integrationtest/responses_test.go index 1a35f707..d3a4983f 100644 --- a/internal/integrationtest/responses_test.go +++ b/internal/integrationtest/responses_test.go @@ -73,6 +73,7 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) { expectToolRecorded: &recorder.ToolUsageRecord{ MsgID: "resp_0da6045a8b68fa5200695fa23dcc2c81a19c849f627abf8a31", Tool: "add", + ItemID: "fc_0da6045a8b68fa5200695fa23e198081a19bf68887d47ae93d", ToolCallID: "call_CJSaa2u51JG996575oVljuNq", Args: map[string]any{"a": float64(3), "b": float64(5)}, Injected: false, @@ -115,6 +116,7 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) { expectToolRecorded: &recorder.ToolUsageRecord{ MsgID: "resp_09c614364030cdf000696942589da081a0af07f5859acb7308", Tool: "code_exec", + ItemID: "ctc_09c614364030cdf0006969425bf33481a09cc0f9522af2d980", ToolCallID: "call_haf8njtwrVZ1754Gm6fjAtuA", Args: "print(\"hello world\")", Injected: false, @@ -166,8 +168,30 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) { expectedClient: aibridge.ClientUnknown, }, { - name: "streaming_simple", - fixture: fixtures.OaiResponsesStreamingSimple, + name: "blocking_web_search", + fixture: fixtures.OaiResponsesBlockingWebSearch, + expectModel: "gpt-5", + expectPromptRecorded: "What is the current weather in Cape Town?", + expectToolRecorded: &recorder.ToolUsageRecord{ + MsgID: "resp_a1b2c3d4e5f6a7b8c9d0e1f2a3b4c5d6a7b8c9d0e1", + Tool: "web_search_call", + ItemID: "ws_a1b2c3d4e5f6a7b8c9d0e1f2a3b4c5d6a7b8c9d0e1", + Injected: false, + }, + expectTokenUsage: &recorder.TokenUsageRecord{ + MsgID: "resp_a1b2c3d4e5f6a7b8c9d0e1f2a3b4c5d6a7b8c9d0e1", + Input: 42, + Output: 150, + ExtraTokenTypes: map[string]int64{ + "input_cached": 0, + "output_reasoning": 0, + "total_tokens": 192, + }, + }, + expectedClient: aibridge.ClientUnknown, + }, + { + name: "streaming_simple", fixture: fixtures.OaiResponsesStreamingSimple, streaming: true, expectModel: "gpt-4o-mini", expectPromptRecorded: "tell me a joke", @@ -212,6 +236,7 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) { expectToolRecorded: &recorder.ToolUsageRecord{ MsgID: "resp_0c3fb28cfcf463a500695fa2f0239481a095ec6ce3dfe4d458", Tool: "add", + ItemID: "fc_0c3fb28cfcf463a500695fa2f0b0a881a0890103ba88b0628e", ToolCallID: "call_7VaiUXZYuuuwWwviCrckxq6t", Args: map[string]any{"a": float64(3), "b": float64(5)}, Injected: false, @@ -256,6 +281,7 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) { expectToolRecorded: &recorder.ToolUsageRecord{ MsgID: "resp_0c26996bc41c2a0500696942e83634819fb71b2b8ff8a4a76c", Tool: "code_exec", + ItemID: "ctc_0c26996bc41c2a0500696942ee6db8819fa6e841317eecbfb2", ToolCallID: "call_2gSnF58IEhXLwlbnqbm5XKMd", Args: "print(\"hello world\")", Injected: false, diff --git a/recorder/types.go b/recorder/types.go index cd541eeb..22a923f1 100644 --- a/recorder/types.go +++ b/recorder/types.go @@ -75,6 +75,7 @@ type ToolUsageRecord struct { InterceptionID string MsgID string Tool string + ItemID string ToolCallID string ServerURL *string Args ToolArgs