diff --git a/.gitignore b/.gitignore index 5592ce9..d2646d5 100644 --- a/.gitignore +++ b/.gitignore @@ -132,5 +132,9 @@ outputs # Notebooks agents/tests/*.ipynb +agents/tests/*.jpg +agents/tests/*.jpeg +agents/tests/*.png agents/agents/*.ipynb +agents/temp/ diff --git a/README.md b/README.md index a1bf150..3c88df4 100644 --- a/README.md +++ b/README.md @@ -1,26 +1,53 @@ -# AgentFly: Scalable and Extensible Reinforcement Learning for LLM Agents +# πŸͺ½AgentFly: Training scalable LLM agents with RL (multi-turn, async tools/rewards, multimodal) + +

-Static Badge -Static Badge -Static Badge +Static Badge +Static Badge +Static Badge +Static Badge +Static Badge +

+

+Static Badge +Static Badge +Static Badge +Static Badge +Static Badge +Static Badge +Static Badge +
+Static Badge +Static Badge +Static Badge +Static Badge +Static Badge +
+Static Badge +Static Badge +Static Badge +Static Badge

- AgentFly is an extensible framework for building LLM agents with reinforcement learning. It supports multi-turn training by adapting traditional RL methods with token-level masking. It features a decorator-based interface for defining tools and reward functions, enabling seamless extension and ease of use. To support high-throughput training, it implemented asynchronous execution of tool calls and reward computations, and design a centralized resource management system for scalable environment coordination. A suite of prebuilt tools and environments are provided. ![Overview](assets/images/overview.png) -## πŸ†• News +## News -**Multi-Modal (Vision) Agent Training Support** - Thanks to the powerful template system, AgentFly now supports training vision-language agents! πŸŽ‰ - -Train agents that can see and understand visual content, including GUI automation and image-based QA. See our [predefined training examples](docs/examples/predefined_training_examples.md) for ready-to-use scripts. +**08/2025 Multi-Modal (Vision) Agent Training Support** - Thanks to the powerful template system, AgentFly now supports training vision-language agents! πŸŽ‰ Train agents that can see and understand visual content, including GUI automation and image-based QA. See our [predefined training examples](docs/examples/predefined_training_examples.md) for ready-to-use scripts. --- -**New: Chat Template System** - A flexible framework for creating conversation templates with multi-model support, vision capabilities, and tool integration. [Learn more β†’](docs/chat_template/) +**08/2025 Chat Template System** - A flexible framework for creating conversation templates with multi-model support, vision capabilities, and tool integration. [Learn more β†’](docs/chat_template/) ## Installation +**Option 1**: One-line Installation: +``` +bash install.sh # Assume conda with python3.10.x +``` +**Option 2**: Customized Installation + Clone and initialize the project: ```bash git clone https://github.com/Agent-One-Lab/AgentFly @@ -42,6 +69,11 @@ Search requires redis to cache results, an optional way to install with conda: conda install conda-forge::redis-server==7.4.0 ``` +## Quick Start +``` + +``` + ## Features ### 1. Multi-Chain Agent Rollout and Multi-Turn Training To support algorithms like GRPO, Reinforce++, we design multi-chain inference, enabling agents to solve one task with multiple paths at the same time. We build RL computation and update LLMs in multi-turn manner by applying token masks. The training is based on [verl](https://github.com/volcengine/verl). @@ -67,6 +99,10 @@ agent = ReactAgent( ### 3. Easy Development Decoupled agent and training module. Simply customize your own agent, which can directly be applied to training. +## Training Curves +Reward curves on Qwen2.5-Instruct 3B and 7B models. +![Curves](assets/images/training_curves.png) + ## Training ### Run Example Training @@ -155,9 +191,6 @@ https://github.com/user-attachments/assets/b8f42534-8d40-48a0-a264-f378e479bb3a [Discord](https://discord.gg/CchUj7Sp) -## Training Curves -Reward curves on Qwen2.5-Instruct 3B and 7B models. -![Curves](assets/images/training_curves.png) ## Cite If you used our code or find it helpful, please cite: diff --git a/agents/agents/agents/agent_base.py b/agents/agents/agents/agent_base.py index 4ac693a..a58092e 100644 --- a/agents/agents/agents/agent_base.py +++ b/agents/agents/agents/agent_base.py @@ -70,6 +70,7 @@ def __init__( self.template = template self.max_length = max_length self.tools = tools + self.tool_names = [tool.name for tool in tools] self.system_prompt = system_prompt self.model_name_or_path = model_name_or_path @@ -209,7 +210,7 @@ def trajectories(self): return trajectories - def tokenize_trajectories(self, tokenizer = None, return_reward_mask: bool = False): + def tokenize_trajectories(self, tokenizer = None, return_reward_mask: bool = False, concatenate_mm_inputs: bool = True): if tokenizer is None: tokenizer = self.tokenizer @@ -239,7 +240,16 @@ def tokenize_trajectories(self, tokenizer = None, return_reward_mask: bool = Fal info['last_response'] = last_response other_info_list.append(info) - inputs = tokenize_conversations(messages_list, tokenizer=tokenizer, conv_template=self.template, processor=self.processor, max_length=self.max_length, return_reward_mask=return_reward_mask) + inputs = tokenize_conversations( + messages_list, + tokenizer=tokenizer, + template=self.template, + processor=self.processor, + max_length=self.max_length, + return_reward_mask=return_reward_mask, + add_generation_prompt=True, + concatenate_mm_inputs=concatenate_mm_inputs, + ) position_ids = torch.clip(torch.cumsum(inputs['attention_mask'], dim=-1) - 1, min=0, max=None) inputs['position_ids'] = position_ids @@ -318,7 +328,7 @@ def rewards(self): def get_verl_data_proto(self): - inputs, other_info_list = self.tokenize_trajectories(return_reward_mask=True) + inputs, other_info_list = self.tokenize_trajectories(return_reward_mask=True, concatenate_mm_inputs=False) group_ids = np.array([info["group_id"] for info in other_info_list], dtype=object) # Do evaluation here reward_values, other_values = self.rewards @@ -329,6 +339,10 @@ def get_verl_data_proto(self): inputs[f"rm_{key}"] = np.array(values) # We handle the group id in the agent side, to be compatible with GRPO inputs["uid"] = group_ids + + if "mm_inputs" in inputs: + mm_inputs = inputs.pop("mm_inputs") + inputs["multi_modal_inputs"] = np.array(mm_inputs, dtype=object) batch = DataProto.from_single_dict(inputs, meta_info={"use_agent": True}) return batch diff --git a/agents/agents/agents/chain/chain_base.py b/agents/agents/agents/chain/chain_base.py index faf8f71..3fcf12d 100644 --- a/agents/agents/agents/chain/chain_base.py +++ b/agents/agents/agents/chain/chain_base.py @@ -474,7 +474,12 @@ async def _execute_tool_call(self, tool_call, newest_messages, chain, chain_id, have_set_tools = True # Execute tool call - result = await submit_tool_call(tool_name, tool_input, id=chain_id) + result = await submit_tool_call( + tool_name, + tool_input, + id=chain_id, + allowed_tool_names=self.tool_names + ) if enable_streaming: # Emit tool observation event diff --git a/agents/agents/agents/specialized/gui_agent.py b/agents/agents/agents/specialized/gui_agent.py index e9dd812..9436819 100644 --- a/agents/agents/agents/specialized/gui_agent.py +++ b/agents/agents/agents/specialized/gui_agent.py @@ -2,10 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 import json +import logging from typing import List, Any, Tuple, Dict, Optional from ..agent_base import BaseAgent from agents.utils.ui_action_parser import parse_action_to_structure_output, IMAGE_FACTOR +logger = logging.getLogger(__name__) + # Default image dimensions TEST_IMAGE_HEIGHT = 1080 TEST_IMAGE_WIDTH = 1920 @@ -96,8 +99,8 @@ def parse(self, responses: List[str], tools: List[Any]) -> List[Dict[str, Any]]: Returns: List of structured messages with tool calls """ - print(f"[GUIAgent.parse] Number of responses: {len(responses)}") - print(f"[GUIAgent.parse] Raw responses type: {type(responses)}") + logger.debug(f"[GUIAgent.parse] Number of responses: {len(responses)}") + logger.debug(f"[GUIAgent.parse] Raw responses type: {type(responses)}") new_messages_list = [] @@ -109,7 +112,7 @@ def parse(self, responses: List[str], tools: List[Any]) -> List[Dict[str, Any]]: elif resp and resp.strip(): # Try to reformat responses that don't have the expected format resp_lower = resp.lower() - print(f"[GUIAgent.parse] Response missing format, reformatting: {resp[:100]}") + logger.debug(f"[GUIAgent.parse] Response missing format, reformatting: {resp[:100]}") # Check if it contains action-like content if any(action in resp_lower for action in ['click', 'type', 'scroll']): @@ -128,9 +131,9 @@ def parse(self, responses: List[str], tools: List[Any]) -> List[Dict[str, Any]]: # Log responses for debugging for idx, resp in enumerate(responses[:3]): # Log first 3 responses if resp: - print(f"[GUIAgent.parse] Response {idx} length: {len(resp)}, preview: {resp[:200]}") + logger.debug(f"[GUIAgent.parse] Response {idx} length: {len(resp)}, preview: {resp[:200]}") else: - print(f"[GUIAgent.parse] Response {idx} is None or empty") + logger.debug(f"[GUIAgent.parse] Response {idx} is None or empty") # Parse actions from responses action_list = [] @@ -145,13 +148,13 @@ def parse(self, responses: List[str], tools: List[Any]) -> List[Dict[str, Any]]: # Create messages with tool calls for i, (response, actions) in enumerate(zip(responses, action_list)): - print(f"[GUIAgent.parse] Processing response {i+1}: response_length={len(response) if response else 0}, actions={actions}") + logger.debug(f"[GUIAgent.parse] Processing response {i+1}: response_length={len(response) if response else 0}, actions={actions}") tool_calls = [] if actions is not None and len(actions) > 0: if len(actions) > 1: - print(f"[GUIAgent.parse] Warning: Multiple actions found ({len(actions)}), using first one") + logger.debug(f"[GUIAgent.parse] Warning: Multiple actions found ({len(actions)}), using first one") action = actions[0] tool_calls = [{ "id": str(i), @@ -163,7 +166,7 @@ def parse(self, responses: List[str], tools: List[Any]) -> List[Dict[str, Any]]: }] else: # If no action was parsed, create a default click action at center - print(f"[GUIAgent.parse] No action parsed from response, creating default click action") + logger.debug(f"[GUIAgent.parse] No action parsed from response, creating default click action") default_action = { "action_type": "click", "action_inputs": {"start_box": "(960, 540)"}, @@ -184,7 +187,7 @@ def parse(self, responses: List[str], tools: List[Any]) -> List[Dict[str, Any]]: status = "terminal" if actions and isinstance(actions[0], dict): action_type = actions[0].get("action_type", "") - print(f"[GUIAgent.parse] Action type: {action_type}, terminating after one turn") + logger.debug(f"[GUIAgent.parse] Action type: {action_type}, terminating after one turn") message = { "role": "assistant", @@ -193,7 +196,7 @@ def parse(self, responses: List[str], tools: List[Any]) -> List[Dict[str, Any]]: "loss": True, "status": status } - print(f"[GUIAgent.parse] Created message with status={status}, tool_calls={len(tool_calls)}, content_length={len(response)}") + logger.debug(f"[GUIAgent.parse] Created message with status={status}, tool_calls={len(tool_calls)}, content_length={len(response)}") new_messages_list.append(message) return new_messages_list diff --git a/agents/agents/agents/templates/templates.py b/agents/agents/agents/templates/templates.py index 7697f87..6b8d3ea 100644 --- a/agents/agents/agents/templates/templates.py +++ b/agents/agents/agents/templates/templates.py @@ -345,6 +345,10 @@ def _encode_user_message_with_tools(self, content, tools: str) -> str: for item in content: if item["type"] == "text": text += item["text"] + elif item["type"] in ["image", "image_url"]: + text += self.vision_start + self.image_token + self.vision_end + elif item["type"] == "video": + text += self.vision_start + self.video_token + self.vision_end else: raise ValueError(f"Invalid message type: {item['type']}") @@ -949,11 +953,11 @@ def _convert_single_message_to_hf_format(self, message: Dict) -> Dict: # Not sure what to do with other types of content pass - def convert_to_hf_format_messages(self, messages: List[Dict]) -> List[Dict]: + def convert_to_hf_format_messages(self, messages: Union[List[Dict], Dict[str, List[Dict]]]) -> List[Dict]: + hf_messages = [] if messages is None: return None role_label, content_label = self._detect_labels(messages) - hf_messages = [] for message in messages: hf_messages.append({"role": message[role_label], "content": message[content_label]}) diff --git a/agents/agents/agents/templates/utils.py b/agents/agents/agents/templates/utils.py index 9e6b383..46faa0c 100644 --- a/agents/agents/agents/templates/utils.py +++ b/agents/agents/agents/templates/utils.py @@ -1,3 +1,4 @@ +from collections import defaultdict import copy from enum import Enum import os @@ -219,26 +220,56 @@ def process_prompt_with_vision( ) -def tokenize_conversations(messages_list, tokenizer, conv_template, max_length, processor=None, return_tensors="pt", return_reward_mask=False): +def tokenize_conversations( + messages_list, + tokenizer, + template, + max_length, + processor=None, + return_tensors="pt", + return_reward_mask=False, + add_generation_prompt=False, + padding_side="right", + concatenate_mm_inputs=False +): batch_input_ids = [] batch_attention_masks = [] batch_labels = [] batch_action_masks = [] + batch_mm_inputs = [] # TODO: add multiprocessing for messages in messages_list: - inputs = tokenize_conversation(messages, tokenizer, conv_template, max_length, processor=processor) + inputs = tokenize_conversation(messages, tokenizer, template, max_length, processor=processor, add_generation_prompt=add_generation_prompt) batch_input_ids.append(inputs['input_ids'].squeeze(0)) batch_attention_masks.append(inputs['attention_mask'].squeeze(0)) batch_labels.append(inputs['labels'].squeeze(0)) batch_action_masks.append(inputs['action_mask'].squeeze(0)) - + mm_inputs = {} + if "pixel_values" in inputs: + mm_inputs["pixel_values"] = inputs["pixel_values"] + else: + mm_inputs["pixel_values"] = None + if "image_grid_thw" in inputs: + mm_inputs["image_grid_thw"] = inputs["image_grid_thw"] + else: + mm_inputs["image_grid_thw"] = None + + batch_mm_inputs.append(mm_inputs) + if return_tensors == "pt": # Use pad_token_id from the tokenizer interface pad_token_id = getattr(tokenizer, 'pad_token_id', 0) - batch_input_ids = torch.nn.utils.rnn.pad_sequence(batch_input_ids, batch_first=True, padding_value=pad_token_id) - batch_attention_masks = torch.nn.utils.rnn.pad_sequence(batch_attention_masks, batch_first=True, padding_value=0) - batch_labels = torch.nn.utils.rnn.pad_sequence(batch_labels, batch_first=True, padding_value=-100) - batch_action_masks = torch.nn.utils.rnn.pad_sequence(batch_action_masks, batch_first=True, padding_value=0) + + batch_input_ids = torch.nn.utils.rnn.pad_sequence(batch_input_ids, batch_first=True, padding_value=pad_token_id, padding_side=padding_side) + batch_attention_masks = torch.nn.utils.rnn.pad_sequence(batch_attention_masks, batch_first=True, padding_value=0, padding_side=padding_side) + batch_labels = torch.nn.utils.rnn.pad_sequence(batch_labels, batch_first=True, padding_value=-100, padding_side=padding_side) + batch_action_masks = torch.nn.utils.rnn.pad_sequence(batch_action_masks, batch_first=True, padding_value=0, padding_side=padding_side) + + # convert [{"pixel_values": tensor, "image_grid_thw": tensor}, ...] to {"key1": concat_tensor, "key2": concat_tensor, ...} + concatenated_mm_inputs = {} + if concatenate_mm_inputs: + for key in batch_mm_inputs[0].keys(): + concatenated_mm_inputs[key] = torch.cat([mm_inputs[key] for mm_inputs in batch_mm_inputs if mm_inputs[key] is not None], dim=0) inputs = dict( input_ids=batch_input_ids, @@ -250,6 +281,20 @@ def tokenize_conversations(messages_list, tokenizer, conv_template, max_length, if return_reward_mask: inputs['reward_mask'] = transform_reward_mask(batch_action_masks) + # Check if we need mm_inputs + mm_keys = list(batch_mm_inputs[0].keys()) + return_mm_inputs = False + for key in mm_keys: + if any(mm_inputs[key] is not None for mm_inputs in batch_mm_inputs): + return_mm_inputs = True + break + + if return_mm_inputs: + if concatenate_mm_inputs: + inputs.update(concatenated_mm_inputs) + else: + inputs["mm_inputs"] = batch_mm_inputs + return inputs def visualize_template(template, messages=None, tools=None, **kwargs): diff --git a/agents/agents/agents/templates/vision_processor.py b/agents/agents/agents/templates/vision_processor.py index ca14aea..63b0062 100644 --- a/agents/agents/agents/templates/vision_processor.py +++ b/agents/agents/agents/templates/vision_processor.py @@ -190,12 +190,16 @@ def process_for_llm( labels.append(-100) action_mask.append(0) + images_to_process = [image for image in images] + videos_to_process = [video for video in videos] # Step 2: Process each element with vision token expansion for element, mask_flag in zip(elements, mask_flags): # Check if element contains vision tokens if self._contains_vision_tokens(element): # Expand vision tokens in this element - expanded_element = self.expand_vision_tokens(element, images, videos, processor) + # Number of images and videos should be equal to the total number of vision tokens in the element + # We check whether all images and videos are processed later. + expanded_element = self.expand_vision_tokens(element, images_to_process, videos_to_process, processor) cur_input_ids = tokenizer.encode(expanded_element, add_special_tokens=False) else: cur_input_ids = tokenizer.encode(element, add_special_tokens=False) @@ -210,6 +214,8 @@ def process_for_llm( else: labels.extend(cur_input_ids) action_mask.extend([1] * len(cur_input_ids)) + + assert len(images_to_process) == len(videos_to_process) == 0, f"All images and videos should be processed, but got {len(images_to_process)} images and {len(videos_to_process)} videos left for vision template {self.config.model_type}." # Step 3: Create base inputs inputs = { @@ -286,6 +292,7 @@ def _load_image_from_input(self, image_input) -> "ImageObject": # Assume it's a file path else: + print(f"Loading image from file path: {image_input}") return Image.open(image_input) # Handle bytes @@ -477,40 +484,52 @@ def expand_vision_tokens( num_image_placeholders = prompt.count(self.config.image_token) num_video_placeholders = prompt.count(self.config.video_token) - if len(images) != num_image_placeholders: - raise ValueError(f"Number of images ({len(images)}) doesn't match placeholders ({num_image_placeholders})") - if len(videos) != num_video_placeholders: - raise ValueError(f"Number of videos ({len(videos)}) doesn't match placeholders ({num_video_placeholders})") - + # if len(images) != num_image_placeholders: + # raise ValueError(f"Number of images ({len(images)}) doesn't match placeholders ({num_image_placeholders})") + # if len(videos) != num_video_placeholders: + # raise ValueError(f"Number of videos ({len(videos)}) doesn't match placeholders ({num_video_placeholders})") + images_slice = [images.pop(0) for _ in range(num_image_placeholders)] + videos_slice = [videos.pop(0) for _ in range(num_video_placeholders)] # Preprocess images and videos to get individual token counts - processed_images = self.preprocess_images(images, processor) if images else {} - processed_videos = self.preprocess_videos(videos, processor) if videos else {} + + processed_images = [self.preprocess_images([image], processor) for image in images_slice] + processed_videos = [self.preprocess_videos([video], processor) for video in videos_slice] - # Expand image tokens using regex to avoid infinite loops expanded_prompt = prompt - if self.config.image_token in expanded_prompt: - if processed_images and "pixel_values" in processed_images: - # Calculate tokens for this specific image - image_tokens = self.calculate_image_tokens(processed_images, processor) - replacement = self.config.image_token * image_tokens - else: - replacement = self.config.image_token - - # Use regex to replace all occurrences at once - import re - expanded_prompt = re.sub(re.escape(self.config.image_token), replacement, expanded_prompt) - - # Expand video tokens using regex to avoid infinite loops - if self.config.video_token in expanded_prompt: - if processed_videos and "pixel_values" in processed_videos: - # Calculate tokens for this specific video - video_tokens = self.calculate_video_tokens(processed_videos, processor) - replacement = self.config.video_token * video_tokens - else: - replacement = self.config.video_token - - # Use regex to replace all occurrences at once - expanded_prompt = re.sub(re.escape(self.config.video_token), replacement, expanded_prompt) + if self.config.image_token in expanded_prompt and processed_images: + parts = expanded_prompt.split(self.config.image_token) + expanded_parts = [parts[0]] + for idx in range(len(parts) - 1): + if idx < len(processed_images): + processed_image = processed_images[idx] + if "pixel_values" in processed_image: + image_tokens = self.calculate_image_tokens(processed_image, processor) + replacement = self.config.image_token * image_tokens + else: + replacement = self.config.image_token + else: + replacement = self.config.image_token + expanded_parts.append(replacement) + expanded_parts.append(parts[idx+1]) + expanded_prompt = ''.join(expanded_parts) + + # Expand video tokens sequentially - each token gets replaced with its corresponding video + if self.config.video_token in expanded_prompt and processed_videos: + parts = expanded_prompt.split(self.config.video_token) + expanded_parts = [parts[0]] + for idx in range(len(parts) - 1): + if idx < len(processed_videos): + processed_video = processed_videos[idx] + if "pixel_values" in processed_video: + video_tokens = self.calculate_video_tokens(processed_video, processor) + replacement = self.config.video_token * video_tokens + else: + replacement = self.config.video_token + else: + replacement = self.config.video_token + expanded_parts.append(replacement) + expanded_parts.append(parts[idx+1]) + expanded_prompt = ''.join(expanded_parts) return expanded_prompt diff --git a/agents/agents/rewards/gui_reward.py b/agents/agents/rewards/gui_reward.py index ae24d7a..8a118c3 100644 --- a/agents/agents/rewards/gui_reward.py +++ b/agents/agents/rewards/gui_reward.py @@ -1,6 +1,7 @@ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # SPDX-License-Identifier: Apache-2.0 +import logging import re import json import ast @@ -9,6 +10,8 @@ from .reward_base import reward from agents.utils.ui_action_parser import parse_action_to_structure_output, IMAGE_FACTOR +logger = logging.getLogger(__name__) + # Image dimensions for testing TEST_IMAGE_HEIGHT = 1080 TEST_IMAGE_WIDTH = 1920 @@ -71,7 +74,7 @@ def extract_action(content: str) -> str: return action_type except Exception as e: - print(f"[extract_action] Error: {e}") + logger.debug(f"[extract_action] Error: {e}") return "no action" @@ -96,7 +99,7 @@ def extract_input_text(content: str) -> str: return "" except Exception as e: - print(f"[extract_input_text] Error: {e}") + logger.debug(f"[extract_input_text] Error: {e}") return "" @@ -120,7 +123,7 @@ def extract_coord(content: str) -> Tuple[list, bool]: coords[0] / TEST_IMAGE_WIDTH, coords[1] / TEST_IMAGE_HEIGHT ] - print(f"[extract_coord] Normalized point from {coords} to {normalized_coords}") + logger.debug(f"[extract_coord] Normalized point from {coords} to {normalized_coords}") return normalized_coords, True elif len(coords) == 4: # Box format [x1, y1, x2, y2] in pixels - normalize to 0-1 @@ -130,15 +133,15 @@ def extract_coord(content: str) -> Tuple[list, bool]: coords[2] / TEST_IMAGE_WIDTH, coords[3] / TEST_IMAGE_HEIGHT ] - print(f"[extract_coord] Normalized box from {coords} to {normalized_coords}") + logger.debug(f"[extract_coord] Normalized box from {coords} to {normalized_coords}") return normalized_coords, True else: - print(f"[extract_coord] Unexpected coord format: {coords}") + logger.debug(f"[extract_coord] Unexpected coord format: {coords}") except Exception as e: - print(f"[extract_coord] Error parsing coordinates: {e}") + logger.debug(f"[extract_coord] Error parsing coordinates: {e}") except Exception as e: - print(f"[extract_coord] Error: {e}") + logger.debug(f"[extract_coord] Error: {e}") return [], False @@ -160,9 +163,9 @@ def gui_accuracy_score(predict_str: str, gt_action: str, gt_bbox: list, gt_input pred_coord, has_coord = extract_coord(predict_str) pred_input_text = extract_input_text(predict_str) - print(f"[gui_accuracy_score] gt_action: {gt_action}, pred_action: {pred_action}") - print(f"[gui_accuracy_score] gt_bbox: {gt_bbox}, pred_coord: {pred_coord}, has_coord: {has_coord}") - print(f"[gui_accuracy_score] gt_input_text: {gt_input_text}, pred_input_text: {pred_input_text}") + logger.debug(f"[gui_accuracy_score] gt_action: {gt_action}, pred_action: {pred_action}") + logger.debug(f"[gui_accuracy_score] gt_bbox: {gt_bbox}, pred_coord: {pred_coord}, has_coord: {has_coord}") + logger.debug(f"[gui_accuracy_score] gt_input_text: {gt_input_text}, pred_input_text: {pred_input_text}") # Map all click variants to 'click' for the 3-action space action_mapping = { @@ -185,9 +188,9 @@ def gui_accuracy_score(predict_str: str, gt_action: str, gt_bbox: list, gt_input # 1. Action type matching (0.5 points) if pred_action_normalized == gt_action_normalized: score += 0.5 - print(f"[gui_accuracy_score] Action matched: +0.5 points") + logger.debug(f"[gui_accuracy_score] Action matched: +0.5 points") else: - print(f"[gui_accuracy_score] Action mismatch: {pred_action_normalized} vs {gt_action_normalized}") + logger.debug(f"[gui_accuracy_score] Action mismatch: {pred_action_normalized} vs {gt_action_normalized}") # 2. Parameter matching (0.5 points) - depends on action type if gt_action_normalized == 'click': @@ -202,7 +205,7 @@ def gui_accuracy_score(predict_str: str, gt_action: str, gt_bbox: list, gt_input gt_x = (gt_bbox[0] + gt_bbox[2]) / 2 gt_y = (gt_bbox[1] + gt_bbox[3]) / 2 else: - print(f"[gui_accuracy_score] Invalid gt_bbox format: {gt_bbox}") + logger.debug(f"[gui_accuracy_score] Invalid gt_bbox format: {gt_bbox}") return score # Get predicted center (already normalized 0-1) @@ -212,7 +215,7 @@ def gui_accuracy_score(predict_str: str, gt_action: str, gt_bbox: list, gt_input pred_x = (pred_coord[0] + pred_coord[2]) / 2 pred_y = (pred_coord[1] + pred_coord[3]) / 2 else: - print(f"[gui_accuracy_score] Invalid pred_coord format: {pred_coord}") + logger.debug(f"[gui_accuracy_score] Invalid pred_coord format: {pred_coord}") return score # Calculate distance in normalized space @@ -223,15 +226,15 @@ def gui_accuracy_score(predict_str: str, gt_action: str, gt_bbox: list, gt_input if distance < threshold: score += 0.5 - print(f"[gui_accuracy_score] Bbox matched (distance={distance:.4f}): +0.5 points") + logger.debug(f"[gui_accuracy_score] Bbox matched (distance={distance:.4f}): +0.5 points") else: - print(f"[gui_accuracy_score] Bbox too far (distance={distance:.4f}, threshold={threshold:.4f})") + logger.debug(f"[gui_accuracy_score] Bbox too far (distance={distance:.4f}, threshold={threshold:.4f})") else: - print(f"[gui_accuracy_score] No predicted coordinates for click action") + logger.debug(f"[gui_accuracy_score] No predicted coordinates for click action") else: # No gt_bbox required, any click gets parameter points score += 0.5 - print(f"[gui_accuracy_score] No gt_bbox required: +0.5 points") + logger.debug(f"[gui_accuracy_score] No gt_bbox required: +0.5 points") elif gt_action_normalized == 'type': # For type: check text content @@ -239,34 +242,34 @@ def gui_accuracy_score(predict_str: str, gt_action: str, gt_bbox: list, gt_input f1, _, _ = f1_score(pred_input_text, gt_input_text) if f1 >= 0.5: score += 0.5 - print(f"[gui_accuracy_score] Type text matched (f1={f1:.2f}): +0.5 points") + logger.debug(f"[gui_accuracy_score] Type text matched (f1={f1:.2f}): +0.5 points") else: - print(f"[gui_accuracy_score] Type text mismatch (f1={f1:.2f})") + logger.debug(f"[gui_accuracy_score] Type text mismatch (f1={f1:.2f})") else: # No text required, any type action gets parameter points score += 0.5 - print(f"[gui_accuracy_score] No text required: +0.5 points") + logger.debug(f"[gui_accuracy_score] No text required: +0.5 points") elif gt_action_normalized == 'scroll': # For scroll: only check direction (no bbox needed) if gt_input_text and gt_input_text != "no input text": if pred_input_text.lower() == gt_input_text.lower(): score += 0.5 - print(f"[gui_accuracy_score] Scroll direction matched: +0.5 points") + logger.debug(f"[gui_accuracy_score] Scroll direction matched: +0.5 points") else: - print(f"[gui_accuracy_score] Scroll direction mismatch: {pred_input_text} vs {gt_input_text}") + logger.debug(f"[gui_accuracy_score] Scroll direction mismatch: {pred_input_text} vs {gt_input_text}") else: # No direction specified, any scroll gets parameter points score += 0.5 - print(f"[gui_accuracy_score] No scroll direction required: +0.5 points") + logger.debug(f"[gui_accuracy_score] No scroll direction required: +0.5 points") - print(f"[gui_accuracy_score] Final score: {score}") + logger.debug(f"[gui_accuracy_score] Final score: {score}") return score except Exception as e: - print(f"Error in gui_accuracy_score: {e}") - print(f"predict_str: {predict_str}") - print(f"gt_action: {gt_action}, gt_bbox: {gt_bbox}, gt_input_text: {gt_input_text}") + logger.debug(f"Error in gui_accuracy_score: {e}") + logger.debug(f"predict_str: {predict_str}") + logger.debug(f"gt_action: {gt_action}, gt_bbox: {gt_bbox}, gt_input_text: {gt_input_text}") return 0.0 @@ -283,18 +286,18 @@ def gui_reward(prediction: str, trajectory: List[Dict] = None, gt_action: str = Returns: Dictionary with reward scores """ - print(f"[gui_reward] Called with prediction: {prediction[:200] if prediction else 'None'}") - print(f"[gui_reward] kwargs keys: {list(kwargs.keys())}") + logger.debug(f"[gui_reward] Called with prediction: {prediction[:200] if prediction else 'None'}") + logger.debug(f"[gui_reward] kwargs keys: {list(kwargs.keys())}") # Handle empty predictions if not prediction or prediction.strip() == "": - print(f"[gui_reward] Warning: Empty prediction received") + logger.debug(f"[gui_reward] Warning: Empty prediction received") # Check if there's a default action in trajectory if trajectory and len(trajectory) > 0: for msg in reversed(trajectory): if msg.get('role') == 'assistant' and msg.get('content'): prediction = msg['content'] - print(f"[gui_reward] Using trajectory content as prediction: {prediction[:100]}") + logger.debug(f"[gui_reward] Using trajectory content as prediction: {prediction[:100]}") break # if not prediction or prediction.strip() == "": @@ -309,7 +312,7 @@ def gui_reward(prediction: str, trajectory: List[Dict] = None, gt_action: str = if hasattr(gt_bbox, 'tolist'): gt_bbox = gt_bbox.tolist() - print(f"[gui_reward] gt_action: {gt_action}, gt_bbox: {gt_bbox}, gt_input_text: {gt_input_text}") + logger.debug(f"[gui_reward] gt_action: {gt_action}, gt_bbox: {gt_bbox}, gt_input_text: {gt_input_text}") # Handle "no input text" as empty if gt_input_text == "no input text": @@ -319,7 +322,7 @@ def gui_reward(prediction: str, trajectory: List[Dict] = None, gt_action: str = # Both prediction and ground truth use normalized coordinates if not gt_action and not gt_bbox and not gt_input_text: - print(f"[gui_reward] Warning: No ground truth data provided - returning 0 reward") + logger.debug(f"[gui_reward] Warning: No ground truth data provided - returning 0 reward") return { "reward": 0.0, "format": gui_format_score(prediction), @@ -333,7 +336,7 @@ def gui_reward(prediction: str, trajectory: List[Dict] = None, gt_action: str = format_score = gui_format_score(prediction) accuracy_score = gui_accuracy_score(prediction, gt_action, gt_bbox, gt_input_text) - print(f"[gui_reward] format_score: {format_score}, accuracy_score: {accuracy_score}") + logger.debug(f"[gui_reward] format_score: {format_score}, accuracy_score: {accuracy_score}") # For f1_score, create answer string for backward compatibility answer_dict = { diff --git a/agents/agents/rewards/qa_reward.py b/agents/agents/rewards/qa_reward.py index a6dd139..a78c99d 100644 --- a/agents/agents/rewards/qa_reward.py +++ b/agents/agents/rewards/qa_reward.py @@ -123,6 +123,12 @@ def ok_vqa_reward(prediction: str, answers: List[str], trajectory: List[str]) -> @reward(name="infoseek_reward") def infoseek_reward(prediction: str, answer: Union[str, List[str]], answer_eval: List[str | Dict], trajectory: List[str]) -> float: + # format reward + call_tool_count = 0 + for msg in trajectory: + if msg["role"] == "tool": + call_tool_count += 1 + f1_scores = [] answers = [] if isinstance(answer, str): @@ -137,5 +143,10 @@ def infoseek_reward(prediction: str, answer: Union[str, List[str]], answer_eval: f1, precision, recall = f1_score(prediction, _answer) f1_scores.append(f1) - # All answers are the correct answer, take the max f1 score - return max(f1_scores) + max_f1_score = max(f1_scores) + + call_tool_reward = 1.0 if call_tool_count > 1 else 0.0 + + reward = 0.2 * call_tool_reward + 0.8 * max_f1_score + + return reward diff --git a/agents/agents/tools/tool_base.py b/agents/agents/tools/tool_base.py index 8cd3095..e2ad350 100644 --- a/agents/agents/tools/tool_base.py +++ b/agents/agents/tools/tool_base.py @@ -307,11 +307,19 @@ def factory(): return decorator -async def submit_tool_call(tool_name: str, tool_input: str, id: str=None) -> dict: +async def submit_tool_call( + tool_name: str, + tool_input: str, + id: str=None, + allowed_tool_names: List[str] = None, +) -> dict: """ Submit a tool call to the environment. """ - if tool_name not in TOOL_REGISTRY: + if allowed_tool_names is None: + allowed_tool_names = list(TOOL_REGISTRY.keys()) + + if tool_name not in allowed_tool_names: tool_name = "hallucination_tool" tool_input = {"tool_name": str(tool_name)} @@ -332,6 +340,7 @@ async def submit_tool_call(tool_name: str, tool_input: str, id: str=None) -> dic # If the loaded input is not a dict, it means the input is not a valid JSON object if not isinstance(tool_input_json, dict): tool_input_json = None + elif isinstance(tool_input, dict): tool_input_json = tool_input else: @@ -351,14 +360,25 @@ async def submit_tool_call(tool_name: str, tool_input: str, id: str=None) -> dic return await tool_obj(**tool_input_json) -def submit_tool_calls(tool_names: List[str], tool_inputs: List[Dict | str], ids: List[str]) -> List[dict]: +def submit_tool_calls( + tool_names: List[str], + tool_inputs: List[Dict | str], + ids: List[str], + allowed_tool_names: List[str] = None, +) -> List[dict]: """ Submit tool calls to the environment. This is a synchronous wrapper that blocks until all results are ready. Uses ThreadPoolExecutor to run tool calls in parallel. """ + + if allowed_tool_names is None: + allowed_tool_names = list(TOOL_REGISTRY.keys()) + + mapped_tool_names = [] mapped_tool_inputs = [] tool_objs = [] + for tool_name, tool_input, id in zip(tool_names, tool_inputs, ids): if isinstance(tool_input, dict): tool_input_json = tool_input @@ -371,10 +391,12 @@ def submit_tool_calls(tool_names: List[str], tool_inputs: List[Dict | str], ids: raise ValueError(f"Invalid tool input: {tool_input}") - if tool_name not in TOOL_REGISTRY: + if tool_name not in allowed_tool_names: + # Called a non-existent tool mapped_tool_name = "hallucination_tool" tool_input_json = {"tool_name": tool_name} elif tool_input_json is None: + # Invalid input mapped_tool_name = "invalid_input_tool" tool_input_json = {"tool_input": tool_input} else: @@ -408,13 +430,11 @@ def submit_tool_calls(tool_names: List[str], tool_inputs: List[Dict | str], ids: @tool() def hallucination_tool(tool_name): - return f"Hallucinated tool: {tool_name}" + return f"Hallucinated tool: {tool_name} does not exist." @tool() def invalid_input_tool(tool_input): - return f"Invalid input: {tool_input}, input mush be a valid JSON object." - - + return f"Invalid input: {tool_input}, input must be a valid JSON object." diff --git a/agents/agents/utils/ui_action_parser.py b/agents/agents/utils/ui_action_parser.py index e992fce..0bc55a5 100644 --- a/agents/agents/utils/ui_action_parser.py +++ b/agents/agents/utils/ui_action_parser.py @@ -4,6 +4,9 @@ import ast import math from typing import List, Dict, Any, Optional, Tuple +import logging + +logger = logging.getLogger(__name__) IMAGE_FACTOR = 1 # Changed to match gui_reward.py MIN_PIXELS = 100 * 28 * 28 @@ -69,7 +72,7 @@ def parse_action(action_str: str) -> Optional[Dict[str, Any]]: return {'function': func_name, 'args': kwargs} except Exception as e: - print(f"Failed to parse action '{action_str}': {e}") + logger.debug(f"Failed to parse action '{action_str}': {e}") return None @@ -145,11 +148,11 @@ def parse_action_to_structure_output(text: str, max_pixels: int = 16384 * 28 * 28, min_pixels: int = 100 * 28 * 28) -> Optional[List[Dict[str, Any]]]: """Parse action text to structured output.""" - print(f"[parse_action_to_structure_output] Input text: {text[:500] if text else 'Empty text'}") + logger.debug(f"[parse_action_to_structure_output] Input text: {text[:500] if text else 'Empty text'}") # Handle empty or None responses if not text: - print(f"[parse_action_to_structure_output] Empty text, returning None") + logger.debug(f"[parse_action_to_structure_output] Empty text, returning None") return None text = text.strip() @@ -197,11 +200,11 @@ def parse_action_to_structure_output(text: str, reflection = thought_match.group(1).strip() if "Action:" not in text: - print(f"[parse_action_to_structure_output] No 'Action:' found in text, returning None") + logger.debug(f"[parse_action_to_structure_output] No 'Action:' found in text, returning None") return None action_str = text.split("Action: ")[-1] - print(f"[parse_action_to_structure_output] Extracted action string: {action_str[:200]}") + logger.debug(f"[parse_action_to_structure_output] Extracted action string: {action_str[:200]}") # Parse multiple actions tmp_all_action = action_str.split(")\n\n") @@ -228,7 +231,7 @@ def parse_action_to_structure_output(text: str, actions = [] for action_instance, raw_str in zip(parsed_actions, all_action): if action_instance is None: - print(f"Action can't parse: {raw_str}") + logger.debug(f"Action can't parse: {raw_str}") continue action_type = action_instance["function"] @@ -256,7 +259,7 @@ def parse_action_to_structure_output(text: str, for num in numbers: float(num.strip()) except ValueError: - print(f"Warning: Invalid coordinate format in '{param_name}': '{ori_box}'") + logger.debug(f"Warning: Invalid coordinate format in '{param_name}': '{ori_box}'") return None # Convert coordinates based on model type diff --git a/agents/requirements.txt b/agents/requirements.txt index 559e048..8a857b0 100644 --- a/agents/requirements.txt +++ b/agents/requirements.txt @@ -6,7 +6,7 @@ redis docker openai faiss-cpu -vllm==0.9.2 +vllm==0.10.0 termcolor tenacity nest-asyncio diff --git a/agents/tests/unit/agents/templates/test_vision_templates_full_align.py b/agents/tests/unit/agents/templates/test_vision_templates_full_align.py index 167307d..a0b8d8e 100644 --- a/agents/tests/unit/agents/templates/test_vision_templates_full_align.py +++ b/agents/tests/unit/agents/templates/test_vision_templates_full_align.py @@ -60,6 +60,25 @@ {"role": "assistant", "content": "I am fine, thank you."}, {"role": "user", "content": "What is 3 times 5?"}, ], + [ + { + "role": "user", + "content": [ + {"type": "image", "image": "https://images.unsplash.com/photo-1592194996308-7b43878e84a6"}, + {"type": "text", "text": "Describe these images."}, + {"type": "image", "image": "https://images.unsplash.com/photo-1599158164704-ef1ec0c94b1c"} + ] + }, + { + "role": "assistant", + "content": [ + { + "type": "text", + "text": "The image is a cat.", + }, + ], + }, + ] ]) @pytest.mark.parametrize("tools", [ None, diff --git a/agents/tests/unit/agents/templates/test_vision_templates_tokenize.py b/agents/tests/unit/agents/templates/test_vision_templates_tokenize.py index 31b5ecd..6e17ef8 100644 --- a/agents/tests/unit/agents/templates/test_vision_templates_tokenize.py +++ b/agents/tests/unit/agents/templates/test_vision_templates_tokenize.py @@ -63,6 +63,25 @@ # {"role": "assistant", "content": "I am fine, thank you."}, # {"role": "user", "content": "What is 3 times 5?"}, # ], + [ + { + "role": "user", + "content": [ + {"type": "image", "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"}, + {"type": "text", "text": "Describe these images."}, + {"type": "image", "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"} + ] + }, + { + "role": "assistant", + "content": [ + { + "type": "text", + "text": "The image is a cat.", + }, + ], + }, + ] ]) @pytest.mark.parametrize("tools", [ None, @@ -94,7 +113,6 @@ def test_chat_template_equal(template, messages, tools, add_generation_prompt): official_prompt = tokenizer.decode(official_inputs['input_ids'][0]) implemented_prompt = tokenizer.decode(implemented_inputs['input_ids'][0]) print(f"Official prompt image tokens: {official_prompt.count('<|image_pad|>')}\nImplemented prompt image tokens: {implemented_prompt.count('<|image_pad|>')}") - print(f"Official images: {official_inputs['pixel_values'].shape}\nImplemented images: {implemented_inputs['pixel_values'].shape}") assert torch.equal(official_inputs["input_ids"], implemented_inputs["input_ids"]), f"""Offical diff --git a/agents/tests/unit/agents/test_vision_agent.py b/agents/tests/unit/agents/test_vision_agent.py index eedc521..67bf4d6 100644 --- a/agents/tests/unit/agents/test_vision_agent.py +++ b/agents/tests/unit/agents/test_vision_agent.py @@ -1,3 +1,4 @@ +import torch from agents.agents.react.react_agent import ReactAgent from agents.tools import answer_qa import pytest @@ -45,5 +46,12 @@ async def test_vision_agent(): messages = messages_list[0]['messages'] for message in messages: print(f"{message['role']}: {message['content']}") - inputs = react_agent.tokenize_trajectories() - print(inputs) \ No newline at end of file + inputs, other_info_list = react_agent.tokenize_trajectories() + for key, value in inputs.items(): + if isinstance(value, torch.Tensor): + print(f"{key}: {value.shape}") + else: + print(f"{key}: {value}") + other_info = other_info_list[0] + for key, value in other_info.items(): + print(f"{key}: {value}") \ No newline at end of file diff --git a/install.sh b/install.sh new file mode 100644 index 0000000..b014117 --- /dev/null +++ b/install.sh @@ -0,0 +1,429 @@ +#!/bin/bash + +# AgentFly Installation Script +# This script handles the complete installation of AgentFly and its dependencies + +set -e # Exit on any error + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +NC='\033[0m' # No Color + +# Function to print colored output +print_status() { + echo -e "${BLUE}[INFO]${NC} $1" +} + +print_success() { + echo -e "${GREEN}[SUCCESS]${NC} $1" +} + +print_warning() { + echo -e "${YELLOW}[WARNING]${NC} $1" +} + +print_error() { + echo -e "${RED}[ERROR]${NC} $1" +} + +# Function to check if command exists +command_exists() { + command -v "$1" >/dev/null 2>&1 +} + +# Function to check if user has sudo access +check_sudo() { + if sudo -n true 2>/dev/null; then + return 0 + else + return 1 + fi +} + +# Function to install enroot +install_enroot() { + print_status "Installing enroot..." + + # Check if we're on a supported system + if [[ "$OSTYPE" == "linux-gnu"* ]]; then + # Ubuntu/Debian + if command_exists apt-get; then + print_status "Detected Ubuntu/Debian system, installing enroot via deb packages..." + + # Get architecture + arch=$(dpkg --print-architecture) + if [ $? -eq 0 ]; then + print_status "Detected architecture: $arch" + INSTALLATION_STATUS+=("architecture detection: SUCCESS") + else + print_error "Failed to detect architecture" + INSTALLATION_STATUS+=("architecture detection: FAILED") + return 1 + fi + + # Download enroot packages + print_status "Downloading enroot packages..." + curl -fSsL -O "https://github.com/NVIDIA/enroot/releases/download/v3.5.0/enroot-hardened_3.5.0-1_${arch}.deb" + if [ $? -eq 0 ]; then + print_success "Downloaded enroot-hardened package" + INSTALLATION_STATUS+=("enroot-hardened download: SUCCESS") + else + print_error "Failed to download enroot-hardened package" + INSTALLATION_STATUS+=("enroot-hardened download: FAILED") + return 1 + fi + + curl -fSsL -O "https://github.com/NVIDIA/enroot/releases/download/v3.5.0/enroot-hardened+caps_3.5.0-1_${arch}.deb" + if [ $? -eq 0 ]; then + print_success "Downloaded enroot-hardened+caps package" + INSTALLATION_STATUS+=("enroot-hardened+caps download: SUCCESS") + else + print_error "Failed to download enroot-hardened+caps package" + INSTALLATION_STATUS+=("enroot-hardened+caps download: FAILED") + return 1 + fi + + # Install packages + print_status "Installing enroot packages..." + sudo apt install -y ./*.deb + if [ $? -eq 0 ]; then + print_success "enroot packages installed successfully!" + INSTALLATION_STATUS+=("enroot package installation: SUCCESS") + else + print_error "Failed to install enroot packages" + INSTALLATION_STATUS+=("enroot package installation: FAILED") + return 1 + fi + + # Clean up downloaded packages + rm -f ./*.deb + print_status "Cleaned up downloaded packages" + INSTALLATION_STATUS+=("package cleanup: SUCCESS") + + else + print_warning "Unsupported package manager. Please install enroot manually from: https://github.com/NVIDIA/enroot/blob/master/doc/installation.md" + return 1 + fi + else + print_warning "Unsupported operating system. Please install enroot manually from: https://github.com/NVIDIA/enroot/blob/master/doc/installation.md" + return 1 + fi + + if command_exists enroot; then + print_success "enroot installed successfully!" + return 0 + else + print_error "Failed to install enroot. Please install manually." + return 1 + fi +} + +# Function to check conda and create agentfly environment +setup_conda_environment() { + if ! command_exists conda; then + print_error "conda not found. Please install conda first." + exit 1 + fi + + print_success "conda found" + + # Check if agentfly environment already exists + if conda env list | grep -q "agentfly"; then + print_status "agentfly environment already exists, activating it..." + conda activate agentfly + if [ $? -eq 0 ]; then + print_success "agentfly environment activated" + INSTALLATION_STATUS+=("conda environment activation: SUCCESS") + else + print_error "Failed to activate existing agentfly environment" + INSTALLATION_STATUS+=("conda environment activation: FAILED") + return 1 + fi + else + print_status "Creating new conda environment 'agentfly' with Python 3.10..." + conda create -n agentfly python=3.10 -y + if [ $? -eq 0 ]; then + print_success "agentfly environment created successfully!" + INSTALLATION_STATUS+=("conda environment creation: SUCCESS") + else + print_error "Failed to create agentfly environment" + INSTALLATION_STATUS+=("conda environment creation: FAILED") + return 1 + fi + + print_status "Activating agentfly environment..." + conda activate agentfly + if [ $? -eq 0 ]; then + print_success "agentfly environment activated" + INSTALLATION_STATUS+=("conda environment activation: SUCCESS") + else + print_error "Failed to activate agentfly environment" + INSTALLATION_STATUS+=("conda environment activation: FAILED") + return 1 + fi + fi +} + +# Function to install redis-server via conda +install_redis() { + print_status "Installing redis-server via conda..." + + # Ensure conda is in PATH + if command_exists conda; then + conda install -y conda-forge::redis-server==7.4.0 + if [ $? -eq 0 ]; then + print_success "redis-server installed successfully!" + return 0 + else + print_error "Failed to install redis-server" + return 1 + fi + else + print_error "conda not found. Please install conda first or install redis-server manually." + return 1 + fi +} + +# Main installation function +main() { + echo "==========================================" + echo " AgentFly Installation Script" + echo "==========================================" + echo "" + + # Check Python version (will be checked again after conda environment setup) + print_status "Checking Python version..." + if command_exists python3; then + PYTHON_VERSION=$(python3 --version 2>&1 | awk '{print $2}') + print_status "Found Python version: $PYTHON_VERSION" + INSTALLATION_STATUS+=("Python availability: SUCCESS") + else + print_error "Python 3 not found. Please install Python 3.10.x first." + INSTALLATION_STATUS+=("Python availability: FAILED") + exit 1 + fi + + # Check pip + print_status "Checking pip..." + if command_exists pip3; then + print_success "pip3 found" + INSTALLATION_STATUS+=("pip availability: SUCCESS") + elif command_exists pip; then + print_success "pip found" + INSTALLATION_STATUS+=("pip availability: SUCCESS") + else + print_error "pip not found. Please install pip first." + INSTALLATION_STATUS+=("pip availability: FAILED") + exit 1 + fi + + # Check git + print_status "Checking git..." + if command_exists git; then + print_success "git found" + INSTALLATION_STATUS+=("git availability: SUCCESS") + else + print_error "git not found. Please install git first." + INSTALLATION_STATUS+=("git availability: FAILED") + exit 1 + fi + + # Initialize git submodules + print_status "Initializing git submodules..." + if [ -d ".git" ]; then + git submodule init + if [ $? -eq 0 ]; then + git submodule update + if [ $? -eq 0 ]; then + print_success "Git submodules initialized successfully!" + INSTALLATION_STATUS+=("Git submodules: SUCCESS") + else + print_error "Failed to update git submodules" + INSTALLATION_STATUS+=("Git submodules: FAILED") + fi + else + print_error "Failed to init git submodules" + INSTALLATION_STATUS+=("Git submodules: FAILED") + fi + else + print_warning "Not in a git repository. Skipping submodule initialization." + INSTALLATION_STATUS+=("Git submodules: SKIPPED (not git repo)") + fi + + # Install Python dependencies + print_status "Installing basic Python dependencies..." + pip install -e . > /dev/null + if [ $? -eq 0 ]; then + print_success "Basic dependencies installed successfully!" + INSTALLATION_STATUS+=("Basic Python dependencies: SUCCESS") + else + print_error "Failed to install basic dependencies" + INSTALLATION_STATUS+=("Basic Python dependencies: FAILED") + fi + + print_status "Installing VERL dependencies..." + pip install -e '.[verl]' --no-build-isolation > /dev/null + if [ $? -eq 0 ]; then + print_success "VERL dependencies installed successfully!" + INSTALLATION_STATUS+=("VERL dependencies: SUCCESS") + else + print_error "Failed to install VERL dependencies" + INSTALLATION_STATUS+=("VERL dependencies: FAILED") + fi + + # Check and install enroot if needed + print_status "Checking enroot installation..." + if command_exists enroot; then + print_success "enroot is already installed" + else + print_warning "enroot not found. Some tools require it for container management." + + if check_sudo; then + print_status "Sudo access detected. Attempting to install enroot..." + INSTALLATION_STATUS+=("sudo access: SUCCESS") + if install_enroot; then + print_success "enroot installation completed!" + INSTALLATION_STATUS+=("enroot installation: SUCCESS") + else + print_warning "enroot installation failed. Some tools may not work properly." + INSTALLATION_STATUS+=("enroot installation: FAILED") + fi + else + print_warning "No sudo access. Please install enroot manually from: https://github.com/NVIDIA/enroot/blob/master/doc/installation.md" + INSTALLATION_STATUS+=("sudo access: FAILED") + INSTALLATION_STATUS+=("enroot installation: SKIPPED (no sudo)") + fi + fi + + # Check conda availability + print_status "Checking conda availability..." + if command_exists conda; then + print_success "conda found" + INSTALLATION_STATUS+=("conda availability: SUCCESS") + else + print_error "conda not found. Please install conda first." + INSTALLATION_STATUS+=("conda availability: FAILED") + exit 1 + fi + + # Install redis-server (assuming we're already in a conda environment) + print_status "Installing redis-server via conda..." + if install_redis; then + INSTALLATION_STATUS+=("redis-server installation: SUCCESS") + else + INSTALLATION_STATUS+=("redis-server installation: FAILED") + fi + + # Final checks and summary + echo "" + echo "==========================================" + echo " Installation Summary" + echo "==========================================" + + print_status "Checking installed components..." + + if command_exists python3; then + PYTHON_VERSION=$(python3 --version 2>&1 | awk '{print $2}') + if [[ "$PYTHON_VERSION" =~ ^3\.10\. ]]; then + print_success "βœ“ Python 3.10.x ($PYTHON_VERSION)" + INSTALLATION_STATUS+=("Python 3.10.x verification: SUCCESS") + else + print_error "βœ— Python version $PYTHON_VERSION does not meet requirements (need 3.10.x)" + INSTALLATION_STATUS+=("Python 3.10.x verification: FAILED") + fi + else + print_error "βœ— Python 3 not found" + INSTALLATION_STATUS+=("Python 3.10.x verification: FAILED") + fi + + if [ -d "AgentFly.egg-info" ] || [ -d "agents" ]; then + print_success "βœ“ AgentFly package" + INSTALLATION_STATUS+=("AgentFly package verification: SUCCESS") + else + print_error "βœ— AgentFly package not found" + INSTALLATION_STATUS+=("AgentFly package verification: FAILED") + fi + + if command_exists enroot; then + print_success "βœ“ enroot" + INSTALLATION_STATUS+=("enroot verification: SUCCESS") + else + print_warning "βœ— enroot (not installed - some tools may not work)" + INSTALLATION_STATUS+=("enroot verification: FAILED") + fi + + if command_exists conda; then + print_success "βœ“ conda" + INSTALLATION_STATUS+=("conda verification: SUCCESS") + else + print_error "βœ— conda not found" + INSTALLATION_STATUS+=("conda verification: FAILED") + fi + + # Skip agentfly environment check - assuming we're already in a conda environment + print_success "βœ“ conda environment (assuming active)" + INSTALLATION_STATUS+=("conda environment verification: SKIPPED (assumed active)") + + if command_exists redis-server; then + print_success "βœ“ redis-server" + INSTALLATION_STATUS+=("redis-server verification: SUCCESS") + else + print_error "βœ— redis-server not found" + INSTALLATION_STATUS+=("redis-server verification: FAILED") + fi + + echo "" + echo "==========================================" + echo " Step-by-Step Status Report" + echo "==========================================" + + # Count successes, failures, and skips + SUCCESS_COUNT=0 + FAILED_COUNT=0 + SKIPPED_COUNT=0 + + for status in "${INSTALLATION_STATUS[@]}"; do + if [[ $status == *"SUCCESS"* ]]; then + echo -e "${GREEN}βœ“${NC} $status" + ((SUCCESS_COUNT++)) + elif [[ $status == *"FAILED"* ]]; then + echo -e "${RED}βœ—${NC} $status" + ((FAILED_COUNT++)) + else + echo " $status" + ((SKIPPED_COUNT++)) + fi + done + + echo "" + echo "==========================================" + echo " Summary Statistics" + echo "==========================================" + echo -e "${GREEN}Successful steps: $SUCCESS_COUNT${NC}" + echo -e "${RED}Failed steps: $FAILED_COUNT${NC}" + if [ $SKIPPED_COUNT -gt 0 ]; then + echo -e "${YELLOW}Skipped steps: $SKIPPED_COUNT${NC}" + fi + + echo "" + if [ $FAILED_COUNT -eq 0 ]; then + print_success "AgentFly installation completed successfully!" + elif [ $FAILED_COUNT -le 2 ]; then + print_warning "AgentFly installation completed with minor issues. Some features may not work properly." + else + print_error "AgentFly installation completed with significant issues. Please review the failed steps above." + fi + + echo "" + print_status "Next steps:" + echo " 1. If you just installed enroot, you may need to restart your terminal" + echo " 2. Check the documentation at: https://agentfly.readthedocs.io/" + echo " 3. Try running an example: cd verl && bash examples/run_agents/run_code_agent.sh" + echo "" +} + +# Run main function +main "$@" diff --git a/pyproject.toml b/pyproject.toml index 92d6504..068ddf6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ dependencies = [ "docker", "openai", "faiss-cpu", - "vllm==0.9.2", + "vllm==0.10.0", "termcolor", "tenacity", "bs4",