Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -132,5 +132,9 @@ outputs

# Notebooks
agents/tests/*.ipynb
agents/tests/*.jpg
agents/tests/*.jpeg
agents/tests/*.png
agents/agents/*.ipynb
agents/temp/

59 changes: 46 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,26 +1,53 @@
# AgentFly: Scalable and Extensible Reinforcement Learning for LLM Agents
# 🪽AgentFly: Training scalable LLM agents with RL (multi-turn, async tools/rewards, multimodal)



<p align="center">
<a href="https://arxiv.org/pdf/2507.14897" target="_blank"><img alt="Static Badge" src="https://img.shields.io/badge/Paper-arXiv-%23ffc8dd?style=plastic&link=https%3A%2F%2Farxiv.org%2Fpdf%2F2507.14897"></a>
<a href="https://agentfly.readthedocs.io/" target="_blank"><img alt="Static Badge" src="https://img.shields.io/badge/Docs-AgentFly-%23a2d2ff?style=plastic&link=https%3A%2F%2Fagentfly.readthedocs.io%2F"></a>
<a href="https://huggingface.co/collections/Agent-One/agentfly-6882061c6cf08537cb66c12b" target="_blank"><img alt="Static Badge" src="https://img.shields.io/badge/Model-%F0%9F%A4%97HF-%23ffb703"></a>
<a href="https://arxiv.org/pdf/2507.14897" target="_blank"><img alt="Static Badge" src="https://img.shields.io/badge/arXiv-Paper-%23cdb4db?style=for-the-badge&logo=arxiv"></a>
<a href="https://agentfly.readthedocs.io/" target="_blank"><img alt="Static Badge" src="https://img.shields.io/badge/DOC-AgentFly-%23ffc8dd?style=for-the-badge&logo=readthedocs"></a>
<a href="https://wandb.ai/AgentRL/Open" target="_blank"><img alt="Static Badge" src="https://img.shields.io/badge/W%26B-LOG-%23ffafcc?style=for-the-badge&logo=weightsandbiases"></a>
<a href="https://huggingface.co/collections/Agent-One/agentfly-6882061c6cf08537cb66c12b" target="_blank"><img alt="Static Badge" src="https://img.shields.io/badge/HF-MODEL-%23bde0fe?style=for-the-badge&logo=huggingface"></a>
<a href="https://github.com/Agent-One-Lab/AgentFly" target="_blank"><img alt="Static Badge" src="https://img.shields.io/github/stars/Agent-One-Lab/AgentFly?style=for-the-badge&logo=github&color=a2d2ff"></a>
</p>
<p align="center">
<img alt="Static Badge" src="https://img.shields.io/badge/Agent%20RL--%23000000?style=social">
<img alt="Static Badge" src="https://img.shields.io/badge/Multi%20Turn--%23000000?style=social">
<img alt="Static Badge" src="https://img.shields.io/badge/Multi%20Modal--%23000000?style=social">
<img alt="Static Badge" src="https://img.shields.io/badge/Tool%20--%23000000?style=social">
<img alt="Static Badge" src="https://img.shields.io/badge/Reward%20--%23000000?style=social">
<img alt="Static Badge" src="https://img.shields.io/badge/Container%20--%23000000?style=social">
<img alt="Static Badge" src="https://img.shields.io/badge/Decouple%20--%23000000?style=social">
<br>
<img alt="Static Badge" src="https://img.shields.io/badge/Code%20Interpreter--%23000000?style=social">
<img alt="Static Badge" src="https://img.shields.io/badge/WebShop--%23000000?style=social">
<img alt="Static Badge" src="https://img.shields.io/badge/ScienceWorld--%23000000?style=social">
<img alt="Static Badge" src="https://img.shields.io/badge/Search--%23000000?style=social">
<img alt="Static Badge" src="https://img.shields.io/badge/ALFWorld--%23000000?style=social">
<br>
<img alt="Static Badge" src="https://img.shields.io/badge/Chat%20Template--%23000000?style=social">
<img alt="Static Badge" src="https://img.shields.io/badge/Masking--%23000000?style=social">
<img alt="Static Badge" src="https://img.shields.io/badge/Asynchronous--%23000000?style=social">
<img alt="Static Badge" src="https://img.shields.io/badge/Chain%20Rollout--%23000000?style=social">
</p>


AgentFly is an extensible framework for building LLM agents with reinforcement learning. It supports multi-turn training by adapting traditional RL methods with token-level masking. It features a decorator-based interface for defining tools and reward functions, enabling seamless extension and ease of use. To support high-throughput training, it implemented asynchronous execution of tool calls and reward computations, and design a centralized resource management system for scalable environment coordination. A suite of prebuilt tools and environments are provided.
![Overview](assets/images/overview.png)

## 🆕 News
## News

**Multi-Modal (Vision) Agent Training Support** - Thanks to the powerful template system, AgentFly now supports training vision-language agents! 🎉

Train agents that can see and understand visual content, including GUI automation and image-based QA. See our [predefined training examples](docs/examples/predefined_training_examples.md) for ready-to-use scripts.
**08/2025 Multi-Modal (Vision) Agent Training Support** - Thanks to the powerful template system, AgentFly now supports training vision-language agents! 🎉 Train agents that can see and understand visual content, including GUI automation and image-based QA. See our [predefined training examples](docs/examples/predefined_training_examples.md) for ready-to-use scripts.

---

**New: Chat Template System** - A flexible framework for creating conversation templates with multi-model support, vision capabilities, and tool integration. [Learn more →](docs/chat_template/)
**08/2025 Chat Template System** - A flexible framework for creating conversation templates with multi-model support, vision capabilities, and tool integration. [Learn more →](docs/chat_template/)

## Installation
**Option 1**: One-line Installation:
```
bash install.sh # Assume conda with python3.10.x
```
**Option 2**: Customized Installation

Clone and initialize the project:
```bash
git clone https://github.com/Agent-One-Lab/AgentFly
Expand All @@ -42,6 +69,11 @@ Search requires redis to cache results, an optional way to install with conda:
conda install conda-forge::redis-server==7.4.0
```

## Quick Start
```

```

## Features
### 1. Multi-Chain Agent Rollout and Multi-Turn Training
To support algorithms like GRPO, Reinforce++, we design multi-chain inference, enabling agents to solve one task with multiple paths at the same time. We build RL computation and update LLMs in multi-turn manner by applying token masks. The training is based on [verl](https://github.com/volcengine/verl).
Expand All @@ -67,6 +99,10 @@ agent = ReactAgent(
### 3. Easy Development
Decoupled agent and training module. Simply customize your own agent, which can directly be applied to training.

## Training Curves
Reward curves on Qwen2.5-Instruct 3B and 7B models.
![Curves](assets/images/training_curves.png)


## Training
### Run Example Training
Expand Down Expand Up @@ -155,9 +191,6 @@ https://github.com/user-attachments/assets/b8f42534-8d40-48a0-a264-f378e479bb3a

[Discord](https://discord.gg/CchUj7Sp)

## Training Curves
Reward curves on Qwen2.5-Instruct 3B and 7B models.
![Curves](assets/images/training_curves.png)

## Cite
If you used our code or find it helpful, please cite:
Expand Down
20 changes: 17 additions & 3 deletions agents/agents/agents/agent_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(
self.template = template
self.max_length = max_length
self.tools = tools
self.tool_names = [tool.name for tool in tools]
self.system_prompt = system_prompt
self.model_name_or_path = model_name_or_path

Expand Down Expand Up @@ -209,7 +210,7 @@ def trajectories(self):

return trajectories

def tokenize_trajectories(self, tokenizer = None, return_reward_mask: bool = False):
def tokenize_trajectories(self, tokenizer = None, return_reward_mask: bool = False, concatenate_mm_inputs: bool = True):
if tokenizer is None:
tokenizer = self.tokenizer

Expand Down Expand Up @@ -239,7 +240,16 @@ def tokenize_trajectories(self, tokenizer = None, return_reward_mask: bool = Fal
info['last_response'] = last_response
other_info_list.append(info)

inputs = tokenize_conversations(messages_list, tokenizer=tokenizer, conv_template=self.template, processor=self.processor, max_length=self.max_length, return_reward_mask=return_reward_mask)
inputs = tokenize_conversations(
messages_list,
tokenizer=tokenizer,
template=self.template,
processor=self.processor,
max_length=self.max_length,
return_reward_mask=return_reward_mask,
add_generation_prompt=True,
concatenate_mm_inputs=concatenate_mm_inputs,
)
position_ids = torch.clip(torch.cumsum(inputs['attention_mask'], dim=-1) - 1, min=0, max=None)
inputs['position_ids'] = position_ids

Expand Down Expand Up @@ -318,7 +328,7 @@ def rewards(self):


def get_verl_data_proto(self):
inputs, other_info_list = self.tokenize_trajectories(return_reward_mask=True)
inputs, other_info_list = self.tokenize_trajectories(return_reward_mask=True, concatenate_mm_inputs=False)
group_ids = np.array([info["group_id"] for info in other_info_list], dtype=object)
# Do evaluation here
reward_values, other_values = self.rewards
Expand All @@ -329,6 +339,10 @@ def get_verl_data_proto(self):
inputs[f"rm_{key}"] = np.array(values)
# We handle the group id in the agent side, to be compatible with GRPO
inputs["uid"] = group_ids

if "mm_inputs" in inputs:
mm_inputs = inputs.pop("mm_inputs")
inputs["multi_modal_inputs"] = np.array(mm_inputs, dtype=object)
batch = DataProto.from_single_dict(inputs, meta_info={"use_agent": True})

return batch
Expand Down
7 changes: 6 additions & 1 deletion agents/agents/agents/chain/chain_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,12 @@ async def _execute_tool_call(self, tool_call, newest_messages, chain, chain_id,
have_set_tools = True

# Execute tool call
result = await submit_tool_call(tool_name, tool_input, id=chain_id)
result = await submit_tool_call(
tool_name,
tool_input,
id=chain_id,
allowed_tool_names=self.tool_names
)

if enable_streaming:
# Emit tool observation event
Expand Down
23 changes: 13 additions & 10 deletions agents/agents/agents/specialized/gui_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
# SPDX-License-Identifier: Apache-2.0

import json
import logging
from typing import List, Any, Tuple, Dict, Optional
from ..agent_base import BaseAgent
from agents.utils.ui_action_parser import parse_action_to_structure_output, IMAGE_FACTOR

logger = logging.getLogger(__name__)

# Default image dimensions
TEST_IMAGE_HEIGHT = 1080
TEST_IMAGE_WIDTH = 1920
Expand Down Expand Up @@ -96,8 +99,8 @@ def parse(self, responses: List[str], tools: List[Any]) -> List[Dict[str, Any]]:
Returns:
List of structured messages with tool calls
"""
print(f"[GUIAgent.parse] Number of responses: {len(responses)}")
print(f"[GUIAgent.parse] Raw responses type: {type(responses)}")
logger.debug(f"[GUIAgent.parse] Number of responses: {len(responses)}")
logger.debug(f"[GUIAgent.parse] Raw responses type: {type(responses)}")

new_messages_list = []

Expand All @@ -109,7 +112,7 @@ def parse(self, responses: List[str], tools: List[Any]) -> List[Dict[str, Any]]:
elif resp and resp.strip():
# Try to reformat responses that don't have the expected format
resp_lower = resp.lower()
print(f"[GUIAgent.parse] Response missing format, reformatting: {resp[:100]}")
logger.debug(f"[GUIAgent.parse] Response missing format, reformatting: {resp[:100]}")

# Check if it contains action-like content
if any(action in resp_lower for action in ['click', 'type', 'scroll']):
Expand All @@ -128,9 +131,9 @@ def parse(self, responses: List[str], tools: List[Any]) -> List[Dict[str, Any]]:
# Log responses for debugging
for idx, resp in enumerate(responses[:3]): # Log first 3 responses
if resp:
print(f"[GUIAgent.parse] Response {idx} length: {len(resp)}, preview: {resp[:200]}")
logger.debug(f"[GUIAgent.parse] Response {idx} length: {len(resp)}, preview: {resp[:200]}")
else:
print(f"[GUIAgent.parse] Response {idx} is None or empty")
logger.debug(f"[GUIAgent.parse] Response {idx} is None or empty")

# Parse actions from responses
action_list = []
Expand All @@ -145,13 +148,13 @@ def parse(self, responses: List[str], tools: List[Any]) -> List[Dict[str, Any]]:

# Create messages with tool calls
for i, (response, actions) in enumerate(zip(responses, action_list)):
print(f"[GUIAgent.parse] Processing response {i+1}: response_length={len(response) if response else 0}, actions={actions}")
logger.debug(f"[GUIAgent.parse] Processing response {i+1}: response_length={len(response) if response else 0}, actions={actions}")

tool_calls = []

if actions is not None and len(actions) > 0:
if len(actions) > 1:
print(f"[GUIAgent.parse] Warning: Multiple actions found ({len(actions)}), using first one")
logger.debug(f"[GUIAgent.parse] Warning: Multiple actions found ({len(actions)}), using first one")
action = actions[0]
tool_calls = [{
"id": str(i),
Expand All @@ -163,7 +166,7 @@ def parse(self, responses: List[str], tools: List[Any]) -> List[Dict[str, Any]]:
}]
else:
# If no action was parsed, create a default click action at center
print(f"[GUIAgent.parse] No action parsed from response, creating default click action")
logger.debug(f"[GUIAgent.parse] No action parsed from response, creating default click action")
default_action = {
"action_type": "click",
"action_inputs": {"start_box": "(960, 540)"},
Expand All @@ -184,7 +187,7 @@ def parse(self, responses: List[str], tools: List[Any]) -> List[Dict[str, Any]]:
status = "terminal"
if actions and isinstance(actions[0], dict):
action_type = actions[0].get("action_type", "")
print(f"[GUIAgent.parse] Action type: {action_type}, terminating after one turn")
logger.debug(f"[GUIAgent.parse] Action type: {action_type}, terminating after one turn")

message = {
"role": "assistant",
Expand All @@ -193,7 +196,7 @@ def parse(self, responses: List[str], tools: List[Any]) -> List[Dict[str, Any]]:
"loss": True,
"status": status
}
print(f"[GUIAgent.parse] Created message with status={status}, tool_calls={len(tool_calls)}, content_length={len(response)}")
logger.debug(f"[GUIAgent.parse] Created message with status={status}, tool_calls={len(tool_calls)}, content_length={len(response)}")
new_messages_list.append(message)

return new_messages_list
Expand Down
8 changes: 6 additions & 2 deletions agents/agents/agents/templates/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,10 @@ def _encode_user_message_with_tools(self, content, tools: str) -> str:
for item in content:
if item["type"] == "text":
text += item["text"]
elif item["type"] in ["image", "image_url"]:
text += self.vision_start + self.image_token + self.vision_end
elif item["type"] == "video":
text += self.vision_start + self.video_token + self.vision_end
else:
raise ValueError(f"Invalid message type: {item['type']}")

Expand Down Expand Up @@ -949,11 +953,11 @@ def _convert_single_message_to_hf_format(self, message: Dict) -> Dict:
# Not sure what to do with other types of content
pass

def convert_to_hf_format_messages(self, messages: List[Dict]) -> List[Dict]:
def convert_to_hf_format_messages(self, messages: Union[List[Dict], Dict[str, List[Dict]]]) -> List[Dict]:
hf_messages = []
if messages is None:
return None
role_label, content_label = self._detect_labels(messages)
hf_messages = []
for message in messages:
hf_messages.append({"role": message[role_label], "content": message[content_label]})

Expand Down
59 changes: 52 additions & 7 deletions agents/agents/agents/templates/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import defaultdict
import copy
from enum import Enum
import os
Expand Down Expand Up @@ -219,26 +220,56 @@ def process_prompt_with_vision(
)


def tokenize_conversations(messages_list, tokenizer, conv_template, max_length, processor=None, return_tensors="pt", return_reward_mask=False):
def tokenize_conversations(
messages_list,
tokenizer,
template,
max_length,
processor=None,
return_tensors="pt",
return_reward_mask=False,
add_generation_prompt=False,
padding_side="right",
concatenate_mm_inputs=False
):
batch_input_ids = []
batch_attention_masks = []
batch_labels = []
batch_action_masks = []
batch_mm_inputs = []
# TODO: add multiprocessing
for messages in messages_list:
inputs = tokenize_conversation(messages, tokenizer, conv_template, max_length, processor=processor)
inputs = tokenize_conversation(messages, tokenizer, template, max_length, processor=processor, add_generation_prompt=add_generation_prompt)
batch_input_ids.append(inputs['input_ids'].squeeze(0))
batch_attention_masks.append(inputs['attention_mask'].squeeze(0))
batch_labels.append(inputs['labels'].squeeze(0))
batch_action_masks.append(inputs['action_mask'].squeeze(0))

mm_inputs = {}
if "pixel_values" in inputs:
mm_inputs["pixel_values"] = inputs["pixel_values"]
else:
mm_inputs["pixel_values"] = None
if "image_grid_thw" in inputs:
mm_inputs["image_grid_thw"] = inputs["image_grid_thw"]
else:
mm_inputs["image_grid_thw"] = None

batch_mm_inputs.append(mm_inputs)

if return_tensors == "pt":
# Use pad_token_id from the tokenizer interface
pad_token_id = getattr(tokenizer, 'pad_token_id', 0)
batch_input_ids = torch.nn.utils.rnn.pad_sequence(batch_input_ids, batch_first=True, padding_value=pad_token_id)
batch_attention_masks = torch.nn.utils.rnn.pad_sequence(batch_attention_masks, batch_first=True, padding_value=0)
batch_labels = torch.nn.utils.rnn.pad_sequence(batch_labels, batch_first=True, padding_value=-100)
batch_action_masks = torch.nn.utils.rnn.pad_sequence(batch_action_masks, batch_first=True, padding_value=0)

batch_input_ids = torch.nn.utils.rnn.pad_sequence(batch_input_ids, batch_first=True, padding_value=pad_token_id, padding_side=padding_side)
batch_attention_masks = torch.nn.utils.rnn.pad_sequence(batch_attention_masks, batch_first=True, padding_value=0, padding_side=padding_side)
batch_labels = torch.nn.utils.rnn.pad_sequence(batch_labels, batch_first=True, padding_value=-100, padding_side=padding_side)
batch_action_masks = torch.nn.utils.rnn.pad_sequence(batch_action_masks, batch_first=True, padding_value=0, padding_side=padding_side)

# convert [{"pixel_values": tensor, "image_grid_thw": tensor}, ...] to {"key1": concat_tensor, "key2": concat_tensor, ...}
concatenated_mm_inputs = {}
if concatenate_mm_inputs:
for key in batch_mm_inputs[0].keys():
concatenated_mm_inputs[key] = torch.cat([mm_inputs[key] for mm_inputs in batch_mm_inputs if mm_inputs[key] is not None], dim=0)

inputs = dict(
input_ids=batch_input_ids,
Expand All @@ -250,6 +281,20 @@ def tokenize_conversations(messages_list, tokenizer, conv_template, max_length,
if return_reward_mask:
inputs['reward_mask'] = transform_reward_mask(batch_action_masks)

# Check if we need mm_inputs
mm_keys = list(batch_mm_inputs[0].keys())
return_mm_inputs = False
for key in mm_keys:
if any(mm_inputs[key] is not None for mm_inputs in batch_mm_inputs):
return_mm_inputs = True
break

if return_mm_inputs:
if concatenate_mm_inputs:
inputs.update(concatenated_mm_inputs)
else:
inputs["mm_inputs"] = batch_mm_inputs

return inputs

def visualize_template(template, messages=None, tools=None, **kwargs):
Expand Down
Loading