diff --git a/.changeset/fix-toolcallid-tracking.md b/.changeset/fix-toolcallid-tracking.md new file mode 100644 index 000000000000..e5e2507e3552 --- /dev/null +++ b/.changeset/fix-toolcallid-tracking.md @@ -0,0 +1,5 @@ +--- +'ai': patch +--- + +Use toolCallId instead of generateId for parallel tool execution tracking to prevent premature stream closure diff --git a/packages/ai/src/generate-text/run-tools-transformation.test.ts b/packages/ai/src/generate-text/run-tools-transformation.test.ts index 6adade0f83d5..1c5005d949b5 100644 --- a/packages/ai/src/generate-text/run-tools-transformation.test.ts +++ b/packages/ai/src/generate-text/run-tools-transformation.test.ts @@ -12,7 +12,23 @@ import { describe, expect, it } from 'vitest'; import { z } from 'zod/v4'; import { NoSuchToolError } from '../error/no-such-tool-error'; import { MockTracer } from '../test/mock-tracer'; -import { runToolsTransformation } from './run-tools-transformation'; +import { + runToolsTransformation, + SingleRequestTextStreamPart, +} from './run-tools-transformation'; +import { ToolSet } from './tool-set'; + +function isToolResult( + part: SingleRequestTextStreamPart, +): part is SingleRequestTextStreamPart & { type: 'tool-result' } { + return part.type === 'tool-result'; +} + +function isToolCall( + part: SingleRequestTextStreamPart, +): part is SingleRequestTextStreamPart & { type: 'tool-call' } { + return part.type === 'tool-call'; +} const testUsage: LanguageModelV3Usage = { inputTokens: { @@ -1140,4 +1156,321 @@ describe('runToolsTransformation', () => { }); }); }); + + describe('parallel tool execution', () => { + it('should use toolCallId for tracking (not generateId) to handle parallel tools correctly', async () => { + // Frameworks can override _internal.generateId for message grouping, returning + // a constant pendingMessageId for all calls within a request. Tool execution + // tracking must use toolCallId (unique per LLM tool call) instead. + const pendingMessageId = 'msg-abc123'; + const frameworkGenerateId = () => pendingMessageId; + + const inputStream: ReadableStream = + convertArrayToReadableStream([ + { + type: 'tool-call', + toolCallId: 'unique-call-1', + toolName: 'toolA', + input: `{ "value": "a" }`, + }, + { + type: 'tool-call', + toolCallId: 'unique-call-2', + toolName: 'toolB', + input: `{ "value": "b" }`, + }, + { + type: 'tool-call', + toolCallId: 'unique-call-3', + toolName: 'toolC', + input: `{ "value": "c" }`, + }, + { + type: 'finish', + finishReason: { unified: 'tool-calls', raw: 'tool_calls' }, + usage: testUsage, + }, + ]); + + const transformedStream = runToolsTransformation({ + generateId: frameworkGenerateId, + tools: { + toolA: { + title: 'Tool A', + inputSchema: z.object({ value: z.string() }), + execute: async ({ value }) => { + await delay(30); + return `${value}-result`; + }, + }, + toolB: { + title: 'Tool B', + inputSchema: z.object({ value: z.string() }), + execute: async ({ value }) => { + await delay(10); + return `${value}-result`; + }, + }, + toolC: { + title: 'Tool C', + inputSchema: z.object({ value: z.string() }), + execute: async ({ value }) => { + await delay(20); + return `${value}-result`; + }, + }, + }, + generatorStream: inputStream, + tracer: new MockTracer(), + telemetry: undefined, + messages: [], + system: undefined, + abortSignal: undefined, + repairToolCall: undefined, + experimental_context: undefined, + }); + + const result = await convertReadableStreamToArray(transformedStream); + + // All three tool results should be captured + // (Bug: without the fix, only 1 result would be captured because + // outstandingToolResults Set would use the same ID for all tools) + const toolResults = result.filter(isToolResult); + expect(toolResults).toHaveLength(3); + expect(toolResults.map(r => r.toolCallId).sort()).toEqual([ + 'unique-call-1', + 'unique-call-2', + 'unique-call-3', + ]); + + // Finish should be last + expect(result[result.length - 1]).toMatchObject({ + type: 'finish', + }); + }); + + it('should capture all results when multiple tools execute in parallel with different delays', async () => { + const inputStream: ReadableStream = + convertArrayToReadableStream([ + { + type: 'tool-call', + toolCallId: 'call-1', + toolName: 'slowTool', + input: `{ "value": "slow" }`, + }, + { + type: 'tool-call', + toolCallId: 'call-2', + toolName: 'fastTool', + input: `{ "value": "fast" }`, + }, + { + type: 'tool-call', + toolCallId: 'call-3', + toolName: 'mediumTool', + input: `{ "value": "medium" }`, + }, + { + type: 'finish', + finishReason: { unified: 'tool-calls', raw: 'tool_calls' }, + usage: testUsage, + }, + ]); + + const transformedStream = runToolsTransformation({ + generateId: mockId({ prefix: 'id' }), + tools: { + slowTool: { + title: 'Slow Tool', + inputSchema: z.object({ value: z.string() }), + execute: async ({ value }) => { + await delay(50); // Slowest + return `${value}-result`; + }, + }, + fastTool: { + title: 'Fast Tool', + inputSchema: z.object({ value: z.string() }), + execute: async ({ value }) => { + await delay(10); // Fastest + return `${value}-result`; + }, + }, + mediumTool: { + title: 'Medium Tool', + inputSchema: z.object({ value: z.string() }), + execute: async ({ value }) => { + await delay(30); // Medium + return `${value}-result`; + }, + }, + }, + generatorStream: inputStream, + tracer: new MockTracer(), + telemetry: undefined, + messages: [], + system: undefined, + abortSignal: undefined, + repairToolCall: undefined, + experimental_context: undefined, + }); + + const result = await convertReadableStreamToArray(transformedStream); + + // All three tool calls should be present + const toolCalls = result.filter(isToolCall); + expect(toolCalls).toHaveLength(3); + + // All three tool results should be present + const toolResults = result.filter(isToolResult); + expect(toolResults).toHaveLength(3); + expect(toolResults.map(r => r.toolCallId).sort()).toEqual([ + 'call-1', + 'call-2', + 'call-3', + ]); + + // Finish should be last + expect(result[result.length - 1]).toMatchObject({ + type: 'finish', + }); + }); + + it('should not close stream prematurely when fast tool completes before slow tool', async () => { + const executionOrder: string[] = []; + + const inputStream: ReadableStream = + convertArrayToReadableStream([ + { + type: 'tool-call', + toolCallId: 'slow-call', + toolName: 'slowTool', + input: `{ "value": "slow" }`, + }, + { + type: 'tool-call', + toolCallId: 'fast-call', + toolName: 'fastTool', + input: `{ "value": "fast" }`, + }, + { + type: 'finish', + finishReason: { unified: 'tool-calls', raw: 'tool_calls' }, + usage: testUsage, + }, + ]); + + const transformedStream = runToolsTransformation({ + generateId: mockId({ prefix: 'id' }), + tools: { + slowTool: { + title: 'Slow Tool', + inputSchema: z.object({ value: z.string() }), + execute: async ({ value }) => { + await delay(50); + executionOrder.push('slow-completed'); + return `${value}-slow-result`; + }, + }, + fastTool: { + title: 'Fast Tool', + inputSchema: z.object({ value: z.string() }), + execute: async ({ value }) => { + await delay(5); + executionOrder.push('fast-completed'); + return `${value}-fast-result`; + }, + }, + }, + generatorStream: inputStream, + tracer: new MockTracer(), + telemetry: undefined, + messages: [], + system: undefined, + abortSignal: undefined, + repairToolCall: undefined, + experimental_context: undefined, + }); + + const result = await convertReadableStreamToArray(transformedStream); + + // Fast tool should complete first + expect(executionOrder).toEqual(['fast-completed', 'slow-completed']); + + // Both results should be captured + const toolResults = result.filter(isToolResult); + expect(toolResults).toHaveLength(2); + expect(toolResults.map(r => r.output).sort()).toEqual([ + 'fast-fast-result', + 'slow-slow-result', + ]); + + // Stream should close properly after all tools complete + expect(result[result.length - 1]).toMatchObject({ + type: 'finish', + }); + }); + + it('should handle many parallel tool calls without losing results', async () => { + const toolCount = 10; + const toolCalls = Array.from({ length: toolCount }, (_, i) => ({ + type: 'tool-call' as const, + toolCallId: `call-${i}`, + toolName: 'parallelTool', + input: `{ "index": ${i} }`, + })); + + const inputStream: ReadableStream = + convertArrayToReadableStream([ + ...toolCalls, + { + type: 'finish', + finishReason: { unified: 'tool-calls', raw: 'tool_calls' }, + usage: testUsage, + }, + ]); + + const transformedStream = runToolsTransformation({ + generateId: mockId({ prefix: 'id' }), + tools: { + parallelTool: { + title: 'Parallel Tool', + inputSchema: z.object({ index: z.number() }), + execute: async ({ index }) => { + // Random delay to simulate real-world variance + await delay(Math.random() * 20); + return `result-${index}`; + }, + }, + }, + generatorStream: inputStream, + tracer: new MockTracer(), + telemetry: undefined, + messages: [], + system: undefined, + abortSignal: undefined, + repairToolCall: undefined, + experimental_context: undefined, + }); + + const result = await convertReadableStreamToArray(transformedStream); + + // All tool results should be captured + const toolResults = result.filter(isToolResult); + expect(toolResults).toHaveLength(toolCount); + + // Verify all results are present (order may vary) + const resultOutputs = toolResults.map(r => r.output).sort(); + const expectedOutputs = Array.from( + { length: toolCount }, + (_, i) => `result-${i}`, + ).sort(); + expect(resultOutputs).toEqual(expectedOutputs); + + // Finish should be last + expect(result[result.length - 1]).toMatchObject({ + type: 'finish', + }); + }); + }); }); diff --git a/packages/ai/src/generate-text/run-tools-transformation.ts b/packages/ai/src/generate-text/run-tools-transformation.ts index 29865865aad7..a21d05a249d0 100644 --- a/packages/ai/src/generate-text/run-tools-transformation.ts +++ b/packages/ai/src/generate-text/run-tools-transformation.ts @@ -309,7 +309,11 @@ export function runToolsTransformation({ // Only execute tools that are not provider-executed: if (tool.execute != null && toolCall.providerExecuted !== true) { - const toolExecutionId = generateId(); // use our own id to guarantee uniqueness + // Use toolCallId for tracking - it's unique per tool call from the LLM. + // Don't use generateId() here because frameworks can override it for + // message grouping (returning the same ID for all tools in a request), + // which would cause the Set to track only one tool instead of all. + const toolExecutionId = toolCall.toolCallId; outstandingToolResults.add(toolExecutionId); // Note: we don't await the tool execution here (by leaving out 'await' on recordSpan),