Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 58 additions & 18 deletions autoagent/server.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,33 @@
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import os
from contextlib import asynccontextmanager
from typing import Dict, Any, Optional, List
from autoagent.registry import registry
from typing import Any, Dict, List, Optional

from fastapi import Depends, FastAPI, HTTPException, Request, status
from pydantic import BaseModel

from autoagent import MetaChain
from autoagent.registry import registry
from autoagent.types import Agent, Response
import importlib
import inspect


def api_token() -> str:
token = os.environ.get("AUTOAGENT_API_TOKEN", "").strip()
if not token:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="AUTOAGENT_API_TOKEN must be set before using the API server",
)
return token


async def require_api_token(request: Request, token: str = Depends(api_token)) -> None:
auth = request.headers.get("authorization", "").strip()
prefix = "Bearer "
if not auth.startswith(prefix) or auth[len(prefix) :].strip() != token:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="valid bearer token required")


# 定义lifespan上下文管理器
@asynccontextmanager
async def lifespan(app: FastAPI):
Expand All @@ -17,24 +37,36 @@ async def lifespan(app: FastAPI):
# 关闭时执行
# 清理代码(如果需要)

app = FastAPI(title="MetaChain API", lifespan=lifespan)

app = FastAPI(title="MetaChain API", lifespan=lifespan, dependencies=[Depends(require_api_token)])


class ToolRequest(BaseModel):
args: Dict[str, Any]


class AgentRequest(BaseModel):
model: str
query: str
context_variables: Optional[Dict[str, Any]] = {}


class Message(BaseModel):
role: str
content: str


class AgentResponse(BaseModel):
result: str
messages: List
agent_name: str


@app.get("/health", dependencies=[])
async def health():
return {"status": "ok"}


# 为所有注册的tools创建endpoints
@app.on_event("startup")
def create_tool_endpoints():
Expand All @@ -48,68 +80,71 @@ async def create_tool_endpoint(request: ToolRequest, func=tool_func):
name for name, param in sig.parameters.items()
if param.default == inspect.Parameter.empty
}

# 验证是否提供了所有必需参数
if not all(param in request.args for param in required_params):
missing = required_params - request.args.keys()
raise HTTPException(
status_code=400,
detail=f"Missing required parameters: {missing}"
)

result = func(**request.args)
return {"status": "success", "result": result}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))

# 添加endpoint到FastAPI应用
endpoint = create_tool_endpoint
endpoint.__name__ = f"tool_{tool_name}"
app.post(f"/tools/{tool_name}")(endpoint)


# 重写agent endpoints创建逻辑
@app.on_event("startup")
def create_agent_endpoints():
for agent_name, agent_func in registry.agents.items():
async def create_agent_endpoint(
request: AgentRequest,
request: AgentRequest,
func=agent_func
) -> AgentResponse:
try:
# 创建agent实例
agent = func(model=request.model)

# 创建MetaChain实例
mc = MetaChain()

# 构建messages
messages = [
{"role": "user", "content": request.query}
]

# 运行agent
response = mc.run(
agent=agent,
messages=messages,
context_storage=request.context_variables,
debug=True
)

return AgentResponse(
result=response.messages[-1]['content'],
messages=response.messages,
agent_name=agent.name
)

except Exception as e:
raise HTTPException(
status_code=400,
detail=f"Agent execution failed: {str(e)}"
)

endpoint = create_agent_endpoint
endpoint.__name__ = f"agent_{agent_name}"
app.post(f"/agents/{agent_name}/run")(endpoint)


# 获取所有可用的agents信息
@app.get("/agents")
async def list_agents():
Expand All @@ -122,6 +157,7 @@ async def list_agents():
for name, info in registry.agents_info.items()
}


# 获取特定agent的详细信息
@app.get("/agents/{agent_name}")
async def get_agent_info(agent_name: str):
Expand All @@ -130,7 +166,7 @@ async def get_agent_info(agent_name: str):
status_code=404,
detail=f"Agent {agent_name} not found"
)

info = registry.agents_info[agent_name]
return {
"name": agent_name,
Expand All @@ -139,6 +175,10 @@ async def get_agent_info(agent_name: str):
"file_path": info.file_path
}


if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)

host = os.environ.get("AUTOAGENT_API_HOST", "127.0.0.1")
port = int(os.environ.get("AUTOAGENT_API_PORT", "8000"))
uvicorn.run(app, host=host, port=port)
80 changes: 80 additions & 0 deletions tests/test_server_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import importlib.util
import sys
import types
from pathlib import Path

from fastapi.testclient import TestClient


def load_server(monkeypatch):
calls = []

def fake_tool(command, context_variables=None):
calls.append((command, context_variables))
return {"status": 0, "result": command}

registry = types.SimpleNamespace(
tools={"execute_command": fake_tool},
agents={},
agents_info={},
)
mod_registry = types.ModuleType("autoagent.registry")
mod_registry.registry = registry
mod_autoagent = types.ModuleType("autoagent")
mod_autoagent.MetaChain = object
mod_types = types.ModuleType("autoagent.types")
mod_types.Agent = object
mod_types.Response = object
monkeypatch.setitem(sys.modules, "autoagent.registry", mod_registry)
monkeypatch.setitem(sys.modules, "autoagent", mod_autoagent)
monkeypatch.setitem(sys.modules, "autoagent.types", mod_types)

server_path = Path(__file__).resolve().parents[1] / "autoagent" / "server.py"
spec = importlib.util.spec_from_file_location("autoagent_server_under_test", server_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
module.create_tool_endpoints()
return module, calls


def test_tool_endpoint_requires_api_token(monkeypatch):
monkeypatch.setenv("AUTOAGENT_API_TOKEN", "secret-token")
module, calls = load_server(monkeypatch)
client = TestClient(module.app)

response = client.post(
"/tools/execute_command",
json={"args": {"command": "id", "context_variables": {}}},
)

assert response.status_code == 401
assert calls == []


def test_tool_endpoint_accepts_valid_bearer_token(monkeypatch):
monkeypatch.setenv("AUTOAGENT_API_TOKEN", "secret-token")
module, calls = load_server(monkeypatch)
client = TestClient(module.app)

response = client.post(
"/tools/execute_command",
headers={"authorization": "Bearer secret-token"},
json={"args": {"command": "id", "context_variables": {}}},
)

assert response.status_code == 200
assert calls == [("id", {})]


def test_api_refuses_to_run_when_token_is_unconfigured(monkeypatch):
monkeypatch.delenv("AUTOAGENT_API_TOKEN", raising=False)
module, calls = load_server(monkeypatch)
client = TestClient(module.app)

response = client.post(
"/tools/execute_command",
json={"args": {"command": "id", "context_variables": {}}},
)

assert response.status_code == 503
assert calls == []