From bfd2e3d66e82dc6521b77df665566b1651637fce Mon Sep 17 00:00:00 2001 From: hzhaoy Date: Tue, 2 Jun 2026 11:25:19 +0800 Subject: [PATCH 1/2] Preserve TUI tool-call context across streaming updates Tool calls arrive through several LangChain and DeepAgents shapes, including partial tool_call_chunks and later updates snapshots. The TUI now normalizes those shapes, accumulates argument fragments, and keeps message segments from bleeding across tool boundaries. Constraint: LangChain streams tool-call arguments through partial chunks and may only provide complete arguments in updates snapshots Rejected: Only display arguments from token.tool_calls | real streams often expose empty args there Rejected: Treat result text containing Error as failure | source files can legitimately contain that text Confidence: high Scope-risk: moderate Directive: Keep tool-call argument handling covered by stream_handler regression tests before changing accumulator behavior Tested: uv run pytest tests/tui/test_stream_handler.py -q --no-cov Tested: uv run pytest Not-tested: Manual TUI session after commit Co-authored-by: OmX --- .../tui/bridge/stream_handler.py | 396 +++++++++++++++--- tests/tui/test_stream_handler.py | 296 ++++++++++++- 2 files changed, 643 insertions(+), 49 deletions(-) diff --git a/src/deep_code_agent/tui/bridge/stream_handler.py b/src/deep_code_agent/tui/bridge/stream_handler.py index 30d0c80..10f915c 100644 --- a/src/deep_code_agent/tui/bridge/stream_handler.py +++ b/src/deep_code_agent/tui/bridge/stream_handler.py @@ -1,6 +1,7 @@ """Stream handler for processing LangGraph Agent output.""" import asyncio +import json from dataclasses import dataclass from enum import Enum, auto from typing import Any, AsyncIterator @@ -63,7 +64,274 @@ def __init__(self, agent, config: dict): self.config = config self._message_chunks: list[str] = [] self._interrupted = False - self._seen_tool_call_ids: set[str] = set() + self._seen_tool_call_signatures: set[tuple[str, str, str]] = set() + self._tool_arg_fragments: dict[str, str] = {} + self._seen_tool_arg_fragment_keys: set[tuple[str, int | None, str]] = set() + self._tool_names_by_id: dict[str, str] = {} + self._tool_ids_by_index: dict[int, str] = {} + + def _reset_run_state(self) -> None: + self._message_chunks.clear() + self._interrupted = False + self._seen_tool_call_signatures.clear() + self._tool_arg_fragments.clear() + self._seen_tool_arg_fragment_keys.clear() + self._tool_names_by_id.clear() + self._tool_ids_by_index.clear() + + def _coerce_tool_args(self, args: Any) -> dict: + """Normalize tool arguments from dict, JSON string, or object forms.""" + if args is None: + return {} + + if hasattr(args, "model_dump"): + try: + args = args.model_dump() + except Exception: + pass + + if isinstance(args, dict): + return args + + if isinstance(args, str): + args_text = args.strip() + if not args_text: + return {} + try: + parsed = json.loads(args_text) + except json.JSONDecodeError: + return {"arguments": args_text} + if isinstance(parsed, dict): + return parsed + return {"arguments": parsed} + + return {"value": args} + + def _tool_call_signature(self, tool_id: str, tool_name: str, tool_args: dict) -> tuple[str, str, str]: + try: + args_signature = json.dumps(tool_args, sort_keys=True, default=str) + except TypeError: + args_signature = str(tool_args) + return (tool_id, tool_name, args_signature) + + def _iter_tool_call_chunks(self, token: Any) -> list[dict]: + """Collect provider/LangChain tool-call chunk shapes from a token.""" + chunks: list[dict] = [] + seen_chunks: set[tuple[str | None, str | None, str, int | None]] = set() + + def add_chunk(item: dict) -> None: + raw_args = item.get("args") if "args" in item else item.get("arguments") + key = ( + item.get("id"), + item.get("name"), + str(raw_args), + item.get("index"), + ) + if key in seen_chunks: + return + seen_chunks.add(key) + chunks.append( + { + "id": item.get("id"), + "name": item.get("name"), + "args": raw_args, + "index": item.get("index"), + "type": item.get("type"), + } + ) + + for item in getattr(token, "tool_call_chunks", None) or []: + if hasattr(item, "model_dump"): + try: + item = item.model_dump() + except Exception: + pass + if isinstance(item, dict): + add_chunk(item) + + additional_kwargs = getattr(token, "additional_kwargs", None) + if isinstance(additional_kwargs, dict): + for item in additional_kwargs.get("tool_calls") or []: + if not isinstance(item, dict): + continue + function = item.get("function") if isinstance(item.get("function"), dict) else {} + add_chunk( + { + "id": item.get("id"), + "name": item.get("name") or function.get("name"), + "args": item.get("args") if "args" in item else function.get("arguments"), + "index": item.get("index"), + "type": item.get("type"), + } + ) + + for item in getattr(token, "content_blocks", None) or []: + if hasattr(item, "model_dump"): + try: + item = item.model_dump() + except Exception: + pass + if not isinstance(item, dict): + continue + block_type = str(item.get("type", "")) + if "tool_call" not in block_type: + continue + add_chunk( + { + "id": item.get("id"), + "name": item.get("name"), + "args": item.get("args") if "args" in item else item.get("arguments"), + "index": item.get("index"), + "type": item.get("type"), + } + ) + + return chunks + + def _args_from_tool_call_chunks(self, tool_id: str, tool_name: str, chunks: list[dict]) -> dict: + for chunk in chunks: + chunk_id = chunk.get("id") + chunk_name = chunk.get("name") + if chunk_id not in (None, tool_id) and chunk_id != tool_id: + continue + if chunk_id is None and chunk_name not in (None, tool_name): + continue + + raw_args = chunk.get("args") + direct_args = self._coerce_tool_args(raw_args) + if not (isinstance(raw_args, str) and set(direct_args) == {"arguments"}): + if direct_args: + return direct_args + continue + + fragment = raw_args.strip() + if not fragment: + continue + fragment_key = (tool_id, chunk.get("index"), fragment) + if fragment_key in self._seen_tool_arg_fragment_keys: + continue + self._seen_tool_arg_fragment_keys.add(fragment_key) + accumulated = self._tool_arg_fragments.get(tool_id, "") + combined = f"{accumulated}{fragment}" + self._tool_arg_fragments[tool_id] = combined + combined_args = self._coerce_tool_args(combined) + if not (set(combined_args) == {"arguments"} and isinstance(combined_args.get("arguments"), str)): + return combined_args + + return {} + + def _tool_call_debug_metadata(self, token: Any, tool_call_chunks: list[dict], tc_preview: Any) -> dict: + debug_candidates = { + "token.type": type(token).__name__, + "token.name": getattr(token, "name", None), + "token.tool_call_chunks": tool_call_chunks[:3], + } + additional_kwargs = getattr(token, "additional_kwargs", None) + if isinstance(additional_kwargs, dict): + debug_candidates["token.additional_kwargs.keys"] = list(additional_kwargs.keys())[:20] + return {"debug_tool_call": debug_candidates, "debug_tc_preview": tc_preview} + + def _remember_tool_call(self, tool_id: str, tool_name: str, chunks: list[dict] | None = None) -> None: + self._tool_names_by_id[tool_id] = tool_name + for chunk in chunks or []: + chunk_index = chunk.get("index") + if not isinstance(chunk_index, int): + continue + if chunk.get("id") == tool_id: + self._tool_ids_by_index[chunk_index] = tool_id + + def _normalize_tool_call(self, tc: Any, *, fallback_name: str | None = None) -> dict | None: + """Normalize a tool call from dict/object/update snapshot shapes.""" + if hasattr(tc, "model_dump"): + try: + tc = tc.model_dump() + except Exception: + pass + + if isinstance(tc, dict): + function = tc.get("function") if isinstance(tc.get("function"), dict) else {} + tool_name = tc.get("name") or tc.get("tool_name") or function.get("name") or "" + raw_args = tc.get("args") if "args" in tc else function.get("arguments") + tool_args = self._coerce_tool_args(raw_args) + tool_id = tc.get("id") or tc.get("tool_call_id") or "" + else: + tool_name = getattr(tc, "name", "") or getattr(tc, "tool_name", "") + tool_args = self._coerce_tool_args(getattr(tc, "args", None)) + tool_id = getattr(tc, "id", "") or getattr(tc, "tool_call_id", "") + + tool_name = str(tool_name or fallback_name or "").strip() + if isinstance(tool_id, str): + tool_id = tool_id.strip() + elif tool_id is None: + tool_id = "" + else: + tool_id = str(tool_id).strip() + + if not tool_name or not tool_id: + return None + return {"name": tool_name, "args": tool_args, "id": tool_id} + + def _find_tool_calls_payload(self, chunk: Any, *, max_depth: int = 5) -> list[dict]: + """Find normalized tool calls/action requests in an updates chunk.""" + if max_depth < 0: + return [] + + if hasattr(chunk, "model_dump"): + try: + chunk = chunk.model_dump() + except Exception: + pass + + found: list[dict] = [] + if isinstance(chunk, dict): + for key in ("tool_calls", "action_requests"): + calls = chunk.get(key) + if isinstance(calls, list): + for call in calls: + normalized = self._normalize_tool_call(call) + if normalized is not None: + found.append(normalized) + + for value in chunk.values(): + found.extend(self._find_tool_calls_payload(value, max_depth=max_depth - 1)) + + elif isinstance(chunk, list): + for item in chunk: + found.extend(self._find_tool_calls_payload(item, max_depth=max_depth - 1)) + + else: + calls = getattr(chunk, "tool_calls", None) + if isinstance(calls, list): + for call in calls: + normalized = self._normalize_tool_call(call) + if normalized is not None: + found.append(normalized) + + deduped: list[dict] = [] + seen: set[tuple[str, str, str]] = set() + for call in found: + signature = self._tool_call_signature(call["id"], call["name"], call["args"]) + if signature in seen: + continue + seen.add(signature) + deduped.append(call) + return deduped + + def _emit_tool_call_event(self, tool_call: dict, *, metadata: dict | None = None) -> AgentEvent | None: + tool_id = tool_call["id"] + tool_name = tool_call["name"] + tool_args = tool_call["args"] + + self._remember_tool_call(tool_id, tool_name) + tool_call_signature = self._tool_call_signature(tool_id, tool_name, tool_args) + if tool_call_signature in self._seen_tool_call_signatures: + return None + self._seen_tool_call_signatures.add(tool_call_signature) + return AgentEvent( + type=EventType.TOOL_CALL, + data={"name": tool_name, "args": tool_args, "id": tool_id}, + metadata=metadata, + ) def _normalize_todo_item(self, item: Any) -> dict[str, str] | None: """Normalize one todo item from dict/object forms. @@ -151,25 +419,20 @@ async def _process_stream(self, stream, include_tool_calls: bool = True) -> Asyn async for mode, chunk in stream: if mode == "messages": token, metadata = chunk + tool_call_chunks = self._iter_tool_call_chunks(token) # Check for tool calls in the message if include_tool_calls and hasattr(token, "tool_calls") and token.tool_calls: for tc in token.tool_calls: - # Handle both dict and object formats + tool_call = self._normalize_tool_call(tc, fallback_name=getattr(token, "name", None)) + if tool_call is None: + continue + tool_id = tool_call["id"] + tool_name = tool_call["name"] + tool_args = tool_call["args"] if isinstance(tc, dict): - tool_name = tc.get("name", "") - if not tool_name: - tool_name = tc.get("tool_name", "") - if not tool_name and isinstance(tc.get("function"), dict): - tool_name = tc["function"].get("name", "") - tool_args = tc.get("args", {}) - tool_id = tc.get("id", "") tc_preview = {k: tc.get(k) for k in list(tc.keys())[:12]} else: - # Handle object format (e.g., ToolCall object) - tool_name = getattr(tc, "name", "") - tool_args = getattr(tc, "args", {}) or {} - tool_id = getattr(tc, "id", "") tc_preview = { "type": type(tc).__name__, "name": getattr(tc, "name", None), @@ -177,25 +440,10 @@ async def _process_stream(self, stream, include_tool_calls: bool = True) -> Asyn "id": getattr(tc, "id", None), } - # If tool_name is empty, try to get it from the token's name attribute - # (some LLM providers put the tool name there) - if not tool_name and hasattr(token, "name"): - tool_name = token.name - - tool_name = (tool_name or "").strip() - if isinstance(tool_id, str): - tool_id = tool_id.strip() or None - elif tool_id is None: - tool_id = None - else: - tool_id = str(tool_id).strip() or None - - tool_args = tool_args or {} - if not tool_id or not tool_name: - continue - if tool_id in self._seen_tool_call_ids: - continue - self._seen_tool_call_ids.add(tool_id) + self._message_chunks.clear() + self._remember_tool_call(tool_id, tool_name, tool_call_chunks) + if not tool_args: + tool_args = self._args_from_tool_call_chunks(tool_id, tool_name, tool_call_chunks) debug_candidates: dict = {} if isinstance(tc, dict): @@ -215,14 +463,56 @@ async def _process_stream(self, stream, include_tool_calls: bool = True) -> Asyn "tc.id": getattr(tc, "id", None), "tc.type": type(tc).__name__, } - debug_candidates["token.type"] = type(token).__name__ - debug_candidates["token.name"] = getattr(token, "name", None) + event = self._emit_tool_call_event( + {"name": tool_name, "args": tool_args, "id": tool_id}, + metadata={ + "debug_tool_call": { + **debug_candidates, + "token.type": type(token).__name__, + "token.name": getattr(token, "name", None), + "token.tool_call_chunks": tool_call_chunks[:3], + "token.additional_kwargs.keys": ( + list(getattr(token, "additional_kwargs", {}).keys())[:20] + if isinstance(getattr(token, "additional_kwargs", None), dict) + else None + ), + }, + "debug_tc_preview": tc_preview, + }, + ) + if event is not None: + yield event - yield AgentEvent( - type=EventType.TOOL_CALL, - data={"name": tool_name, "args": tool_args, "id": tool_id}, - metadata={"debug_tool_call": debug_candidates, "debug_tc_preview": tc_preview}, + if include_tool_calls and tool_call_chunks: + for tool_chunk in tool_call_chunks: + tool_id = tool_chunk.get("id") + if isinstance(tool_id, str): + tool_id = tool_id.strip() or None + chunk_index = tool_chunk.get("index") + if not tool_id and isinstance(chunk_index, int): + tool_id = self._tool_ids_by_index.get(chunk_index) + if not tool_id: + continue + self._message_chunks.clear() + + tool_name = tool_chunk.get("name") or self._tool_names_by_id.get(tool_id, "") + tool_name = str(tool_name).strip() + if not tool_name: + continue + self._remember_tool_call(tool_id, tool_name, [tool_chunk]) + if isinstance(chunk_index, int): + self._tool_ids_by_index[chunk_index] = tool_id + + tool_args = self._args_from_tool_call_chunks(tool_id, tool_name, [tool_chunk]) + if not tool_args: + continue + + event = self._emit_tool_call_event( + {"name": tool_name, "args": tool_args, "id": tool_id}, + metadata=self._tool_call_debug_metadata(token, [tool_chunk], tool_chunk), ) + if event is not None: + yield event # Check for tool execution results (ToolMessage) if include_tool_calls and hasattr(token, "name") and hasattr(token, "content"): @@ -230,13 +520,12 @@ async def _process_stream(self, stream, include_tool_calls: bool = True) -> Asyn from langchain_core.messages import ToolMessage if isinstance(token, ToolMessage): - # Check if content indicates an error + self._message_chunks.clear() content_str = str(token.content) if token.content is not None else "(empty result)" if not content_str.strip(): content_str = "(empty result)" - is_error = any( - word in content_str.lower() for word in ["error", "failed", "exception", "traceback"] - ) + status = getattr(token, "status", None) + is_error = status == "error" yield AgentEvent( type=EventType.TOOL_SUCCESS if not is_error else EventType.TOOL_ERROR, @@ -262,6 +551,21 @@ async def _process_stream(self, stream, include_tool_calls: bool = True) -> Asyn if todos: yield AgentEvent(type=EventType.TODOS_UPDATE, data=todos) + if include_tool_calls: + for tool_call in self._find_tool_calls_payload(chunk): + event = self._emit_tool_call_event( + tool_call, + metadata={ + "debug_tool_call": { + "updates.tool_call": tool_call, + "updates.type": type(chunk).__name__, + }, + "debug_tc_preview": tool_call, + }, + ) + if event is not None: + yield event + async def process(self, state: dict) -> AsyncIterator[AgentEvent]: """Process a user request and yield events. @@ -271,9 +575,7 @@ async def process(self, state: dict) -> AsyncIterator[AgentEvent]: Yields: AgentEvent instances """ - self._message_chunks.clear() - self._interrupted = False - self._seen_tool_call_ids.clear() + self._reset_run_state() try: # Signal that thinking has started @@ -309,9 +611,7 @@ async def resume_with_decision(self, decision: dict) -> AsyncIterator[AgentEvent """ from langgraph.types import Command - self._message_chunks.clear() - self._interrupted = False - self._seen_tool_call_ids.clear() + self._reset_run_state() try: # Handle both single decision and multiple decisions diff --git a/tests/tui/test_stream_handler.py b/tests/tui/test_stream_handler.py index c48b161..7aaf007 100644 --- a/tests/tui/test_stream_handler.py +++ b/tests/tui/test_stream_handler.py @@ -4,10 +4,22 @@ class _FakeToken: - def __init__(self, *, tool_calls=None, content=None, name=None): + def __init__( + self, + *, + tool_calls=None, + content=None, + name=None, + tool_call_chunks=None, + additional_kwargs=None, + content_blocks=None, + ): self.tool_calls = tool_calls self.content = content self.name = name + self.tool_call_chunks = tool_call_chunks or [] + self.additional_kwargs = additional_kwargs or {} + self.content_blocks = content_blocks or [] class _FakeAgent: @@ -44,6 +56,71 @@ def test_event_types_include_new_tool_events(): assert hasattr(EventType, 'TODOS_UPDATE') +def test_tool_message_with_error_text_but_success_status_is_success(): + from langchain_core.messages import ToolMessage + + from deep_code_agent.tui.bridge.stream_handler import EventType, StreamHandler + + token = ToolMessage( + content='Error: 1 """Terminal command execution tools."""', + name="read_file", + tool_call_id="call_read", + status="success", + ) + agent = _FakeAgent(events=[("messages", (token, {}))]) + handler = StreamHandler(agent, config={}) + + events = asyncio.run(_collect(handler.process({"messages": []}))) + tool_events = [e for e in events if e.type in (EventType.TOOL_SUCCESS, EventType.TOOL_ERROR)] + assert len(tool_events) == 1 + assert tool_events[0].type == EventType.TOOL_SUCCESS + + +def test_tool_message_with_error_status_is_error(): + from langchain_core.messages import ToolMessage + + from deep_code_agent.tui.bridge.stream_handler import EventType, StreamHandler + + token = ToolMessage( + content="permission denied", + name="read_file", + tool_call_id="call_read", + status="error", + ) + agent = _FakeAgent(events=[("messages", (token, {}))]) + handler = StreamHandler(agent, config={}) + + events = asyncio.run(_collect(handler.process({"messages": []}))) + tool_events = [e for e in events if e.type in (EventType.TOOL_SUCCESS, EventType.TOOL_ERROR)] + assert len(tool_events) == 1 + assert tool_events[0].type == EventType.TOOL_ERROR + + +def test_message_complete_only_contains_current_segment_after_tool_call(): + from deep_code_agent.tui.bridge.stream_handler import EventType, StreamHandler + + first_text = _FakeToken(tool_calls=[], content="Before tool.", name=None) + tool_call = _FakeToken( + tool_calls=[{"name": "read_file", "args": {"file_path": "x.txt"}, "id": "call_read", "type": "tool_call"}], + content=None, + name=None, + ) + second_text = _FakeToken(tool_calls=[], content="After tool.", name=None) + agent = _FakeAgent( + events=[ + ("messages", (first_text, {})), + ("messages", (tool_call, {})), + ("messages", (second_text, {})), + ] + ) + handler = StreamHandler(agent, config={}) + + events = asyncio.run(_collect(handler.process({"messages": []}))) + complete_events = [e for e in events if e.type == EventType.MESSAGE_COMPLETE] + assert len(complete_events) == 1 + assert complete_events[0].data == "After tool." + + def test_filters_incomplete_tool_calls_without_id_or_name(): from deep_code_agent.tui.bridge.stream_handler import EventType, StreamHandler @@ -74,6 +151,223 @@ def test_dedupes_repeated_tool_calls_by_id(): assert tool_call_events[0].data == {"name": "write_file", "args": {}, "id": "call_123"} +def test_extracts_tool_args_from_function_arguments_json(): + from deep_code_agent.tui.bridge.stream_handler import EventType, StreamHandler + + tc = { + "id": "call_read", + "function": { + "name": "read_file", + "arguments": '{"file_path": "src/deep_code_agent/cli.py"}', + }, + } + token = _FakeToken(tool_calls=[tc], content=None, name=None) + agent = _FakeAgent(events=[("messages", (token, {}))]) + handler = StreamHandler(agent, config={}) + + events = asyncio.run(_collect(handler.process({"messages": []}))) + tool_call_events = [e for e in events if e.type == EventType.TOOL_CALL] + assert len(tool_call_events) == 1 + assert tool_call_events[0].data == { + "name": "read_file", + "args": {"file_path": "src/deep_code_agent/cli.py"}, + "id": "call_read", + } + + +def test_emits_later_tool_call_when_arguments_arrive_after_initial_chunk(): + from deep_code_agent.tui.bridge.stream_handler import EventType, StreamHandler + + token1 = _FakeToken( + tool_calls=[{"name": "read_file", "args": {}, "id": "call_read", "type": "tool_call"}], + content=None, + name=None, + ) + token2 = _FakeToken( + tool_calls=[ + { + "name": "read_file", + "args": {"file_path": "src/deep_code_agent/cli.py"}, + "id": "call_read", + "type": "tool_call", + } + ], + content=None, + name=None, + ) + agent = _FakeAgent(events=[("messages", (token1, {})), ("messages", (token2, {}))]) + handler = StreamHandler(agent, config={}) + + events = asyncio.run(_collect(handler.process({"messages": []}))) + tool_call_events = [e for e in events if e.type == EventType.TOOL_CALL] + assert [e.data for e in tool_call_events] == [ + {"name": "read_file", "args": {}, "id": "call_read"}, + {"name": "read_file", "args": {"file_path": "src/deep_code_agent/cli.py"}, "id": "call_read"}, + ] + + +def test_fills_empty_tool_call_args_from_matching_tool_call_chunk(): + from deep_code_agent.tui.bridge.stream_handler import EventType, StreamHandler + + token = _FakeToken( + tool_calls=[{"name": "read_file", "args": {}, "id": "call_read", "type": "tool_call"}], + tool_call_chunks=[ + { + "name": "read_file", + "args": '{"file_path": "src/deep_code_agent/prompts.py"}', + "id": "call_read", + "index": 0, + "type": "tool_call_chunk", + } + ], + content=None, + name=None, + ) + agent = _FakeAgent(events=[("messages", (token, {}))]) + handler = StreamHandler(agent, config={}) + + events = asyncio.run(_collect(handler.process({"messages": []}))) + tool_call_events = [e for e in events if e.type == EventType.TOOL_CALL] + assert len(tool_call_events) == 1 + assert tool_call_events[0].data == { + "name": "read_file", + "args": {"file_path": "src/deep_code_agent/prompts.py"}, + "id": "call_read", + } + + +def test_accumulates_streamed_tool_call_chunk_argument_fragments(): + from deep_code_agent.tui.bridge.stream_handler import EventType, StreamHandler + + token1 = _FakeToken( + tool_calls=[{"name": "read_file", "args": {}, "id": "call_read", "type": "tool_call"}], + tool_call_chunks=[ + {"name": "read_file", "args": '{"file_path": "', "id": "call_read", "index": 0}, + ], + content=None, + name=None, + ) + token2 = _FakeToken( + tool_calls=[], + tool_call_chunks=[ + {"name": None, "args": "src/deep_code_agent/cli.py\"}", "id": "call_read", "index": 0}, + ], + content=None, + name=None, + ) + agent = _FakeAgent(events=[("messages", (token1, {})), ("messages", (token2, {}))]) + handler = StreamHandler(agent, config={}) + + events = asyncio.run(_collect(handler.process({"messages": []}))) + tool_call_events = [e for e in events if e.type == EventType.TOOL_CALL] + assert [e.data for e in tool_call_events] == [ + {"name": "read_file", "args": {}, "id": "call_read"}, + {"name": "read_file", "args": {"file_path": "src/deep_code_agent/cli.py"}, "id": "call_read"}, + ] + + +def test_dedupes_duplicate_tool_call_chunks_from_content_blocks(): + from deep_code_agent.tui.bridge.stream_handler import EventType, StreamHandler + + first_fragment = {"name": "read_file", "args": "{", "id": "call_read", "index": 0, "type": "tool_call_chunk"} + token1 = _FakeToken( + tool_calls=[{"name": "read_file", "args": {}, "id": "call_read", "type": "tool_call"}], + tool_call_chunks=[first_fragment], + content_blocks=[first_fragment], + content=None, + name=None, + ) + token2 = _FakeToken( + tool_calls=[], + tool_call_chunks=[ + { + "name": None, + "args": '"file_path": "src/deep_code_agent/cli.py"}', + "id": "call_read", + "index": 0, + }, + ], + content=None, + name=None, + ) + agent = _FakeAgent(events=[("messages", (token1, {})), ("messages", (token2, {}))]) + handler = StreamHandler(agent, config={}) + + events = asyncio.run(_collect(handler.process({"messages": []}))) + tool_call_events = [e for e in events if e.type == EventType.TOOL_CALL] + assert [e.data for e in tool_call_events] == [ + {"name": "read_file", "args": {}, "id": "call_read"}, + {"name": "read_file", "args": {"file_path": "src/deep_code_agent/cli.py"}, "id": "call_read"}, + ] + + +def test_accumulates_tool_call_chunks_by_index_when_later_chunks_have_no_id(): + from deep_code_agent.tui.bridge.stream_handler import EventType, StreamHandler + + token1 = _FakeToken( + tool_calls=[{"name": "read_file", "args": {}, "id": "call_read", "type": "tool_call"}], + tool_call_chunks=[ + {"name": "read_file", "args": "{", "id": "call_read", "index": 2, "type": "tool_call_chunk"}, + ], + content=None, + name=None, + ) + token2 = _FakeToken( + tool_calls=[], + tool_call_chunks=[ + {"name": None, "args": '"file_path": "src/deep_code_agent/config.py"}', "id": None, "index": 2}, + ], + content=None, + name=None, + ) + agent = _FakeAgent(events=[("messages", (token1, {})), ("messages", (token2, {}))]) + handler = StreamHandler(agent, config={}) + + events = asyncio.run(_collect(handler.process({"messages": []}))) + tool_call_events = [e for e in events if e.type == EventType.TOOL_CALL] + assert [e.data for e in tool_call_events] == [ + {"name": "read_file", "args": {}, "id": "call_read"}, + {"name": "read_file", "args": {"file_path": "src/deep_code_agent/config.py"}, "id": "call_read"}, + ] + + +def test_emits_tool_call_args_from_updates_messages_snapshot(): + from deep_code_agent.tui.bridge.stream_handler import EventType, StreamHandler + + token = _FakeToken( + tool_calls=[{"name": "read_file", "args": {}, "id": "call_read", "type": "tool_call"}], + tool_call_chunks=[ + {"name": "read_file", "args": "{", "id": "call_read", "index": 0, "type": "tool_call_chunk"}, + ], + content=None, + name=None, + ) + snapshot = { + "agent": { + "messages": [ + { + "tool_calls": [ + { + "name": "read_file", + "args": {"file_path": "src/deep_code_agent/code_agent.py"}, + "id": "call_read", + } + ] + } + ] + } + } + agent = _FakeAgent(events=[("messages", (token, {})), ("updates", snapshot)]) + handler = StreamHandler(agent, config={}) + + events = asyncio.run(_collect(handler.process({"messages": []}))) + tool_call_events = [e for e in events if e.type == EventType.TOOL_CALL] + assert [e.data for e in tool_call_events] == [ + {"name": "read_file", "args": {}, "id": "call_read"}, + {"name": "read_file", "args": {"file_path": "src/deep_code_agent/code_agent.py"}, "id": "call_read"}, + ] + + def test_emits_todos_update_from_top_level_updates_chunk(): from deep_code_agent.tui.bridge.stream_handler import EventType, StreamHandler From 20293c7c1bcf67886e7f5e3ea1ea35301ba4e791 Mon Sep 17 00:00:00 2001 From: Elwin <61868295+hzhaoy@users.noreply.github.com> Date: Tue, 2 Jun 2026 12:35:36 +0800 Subject: [PATCH 2/2] Preserve streamed tool arg whitespace --- .../tui/bridge/stream_handler.py | 4 +-- tests/tui/test_stream_handler.py | 30 +++++++++++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/src/deep_code_agent/tui/bridge/stream_handler.py b/src/deep_code_agent/tui/bridge/stream_handler.py index 10f915c..12f2fa8 100644 --- a/src/deep_code_agent/tui/bridge/stream_handler.py +++ b/src/deep_code_agent/tui/bridge/stream_handler.py @@ -204,8 +204,8 @@ def _args_from_tool_call_chunks(self, tool_id: str, tool_name: str, chunks: list return direct_args continue - fragment = raw_args.strip() - if not fragment: + fragment = raw_args + if fragment == "": continue fragment_key = (tool_id, chunk.get("index"), fragment) if fragment_key in self._seen_tool_arg_fragment_keys: diff --git a/tests/tui/test_stream_handler.py b/tests/tui/test_stream_handler.py index 7aaf007..bc71f2a 100644 --- a/tests/tui/test_stream_handler.py +++ b/tests/tui/test_stream_handler.py @@ -266,6 +266,36 @@ def test_accumulates_streamed_tool_call_chunk_argument_fragments(): ] +def test_preserves_whitespace_in_streamed_tool_call_chunk_argument_fragments(): + from deep_code_agent.tui.bridge.stream_handler import EventType, StreamHandler + + token1 = _FakeToken( + tool_calls=[{"name": "terminal", "args": {}, "id": "call_terminal", "type": "tool_call"}], + tool_call_chunks=[ + {"name": "terminal", "args": '{"command": "echo', "id": "call_terminal", "index": 0}, + ], + content=None, + name=None, + ) + token2 = _FakeToken( + tool_calls=[], + tool_call_chunks=[ + {"name": None, "args": ' hello"}', "id": "call_terminal", "index": 0}, + ], + content=None, + name=None, + ) + agent = _FakeAgent(events=[("messages", (token1, {})), ("messages", (token2, {}))]) + handler = StreamHandler(agent, config={}) + + events = asyncio.run(_collect(handler.process({"messages": []}))) + tool_call_events = [e for e in events if e.type == EventType.TOOL_CALL] + assert [e.data for e in tool_call_events] == [ + {"name": "terminal", "args": {}, "id": "call_terminal"}, + {"name": "terminal", "args": {"command": "echo hello"}, "id": "call_terminal"}, + ] + + def test_dedupes_duplicate_tool_call_chunks_from_content_blocks(): from deep_code_agent.tui.bridge.stream_handler import EventType, StreamHandler