diff --git a/.gitignore b/.gitignore
index 5592ce9..d2646d5 100644
--- a/.gitignore
+++ b/.gitignore
@@ -132,5 +132,9 @@ outputs
# Notebooks
agents/tests/*.ipynb
+agents/tests/*.jpg
+agents/tests/*.jpeg
+agents/tests/*.png
agents/agents/*.ipynb
+agents/temp/
diff --git a/README.md b/README.md
index a1bf150..3c88df4 100644
--- a/README.md
+++ b/README.md
@@ -1,26 +1,53 @@
-# AgentFly: Scalable and Extensible Reinforcement Learning for LLM Agents
+# πͺ½AgentFly: Training scalable LLM agents with RL (multi-turn, async tools/rewards, multimodal)
+
+
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
-
AgentFly is an extensible framework for building LLM agents with reinforcement learning. It supports multi-turn training by adapting traditional RL methods with token-level masking. It features a decorator-based interface for defining tools and reward functions, enabling seamless extension and ease of use. To support high-throughput training, it implemented asynchronous execution of tool calls and reward computations, and design a centralized resource management system for scalable environment coordination. A suite of prebuilt tools and environments are provided.

-## π News
+## News
-**Multi-Modal (Vision) Agent Training Support** - Thanks to the powerful template system, AgentFly now supports training vision-language agents! π
-
-Train agents that can see and understand visual content, including GUI automation and image-based QA. See our [predefined training examples](docs/examples/predefined_training_examples.md) for ready-to-use scripts.
+**08/2025 Multi-Modal (Vision) Agent Training Support** - Thanks to the powerful template system, AgentFly now supports training vision-language agents! π Train agents that can see and understand visual content, including GUI automation and image-based QA. See our [predefined training examples](docs/examples/predefined_training_examples.md) for ready-to-use scripts.
---
-**New: Chat Template System** - A flexible framework for creating conversation templates with multi-model support, vision capabilities, and tool integration. [Learn more β](docs/chat_template/)
+**08/2025 Chat Template System** - A flexible framework for creating conversation templates with multi-model support, vision capabilities, and tool integration. [Learn more β](docs/chat_template/)
## Installation
+**Option 1**: One-line Installation:
+```
+bash install.sh # Assume conda with python3.10.x
+```
+**Option 2**: Customized Installation
+
Clone and initialize the project:
```bash
git clone https://github.com/Agent-One-Lab/AgentFly
@@ -42,6 +69,11 @@ Search requires redis to cache results, an optional way to install with conda:
conda install conda-forge::redis-server==7.4.0
```
+## Quick Start
+```
+
+```
+
## Features
### 1. Multi-Chain Agent Rollout and Multi-Turn Training
To support algorithms like GRPO, Reinforce++, we design multi-chain inference, enabling agents to solve one task with multiple paths at the same time. We build RL computation and update LLMs in multi-turn manner by applying token masks. The training is based on [verl](https://github.com/volcengine/verl).
@@ -67,6 +99,10 @@ agent = ReactAgent(
### 3. Easy Development
Decoupled agent and training module. Simply customize your own agent, which can directly be applied to training.
+## Training Curves
+Reward curves on Qwen2.5-Instruct 3B and 7B models.
+
+
## Training
### Run Example Training
@@ -155,9 +191,6 @@ https://github.com/user-attachments/assets/b8f42534-8d40-48a0-a264-f378e479bb3a
[Discord](https://discord.gg/CchUj7Sp)
-## Training Curves
-Reward curves on Qwen2.5-Instruct 3B and 7B models.
-
## Cite
If you used our code or find it helpful, please cite:
diff --git a/agents/agents/agents/agent_base.py b/agents/agents/agents/agent_base.py
index 4ac693a..a58092e 100644
--- a/agents/agents/agents/agent_base.py
+++ b/agents/agents/agents/agent_base.py
@@ -70,6 +70,7 @@ def __init__(
self.template = template
self.max_length = max_length
self.tools = tools
+ self.tool_names = [tool.name for tool in tools]
self.system_prompt = system_prompt
self.model_name_or_path = model_name_or_path
@@ -209,7 +210,7 @@ def trajectories(self):
return trajectories
- def tokenize_trajectories(self, tokenizer = None, return_reward_mask: bool = False):
+ def tokenize_trajectories(self, tokenizer = None, return_reward_mask: bool = False, concatenate_mm_inputs: bool = True):
if tokenizer is None:
tokenizer = self.tokenizer
@@ -239,7 +240,16 @@ def tokenize_trajectories(self, tokenizer = None, return_reward_mask: bool = Fal
info['last_response'] = last_response
other_info_list.append(info)
- inputs = tokenize_conversations(messages_list, tokenizer=tokenizer, conv_template=self.template, processor=self.processor, max_length=self.max_length, return_reward_mask=return_reward_mask)
+ inputs = tokenize_conversations(
+ messages_list,
+ tokenizer=tokenizer,
+ template=self.template,
+ processor=self.processor,
+ max_length=self.max_length,
+ return_reward_mask=return_reward_mask,
+ add_generation_prompt=True,
+ concatenate_mm_inputs=concatenate_mm_inputs,
+ )
position_ids = torch.clip(torch.cumsum(inputs['attention_mask'], dim=-1) - 1, min=0, max=None)
inputs['position_ids'] = position_ids
@@ -318,7 +328,7 @@ def rewards(self):
def get_verl_data_proto(self):
- inputs, other_info_list = self.tokenize_trajectories(return_reward_mask=True)
+ inputs, other_info_list = self.tokenize_trajectories(return_reward_mask=True, concatenate_mm_inputs=False)
group_ids = np.array([info["group_id"] for info in other_info_list], dtype=object)
# Do evaluation here
reward_values, other_values = self.rewards
@@ -329,6 +339,10 @@ def get_verl_data_proto(self):
inputs[f"rm_{key}"] = np.array(values)
# We handle the group id in the agent side, to be compatible with GRPO
inputs["uid"] = group_ids
+
+ if "mm_inputs" in inputs:
+ mm_inputs = inputs.pop("mm_inputs")
+ inputs["multi_modal_inputs"] = np.array(mm_inputs, dtype=object)
batch = DataProto.from_single_dict(inputs, meta_info={"use_agent": True})
return batch
diff --git a/agents/agents/agents/chain/chain_base.py b/agents/agents/agents/chain/chain_base.py
index faf8f71..3fcf12d 100644
--- a/agents/agents/agents/chain/chain_base.py
+++ b/agents/agents/agents/chain/chain_base.py
@@ -474,7 +474,12 @@ async def _execute_tool_call(self, tool_call, newest_messages, chain, chain_id,
have_set_tools = True
# Execute tool call
- result = await submit_tool_call(tool_name, tool_input, id=chain_id)
+ result = await submit_tool_call(
+ tool_name,
+ tool_input,
+ id=chain_id,
+ allowed_tool_names=self.tool_names
+ )
if enable_streaming:
# Emit tool observation event
diff --git a/agents/agents/agents/specialized/gui_agent.py b/agents/agents/agents/specialized/gui_agent.py
index e9dd812..9436819 100644
--- a/agents/agents/agents/specialized/gui_agent.py
+++ b/agents/agents/agents/specialized/gui_agent.py
@@ -2,10 +2,13 @@
# SPDX-License-Identifier: Apache-2.0
import json
+import logging
from typing import List, Any, Tuple, Dict, Optional
from ..agent_base import BaseAgent
from agents.utils.ui_action_parser import parse_action_to_structure_output, IMAGE_FACTOR
+logger = logging.getLogger(__name__)
+
# Default image dimensions
TEST_IMAGE_HEIGHT = 1080
TEST_IMAGE_WIDTH = 1920
@@ -96,8 +99,8 @@ def parse(self, responses: List[str], tools: List[Any]) -> List[Dict[str, Any]]:
Returns:
List of structured messages with tool calls
"""
- print(f"[GUIAgent.parse] Number of responses: {len(responses)}")
- print(f"[GUIAgent.parse] Raw responses type: {type(responses)}")
+ logger.debug(f"[GUIAgent.parse] Number of responses: {len(responses)}")
+ logger.debug(f"[GUIAgent.parse] Raw responses type: {type(responses)}")
new_messages_list = []
@@ -109,7 +112,7 @@ def parse(self, responses: List[str], tools: List[Any]) -> List[Dict[str, Any]]:
elif resp and resp.strip():
# Try to reformat responses that don't have the expected format
resp_lower = resp.lower()
- print(f"[GUIAgent.parse] Response missing format, reformatting: {resp[:100]}")
+ logger.debug(f"[GUIAgent.parse] Response missing format, reformatting: {resp[:100]}")
# Check if it contains action-like content
if any(action in resp_lower for action in ['click', 'type', 'scroll']):
@@ -128,9 +131,9 @@ def parse(self, responses: List[str], tools: List[Any]) -> List[Dict[str, Any]]:
# Log responses for debugging
for idx, resp in enumerate(responses[:3]): # Log first 3 responses
if resp:
- print(f"[GUIAgent.parse] Response {idx} length: {len(resp)}, preview: {resp[:200]}")
+ logger.debug(f"[GUIAgent.parse] Response {idx} length: {len(resp)}, preview: {resp[:200]}")
else:
- print(f"[GUIAgent.parse] Response {idx} is None or empty")
+ logger.debug(f"[GUIAgent.parse] Response {idx} is None or empty")
# Parse actions from responses
action_list = []
@@ -145,13 +148,13 @@ def parse(self, responses: List[str], tools: List[Any]) -> List[Dict[str, Any]]:
# Create messages with tool calls
for i, (response, actions) in enumerate(zip(responses, action_list)):
- print(f"[GUIAgent.parse] Processing response {i+1}: response_length={len(response) if response else 0}, actions={actions}")
+ logger.debug(f"[GUIAgent.parse] Processing response {i+1}: response_length={len(response) if response else 0}, actions={actions}")
tool_calls = []
if actions is not None and len(actions) > 0:
if len(actions) > 1:
- print(f"[GUIAgent.parse] Warning: Multiple actions found ({len(actions)}), using first one")
+ logger.debug(f"[GUIAgent.parse] Warning: Multiple actions found ({len(actions)}), using first one")
action = actions[0]
tool_calls = [{
"id": str(i),
@@ -163,7 +166,7 @@ def parse(self, responses: List[str], tools: List[Any]) -> List[Dict[str, Any]]:
}]
else:
# If no action was parsed, create a default click action at center
- print(f"[GUIAgent.parse] No action parsed from response, creating default click action")
+ logger.debug(f"[GUIAgent.parse] No action parsed from response, creating default click action")
default_action = {
"action_type": "click",
"action_inputs": {"start_box": "(960, 540)"},
@@ -184,7 +187,7 @@ def parse(self, responses: List[str], tools: List[Any]) -> List[Dict[str, Any]]:
status = "terminal"
if actions and isinstance(actions[0], dict):
action_type = actions[0].get("action_type", "")
- print(f"[GUIAgent.parse] Action type: {action_type}, terminating after one turn")
+ logger.debug(f"[GUIAgent.parse] Action type: {action_type}, terminating after one turn")
message = {
"role": "assistant",
@@ -193,7 +196,7 @@ def parse(self, responses: List[str], tools: List[Any]) -> List[Dict[str, Any]]:
"loss": True,
"status": status
}
- print(f"[GUIAgent.parse] Created message with status={status}, tool_calls={len(tool_calls)}, content_length={len(response)}")
+ logger.debug(f"[GUIAgent.parse] Created message with status={status}, tool_calls={len(tool_calls)}, content_length={len(response)}")
new_messages_list.append(message)
return new_messages_list
diff --git a/agents/agents/agents/templates/templates.py b/agents/agents/agents/templates/templates.py
index 7697f87..6b8d3ea 100644
--- a/agents/agents/agents/templates/templates.py
+++ b/agents/agents/agents/templates/templates.py
@@ -345,6 +345,10 @@ def _encode_user_message_with_tools(self, content, tools: str) -> str:
for item in content:
if item["type"] == "text":
text += item["text"]
+ elif item["type"] in ["image", "image_url"]:
+ text += self.vision_start + self.image_token + self.vision_end
+ elif item["type"] == "video":
+ text += self.vision_start + self.video_token + self.vision_end
else:
raise ValueError(f"Invalid message type: {item['type']}")
@@ -949,11 +953,11 @@ def _convert_single_message_to_hf_format(self, message: Dict) -> Dict:
# Not sure what to do with other types of content
pass
- def convert_to_hf_format_messages(self, messages: List[Dict]) -> List[Dict]:
+ def convert_to_hf_format_messages(self, messages: Union[List[Dict], Dict[str, List[Dict]]]) -> List[Dict]:
+ hf_messages = []
if messages is None:
return None
role_label, content_label = self._detect_labels(messages)
- hf_messages = []
for message in messages:
hf_messages.append({"role": message[role_label], "content": message[content_label]})
diff --git a/agents/agents/agents/templates/utils.py b/agents/agents/agents/templates/utils.py
index 9e6b383..46faa0c 100644
--- a/agents/agents/agents/templates/utils.py
+++ b/agents/agents/agents/templates/utils.py
@@ -1,3 +1,4 @@
+from collections import defaultdict
import copy
from enum import Enum
import os
@@ -219,26 +220,56 @@ def process_prompt_with_vision(
)
-def tokenize_conversations(messages_list, tokenizer, conv_template, max_length, processor=None, return_tensors="pt", return_reward_mask=False):
+def tokenize_conversations(
+ messages_list,
+ tokenizer,
+ template,
+ max_length,
+ processor=None,
+ return_tensors="pt",
+ return_reward_mask=False,
+ add_generation_prompt=False,
+ padding_side="right",
+ concatenate_mm_inputs=False
+):
batch_input_ids = []
batch_attention_masks = []
batch_labels = []
batch_action_masks = []
+ batch_mm_inputs = []
# TODO: add multiprocessing
for messages in messages_list:
- inputs = tokenize_conversation(messages, tokenizer, conv_template, max_length, processor=processor)
+ inputs = tokenize_conversation(messages, tokenizer, template, max_length, processor=processor, add_generation_prompt=add_generation_prompt)
batch_input_ids.append(inputs['input_ids'].squeeze(0))
batch_attention_masks.append(inputs['attention_mask'].squeeze(0))
batch_labels.append(inputs['labels'].squeeze(0))
batch_action_masks.append(inputs['action_mask'].squeeze(0))
-
+ mm_inputs = {}
+ if "pixel_values" in inputs:
+ mm_inputs["pixel_values"] = inputs["pixel_values"]
+ else:
+ mm_inputs["pixel_values"] = None
+ if "image_grid_thw" in inputs:
+ mm_inputs["image_grid_thw"] = inputs["image_grid_thw"]
+ else:
+ mm_inputs["image_grid_thw"] = None
+
+ batch_mm_inputs.append(mm_inputs)
+
if return_tensors == "pt":
# Use pad_token_id from the tokenizer interface
pad_token_id = getattr(tokenizer, 'pad_token_id', 0)
- batch_input_ids = torch.nn.utils.rnn.pad_sequence(batch_input_ids, batch_first=True, padding_value=pad_token_id)
- batch_attention_masks = torch.nn.utils.rnn.pad_sequence(batch_attention_masks, batch_first=True, padding_value=0)
- batch_labels = torch.nn.utils.rnn.pad_sequence(batch_labels, batch_first=True, padding_value=-100)
- batch_action_masks = torch.nn.utils.rnn.pad_sequence(batch_action_masks, batch_first=True, padding_value=0)
+
+ batch_input_ids = torch.nn.utils.rnn.pad_sequence(batch_input_ids, batch_first=True, padding_value=pad_token_id, padding_side=padding_side)
+ batch_attention_masks = torch.nn.utils.rnn.pad_sequence(batch_attention_masks, batch_first=True, padding_value=0, padding_side=padding_side)
+ batch_labels = torch.nn.utils.rnn.pad_sequence(batch_labels, batch_first=True, padding_value=-100, padding_side=padding_side)
+ batch_action_masks = torch.nn.utils.rnn.pad_sequence(batch_action_masks, batch_first=True, padding_value=0, padding_side=padding_side)
+
+ # convert [{"pixel_values": tensor, "image_grid_thw": tensor}, ...] to {"key1": concat_tensor, "key2": concat_tensor, ...}
+ concatenated_mm_inputs = {}
+ if concatenate_mm_inputs:
+ for key in batch_mm_inputs[0].keys():
+ concatenated_mm_inputs[key] = torch.cat([mm_inputs[key] for mm_inputs in batch_mm_inputs if mm_inputs[key] is not None], dim=0)
inputs = dict(
input_ids=batch_input_ids,
@@ -250,6 +281,20 @@ def tokenize_conversations(messages_list, tokenizer, conv_template, max_length,
if return_reward_mask:
inputs['reward_mask'] = transform_reward_mask(batch_action_masks)
+ # Check if we need mm_inputs
+ mm_keys = list(batch_mm_inputs[0].keys())
+ return_mm_inputs = False
+ for key in mm_keys:
+ if any(mm_inputs[key] is not None for mm_inputs in batch_mm_inputs):
+ return_mm_inputs = True
+ break
+
+ if return_mm_inputs:
+ if concatenate_mm_inputs:
+ inputs.update(concatenated_mm_inputs)
+ else:
+ inputs["mm_inputs"] = batch_mm_inputs
+
return inputs
def visualize_template(template, messages=None, tools=None, **kwargs):
diff --git a/agents/agents/agents/templates/vision_processor.py b/agents/agents/agents/templates/vision_processor.py
index ca14aea..63b0062 100644
--- a/agents/agents/agents/templates/vision_processor.py
+++ b/agents/agents/agents/templates/vision_processor.py
@@ -190,12 +190,16 @@ def process_for_llm(
labels.append(-100)
action_mask.append(0)
+ images_to_process = [image for image in images]
+ videos_to_process = [video for video in videos]
# Step 2: Process each element with vision token expansion
for element, mask_flag in zip(elements, mask_flags):
# Check if element contains vision tokens
if self._contains_vision_tokens(element):
# Expand vision tokens in this element
- expanded_element = self.expand_vision_tokens(element, images, videos, processor)
+ # Number of images and videos should be equal to the total number of vision tokens in the element
+ # We check whether all images and videos are processed later.
+ expanded_element = self.expand_vision_tokens(element, images_to_process, videos_to_process, processor)
cur_input_ids = tokenizer.encode(expanded_element, add_special_tokens=False)
else:
cur_input_ids = tokenizer.encode(element, add_special_tokens=False)
@@ -210,6 +214,8 @@ def process_for_llm(
else:
labels.extend(cur_input_ids)
action_mask.extend([1] * len(cur_input_ids))
+
+ assert len(images_to_process) == len(videos_to_process) == 0, f"All images and videos should be processed, but got {len(images_to_process)} images and {len(videos_to_process)} videos left for vision template {self.config.model_type}."
# Step 3: Create base inputs
inputs = {
@@ -286,6 +292,7 @@ def _load_image_from_input(self, image_input) -> "ImageObject":
# Assume it's a file path
else:
+ print(f"Loading image from file path: {image_input}")
return Image.open(image_input)
# Handle bytes
@@ -477,40 +484,52 @@ def expand_vision_tokens(
num_image_placeholders = prompt.count(self.config.image_token)
num_video_placeholders = prompt.count(self.config.video_token)
- if len(images) != num_image_placeholders:
- raise ValueError(f"Number of images ({len(images)}) doesn't match placeholders ({num_image_placeholders})")
- if len(videos) != num_video_placeholders:
- raise ValueError(f"Number of videos ({len(videos)}) doesn't match placeholders ({num_video_placeholders})")
-
+ # if len(images) != num_image_placeholders:
+ # raise ValueError(f"Number of images ({len(images)}) doesn't match placeholders ({num_image_placeholders})")
+ # if len(videos) != num_video_placeholders:
+ # raise ValueError(f"Number of videos ({len(videos)}) doesn't match placeholders ({num_video_placeholders})")
+ images_slice = [images.pop(0) for _ in range(num_image_placeholders)]
+ videos_slice = [videos.pop(0) for _ in range(num_video_placeholders)]
# Preprocess images and videos to get individual token counts
- processed_images = self.preprocess_images(images, processor) if images else {}
- processed_videos = self.preprocess_videos(videos, processor) if videos else {}
+
+ processed_images = [self.preprocess_images([image], processor) for image in images_slice]
+ processed_videos = [self.preprocess_videos([video], processor) for video in videos_slice]
- # Expand image tokens using regex to avoid infinite loops
expanded_prompt = prompt
- if self.config.image_token in expanded_prompt:
- if processed_images and "pixel_values" in processed_images:
- # Calculate tokens for this specific image
- image_tokens = self.calculate_image_tokens(processed_images, processor)
- replacement = self.config.image_token * image_tokens
- else:
- replacement = self.config.image_token
-
- # Use regex to replace all occurrences at once
- import re
- expanded_prompt = re.sub(re.escape(self.config.image_token), replacement, expanded_prompt)
-
- # Expand video tokens using regex to avoid infinite loops
- if self.config.video_token in expanded_prompt:
- if processed_videos and "pixel_values" in processed_videos:
- # Calculate tokens for this specific video
- video_tokens = self.calculate_video_tokens(processed_videos, processor)
- replacement = self.config.video_token * video_tokens
- else:
- replacement = self.config.video_token
-
- # Use regex to replace all occurrences at once
- expanded_prompt = re.sub(re.escape(self.config.video_token), replacement, expanded_prompt)
+ if self.config.image_token in expanded_prompt and processed_images:
+ parts = expanded_prompt.split(self.config.image_token)
+ expanded_parts = [parts[0]]
+ for idx in range(len(parts) - 1):
+ if idx < len(processed_images):
+ processed_image = processed_images[idx]
+ if "pixel_values" in processed_image:
+ image_tokens = self.calculate_image_tokens(processed_image, processor)
+ replacement = self.config.image_token * image_tokens
+ else:
+ replacement = self.config.image_token
+ else:
+ replacement = self.config.image_token
+ expanded_parts.append(replacement)
+ expanded_parts.append(parts[idx+1])
+ expanded_prompt = ''.join(expanded_parts)
+
+ # Expand video tokens sequentially - each token gets replaced with its corresponding video
+ if self.config.video_token in expanded_prompt and processed_videos:
+ parts = expanded_prompt.split(self.config.video_token)
+ expanded_parts = [parts[0]]
+ for idx in range(len(parts) - 1):
+ if idx < len(processed_videos):
+ processed_video = processed_videos[idx]
+ if "pixel_values" in processed_video:
+ video_tokens = self.calculate_video_tokens(processed_video, processor)
+ replacement = self.config.video_token * video_tokens
+ else:
+ replacement = self.config.video_token
+ else:
+ replacement = self.config.video_token
+ expanded_parts.append(replacement)
+ expanded_parts.append(parts[idx+1])
+ expanded_prompt = ''.join(expanded_parts)
return expanded_prompt
diff --git a/agents/agents/rewards/gui_reward.py b/agents/agents/rewards/gui_reward.py
index ae24d7a..8a118c3 100644
--- a/agents/agents/rewards/gui_reward.py
+++ b/agents/agents/rewards/gui_reward.py
@@ -1,6 +1,7 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: Apache-2.0
+import logging
import re
import json
import ast
@@ -9,6 +10,8 @@
from .reward_base import reward
from agents.utils.ui_action_parser import parse_action_to_structure_output, IMAGE_FACTOR
+logger = logging.getLogger(__name__)
+
# Image dimensions for testing
TEST_IMAGE_HEIGHT = 1080
TEST_IMAGE_WIDTH = 1920
@@ -71,7 +74,7 @@ def extract_action(content: str) -> str:
return action_type
except Exception as e:
- print(f"[extract_action] Error: {e}")
+ logger.debug(f"[extract_action] Error: {e}")
return "no action"
@@ -96,7 +99,7 @@ def extract_input_text(content: str) -> str:
return ""
except Exception as e:
- print(f"[extract_input_text] Error: {e}")
+ logger.debug(f"[extract_input_text] Error: {e}")
return ""
@@ -120,7 +123,7 @@ def extract_coord(content: str) -> Tuple[list, bool]:
coords[0] / TEST_IMAGE_WIDTH,
coords[1] / TEST_IMAGE_HEIGHT
]
- print(f"[extract_coord] Normalized point from {coords} to {normalized_coords}")
+ logger.debug(f"[extract_coord] Normalized point from {coords} to {normalized_coords}")
return normalized_coords, True
elif len(coords) == 4:
# Box format [x1, y1, x2, y2] in pixels - normalize to 0-1
@@ -130,15 +133,15 @@ def extract_coord(content: str) -> Tuple[list, bool]:
coords[2] / TEST_IMAGE_WIDTH,
coords[3] / TEST_IMAGE_HEIGHT
]
- print(f"[extract_coord] Normalized box from {coords} to {normalized_coords}")
+ logger.debug(f"[extract_coord] Normalized box from {coords} to {normalized_coords}")
return normalized_coords, True
else:
- print(f"[extract_coord] Unexpected coord format: {coords}")
+ logger.debug(f"[extract_coord] Unexpected coord format: {coords}")
except Exception as e:
- print(f"[extract_coord] Error parsing coordinates: {e}")
+ logger.debug(f"[extract_coord] Error parsing coordinates: {e}")
except Exception as e:
- print(f"[extract_coord] Error: {e}")
+ logger.debug(f"[extract_coord] Error: {e}")
return [], False
@@ -160,9 +163,9 @@ def gui_accuracy_score(predict_str: str, gt_action: str, gt_bbox: list, gt_input
pred_coord, has_coord = extract_coord(predict_str)
pred_input_text = extract_input_text(predict_str)
- print(f"[gui_accuracy_score] gt_action: {gt_action}, pred_action: {pred_action}")
- print(f"[gui_accuracy_score] gt_bbox: {gt_bbox}, pred_coord: {pred_coord}, has_coord: {has_coord}")
- print(f"[gui_accuracy_score] gt_input_text: {gt_input_text}, pred_input_text: {pred_input_text}")
+ logger.debug(f"[gui_accuracy_score] gt_action: {gt_action}, pred_action: {pred_action}")
+ logger.debug(f"[gui_accuracy_score] gt_bbox: {gt_bbox}, pred_coord: {pred_coord}, has_coord: {has_coord}")
+ logger.debug(f"[gui_accuracy_score] gt_input_text: {gt_input_text}, pred_input_text: {pred_input_text}")
# Map all click variants to 'click' for the 3-action space
action_mapping = {
@@ -185,9 +188,9 @@ def gui_accuracy_score(predict_str: str, gt_action: str, gt_bbox: list, gt_input
# 1. Action type matching (0.5 points)
if pred_action_normalized == gt_action_normalized:
score += 0.5
- print(f"[gui_accuracy_score] Action matched: +0.5 points")
+ logger.debug(f"[gui_accuracy_score] Action matched: +0.5 points")
else:
- print(f"[gui_accuracy_score] Action mismatch: {pred_action_normalized} vs {gt_action_normalized}")
+ logger.debug(f"[gui_accuracy_score] Action mismatch: {pred_action_normalized} vs {gt_action_normalized}")
# 2. Parameter matching (0.5 points) - depends on action type
if gt_action_normalized == 'click':
@@ -202,7 +205,7 @@ def gui_accuracy_score(predict_str: str, gt_action: str, gt_bbox: list, gt_input
gt_x = (gt_bbox[0] + gt_bbox[2]) / 2
gt_y = (gt_bbox[1] + gt_bbox[3]) / 2
else:
- print(f"[gui_accuracy_score] Invalid gt_bbox format: {gt_bbox}")
+ logger.debug(f"[gui_accuracy_score] Invalid gt_bbox format: {gt_bbox}")
return score
# Get predicted center (already normalized 0-1)
@@ -212,7 +215,7 @@ def gui_accuracy_score(predict_str: str, gt_action: str, gt_bbox: list, gt_input
pred_x = (pred_coord[0] + pred_coord[2]) / 2
pred_y = (pred_coord[1] + pred_coord[3]) / 2
else:
- print(f"[gui_accuracy_score] Invalid pred_coord format: {pred_coord}")
+ logger.debug(f"[gui_accuracy_score] Invalid pred_coord format: {pred_coord}")
return score
# Calculate distance in normalized space
@@ -223,15 +226,15 @@ def gui_accuracy_score(predict_str: str, gt_action: str, gt_bbox: list, gt_input
if distance < threshold:
score += 0.5
- print(f"[gui_accuracy_score] Bbox matched (distance={distance:.4f}): +0.5 points")
+ logger.debug(f"[gui_accuracy_score] Bbox matched (distance={distance:.4f}): +0.5 points")
else:
- print(f"[gui_accuracy_score] Bbox too far (distance={distance:.4f}, threshold={threshold:.4f})")
+ logger.debug(f"[gui_accuracy_score] Bbox too far (distance={distance:.4f}, threshold={threshold:.4f})")
else:
- print(f"[gui_accuracy_score] No predicted coordinates for click action")
+ logger.debug(f"[gui_accuracy_score] No predicted coordinates for click action")
else:
# No gt_bbox required, any click gets parameter points
score += 0.5
- print(f"[gui_accuracy_score] No gt_bbox required: +0.5 points")
+ logger.debug(f"[gui_accuracy_score] No gt_bbox required: +0.5 points")
elif gt_action_normalized == 'type':
# For type: check text content
@@ -239,34 +242,34 @@ def gui_accuracy_score(predict_str: str, gt_action: str, gt_bbox: list, gt_input
f1, _, _ = f1_score(pred_input_text, gt_input_text)
if f1 >= 0.5:
score += 0.5
- print(f"[gui_accuracy_score] Type text matched (f1={f1:.2f}): +0.5 points")
+ logger.debug(f"[gui_accuracy_score] Type text matched (f1={f1:.2f}): +0.5 points")
else:
- print(f"[gui_accuracy_score] Type text mismatch (f1={f1:.2f})")
+ logger.debug(f"[gui_accuracy_score] Type text mismatch (f1={f1:.2f})")
else:
# No text required, any type action gets parameter points
score += 0.5
- print(f"[gui_accuracy_score] No text required: +0.5 points")
+ logger.debug(f"[gui_accuracy_score] No text required: +0.5 points")
elif gt_action_normalized == 'scroll':
# For scroll: only check direction (no bbox needed)
if gt_input_text and gt_input_text != "no input text":
if pred_input_text.lower() == gt_input_text.lower():
score += 0.5
- print(f"[gui_accuracy_score] Scroll direction matched: +0.5 points")
+ logger.debug(f"[gui_accuracy_score] Scroll direction matched: +0.5 points")
else:
- print(f"[gui_accuracy_score] Scroll direction mismatch: {pred_input_text} vs {gt_input_text}")
+ logger.debug(f"[gui_accuracy_score] Scroll direction mismatch: {pred_input_text} vs {gt_input_text}")
else:
# No direction specified, any scroll gets parameter points
score += 0.5
- print(f"[gui_accuracy_score] No scroll direction required: +0.5 points")
+ logger.debug(f"[gui_accuracy_score] No scroll direction required: +0.5 points")
- print(f"[gui_accuracy_score] Final score: {score}")
+ logger.debug(f"[gui_accuracy_score] Final score: {score}")
return score
except Exception as e:
- print(f"Error in gui_accuracy_score: {e}")
- print(f"predict_str: {predict_str}")
- print(f"gt_action: {gt_action}, gt_bbox: {gt_bbox}, gt_input_text: {gt_input_text}")
+ logger.debug(f"Error in gui_accuracy_score: {e}")
+ logger.debug(f"predict_str: {predict_str}")
+ logger.debug(f"gt_action: {gt_action}, gt_bbox: {gt_bbox}, gt_input_text: {gt_input_text}")
return 0.0
@@ -283,18 +286,18 @@ def gui_reward(prediction: str, trajectory: List[Dict] = None, gt_action: str =
Returns:
Dictionary with reward scores
"""
- print(f"[gui_reward] Called with prediction: {prediction[:200] if prediction else 'None'}")
- print(f"[gui_reward] kwargs keys: {list(kwargs.keys())}")
+ logger.debug(f"[gui_reward] Called with prediction: {prediction[:200] if prediction else 'None'}")
+ logger.debug(f"[gui_reward] kwargs keys: {list(kwargs.keys())}")
# Handle empty predictions
if not prediction or prediction.strip() == "":
- print(f"[gui_reward] Warning: Empty prediction received")
+ logger.debug(f"[gui_reward] Warning: Empty prediction received")
# Check if there's a default action in trajectory
if trajectory and len(trajectory) > 0:
for msg in reversed(trajectory):
if msg.get('role') == 'assistant' and msg.get('content'):
prediction = msg['content']
- print(f"[gui_reward] Using trajectory content as prediction: {prediction[:100]}")
+ logger.debug(f"[gui_reward] Using trajectory content as prediction: {prediction[:100]}")
break
# if not prediction or prediction.strip() == "":
@@ -309,7 +312,7 @@ def gui_reward(prediction: str, trajectory: List[Dict] = None, gt_action: str =
if hasattr(gt_bbox, 'tolist'):
gt_bbox = gt_bbox.tolist()
- print(f"[gui_reward] gt_action: {gt_action}, gt_bbox: {gt_bbox}, gt_input_text: {gt_input_text}")
+ logger.debug(f"[gui_reward] gt_action: {gt_action}, gt_bbox: {gt_bbox}, gt_input_text: {gt_input_text}")
# Handle "no input text" as empty
if gt_input_text == "no input text":
@@ -319,7 +322,7 @@ def gui_reward(prediction: str, trajectory: List[Dict] = None, gt_action: str =
# Both prediction and ground truth use normalized coordinates
if not gt_action and not gt_bbox and not gt_input_text:
- print(f"[gui_reward] Warning: No ground truth data provided - returning 0 reward")
+ logger.debug(f"[gui_reward] Warning: No ground truth data provided - returning 0 reward")
return {
"reward": 0.0,
"format": gui_format_score(prediction),
@@ -333,7 +336,7 @@ def gui_reward(prediction: str, trajectory: List[Dict] = None, gt_action: str =
format_score = gui_format_score(prediction)
accuracy_score = gui_accuracy_score(prediction, gt_action, gt_bbox, gt_input_text)
- print(f"[gui_reward] format_score: {format_score}, accuracy_score: {accuracy_score}")
+ logger.debug(f"[gui_reward] format_score: {format_score}, accuracy_score: {accuracy_score}")
# For f1_score, create answer string for backward compatibility
answer_dict = {
diff --git a/agents/agents/rewards/qa_reward.py b/agents/agents/rewards/qa_reward.py
index a6dd139..a78c99d 100644
--- a/agents/agents/rewards/qa_reward.py
+++ b/agents/agents/rewards/qa_reward.py
@@ -123,6 +123,12 @@ def ok_vqa_reward(prediction: str, answers: List[str], trajectory: List[str]) ->
@reward(name="infoseek_reward")
def infoseek_reward(prediction: str, answer: Union[str, List[str]], answer_eval: List[str | Dict], trajectory: List[str]) -> float:
+ # format reward
+ call_tool_count = 0
+ for msg in trajectory:
+ if msg["role"] == "tool":
+ call_tool_count += 1
+
f1_scores = []
answers = []
if isinstance(answer, str):
@@ -137,5 +143,10 @@ def infoseek_reward(prediction: str, answer: Union[str, List[str]], answer_eval:
f1, precision, recall = f1_score(prediction, _answer)
f1_scores.append(f1)
- # All answers are the correct answer, take the max f1 score
- return max(f1_scores)
+ max_f1_score = max(f1_scores)
+
+ call_tool_reward = 1.0 if call_tool_count > 1 else 0.0
+
+ reward = 0.2 * call_tool_reward + 0.8 * max_f1_score
+
+ return reward
diff --git a/agents/agents/tools/tool_base.py b/agents/agents/tools/tool_base.py
index 8cd3095..e2ad350 100644
--- a/agents/agents/tools/tool_base.py
+++ b/agents/agents/tools/tool_base.py
@@ -307,11 +307,19 @@ def factory():
return decorator
-async def submit_tool_call(tool_name: str, tool_input: str, id: str=None) -> dict:
+async def submit_tool_call(
+ tool_name: str,
+ tool_input: str,
+ id: str=None,
+ allowed_tool_names: List[str] = None,
+) -> dict:
"""
Submit a tool call to the environment.
"""
- if tool_name not in TOOL_REGISTRY:
+ if allowed_tool_names is None:
+ allowed_tool_names = list(TOOL_REGISTRY.keys())
+
+ if tool_name not in allowed_tool_names:
tool_name = "hallucination_tool"
tool_input = {"tool_name": str(tool_name)}
@@ -332,6 +340,7 @@ async def submit_tool_call(tool_name: str, tool_input: str, id: str=None) -> dic
# If the loaded input is not a dict, it means the input is not a valid JSON object
if not isinstance(tool_input_json, dict):
tool_input_json = None
+
elif isinstance(tool_input, dict):
tool_input_json = tool_input
else:
@@ -351,14 +360,25 @@ async def submit_tool_call(tool_name: str, tool_input: str, id: str=None) -> dic
return await tool_obj(**tool_input_json)
-def submit_tool_calls(tool_names: List[str], tool_inputs: List[Dict | str], ids: List[str]) -> List[dict]:
+def submit_tool_calls(
+ tool_names: List[str],
+ tool_inputs: List[Dict | str],
+ ids: List[str],
+ allowed_tool_names: List[str] = None,
+) -> List[dict]:
"""
Submit tool calls to the environment. This is a synchronous wrapper that blocks until all results are ready.
Uses ThreadPoolExecutor to run tool calls in parallel.
"""
+
+ if allowed_tool_names is None:
+ allowed_tool_names = list(TOOL_REGISTRY.keys())
+
+
mapped_tool_names = []
mapped_tool_inputs = []
tool_objs = []
+
for tool_name, tool_input, id in zip(tool_names, tool_inputs, ids):
if isinstance(tool_input, dict):
tool_input_json = tool_input
@@ -371,10 +391,12 @@ def submit_tool_calls(tool_names: List[str], tool_inputs: List[Dict | str], ids:
raise ValueError(f"Invalid tool input: {tool_input}")
- if tool_name not in TOOL_REGISTRY:
+ if tool_name not in allowed_tool_names:
+ # Called a non-existent tool
mapped_tool_name = "hallucination_tool"
tool_input_json = {"tool_name": tool_name}
elif tool_input_json is None:
+ # Invalid input
mapped_tool_name = "invalid_input_tool"
tool_input_json = {"tool_input": tool_input}
else:
@@ -408,13 +430,11 @@ def submit_tool_calls(tool_names: List[str], tool_inputs: List[Dict | str], ids:
@tool()
def hallucination_tool(tool_name):
- return f"Hallucinated tool: {tool_name}"
+ return f"Hallucinated tool: {tool_name} does not exist."
@tool()
def invalid_input_tool(tool_input):
- return f"Invalid input: {tool_input}, input mush be a valid JSON object."
-
-
+ return f"Invalid input: {tool_input}, input must be a valid JSON object."
diff --git a/agents/agents/utils/ui_action_parser.py b/agents/agents/utils/ui_action_parser.py
index e992fce..0bc55a5 100644
--- a/agents/agents/utils/ui_action_parser.py
+++ b/agents/agents/utils/ui_action_parser.py
@@ -4,6 +4,9 @@
import ast
import math
from typing import List, Dict, Any, Optional, Tuple
+import logging
+
+logger = logging.getLogger(__name__)
IMAGE_FACTOR = 1 # Changed to match gui_reward.py
MIN_PIXELS = 100 * 28 * 28
@@ -69,7 +72,7 @@ def parse_action(action_str: str) -> Optional[Dict[str, Any]]:
return {'function': func_name, 'args': kwargs}
except Exception as e:
- print(f"Failed to parse action '{action_str}': {e}")
+ logger.debug(f"Failed to parse action '{action_str}': {e}")
return None
@@ -145,11 +148,11 @@ def parse_action_to_structure_output(text: str,
max_pixels: int = 16384 * 28 * 28,
min_pixels: int = 100 * 28 * 28) -> Optional[List[Dict[str, Any]]]:
"""Parse action text to structured output."""
- print(f"[parse_action_to_structure_output] Input text: {text[:500] if text else 'Empty text'}")
+ logger.debug(f"[parse_action_to_structure_output] Input text: {text[:500] if text else 'Empty text'}")
# Handle empty or None responses
if not text:
- print(f"[parse_action_to_structure_output] Empty text, returning None")
+ logger.debug(f"[parse_action_to_structure_output] Empty text, returning None")
return None
text = text.strip()
@@ -197,11 +200,11 @@ def parse_action_to_structure_output(text: str,
reflection = thought_match.group(1).strip()
if "Action:" not in text:
- print(f"[parse_action_to_structure_output] No 'Action:' found in text, returning None")
+ logger.debug(f"[parse_action_to_structure_output] No 'Action:' found in text, returning None")
return None
action_str = text.split("Action: ")[-1]
- print(f"[parse_action_to_structure_output] Extracted action string: {action_str[:200]}")
+ logger.debug(f"[parse_action_to_structure_output] Extracted action string: {action_str[:200]}")
# Parse multiple actions
tmp_all_action = action_str.split(")\n\n")
@@ -228,7 +231,7 @@ def parse_action_to_structure_output(text: str,
actions = []
for action_instance, raw_str in zip(parsed_actions, all_action):
if action_instance is None:
- print(f"Action can't parse: {raw_str}")
+ logger.debug(f"Action can't parse: {raw_str}")
continue
action_type = action_instance["function"]
@@ -256,7 +259,7 @@ def parse_action_to_structure_output(text: str,
for num in numbers:
float(num.strip())
except ValueError:
- print(f"Warning: Invalid coordinate format in '{param_name}': '{ori_box}'")
+ logger.debug(f"Warning: Invalid coordinate format in '{param_name}': '{ori_box}'")
return None
# Convert coordinates based on model type
diff --git a/agents/requirements.txt b/agents/requirements.txt
index 559e048..8a857b0 100644
--- a/agents/requirements.txt
+++ b/agents/requirements.txt
@@ -6,7 +6,7 @@ redis
docker
openai
faiss-cpu
-vllm==0.9.2
+vllm==0.10.0
termcolor
tenacity
nest-asyncio
diff --git a/agents/tests/unit/agents/templates/test_vision_templates_full_align.py b/agents/tests/unit/agents/templates/test_vision_templates_full_align.py
index 167307d..a0b8d8e 100644
--- a/agents/tests/unit/agents/templates/test_vision_templates_full_align.py
+++ b/agents/tests/unit/agents/templates/test_vision_templates_full_align.py
@@ -60,6 +60,25 @@
{"role": "assistant", "content": "I am fine, thank you."},
{"role": "user", "content": "What is 3 times 5?"},
],
+ [
+ {
+ "role": "user",
+ "content": [
+ {"type": "image", "image": "https://images.unsplash.com/photo-1592194996308-7b43878e84a6"},
+ {"type": "text", "text": "Describe these images."},
+ {"type": "image", "image": "https://images.unsplash.com/photo-1599158164704-ef1ec0c94b1c"}
+ ]
+ },
+ {
+ "role": "assistant",
+ "content": [
+ {
+ "type": "text",
+ "text": "The image is a cat.",
+ },
+ ],
+ },
+ ]
])
@pytest.mark.parametrize("tools", [
None,
diff --git a/agents/tests/unit/agents/templates/test_vision_templates_tokenize.py b/agents/tests/unit/agents/templates/test_vision_templates_tokenize.py
index 31b5ecd..6e17ef8 100644
--- a/agents/tests/unit/agents/templates/test_vision_templates_tokenize.py
+++ b/agents/tests/unit/agents/templates/test_vision_templates_tokenize.py
@@ -63,6 +63,25 @@
# {"role": "assistant", "content": "I am fine, thank you."},
# {"role": "user", "content": "What is 3 times 5?"},
# ],
+ [
+ {
+ "role": "user",
+ "content": [
+ {"type": "image", "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"},
+ {"type": "text", "text": "Describe these images."},
+ {"type": "image", "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"}
+ ]
+ },
+ {
+ "role": "assistant",
+ "content": [
+ {
+ "type": "text",
+ "text": "The image is a cat.",
+ },
+ ],
+ },
+ ]
])
@pytest.mark.parametrize("tools", [
None,
@@ -94,7 +113,6 @@ def test_chat_template_equal(template, messages, tools, add_generation_prompt):
official_prompt = tokenizer.decode(official_inputs['input_ids'][0])
implemented_prompt = tokenizer.decode(implemented_inputs['input_ids'][0])
print(f"Official prompt image tokens: {official_prompt.count('<|image_pad|>')}\nImplemented prompt image tokens: {implemented_prompt.count('<|image_pad|>')}")
-
print(f"Official images: {official_inputs['pixel_values'].shape}\nImplemented images: {implemented_inputs['pixel_values'].shape}")
assert torch.equal(official_inputs["input_ids"], implemented_inputs["input_ids"]), f"""Offical
diff --git a/agents/tests/unit/agents/test_vision_agent.py b/agents/tests/unit/agents/test_vision_agent.py
index eedc521..67bf4d6 100644
--- a/agents/tests/unit/agents/test_vision_agent.py
+++ b/agents/tests/unit/agents/test_vision_agent.py
@@ -1,3 +1,4 @@
+import torch
from agents.agents.react.react_agent import ReactAgent
from agents.tools import answer_qa
import pytest
@@ -45,5 +46,12 @@ async def test_vision_agent():
messages = messages_list[0]['messages']
for message in messages:
print(f"{message['role']}: {message['content']}")
- inputs = react_agent.tokenize_trajectories()
- print(inputs)
\ No newline at end of file
+ inputs, other_info_list = react_agent.tokenize_trajectories()
+ for key, value in inputs.items():
+ if isinstance(value, torch.Tensor):
+ print(f"{key}: {value.shape}")
+ else:
+ print(f"{key}: {value}")
+ other_info = other_info_list[0]
+ for key, value in other_info.items():
+ print(f"{key}: {value}")
\ No newline at end of file
diff --git a/install.sh b/install.sh
new file mode 100644
index 0000000..b014117
--- /dev/null
+++ b/install.sh
@@ -0,0 +1,429 @@
+#!/bin/bash
+
+# AgentFly Installation Script
+# This script handles the complete installation of AgentFly and its dependencies
+
+set -e # Exit on any error
+
+# Colors for output
+RED='\033[0;31m'
+GREEN='\033[0;32m'
+YELLOW='\033[1;33m'
+BLUE='\033[0;34m'
+NC='\033[0m' # No Color
+
+# Function to print colored output
+print_status() {
+ echo -e "${BLUE}[INFO]${NC} $1"
+}
+
+print_success() {
+ echo -e "${GREEN}[SUCCESS]${NC} $1"
+}
+
+print_warning() {
+ echo -e "${YELLOW}[WARNING]${NC} $1"
+}
+
+print_error() {
+ echo -e "${RED}[ERROR]${NC} $1"
+}
+
+# Function to check if command exists
+command_exists() {
+ command -v "$1" >/dev/null 2>&1
+}
+
+# Function to check if user has sudo access
+check_sudo() {
+ if sudo -n true 2>/dev/null; then
+ return 0
+ else
+ return 1
+ fi
+}
+
+# Function to install enroot
+install_enroot() {
+ print_status "Installing enroot..."
+
+ # Check if we're on a supported system
+ if [[ "$OSTYPE" == "linux-gnu"* ]]; then
+ # Ubuntu/Debian
+ if command_exists apt-get; then
+ print_status "Detected Ubuntu/Debian system, installing enroot via deb packages..."
+
+ # Get architecture
+ arch=$(dpkg --print-architecture)
+ if [ $? -eq 0 ]; then
+ print_status "Detected architecture: $arch"
+ INSTALLATION_STATUS+=("architecture detection: SUCCESS")
+ else
+ print_error "Failed to detect architecture"
+ INSTALLATION_STATUS+=("architecture detection: FAILED")
+ return 1
+ fi
+
+ # Download enroot packages
+ print_status "Downloading enroot packages..."
+ curl -fSsL -O "https://github.com/NVIDIA/enroot/releases/download/v3.5.0/enroot-hardened_3.5.0-1_${arch}.deb"
+ if [ $? -eq 0 ]; then
+ print_success "Downloaded enroot-hardened package"
+ INSTALLATION_STATUS+=("enroot-hardened download: SUCCESS")
+ else
+ print_error "Failed to download enroot-hardened package"
+ INSTALLATION_STATUS+=("enroot-hardened download: FAILED")
+ return 1
+ fi
+
+ curl -fSsL -O "https://github.com/NVIDIA/enroot/releases/download/v3.5.0/enroot-hardened+caps_3.5.0-1_${arch}.deb"
+ if [ $? -eq 0 ]; then
+ print_success "Downloaded enroot-hardened+caps package"
+ INSTALLATION_STATUS+=("enroot-hardened+caps download: SUCCESS")
+ else
+ print_error "Failed to download enroot-hardened+caps package"
+ INSTALLATION_STATUS+=("enroot-hardened+caps download: FAILED")
+ return 1
+ fi
+
+ # Install packages
+ print_status "Installing enroot packages..."
+ sudo apt install -y ./*.deb
+ if [ $? -eq 0 ]; then
+ print_success "enroot packages installed successfully!"
+ INSTALLATION_STATUS+=("enroot package installation: SUCCESS")
+ else
+ print_error "Failed to install enroot packages"
+ INSTALLATION_STATUS+=("enroot package installation: FAILED")
+ return 1
+ fi
+
+ # Clean up downloaded packages
+ rm -f ./*.deb
+ print_status "Cleaned up downloaded packages"
+ INSTALLATION_STATUS+=("package cleanup: SUCCESS")
+
+ else
+ print_warning "Unsupported package manager. Please install enroot manually from: https://github.com/NVIDIA/enroot/blob/master/doc/installation.md"
+ return 1
+ fi
+ else
+ print_warning "Unsupported operating system. Please install enroot manually from: https://github.com/NVIDIA/enroot/blob/master/doc/installation.md"
+ return 1
+ fi
+
+ if command_exists enroot; then
+ print_success "enroot installed successfully!"
+ return 0
+ else
+ print_error "Failed to install enroot. Please install manually."
+ return 1
+ fi
+}
+
+# Function to check conda and create agentfly environment
+setup_conda_environment() {
+ if ! command_exists conda; then
+ print_error "conda not found. Please install conda first."
+ exit 1
+ fi
+
+ print_success "conda found"
+
+ # Check if agentfly environment already exists
+ if conda env list | grep -q "agentfly"; then
+ print_status "agentfly environment already exists, activating it..."
+ conda activate agentfly
+ if [ $? -eq 0 ]; then
+ print_success "agentfly environment activated"
+ INSTALLATION_STATUS+=("conda environment activation: SUCCESS")
+ else
+ print_error "Failed to activate existing agentfly environment"
+ INSTALLATION_STATUS+=("conda environment activation: FAILED")
+ return 1
+ fi
+ else
+ print_status "Creating new conda environment 'agentfly' with Python 3.10..."
+ conda create -n agentfly python=3.10 -y
+ if [ $? -eq 0 ]; then
+ print_success "agentfly environment created successfully!"
+ INSTALLATION_STATUS+=("conda environment creation: SUCCESS")
+ else
+ print_error "Failed to create agentfly environment"
+ INSTALLATION_STATUS+=("conda environment creation: FAILED")
+ return 1
+ fi
+
+ print_status "Activating agentfly environment..."
+ conda activate agentfly
+ if [ $? -eq 0 ]; then
+ print_success "agentfly environment activated"
+ INSTALLATION_STATUS+=("conda environment activation: SUCCESS")
+ else
+ print_error "Failed to activate agentfly environment"
+ INSTALLATION_STATUS+=("conda environment activation: FAILED")
+ return 1
+ fi
+ fi
+}
+
+# Function to install redis-server via conda
+install_redis() {
+ print_status "Installing redis-server via conda..."
+
+ # Ensure conda is in PATH
+ if command_exists conda; then
+ conda install -y conda-forge::redis-server==7.4.0
+ if [ $? -eq 0 ]; then
+ print_success "redis-server installed successfully!"
+ return 0
+ else
+ print_error "Failed to install redis-server"
+ return 1
+ fi
+ else
+ print_error "conda not found. Please install conda first or install redis-server manually."
+ return 1
+ fi
+}
+
+# Main installation function
+main() {
+ echo "=========================================="
+ echo " AgentFly Installation Script"
+ echo "=========================================="
+ echo ""
+
+ # Check Python version (will be checked again after conda environment setup)
+ print_status "Checking Python version..."
+ if command_exists python3; then
+ PYTHON_VERSION=$(python3 --version 2>&1 | awk '{print $2}')
+ print_status "Found Python version: $PYTHON_VERSION"
+ INSTALLATION_STATUS+=("Python availability: SUCCESS")
+ else
+ print_error "Python 3 not found. Please install Python 3.10.x first."
+ INSTALLATION_STATUS+=("Python availability: FAILED")
+ exit 1
+ fi
+
+ # Check pip
+ print_status "Checking pip..."
+ if command_exists pip3; then
+ print_success "pip3 found"
+ INSTALLATION_STATUS+=("pip availability: SUCCESS")
+ elif command_exists pip; then
+ print_success "pip found"
+ INSTALLATION_STATUS+=("pip availability: SUCCESS")
+ else
+ print_error "pip not found. Please install pip first."
+ INSTALLATION_STATUS+=("pip availability: FAILED")
+ exit 1
+ fi
+
+ # Check git
+ print_status "Checking git..."
+ if command_exists git; then
+ print_success "git found"
+ INSTALLATION_STATUS+=("git availability: SUCCESS")
+ else
+ print_error "git not found. Please install git first."
+ INSTALLATION_STATUS+=("git availability: FAILED")
+ exit 1
+ fi
+
+ # Initialize git submodules
+ print_status "Initializing git submodules..."
+ if [ -d ".git" ]; then
+ git submodule init
+ if [ $? -eq 0 ]; then
+ git submodule update
+ if [ $? -eq 0 ]; then
+ print_success "Git submodules initialized successfully!"
+ INSTALLATION_STATUS+=("Git submodules: SUCCESS")
+ else
+ print_error "Failed to update git submodules"
+ INSTALLATION_STATUS+=("Git submodules: FAILED")
+ fi
+ else
+ print_error "Failed to init git submodules"
+ INSTALLATION_STATUS+=("Git submodules: FAILED")
+ fi
+ else
+ print_warning "Not in a git repository. Skipping submodule initialization."
+ INSTALLATION_STATUS+=("Git submodules: SKIPPED (not git repo)")
+ fi
+
+ # Install Python dependencies
+ print_status "Installing basic Python dependencies..."
+ pip install -e . > /dev/null
+ if [ $? -eq 0 ]; then
+ print_success "Basic dependencies installed successfully!"
+ INSTALLATION_STATUS+=("Basic Python dependencies: SUCCESS")
+ else
+ print_error "Failed to install basic dependencies"
+ INSTALLATION_STATUS+=("Basic Python dependencies: FAILED")
+ fi
+
+ print_status "Installing VERL dependencies..."
+ pip install -e '.[verl]' --no-build-isolation > /dev/null
+ if [ $? -eq 0 ]; then
+ print_success "VERL dependencies installed successfully!"
+ INSTALLATION_STATUS+=("VERL dependencies: SUCCESS")
+ else
+ print_error "Failed to install VERL dependencies"
+ INSTALLATION_STATUS+=("VERL dependencies: FAILED")
+ fi
+
+ # Check and install enroot if needed
+ print_status "Checking enroot installation..."
+ if command_exists enroot; then
+ print_success "enroot is already installed"
+ else
+ print_warning "enroot not found. Some tools require it for container management."
+
+ if check_sudo; then
+ print_status "Sudo access detected. Attempting to install enroot..."
+ INSTALLATION_STATUS+=("sudo access: SUCCESS")
+ if install_enroot; then
+ print_success "enroot installation completed!"
+ INSTALLATION_STATUS+=("enroot installation: SUCCESS")
+ else
+ print_warning "enroot installation failed. Some tools may not work properly."
+ INSTALLATION_STATUS+=("enroot installation: FAILED")
+ fi
+ else
+ print_warning "No sudo access. Please install enroot manually from: https://github.com/NVIDIA/enroot/blob/master/doc/installation.md"
+ INSTALLATION_STATUS+=("sudo access: FAILED")
+ INSTALLATION_STATUS+=("enroot installation: SKIPPED (no sudo)")
+ fi
+ fi
+
+ # Check conda availability
+ print_status "Checking conda availability..."
+ if command_exists conda; then
+ print_success "conda found"
+ INSTALLATION_STATUS+=("conda availability: SUCCESS")
+ else
+ print_error "conda not found. Please install conda first."
+ INSTALLATION_STATUS+=("conda availability: FAILED")
+ exit 1
+ fi
+
+ # Install redis-server (assuming we're already in a conda environment)
+ print_status "Installing redis-server via conda..."
+ if install_redis; then
+ INSTALLATION_STATUS+=("redis-server installation: SUCCESS")
+ else
+ INSTALLATION_STATUS+=("redis-server installation: FAILED")
+ fi
+
+ # Final checks and summary
+ echo ""
+ echo "=========================================="
+ echo " Installation Summary"
+ echo "=========================================="
+
+ print_status "Checking installed components..."
+
+ if command_exists python3; then
+ PYTHON_VERSION=$(python3 --version 2>&1 | awk '{print $2}')
+ if [[ "$PYTHON_VERSION" =~ ^3\.10\. ]]; then
+ print_success "β Python 3.10.x ($PYTHON_VERSION)"
+ INSTALLATION_STATUS+=("Python 3.10.x verification: SUCCESS")
+ else
+ print_error "β Python version $PYTHON_VERSION does not meet requirements (need 3.10.x)"
+ INSTALLATION_STATUS+=("Python 3.10.x verification: FAILED")
+ fi
+ else
+ print_error "β Python 3 not found"
+ INSTALLATION_STATUS+=("Python 3.10.x verification: FAILED")
+ fi
+
+ if [ -d "AgentFly.egg-info" ] || [ -d "agents" ]; then
+ print_success "β AgentFly package"
+ INSTALLATION_STATUS+=("AgentFly package verification: SUCCESS")
+ else
+ print_error "β AgentFly package not found"
+ INSTALLATION_STATUS+=("AgentFly package verification: FAILED")
+ fi
+
+ if command_exists enroot; then
+ print_success "β enroot"
+ INSTALLATION_STATUS+=("enroot verification: SUCCESS")
+ else
+ print_warning "β enroot (not installed - some tools may not work)"
+ INSTALLATION_STATUS+=("enroot verification: FAILED")
+ fi
+
+ if command_exists conda; then
+ print_success "β conda"
+ INSTALLATION_STATUS+=("conda verification: SUCCESS")
+ else
+ print_error "β conda not found"
+ INSTALLATION_STATUS+=("conda verification: FAILED")
+ fi
+
+ # Skip agentfly environment check - assuming we're already in a conda environment
+ print_success "β conda environment (assuming active)"
+ INSTALLATION_STATUS+=("conda environment verification: SKIPPED (assumed active)")
+
+ if command_exists redis-server; then
+ print_success "β redis-server"
+ INSTALLATION_STATUS+=("redis-server verification: SUCCESS")
+ else
+ print_error "β redis-server not found"
+ INSTALLATION_STATUS+=("redis-server verification: FAILED")
+ fi
+
+ echo ""
+ echo "=========================================="
+ echo " Step-by-Step Status Report"
+ echo "=========================================="
+
+ # Count successes, failures, and skips
+ SUCCESS_COUNT=0
+ FAILED_COUNT=0
+ SKIPPED_COUNT=0
+
+ for status in "${INSTALLATION_STATUS[@]}"; do
+ if [[ $status == *"SUCCESS"* ]]; then
+ echo -e "${GREEN}β${NC} $status"
+ ((SUCCESS_COUNT++))
+ elif [[ $status == *"FAILED"* ]]; then
+ echo -e "${RED}β${NC} $status"
+ ((FAILED_COUNT++))
+ else
+ echo " $status"
+ ((SKIPPED_COUNT++))
+ fi
+ done
+
+ echo ""
+ echo "=========================================="
+ echo " Summary Statistics"
+ echo "=========================================="
+ echo -e "${GREEN}Successful steps: $SUCCESS_COUNT${NC}"
+ echo -e "${RED}Failed steps: $FAILED_COUNT${NC}"
+ if [ $SKIPPED_COUNT -gt 0 ]; then
+ echo -e "${YELLOW}Skipped steps: $SKIPPED_COUNT${NC}"
+ fi
+
+ echo ""
+ if [ $FAILED_COUNT -eq 0 ]; then
+ print_success "AgentFly installation completed successfully!"
+ elif [ $FAILED_COUNT -le 2 ]; then
+ print_warning "AgentFly installation completed with minor issues. Some features may not work properly."
+ else
+ print_error "AgentFly installation completed with significant issues. Please review the failed steps above."
+ fi
+
+ echo ""
+ print_status "Next steps:"
+ echo " 1. If you just installed enroot, you may need to restart your terminal"
+ echo " 2. Check the documentation at: https://agentfly.readthedocs.io/"
+ echo " 3. Try running an example: cd verl && bash examples/run_agents/run_code_agent.sh"
+ echo ""
+}
+
+# Run main function
+main "$@"
diff --git a/pyproject.toml b/pyproject.toml
index 92d6504..068ddf6 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -27,7 +27,7 @@ dependencies = [
"docker",
"openai",
"faiss-cpu",
- "vllm==0.9.2",
+ "vllm==0.10.0",
"termcolor",
"tenacity",
"bs4",