From b27230a415366db3ef69b1917365360a8f142ebb Mon Sep 17 00:00:00 2001 From: Reason-Wang Date: Mon, 21 Jul 2025 06:27:47 +0000 Subject: [PATCH 1/8] Initial streaming --- README.md | 3 + agents/agents/agents/agent_base.py | 21 ++ .../agents/agents/chain/README_STREAMING.md | 324 ++++++++++++++++++ agents/agents/agents/chain/chain_base.py | 281 ++++++++++----- .../agents/agents/chain/streaming_observer.py | 259 ++++++++++++++ .../agents/chain/websocket_streaming.py | 227 ++++++++++++ agents/agents/agents/llm_backend.py | 82 ++++- .../examples/multi_chain_streaming_example.py | 307 +++++++++++++++++ agents/agents/examples/streaming_example.py | 227 ++++++++++++ 9 files changed, 1650 insertions(+), 81 deletions(-) create mode 100644 agents/agents/agents/chain/README_STREAMING.md create mode 100644 agents/agents/agents/chain/streaming_observer.py create mode 100644 agents/agents/agents/chain/websocket_streaming.py create mode 100644 agents/agents/examples/multi_chain_streaming_example.py create mode 100644 agents/agents/examples/streaming_example.py diff --git a/README.md b/README.md index b5ffdb2..b00dc00 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,7 @@ # AgentFly: Scalable and Extendable Reinforcement Larning for LLM Agents +
+[📃paper]() | [📖documentation]() +
This library is an extandable framework for building LLM agents with reinforcement learning. It provides a flexible and powerful system for creating agent that interact with tools, learn from rewards in multi-turn manner and complete tasks automatically. ![Overview](assets/images/overview.png) diff --git a/agents/agents/agents/agent_base.py b/agents/agents/agents/agent_base.py index c2a8639..ca89159 100644 --- a/agents/agents/agents/agent_base.py +++ b/agents/agents/agents/agent_base.py @@ -118,6 +118,27 @@ async def generate_async(self, messages_list_or_inputs: List[List[Dict]], **args List of responses. """ return await self.llm_engine.generate_async(messages_list_or_inputs, **args) + + async def generate_streaming(self, messages_list_or_inputs: List[List[Dict]], streaming_callback=None, **args): + """ + Generate responses with streaming support. This method yields response chunks as they are generated. + + Args: + messages_list_or_inputs: List of messages to generate responses for. + streaming_callback: Optional callback function for streaming chunks. + **args: Additional arguments for generation. + + Yields: + str: Response chunks as they are generated. + """ + if hasattr(self.llm_engine, 'generate_streaming'): + async for chunk in self.llm_engine.generate_streaming(messages_list_or_inputs, streaming_callback=streaming_callback, **args): + yield chunk + else: + # Fallback to non-streaming generation + responses = await self.generate_async(messages_list_or_inputs, **args) + for response in responses: + yield response @property def timing_data(self): diff --git a/agents/agents/agents/chain/README_STREAMING.md b/agents/agents/agents/chain/README_STREAMING.md new file mode 100644 index 0000000..2869b05 --- /dev/null +++ b/agents/agents/agents/chain/README_STREAMING.md @@ -0,0 +1,324 @@ +# Streaming Functionality for LLM Agent Reinforcement Learning + +This document describes the streaming functionality added to the LLM agent reinforcement learning framework, which allows real-time monitoring of agent responses and tool observations. + +## Overview + +The streaming functionality provides: + +1. **Real-time LLM response streaming** - See tokens as they are generated +2. **Tool observation streaming** - Monitor tool calls and their results in real-time +3. **Event-based architecture** - Flexible observer pattern for different use cases +4. **Multiple output formats** - Console, JSON, WebSocket, and custom callbacks +5. **Async support** - Non-blocking streaming with proper async/await patterns +6. **Multi-chain support** - Handle multiple chains without mixing outputs + +## Architecture + +### Core Components + +1. **StreamEvent** - Represents a streaming event with metadata +2. **StreamObserver** - Abstract base class for event observers +3. **StreamingManager** - Manages observers and event distribution +4. **ChainGeneration** - Enhanced with streaming support + +### Event Types + +- `LLM_GENERATION_START` - LLM generation begins +- `LLM_GENERATION_CHUNK` - Individual token/chunk generated +- `LLM_GENERATION_END` - LLM generation completes +- `TOOL_CALL_START` - Tool call begins +- `TOOL_CALL_END` - Tool call completes +- `TOOL_OBSERVATION` - Tool observation received +- `CHAIN_START` - Agent chain begins +- `CHAIN_END` - Agent chain completes +- `ERROR` - Error occurred + +### Multi-Chain Problem and Solutions + +**Problem**: When running multiple chains with streaming, outputs from different chains get mixed together, making it impossible to follow which events belong to which chain. + +**Solutions**: + +1. **Automatic Color Coding** - Each chain gets a different color in console output +2. **Chain Filtering** - Filter events to only show specific chains +3. **Chain-Specific Observers** - Create separate observers for each chain +4. **Separate Async Generators** - Process each chain's events independently +5. **Multi-Chain Observer** - Organize observers by chain ID + +## Usage Examples + +### Basic Console Streaming + +```python +from agents.agents.agents.agent_base import BaseAgent +from agents.agents.agents.chain.streaming_observer import ConsoleStreamObserver + +# Initialize agent +agent = YourAgent(model_name="your-model", tools=[...]) + +# Add console observer +console_observer = ConsoleStreamObserver(show_timestamps=True) +agent.streaming_manager.add_observer(console_observer) + +# Run with streaming +await agent.run_async( + max_steps=5, + start_messages=your_messages, + num_chains=1, + enable_streaming=True +) +``` + +### Multi-Chain Streaming Solutions + +When running multiple chains, the streaming output can become mixed. Here are several solutions: + +#### 1. Colored Output (Automatic) + +```python +# Each chain gets a different color automatically +console_observer = ConsoleStreamObserver(show_timestamps=True) +agent.streaming_manager.add_observer(console_observer) + +await agent.run_async( + max_steps=5, + start_messages=your_messages, + num_chains=3, # Multiple chains + enable_streaming=True +) +``` + +#### 2. Chain-Specific Observers + +```python +from agents.agents.agents.chain.streaming_observer import MultiChainStreamObserver, ChainSpecificStreamObserver + +# Create multi-chain observer +multi_observer = MultiChainStreamObserver() + +# Add observers for specific chains +for chain_id in ["chain_0", "chain_1", "chain_2"]: + console_observer = ConsoleStreamObserver(show_timestamps=True) + chain_observer = ChainSpecificStreamObserver(chain_id, console_observer) + multi_observer.add_chain_observer(chain_id, chain_observer) + +agent.streaming_manager.add_observer(multi_observer) +``` + +#### 3. Filter by Chain + +```python +# Only observe events from a specific chain +filtered_observer = ConsoleStreamObserver( + show_timestamps=True, + chain_filter="chain_0" # Only show chain_0 events +) +agent.streaming_manager.add_observer(filtered_observer) +``` + +#### 4. Separate Async Generators per Chain + +```python +from agents.agents.agents.chain.streaming_observer import AsyncGeneratorStreamObserver + +# Create separate generators for each chain +chain_generators = {} +for i in range(3): + chain_id = f"chain_{i}" + async_observer = AsyncGeneratorStreamObserver(chain_filter=chain_id) + agent.streaming_manager.add_observer(async_observer) + chain_generators[chain_id] = async_observer.events() + +# Process each chain separately +async def process_chain(chain_id, generator): + async for event in generator: + print(f"{chain_id}: {event.event_type.value}") + +# Run all chains concurrently +tasks = [process_chain(chain_id, generator) for chain_id, generator in chain_generators.items()] +await asyncio.gather(*tasks) +``` + +### JSON Logging + +```python +from agents.agents.agents.chain.streaming_observer import JSONStreamObserver + +# Add JSON observer +json_observer = JSONStreamObserver(file_path="events.jsonl") +agent.streaming_manager.add_observer(json_observer) +``` + +### Custom Streaming Callback + +```python +async def custom_callback(chunk: str): + print(f"🔄 {chunk}", end="", flush=True) + +await agent.run_async( + max_steps=5, + start_messages=your_messages, + num_chains=1, + enable_streaming=True, + streaming_callback=custom_callback +) +``` + +### WebSocket Streaming + +```python +from agents.agents.agents.chain.websocket_streaming import WebSocketStreamingServer + +# Start WebSocket server +server = WebSocketStreamingServer(host="localhost", port=8765) +await server.start() + +# Add WebSocket observer +agent.streaming_manager.add_observer(server.get_observer()) + +# Run agent +await agent.run_async(..., enable_streaming=True) +``` + +### Async Generator Events + +```python +from agents.agents.agents.chain.streaming_observer import AsyncGeneratorStreamObserver + +# Create async generator observer +async_observer = AsyncGeneratorStreamObserver() +agent.streaming_manager.add_observer(async_observer) + +# Start agent run +run_task = asyncio.create_task(agent.run_async(..., enable_streaming=True)) + +# Process events as they arrive +async for event in async_observer.events(): + print(f"Event: {event.event_type.value}") + if event.event_type.value == "llm_generation_chunk": + print(f"Content: {event.data.get('content', '')}") + +await run_task +``` + +## Backend Support + +### Transformers Backend + +The Transformers backend supports streaming through token-by-token generation: + +```python +agent = YourAgent( + model_name="your-model", + backend="transformers", + # ... other args +) +``` + +### Async vLLM Backend + +The Async vLLM backend provides efficient streaming: + +```python +agent = YourAgent( + model_name="your-model", + backend="async_vllm", + # ... other args +) +``` + +## Event Structure + +Each streaming event contains: + +```python +@dataclass +class StreamEvent: + event_type: StreamEventType + chain_id: str + timestamp: float + data: Dict[str, Any] + step: Optional[int] = None + depth: Optional[int] = None +``` + +### Example Events + +**LLM Generation Chunk:** +```json +{ + "event_type": "llm_generation_chunk", + "chain_id": "uuid-123", + "timestamp": 1234567890.123, + "data": {"content": "def factorial"}, + "step": 1, + "depth": 1 +} +``` + +**Tool Observation:** +```json +{ + "event_type": "tool_observation", + "chain_id": "uuid-123", + "timestamp": 1234567890.456, + "data": { + "tool_name": "code_interpreter", + "observation": "120", + "status": "success" + }, + "step": 1, + "depth": 1 +} +``` + +## Performance Considerations + +1. **Memory Usage** - Streaming events are lightweight but can accumulate +2. **Network Overhead** - WebSocket streaming adds minimal overhead +3. **Backend Compatibility** - Not all backends support streaming equally +4. **Observer Performance** - Heavy observers can slow down the main loop + +## Best Practices + +1. **Use appropriate observers** - Console for debugging, JSON for logging, WebSocket for web apps +2. **Handle errors gracefully** - Implement error handling in custom observers +3. **Clean up resources** - Properly close WebSocket connections and file handles +4. **Monitor performance** - Watch for memory leaks in long-running streams +5. **Test thoroughly** - Streaming adds complexity, test edge cases + +## Integration with Existing Code + +The streaming functionality is designed to be non-intrusive: + +- Existing code continues to work without changes +- Streaming is opt-in via the `run_async_streaming` method +- Observers can be added/removed at runtime +- No performance impact when streaming is disabled + +## Troubleshooting + +### Common Issues + +1. **No streaming output** - Check if backend supports streaming +2. **WebSocket connection issues** - Verify port availability and firewall settings +3. **Memory leaks** - Ensure observers are properly cleaned up +4. **Performance issues** - Consider using fewer observers or lighter implementations + +### Debug Mode + +Enable debug logging to troubleshoot streaming issues: + +```python +import logging +logging.basicConfig(level=logging.DEBUG) +``` + +## Future Enhancements + +1. **Filtering** - Event filtering based on type, chain_id, etc. +2. **Batching** - Batch multiple events for efficiency +3. **Compression** - Compress WebSocket messages for large-scale deployments +4. **Authentication** - Add authentication to WebSocket connections +5. **Metrics** - Built-in streaming metrics and monitoring \ No newline at end of file diff --git a/agents/agents/agents/chain/chain_base.py b/agents/agents/agents/chain/chain_base.py index 0ee77be..f363a3c 100644 --- a/agents/agents/agents/chain/chain_base.py +++ b/agents/agents/agents/chain/chain_base.py @@ -2,8 +2,9 @@ from collections import defaultdict from dataclasses import dataclass, field import json +import time from ...utils.timing import Timer -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union, Callable import uuid from termcolor import colored import numpy as np @@ -13,6 +14,8 @@ from ...utils.monitor import JsonlSink, MetricEvent, Monitor, WandbSink, emit, serialize_for_json from ... import AGENT_DATA_DIR import wandb +from .streaming_observer import StreamingManager, StreamEvent, StreamEventType + @dataclass class Node: is_terminal: bool = False @@ -135,6 +138,7 @@ def __init__(self): self.finished_chains_count = 0 self.initialize_monitor() self.monitor_info = defaultdict(list) + self.streaming_manager = StreamingManager() def reset(self) -> None: self.status_code: str = "continue" @@ -215,33 +219,20 @@ async def run_async(self, max_steps: int, start_messages: Union[List[dict], np.ndarray], num_chains: int, - generation_config: Optional[Dict[str, Any]] = None + generation_config: Optional[Dict[str, Any]] = None, + enable_streaming: bool = False, + streaming_callback: Optional[Callable] = None, ): """ - Run the chain-based rollout. + Run the chain-based rollout with optional streaming support. Args: max_steps: Maximum number of steps for each chain. - start_messages: List of messages to start the chains. Each message should be a dict - with "messages" key containing a list of message dictionaries. + start_messages: List of messages to start the chains. num_chains: Number of chains to run for each message. generation_config: Generation configuration dictionary. - - Example: - .. code-block:: python - - start_messages = [ - { - "messages": [{"role": "user", "content": "..."}], - "info": {"question": "..."}, - "answer": "..." - }, - { - "messages": [{"role": "user", "content": "..."}], - "info": {"question": "..."}, - "answer": "..." - } - ] + enable_streaming: Whether to enable streaming mode. + streaming_callback: Optional callback for streaming events. """ assert max_steps >= 1, "max_steps must be at least 1." Monitor.ensure_started() @@ -254,30 +245,34 @@ async def run_async(self, ) tool_schemas = [tool.schema for tool in self.tools] + # Emit chain start events if streaming is enabled + if enable_streaming: + for cid in first_nodes.keys(): + await self.streaming_manager.emit_event(StreamEvent( + event_type=StreamEventType.CHAIN_START, + chain_id=cid, + timestamp=time.time(), + data={"max_steps": max_steps}, + step=0, + depth=0 + )) + done_q = asyncio.Queue() tasks = [ asyncio.create_task( - self._run_chain_async( + self._run_single_chain( cid, node, chains[cid], tool_schemas, max_steps=max_steps, - done_queue=done_q) + done_queue=done_q, + enable_streaming=enable_streaming, + streaming_callback=streaming_callback) ) for cid, node in first_nodes.items() ] - # Throttle the number of concurrent chains - # print([tool.parallel_size for tool in self.tools]) - - # minimal_tool_parallel_size = 1 - # sem = asyncio.Semaphore(minimal_tool_parallel_size) - # async def guarded_run(cid, *args): - # async with sem: - # return await self._run_chain_async(cid, *args) - # tasks = [guarded_run(cid, node, chains[cid], max_steps, done_q) for cid, node in first_nodes.items()] - # await asyncio.gather(*tasks) await tqdm_asyncio.gather(*tasks) self.chains = {} @@ -289,29 +284,32 @@ async def run_async(self, self.global_step += 1 self.monitor_step() - async def _run_chain_async(self, + async def _run_single_chain(self, chain_id: str, first_node: Node, chain: Chain, tools: List[Dict], max_steps: int, - done_queue: asyncio.Queue + done_queue: asyncio.Queue, + enable_streaming: bool = False, + streaming_callback: Optional[Callable] = None, ): """ - Drives *one* trajectory until it terminates or max_steps is reached. - Writes (chain_id, chain) to done_queue when finished. + Run a single chain with optional streaming support. """ current_node = first_node depth = 0 - final_result = None have_set_tools = False while not current_node.is_terminal and depth < max_steps: newest_messages = current_node.messages + if not current_node.is_terminal: - responses = await self.generate_async([current_node.messages], tools=tools, num_return_sequences=1) - new_msg = self.parse(responses, self.tools) - new_msg = new_msg[0] + # Generate response + new_msg = await self._generate_response( + current_node, tools, depth, chain_id, enable_streaming, streaming_callback + ) + newest_messages.append(new_msg) thought_node = chain.add_node( type="Thought", @@ -321,47 +319,167 @@ async def _run_chain_async(self, thought_node.is_terminal = new_msg.get("status", "continue") in self.terminal_status current_node = thought_node + # Handle tool calls if current_node.messages[-1].get("tool_calls"): for tool_call in current_node.messages[-1]["tool_calls"]: - tool_name = tool_call["function"]["name"] - tool_input = tool_call["function"]["arguments"] - action_node = chain.add_node( - type="Action", - messages=deepcopy(newest_messages), - description=tool_name - ) - if not have_set_tools: - await self.set_tools(chain_id, chain.info) - have_set_tools = True - - result = await submit_tool_call(tool_name, tool_input, id=chain_id) - final_result = result - action_input_node = chain.add_node( - type="Action Input", - messages=deepcopy(newest_messages), - description=result.get("arguments", "") + current_node = await self._execute_tool_call( + tool_call, newest_messages, chain, chain_id, depth, + have_set_tools, enable_streaming ) - observation = result["observation"] - observation_json = json.dumps({ - "name": result["name"], - "content": observation, - }, indent=4) - action_input_node.observation = observation_json - action_input_node.observation_code = result["status"] - newest_messages.append({ - "role": "tool", - "tool_call_id": tool_call["id"], - "content": [{"type": "text", "text": observation_json}], - }) - action_input_node.messages = deepcopy(newest_messages) - action_input_node.is_terminal = result["status"] in self.terminal_status - current_node = action_input_node + have_set_tools = True else: - # If there is no tool call, we assume the chain is finished + # No tool calls, chain is finished break depth += 1 + # Finalize chain + await self._finalize_chain(chain_id, chain, current_node, depth, enable_streaming) + await done_queue.put((chain_id, chain, current_node)) + + self.finished_chains_count += 1 + self.monitor_chain() + + async def _generate_response(self, current_node, tools, depth, chain_id, enable_streaming, streaming_callback): + """Generate response with optional streaming support.""" + if enable_streaming: + # Emit generation start event + await self.streaming_manager.emit_event(StreamEvent( + event_type=StreamEventType.LLM_GENERATION_START, + chain_id=chain_id, + timestamp=time.time(), + data={"depth": depth}, + step=depth, + depth=depth + )) + + # Use streaming generation if available + if hasattr(self, 'generate_streaming') and streaming_callback: + async def stream_callback(chunk: str): + await self.streaming_manager.emit_event(StreamEvent( + event_type=StreamEventType.LLM_GENERATION_CHUNK, + chain_id=chain_id, + timestamp=time.time(), + data={"content": chunk}, + step=depth, + depth=depth + )) + if streaming_callback: + await streaming_callback(chunk) + + # Collect full response from streaming + full_response = "" + async for chunk in self.generate_streaming([current_node.messages], tools=tools, streaming_callback=stream_callback): + full_response += chunk + + # Emit generation end event + await self.streaming_manager.emit_event(StreamEvent( + event_type=StreamEventType.LLM_GENERATION_END, + chain_id=chain_id, + timestamp=time.time(), + data={"full_response": full_response}, + step=depth, + depth=depth + )) + + # Parse response + new_msg = self.parse([full_response], self.tools) + return new_msg[0] + else: + # Fallback to non-streaming generation + responses = await self.generate_async([current_node.messages], tools=tools, num_return_sequences=1) + new_msg = self.parse(responses, self.tools) + return new_msg[0] + else: + # Non-streaming generation + responses = await self.generate_async([current_node.messages], tools=tools, num_return_sequences=1) + new_msg = self.parse(responses, self.tools) + return new_msg[0] + + async def _execute_tool_call(self, tool_call, newest_messages, chain, chain_id, depth, have_set_tools, enable_streaming): + """Execute a tool call with optional streaming support.""" + tool_name = tool_call["function"]["name"] + tool_input = tool_call["function"]["arguments"] + + if enable_streaming: + # Emit tool call start event + await self.streaming_manager.emit_event(StreamEvent( + event_type=StreamEventType.TOOL_CALL_START, + chain_id=chain_id, + timestamp=time.time(), + data={"tool_name": tool_name, "arguments": tool_input}, + step=depth, + depth=depth + )) + + # Create action node + action_node = chain.add_node( + type="Action", + messages=deepcopy(newest_messages), + description=tool_name + ) + + # Set up tools if needed + if not have_set_tools: + await self.set_tools(chain_id, chain.info) + have_set_tools = True + + # Execute tool call + result = await submit_tool_call(tool_name, tool_input, id=chain_id) + + if enable_streaming: + # Emit tool observation event + await self.streaming_manager.emit_event(StreamEvent( + event_type=StreamEventType.TOOL_OBSERVATION, + chain_id=chain_id, + timestamp=time.time(), + data={ + "tool_name": tool_name, + "observation": result["observation"], + "status": result["status"] + }, + step=depth, + depth=depth + )) + + # Emit tool call end event + await self.streaming_manager.emit_event(StreamEvent( + event_type=StreamEventType.TOOL_CALL_END, + chain_id=chain_id, + timestamp=time.time(), + data={"tool_name": tool_name, "result": result}, + step=depth, + depth=depth + )) + + # Create action input node + action_input_node = chain.add_node( + type="Action Input", + messages=deepcopy(newest_messages), + description=result.get("arguments", "") + ) + + # Process observation + observation = result["observation"] + observation_json = json.dumps({ + "name": result["name"], + "content": observation, + }, indent=4) + + action_input_node.observation = observation_json + action_input_node.observation_code = result["status"] + newest_messages.append({ + "role": "tool", + "tool_call_id": tool_call["id"], + "content": [{"type": "text", "text": observation_json}], + }) + action_input_node.messages = deepcopy(newest_messages) + action_input_node.is_terminal = result["status"] in self.terminal_status + + return action_input_node + + async def _finalize_chain(self, chain_id, chain, current_node, depth, enable_streaming): + """Finalize the chain with reward calculation and cleanup.""" if self._reward_fn is not None: trajectory = current_node.messages final_response = self.extract_final_response(trajectory) @@ -369,12 +487,19 @@ async def _run_chain_async(self, chain.info["reward"] = await self._reward_fn(prediction=final_response, **other_args, trajectory=trajectory, id=chain_id) else: chain.info["reward"] = None + await self.release_resources(chain_id) - await done_queue.put((chain_id, chain, current_node)) - - self.finished_chains_count += 1 - self.monitor_chain() + if enable_streaming: + # Emit chain end event + await self.streaming_manager.emit_event(StreamEvent( + event_type=StreamEventType.CHAIN_END, + chain_id=chain_id, + timestamp=time.time(), + data={"final_depth": depth, "reward": chain.info.get("reward")}, + step=depth, + depth=depth + )) async def release_resources(self, id: str) -> None: for tool in self.tools: diff --git a/agents/agents/agents/chain/streaming_observer.py b/agents/agents/agents/chain/streaming_observer.py new file mode 100644 index 0000000..d4de1a3 --- /dev/null +++ b/agents/agents/agents/chain/streaming_observer.py @@ -0,0 +1,259 @@ +import asyncio +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Callable, AsyncGenerator, Set +from enum import Enum +import json +import time +from termcolor import colored + + +class StreamEventType(Enum): + """Types of streaming events""" + LLM_GENERATION_START = "llm_generation_start" + LLM_GENERATION_CHUNK = "llm_generation_chunk" + LLM_GENERATION_END = "llm_generation_end" + TOOL_CALL_START = "tool_call_start" + TOOL_CALL_END = "tool_call_end" + TOOL_OBSERVATION = "tool_observation" + CHAIN_START = "chain_start" + CHAIN_END = "chain_end" + ERROR = "error" + + +@dataclass +class StreamEvent: + """A streaming event with metadata""" + event_type: StreamEventType + chain_id: str + timestamp: float + data: Dict[str, Any] + step: Optional[int] = None + depth: Optional[int] = None + + def __post_init__(self): + # Add a unique identifier for this event + if not hasattr(self, 'event_id'): + self.event_id = f"{self.chain_id}_{self.timestamp}_{self.event_type.value}" + + +class StreamObserver(ABC): + """Abstract base class for stream observers""" + + @abstractmethod + async def on_event(self, event: StreamEvent) -> None: + """Handle a streaming event""" + pass + + async def on_error(self, error: Exception, chain_id: str) -> None: + """Handle an error event""" + event = StreamEvent( + event_type=StreamEventType.ERROR, + chain_id=chain_id, + timestamp=time.time(), + data={"error": str(error), "error_type": type(error).__name__} + ) + await self.on_event(event) + + +class StreamingManager: + """Manages streaming observers and event distribution""" + + def __init__(self): + self.observers: List[StreamObserver] = [] + self.enabled = False + self.active_chains: Set[str] = set() + self.chain_events: Dict[str, List[StreamEvent]] = {} + + def add_observer(self, observer: StreamObserver) -> None: + """Add a streaming observer""" + self.observers.append(observer) + self.enabled = True + + def remove_observer(self, observer: StreamObserver) -> None: + """Remove a streaming observer""" + if observer in self.observers: + self.observers.remove(observer) + if not self.observers: + self.enabled = False + + async def emit_event(self, event: StreamEvent) -> None: + """Emit an event to all observers""" + if not self.enabled: + return + + # Track active chains + if event.event_type == StreamEventType.CHAIN_START: + self.active_chains.add(event.chain_id) + self.chain_events[event.chain_id] = [] + elif event.event_type == StreamEventType.CHAIN_END: + self.active_chains.discard(event.chain_id) + + # Store event for this chain + if event.chain_id in self.chain_events: + self.chain_events[event.chain_id].append(event) + + tasks = [observer.on_event(event) for observer in self.observers] + await asyncio.gather(*tasks, return_exceptions=True) + + async def emit_error(self, error: Exception, chain_id: str) -> None: + """Emit an error event to all observers""" + if not self.enabled: + return + + tasks = [observer.on_error(error, chain_id) for observer in self.observers] + await asyncio.gather(*tasks, return_exceptions=True) + + def get_chain_events(self, chain_id: str) -> List[StreamEvent]: + """Get all events for a specific chain""" + return self.chain_events.get(chain_id, []) + + def get_active_chains(self) -> Set[str]: + """Get all currently active chain IDs""" + return self.active_chains.copy() + + +class ConsoleStreamObserver(StreamObserver): + """Simple console-based stream observer for debugging""" + + def __init__(self, show_timestamps: bool = True, chain_filter: Optional[str] = None): + self.show_timestamps = show_timestamps + self.chain_filter = chain_filter # Only show events for this chain_id + + async def on_event(self, event: StreamEvent) -> None: + # Filter by chain if specified + if self.chain_filter and event.chain_id != self.chain_filter: + return + + timestamp = f"[{event.timestamp:.2f}s] " if self.show_timestamps else "" + step_info = f" (step {event.step})" if event.step is not None else "" + depth_info = f" (depth {event.depth})" if event.depth is not None else "" + + # Use different colors for different chains + chain_colors = ["red", "green", "blue", "yellow", "magenta", "cyan"] + chain_index = hash(event.chain_id) % len(chain_colors) + chain_color = chain_colors[chain_index] + + print(colored(f"{timestamp}Chain {event.chain_id}{step_info}{depth_info}: {event.event_type.value}", color=chain_color)) + + if event.event_type == StreamEventType.LLM_GENERATION_CHUNK: + content = event.data.get("content", "") + if content: + print(colored(f" → {content}", color=chain_color)) + elif event.event_type == StreamEventType.TOOL_OBSERVATION: + observation = event.data.get("observation", "") + tool_name = event.data.get("tool_name", "") + print(colored(f" 🔧 {tool_name}: {observation[:200]}{'...' if len(observation) > 200 else ''}", color=chain_color)) + elif event.event_type == StreamEventType.ERROR: + error_msg = event.data.get("error", "") + print(colored(f" ❌ Error: {error_msg}", color=chain_color)) + + +class JSONStreamObserver(StreamObserver): + """JSON-based stream observer for structured logging""" + + def __init__(self, file_path: Optional[str] = None, chain_filter: Optional[str] = None): + self.file_path = file_path + self.chain_filter = chain_filter + + async def on_event(self, event: StreamEvent) -> None: + # Filter by chain if specified + if self.chain_filter and event.chain_id != self.chain_filter: + return + + event_dict = { + "event_type": event.event_type.value, + "chain_id": event.chain_id, + "timestamp": event.timestamp, + "step": event.step, + "depth": event.depth, + "data": event.data + } + + if self.file_path: + with open(self.file_path, "a") as f: + f.write(json.dumps(event_dict) + "\n") + else: + print(json.dumps(event_dict)) + + +class AsyncGeneratorStreamObserver(StreamObserver): + """Stream observer that yields events as an async generator""" + + def __init__(self, chain_filter: Optional[str] = None): + self.queue = asyncio.Queue() + self.chain_filter = chain_filter + + async def on_event(self, event: StreamEvent) -> None: + # Filter by chain if specified + if self.chain_filter and event.chain_id != self.chain_filter: + return + + await self.queue.put(event) + + async def events(self) -> AsyncGenerator[StreamEvent, None]: + """Yield events as they arrive""" + while True: + try: + event = await self.queue.get() + if event.event_type == StreamEventType.CHAIN_END: + # Send the final event and stop + yield event + break + yield event + except asyncio.CancelledError: + break + + +class ChainSpecificStreamObserver(StreamObserver): + """Stream observer that only handles events for a specific chain""" + + def __init__(self, target_chain_id: str, base_observer: StreamObserver): + self.target_chain_id = target_chain_id + self.base_observer = base_observer + + async def on_event(self, event: StreamEvent) -> None: + if event.chain_id == self.target_chain_id: + await self.base_observer.on_event(event) + + async def on_error(self, error: Exception, chain_id: str) -> None: + if chain_id == self.target_chain_id: + await self.base_observer.on_error(error, chain_id) + + +class MultiChainStreamObserver(StreamObserver): + """Stream observer that organizes events by chain""" + + def __init__(self): + self.chain_observers: Dict[str, List[StreamObserver]] = {} + self.global_observers: List[StreamObserver] = [] + + def add_chain_observer(self, chain_id: str, observer: StreamObserver) -> None: + """Add an observer for a specific chain""" + if chain_id not in self.chain_observers: + self.chain_observers[chain_id] = [] + self.chain_observers[chain_id].append(observer) + + def add_global_observer(self, observer: StreamObserver) -> None: + """Add an observer for all chains""" + self.global_observers.append(observer) + + async def on_event(self, event: StreamEvent) -> None: + # Send to chain-specific observers + if event.chain_id in self.chain_observers: + tasks = [obs.on_event(event) for obs in self.chain_observers[event.chain_id]] + await asyncio.gather(*tasks, return_exceptions=True) + + # Send to global observers + tasks = [obs.on_event(event) for obs in self.global_observers] + await asyncio.gather(*tasks, return_exceptions=True) + + async def on_error(self, error: Exception, chain_id: str) -> None: + # Send to chain-specific observers + if chain_id in self.chain_observers: + tasks = [obs.on_error(error, chain_id) for obs in self.chain_observers[chain_id]] + await asyncio.gather(*tasks, return_exceptions=True) + + # Send to global observers + tasks = [obs.on_error(error, chain_id) for obs in self.global_observers] + await asyncio.gather(*tasks, return_exceptions=True) \ No newline at end of file diff --git a/agents/agents/agents/chain/websocket_streaming.py b/agents/agents/agents/chain/websocket_streaming.py new file mode 100644 index 0000000..4ce4fc6 --- /dev/null +++ b/agents/agents/agents/chain/websocket_streaming.py @@ -0,0 +1,227 @@ +""" +WebSocket-based streaming interface for real-time agent interactions. +This module provides a WebSocket server that can stream agent events to web clients. +""" + +import asyncio +import json +import websockets +from typing import Dict, Set, Optional, Callable +from .streaming_observer import StreamObserver, StreamEvent, StreamEventType +import logging + +logger = logging.getLogger(__name__) + + +class WebSocketStreamObserver(StreamObserver): + """Stream observer that broadcasts events to WebSocket clients""" + + def __init__(self): + self.clients: Set[websockets.WebSocketServerProtocol] = set() + self.lock = asyncio.Lock() + + async def on_event(self, event: StreamEvent) -> None: + """Broadcast event to all connected WebSocket clients""" + if not self.clients: + return + + # Convert event to JSON + event_data = { + "event_type": event.event_type.value, + "chain_id": event.chain_id, + "timestamp": event.timestamp, + "step": event.step, + "depth": event.depth, + "data": event.data + } + + message = json.dumps(event_data) + + # Broadcast to all clients + disconnected_clients = set() + async with self.lock: + for client in self.clients: + try: + await client.send(message) + except websockets.exceptions.ConnectionClosed: + disconnected_clients.add(client) + except Exception as e: + logger.error(f"Error sending to WebSocket client: {e}") + disconnected_clients.add(client) + + # Remove disconnected clients + self.clients -= disconnected_clients + + async def add_client(self, websocket: websockets.WebSocketServerProtocol) -> None: + """Add a new WebSocket client""" + async with self.lock: + self.clients.add(websocket) + logger.info(f"WebSocket client connected. Total clients: {len(self.clients)}") + + async def remove_client(self, websocket: websockets.WebSocketServerProtocol) -> None: + """Remove a WebSocket client""" + async with self.lock: + self.clients.discard(websocket) + logger.info(f"WebSocket client disconnected. Total clients: {len(self.clients)}") + + +class WebSocketStreamingServer: + """WebSocket server for streaming agent events""" + + def __init__(self, host: str = "localhost", port: int = 8765): + self.host = host + self.port = port + self.observer = WebSocketStreamObserver() + self.server = None + + async def handle_client(self, websocket, path): + """Handle individual WebSocket client connection""" + await self.observer.add_client(websocket) + try: + # Keep connection alive and handle incoming messages + async for message in websocket: + try: + data = json.loads(message) + # Handle client messages if needed + logger.info(f"Received message from client: {data}") + except json.JSONDecodeError: + logger.warning(f"Invalid JSON from client: {message}") + except Exception as e: + logger.error(f"Error handling client message: {e}") + except websockets.exceptions.ConnectionClosed: + pass + finally: + await self.observer.remove_client(websocket) + + async def start(self): + """Start the WebSocket server""" + self.server = await websockets.serve( + self.handle_client, + self.host, + self.port + ) + logger.info(f"WebSocket server started on ws://{self.host}:{self.port}") + return self.server + + async def stop(self): + """Stop the WebSocket server""" + if self.server: + self.server.close() + await self.server.wait_closed() + logger.info("WebSocket server stopped") + + def get_observer(self) -> WebSocketStreamObserver: + """Get the WebSocket stream observer""" + return self.observer + + +class WebSocketStreamingClient: + """WebSocket client for receiving streaming events""" + + def __init__(self, uri: str = "ws://localhost:8765"): + self.uri = uri + self.websocket = None + + async def connect(self): + """Connect to the WebSocket server""" + self.websocket = await websockets.connect(self.uri) + logger.info(f"Connected to WebSocket server at {self.uri}") + + async def disconnect(self): + """Disconnect from the WebSocket server""" + if self.websocket: + await self.websocket.close() + logger.info("Disconnected from WebSocket server") + + async def receive_events(self, event_handler: Optional[Callable] = None): + """Receive and handle streaming events""" + if not self.websocket: + await self.connect() + + try: + async for message in self.websocket: + try: + event_data = json.loads(message) + if event_handler: + await event_handler(event_data) + else: + # Default event handling + event_type = event_data.get("event_type") + chain_id = event_data.get("chain_id") + data = event_data.get("data", {}) + + if event_type == "llm_generation_chunk": + content = data.get("content", "") + print(f"🤖 Chain {chain_id}: {content}", end="", flush=True) + elif event_type == "tool_observation": + tool_name = data.get("tool_name", "") + observation = data.get("observation", "") + print(f"\n🔧 {tool_name}: {observation[:100]}...") + elif event_type == "chain_end": + print(f"\n✅ Chain {chain_id} completed!") + + except json.JSONDecodeError: + logger.warning(f"Invalid JSON received: {message}") + except Exception as e: + logger.error(f"Error handling event: {e}") + except websockets.exceptions.ConnectionClosed: + logger.info("WebSocket connection closed") + except Exception as e: + logger.error(f"WebSocket error: {e}") + + +# Example usage functions +async def start_websocket_server(): + """Start the WebSocket streaming server""" + server = WebSocketStreamingServer() + await server.start() + return server + + +async def run_agent_with_websocket_streaming(agent, start_messages, max_steps=5, num_chains=1): + """Run an agent with WebSocket streaming""" + + # Start WebSocket server + server = await start_websocket_server() + + # Add WebSocket observer to agent + agent.streaming_manager.add_observer(server.get_observer()) + + try: + # Run the agent + await agent.run_async( + max_steps=max_steps, + start_messages=start_messages, + num_chains=num_chains, + enable_streaming=True + ) + finally: + # Stop the server + await server.stop() + + +async def connect_and_monitor(): + """Connect to WebSocket server and monitor events""" + client = WebSocketStreamingClient() + + async def event_handler(event_data): + """Custom event handler""" + event_type = event_data.get("event_type") + chain_id = event_data.get("chain_id") + + if event_type == "llm_generation_start": + print(f"🚀 Chain {chain_id}: Starting LLM generation...") + elif event_type == "tool_call_start": + tool_name = event_data.get("data", {}).get("tool_name", "") + print(f"🔧 Chain {chain_id}: Calling tool {tool_name}...") + elif event_type == "chain_end": + final_depth = event_data.get("data", {}).get("final_depth", 0) + reward = event_data.get("data", {}).get("reward") + print(f"✅ Chain {chain_id}: Completed in {final_depth} steps (reward: {reward})") + + await client.receive_events(event_handler) + + +if __name__ == "__main__": + # Example: Start WebSocket server + asyncio.run(start_websocket_server()) \ No newline at end of file diff --git a/agents/agents/agents/llm_backend.py b/agents/agents/agents/llm_backend.py index bc527f7..3e481a8 100644 --- a/agents/agents/agents/llm_backend.py +++ b/agents/agents/agents/llm_backend.py @@ -7,7 +7,7 @@ from collections import deque from functools import partial import time -from typing import Dict, Any, List, Optional +from typing import Dict, Any, List, Optional, Callable, AsyncGenerator import uuid from .templates.utils import convert_messages_to_openai_format import numpy as np @@ -54,6 +54,10 @@ def apply_chat_template(self, messages_list: List[List[Dict]], template: str, ad def generate(self, messages_list: str, **kwargs) -> str: """Generate text from prompt""" raise NotImplementedError("Subclasses must implement generate()") + + async def generate_streaming(self, messages_list: List[List[Dict]], streaming_callback: Optional[Callable] = None, **kwargs) -> AsyncGenerator[str, None]: + """Generate text with streaming support""" + raise NotImplementedError("Subclasses must implement generate_streaming()") class TransformersBackend(LLMBackend): """HuggingFace Transformers implementation""" @@ -100,8 +104,50 @@ def generate(self, messages_list: str, **kwargs) -> str: response_texts = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) return response_texts - def generate_async(self, messages_list: str, **kwargs) -> str: - raise NotImplementedError("Transformers backend does not support async generation") + async def generate_async(self, messages_list: str, **kwargs) -> str: + """Async wrapper for generate""" + return self.generate(messages_list, **kwargs) + + async def generate_streaming(self, messages_list: List[List[Dict]], streaming_callback: Optional[Callable] = None, **kwargs) -> AsyncGenerator[str, None]: + """Generate text with streaming support using Transformers""" + max_new_tokens = kwargs.get("max_new_tokens", self.max_new_tokens) + temperature = kwargs.get("temperature", self.temperature) + + prompts, _ = self.apply_chat_template(messages_list, self.template) + + inputs = self.tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left").to(self.llm_engine.device) + input_length = inputs['input_ids'].shape[1] + + # Use streaming generation + generated_tokens = [] + for i in range(max_new_tokens): + outputs = self.llm_engine.generate( + **inputs, + max_new_tokens=1, + temperature=temperature, + do_sample=temperature > 0, + pad_token_id=self.tokenizer.eos_token_id, + use_cache=True + ) + + new_token = outputs[0][-1].unsqueeze(0) + generated_tokens.append(new_token) + + # Decode the new token + new_text = self.tokenizer.decode(new_token, skip_special_tokens=True) + + if streaming_callback: + await streaming_callback(new_text) + + yield new_text + + # Check for EOS + if new_token.item() == self.tokenizer.eos_token_id: + break + + # Update input for next iteration + inputs['input_ids'] = torch.cat([inputs['input_ids'], new_token.unsqueeze(0)], dim=1) + inputs['attention_mask'] = torch.cat([inputs['attention_mask'], torch.ones(1, 1, device=inputs['attention_mask'].device)], dim=1) class VLLMBackend(LLMBackend): """vLLM implementation""" @@ -221,6 +267,36 @@ async def generate_async(self, messages_list: str, **kwargs) -> str: LOGGER.debug(f"[AsyncVLLMBackend] response_texts: {response_texts}") return response_texts + + async def generate_streaming(self, messages_list: List[List[Dict]], streaming_callback: Optional[Callable] = None, **kwargs) -> AsyncGenerator[str, None]: + """Generate text with streaming support using Async vLLM""" + max_new_tokens = kwargs.get("max_new_tokens", self.max_new_tokens) + temperature = kwargs.get("temperature", self.temperature) + sampling_params = SamplingParams( + n=1, + max_tokens=max_new_tokens, + temperature=temperature, + ) + + tools = kwargs.get("tools", None) + prompts, vision_inputs = self.apply_chat_template(messages_list, self.template, tools=tools) + inputs = self._process_inputs(prompts, vision_inputs) + + # For streaming, we process one input at a time + for input_data in inputs: + outputs_gen = self.llm_engine.generate( + input_data, + sampling_params=sampling_params, + request_id=str(uuid.uuid4()), + ) + + async for output in outputs_gen: + for sequence in output.outputs: + # Stream each token + if hasattr(sequence, 'text'): + if streaming_callback: + await streaming_callback(sequence.text) + yield sequence.text class VerlBackend(LLMBackend): """Verl implementation""" diff --git a/agents/agents/examples/multi_chain_streaming_example.py b/agents/agents/examples/multi_chain_streaming_example.py new file mode 100644 index 0000000..7cdba3a --- /dev/null +++ b/agents/agents/examples/multi_chain_streaming_example.py @@ -0,0 +1,307 @@ +#!/usr/bin/env python3 +""" +Example demonstrating how to handle multiple chains with streaming +without mixing outputs from different chains. +""" + +import asyncio +import json +from typing import List, Dict, Any +from agents.agents.agents.agent_base import BaseAgent +from agents.agents.agents.chain.streaming_observer import ( + ConsoleStreamObserver, + JSONStreamObserver, + ChainSpecificStreamObserver, + MultiChainStreamObserver, + AsyncGeneratorStreamObserver +) +from agents.agents.tools.src.code.tools import code_interpreter + + +class StreamingAgent(BaseAgent): + """Example agent with streaming support""" + + def parse(self, responses: List[str], tools: List[Any], **args) -> List[Dict[str, Any]]: + """Parse tool calls from responses""" + parsed_responses = [] + for response in responses: + if "```python" in response: + start = response.find("```python") + 9 + end = response.find("```", start) + if end != -1: + code = response[start:end].strip() + parsed_responses.append({ + "role": "assistant", + "content": [{"type": "text", "text": response}], + "tool_calls": [{ + "id": "call_1", + "function": { + "name": "code_interpreter", + "arguments": json.dumps({"code": code}) + } + }] + }) + else: + parsed_responses.append({ + "role": "assistant", + "content": [{"type": "text", "text": response}] + }) + else: + parsed_responses.append({ + "role": "assistant", + "content": [{"type": "text", "text": response}] + }) + + return parsed_responses + + +async def example_1_colored_output(): + """Example 1: Use colored output to distinguish chains""" + print("=== Example 1: Colored Output for Multiple Chains ===") + + agent = StreamingAgent( + model_name_or_path="microsoft/DialoGPT-medium", + template="chatml", + tools=[code_interpreter], + backend="transformers", + debug=True + ) + + # Add colored console observer + console_observer = ConsoleStreamObserver(show_timestamps=True) + agent.streaming_manager.add_observer(console_observer) + + start_messages = [ + { + "messages": [{"role": "user", "content": "Write a function to calculate factorial of 5."}], + "info": {"task": "factorial"}, + "answer": "def factorial(n): return 1 if n <= 1 else n * factorial(n-1)\nprint(factorial(5))" + }, + { + "messages": [{"role": "user", "content": "Write a function to calculate fibonacci of 10."}], + "info": {"task": "fibonacci"}, + "answer": "def fibonacci(n): return n if n <= 1 else fibonacci(n-1) + fibonacci(n-2)\nprint(fibonacci(10))" + } + ] + + await agent.run_async( + max_steps=3, + start_messages=start_messages, + num_chains=2, # Run 2 chains per message = 4 total chains + enable_streaming=True + ) + + +async def example_2_chain_specific_observers(): + """Example 2: Use chain-specific observers""" + print("\n=== Example 2: Chain-Specific Observers ===") + + agent = StreamingAgent( + model_name_or_path="microsoft/DialoGPT-medium", + template="chatml", + tools=[code_interpreter], + backend="transformers", + debug=True + ) + + # Create a multi-chain observer + multi_observer = MultiChainStreamObserver() + + # Add chain-specific observers + for i in range(4): # We'll have 4 chains + chain_id = f"chain_{i}" + console_observer = ConsoleStreamObserver(show_timestamps=True) + json_observer = JSONStreamObserver(file_path=f"chain_{i}_events.jsonl") + + # Create chain-specific observers + chain_console = ChainSpecificStreamObserver(chain_id, console_observer) + chain_json = ChainSpecificStreamObserver(chain_id, json_observer) + + multi_observer.add_chain_observer(chain_id, chain_console) + multi_observer.add_chain_observer(chain_id, chain_json) + + agent.streaming_manager.add_observer(multi_observer) + + start_messages = [ + { + "messages": [{"role": "user", "content": "Write a function to add two numbers."}], + "info": {"task": "addition"}, + "answer": "def add(a, b): return a + b" + }, + { + "messages": [{"role": "user", "content": "Write a function to multiply two numbers."}], + "info": {"task": "multiplication"}, + "answer": "def multiply(a, b): return a * b" + } + ] + + await agent.run_async( + max_steps=3, + start_messages=start_messages, + num_chains=2, + enable_streaming=True + ) + + +async def example_3_filter_by_chain(): + """Example 3: Filter events by specific chain""" + print("\n=== Example 3: Filter by Specific Chain ===") + + agent = StreamingAgent( + model_name_or_path="microsoft/DialoGPT-medium", + template="chatml", + tools=[code_interpreter], + backend="transformers", + debug=True + ) + + # Only observe events from a specific chain + target_chain_id = "chain_0" # We'll focus on the first chain + filtered_observer = ConsoleStreamObserver( + show_timestamps=True, + chain_filter=target_chain_id + ) + + agent.streaming_manager.add_observer(filtered_observer) + + start_messages = [ + { + "messages": [{"role": "user", "content": "Write a function to check if a number is prime."}], + "info": {"task": "prime_check"}, + "answer": "def is_prime(n): return n > 1 and all(n % i != 0 for i in range(2, int(n**0.5) + 1))" + } + ] + + await agent.run_async( + max_steps=3, + start_messages=start_messages, + num_chains=3, # Run 3 chains but only observe one + enable_streaming=True + ) + + +async def example_4_async_generator_per_chain(): + """Example 4: Use async generators for each chain""" + print("\n=== Example 4: Async Generator per Chain ===") + + agent = StreamingAgent( + model_name_or_path="microsoft/DialoGPT-medium", + template="chatml", + tools=[code_interpreter], + backend="transformers", + debug=True + ) + + # Create async generators for each chain + chain_generators = {} + for i in range(2): + chain_id = f"chain_{i}" + async_observer = AsyncGeneratorStreamObserver(chain_filter=chain_id) + agent.streaming_manager.add_observer(async_observer) + chain_generators[chain_id] = async_observer.events() + + start_messages = [ + { + "messages": [{"role": "user", "content": "Write a function to reverse a string."}], + "info": {"task": "string_reverse"}, + "answer": "def reverse_string(s): return s[::-1]" + } + ] + + # Start the agent run + run_task = asyncio.create_task( + agent.run_async( + max_steps=3, + start_messages=start_messages, + num_chains=2, + enable_streaming=True + ) + ) + + # Process events for each chain separately + async def process_chain_events(chain_id: str, generator): + print(f"\n--- Processing events for {chain_id} ---") + async for event in generator: + print(f"{chain_id}: {event.event_type.value} - {event.data.get('content', '')[:50]}...") + + # Process all chains concurrently + tasks = [ + process_chain_events(chain_id, generator) + for chain_id, generator in chain_generators.items() + ] + + await asyncio.gather(*tasks) + await run_task + + +async def example_5_web_interface_simulation(): + """Example 5: Simulate web interface with separate streams""" + print("\n=== Example 5: Web Interface Simulation ===") + + agent = StreamingAgent( + model_name_or_path="microsoft/DialoGPT-medium", + template="chatml", + tools=[code_interpreter], + backend="transformers", + debug=True + ) + + # Simulate web interface with separate streams per chain + web_streams = {} + + async def create_web_stream(chain_id: str): + """Simulate a web stream for a specific chain""" + print(f"🌐 Web stream created for {chain_id}") + + async def web_stream_handler(event): + # Simulate sending to web client + web_message = { + "chain_id": event.chain_id, + "event_type": event.event_type.value, + "timestamp": event.timestamp, + "data": event.data + } + print(f"📡 Web client {chain_id}: {web_message['event_type']}") + + return web_stream_handler + + # Create web streams for each chain + for i in range(2): + chain_id = f"web_chain_{i}" + handler = await create_web_stream(chain_id) + web_streams[chain_id] = handler + + # Add observers that send to web streams + for chain_id, handler in web_streams.items(): + chain_observer = ChainSpecificStreamObserver(chain_id, handler) + agent.streaming_manager.add_observer(chain_observer) + + start_messages = [ + { + "messages": [{"role": "user", "content": "Write a function to sort a list."}], + "info": {"task": "sorting"}, + "answer": "def sort_list(lst): return sorted(lst)" + } + ] + + await agent.run_async( + max_steps=3, + start_messages=start_messages, + num_chains=2, + enable_streaming=True + ) + + +if __name__ == "__main__": + print("Multi-Chain Streaming Examples") + print("=" * 50) + + # Run all examples + asyncio.run(example_1_colored_output()) + asyncio.run(example_2_chain_specific_observers()) + asyncio.run(example_3_filter_by_chain()) + asyncio.run(example_4_async_generator_per_chain()) + asyncio.run(example_5_web_interface_simulation()) + + print("\n" + "=" * 50) + print("All examples completed!") \ No newline at end of file diff --git a/agents/agents/examples/streaming_example.py b/agents/agents/examples/streaming_example.py new file mode 100644 index 0000000..996a42c --- /dev/null +++ b/agents/agents/examples/streaming_example.py @@ -0,0 +1,227 @@ +#!/usr/bin/env python3 +""" +Example script demonstrating streaming functionality for LLM agent reinforcement learning. + +This example shows how to: +1. Set up streaming observers +2. Run agent chains with streaming +3. Handle real-time events +""" + +import asyncio +import json +from typing import List, Dict, Any +from agents.agents.agents.agent_base import BaseAgent +from agents.agents.agents.chain.streaming_observer import ( + StreamingManager, + ConsoleStreamObserver, + JSONStreamObserver, + AsyncGeneratorStreamObserver, + StreamEvent +) +from agents.agents.tools.src.code.tools import code_interpreter + + +class StreamingAgent(BaseAgent): + """Example agent with streaming support""" + + def parse(self, responses: List[str], tools: List[Any], **args) -> List[Dict[str, Any]]: + """Parse tool calls from responses""" + # Simple parsing - in practice you'd use a more sophisticated parser + parsed_responses = [] + for response in responses: + # Check if response contains tool call + if "```python" in response: + # Extract code between ```python and ``` + start = response.find("```python") + 9 + end = response.find("```", start) + if end != -1: + code = response[start:end].strip() + parsed_responses.append({ + "role": "assistant", + "content": [{"type": "text", "text": response}], + "tool_calls": [{ + "id": "call_1", + "function": { + "name": "code_interpreter", + "arguments": json.dumps({"code": code}) + } + }] + }) + else: + parsed_responses.append({ + "role": "assistant", + "content": [{"type": "text", "text": response}] + }) + else: + parsed_responses.append({ + "role": "assistant", + "content": [{"type": "text", "text": response}] + }) + + return parsed_responses + + +async def main(): + """Main example function""" + + # Initialize the agent + agent = StreamingAgent( + model_name_or_path="microsoft/DialoGPT-medium", # Replace with your model + template="chatml", + tools=[code_interpreter], + backend="transformers", # or "async_vllm" for streaming support + debug=True + ) + + # Set up streaming observers + console_observer = ConsoleStreamObserver(show_timestamps=True) + json_observer = JSONStreamObserver(file_path="streaming_events.jsonl") + + # Add observers to the streaming manager + agent.streaming_manager.add_observer(console_observer) + agent.streaming_manager.add_observer(json_observer) + + # Example messages + start_messages = [ + { + "messages": [ + {"role": "user", "content": "Write a Python function to calculate the factorial of a number and test it with input 5."} + ], + "info": {"task": "factorial_calculation"}, + "answer": "def factorial(n): return 1 if n <= 1 else n * factorial(n-1)\nprint(factorial(5))" + } + ] + + print("Starting streaming agent run...") + print("=" * 50) + + # Run with streaming + await agent.run_async( + max_steps=5, + start_messages=start_messages, + num_chains=1, + generation_config={"temperature": 0.7, "max_new_tokens": 512}, + enable_streaming=True + ) + + print("=" * 50) + print("Streaming run completed!") + + # Print final results + print("\nFinal trajectories:") + for i, trajectory in enumerate(agent.trajectories): + print(f"\nChain {i}:") + for msg in trajectory["messages"]: + if msg["role"] == "assistant": + print(f"Assistant: {msg['content'][0]['text'][:100]}...") + elif msg["role"] == "tool": + print(f"Tool: {msg['content'][0]['text'][:100]}...") + + +async def streaming_with_custom_callback(): + """Example with custom streaming callback""" + + agent = StreamingAgent( + model_name_or_path="microsoft/DialoGPT-medium", + template="chatml", + tools=[code_interpreter], + backend="transformers", + debug=True + ) + + # Custom streaming callback + async def custom_streaming_callback(chunk: str): + """Custom callback to handle streaming chunks""" + print(f"🔄 Streaming chunk: {chunk}", end="", flush=True) + + # Add console observer + console_observer = ConsoleStreamObserver(show_timestamps=True) + agent.streaming_manager.add_observer(console_observer) + + start_messages = [ + { + "messages": [ + {"role": "user", "content": "Write a simple Python function to add two numbers."} + ], + "info": {"task": "simple_addition"}, + "answer": "def add(a, b): return a + b" + } + ] + + print("Starting streaming with custom callback...") + print("=" * 50) + + await agent.run_async( + max_steps=3, + start_messages=start_messages, + num_chains=1, + generation_config={"temperature": 0.7, "max_new_tokens": 256}, + enable_streaming=True, + streaming_callback=custom_streaming_callback + ) + + +async def async_generator_example(): + """Example using AsyncGeneratorStreamObserver""" + + agent = StreamingAgent( + model_name_or_path="microsoft/DialoGPT-medium", + template="chatml", + tools=[code_interpreter], + backend="transformers", + debug=True + ) + + # Create async generator observer + async_observer = AsyncGeneratorStreamObserver() + agent.streaming_manager.add_observer(async_observer) + + start_messages = [ + { + "messages": [ + {"role": "user", "content": "Write a Python function to check if a number is prime."} + ], + "info": {"task": "prime_check"}, + "answer": "def is_prime(n): return n > 1 and all(n % i != 0 for i in range(2, int(n**0.5) + 1))" + } + ] + + print("Starting async generator example...") + print("=" * 50) + + # Start the agent run in background + run_task = asyncio.create_task( + agent.run_async( + max_steps=3, + start_messages=start_messages, + num_chains=1, + generation_config={"temperature": 0.7, "max_new_tokens": 256}, + enable_streaming=True + ) + ) + + # Process events as they arrive + async for event in async_observer.events(): + print(f"📡 Event: {event.event_type.value} - Chain: {event.chain_id}") + if event.event_type.value == "llm_generation_chunk": + print(f" Content: {event.data.get('content', '')}") + elif event.event_type.value == "tool_observation": + print(f" Tool: {event.data.get('tool_name', '')} - {event.data.get('observation', '')[:50]}...") + + # Wait for the run to complete + await run_task + + +if __name__ == "__main__": + print("Streaming Agent Example") + print("=" * 50) + + # Run different examples + asyncio.run(main()) + print("\n" + "=" * 50) + + asyncio.run(streaming_with_custom_callback()) + print("\n" + "=" * 50) + + asyncio.run(async_generator_example()) \ No newline at end of file From 14c4fd94b844aea76e6cfdd04038bb56216eb3a7 Mon Sep 17 00:00:00 2001 From: Reason-Wang Date: Tue, 22 Jul 2025 14:36:30 +0000 Subject: [PATCH 2/8] Test console streaming --- agents/agents/agents/chain/chain_base.py | 119 ++++--- .../agents/agents/chain/streaming_observer.py | 75 ++--- agents/agents/agents/llm_backend.py | 33 +- .../examples/multi_chain_streaming_example.py | 307 ------------------ agents/agents/examples/streaming_example.py | 248 +++----------- 5 files changed, 179 insertions(+), 603 deletions(-) delete mode 100644 agents/agents/examples/multi_chain_streaming_example.py diff --git a/agents/agents/agents/chain/chain_base.py b/agents/agents/agents/chain/chain_base.py index f363a3c..91faaf0 100644 --- a/agents/agents/agents/chain/chain_base.py +++ b/agents/agents/agents/chain/chain_base.py @@ -14,7 +14,7 @@ from ...utils.monitor import JsonlSink, MetricEvent, Monitor, WandbSink, emit, serialize_for_json from ... import AGENT_DATA_DIR import wandb -from .streaming_observer import StreamingManager, StreamEvent, StreamEventType +from .streaming_observer import ConsoleStreamObserver, StreamingManager, StreamEvent, StreamEventType @dataclass class Node: @@ -214,6 +214,14 @@ def prepare_chain_messages(self, start_messages: Union[List[dict], np.ndarray]): other_info_list.append(info) return messages_list, other_info_list + + def validate_run_args(self, max_steps: int, num_chains: int): + assert max_steps >= 1, "max_steps must be at least 1." + assert num_chains >= 1, "num_chains must be at least 1." + for observer in self.streaming_manager.observers: + if isinstance(observer, ConsoleStreamObserver): + assert num_chains == 1, "num_chains must be 1 when ConsoleStreamObserver is used." + async def run_async(self, max_steps: int, @@ -234,7 +242,7 @@ async def run_async(self, enable_streaming: Whether to enable streaming mode. streaming_callback: Optional callback for streaming events. """ - assert max_steps >= 1, "max_steps must be at least 1." + self.validate_run_args(max_steps, num_chains) Monitor.ensure_started() self.reset() messages_list, other_info_list = self.prepare_chain_messages(start_messages) @@ -260,15 +268,15 @@ async def run_async(self, done_q = asyncio.Queue() tasks = [ asyncio.create_task( - self._run_single_chain( - cid, - node, - chains[cid], - tool_schemas, - max_steps=max_steps, - done_queue=done_q, - enable_streaming=enable_streaming, - streaming_callback=streaming_callback) + self._run_single_chain( + cid, + node, + chains[cid], + tool_schemas, + max_steps=max_steps, + done_queue=done_q, + enable_streaming=enable_streaming + ) ) for cid, node in first_nodes.items() ] @@ -291,8 +299,7 @@ async def _run_single_chain(self, tools: List[Dict], max_steps: int, done_queue: asyncio.Queue, - enable_streaming: bool = False, - streaming_callback: Optional[Callable] = None, + enable_streaming: bool = False ): """ Run a single chain with optional streaming support. @@ -302,12 +309,12 @@ async def _run_single_chain(self, have_set_tools = False while not current_node.is_terminal and depth < max_steps: - newest_messages = current_node.messages + newest_messages = deepcopy(current_node.messages) if not current_node.is_terminal: # Generate response new_msg = await self._generate_response( - current_node, tools, depth, chain_id, enable_streaming, streaming_callback + current_node, tools, depth, chain_id, enable_streaming ) newest_messages.append(new_msg) @@ -340,7 +347,7 @@ async def _run_single_chain(self, self.finished_chains_count += 1 self.monitor_chain() - async def _generate_response(self, current_node, tools, depth, chain_id, enable_streaming, streaming_callback): + async def _generate_response(self, current_node, tools, depth, chain_id, enable_streaming): """Generate response with optional streaming support.""" if enable_streaming: # Emit generation start event @@ -353,9 +360,22 @@ async def _generate_response(self, current_node, tools, depth, chain_id, enable_ depth=depth )) - # Use streaming generation if available - if hasattr(self, 'generate_streaming') and streaming_callback: - async def stream_callback(chunk: str): + # Check if we have streaming capabilities + has_streaming = False + if hasattr(self, 'generate_streaming'): + has_streaming = True + elif hasattr(self, 'llm_engine') and hasattr(self.llm_engine, 'generate_streaming'): + has_streaming = True + # Create a wrapper to use the LLM engine's streaming + async def generate_streaming_wrapper(messages_list, **kwargs): + async for chunk in self.llm_engine.generate_streaming(messages_list, **kwargs): + yield chunk + self.generate_streaming = generate_streaming_wrapper + + if has_streaming: + # Collect full response from streaming + full_response = "" + async for chunk in self.generate_streaming([current_node.messages], tools=tools): await self.streaming_manager.emit_event(StreamEvent( event_type=StreamEventType.LLM_GENERATION_CHUNK, chain_id=chain_id, @@ -364,13 +384,8 @@ async def stream_callback(chunk: str): step=depth, depth=depth )) - if streaming_callback: - await streaming_callback(chunk) - - # Collect full response from streaming - full_response = "" - async for chunk in self.generate_streaming([current_node.messages], tools=tools, streaming_callback=stream_callback): - full_response += chunk + # chunk is the whole generated text + full_response = chunk # Emit generation end event await self.streaming_manager.emit_event(StreamEvent( @@ -389,6 +404,37 @@ async def stream_callback(chunk: str): # Fallback to non-streaming generation responses = await self.generate_async([current_node.messages], tools=tools, num_return_sequences=1) new_msg = self.parse(responses, self.tools) + + # Emit a single chunk event for the full response + full_response = new_msg[0].get("content", "") + if isinstance(full_response, list) and len(full_response) > 0: + # Handle case where content is a list of content blocks + if isinstance(full_response[0], dict) and "text" in full_response[0]: + full_response = full_response[0]["text"] + else: + full_response = str(full_response) + elif not isinstance(full_response, str): + full_response = str(full_response) + + await self.streaming_manager.emit_event(StreamEvent( + event_type=StreamEventType.LLM_GENERATION_CHUNK, + chain_id=chain_id, + timestamp=time.time(), + data={"content": full_response}, + step=depth, + depth=depth + )) + + # Emit generation end event + await self.streaming_manager.emit_event(StreamEvent( + event_type=StreamEventType.LLM_GENERATION_END, + chain_id=chain_id, + timestamp=time.time(), + data={"full_response": full_response}, + step=depth, + depth=depth + )) + return new_msg[0] else: # Non-streaming generation @@ -401,17 +447,6 @@ async def _execute_tool_call(self, tool_call, newest_messages, chain, chain_id, tool_name = tool_call["function"]["name"] tool_input = tool_call["function"]["arguments"] - if enable_streaming: - # Emit tool call start event - await self.streaming_manager.emit_event(StreamEvent( - event_type=StreamEventType.TOOL_CALL_START, - chain_id=chain_id, - timestamp=time.time(), - data={"tool_name": tool_name, "arguments": tool_input}, - step=depth, - depth=depth - )) - # Create action node action_node = chain.add_node( type="Action", @@ -424,6 +459,7 @@ async def _execute_tool_call(self, tool_call, newest_messages, chain, chain_id, await self.set_tools(chain_id, chain.info) have_set_tools = True + # Execute tool call result = await submit_tool_call(tool_name, tool_input, id=chain_id) @@ -442,15 +478,6 @@ async def _execute_tool_call(self, tool_call, newest_messages, chain, chain_id, depth=depth )) - # Emit tool call end event - await self.streaming_manager.emit_event(StreamEvent( - event_type=StreamEventType.TOOL_CALL_END, - chain_id=chain_id, - timestamp=time.time(), - data={"tool_name": tool_name, "result": result}, - step=depth, - depth=depth - )) # Create action input node action_input_node = chain.add_node( diff --git a/agents/agents/agents/chain/streaming_observer.py b/agents/agents/agents/chain/streaming_observer.py index d4de1a3..961ca2a 100644 --- a/agents/agents/agents/chain/streaming_observer.py +++ b/agents/agents/agents/chain/streaming_observer.py @@ -1,6 +1,7 @@ import asyncio from abc import ABC, abstractmethod from dataclasses import dataclass +import os from typing import Any, Dict, List, Optional, Callable, AsyncGenerator, Set from enum import Enum import json @@ -119,6 +120,9 @@ class ConsoleStreamObserver(StreamObserver): def __init__(self, show_timestamps: bool = True, chain_filter: Optional[str] = None): self.show_timestamps = show_timestamps self.chain_filter = chain_filter # Only show events for this chain_id + self.chain_colors = ["red", "green", "blue", "yellow", "magenta", "cyan"] + self.chain_id_data = {} + async def on_event(self, event: StreamEvent) -> None: # Filter by chain if specified @@ -126,55 +130,48 @@ async def on_event(self, event: StreamEvent) -> None: return timestamp = f"[{event.timestamp:.2f}s] " if self.show_timestamps else "" - step_info = f" (step {event.step})" if event.step is not None else "" - depth_info = f" (depth {event.depth})" if event.depth is not None else "" + turn_info = f" (turn {event.step})" if event.step is not None else "" # Use different colors for different chains - chain_colors = ["red", "green", "blue", "yellow", "magenta", "cyan"] - chain_index = hash(event.chain_id) % len(chain_colors) - chain_color = chain_colors[chain_index] - - print(colored(f"{timestamp}Chain {event.chain_id}{step_info}{depth_info}: {event.event_type.value}", color=chain_color)) - - if event.event_type == StreamEventType.LLM_GENERATION_CHUNK: + chain_index = hash(event.chain_id) % len(self.chain_colors) + chain_color = self.chain_colors[chain_index] + if event.chain_id not in self.chain_id_data: + self.chain_id_data[event.chain_id] = { + "color": chain_color, + "timestamp": event.timestamp, + "step": event.step, + "depth": event.depth, + "event_type": event.event_type.value, + "content_buffer": "" + } + else: + self.chain_id_data[event.chain_id]["timestamp"] = event.timestamp + + if event.event_type == StreamEventType.LLM_GENERATION_START: + print(colored(f"{timestamp} {turn_info}", color=chain_color), end="", flush=True) + elif event.event_type == StreamEventType.LLM_GENERATION_CHUNK: content = event.data.get("content", "") if content: - print(colored(f" → {content}", color=chain_color)) + # clear the terminal + if self.chain_id_data[event.chain_id]["event_type"] == StreamEventType.LLM_GENERATION_CHUNK: + print(colored(f"""{content[len(self.chain_id_data[event.chain_id]["content_buffer"]):]}""", color=chain_color), end="", flush=True) + self.chain_id_data[event.chain_id]["content_buffer"] = content + else: + self.chain_id_data[event.chain_id]["content_buffer"] = content + print(colored(f"{content}", color=chain_color), end="", flush=True) + self.chain_id_data[event.chain_id]["event_type"] = StreamEventType.LLM_GENERATION_CHUNK + elif event.event_type == StreamEventType.LLM_GENERATION_END: + print(colored(f"\n{event.data.get('timestamp', '')}", color=chain_color), end="", flush=True) + self.chain_id_data[event.chain_id]["event_type"] = StreamEventType.LLM_GENERATION_END elif event.event_type == StreamEventType.TOOL_OBSERVATION: observation = event.data.get("observation", "") tool_name = event.data.get("tool_name", "") - print(colored(f" 🔧 {tool_name}: {observation[:200]}{'...' if len(observation) > 200 else ''}", color=chain_color)) + print(colored(f"Tool: [{tool_name}] {observation[:200]}{'...' if len(observation) > 200 else ''}", color=chain_color)) + self.chain_id_data[event.chain_id]["event_type"] = StreamEventType.TOOL_OBSERVATION elif event.event_type == StreamEventType.ERROR: error_msg = event.data.get("error", "") print(colored(f" ❌ Error: {error_msg}", color=chain_color)) - - -class JSONStreamObserver(StreamObserver): - """JSON-based stream observer for structured logging""" - - def __init__(self, file_path: Optional[str] = None, chain_filter: Optional[str] = None): - self.file_path = file_path - self.chain_filter = chain_filter - - async def on_event(self, event: StreamEvent) -> None: - # Filter by chain if specified - if self.chain_filter and event.chain_id != self.chain_filter: - return - - event_dict = { - "event_type": event.event_type.value, - "chain_id": event.chain_id, - "timestamp": event.timestamp, - "step": event.step, - "depth": event.depth, - "data": event.data - } - - if self.file_path: - with open(self.file_path, "a") as f: - f.write(json.dumps(event_dict) + "\n") - else: - print(json.dumps(event_dict)) + self.chain_id_data[event.chain_id]["event_type"] = StreamEventType.ERROR class AsyncGeneratorStreamObserver(StreamObserver): diff --git a/agents/agents/agents/llm_backend.py b/agents/agents/agents/llm_backend.py index 3e481a8..2401137 100644 --- a/agents/agents/agents/llm_backend.py +++ b/agents/agents/agents/llm_backend.py @@ -202,6 +202,35 @@ def generate(self, messages_list: str, **kwargs) -> str: def generate_async(self, messages_list: str, **kwargs) -> str: raise NotImplementedError("VLLM backend does not support async generation") + async def generate_streaming(self, messages_list: List[List[Dict]], streaming_callback: Optional[Callable] = None, **kwargs) -> AsyncGenerator[str, None]: + """Generate text with streaming support using vLLM""" + max_new_tokens = kwargs.get("max_new_tokens", self.max_new_tokens) + temperature = kwargs.get("temperature", self.temperature) + sampling_params = SamplingParams( + n=1, + max_tokens=max_new_tokens, + temperature=temperature, + ) + + tools = kwargs.get("tools", None) + prompts, vision_inputs = self.apply_chat_template(messages_list, self.template, tools=tools) + inputs = self._process_inputs(prompts, vision_inputs) + + # For streaming, we process one input at a time + for input_data in inputs: + outputs_gen = self.llm_engine.generate( + input_data, + sampling_params=sampling_params, + request_id=str(uuid.uuid4()), + ) + + async for output in outputs_gen: + for sequence in output.outputs: + # Stream each token + if hasattr(sequence, 'text'): + if streaming_callback: + await streaming_callback(sequence.text) + yield sequence.text class AsyncVLLMBackend(LLMBackend): """Async vLLM implementation""" @@ -268,7 +297,7 @@ async def generate_async(self, messages_list: str, **kwargs) -> str: return response_texts - async def generate_streaming(self, messages_list: List[List[Dict]], streaming_callback: Optional[Callable] = None, **kwargs) -> AsyncGenerator[str, None]: + async def generate_streaming(self, messages_list: List[List[Dict]], **kwargs) -> AsyncGenerator[str, None]: """Generate text with streaming support using Async vLLM""" max_new_tokens = kwargs.get("max_new_tokens", self.max_new_tokens) temperature = kwargs.get("temperature", self.temperature) @@ -294,8 +323,6 @@ async def generate_streaming(self, messages_list: List[List[Dict]], streaming_ca for sequence in output.outputs: # Stream each token if hasattr(sequence, 'text'): - if streaming_callback: - await streaming_callback(sequence.text) yield sequence.text class VerlBackend(LLMBackend): diff --git a/agents/agents/examples/multi_chain_streaming_example.py b/agents/agents/examples/multi_chain_streaming_example.py deleted file mode 100644 index 7cdba3a..0000000 --- a/agents/agents/examples/multi_chain_streaming_example.py +++ /dev/null @@ -1,307 +0,0 @@ -#!/usr/bin/env python3 -""" -Example demonstrating how to handle multiple chains with streaming -without mixing outputs from different chains. -""" - -import asyncio -import json -from typing import List, Dict, Any -from agents.agents.agents.agent_base import BaseAgent -from agents.agents.agents.chain.streaming_observer import ( - ConsoleStreamObserver, - JSONStreamObserver, - ChainSpecificStreamObserver, - MultiChainStreamObserver, - AsyncGeneratorStreamObserver -) -from agents.agents.tools.src.code.tools import code_interpreter - - -class StreamingAgent(BaseAgent): - """Example agent with streaming support""" - - def parse(self, responses: List[str], tools: List[Any], **args) -> List[Dict[str, Any]]: - """Parse tool calls from responses""" - parsed_responses = [] - for response in responses: - if "```python" in response: - start = response.find("```python") + 9 - end = response.find("```", start) - if end != -1: - code = response[start:end].strip() - parsed_responses.append({ - "role": "assistant", - "content": [{"type": "text", "text": response}], - "tool_calls": [{ - "id": "call_1", - "function": { - "name": "code_interpreter", - "arguments": json.dumps({"code": code}) - } - }] - }) - else: - parsed_responses.append({ - "role": "assistant", - "content": [{"type": "text", "text": response}] - }) - else: - parsed_responses.append({ - "role": "assistant", - "content": [{"type": "text", "text": response}] - }) - - return parsed_responses - - -async def example_1_colored_output(): - """Example 1: Use colored output to distinguish chains""" - print("=== Example 1: Colored Output for Multiple Chains ===") - - agent = StreamingAgent( - model_name_or_path="microsoft/DialoGPT-medium", - template="chatml", - tools=[code_interpreter], - backend="transformers", - debug=True - ) - - # Add colored console observer - console_observer = ConsoleStreamObserver(show_timestamps=True) - agent.streaming_manager.add_observer(console_observer) - - start_messages = [ - { - "messages": [{"role": "user", "content": "Write a function to calculate factorial of 5."}], - "info": {"task": "factorial"}, - "answer": "def factorial(n): return 1 if n <= 1 else n * factorial(n-1)\nprint(factorial(5))" - }, - { - "messages": [{"role": "user", "content": "Write a function to calculate fibonacci of 10."}], - "info": {"task": "fibonacci"}, - "answer": "def fibonacci(n): return n if n <= 1 else fibonacci(n-1) + fibonacci(n-2)\nprint(fibonacci(10))" - } - ] - - await agent.run_async( - max_steps=3, - start_messages=start_messages, - num_chains=2, # Run 2 chains per message = 4 total chains - enable_streaming=True - ) - - -async def example_2_chain_specific_observers(): - """Example 2: Use chain-specific observers""" - print("\n=== Example 2: Chain-Specific Observers ===") - - agent = StreamingAgent( - model_name_or_path="microsoft/DialoGPT-medium", - template="chatml", - tools=[code_interpreter], - backend="transformers", - debug=True - ) - - # Create a multi-chain observer - multi_observer = MultiChainStreamObserver() - - # Add chain-specific observers - for i in range(4): # We'll have 4 chains - chain_id = f"chain_{i}" - console_observer = ConsoleStreamObserver(show_timestamps=True) - json_observer = JSONStreamObserver(file_path=f"chain_{i}_events.jsonl") - - # Create chain-specific observers - chain_console = ChainSpecificStreamObserver(chain_id, console_observer) - chain_json = ChainSpecificStreamObserver(chain_id, json_observer) - - multi_observer.add_chain_observer(chain_id, chain_console) - multi_observer.add_chain_observer(chain_id, chain_json) - - agent.streaming_manager.add_observer(multi_observer) - - start_messages = [ - { - "messages": [{"role": "user", "content": "Write a function to add two numbers."}], - "info": {"task": "addition"}, - "answer": "def add(a, b): return a + b" - }, - { - "messages": [{"role": "user", "content": "Write a function to multiply two numbers."}], - "info": {"task": "multiplication"}, - "answer": "def multiply(a, b): return a * b" - } - ] - - await agent.run_async( - max_steps=3, - start_messages=start_messages, - num_chains=2, - enable_streaming=True - ) - - -async def example_3_filter_by_chain(): - """Example 3: Filter events by specific chain""" - print("\n=== Example 3: Filter by Specific Chain ===") - - agent = StreamingAgent( - model_name_or_path="microsoft/DialoGPT-medium", - template="chatml", - tools=[code_interpreter], - backend="transformers", - debug=True - ) - - # Only observe events from a specific chain - target_chain_id = "chain_0" # We'll focus on the first chain - filtered_observer = ConsoleStreamObserver( - show_timestamps=True, - chain_filter=target_chain_id - ) - - agent.streaming_manager.add_observer(filtered_observer) - - start_messages = [ - { - "messages": [{"role": "user", "content": "Write a function to check if a number is prime."}], - "info": {"task": "prime_check"}, - "answer": "def is_prime(n): return n > 1 and all(n % i != 0 for i in range(2, int(n**0.5) + 1))" - } - ] - - await agent.run_async( - max_steps=3, - start_messages=start_messages, - num_chains=3, # Run 3 chains but only observe one - enable_streaming=True - ) - - -async def example_4_async_generator_per_chain(): - """Example 4: Use async generators for each chain""" - print("\n=== Example 4: Async Generator per Chain ===") - - agent = StreamingAgent( - model_name_or_path="microsoft/DialoGPT-medium", - template="chatml", - tools=[code_interpreter], - backend="transformers", - debug=True - ) - - # Create async generators for each chain - chain_generators = {} - for i in range(2): - chain_id = f"chain_{i}" - async_observer = AsyncGeneratorStreamObserver(chain_filter=chain_id) - agent.streaming_manager.add_observer(async_observer) - chain_generators[chain_id] = async_observer.events() - - start_messages = [ - { - "messages": [{"role": "user", "content": "Write a function to reverse a string."}], - "info": {"task": "string_reverse"}, - "answer": "def reverse_string(s): return s[::-1]" - } - ] - - # Start the agent run - run_task = asyncio.create_task( - agent.run_async( - max_steps=3, - start_messages=start_messages, - num_chains=2, - enable_streaming=True - ) - ) - - # Process events for each chain separately - async def process_chain_events(chain_id: str, generator): - print(f"\n--- Processing events for {chain_id} ---") - async for event in generator: - print(f"{chain_id}: {event.event_type.value} - {event.data.get('content', '')[:50]}...") - - # Process all chains concurrently - tasks = [ - process_chain_events(chain_id, generator) - for chain_id, generator in chain_generators.items() - ] - - await asyncio.gather(*tasks) - await run_task - - -async def example_5_web_interface_simulation(): - """Example 5: Simulate web interface with separate streams""" - print("\n=== Example 5: Web Interface Simulation ===") - - agent = StreamingAgent( - model_name_or_path="microsoft/DialoGPT-medium", - template="chatml", - tools=[code_interpreter], - backend="transformers", - debug=True - ) - - # Simulate web interface with separate streams per chain - web_streams = {} - - async def create_web_stream(chain_id: str): - """Simulate a web stream for a specific chain""" - print(f"🌐 Web stream created for {chain_id}") - - async def web_stream_handler(event): - # Simulate sending to web client - web_message = { - "chain_id": event.chain_id, - "event_type": event.event_type.value, - "timestamp": event.timestamp, - "data": event.data - } - print(f"📡 Web client {chain_id}: {web_message['event_type']}") - - return web_stream_handler - - # Create web streams for each chain - for i in range(2): - chain_id = f"web_chain_{i}" - handler = await create_web_stream(chain_id) - web_streams[chain_id] = handler - - # Add observers that send to web streams - for chain_id, handler in web_streams.items(): - chain_observer = ChainSpecificStreamObserver(chain_id, handler) - agent.streaming_manager.add_observer(chain_observer) - - start_messages = [ - { - "messages": [{"role": "user", "content": "Write a function to sort a list."}], - "info": {"task": "sorting"}, - "answer": "def sort_list(lst): return sorted(lst)" - } - ] - - await agent.run_async( - max_steps=3, - start_messages=start_messages, - num_chains=2, - enable_streaming=True - ) - - -if __name__ == "__main__": - print("Multi-Chain Streaming Examples") - print("=" * 50) - - # Run all examples - asyncio.run(example_1_colored_output()) - asyncio.run(example_2_chain_specific_observers()) - asyncio.run(example_3_filter_by_chain()) - asyncio.run(example_4_async_generator_per_chain()) - asyncio.run(example_5_web_interface_simulation()) - - print("\n" + "=" * 50) - print("All examples completed!") \ No newline at end of file diff --git a/agents/agents/examples/streaming_example.py b/agents/agents/examples/streaming_example.py index 996a42c..c4fd08b 100644 --- a/agents/agents/examples/streaming_example.py +++ b/agents/agents/examples/streaming_example.py @@ -1,227 +1,59 @@ -#!/usr/bin/env python3 -""" -Example script demonstrating streaming functionality for LLM agent reinforcement learning. - -This example shows how to: -1. Set up streaming observers -2. Run agent chains with streaming -3. Handle real-time events -""" - import asyncio +from agents.agents.react.react_agent import ReactAgent +from agents.tools import code_interpreter +from agents.rewards import math_reward_tool import json -from typing import List, Dict, Any -from agents.agents.agents.agent_base import BaseAgent -from agents.agents.agents.chain.streaming_observer import ( - StreamingManager, - ConsoleStreamObserver, - JSONStreamObserver, - AsyncGeneratorStreamObserver, - StreamEvent -) -from agents.agents.tools.src.code.tools import code_interpreter - - -class StreamingAgent(BaseAgent): - """Example agent with streaming support""" - - def parse(self, responses: List[str], tools: List[Any], **args) -> List[Dict[str, Any]]: - """Parse tool calls from responses""" - # Simple parsing - in practice you'd use a more sophisticated parser - parsed_responses = [] - for response in responses: - # Check if response contains tool call - if "```python" in response: - # Extract code between ```python and ``` - start = response.find("```python") + 9 - end = response.find("```", start) - if end != -1: - code = response[start:end].strip() - parsed_responses.append({ - "role": "assistant", - "content": [{"type": "text", "text": response}], - "tool_calls": [{ - "id": "call_1", - "function": { - "name": "code_interpreter", - "arguments": json.dumps({"code": code}) - } - }] - }) - else: - parsed_responses.append({ - "role": "assistant", - "content": [{"type": "text", "text": response}] - }) - else: - parsed_responses.append({ - "role": "assistant", - "content": [{"type": "text", "text": response}] - }) - - return parsed_responses +from agents.agents.chain.streaming_observer import ConsoleStreamObserver async def main(): - """Main example function""" - - # Initialize the agent - agent = StreamingAgent( - model_name_or_path="microsoft/DialoGPT-medium", # Replace with your model - template="chatml", - tools=[code_interpreter], - backend="transformers", # or "async_vllm" for streaming support + tools = [code_interpreter] + + agent = ReactAgent( + "Qwen/Qwen2.5-7B-Instruct", + tools=tools, + template="qwen2.5-no-tool", + backend="async_vllm", + reward_fn=math_reward_tool, debug=True ) - - # Set up streaming observers - console_observer = ConsoleStreamObserver(show_timestamps=True) - json_observer = JSONStreamObserver(file_path="streaming_events.jsonl") - - # Add observers to the streaming manager - agent.streaming_manager.add_observer(console_observer) - agent.streaming_manager.add_observer(json_observer) - - # Example messages - start_messages = [ - { - "messages": [ - {"role": "user", "content": "Write a Python function to calculate the factorial of a number and test it with input 5."} - ], - "info": {"task": "factorial_calculation"}, - "answer": "def factorial(n): return 1 if n <= 1 else n * factorial(n-1)\nprint(factorial(5))" - } - ] - - print("Starting streaming agent run...") - print("=" * 50) - - # Run with streaming - await agent.run_async( - max_steps=5, - start_messages=start_messages, - num_chains=1, - generation_config={"temperature": 0.7, "max_new_tokens": 512}, - enable_streaming=True - ) - - print("=" * 50) - print("Streaming run completed!") - - # Print final results - print("\nFinal trajectories:") - for i, trajectory in enumerate(agent.trajectories): - print(f"\nChain {i}:") - for msg in trajectory["messages"]: - if msg["role"] == "assistant": - print(f"Assistant: {msg['content'][0]['text'][:100]}...") - elif msg["role"] == "tool": - print(f"Tool: {msg['content'][0]['text'][:100]}...") -async def streaming_with_custom_callback(): - """Example with custom streaming callback""" - - agent = StreamingAgent( - model_name_or_path="microsoft/DialoGPT-medium", - template="chatml", - tools=[code_interpreter], - backend="transformers", - debug=True - ) - - # Custom streaming callback - async def custom_streaming_callback(chunk: str): - """Custom callback to handle streaming chunks""" - print(f"🔄 Streaming chunk: {chunk}", end="", flush=True) - - # Add console observer - console_observer = ConsoleStreamObserver(show_timestamps=True) - agent.streaming_manager.add_observer(console_observer) - - start_messages = [ - { - "messages": [ - {"role": "user", "content": "Write a simple Python function to add two numbers."} - ], - "info": {"task": "simple_addition"}, - "answer": "def add(a, b): return a + b" - } - ] - - print("Starting streaming with custom callback...") - print("=" * 50) - - await agent.run_async( - max_steps=3, - start_messages=start_messages, - num_chains=1, - generation_config={"temperature": 0.7, "max_new_tokens": 256}, - enable_streaming=True, - streaming_callback=custom_streaming_callback - ) + console_stream_observer = ConsoleStreamObserver() + agent.streaming_manager.add_observer(console_stream_observer) -async def async_generator_example(): - """Example using AsyncGeneratorStreamObserver""" - - agent = StreamingAgent( - model_name_or_path="microsoft/DialoGPT-medium", - template="chatml", - tools=[code_interpreter], - backend="transformers", - debug=True - ) - - # Create async generator observer - async_observer = AsyncGeneratorStreamObserver() - agent.streaming_manager.add_observer(async_observer) - - start_messages = [ + question1 = "Every morning Aya goes for a $9$-kilometer-long walk and stops at a coffee shop afterwards. When she walks at a constant speed of $s$ kilometers per hour, the walk takes her 4 hours, including $t$ minutes spent in the coffee shop. When she walks $s+2$ kilometers per hour, the walk takes her 2 hours and 24 minutes, including $t$ minutes spent in the coffee shop. Suppose Aya walks at $s+\frac{1}{2}$ kilometers per hour. Find the number of minutes the walk takes her, including the $t$ minutes spent in the coffee shop." + answer1 = "204" + question2 = "$P(x)$ is a polynomial of degree $3n$ such that\n\\begin{eqnarray*} P(0) = P(3) = \\cdots &=& P(3n) = 2, \\\\ P(1) = P(4) = \\cdots &=& P(3n-2) = 1, \\\\ P(2) = P(5) = \\cdots &=& P(3n-1) = 0, \\quad\\text{ and }\\\\ && P(3n+1) = 730.\\end{eqnarray*}\nDetermine $n$." + answer2 = "3" + + + messages = [ { "messages": [ - {"role": "user", "content": "Write a Python function to check if a number is prime."} + {"role": "user", "content": f"{question1}"} ], - "info": {"task": "prime_check"}, - "answer": "def is_prime(n): return n > 1 and all(n % i != 0 for i in range(2, int(n**0.5) + 1))" - } + "question": f"{question1}", + "answer": f"{answer1}" + }, + # { + # "messages": [ + # {"role": "user", "content": f"{question2}"} + # ], + # "question": f"{question2}", + # "answer": f"{answer2}" + # } ] - - print("Starting async generator example...") - print("=" * 50) - - # Start the agent run in background - run_task = asyncio.create_task( - agent.run_async( - max_steps=3, - start_messages=start_messages, - num_chains=1, - generation_config={"temperature": 0.7, "max_new_tokens": 256}, - enable_streaming=True - ) - ) - - # Process events as they arrive - async for event in async_observer.events(): - print(f"📡 Event: {event.event_type.value} - Chain: {event.chain_id}") - if event.event_type.value == "llm_generation_chunk": - print(f" Content: {event.data.get('content', '')}") - elif event.event_type.value == "tool_observation": - print(f" Tool: {event.data.get('tool_name', '')} - {event.data.get('observation', '')[:50]}...") - - # Wait for the run to complete - await run_task + await agent.run_async( + max_steps=4, + start_messages=messages, + num_chains=1, + enable_streaming=True + ) + if __name__ == "__main__": - print("Streaming Agent Example") - print("=" * 50) - - # Run different examples asyncio.run(main()) - print("\n" + "=" * 50) - - asyncio.run(streaming_with_custom_callback()) - print("\n" + "=" * 50) - - asyncio.run(async_generator_example()) \ No newline at end of file + From 849b7ca9e50b7a2b7d260b7a4b71efa611a983ad Mon Sep 17 00:00:00 2001 From: Reason-Wang Date: Wed, 23 Jul 2025 03:43:37 +0000 Subject: [PATCH 3/8] Adjust concolse print --- agents/agents/agents/chain/streaming_observer.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/agents/agents/agents/chain/streaming_observer.py b/agents/agents/agents/chain/streaming_observer.py index 961ca2a..7ac9447 100644 --- a/agents/agents/agents/chain/streaming_observer.py +++ b/agents/agents/agents/chain/streaming_observer.py @@ -128,8 +128,7 @@ async def on_event(self, event: StreamEvent) -> None: # Filter by chain if specified if self.chain_filter and event.chain_id != self.chain_filter: return - - timestamp = f"[{event.timestamp:.2f}s] " if self.show_timestamps else "" + turn_info = f" (turn {event.step})" if event.step is not None else "" # Use different colors for different chains @@ -144,11 +143,9 @@ async def on_event(self, event: StreamEvent) -> None: "event_type": event.event_type.value, "content_buffer": "" } - else: - self.chain_id_data[event.chain_id]["timestamp"] = event.timestamp if event.event_type == StreamEventType.LLM_GENERATION_START: - print(colored(f"{timestamp} {turn_info}", color=chain_color), end="", flush=True) + print(colored(f"{event.timestamp - self.chain_id_data[event.chain_id]['timestamp']:.2f}s {turn_info} =====================", color=chain_color), flush=True) elif event.event_type == StreamEventType.LLM_GENERATION_CHUNK: content = event.data.get("content", "") if content: @@ -161,11 +158,12 @@ async def on_event(self, event: StreamEvent) -> None: print(colored(f"{content}", color=chain_color), end="", flush=True) self.chain_id_data[event.chain_id]["event_type"] = StreamEventType.LLM_GENERATION_CHUNK elif event.event_type == StreamEventType.LLM_GENERATION_END: - print(colored(f"\n{event.data.get('timestamp', '')}", color=chain_color), end="", flush=True) + print(colored(f"\n{event.timestamp - self.chain_id_data[event.chain_id]['timestamp']:.2f}s", color=chain_color), flush=True) self.chain_id_data[event.chain_id]["event_type"] = StreamEventType.LLM_GENERATION_END elif event.event_type == StreamEventType.TOOL_OBSERVATION: observation = event.data.get("observation", "") tool_name = event.data.get("tool_name", "") + print(colored(f"{event.timestamp - self.chain_id_data[event.chain_id]['timestamp']:.2f}s {turn_info} =====================", color=chain_color), flush=True) print(colored(f"Tool: [{tool_name}] {observation[:200]}{'...' if len(observation) > 200 else ''}", color=chain_color)) self.chain_id_data[event.chain_id]["event_type"] = StreamEventType.TOOL_OBSERVATION elif event.event_type == StreamEventType.ERROR: From 07e94b4c58adf836e2573ae8916384226e6b8226 Mon Sep 17 00:00:00 2001 From: Reason-Wang Date: Wed, 23 Jul 2025 10:12:15 +0000 Subject: [PATCH 4/8] update streaming --- .../agents/agents/chain/streaming_observer.py | 2 +- agents/agents/agents/react/react_agent.py | 78 ++++++++++++++----- agents/agents/rewards/math_reward.py | 6 +- agents/agents/tools/src/code/tools.py | 2 +- agents/requirements.txt | 2 +- 5 files changed, 65 insertions(+), 25 deletions(-) diff --git a/agents/agents/agents/chain/streaming_observer.py b/agents/agents/agents/chain/streaming_observer.py index 7ac9447..f358224 100644 --- a/agents/agents/agents/chain/streaming_observer.py +++ b/agents/agents/agents/chain/streaming_observer.py @@ -164,7 +164,7 @@ async def on_event(self, event: StreamEvent) -> None: observation = event.data.get("observation", "") tool_name = event.data.get("tool_name", "") print(colored(f"{event.timestamp - self.chain_id_data[event.chain_id]['timestamp']:.2f}s {turn_info} =====================", color=chain_color), flush=True) - print(colored(f"Tool: [{tool_name}] {observation[:200]}{'...' if len(observation) > 200 else ''}", color=chain_color)) + print(colored(f"Tool: [{tool_name}] {observation[:1024]}{'...' if len(observation) > 200 else ''}", color=chain_color)) self.chain_id_data[event.chain_id]["event_type"] = StreamEventType.TOOL_OBSERVATION elif event.event_type == StreamEventType.ERROR: error_msg = event.data.get("error", "") diff --git a/agents/agents/agents/react/react_agent.py b/agents/agents/agents/react/react_agent.py index 8f32c37..765528f 100644 --- a/agents/agents/agents/react/react_agent.py +++ b/agents/agents/agents/react/react_agent.py @@ -1,7 +1,7 @@ import json -from typing import Any, Callable, Dict, List, Tuple +from typing import Any, Dict, List, Optional from ..utils.json import jsonish from ...tools.tool_base import Tool try: @@ -13,12 +13,6 @@ import numpy as np import re -import re -from typing import List, Dict - -import re -from typing import Dict, Optional - def parse_react_step(text: str) -> Dict[str, Optional[str]]: """ Parse a single ReAct-style step (one Thought→Action→Input) into its components. @@ -45,7 +39,32 @@ def parse_react_step(text: str) -> Dict[str, Optional[str]]: "input": m.group("input").strip(), } - +def extract_tool_calls(action_input: str) -> List[Dict]: + if action_input is None: + return [] + + tool_call_str = "" + # Extract the tool call from the action input + # 1. Extract with qwen style + pattern = re.compile(r"\s*(.*?)\s*", re.DOTALL) + m = pattern.search(action_input) + # If we find a tool call, extract it + if m: + tool_call_str = m.group(1).strip() + try: + tool_call = jsonish(tool_call_str) + return [tool_call] + except: + pass + + # 2. Extract directly + try: + tool_call = jsonish(action_input) + return [tool_call] + except: + pass + + return [] ReactSystemPromptTemplate = """You are a ReAct-style agent. When you receive a user query, in each step, you must: @@ -122,26 +141,47 @@ def parse(self, responses: List[str], tools: List[Any]) -> List[Dict]: new_messages_list = [] for response, thought_action in zip(responses, thought_actions): + thought = thought_action["thought"] action = thought_action["action"] action_input = thought_action["input"] - if action_input is not None: - action_input = jsonish(action_input) if action is None: tool_calls = [] else: - tool_calls = [{ - "id": None, - "type": "function", - "function": { - "name": action, - "arguments": action_input - } - }] + tool_calls = extract_tool_calls(action_input) + + formatted_tool_calls = [] + # We only support one tool call for now + if len(tool_calls) == 1: + tool_call = tool_calls[0] + try: + tool_call = json.loads(tool_call) + # {"name": "...", "arguments": "..."} + if "name" in tool_call and "arguments" in tool_call: + name = tool_call["name"] + arguments = tool_call["arguments"] + # {"param1": "...", "param2": "..."} + else: + name = action + arguments = tool_call + formatted_tool_calls.append({ + "id": None, + "type": "function", + "function": { + "name": name, + "arguments": arguments + } + }) + except Exception as e: + name = action + arguments = tool_call + else: + pass + message = { "role": "assistant", "content": [{"type": "text", "text": response}], - "tool_calls": tool_calls, + "tool_calls": formatted_tool_calls, "loss": True } new_messages_list.append(message) diff --git a/agents/agents/rewards/math_reward.py b/agents/agents/rewards/math_reward.py index 85ba7e8..01c06d6 100644 --- a/agents/agents/rewards/math_reward.py +++ b/agents/agents/rewards/math_reward.py @@ -489,8 +489,8 @@ def math_reward_tool(prediction: str, answer: str, trajectory: List[Dict]) -> fl "acc": 1.0 if answer_correct else 0.0, } -@reward(name="math_reward_thought") -def math_reward_thought(prediction: str, answer: str, trajectory: List[Dict]) -> float: +@reward(name="math_reward_thought_with_tool") +def math_reward_thought_with_tool(prediction: str, answer: str, trajectory: List[Dict]) -> float: has_called_tool = False for msg in trajectory: if msg["role"] == "tool": @@ -519,7 +519,7 @@ def math_reward_thought(prediction: str, answer: str, trajectory: List[Dict]) -> elif has_called_tool and all_have_thought and not answer_correct: reward = 0.1 elif has_called_tool and not all_have_thought and answer_correct: - reward = 0.1 + reward = 0.0 elif has_called_tool and all_have_thought and answer_correct: reward = 1.0 else: diff --git a/agents/agents/tools/src/code/tools.py b/agents/agents/tools/src/code/tools.py index 2845ff7..0ddc378 100644 --- a/agents/agents/tools/src/code/tools.py +++ b/agents/agents/tools/src/code/tools.py @@ -54,7 +54,7 @@ def make_request(url, payload, headers, timeout=20): # else: # return str(response) -@tool(env_cls=PythonSandboxEnv, name="code_interpreter", description="Run the code in docker container and return the output from stdout or stderr", stateful=True, pool_size=16) +@tool(env_cls=PythonSandboxEnv, name="code_interpreter", description="Run the code in docker container and return the output from stdout or stderr. Output should be printed.", stateful=True, pool_size=16) async def code_interpreter(code: str, env: PythonSandboxEnv): """ Run the code in docker container and return the output from stdout or stderr diff --git a/agents/requirements.txt b/agents/requirements.txt index 15b0b23..55963d4 100644 --- a/agents/requirements.txt +++ b/agents/requirements.txt @@ -6,7 +6,7 @@ redis docker openai faiss-cpu -vllm==0.9.1 +vllm==0.9.2 termcolor tenacity nest-asyncio From 0de34c7f01e80a6c464663ee16661c920cb47bf2 Mon Sep 17 00:00:00 2001 From: Reason-Wang Date: Thu, 24 Jul 2025 10:02:22 +0000 Subject: [PATCH 5/8] Update console streaming --- agents/agents/agents/chain/streaming_observer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/agents/agents/agents/chain/streaming_observer.py b/agents/agents/agents/chain/streaming_observer.py index f358224..f2ab24e 100644 --- a/agents/agents/agents/chain/streaming_observer.py +++ b/agents/agents/agents/chain/streaming_observer.py @@ -145,7 +145,7 @@ async def on_event(self, event: StreamEvent) -> None: } if event.event_type == StreamEventType.LLM_GENERATION_START: - print(colored(f"{event.timestamp - self.chain_id_data[event.chain_id]['timestamp']:.2f}s {turn_info} =====================", color=chain_color), flush=True) + print(f"{event.timestamp - self.chain_id_data[event.chain_id]['timestamp']:.2f}s {turn_info}".center(80, "="), flush=True) elif event.event_type == StreamEventType.LLM_GENERATION_CHUNK: content = event.data.get("content", "") if content: @@ -163,8 +163,8 @@ async def on_event(self, event: StreamEvent) -> None: elif event.event_type == StreamEventType.TOOL_OBSERVATION: observation = event.data.get("observation", "") tool_name = event.data.get("tool_name", "") - print(colored(f"{event.timestamp - self.chain_id_data[event.chain_id]['timestamp']:.2f}s {turn_info} =====================", color=chain_color), flush=True) print(colored(f"Tool: [{tool_name}] {observation[:1024]}{'...' if len(observation) > 200 else ''}", color=chain_color)) + print(f"".center(80, "="), flush=True) self.chain_id_data[event.chain_id]["event_type"] = StreamEventType.TOOL_OBSERVATION elif event.event_type == StreamEventType.ERROR: error_msg = event.data.get("error", "") From 78e3eb66b997f14437c083a3084a36e26f5235e7 Mon Sep 17 00:00:00 2001 From: Reason-Wang Date: Sun, 27 Jul 2025 07:23:50 +0000 Subject: [PATCH 6/8] Update react agent and verl --- agents/agents/agents/react/react_agent.py | 43 ++++++++++++-------- agents/tests/unit/agents/test_react_agent.py | 22 ++++------ 2 files changed, 33 insertions(+), 32 deletions(-) diff --git a/agents/agents/agents/react/react_agent.py b/agents/agents/agents/react/react_agent.py index 765528f..e176149 100644 --- a/agents/agents/agents/react/react_agent.py +++ b/agents/agents/agents/react/react_agent.py @@ -15,29 +15,36 @@ def parse_react_step(text: str) -> Dict[str, Optional[str]]: """ - Parse a single ReAct-style step (one Thought→Action→Input) into its components. + Parse a single ReAct-style step into its components. Args: - text: A string containing exactly one Thought:, one Action:, and one Input:. + text: A string that may contain Thought:, Action:, and/or Input: components. Returns: - A dict with keys 'thought', 'action', and 'input', or None if not found. + A dict with keys 'thought', 'action', and 'input', with None for missing components. """ - pattern = re.compile( - r"Thought:\s*(?P.*?)\s*" - r"Action:\s*(?P.*?)\s*" - r"Input:\s*(?P.*)", - re.IGNORECASE | re.DOTALL - ) - m = pattern.search(text) - if not m: - return {"thought": None, "action": None, "input": None} - - return { - "thought": m.group("thought").strip(), - "action": m.group("action").strip(), - "input": m.group("input").strip(), - } + # Initialize result with None values + result = {"thought": None, "action": None, "input": None} + + # Pattern for Thought: + thought_pattern = re.compile(r"Thought:\s*(.*?)(?=\s*(?:Action:|Input:|$))", re.IGNORECASE | re.DOTALL) + thought_match = thought_pattern.search(text) + if thought_match: + result["thought"] = thought_match.group(1).strip() + + # Pattern for Action: + action_pattern = re.compile(r"Action:\s*(.*?)(?=\s*(?:Thought:|Input:|$))", re.IGNORECASE | re.DOTALL) + action_match = action_pattern.search(text) + if action_match: + result["action"] = action_match.group(1).strip() + + # Pattern for Input: + input_pattern = re.compile(r"Input:\s*(.*?)(?=\s*(?:Thought:|Action:|$))", re.IGNORECASE | re.DOTALL) + input_match = input_pattern.search(text) + if input_match: + result["input"] = input_match.group(1).strip() + + return result def extract_tool_calls(action_input: str) -> List[Dict]: if action_input is None: diff --git a/agents/tests/unit/agents/test_react_agent.py b/agents/tests/unit/agents/test_react_agent.py index 1e84eda..dcb2043 100644 --- a/agents/tests/unit/agents/test_react_agent.py +++ b/agents/tests/unit/agents/test_react_agent.py @@ -10,7 +10,7 @@ def test_react_agent_initialization(): agent = ReactAgent( "Qwen/Qwen2.5-3B-Instruct", tools=tools, - template="qwen-7b-chat", + template="qwen2.5", task_info=task_info, backend="client" ) @@ -35,7 +35,7 @@ def test_parse_react_step(): # Test with missing components text_missing = "Thought: I'm thinking about something." result_missing = parse_react_step(text_missing) - assert result_missing["thought"] is None + assert result_missing["thought"] == "I'm thinking about something." assert result_missing["action"] is None assert result_missing["input"] is None @@ -45,25 +45,19 @@ def test_react_agent_parse(): agent = ReactAgent( "Qwen/Qwen2.5-3B-Instruct", tools=tools, - template="qwen-7b-chat", + template="qwen2.5", backend="client" ) - # Mock the generate method to return a predefined response - def mock_generate(*args, **kwargs): - return ["""Thought: I need to search for information. + responses = ["""Thought: I need to search for information. Action: google_search Input: {"query": "test query"}"""] - agent.generate = mock_generate - - # Test the parse method - messages_list = [[{"role": "user", "content": "Find information about Python"}]] - result = agent.parse(messages_list, tools) - + result = agent.parse(responses, tools) + print(result) assert len(result) == 1 assert result[0]["role"] == "assistant" - assert "Thought: I need to search for information." in result[0]["content"] + assert "Thought: I need to search for information." in result[0]["content"][0]["text"] assert len(result[0]["tool_calls"]) == 1 assert result[0]["tool_calls"][0]["function"]["name"] == "google_search" - assert result[0]["tool_calls"][0]["function"]["arguments"] == '{"query": "test query"}' \ No newline at end of file + assert result[0]["tool_calls"][0]["function"]["arguments"] == {"query": "test query"} \ No newline at end of file From b5842482f7f2468a6421eef06b642ae54cd201fc Mon Sep 17 00:00:00 2001 From: Reason-Wang Date: Sun, 27 Jul 2025 13:33:32 +0000 Subject: [PATCH 7/8] Add streaming option and update docs --- agents/agents/agents/agent_base.py | 8 + agents/agents/agents/chain/chain_base.py | 1 - agents/agents/envs/manager/env_manager.py | 1 + agents/agents/rewards/qa_reward.py | 6 +- .../tools/src/search/dense_retriever.py | 6 +- docs/index.rst | 1 + docs/start/agent_examples.md | 215 ++++++++++++++++++ docs/start/use_agent.md | 74 ------ 8 files changed, 231 insertions(+), 81 deletions(-) create mode 100644 docs/start/agent_examples.md delete mode 100644 docs/start/use_agent.md diff --git a/agents/agents/agents/agent_base.py b/agents/agents/agents/agent_base.py index ca89159..0d712ce 100644 --- a/agents/agents/agents/agent_base.py +++ b/agents/agents/agents/agent_base.py @@ -14,6 +14,7 @@ import os import transformers import warnings +from .chain.streaming_observer import ConsoleStreamObserver, StreamingManager try: from verl.protocol import DataProto except ImportError: @@ -43,6 +44,7 @@ def __init__( log_file: str = "agent", project_name: str = None, run_name: str = None, + streaming: str = "console", **kwargs # To pass other unused arguments ): """ @@ -68,6 +70,12 @@ def __init__( self.jinja_template = get_template(self.template).jinja_template() self.project_name = project_name self.run_name = run_name + self.streaming_manager = StreamingManager() + if streaming == "console": + self.streaming_manager.add_observer(ConsoleStreamObserver()) + else: + # TODO: Support other streaming modes + raise ValueError(f"Streaming mode {streaming} is not supported.") super().__init__() if kwargs: warnings.warn(f"Unused arguments for agent initialization: {kwargs}") diff --git a/agents/agents/agents/chain/chain_base.py b/agents/agents/agents/chain/chain_base.py index b80a67c..61cd9d7 100644 --- a/agents/agents/agents/chain/chain_base.py +++ b/agents/agents/agents/chain/chain_base.py @@ -138,7 +138,6 @@ def __init__(self): self.finished_chains_count = 0 self.initialize_monitor() self.monitor_info = defaultdict(list) - self.streaming_manager = StreamingManager() def reset(self) -> None: self.status_code: str = "continue" diff --git a/agents/agents/envs/manager/env_manager.py b/agents/agents/envs/manager/env_manager.py index c74dc7d..495d6d9 100644 --- a/agents/agents/envs/manager/env_manager.py +++ b/agents/agents/envs/manager/env_manager.py @@ -19,6 +19,7 @@ async def start(cls, env_cls: type[BaseEnv], size: int = 1, env_kwargs: dict | N or add more envs to the existing pool if the size is larger. If the size is smaller, do nothing. """ + # TODO: Currently, WarmPool will start all the envs at once. This should be fine for training, but might be wasteful for showing the demo, we may need to support feature to start a new env when acquiring, or make it a configurable option. key = env_cls if key not in cls._pools: cls._pools[key] = WarmPool( diff --git a/agents/agents/rewards/qa_reward.py b/agents/agents/rewards/qa_reward.py index 83b596f..583519f 100644 --- a/agents/agents/rewards/qa_reward.py +++ b/agents/agents/rewards/qa_reward.py @@ -52,11 +52,11 @@ def em_score(prediction, ground_truth): @reward(name="qa_f1_reward") -def qa_f1_reward(prediction: str, golden_answer: str, trajectory: List[str]) -> float: +def qa_f1_reward(prediction: str, answer: str, trajectory: List[str]) -> float: # Extract answer from agent's response response = prediction - f1, precision, recall = f1_score(response, golden_answer) - em = em_score(response, golden_answer) + f1, precision, recall = f1_score(response, answer) + em = em_score(response, answer) return { "reward": f1, diff --git a/agents/agents/tools/src/search/dense_retriever.py b/agents/agents/tools/src/search/dense_retriever.py index 9ea0288..d8b77d2 100644 --- a/agents/agents/tools/src/search/dense_retriever.py +++ b/agents/agents/tools/src/search/dense_retriever.py @@ -4,7 +4,7 @@ from transformers import AutoTokenizer, AutoModel from torch import Tensor from ...tool_base import tool -from ....__init__ import AGENT_DATA_DIR +from ....__init__ import AGENT_CACHE_DIR def load_corpus(corpus_path: str): corpus = datasets.load_dataset( @@ -49,12 +49,12 @@ async def search(self, queries: list[str], top_k: int): @tool(name="dense_retrieve", description="Use a dense retriever to retrieve documents from a corpus.", max_length=4096) async def dense_retrieve(query: str): - global AGENT_DATA_DIR + global AGENT_CACHE_DIR if not query.startswith("query:"): query = "query: " + query global GLOBAL_RETRIEVER if GLOBAL_RETRIEVER is None: - GLOBAL_RETRIEVER = DenseRetriever(corpus_file=os.path.join(AGENT_DATA_DIR, "search", "wiki-18.jsonl"), index_file=os.path.join(AGENT_DATA_DIR, "search", "e5_Flat.index")) + GLOBAL_RETRIEVER = DenseRetriever(corpus_file=os.path.join(AGENT_CACHE_DIR, "data", "search", "wiki-18.jsonl"), index_file=os.path.join(AGENT_CACHE_DIR, "data", "search", "e5_Flat.index")) doc_list = await GLOBAL_RETRIEVER.search(query, 3) doc_list = doc_list[0] content = "" diff --git a/docs/index.rst b/docs/index.rst index 80dd89c..3278b99 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -11,6 +11,7 @@ AgentFly is a scalable and extensible Agent-RL framework designed to empower LM start/installation start/training_example + start/agent_examples .. toctree:: :maxdepth: 2 diff --git a/docs/start/agent_examples.md b/docs/start/agent_examples.md new file mode 100644 index 0000000..2f5f3b0 --- /dev/null +++ b/docs/start/agent_examples.md @@ -0,0 +1,215 @@ +## Build an Agent + +### Use a Predefined Agent +We can specify the following arguments to use a predefined agent: + +- model_name: the path or name or the model, used to load weights +- tools: tools that will be used by the agent +- template: chat template +- backend: what type of backend + +The following shows an example to use Qwen2.5-7B-Instruct as a react agent: + +```python +from agents.agents.react.react_agent import ReactAgent +from agents.tools.src.code.tools import code_interpreter +from agents.tools.src.search.google_search import google_search_serper +from agents.tools.src.react.tools import answer + +tools = [google_search_serper, answer] + +task_info = "Use code to get answers. Result must be printed." + +react_agent = ReactAgent( + "Qwen/Qwen2.5-7B-Instruct", + tools=tools, + template="qwen2.5-no-tool", + task_info=task_info, + backend="async_vllm" +) + +question = "Solve the equation 2x + 5y = 4 such that sum of x and y is 7." +messages = [ + { + "messages": [ + {"role": "user", "content": f"{question}"} + ], + "question": f"{question}", + }, +] + +await react_agent.run_async( + max_steps=4, + start_messages=messages, + num_chains=5 # for the question, the agent will generate 5 trajectories +) + +``` + +After the rollout, we can obtain the trajectories: + +```python +react_agent.trajectories +``` + +Obtaining the rewards (if you specified reward function and give necessary parameters in input messages) +``` +react_agent.rewards) +``` + +### Customize Agent + +You can customize your own agent by defining how the agent do generation and handle tool calls. + +```python +class CustomizedAgent(BaseAgent): + def __init__(self, + **kwargs + ) + super().__init__(**kwargs) + + async def generate_async(self, messages_list: List[List[Dict]], **args): + return await self.llm_engine.generate_async(messages_list, **args) + + def parse(self, responses: List(str), tools): + # parse responses into tool calls + ... +``` + +### Use Trained Agent +We provide the following agent that we can try: + +- WebShop Agent: +```python +import asyncio +from agents.agents import ReactAgent +from agents.tools import webshop_browser +from agents.rewards import webshop_reward +from agents.agents.chain.streaming_observer import ConsoleStreamObserver + +tools = [webshop_browser] + +agent = ReactAgent( + "Agent-One/Qwen2.5-3B-Instruct-WebShop", + tools=tools, + template="qwen2.5", + backend="async_vllm", + streaming="console" +) + +question = "I am looking for a gluten free, 100% vegan plant based protein shake that is soy-free, and price lower than 40.00 dollars" + +messages = [ + { + "messages": [ + {"role": "user", "content": f"{question}"} + ], + "question": f"{question}" + }, +] + +await agent.run_async( + max_steps=10, + start_messages=messages, + num_chains=1, + enable_streaming=True +) +``` +- Science World Agent +```python +from agents.tools import scienceworld_explorer +from agents.rewards import scienceworld_reward + + +tools = [scienceworld_explorer] + +agent = ReactAgent( + "Agent-One/Qwen2.5-7B-Instruct-ScienceWorld-REINFORCEPP", + tools=tools, + template="qwen2.5-no-system-tool", + backend="async_vllm", + reward_fn=scienceworld_reward, + streaming="console" +) + +# question = "Your task is to boil ice cream. For compounds without a boiling point, combusting the substance is also acceptable. First, focus on the substance. Then, take actions that will cause it to change its state of matter." +# task_name = "boil" +# variation_idx = 14 + + +# question = "Your task is to find a(n) animal. First, focus on the thing. Then, move it to the green box in the kitchen." +# task_name = "find-animal" +# variation_idx = 150 + +# question = "Your task is to find a(n) plant. First, focus on the thing. Then, move it to the purple box in the bathroom." +# task_name = "find-plant" +# variation_idx = 191 + + +# question = "Your task is to grow a orange. This will require growing several plants, and them being crosspollinated to produce fruit. Seeds can be found in the kitchen. To complete the task, focus on the grown orange." +# task_name = "grow-fruit" +# variation_idx = 92 + +# question = "Your task is to determine which of the two inclined planes (bronze, copper) has the most friction. After completing your experiment, focus on the inclined plane with the most friction." +# task_name = "inclined-plane-friction-named-surfaces" +# variation_idx = 776 + +question = "Your task is to measure the melting point of orange juice, which is located around the kitchen. First, focus on the thermometer. Next, focus on the orange juice. If the melting point of orange juice is above 200.0 degrees celsius, focus on the yellow box. If the melting point of orange juice is below 200.0 degrees celsius, focus on the purple box. The boxes are located around the kitchen." +task_name = "measure-melting-point-known-substance" +variation_idx = 247 + +messages = [ + { + "messages": [ + {"role": "user", "content": f"{question}"} + ], + "question": f"{question}", + "task_name": task_name, + "variation_idx": variation_idx + }, +] + +await agent.run_async( + max_steps=20, + start_messages=messages, + num_chains=1, + enable_streaming=True +) + +print(agent.rewards) +``` + +- Retrieval Agent + +```python +from agents.tools import dense_retrieve, asyncdense_retrieve + +tools = [dense_retrieve] + +agent = ReactAgent( + "Agent-One/Qwen2.5-3B-Instruct-Retrieval-GRPO", + tools=tools, + template="qwen2.5-no-system-tool", + backend="async_vllm", + streaming="console" +) + +question = "Who is Geoffrey Hinton" + + +messages = [ + { + "messages": [ + {"role": "user", "content": f"{question}"} + ], + "question": f"{question}", + }, +] + +await agent.run_async( + max_steps=6, + start_messages=messages, + num_chains=1, + enable_streaming=True +) +``` \ No newline at end of file diff --git a/docs/start/use_agent.md b/docs/start/use_agent.md deleted file mode 100644 index 0ed10ea..0000000 --- a/docs/start/use_agent.md +++ /dev/null @@ -1,74 +0,0 @@ -## Build an Agent - -### Use a Predefined Agent -We can specify the following arguments to use a predefined agent: - -- model_name: the path or name or the model, used to load weights -- tools: tools that will be used by the agent -- template: chat template -- backend: what type of backend - -The following shows an example to use Qwen2.5-7B-Instruct as a react agent: - -```python -from agents.agents.react.react_agent import ReactAgent -from agents.tools.src.code.tools import code_interpreter -from agents.tools.src.search.google_search import google_search_serper -from agents.tools.src.react.tools import answer - -tools = [google_search_serper, answer] - -task_info = "Use code to get answers. Result must be printed." - -react_agent = ReactAgent( - "Qwen/Qwen2.5-7B-Instruct", - tools=tools, - template="qwen2.5-no-tool", - task_info=task_info, - backend="async_vllm" -) - -question = "Solve the equation 2x + 5y = 4 such that sum of x and y is 7." -messages = [ - { - "messages": [ - {"role": "user", "content": f"{question}"} - ], - "question": f"{question}", - }, -] - -await react_agent.run_async( - max_steps=4, - start_messages=messages, - num_chains=5 # for the question, the agent will generate 5 trajectories -) - -``` - -After the rollout, we can obtain the trajectories: - -```python -react_agent.trajectories -``` - -### Customize Agent - -You can customize your own agent by defining how the agent do generation and handle tool calls. - -```python -class CustomizedAgent(BaseAgent): - def __init__(self, - **kwargs - ) - super().__init__(**kwargs) - - async def generate_async(self, messages_list: List[List[Dict]], **args): - return await self.llm_engine.generate_async(messages_list, **args) - - def parse(self, responses: List(str), tools): - # parse responses into tool calls - ... -``` - - From 4a49bad957a8e308ec37927c7023ba39405ca8a3 Mon Sep 17 00:00:00 2001 From: Reason-Wang Date: Sun, 27 Jul 2025 14:17:21 +0000 Subject: [PATCH 8/8] Remove deprecated file --- .../agents/agents/chain/README_STREAMING.md | 324 ------------------ 1 file changed, 324 deletions(-) delete mode 100644 agents/agents/agents/chain/README_STREAMING.md diff --git a/agents/agents/agents/chain/README_STREAMING.md b/agents/agents/agents/chain/README_STREAMING.md deleted file mode 100644 index 2869b05..0000000 --- a/agents/agents/agents/chain/README_STREAMING.md +++ /dev/null @@ -1,324 +0,0 @@ -# Streaming Functionality for LLM Agent Reinforcement Learning - -This document describes the streaming functionality added to the LLM agent reinforcement learning framework, which allows real-time monitoring of agent responses and tool observations. - -## Overview - -The streaming functionality provides: - -1. **Real-time LLM response streaming** - See tokens as they are generated -2. **Tool observation streaming** - Monitor tool calls and their results in real-time -3. **Event-based architecture** - Flexible observer pattern for different use cases -4. **Multiple output formats** - Console, JSON, WebSocket, and custom callbacks -5. **Async support** - Non-blocking streaming with proper async/await patterns -6. **Multi-chain support** - Handle multiple chains without mixing outputs - -## Architecture - -### Core Components - -1. **StreamEvent** - Represents a streaming event with metadata -2. **StreamObserver** - Abstract base class for event observers -3. **StreamingManager** - Manages observers and event distribution -4. **ChainGeneration** - Enhanced with streaming support - -### Event Types - -- `LLM_GENERATION_START` - LLM generation begins -- `LLM_GENERATION_CHUNK` - Individual token/chunk generated -- `LLM_GENERATION_END` - LLM generation completes -- `TOOL_CALL_START` - Tool call begins -- `TOOL_CALL_END` - Tool call completes -- `TOOL_OBSERVATION` - Tool observation received -- `CHAIN_START` - Agent chain begins -- `CHAIN_END` - Agent chain completes -- `ERROR` - Error occurred - -### Multi-Chain Problem and Solutions - -**Problem**: When running multiple chains with streaming, outputs from different chains get mixed together, making it impossible to follow which events belong to which chain. - -**Solutions**: - -1. **Automatic Color Coding** - Each chain gets a different color in console output -2. **Chain Filtering** - Filter events to only show specific chains -3. **Chain-Specific Observers** - Create separate observers for each chain -4. **Separate Async Generators** - Process each chain's events independently -5. **Multi-Chain Observer** - Organize observers by chain ID - -## Usage Examples - -### Basic Console Streaming - -```python -from agents.agents.agents.agent_base import BaseAgent -from agents.agents.agents.chain.streaming_observer import ConsoleStreamObserver - -# Initialize agent -agent = YourAgent(model_name="your-model", tools=[...]) - -# Add console observer -console_observer = ConsoleStreamObserver(show_timestamps=True) -agent.streaming_manager.add_observer(console_observer) - -# Run with streaming -await agent.run_async( - max_steps=5, - start_messages=your_messages, - num_chains=1, - enable_streaming=True -) -``` - -### Multi-Chain Streaming Solutions - -When running multiple chains, the streaming output can become mixed. Here are several solutions: - -#### 1. Colored Output (Automatic) - -```python -# Each chain gets a different color automatically -console_observer = ConsoleStreamObserver(show_timestamps=True) -agent.streaming_manager.add_observer(console_observer) - -await agent.run_async( - max_steps=5, - start_messages=your_messages, - num_chains=3, # Multiple chains - enable_streaming=True -) -``` - -#### 2. Chain-Specific Observers - -```python -from agents.agents.agents.chain.streaming_observer import MultiChainStreamObserver, ChainSpecificStreamObserver - -# Create multi-chain observer -multi_observer = MultiChainStreamObserver() - -# Add observers for specific chains -for chain_id in ["chain_0", "chain_1", "chain_2"]: - console_observer = ConsoleStreamObserver(show_timestamps=True) - chain_observer = ChainSpecificStreamObserver(chain_id, console_observer) - multi_observer.add_chain_observer(chain_id, chain_observer) - -agent.streaming_manager.add_observer(multi_observer) -``` - -#### 3. Filter by Chain - -```python -# Only observe events from a specific chain -filtered_observer = ConsoleStreamObserver( - show_timestamps=True, - chain_filter="chain_0" # Only show chain_0 events -) -agent.streaming_manager.add_observer(filtered_observer) -``` - -#### 4. Separate Async Generators per Chain - -```python -from agents.agents.agents.chain.streaming_observer import AsyncGeneratorStreamObserver - -# Create separate generators for each chain -chain_generators = {} -for i in range(3): - chain_id = f"chain_{i}" - async_observer = AsyncGeneratorStreamObserver(chain_filter=chain_id) - agent.streaming_manager.add_observer(async_observer) - chain_generators[chain_id] = async_observer.events() - -# Process each chain separately -async def process_chain(chain_id, generator): - async for event in generator: - print(f"{chain_id}: {event.event_type.value}") - -# Run all chains concurrently -tasks = [process_chain(chain_id, generator) for chain_id, generator in chain_generators.items()] -await asyncio.gather(*tasks) -``` - -### JSON Logging - -```python -from agents.agents.agents.chain.streaming_observer import JSONStreamObserver - -# Add JSON observer -json_observer = JSONStreamObserver(file_path="events.jsonl") -agent.streaming_manager.add_observer(json_observer) -``` - -### Custom Streaming Callback - -```python -async def custom_callback(chunk: str): - print(f"🔄 {chunk}", end="", flush=True) - -await agent.run_async( - max_steps=5, - start_messages=your_messages, - num_chains=1, - enable_streaming=True, - streaming_callback=custom_callback -) -``` - -### WebSocket Streaming - -```python -from agents.agents.agents.chain.websocket_streaming import WebSocketStreamingServer - -# Start WebSocket server -server = WebSocketStreamingServer(host="localhost", port=8765) -await server.start() - -# Add WebSocket observer -agent.streaming_manager.add_observer(server.get_observer()) - -# Run agent -await agent.run_async(..., enable_streaming=True) -``` - -### Async Generator Events - -```python -from agents.agents.agents.chain.streaming_observer import AsyncGeneratorStreamObserver - -# Create async generator observer -async_observer = AsyncGeneratorStreamObserver() -agent.streaming_manager.add_observer(async_observer) - -# Start agent run -run_task = asyncio.create_task(agent.run_async(..., enable_streaming=True)) - -# Process events as they arrive -async for event in async_observer.events(): - print(f"Event: {event.event_type.value}") - if event.event_type.value == "llm_generation_chunk": - print(f"Content: {event.data.get('content', '')}") - -await run_task -``` - -## Backend Support - -### Transformers Backend - -The Transformers backend supports streaming through token-by-token generation: - -```python -agent = YourAgent( - model_name="your-model", - backend="transformers", - # ... other args -) -``` - -### Async vLLM Backend - -The Async vLLM backend provides efficient streaming: - -```python -agent = YourAgent( - model_name="your-model", - backend="async_vllm", - # ... other args -) -``` - -## Event Structure - -Each streaming event contains: - -```python -@dataclass -class StreamEvent: - event_type: StreamEventType - chain_id: str - timestamp: float - data: Dict[str, Any] - step: Optional[int] = None - depth: Optional[int] = None -``` - -### Example Events - -**LLM Generation Chunk:** -```json -{ - "event_type": "llm_generation_chunk", - "chain_id": "uuid-123", - "timestamp": 1234567890.123, - "data": {"content": "def factorial"}, - "step": 1, - "depth": 1 -} -``` - -**Tool Observation:** -```json -{ - "event_type": "tool_observation", - "chain_id": "uuid-123", - "timestamp": 1234567890.456, - "data": { - "tool_name": "code_interpreter", - "observation": "120", - "status": "success" - }, - "step": 1, - "depth": 1 -} -``` - -## Performance Considerations - -1. **Memory Usage** - Streaming events are lightweight but can accumulate -2. **Network Overhead** - WebSocket streaming adds minimal overhead -3. **Backend Compatibility** - Not all backends support streaming equally -4. **Observer Performance** - Heavy observers can slow down the main loop - -## Best Practices - -1. **Use appropriate observers** - Console for debugging, JSON for logging, WebSocket for web apps -2. **Handle errors gracefully** - Implement error handling in custom observers -3. **Clean up resources** - Properly close WebSocket connections and file handles -4. **Monitor performance** - Watch for memory leaks in long-running streams -5. **Test thoroughly** - Streaming adds complexity, test edge cases - -## Integration with Existing Code - -The streaming functionality is designed to be non-intrusive: - -- Existing code continues to work without changes -- Streaming is opt-in via the `run_async_streaming` method -- Observers can be added/removed at runtime -- No performance impact when streaming is disabled - -## Troubleshooting - -### Common Issues - -1. **No streaming output** - Check if backend supports streaming -2. **WebSocket connection issues** - Verify port availability and firewall settings -3. **Memory leaks** - Ensure observers are properly cleaned up -4. **Performance issues** - Consider using fewer observers or lighter implementations - -### Debug Mode - -Enable debug logging to troubleshoot streaming issues: - -```python -import logging -logging.basicConfig(level=logging.DEBUG) -``` - -## Future Enhancements - -1. **Filtering** - Event filtering based on type, chain_id, etc. -2. **Batching** - Batch multiple events for efficiency -3. **Compression** - Compress WebSocket messages for large-scale deployments -4. **Authentication** - Add authentication to WebSocket connections -5. **Metrics** - Built-in streaming metrics and monitoring \ No newline at end of file