Skip to content

Commit 03bf9a2

Browse files
fix(frontend): render all parallel tool calls instead of last only
Fixes #266 MessageGroup previously rendered only the array-last tool-call step from the most recent AI message, masking in-flight siblings when tools fired in parallel and finished out of order. Extracted convertToToolCallSteps and a new partitionStepsForDisplay helper into core/tools/utils.ts so all tool-call steps sharing the most-recent AI tool-call message id render as siblings, each with its own pending/result state. isLast=true is reserved for the final sibling so artifact auto-open still fires once. Tool results pair by tool_call_id, handling out-of-order completion. Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 8ba01df commit 03bf9a2

3 files changed

Lines changed: 236 additions & 90 deletions

File tree

frontend/src/components/workspace/messages/message-group.tsx

Lines changed: 26 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ import {
2424
import { CodeBlock } from "@/components/ai-elements/code-block";
2525
import { Button } from "@/components/ui/button";
2626
import { useI18n } from "@/core/i18n/hooks";
27-
import {
28-
extractReasoningContentFromMessage,
29-
findToolCallResult,
30-
} from "@/core/messages/utils";
3127
import { useRehypeSplitWordsIntoSpans } from "@/core/rehype";
28+
import {
29+
convertToToolCallSteps,
30+
partitionStepsForDisplay,
31+
} from "@/core/tools/utils";
3232
import { extractTitleFromMarkdown } from "@/core/utils/markdown";
3333
import { env } from "@/env";
3434
import { cn } from "@/lib/utils";
@@ -55,18 +55,15 @@ export function MessageGroup({
5555
const [showLastThinking, setShowLastThinking] = useState(
5656
env.NEXT_PUBLIC_STATIC_WEBSITE_ONLY === "true",
5757
);
58-
const steps = useMemo(() => convertToSteps(messages), [messages]);
59-
const lastToolCallStep = useMemo(() => {
60-
const filteredSteps = steps.filter((step) => step.type === "toolCall");
61-
return filteredSteps[filteredSteps.length - 1];
62-
}, [steps]);
63-
const aboveLastToolCallSteps = useMemo(() => {
64-
if (lastToolCallStep) {
65-
const index = steps.indexOf(lastToolCallStep);
66-
return steps.slice(0, index);
67-
}
68-
return [];
69-
}, [lastToolCallStep, steps]);
58+
const steps = useMemo(() => convertToToolCallSteps(messages), [messages]);
59+
const {
60+
aboveSteps: aboveLastToolCallSteps,
61+
activeSteps: activeToolCallSteps,
62+
} = useMemo(() => partitionStepsForDisplay(steps), [steps]);
63+
const lastToolCallStep = useMemo(
64+
() => activeToolCallSteps[activeToolCallSteps.length - 1],
65+
[activeToolCallSteps],
66+
);
7067
const lastReasoningStep = useMemo(() => {
7168
if (lastToolCallStep) {
7269
const index = steps.indexOf(lastToolCallStep);
@@ -127,16 +124,19 @@ export function MessageGroup({
127124
<ToolCall key={step.id} {...step} isLoading={isLoading} />
128125
),
129126
)}
130-
{lastToolCallStep && (
131-
<FlipDisplay uniqueKey={lastToolCallStep.id ?? ""}>
132-
<ToolCall
133-
key={lastToolCallStep.id}
134-
{...lastToolCallStep}
135-
isLast={true}
136-
isLoading={isLoading}
137-
/>
138-
</FlipDisplay>
139-
)}
127+
{activeToolCallSteps.map((step, index) => {
128+
const isLast = index === activeToolCallSteps.length - 1;
129+
return (
130+
<FlipDisplay key={step.id} uniqueKey={step.id ?? ""}>
131+
<ToolCall
132+
key={step.id}
133+
{...step}
134+
isLast={isLast}
135+
isLoading={isLoading}
136+
/>
137+
</FlipDisplay>
138+
);
139+
})}
140140
</ChainOfThoughtContent>
141141
)}
142142
{lastReasoningStep && (
@@ -422,65 +422,3 @@ function ToolCall({
422422
);
423423
}
424424
}
425-
426-
interface GenericCoTStep<T extends string = string> {
427-
id?: string;
428-
messageId?: string;
429-
type: T;
430-
}
431-
432-
interface CoTReasoningStep extends GenericCoTStep<"reasoning"> {
433-
reasoning: string | null;
434-
}
435-
436-
interface CoTToolCallStep extends GenericCoTStep<"toolCall"> {
437-
name: string;
438-
args: Record<string, unknown>;
439-
result?: string;
440-
}
441-
442-
type CoTStep = CoTReasoningStep | CoTToolCallStep;
443-
444-
function convertToSteps(messages: Message[]): CoTStep[] {
445-
const steps: CoTStep[] = [];
446-
for (const message of messages) {
447-
if (message.type === "ai") {
448-
const reasoning = extractReasoningContentFromMessage(message);
449-
if (reasoning) {
450-
const step: CoTReasoningStep = {
451-
id: message.id,
452-
messageId: message.id,
453-
type: "reasoning",
454-
reasoning,
455-
};
456-
steps.push(step);
457-
}
458-
for (const tool_call of message.tool_calls ?? []) {
459-
if (tool_call.name === "task") {
460-
continue;
461-
}
462-
const step: CoTToolCallStep = {
463-
id: tool_call.id,
464-
messageId: message.id,
465-
type: "toolCall",
466-
name: tool_call.name,
467-
args: tool_call.args,
468-
};
469-
const toolCallId = tool_call.id;
470-
if (toolCallId) {
471-
const toolCallResult = findToolCallResult(toolCallId, messages);
472-
if (toolCallResult) {
473-
try {
474-
const json = JSON.parse(toolCallResult);
475-
step.result = json;
476-
} catch {
477-
step.result = toolCallResult;
478-
}
479-
}
480-
}
481-
steps.push(step);
482-
}
483-
}
484-
}
485-
return steps;
486-
}

frontend/src/core/tools/utils.ts

Lines changed: 93 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
import type { ToolCall } from "@langchain/core/messages";
2-
import type { AIMessage } from "@langchain/langgraph-sdk";
2+
import type { AIMessage, Message } from "@langchain/langgraph-sdk";
33

44
import type { Translations } from "../i18n";
5-
import { hasToolCalls } from "../messages/utils";
5+
import {
6+
extractReasoningContentFromMessage,
7+
findToolCallResult,
8+
hasToolCalls,
9+
} from "../messages/utils";
610

711
export function explainLastToolCall(message: AIMessage, t: Translations) {
812
if (hasToolCalls(message)) {
@@ -27,3 +31,90 @@ export function explainToolCall(toolCall: ToolCall, t: Translations) {
2731
return t.toolCalls.useTool(toolCall.name);
2832
}
2933
}
34+
35+
interface GenericCoTStep<T extends string = string> {
36+
id?: string;
37+
messageId?: string;
38+
type: T;
39+
}
40+
41+
export interface CoTReasoningStep extends GenericCoTStep<"reasoning"> {
42+
reasoning: string | null;
43+
}
44+
45+
export interface CoTToolCallStep extends GenericCoTStep<"toolCall"> {
46+
name: string;
47+
args: Record<string, unknown>;
48+
result?: string | Record<string, unknown>;
49+
}
50+
51+
export type CoTStep = CoTReasoningStep | CoTToolCallStep;
52+
53+
export function convertToToolCallSteps(messages: Message[]): CoTStep[] {
54+
const steps: CoTStep[] = [];
55+
for (const message of messages) {
56+
if (message.type !== "ai") {
57+
continue;
58+
}
59+
const reasoning = extractReasoningContentFromMessage(message);
60+
if (reasoning) {
61+
steps.push({
62+
id: message.id,
63+
messageId: message.id,
64+
type: "reasoning",
65+
reasoning,
66+
});
67+
}
68+
for (const tool_call of message.tool_calls ?? []) {
69+
if (tool_call.name === "task") {
70+
continue;
71+
}
72+
const step: CoTToolCallStep = {
73+
id: tool_call.id,
74+
messageId: message.id,
75+
type: "toolCall",
76+
name: tool_call.name,
77+
args: tool_call.args,
78+
};
79+
const toolCallId = tool_call.id;
80+
if (toolCallId) {
81+
const toolCallResult = findToolCallResult(toolCallId, messages);
82+
if (toolCallResult) {
83+
try {
84+
step.result = JSON.parse(toolCallResult);
85+
} catch {
86+
step.result = toolCallResult;
87+
}
88+
}
89+
}
90+
steps.push(step);
91+
}
92+
}
93+
return steps;
94+
}
95+
96+
export function partitionStepsForDisplay(steps: CoTStep[]): {
97+
aboveSteps: CoTStep[];
98+
activeSteps: CoTToolCallStep[];
99+
} {
100+
const toolCallSteps = steps.filter(
101+
(step): step is CoTToolCallStep => step.type === "toolCall",
102+
);
103+
if (toolCallSteps.length === 0) {
104+
return { aboveSteps: [], activeSteps: [] };
105+
}
106+
107+
const lastAIMessageId = toolCallSteps[toolCallSteps.length - 1]!.messageId;
108+
const activeSteps = toolCallSteps.filter(
109+
(step) =>
110+
step.messageId !== undefined && step.messageId === lastAIMessageId,
111+
);
112+
if (activeSteps.length === 0) {
113+
return { aboveSteps: [], activeSteps: [] };
114+
}
115+
116+
const firstActiveIndex = steps.indexOf(activeSteps[0]!);
117+
const aboveSteps = steps.slice(0, firstActiveIndex);
118+
119+
return { aboveSteps, activeSteps };
120+
}
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import type { Message } from "@langchain/langgraph-sdk";
2+
import { expect, test } from "vitest";
3+
4+
import {
5+
convertToToolCallSteps,
6+
partitionStepsForDisplay,
7+
} from "@/core/tools/utils";
8+
9+
function aiMessage(
10+
id: string,
11+
toolCalls: { id: string; name: string; args?: Record<string, unknown> }[],
12+
reasoning?: string,
13+
): Message {
14+
return {
15+
type: "ai",
16+
id,
17+
content: "",
18+
tool_calls: toolCalls.map((tc) => ({
19+
id: tc.id,
20+
name: tc.name,
21+
args: tc.args ?? {},
22+
})),
23+
additional_kwargs: reasoning ? { reasoning_content: reasoning } : {},
24+
} as Message;
25+
}
26+
27+
function toolMessage(toolCallId: string, content: string): Message {
28+
return {
29+
type: "tool",
30+
id: `tool-msg-${toolCallId}`,
31+
tool_call_id: toolCallId,
32+
content,
33+
} as Message;
34+
}
35+
36+
test("a single tool call in the latest AI message stays the only active step", () => {
37+
const messages: Message[] = [
38+
aiMessage("ai-1", [
39+
{ id: "call-1", name: "web_search", args: { query: "x" } },
40+
]),
41+
toolMessage("call-1", "[]"),
42+
aiMessage("ai-2", [
43+
{ id: "call-2", name: "web_fetch", args: { url: "u" } },
44+
]),
45+
];
46+
const steps = convertToToolCallSteps(messages);
47+
const { aboveSteps, activeSteps } = partitionStepsForDisplay(steps);
48+
49+
expect(activeSteps.map((s) => s.id)).toEqual(["call-2"]);
50+
expect(aboveSteps.map((s) => s.id)).toEqual(["call-1"]);
51+
});
52+
53+
test("all parallel siblings stay active until each one completes", () => {
54+
const messages: Message[] = [
55+
aiMessage("ai-1", [
56+
{ id: "call-1", name: "web_search", args: { query: "a" } },
57+
{ id: "call-2", name: "web_search", args: { query: "b" } },
58+
{ id: "call-3", name: "web_search", args: { query: "c" } },
59+
]),
60+
toolMessage("call-2", "[]"),
61+
];
62+
const steps = convertToToolCallSteps(messages);
63+
const { aboveSteps, activeSteps } = partitionStepsForDisplay(steps);
64+
65+
expect(activeSteps.map((s) => s.id)).toEqual(["call-1", "call-2", "call-3"]);
66+
expect(aboveSteps).toEqual([]);
67+
});
68+
69+
test("parallel tool results pair by tool_call_id regardless of arrival order", () => {
70+
const messages: Message[] = [
71+
aiMessage("ai-1", [
72+
{ id: "call-1", name: "web_search", args: { query: "a" } },
73+
{ id: "call-2", name: "web_search", args: { query: "b" } },
74+
]),
75+
toolMessage("call-2", '[{"url":"u2","title":"t2"}]'),
76+
toolMessage("call-1", '[{"url":"u1","title":"t1"}]'),
77+
];
78+
const steps = convertToToolCallSteps(messages);
79+
const toolSteps = steps.filter((s) => s.type === "toolCall");
80+
const byId = new Map(toolSteps.map((s) => [s.id, s]));
81+
expect(byId.get("call-1")?.result).toEqual([{ url: "u1", title: "t1" }]);
82+
expect(byId.get("call-2")?.result).toEqual([{ url: "u2", title: "t2" }]);
83+
});
84+
85+
test("reasoning emitted with parallel tool calls stays visible above the active batch", () => {
86+
const messages: Message[] = [
87+
aiMessage(
88+
"ai-1",
89+
[
90+
{ id: "call-1", name: "web_search", args: { query: "a" } },
91+
{ id: "call-2", name: "web_search", args: { query: "b" } },
92+
],
93+
"considering both queries in parallel",
94+
),
95+
];
96+
const steps = convertToToolCallSteps(messages);
97+
const { aboveSteps, activeSteps } = partitionStepsForDisplay(steps);
98+
99+
expect(aboveSteps.map((s) => s.type)).toEqual(["reasoning"]);
100+
expect(activeSteps.map((s) => s.id)).toEqual(["call-1", "call-2"]);
101+
});
102+
103+
test("earlier serial tool calls collapse above a fresh parallel batch", () => {
104+
const messages: Message[] = [
105+
aiMessage("ai-1", [{ id: "call-0", name: "ls", args: { path: "/" } }]),
106+
toolMessage("call-0", "ok"),
107+
aiMessage("ai-2", [
108+
{ id: "call-1", name: "web_search", args: { query: "a" } },
109+
{ id: "call-2", name: "web_search", args: { query: "b" } },
110+
]),
111+
];
112+
const steps = convertToToolCallSteps(messages);
113+
const { aboveSteps, activeSteps } = partitionStepsForDisplay(steps);
114+
115+
expect(aboveSteps.map((s) => s.id)).toEqual(["call-0"]);
116+
expect(activeSteps.map((s) => s.id)).toEqual(["call-1", "call-2"]);
117+
});

0 commit comments

Comments
 (0)