Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,4 @@ test_outputs/
agentfly/data/
*.ipynb

./*.jpg
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ Please refer to [installation.md](docs/start/installation.md) for custmoized ins
```python
# Really small example to build an agent and run
from agentfly.agents import HFAgent
from agentfly.tools import calculate
from agentfly.tools import calculator
messages = [{"role": "user", "content": "What is the result of 1 + 1?"}]
agent = HFAgent(
model_name_or_path="Qwen/Qwen2.5-3B-Instruct",
tools=[calculate],
tools=[calculator],
template="qwen2.5",
backend="async_vllm",
)
Expand Down
3 changes: 1 addition & 2 deletions agentfly/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,4 @@
from .specialized.code_agent import CodeAgent
from .specialized.think_agent import ThinkAgent
from .specialized.gui_agent import GUIAgent
from .specialized.hf_agent import HFAgent
from .templates.utils import process_vision_info, tokenize_conversation, tokenize_conversations
from .specialized.hf_agent import HFAgent
5 changes: 2 additions & 3 deletions agentfly/agents/agent_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import defaultdict
import json
from .utils.messages import MessagesList
from .templates.templates import get_template
from ..templates.templates import get_template
from ..__init__ import AGENT_DATA_DIR
from .llm_backends import (
AsyncVLLMBackend,
Expand All @@ -15,8 +15,7 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from .templates.utils import tokenize_conversations
from .templates.vision_processor import is_vision_template
from ..templates import tokenize_conversations
from .chain.chain_base import ChainRollout
import os
import transformers
Expand Down
3 changes: 1 addition & 2 deletions agentfly/agents/llm_backends/llm_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
from ...utils.verl import pad_tensor_to_rank_size
from vllm import LLM, AsyncLLMEngine, SamplingParams, AsyncEngineArgs
import openai
from ..templates.templates import Chat
from ..templates.vision_processor import get_processor
from ...templates import Chat
import logging
import PIL

Expand Down
4 changes: 2 additions & 2 deletions agentfly/agents/specialized/openai_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from ...tools import answer_qa
from ...tools.tool_base import tool
from ..agent_base import BaseAgent
from ..llm_backend import ClientBackend
from ..backend_config import ClientConfig
from ..llm_backends import ClientBackend
from ..llm_backends.backend_configs import ClientConfig
from tenacity import retry, wait_random_exponential, stop_after_attempt
from termcolor import colored
import json
Expand Down
Empty file.
10 changes: 10 additions & 0 deletions agentfly/templates/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from .templates import Template, Chat, get_template, register_template
from .utils import (
process_vision_info,
tokenize_conversation,
tokenize_conversations,
compare_hf_template,
)
from .tool_policy import ToolPolicy, JsonFormatter
from .system_policy import SystemPolicy
from .global_policy import GlobalPolicy
File renamed without changes.
5 changes: 5 additions & 0 deletions agentfly/templates/global_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import dataclasses

@dataclasses.dataclass
class GlobalPolicy:
prefix: str = None
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .system_policy import Llama32DateProcessor, SystemPolicy
from .tool_policy import ToolPolicy
from .constants import ToolPlacement, Role
from .global_policy import GlobalPolicy

Logger = logging.getLogger(__name__)

Expand All @@ -36,15 +37,25 @@
console_handler.setFormatter(formatter)
Logger.addHandler(console_handler)

@dataclasses.dataclass
class GlobalPolicy:
prefix: str = None


@dataclasses.dataclass
class Template:
"""A class that manages prompt templates and keeps all conversation history."""
# Properties
"""Class that holds all the components of a chat template. Convert messages to string prompts, tokenize messages to token ids, and generate jinja-based chat templates.

Args:
name: The name of this template
system_template: The system template component
system_template_with_tools: The system template with tool usage component
system_message: The default system message
stop_words: The stop words where the model stops generating (usually EOS token)
tool_template: The tool response template component
user_template: The user template component
user_template_with_tools: The user template with tool usage component
assistant_template: The assistant template component
global_policy: The global policy, controls the behavior of the template
system_policy: The system message policy, controls the behavior of forming the system message
tool_policy: The tool policy for the template, controls the behavior of forming tools.
"""
# The name of this template
name: str
# The template of the system prompt
Expand Down Expand Up @@ -142,17 +153,24 @@ def _supports_tool_call(self) -> bool:
return False

def render(self, messages: List[Dict], tools=None, add_generation_prompt: bool = False) -> str:
"""Render the template and return
1. the final prompt string,
2. the list of string *elements* that compose the prompt, and
3. the corresponding list of *roles* (used by downstream post-processing).
"""Render the template.

The heavy lifting is delegated to small, single-purpose helpers so the
high-level flow is immediately apparent:

1. _insert_tools – decide where the tool catalogue lives
2. _encode_turns – encode every conversation turn
3. _maybe_add_generation_prompt – append the generation prefix if requested

Args:
messages: The list of messages
tools: The list of tools
add_generation_prompt: Whether to add the generation prefix

Returns:
prompt: The final prompt string
elements: The list of string *elements* that compose the prompt
roles: The corresponding list of *roles* (used by downstream post-processing)
"""

# Step 1 – decide tool placement & clone messages
Expand All @@ -173,15 +191,14 @@ def _insert_tools(self, messages: List[Dict], tools):
"""Clone *messages* and compute where (and how) the tool catalogue
should be injected.

Returns
-------
work_messages : List[Dict]
A deepcopy of the original *messages* so we never mutate caller data.
tools_str : Optional[str]
The formatted tool catalogue or *None* if `tools` is falsy.
insert_tools_idx : int
Index of the *user* message that receives the catalogue, or -1 when
no injection is required.
Returns:
work_messages : List[Dict]
A deepcopy of the original *messages* so we never mutate caller data.
tools_str : Optional[str]
The formatted tool catalogue or *None* if `tools` is falsy.
insert_tools_idx : int
Index of the *user* message that receives the catalogue, or -1 when
no injection is required.
"""

work_messages = deepcopy(messages)
Expand Down Expand Up @@ -418,6 +435,19 @@ def _split_assistant_message(self, assistant_message: str) -> List[str]:


def encode(self, messages: List[Dict], tokenizer: PreTrainedTokenizer, return_tensors: str = None, tools=None, add_generation_prompt=False, processor=None) -> str:
"""Encode the messages to token ids.

Args:
messages: The list of messages
tokenizer: The tokenizer
return_tensors: The return tensors
tools: The list of tools
add_generation_prompt: Whether to add the generation prefix
processor: The processor for vision templates

Returns:
inputs: The dictionary of input ids, attention mask, labels, and action mask
"""
if processor is None and self.supports_vision():
raise ValueError(f"Processor is required for vision templates: {self.name}")

Expand Down Expand Up @@ -587,6 +617,11 @@ def get_vision_inputs(self, messages: List[Dict]):
return vision_inputs

def jinja_template(self) -> str:
"""Interface for getting the Jinja template.

Returns:
The Jinja template string
"""
if self.chat_template:
return self.chat_template
else:
Expand Down Expand Up @@ -926,6 +961,13 @@ def dict(self):

class Chat:
def __init__(self, template: str, messages: List[List[str]]=None, tools=None, tokenizer: PreTrainedTokenizer = None):
"""
Args:
template: The name of the template to use.
messages: The messages to use for the chat.
tools: The tools to use for the chat.
tokenizer: The tokenizer to use for the chat.
"""
self.template = get_template(template)
self.messages = self.convert_to_hf_format_messages(messages)
self.tokenizer = tokenizer
Expand Down Expand Up @@ -967,10 +1009,21 @@ def convert_to_hf_format_messages(self, messages: Union[List[Dict], Dict[str, Li
return hf_messages

def set_messages(self, messages: List[Dict]):
"""Set the messages for the chat."""
self.messages = self.convert_to_hf_format_messages(messages)

def prompt(self, add_generation_prompt=False, tools=None) -> str:
"""Get the prompt for the chat.

Args:
add_generation_prompt: Whether to add the generation prompt.
tools: The tools to use for the chat.

Returns:
The prompt for the chat.
"""
self.flags['add_generation_prompt'] = add_generation_prompt
tools = tools or self.tools
prompt, _, _ = self.template.render(messages=self.messages, tools=tools, add_generation_prompt=add_generation_prompt)
return prompt

Expand All @@ -982,13 +1035,32 @@ def vision_inputs(self) -> List[Any]:
return self.template.get_vision_inputs(self.messages)

def tokenize(self, tokenizer: PreTrainedTokenizer = None, add_generation_prompt=False, tools=None, processor=None) -> List[int]:
"""Tokenize the messages.

Args:
tokenizer: The tokenizer to use for the chat.
add_generation_prompt: Whether to add the generation prompt.
tools: The tools to use for the chat.
processor: The processor to use for the chat.

Returns:
inputs (dict): Inputs for helping training.
- input_ids
- attention_mask
- labels
- action_mask
- multi_modal_inputs
"""
if tokenizer is None:
if self.tokenizer is None:
raise ValueError("Tokenizer is not set. Set it when initializing the chat or pass it as an argument.")
tokenizer = self.tokenizer

if tools is None:
tools = self.tools
return self.template.encode(messages=self.messages, tokenizer=tokenizer, return_tensors="pt", tools=tools, add_generation_prompt=add_generation_prompt, processor=processor)

def append(self, message: Union[Dict, List[Dict]]):
def append(self, message: Union[Dict]):
self._convert_single_message_to_hf_format(message)
self.messages.append(message)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
import re
import logging
from .templates import Chat, get_template
from ... import AGENT_DATA_DIR
from .. import AGENT_DATA_DIR
from typing import Any
from .vision_processor import get_processor
# Set up logging that won't be overridden by other modules

LOGGER = logging.getLogger(__name__)

ANSI_RE = re.compile(r'\x1b\[[0-9;]*m') # matches any ANSI color/style code
Expand Down Expand Up @@ -269,7 +269,8 @@ def tokenize_conversations(
concatenated_mm_inputs = {}
if concatenate_mm_inputs:
for key in batch_mm_inputs[0].keys():
concatenated_mm_inputs[key] = torch.cat([mm_inputs[key] for mm_inputs in batch_mm_inputs if mm_inputs[key] is not None], dim=0)
if mm_inputs[key]:
concatenated_mm_inputs[key] = torch.cat([mm_inputs[key] for mm_inputs in batch_mm_inputs if mm_inputs[key] is not None], dim=0)

inputs = dict(
input_ids=batch_input_ids,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""


from .....agents.templates.utils import compare_hf_template
from .....templates import compare_hf_template
from transformers import AutoTokenizer
import pytest

Expand Down
7 changes: 6 additions & 1 deletion agentfly/tools/tool_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
import concurrent.futures
from ..envs.manager.env_manager import EnvironmentManager

import logging

logger = logging.getLogger(__name__)

# current_env = contextvars.ContextVar("current_env")

class Tool:
Expand Down Expand Up @@ -271,7 +275,8 @@ def decorator(func):
func_name = func.__name__
final_name = name or func_name
if name and name != func_name:
warnings.warn(f"Tool name {func_name!r} overridden by {name!r}")
logger.warning(f"Tool name {func_name!r} overridden by {name!r}")
# warnings.warn(f"Tool name {func_name!r} overridden by {name!r}")

signature = extract_signatures(func)
docs = parse_docstring(inspect.getdoc(func))
Expand Down
5 changes: 4 additions & 1 deletion agentfly/tools/utils/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import inspect
import warnings
from copy import deepcopy
import logging

logger = logging.getLogger(__name__)

def extract_signatures(func):
sig = inspect.signature(func)
Expand Down Expand Up @@ -153,7 +156,7 @@ def validate_schema(name, description, signature, docs):
else:
# May be should raise an error
properties[param]['type'] = "unknown"
warnings.warn(f"Parameter {param} has no type in signature or docstring.")
logger.warning(f"Parameter {param} has no type in signature or docstring.")

if "default" in signature[param]:
properties[param]['default'] = signature[param]['default']
Expand Down
6 changes: 0 additions & 6 deletions docs/api_references/agents/agent.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,4 @@ The foundation class for all agents in AgentFly:
:members:
:show-inheritance:

Chain Generation
----------------

Base class for chain-based generation:

.. autoclass:: agentfly.agents.chain.chain_base.ChainRollout
:members:
Loading