From 6977ceabea97336ebcfac5dbba90a423e6562a4c Mon Sep 17 00:00:00 2001 From: hinotoi-agent Date: Sun, 10 May 2026 13:53:50 +0800 Subject: [PATCH] fix: require token for API tool server --- autoagent/server.py | 76 ++++++++++++++++++++++++++++--------- tests/test_server_auth.py | 80 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 138 insertions(+), 18 deletions(-) create mode 100644 tests/test_server_auth.py diff --git a/autoagent/server.py b/autoagent/server.py index c82e108..2990c52 100644 --- a/autoagent/server.py +++ b/autoagent/server.py @@ -1,13 +1,33 @@ -from fastapi import FastAPI, HTTPException -from pydantic import BaseModel +import os from contextlib import asynccontextmanager -from typing import Dict, Any, Optional, List -from autoagent.registry import registry +from typing import Any, Dict, List, Optional + +from fastapi import Depends, FastAPI, HTTPException, Request, status +from pydantic import BaseModel + from autoagent import MetaChain +from autoagent.registry import registry from autoagent.types import Agent, Response -import importlib import inspect + +def api_token() -> str: + token = os.environ.get("AUTOAGENT_API_TOKEN", "").strip() + if not token: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="AUTOAGENT_API_TOKEN must be set before using the API server", + ) + return token + + +async def require_api_token(request: Request, token: str = Depends(api_token)) -> None: + auth = request.headers.get("authorization", "").strip() + prefix = "Bearer " + if not auth.startswith(prefix) or auth[len(prefix) :].strip() != token: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="valid bearer token required") + + # 定义lifespan上下文管理器 @asynccontextmanager async def lifespan(app: FastAPI): @@ -17,24 +37,36 @@ async def lifespan(app: FastAPI): # 关闭时执行 # 清理代码(如果需要) -app = FastAPI(title="MetaChain API", lifespan=lifespan) + +app = FastAPI(title="MetaChain API", lifespan=lifespan, dependencies=[Depends(require_api_token)]) + class ToolRequest(BaseModel): args: Dict[str, Any] + class AgentRequest(BaseModel): model: str query: str context_variables: Optional[Dict[str, Any]] = {} + class Message(BaseModel): role: str content: str + class AgentResponse(BaseModel): result: str messages: List agent_name: str + + +@app.get("/health", dependencies=[]) +async def health(): + return {"status": "ok"} + + # 为所有注册的tools创建endpoints @app.on_event("startup") def create_tool_endpoints(): @@ -48,7 +80,7 @@ async def create_tool_endpoint(request: ToolRequest, func=tool_func): name for name, param in sig.parameters.items() if param.default == inspect.Parameter.empty } - + # 验证是否提供了所有必需参数 if not all(param in request.args for param in required_params): missing = required_params - request.args.keys() @@ -56,36 +88,38 @@ async def create_tool_endpoint(request: ToolRequest, func=tool_func): status_code=400, detail=f"Missing required parameters: {missing}" ) - + result = func(**request.args) return {"status": "success", "result": result} except Exception as e: raise HTTPException(status_code=400, detail=str(e)) - + # 添加endpoint到FastAPI应用 endpoint = create_tool_endpoint endpoint.__name__ = f"tool_{tool_name}" app.post(f"/tools/{tool_name}")(endpoint) + + # 重写agent endpoints创建逻辑 @app.on_event("startup") def create_agent_endpoints(): for agent_name, agent_func in registry.agents.items(): async def create_agent_endpoint( - request: AgentRequest, + request: AgentRequest, func=agent_func ) -> AgentResponse: try: # 创建agent实例 agent = func(model=request.model) - + # 创建MetaChain实例 mc = MetaChain() - + # 构建messages messages = [ {"role": "user", "content": request.query} ] - + # 运行agent response = mc.run( agent=agent, @@ -93,23 +127,24 @@ async def create_agent_endpoint( context_storage=request.context_variables, debug=True ) - + return AgentResponse( result=response.messages[-1]['content'], messages=response.messages, agent_name=agent.name ) - + except Exception as e: raise HTTPException( status_code=400, detail=f"Agent execution failed: {str(e)}" ) - + endpoint = create_agent_endpoint endpoint.__name__ = f"agent_{agent_name}" app.post(f"/agents/{agent_name}/run")(endpoint) + # 获取所有可用的agents信息 @app.get("/agents") async def list_agents(): @@ -122,6 +157,7 @@ async def list_agents(): for name, info in registry.agents_info.items() } + # 获取特定agent的详细信息 @app.get("/agents/{agent_name}") async def get_agent_info(agent_name: str): @@ -130,7 +166,7 @@ async def get_agent_info(agent_name: str): status_code=404, detail=f"Agent {agent_name} not found" ) - + info = registry.agents_info[agent_name] return { "name": agent_name, @@ -139,6 +175,10 @@ async def get_agent_info(agent_name: str): "file_path": info.file_path } + if __name__ == "__main__": import uvicorn - uvicorn.run(app, host="0.0.0.0", port=8000) \ No newline at end of file + + host = os.environ.get("AUTOAGENT_API_HOST", "127.0.0.1") + port = int(os.environ.get("AUTOAGENT_API_PORT", "8000")) + uvicorn.run(app, host=host, port=port) diff --git a/tests/test_server_auth.py b/tests/test_server_auth.py new file mode 100644 index 0000000..477842e --- /dev/null +++ b/tests/test_server_auth.py @@ -0,0 +1,80 @@ +import importlib.util +import sys +import types +from pathlib import Path + +from fastapi.testclient import TestClient + + +def load_server(monkeypatch): + calls = [] + + def fake_tool(command, context_variables=None): + calls.append((command, context_variables)) + return {"status": 0, "result": command} + + registry = types.SimpleNamespace( + tools={"execute_command": fake_tool}, + agents={}, + agents_info={}, + ) + mod_registry = types.ModuleType("autoagent.registry") + mod_registry.registry = registry + mod_autoagent = types.ModuleType("autoagent") + mod_autoagent.MetaChain = object + mod_types = types.ModuleType("autoagent.types") + mod_types.Agent = object + mod_types.Response = object + monkeypatch.setitem(sys.modules, "autoagent.registry", mod_registry) + monkeypatch.setitem(sys.modules, "autoagent", mod_autoagent) + monkeypatch.setitem(sys.modules, "autoagent.types", mod_types) + + server_path = Path(__file__).resolve().parents[1] / "autoagent" / "server.py" + spec = importlib.util.spec_from_file_location("autoagent_server_under_test", server_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + module.create_tool_endpoints() + return module, calls + + +def test_tool_endpoint_requires_api_token(monkeypatch): + monkeypatch.setenv("AUTOAGENT_API_TOKEN", "secret-token") + module, calls = load_server(monkeypatch) + client = TestClient(module.app) + + response = client.post( + "/tools/execute_command", + json={"args": {"command": "id", "context_variables": {}}}, + ) + + assert response.status_code == 401 + assert calls == [] + + +def test_tool_endpoint_accepts_valid_bearer_token(monkeypatch): + monkeypatch.setenv("AUTOAGENT_API_TOKEN", "secret-token") + module, calls = load_server(monkeypatch) + client = TestClient(module.app) + + response = client.post( + "/tools/execute_command", + headers={"authorization": "Bearer secret-token"}, + json={"args": {"command": "id", "context_variables": {}}}, + ) + + assert response.status_code == 200 + assert calls == [("id", {})] + + +def test_api_refuses_to_run_when_token_is_unconfigured(monkeypatch): + monkeypatch.delenv("AUTOAGENT_API_TOKEN", raising=False) + module, calls = load_server(monkeypatch) + client = TestClient(module.app) + + response = client.post( + "/tools/execute_command", + json={"args": {"command": "id", "context_variables": {}}}, + ) + + assert response.status_code == 503 + assert calls == []