Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion agentfly/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,6 @@

AGENT_CONFIG_DIR = os.getenv("AGENT_CONFIG_DIR", AGENT_CONFIG_DIR)

ENROOT_HOME = os.getenv("ENROOT_HOME", ENROOT_HOME)
ENROOT_HOME = os.getenv("ENROOT_HOME", ENROOT_HOME)

os.environ["VLLM_CONFIGURE_LOGGING"] = "0"
9 changes: 7 additions & 2 deletions agentfly/agents/agent_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
from .utils.messages import MessagesList
from .templates.templates import get_template
from ..__init__ import AGENT_DATA_DIR
from .llm_backend import AsyncVLLMBackend, AsyncVerlBackend, ClientBackend, TransformersBackend, VLLMBackend
from .llm_backends import (
AsyncVLLMBackend,
AsyncVerlBackend,
ClientBackend,
TransformersBackend,
)
from .llm_backends.backend_configs import BACKEND_CONFIGS
from ..utils.logging import get_logger
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
Expand All @@ -18,7 +24,6 @@
import logging
from .chain.streaming_observer import ConsoleStreamObserver, StreamingManager
from .utils.tokenizer import create_processor, create_tokenizer
from .backend_config import BACKEND_CONFIGS
try:
from verl.protocol import DataProto
except ImportError:
Expand Down
66 changes: 0 additions & 66 deletions agentfly/agents/backend_config.py

This file was deleted.

15 changes: 15 additions & 0 deletions agentfly/agents/llm_backends/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from .backend_configs import (
TransformersConfig,
VLLMConfig,
AsyncVLLMConfig,
AsyncVerlConfig,
ClientConfig,
)

from .llm_backends import (
TransformersBackend,
VLLMBackend,
AsyncVLLMBackend,
AsyncVerlBackend,
ClientBackend,
)
122 changes: 122 additions & 0 deletions agentfly/agents/llm_backends/backend_configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from dataclasses import dataclass
from typing import Optional, Dict, Any, List

from vllm import AsyncEngineArgs


@dataclass
class TransformersConfig:
"""Configuration for Transformers backend using Hugging Face models.

Attributes:
temperature (float): Sampling temperature for text generation. Controls randomness.
Higher values (e.g., 1.0) make output more random, lower values (e.g., 0.1) make it more deterministic.
Defaults to 1.0.
max_new_tokens (int): Maximum number of new tokens to generate. Defaults to 1024.
trust_remote_code (bool): Whether to trust remote code when loading models.
This is required for some custom models. Defaults to True.
device_map (str): Device mapping strategy for model placement.
Options include "auto", "cpu", "cuda:0", etc. Defaults to "auto".
"""
temperature: float = 1.0
max_new_tokens: int = 1024
trust_remote_code: bool = True
device_map: str = "auto"


@dataclass
class VLLMConfig:
"""Configuration for VLLM backend for high-performance inference.

Attributes:
temperature (float): Sampling temperature for text generation. Controls randomness.
Higher values (e.g., 1.0) make output more random, lower values (e.g., 0.1) make it more deterministic.
Defaults to 1.0.
max_new_tokens (int): Maximum number of new tokens to generate. Defaults to 1024.
"""
temperature: float = 1.0
max_new_tokens: int = 1024



@dataclass
class AsyncVLLMConfig:
"""Configuration for Async VLLM backend with engine arguments. Arguments are the same as vLLM's arguments, which can
be found at https://docs.vllm.ai/en/latest/configuration/engine_args.html. Here listed some important arguments:

Attributes:
gpu_memory_utilization (float): The fraction of GPU memory to be used for the model executor, which can range from 0 to 1.
max_model_len (int): Model context length (prompt and output). If unspecified, will be automatically derived from the model config.
rope_scaling (dict): Rope scaling. For example, {"rope_type":"dynamic","factor":2.0}.
trust_remote_code (bool): Whether to trust remote code when loading models.
pipeline_parallel_size (int): Pipeline parallel size.
data_parallel_size (int): Data parallel size.
tensor_parallel_size (int): Tensor parallel size.
"""
engine_args: AsyncEngineArgs = AsyncEngineArgs()

def __init__(self, **kwargs):
self.engine_args = AsyncEngineArgs(**kwargs)


@dataclass
class VerlConfig:
"""Configuration for Verl backend.

Attributes:
temperature (float): Sampling temperature for text generation. Controls randomness.
Higher values (e.g., 1.0) make output more random, lower values (e.g., 0.1) make it more deterministic.
Defaults to 1.0.
max_new_tokens (int): Maximum number of new tokens to generate. Defaults to 1024.
"""
temperature: float = 1.0
max_new_tokens: int = 1024


@dataclass
class AsyncVerlConfig:
"""Configuration for Async Verl backend.

Attributes:
temperature (float): Sampling temperature for text generation. Controls randomness.
Higher values (e.g., 1.0) make output more random, lower values (e.g., 0.1) make it more deterministic.
Defaults to 1.0.
max_new_tokens (int): Maximum number of new tokens to generate. Defaults to 1024.
"""
temperature: float = 1.0
max_new_tokens: int = 1024


@dataclass
class ClientConfig:
"""Configuration for Client backend (OpenAI-compatible)

This configuration class provides settings for connecting to OpenAI-compatible
API endpoints, such as local models served via vLLM, Ollama, or other
compatible servers.

Attributes:
base_url: The base URL for the API endpoint. Defaults to localhost:8000.
max_requests_per_minute: Rate limiting for API requests. Defaults to 100.
timeout: Request timeout in seconds. Defaults to 600 (10 minutes).
api_key: API key for authentication. Defaults to "EMPTY" for local servers.
max_new_tokens: Maximum number of tokens to generate. Defaults to 1024.
temperature: Sampling temperature for text generation. Defaults to 1.0.
"""
base_url: str = "http://localhost:8000/v1"
max_requests_per_minute: int = 100
timeout: int = 600
api_key: str = "EMPTY"
max_new_tokens: int = 1024
temperature: float = 1.0


# Backend configuration mapping
BACKEND_CONFIGS = {
"transformers": TransformersConfig,
"vllm": VLLMConfig,
"async_vllm": AsyncVLLMConfig,
"verl": VerlConfig,
"async_verl": AsyncVerlConfig,
"client": ClientConfig,
}
Loading