diff --git a/README.md b/README.md index 28d2814..9f63f6a 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,7 @@ - [A2A](#a2a-agent-to-agent) - [Retrieval (RAG)](#retrieval-rag) - [Sub-agents](#sub-agents) + - [Managing Capabilities at Runtime](#managing-capabilities-at-runtime) - [Approvals](#approvals) - [Timeouts and deadlines](#timeouts-and-deadlines) - [Custom tools](#custom-tools) @@ -347,11 +348,13 @@ Custom tools may also implement: - `interfaces.ToolApproval` — tool-level hint for **interactive human approval**. Use this when a person should decide whether the tool runs, and no agent-level approval policy is set. - `interfaces.ToolAuthorizer` — tool-level **programmatic authorization**. Use this when code should decide whether the tool runs before approval/execute (for example: scopes, tenancy, environment flags, or feature access). Return `Allow=false` to deny the tool call without executing it. +- `interfaces.ToolKindProvider` — optional interface that reports the tool's origin category. The built-in tool wrappers already implement it (`"mcp"`, `"a2a"`, `"sub-agent"`, `"retriever"`). Implement it on custom tools when you want to distinguish origin in logs or metrics. Use `interfaces.KindOf(tool)` to read the kind from any tool; returns `"native"` when the interface is not implemented. ```go -reg := tools.NewRegistry() -reg.Register(calculator.New()) -reg.Register(weather.New()) +reg := agent.NewToolRegistry() +if err := agent.RegisterTools(reg, calculator.New(), weather.New()); err != nil { + log.Fatal(err) +} a, _ := agent.NewAgent( agent.WithTemporalConfig(...), @@ -372,7 +375,7 @@ result, _ := a.Run(ctx, "What's the weather in Tokyo?", nil) MCP servers extend your agent with external tools that work identically to built-in tools across `Run`, `Stream`, `RunAsync`, and approval gates. Each server needs a **unique** name in config (the `WithMCPConfig` map key or the first argument to `mcpclient.NewClient`); tools are registered under stable names so they do not collide when several servers expose the same logical tool id. -At `NewAgent`, the SDK connects to each server, discovers its tools, applies any `**ToolFilter`** (`AllowTools`/`BlockTools`), and registers the results — failing fast if a server is unreachable. +At `NewAgent`, the SDK connects to each server, discovers its tools, applies any `**ToolFilter`** (`AllowTools`/`BlockTools`), and validates the setup — failing fast if a server is unreachable. After creation, add or remove MCP servers at any time via `a.MCPRegistry()` (see [Managing Capabilities at Runtime](#managing-capabilities-at-runtime)); the next `Run`, `Stream`, or `RunAsync` uses whatever servers are in the registry at that point. Use `mcp.MCPStdio` (local process) or `mcp.MCPStreamableHTTP` (remote) from `pkg/mcp` for transport. Streamable HTTP supports `Token`, `OAuthClientCreds`, custom `Headers`, and `SkipTLSVerify` for local HTTPS. You can register multiple servers per agent with different transports, timeouts, retries, and filters per server. @@ -516,7 +519,7 @@ if err := a.RunA2A(ctx); err != nil { Remote [A2A](https://github.com/a2aproject/A2A) agents connect as tool providers: the SDK fetches the agent card, discovers skills, and registers each skill as a first-class tool available to the LLM across `Run`, `Stream`, `RunAsync`, and approval gates. Each server entry needs a **unique** name (the `WithA2AConfig` map key or the first argument to `a2aclient.NewClient`); tools are registered under stable names (`a2a__`) that do not collide across multiple remote agents. -At `NewAgent`, the SDK resolves the agent card, applies any `**SkillFilter`** (`AllowSkills`/`BlockSkills`), and registers the resulting tools — failing fast if a server is unreachable. +At `NewAgent`, the SDK resolves the agent card, applies any `**SkillFilter`** (`AllowSkills`/`BlockSkills`), and validates the setup — failing fast if a server is unreachable. After creation, add or remove A2A agents at any time via `a.A2ARegistry()` (see [Managing Capabilities at Runtime](#managing-capabilities-at-runtime)); the next run uses whatever agents are in the registry at that point. Configure auth, timeout, and skill filtering per server entry. `SkipTLSVerify` is available for local HTTPS development only. @@ -701,6 +704,66 @@ result, _ := mainAgent.Run(ctx, "What is 144 divided by 12?", nil) **Stream event fan-in:** Subscribe once on the main agent; the stream includes the full tree (tool events, `**AgentEventTypeCustom`** for approvals/delegation, optional `**AgentEventTypeStepStarted` / `AgentEventTypeStepFinished**` around sub-agent runs, `**AgentEventTypeRunFinished**`, etc.). For each event, use `**ev.Type()**` and type-assert to the concrete struct (see [examples/agent_with_stream](examples/agent_with_stream), [examples/agent_with_subagents](examples/agent_with_subagents)). For `**CUSTOM**`, assert `***AgentCustomEvent**`, then `[ParseCustomEventApproval](pkg/agent/event.go)` or `[ParseCustomEventDelegation](pkg/agent/event.go)` to read `**AgentName**`, `**ApprovalToken**`, `**ToolName**` or `**SubAgentName**`, and call `[OnApproval](pkg/agent/approval.go)` with the token. +### Managing Capabilities at Runtime + +All capabilities are resolved from their respective registries at execution time. Each `Run`, `Stream`, or `RunAsync` picks up the current registry state at call time — no restart needed. + +| What you want to change | Accessor | Methods | +|---|---|---| +| Native / custom tools | `a.ToolRegistry()` | `Register(tool)` · `Unregister(name)` | +| MCP servers | `a.MCPRegistry()` | `Register(name, config)` · `RegisterClient(cl)` · `Unregister(name)` | +| A2A remote agents | `a.A2ARegistry()` | `Register(name, config)` · `RegisterClient(cl)` · `Unregister(name)` | +| Specialist sub-agents | `a.SubAgentRegistry()` | `Register(sub)` · `Unregister(name)` | + +All registries are safe for concurrent use. Name uniqueness is enforced — `Register` on a name already in the registry returns `agent.ErrRegistryDuplicate`. To update an existing entry, call `Unregister` first then `Register` with the new configuration. + +**Tools example** + +```go +a, _ := agent.NewAgent(agent.WithToolRegistry(reg), ...) + +// first run — only tools already in reg +result, _ := a.Run(ctx, "What is 17 * 23?", nil) + +// add a tool before the next run +_ = a.ToolRegistry().Register(calculator.New()) +result, _ = a.Run(ctx, "What is 17 * 23?", nil) // now has calculator + +// remove it again +_ = a.ToolRegistry().Unregister("calculator") +``` + +**MCP example** — add a new MCP server after the agent is already running: + +```go +mcpReg := a.MCPRegistry() +cl, _ := mcpclient.NewClient("extra-server", mcp.MCPStreamableHTTP{URL: "https://..."}) +_ = mcpReg.RegisterClient(cl) +// next run includes tools from extra-server +result, _ := a.Run(ctx, prompt, nil) +``` + +**Sub-agent example** — attach a specialist to a running main agent: + +```go +math, _ := agent.NewAgent(agent.WithName("Math"), ...) +_ = a.SubAgentRegistry().Register(math) +// next run can delegate to Math +result, _ := a.Run(ctx, "What is 144 / 12?", nil) + +// revoke delegation +_ = a.SubAgentRegistry().Unregister("Math") +``` + +See [examples/agent_with_tools/dynamic_registry](examples/agent_with_tools/dynamic_registry) for a runnable tool registration example. + +**When to use static vs dynamic setup** + +- **Static** (`WithTools`, `WithMCPConfig`, `WithSubAgents`, etc.) — all options resolved at `NewAgent`; fastest, most straightforward for fixed configurations. +- **Dynamic** (registries after `NewAgent`) — needed when the capability set changes based on tenant, user session, runtime feature flags, or incremental deployment. + +Both styles can be combined: pass initial tools via `WithTools` or `WithToolRegistry` at creation, then add or remove via the registry later. + ### Approvals The model can trigger registry tools (`WithTools` / registry), MCP tools, and delegation to specialists (`WithSubAgents`). **User approval** can be required before any of those run. `WithToolApprovalPolicy` is the one setting that governs all of them. If you omit it, the default is **require-all**—each path goes through your approval handler. For `Run`, set `WithApprovalHandler` whenever approvals can occur. See [examples/agent_with_subagents](examples/agent_with_subagents). @@ -745,7 +808,7 @@ Math agent: WithToolApprovalPolicy(Auto) → calculator inside speciali Math agent: WithToolApprovalPolicy(RequireAll) → calculator inside specialist → approval (fan-in on main stream) ``` -Each `ApprovalRequest` includes `Respond`; call `req.Respond(Approved|Rejected)` when ready (same as RunAsync): +Each `ApprovalRequest` includes `Respond`; call `req.Respond(Approved|Rejected)` in `WithApprovalHandler`: ```go a, _ := agent.NewAgent( @@ -777,21 +840,20 @@ for ev := range eventCh { } ``` -**RunAsync** — channel-based completion without streaming. Do not set `WithApprovalHandler` for this path (it is replaced for the duration of the run). Receive each pending approval on `approvalCh` and call `req.Respond` (same idea as `WithApprovalHandler`): +**RunAsync** — starts the run in a goroutine and returns `resultCh`, which delivers one `AgentRunAsyncResult` when the run finishes (including after any tool approvals). Use `resultCh` for the final response; handle tool and sub-agent approvals with `WithApprovalHandler`, the same callback as `Run`. ```go -resultCh, approvalCh, err := a.RunAsync(ctx, prompt, nil) -if err != nil { /* validation error before goroutine started */ } - -go func() { - for req := range approvalCh { - _ = req.Respond(agent.ApprovalStatusApproved) // or Rejected - } -}() +a, _ := agent.NewAgent( + agent.WithApprovalHandler(func(ctx context.Context, req *agent.ApprovalRequest) { + _ = req.Respond(agent.ApprovalStatusApproved) + }), + // ... +) +resultCh, _ := a.RunAsync(ctx, prompt, nil) res := <-resultCh -if res.Err != nil { /* handle */ } -// res.Response.Content +if res.Error != nil { /* handle */ } +fmt.Println(res.Result.Content) ``` For **Run** / **RunAsync**, use `req.Respond` only. For **Stream**, use `**OnApproval`** as in the snippet above—the activity token string is `**ApprovalToken**` from `**ParseCustomEventApproval**` / `**ParseCustomEventDelegation**` (not a field on the `**AgentEvent**` interface). @@ -804,7 +866,7 @@ For **Run** / **RunAsync**, use `req.Respond` only. For **Stream**, use `**OnApp - **Run:** `Run()` returns `nil, err` with the failure. - **Stream:** An `AgentEventError` is emitted on the event channel with the error message. -- **RunAsync:** `resultCh` receives `RunAsyncResult` with `Err` set. +- **RunAsync:** `resultCh` receives `AgentRunAsyncResult` with `Error` set. ### Timeouts and deadlines @@ -1113,7 +1175,8 @@ a, _ := agent.NewAgent( | Span | Emitted by | |---|---| -| `agent.run` | `Agent.Run` / `Agent.RunAsync` | +| `agent.run` | `Agent.Run` | +| `agent.run.async` | `Agent.RunAsync` | | `agent.stream` | `Agent.Stream` (dispatch phase) | | `a2a.execute` | A2A server executor per request | | `llm.generate` | `AgentLLMActivity` (sync LLM call) | diff --git a/benchmarks/setup/mock_tool.go b/benchmarks/setup/mock_tool.go index 6dde63c..5d76c6c 100644 --- a/benchmarks/setup/mock_tool.go +++ b/benchmarks/setup/mock_tool.go @@ -5,6 +5,7 @@ import ( "fmt" "math/rand" + "github.com/agenticenv/agent-sdk-go/pkg/agent" "github.com/agenticenv/agent-sdk-go/pkg/interfaces" "github.com/agenticenv/agent-sdk-go/pkg/tools" ) @@ -50,10 +51,12 @@ func (t *MockBenchmarkTool) Execute(ctx context.Context, args map[string]any) (a return map[string]any{"tool": t.name, "input": input, "status": "ok"}, nil } -func RegisterBenchmarkTools(count int, cfg ToolConfig, rng *rand.Rand) *tools.Registry { - reg := tools.NewRegistry() +func RegisterBenchmarkTools(count int, cfg ToolConfig, rng *rand.Rand) agent.ToolRegistry { + reg := agent.NewToolRegistry() for i := 1; i <= count; i++ { - reg.Register(NewMockBenchmarkTool(i, cfg, rng)) + if err := reg.Register(NewMockBenchmarkTool(i, cfg, rng)); err != nil { + panic(err) + } } return reg } diff --git a/cmd/main.go b/cmd/main.go index b231bce..d635812 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -14,7 +14,6 @@ import ( "github.com/agenticenv/agent-sdk-go/internal/types" "github.com/agenticenv/agent-sdk-go/pkg/agent" "github.com/agenticenv/agent-sdk-go/pkg/conversation/inmem" - "github.com/agenticenv/agent-sdk-go/pkg/tools" "github.com/agenticenv/agent-sdk-go/pkg/tools/calculator" "github.com/agenticenv/agent-sdk-go/pkg/tools/currenttime" "github.com/agenticenv/agent-sdk-go/pkg/tools/echo" @@ -58,15 +57,18 @@ func main() { log.Fatalf("failed to create LLM client: %v", err) } - reg := tools.NewRegistry() - reg.Register(echo.New()) - reg.Register(currenttime.New()) - reg.Register(random.New()) - reg.Register(calculator.New()) - reg.Register(weather.New()) - reg.Register(wikipedia.New()) - reg.Register(search.New()) - + reg := agent.NewToolRegistry() + if err := agent.RegisterTools(reg, + echo.New(), + currenttime.New(), + random.New(), + calculator.New(), + weather.New(), + wikipedia.New(), + search.New(), + ); err != nil { + log.Fatalf("register tools: %v", err) + } mcpServers, err := BuildMCPServers(cfg) if err != nil { log.Fatalf("mcp config: %v", err) diff --git a/examples/README.md b/examples/README.md index 3dc41fb..e8c654c 100644 --- a/examples/README.md +++ b/examples/README.md @@ -29,10 +29,11 @@ These examples run with `AGENT_RUNTIME=local` (default) or `AGENT_RUNTIME=tempor | `agent_with_tools/approval` | Tools + `WithApprovalHandler` — user approves or rejects each tool run (`Run` only) | — | | `agent_with_tools/authorizer` | Custom tool authorization via `interfaces.ToolAuthorizer` — denied calls surface as `tool_result` with `denied` status | — | | `agent_with_tools/custom` | Custom tools via `WithTools` — implementing `interfaces.Tool` | — | +| `agent_with_tools/dynamic_registry` | Register a tool on a live agent between two runs — shows `ToolRegistry().Register` changing what the model can call without restarting | — | | `agent_with_stream` | Streaming with `Stream` — **`TEXT_MESSAGE_*`**, **`TOOL_CALL_*`**, **`RUN_FINISHED`**; prints token usage from **`RUN_FINISHED`** result when present | — | | `agent_with_agui` | Go **`POST /agui` SSE** + **Next.js + CopilotKit** ([`agent_with_agui/README.md`](agent_with_agui/README.md)) — agent server, then `ui/` dev server | UI manual (`npm run dev` in `ui/`) | | `agent_with_stream_conversation` | Stream + conversation; avoid printing the same text twice (**`TEXT_MESSAGE_CONTENT`** deltas vs **`RUN_FINISHED`** body) | — | -| `agent_with_run_async` | `RunAsync` — `resultCh` + `approvalCh`; use `req.Respond` (no `WithApprovalHandler`) | — | +| `agent_with_run_async` | `RunAsync` — `resultCh`; `WithApprovalHandler` for approvals (same as `Run`) | — | | `multiple_agents` | Multiple agents with `WithInstanceId` — sequential or concurrent | — | | `agent_with_subagents` | Main agent + math specialist — `WithSubAgents`; prints **`STEP_STARTED` / `STEP_FINISHED`** (sub-agent name) around each child run when using `Stream` | — | | `agent_with_json_response` | Structured LLM output — `WithResponseFormat` + `interfaces.JSONSchema` (JSON with schema; no tools) | — | @@ -99,6 +100,7 @@ go run ./agent_with_tools/basic "What's the weather in Tokyo?" go run ./agent_with_tools/approval "What is 15 + 27?" go run ./agent_with_tools/authorizer "Get the protected note for roadmap." go run ./agent_with_tools/custom "Reverse 'hello world'" +go run ./agent_with_tools/dynamic_registry ``` ### Streaming (partial content as tokens arrive) diff --git a/examples/agent_with_a2a_server/main.go b/examples/agent_with_a2a_server/main.go index fb9bad5..6c4a3ec 100644 --- a/examples/agent_with_a2a_server/main.go +++ b/examples/agent_with_a2a_server/main.go @@ -12,7 +12,6 @@ import ( "github.com/a2aproject/a2a-go/v2/a2asrv" config "github.com/agenticenv/agent-sdk-go/examples" "github.com/agenticenv/agent-sdk-go/pkg/agent" - "github.com/agenticenv/agent-sdk-go/pkg/tools" "github.com/agenticenv/agent-sdk-go/pkg/tools/echo" ) @@ -24,9 +23,12 @@ func main() { log.Fatalf("failed to create LLM client: %v", err) } - reg := tools.NewRegistry() - reg.Register(echo.New()) - + reg := agent.NewToolRegistry() + if err := agent.RegisterTools(reg, + echo.New(), + ); err != nil { + log.Fatalf("register tools: %v", err) + } opts := []agent.Option{ agent.WithName("agent-with-a2a-server"), agent.WithDescription("Example agent exposed as an A2A HTTP server (agent card + JSON-RPC)."), diff --git a/examples/agent_with_agui/server/main.go b/examples/agent_with_agui/server/main.go index 6beb785..345193c 100644 --- a/examples/agent_with_agui/server/main.go +++ b/examples/agent_with_agui/server/main.go @@ -10,7 +10,6 @@ import ( config "github.com/agenticenv/agent-sdk-go/examples" "github.com/agenticenv/agent-sdk-go/pkg/agent" "github.com/agenticenv/agent-sdk-go/pkg/interfaces" - "github.com/agenticenv/agent-sdk-go/pkg/tools" "github.com/agenticenv/agent-sdk-go/pkg/tools/calculator" "github.com/agenticenv/agent-sdk-go/pkg/tools/echo" ) @@ -45,10 +44,13 @@ func main() { log.Fatalf("LLM client: %v", err) } - reg := tools.NewRegistry() - reg.Register(echo.New()) - reg.Register(calculator.New()) - + reg := agent.NewToolRegistry() + if err := agent.RegisterTools(reg, + echo.New(), + calculator.New(), + ); err != nil { + log.Fatalf("register tools: %v", err) + } agentOpts := []agent.Option{ agent.WithName("agui-demo-agent"), agent.WithDescription("Streaming demo for AG-UI / CopilotKit"), diff --git a/examples/agent_with_conversation/main.go b/examples/agent_with_conversation/main.go index bd74528..ff503f3 100644 --- a/examples/agent_with_conversation/main.go +++ b/examples/agent_with_conversation/main.go @@ -11,7 +11,6 @@ import ( config "github.com/agenticenv/agent-sdk-go/examples" "github.com/agenticenv/agent-sdk-go/pkg/agent" "github.com/agenticenv/agent-sdk-go/pkg/conversation/redis" - "github.com/agenticenv/agent-sdk-go/pkg/tools" "github.com/agenticenv/agent-sdk-go/pkg/tools/calculator" "github.com/agenticenv/agent-sdk-go/pkg/tools/echo" ) @@ -34,10 +33,13 @@ func main() { } defer func() { _ = conv.Close() }() - reg := tools.NewRegistry() - reg.Register(echo.New()) - reg.Register(calculator.New()) - + reg := agent.NewToolRegistry() + if err := agent.RegisterTools(reg, + echo.New(), + calculator.New(), + ); err != nil { + log.Fatalf("register tools: %v", err) + } opts := []agent.Option{ agent.WithName("agent-with-conversation"), agent.WithDescription("Agent with Redis conversation and tools for multi-turn context"), diff --git a/examples/agent_with_observability/setup/setup.go b/examples/agent_with_observability/setup/setup.go index 428a70e..5a21e0d 100644 --- a/examples/agent_with_observability/setup/setup.go +++ b/examples/agent_with_observability/setup/setup.go @@ -10,7 +10,6 @@ import ( "github.com/agenticenv/agent-sdk-go/pkg/agent" "github.com/agenticenv/agent-sdk-go/pkg/interfaces" "github.com/agenticenv/agent-sdk-go/pkg/observability" - "github.com/agenticenv/agent-sdk-go/pkg/tools" "github.com/agenticenv/agent-sdk-go/pkg/tools/calculator" ) @@ -49,9 +48,10 @@ func MustParseOTLP() OTLP { // BaseAgentOptions returns shared [agent.Option]s for both examples (identity, Temporal, LLM, logger). func BaseAgentOptions(cfg *excfg.Config, llm interfaces.LLMClient) []agent.Option { - reg := tools.NewRegistry() - reg.Register(calculator.New()) - + reg := agent.NewToolRegistry() + if err := agent.RegisterTools(reg, calculator.New()); err != nil { + log.Fatalf("register tools: %v", err) + } opts := []agent.Option{ agent.WithName("observability-example-agent"), agent.WithDescription("Agent demonstrating OTLP wiring (see examples/agent_with_observability)."), diff --git a/examples/agent_with_run_async/main.go b/examples/agent_with_run_async/main.go index 49170b5..df47395 100644 --- a/examples/agent_with_run_async/main.go +++ b/examples/agent_with_run_async/main.go @@ -1,5 +1,5 @@ -// agent_with_run_async demonstrates RunAsync: result and approval channels without -// WithApprovalHandler or Stream. Complete each approval with req.Respond. +// agent_with_run_async demonstrates RunAsync: non-blocking result channel with +// WithApprovalHandler for tool approvals (same as Run). package main import ( @@ -10,11 +10,9 @@ import ( "log" "os" "strings" - "sync" config "github.com/agenticenv/agent-sdk-go/examples" "github.com/agenticenv/agent-sdk-go/pkg/agent" - "github.com/agenticenv/agent-sdk-go/pkg/tools" "github.com/agenticenv/agent-sdk-go/pkg/tools/calculator" "github.com/agenticenv/agent-sdk-go/pkg/tools/echo" ) @@ -27,10 +25,13 @@ func main() { log.Fatalf("failed to create LLM client: %v", err) } - reg := tools.NewRegistry() - reg.Register(echo.New()) - reg.Register(calculator.New()) - + reg := agent.NewToolRegistry() + if err := agent.RegisterTools(reg, + echo.New(), + calculator.New(), + ); err != nil { + log.Fatalf("register tools: %v", err) + } lineCh := make(chan string) go func() { scanner := bufio.NewScanner(os.Stdin) @@ -42,10 +43,11 @@ func main() { opts := []agent.Option{ agent.WithName("agent-with-run-async"), - agent.WithDescription("RunAsync demo: approvals on approvalCh, outcome on resultCh"), + agent.WithDescription("RunAsync demo: WithApprovalHandler, outcome on resultCh"), agent.WithSystemPrompt("You are a helpful assistant. Use the echo or calculator tool when asked."), agent.WithLLMClient(llmClient), agent.WithToolRegistry(reg), + agent.WithApprovalHandler(makeApprovalHandler(lineCh)), agent.WithLogger(config.NewLoggerFromLogConfig(cfg)), } opts = append(opts, config.ToolApprovalOptions()...) @@ -63,43 +65,13 @@ func main() { } ctx := context.Background() - resultCh, approvalCh, err := a.RunAsync(ctx, prompt, nil) + resultCh, err := a.RunAsync(ctx, prompt, nil) if err != nil { log.Fatalf("RunAsync: %v", err) } - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - for req := range approvalCh { - v, err := agent.ParseToolApproval(req) - if err != nil { - log.Printf("approval from RunAsync: %v", err) - continue - } - args := v.Args - if args == nil { - args = map[string]any{} - } - argsJSON, _ := json.MarshalIndent(args, "", " ") - fmt.Printf("\n--- Tool approval required ---\nTool: %s\nArgs:\n%s\nApprove? (y/n): ", v.ToolName, string(argsJSON)) - line, ok := <-lineCh - if ok && strings.TrimSpace(strings.ToLower(line)) == "y" { - if err := req.Respond(agent.ApprovalStatusApproved); err != nil { - log.Printf("respond approved: %v", err) - } - } else if ok { - if err := req.Respond(agent.ApprovalStatusRejected); err != nil { - log.Printf("respond rejected: %v", err) - } - } - } - }() - fmt.Println("user:", prompt) res := <-resultCh - wg.Wait() if res.Error != nil { log.Printf("run failed: %v", res.Error) @@ -111,3 +83,29 @@ func main() { } fmt.Println("agent:", res.Result.Content) } + +func makeApprovalHandler(lineCh <-chan string) agent.ApprovalHandler { + return func(ctx context.Context, req *agent.ApprovalRequest) { + v, err := agent.ParseToolApproval(req) + if err != nil { + log.Printf("approval handler: %v", err) + return + } + args := v.Args + if args == nil { + args = map[string]any{} + } + argsJSON, _ := json.MarshalIndent(args, "", " ") + fmt.Printf("\n--- Tool approval required ---\nTool: %s\nArgs:\n%s\nApprove? (y/n): ", v.ToolName, string(argsJSON)) + select { + case <-ctx.Done(): + return + case line, ok := <-lineCh: + if ok && strings.TrimSpace(strings.ToLower(line)) == "y" { + _ = req.Respond(agent.ApprovalStatusApproved) + } else if ok { + _ = req.Respond(agent.ApprovalStatusRejected) + } + } + } +} diff --git a/examples/agent_with_stream/main.go b/examples/agent_with_stream/main.go index 11d2baf..fa6a22f 100644 --- a/examples/agent_with_stream/main.go +++ b/examples/agent_with_stream/main.go @@ -12,7 +12,6 @@ import ( "github.com/agenticenv/agent-sdk-go/examples/shared" "github.com/agenticenv/agent-sdk-go/pkg/agent" - "github.com/agenticenv/agent-sdk-go/pkg/tools" "github.com/agenticenv/agent-sdk-go/pkg/tools/calculator" "github.com/agenticenv/agent-sdk-go/pkg/tools/currenttime" "github.com/agenticenv/agent-sdk-go/pkg/tools/echo" @@ -30,15 +29,18 @@ func main() { log.Fatalf("failed to create LLM client: %v", err) } - reg := tools.NewRegistry() - reg.Register(echo.New()) - reg.Register(currenttime.New()) - reg.Register(random.New()) - reg.Register(calculator.New()) - reg.Register(weather.New()) - reg.Register(wikipedia.New()) - reg.Register(search.New()) - + reg := agent.NewToolRegistry() + if err := agent.RegisterTools(reg, + echo.New(), + currenttime.New(), + random.New(), + calculator.New(), + weather.New(), + wikipedia.New(), + search.New(), + ); err != nil { + log.Fatalf("register tools: %v", err) + } opts := []agent.Option{ agent.WithName("agent-with-stream"), agent.WithDescription("Agent that streams events via Stream"), diff --git a/examples/agent_with_stream_conversation/main.go b/examples/agent_with_stream_conversation/main.go index ecbe875..1ca885d 100644 --- a/examples/agent_with_stream_conversation/main.go +++ b/examples/agent_with_stream_conversation/main.go @@ -12,7 +12,6 @@ import ( "github.com/agenticenv/agent-sdk-go/examples/shared" "github.com/agenticenv/agent-sdk-go/pkg/agent" "github.com/agenticenv/agent-sdk-go/pkg/conversation/inmem" - "github.com/agenticenv/agent-sdk-go/pkg/tools" "github.com/agenticenv/agent-sdk-go/pkg/tools/calculator" "github.com/agenticenv/agent-sdk-go/pkg/tools/echo" ) @@ -29,10 +28,13 @@ func main() { conv := inmem.NewInMemoryConversation(inmem.WithMaxSize(100)) - reg := tools.NewRegistry() - reg.Register(echo.New()) - reg.Register(calculator.New()) - + reg := agent.NewToolRegistry() + if err := agent.RegisterTools(reg, + echo.New(), + calculator.New(), + ); err != nil { + log.Fatalf("register tools: %v", err) + } opts := []agent.Option{ agent.WithName("agent-stream-conversation"), agent.WithDescription("Stream with conversation; shows event handling pattern to avoid duplicate output"), diff --git a/examples/agent_with_subagents/main.go b/examples/agent_with_subagents/main.go index d490b78..6d1b0e9 100644 --- a/examples/agent_with_subagents/main.go +++ b/examples/agent_with_subagents/main.go @@ -12,7 +12,6 @@ import ( config "github.com/agenticenv/agent-sdk-go/examples" "github.com/agenticenv/agent-sdk-go/examples/shared" "github.com/agenticenv/agent-sdk-go/pkg/agent" - "github.com/agenticenv/agent-sdk-go/pkg/tools" "github.com/agenticenv/agent-sdk-go/pkg/tools/calculator" ) @@ -46,8 +45,10 @@ func main() { mathQueue := baseQueue + "-math-specialist" mainQueue := baseQueue + "-main-agent" - mathReg := tools.NewRegistry() - mathReg.Register(calculator.New()) + mathReg := agent.NewToolRegistry() + if err := mathReg.Register(calculator.New()); err != nil { + log.Fatalf("register tools: %v", err) + } mathAgentOpts := []agent.Option{ agent.WithName("MathSpecialist"), diff --git a/examples/agent_with_tools/approval/main.go b/examples/agent_with_tools/approval/main.go index 352d7c3..22f1e97 100644 --- a/examples/agent_with_tools/approval/main.go +++ b/examples/agent_with_tools/approval/main.go @@ -11,7 +11,6 @@ import ( config "github.com/agenticenv/agent-sdk-go/examples" "github.com/agenticenv/agent-sdk-go/pkg/agent" - "github.com/agenticenv/agent-sdk-go/pkg/tools" "github.com/agenticenv/agent-sdk-go/pkg/tools/calculator" "github.com/agenticenv/agent-sdk-go/pkg/tools/echo" ) @@ -24,10 +23,13 @@ func main() { log.Fatalf("failed to create LLM client: %v", err) } - reg := tools.NewRegistry() - reg.Register(echo.New()) - reg.Register(calculator.New()) - + reg := agent.NewToolRegistry() + if err := agent.RegisterTools(reg, + echo.New(), + calculator.New(), + ); err != nil { + log.Fatalf("register tools: %v", err) + } lineCh := make(chan string) go func() { scanner := bufio.NewScanner(os.Stdin) diff --git a/examples/agent_with_tools/basic/main.go b/examples/agent_with_tools/basic/main.go index c898717..2e7958a 100644 --- a/examples/agent_with_tools/basic/main.go +++ b/examples/agent_with_tools/basic/main.go @@ -9,7 +9,6 @@ import ( config "github.com/agenticenv/agent-sdk-go/examples" "github.com/agenticenv/agent-sdk-go/pkg/agent" - "github.com/agenticenv/agent-sdk-go/pkg/tools" "github.com/agenticenv/agent-sdk-go/pkg/tools/calculator" "github.com/agenticenv/agent-sdk-go/pkg/tools/currenttime" "github.com/agenticenv/agent-sdk-go/pkg/tools/echo" @@ -27,15 +26,18 @@ func main() { log.Fatalf("failed to create LLM client: %v", err) } - reg := tools.NewRegistry() - reg.Register(echo.New()) - reg.Register(currenttime.New()) - reg.Register(random.New()) - reg.Register(calculator.New()) - reg.Register(weather.New()) - reg.Register(wikipedia.New()) - reg.Register(search.New()) - + reg := agent.NewToolRegistry() + if err := agent.RegisterTools(reg, + echo.New(), + currenttime.New(), + random.New(), + calculator.New(), + weather.New(), + wikipedia.New(), + search.New(), + ); err != nil { + log.Fatalf("register tools: %v", err) + } opts := []agent.Option{ agent.WithName("agent-with-tools"), agent.WithDescription("Agent with echo, currenttime, random, calculator, weather, wikipedia, search tools"), diff --git a/examples/agent_with_tools/dynamic_registry/main.go b/examples/agent_with_tools/dynamic_registry/main.go new file mode 100644 index 0000000..3e36e63 --- /dev/null +++ b/examples/agent_with_tools/dynamic_registry/main.go @@ -0,0 +1,69 @@ +package main + +import ( + "context" + "fmt" + "log" + + config "github.com/agenticenv/agent-sdk-go/examples" + "github.com/agenticenv/agent-sdk-go/pkg/agent" + "github.com/agenticenv/agent-sdk-go/pkg/tools/calculator" + "github.com/agenticenv/agent-sdk-go/pkg/tools/echo" +) + +func main() { + cfg := config.LoadFromEnv() + + llmClient, err := config.NewLLMClientFromConfig(cfg) + if err != nil { + log.Fatalf("failed to create LLM client: %v", err) + } + + reg := agent.NewToolRegistry() + if err := agent.RegisterTools(reg, echo.New()); err != nil { + log.Fatalf("register tools: %v", err) + } + + opts := []agent.Option{ + agent.WithName("dynamic-registry"), + agent.WithDescription("Agent whose tools can change between runs via ToolRegistry"), + agent.WithSystemPrompt("You are a helpful assistant. Use tools when they are available; do not guess numeric results when a calculator tool exists."), + agent.WithLLMClient(llmClient), + agent.WithToolRegistry(reg), + agent.WithToolApprovalPolicy(agent.AutoToolApprovalPolicy()), + agent.WithLogger(config.NewLoggerFromLogConfig(cfg)), + } + opts = append(opts, config.RuntimeOption(cfg)...) + + a, err := agent.NewAgent(opts...) + if err != nil { + log.Fatal(config.FormatNewAgentError("failed to create agent", err)) + } + defer a.Close() + + ctx := context.Background() + mathPrompt := "What is 17 times 23? Use the calculator tool if you have it." + + fmt.Println("--- run 1 (echo only) ---") + fmt.Println("user:", mathPrompt) + result, err := a.Run(ctx, mathPrompt, nil) + if err != nil { + log.Printf("run 1 failed: %v", err) + } else { + fmt.Println("agent:", result.Content) + } + + if err := a.ToolRegistry().Register(calculator.New()); err != nil { + log.Fatalf("register calculator: %v", err) + } + fmt.Println("\nregistered calculator on ToolRegistry()") + + fmt.Println("\n--- run 2 (echo + calculator) ---") + fmt.Println("user:", mathPrompt) + result, err = a.Run(ctx, mathPrompt, nil) + if err != nil { + log.Printf("run 2 failed: %v", err) + return + } + fmt.Println("agent:", result.Content) +} diff --git a/internal/runtime/base/runtime.go b/internal/runtime/base/runtime.go index 9c0c0ae..1dc0792 100644 --- a/internal/runtime/base/runtime.go +++ b/internal/runtime/base/runtime.go @@ -22,10 +22,10 @@ import ( // Runtime holds the execution inputs shared by all runtime backends. // Local and Temporal runtimes embed this struct and call its methods directly. type Runtime struct { - AgentSpec runtime.AgentSpec - AgentExecution runtime.AgentExecution - Tracer interfaces.Tracer - Metrics interfaces.Metrics + AgentSpec runtime.AgentSpec + AgentConfig runtime.AgentConfig + Tracer interfaces.Tracer + Metrics interfaces.Metrics // ToolExecutionMode controls whether tool calls in one LLM round are executed // in parallel or sequentially. Defaults to parallel when empty. ToolExecutionMode types.AgentToolExecutionMode @@ -33,9 +33,8 @@ type Runtime struct { // BuildLLMRequest constructs an LLMRequest from the given messages and options. // When retrieverContext is non-empty it is appended to the system prompt (prefetch/hybrid mode). -// Returns the request and the resolved tools slice for later use in response parsing. -func (rt *Runtime) BuildLLMRequest(messages []interfaces.Message, skipTools bool, retrieverContext string) (*interfaces.LLMRequest, []interfaces.Tool) { - tools := rt.AgentExecution.Tools.Tools +// tools is the per-run resolved tool list from [runtime.ExecuteRequest] or activity resolve. +func (rt *Runtime) BuildLLMRequest(messages []interfaces.Message, skipTools bool, retrieverContext string, tools []interfaces.Tool) *interfaces.LLMRequest { systemMessage := rt.AgentSpec.SystemPrompt if retrieverContext != "" { systemMessage = fmt.Sprintf("%s\n\nRelevant Context:\n%s", rt.AgentSpec.SystemPrompt, retrieverContext) @@ -45,25 +44,25 @@ func (rt *Runtime) BuildLLMRequest(messages []interfaces.Message, skipTools bool ResponseFormat: rt.AgentSpec.ResponseFormat, Messages: messages, } - ApplyLLMSampling(rt.AgentExecution.LLM.Sampling, req) + ApplyLLMSampling(rt.AgentConfig.LLM.Sampling, req) if skipTools { req.Tools = []interfaces.ToolSpec{} } else { req.Tools = interfaces.ToolsToSpecs(tools) } - return req, tools + return req } // RequiresApproval reports whether t requires human approval before execution. // When no approval policy is configured the tool's own ApprovalRequired flag is used. func (rt *Runtime) RequiresApproval(t interfaces.Tool) bool { - if rt.AgentExecution.Tools.ApprovalPolicy == nil { + if rt.AgentConfig.ToolApprovalPolicy == nil { if ar, ok := t.(interfaces.ToolApproval); ok && ar.ApprovalRequired() { return true } return false } - return rt.AgentExecution.Tools.ApprovalPolicy.RequiresApproval(t) + return rt.AgentConfig.ToolApprovalPolicy.RequiresApproval(t) } // FetchConversationMessages loads prior messages from the conversation store. @@ -71,11 +70,11 @@ func (rt *Runtime) RequiresApproval(t interfaces.Tool) bool { func (rt *Runtime) FetchConversationMessages(ctx context.Context, log logger.Logger, conversationID string) ([]interfaces.Message, error) { log.Debug(ctx, "runtime: loading conversation history", slog.String("scope", "runtime"), slog.String("conversationID", conversationID)) - if rt.AgentExecution.Session.Conversation == nil { + if rt.AgentConfig.Session.Conversation == nil { return nil, fmt.Errorf("conversation is not configured") } - limit := rt.AgentExecution.Session.ConversationSize + limit := rt.AgentConfig.Session.ConversationSize if limit <= 0 { limit = 20 } @@ -86,7 +85,7 @@ func (rt *Runtime) FetchConversationMessages(ctx context.Context, log logger.Log ) defer sp.End() - messages, err := rt.AgentExecution.Session.Conversation.ListMessages(ctx, conversationID, interfaces.WithLimit(limit)) + messages, err := rt.AgentConfig.Session.Conversation.ListMessages(ctx, conversationID, interfaces.WithLimit(limit)) if err != nil { sp.RecordError(err) return nil, fmt.Errorf("failed to list conversation messages: %w", err) @@ -141,11 +140,12 @@ func (rt *Runtime) ExecuteLLM( messages []interfaces.Message, skipTools bool, retrieverContext string, + tools []interfaces.Tool, emit func(events.AgentEvent), ) (*LLMResult, error) { - req, tools := rt.BuildLLMRequest(messages, skipTools, retrieverContext) + req := rt.BuildLLMRequest(messages, skipTools, retrieverContext, tools) - llmClient := rt.AgentExecution.LLM.Client + llmClient := rt.AgentConfig.LLM.Client model := llmClient.GetModel() provider := string(llmClient.GetProvider()) modelAttr := interfaces.Attribute{Key: types.MetricAttrModel, Value: model} @@ -204,11 +204,12 @@ func (rt *Runtime) ExecuteLLMStream( messages []interfaces.Message, skipTools bool, retrieverContext string, + tools []interfaces.Tool, emit func(events.AgentEvent), ) (*LLMResult, error) { - req, tools := rt.BuildLLMRequest(messages, skipTools, retrieverContext) + req := rt.BuildLLMRequest(messages, skipTools, retrieverContext, tools) - llmClient := rt.AgentExecution.LLM.Client + llmClient := rt.AgentConfig.LLM.Client model := llmClient.GetModel() provider := string(llmClient.GetProvider()) modelAttr := interfaces.Attribute{Key: types.MetricAttrModel, Value: model} @@ -372,10 +373,10 @@ func (rt *Runtime) ExecuteLLMStream( // ExecuteTool finds the named tool and executes it, recording tracing and metrics. // Returns the string representation of the tool result. -func (rt *Runtime) ExecuteTool(ctx context.Context, log logger.Logger, toolName string, args map[string]any) (string, error) { +func (rt *Runtime) ExecuteTool(ctx context.Context, log logger.Logger, tools []interfaces.Tool, toolName string, args map[string]any) (string, error) { log.Debug(ctx, "runtime: tool execute started", slog.String("scope", "runtime"), slog.String("tool", toolName), slog.Int("argCount", len(args))) - tool, ok := FindToolByName(rt.AgentExecution.Tools.Tools, toolName) + tool, ok := FindToolByName(tools, toolName) if !ok { log.Warn(ctx, "runtime: unknown tool", slog.String("scope", "runtime"), slog.String("tool", toolName)) return "", fmt.Errorf("unknown tool: %s", toolName) @@ -408,10 +409,10 @@ func (rt *Runtime) ExecuteTool(ctx context.Context, log logger.Logger, toolName // AuthorizeTool checks programmatic authorization for a tool before approval/execution. // Tools that do not implement interfaces.ToolAuthorizer are allowed by default. -func (rt *Runtime) AuthorizeTool(ctx context.Context, log logger.Logger, toolName string, args map[string]any) (AuthorizeResult, error) { +func (rt *Runtime) AuthorizeTool(ctx context.Context, log logger.Logger, tools []interfaces.Tool, toolName string, args map[string]any) (AuthorizeResult, error) { log.Debug(ctx, "runtime: tool authorize started", slog.String("scope", "runtime"), slog.String("tool", toolName), slog.Int("argCount", len(args))) - tool, ok := FindToolByName(rt.AgentExecution.Tools.Tools, toolName) + tool, ok := FindToolByName(tools, toolName) if !ok { log.Warn(ctx, "runtime: unknown tool in authorization", slog.String("scope", "runtime"), slog.String("tool", toolName)) return AuthorizeResult{}, fmt.Errorf("unknown tool: %s", toolName) @@ -453,7 +454,7 @@ func (rt *Runtime) AuthorizeTool(ctx context.Context, log logger.Logger, toolNam // returns a combined document context string for injection into the LLM system prompt. // Partial failures are logged and skipped; all retrievers failing returns an error. func (rt *Runtime) ExecuteRetrievers(ctx context.Context, log logger.Logger, query string) (string, error) { - retrievers := rt.AgentExecution.Retrievers.Retrievers + retrievers := rt.AgentConfig.Retrievers.Retrievers if len(retrievers) == 0 { return "", nil } diff --git a/internal/runtime/base/runtime_test.go b/internal/runtime/base/runtime_test.go index f59e296..db3a882 100644 --- a/internal/runtime/base/runtime_test.go +++ b/internal/runtime/base/runtime_test.go @@ -17,15 +17,15 @@ import ( ) // newTestRuntime returns a Runtime wired with noop tracer/metrics and the provided execution. -func newTestRuntime(exec sdkruntime.AgentExecution) *Runtime { +func newTestRuntime(exec sdkruntime.AgentConfig) *Runtime { return &Runtime{ AgentSpec: sdkruntime.AgentSpec{ Name: "test-agent", SystemPrompt: "you are helpful", }, - AgentExecution: exec, - Tracer: observability.DefaultNoopTracer, - Metrics: observability.DefaultNoopMetrics, + AgentConfig: exec, + Tracer: observability.DefaultNoopTracer, + Metrics: observability.DefaultNoopMetrics, } } @@ -50,22 +50,22 @@ func (stubLLMClient) IsStreamSupported() bool { return false } // --- BuildLLMRequest --- func TestBuildLLMRequest_Basic(t *testing.T) { - rt := newTestRuntime(sdkruntime.AgentExecution{ + rt := newTestRuntime(sdkruntime.AgentConfig{ LLM: sdkruntime.AgentLLM{Client: stubLLMClient{}}, }) msgs := []interfaces.Message{{Role: interfaces.MessageRoleUser, Content: "hello"}} - req, tools := rt.BuildLLMRequest(msgs, false, "") + req := rt.BuildLLMRequest(msgs, false, "", nil) require.Equal(t, "you are helpful", req.SystemMessage) require.Equal(t, msgs, req.Messages) - require.Empty(t, tools) + require.Empty(t, req.Tools) } func TestBuildLLMRequest_WithRetrieverContext(t *testing.T) { - rt := newTestRuntime(sdkruntime.AgentExecution{ + rt := newTestRuntime(sdkruntime.AgentConfig{ LLM: sdkruntime.AgentLLM{Client: stubLLMClient{}}, }) - req, _ := rt.BuildLLMRequest(nil, false, "extra context") + req := rt.BuildLLMRequest(nil, false, "extra context", nil) require.Contains(t, req.SystemMessage, "you are helpful") require.Contains(t, req.SystemMessage, "extra context") } @@ -77,11 +77,10 @@ func TestBuildLLMRequest_SkipTools(t *testing.T) { tool.EXPECT().Description().Return("").AnyTimes() tool.EXPECT().Parameters().Return(nil).AnyTimes() - rt := newTestRuntime(sdkruntime.AgentExecution{ - LLM: sdkruntime.AgentLLM{Client: stubLLMClient{}}, - Tools: sdkruntime.AgentTools{Tools: []interfaces.Tool{tool}}, + rt := newTestRuntime(sdkruntime.AgentConfig{ + LLM: sdkruntime.AgentLLM{Client: stubLLMClient{}}, }) - req, _ := rt.BuildLLMRequest(nil, true, "") + req := rt.BuildLLMRequest(nil, true, "", []interfaces.Tool{tool}) require.Empty(t, req.Tools) } @@ -120,7 +119,7 @@ func (a authorizerToolStub) Authorize(_ context.Context, _ map[string]any) (inte // --- RequiresApproval --- func TestRequiresApproval_NoPolicyToolHasApproval(t *testing.T) { - rt := newTestRuntime(sdkruntime.AgentExecution{}) + rt := newTestRuntime(sdkruntime.AgentConfig{}) tool := approvalToolStub{name: "t", approvalRequired: true} require.True(t, rt.RequiresApproval(tool)) } @@ -128,14 +127,14 @@ func TestRequiresApproval_NoPolicyToolHasApproval(t *testing.T) { func TestRequiresApproval_NoPolicyToolNoApproval(t *testing.T) { ctrl := gomock.NewController(t) tool := ifmocks.NewMockTool(ctrl) - rt := newTestRuntime(sdkruntime.AgentExecution{}) + rt := newTestRuntime(sdkruntime.AgentConfig{}) require.False(t, rt.RequiresApproval(tool)) } // --- FetchConversationMessages --- func TestFetchConversationMessages_NoConversation(t *testing.T) { - rt := newTestRuntime(sdkruntime.AgentExecution{ + rt := newTestRuntime(sdkruntime.AgentConfig{ Session: sdkruntime.AgentSession{Conversation: nil}, }) _, err := rt.FetchConversationMessages(context.Background(), noopLog(), "conv-1") @@ -149,7 +148,7 @@ func TestFetchConversationMessages_Success(t *testing.T) { msgs := []interfaces.Message{{Role: interfaces.MessageRoleUser, Content: "hi"}} conv.EXPECT().ListMessages(gomock.Any(), "conv-1", gomock.Any()).Return(msgs, nil) - rt := newTestRuntime(sdkruntime.AgentExecution{ + rt := newTestRuntime(sdkruntime.AgentConfig{ Session: sdkruntime.AgentSession{Conversation: conv, ConversationSize: 10}, }) got, err := rt.FetchConversationMessages(context.Background(), noopLog(), "conv-1") @@ -162,7 +161,7 @@ func TestFetchConversationMessages_Error(t *testing.T) { conv := ifmocks.NewMockConversation(ctrl) conv.EXPECT().ListMessages(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("store down")) - rt := newTestRuntime(sdkruntime.AgentExecution{ + rt := newTestRuntime(sdkruntime.AgentConfig{ Session: sdkruntime.AgentSession{Conversation: conv}, }) _, err := rt.FetchConversationMessages(context.Background(), noopLog(), "c") @@ -173,10 +172,8 @@ func TestFetchConversationMessages_Error(t *testing.T) { // --- ExecuteTool --- func TestExecuteTool_UnknownTool(t *testing.T) { - rt := newTestRuntime(sdkruntime.AgentExecution{ - Tools: sdkruntime.AgentTools{Tools: nil}, - }) - _, err := rt.ExecuteTool(context.Background(), noopLog(), "missing", nil) + rt := newTestRuntime(sdkruntime.AgentConfig{}) + _, err := rt.ExecuteTool(context.Background(), noopLog(), nil, "missing", nil) require.Error(t, err) require.Contains(t, err.Error(), "unknown tool") } @@ -187,10 +184,8 @@ func TestExecuteTool_Success(t *testing.T) { tool.EXPECT().Name().Return("calc").AnyTimes() tool.EXPECT().Execute(gomock.Any(), gomock.Any()).Return("42", nil) - rt := newTestRuntime(sdkruntime.AgentExecution{ - Tools: sdkruntime.AgentTools{Tools: []interfaces.Tool{tool}}, - }) - result, err := rt.ExecuteTool(context.Background(), noopLog(), "calc", map[string]any{"x": 1}) + rt := newTestRuntime(sdkruntime.AgentConfig{}) + result, err := rt.ExecuteTool(context.Background(), noopLog(), []interfaces.Tool{tool}, "calc", map[string]any{"x": 1}) require.NoError(t, err) require.Equal(t, "42", result) } @@ -201,10 +196,8 @@ func TestExecuteTool_ToolError(t *testing.T) { tool.EXPECT().Name().Return("fail-tool").AnyTimes() tool.EXPECT().Execute(gomock.Any(), gomock.Any()).Return(nil, errors.New("tool failed")) - rt := newTestRuntime(sdkruntime.AgentExecution{ - Tools: sdkruntime.AgentTools{Tools: []interfaces.Tool{tool}}, - }) - _, err := rt.ExecuteTool(context.Background(), noopLog(), "fail-tool", nil) + rt := newTestRuntime(sdkruntime.AgentConfig{}) + _, err := rt.ExecuteTool(context.Background(), noopLog(), []interfaces.Tool{tool}, "fail-tool", nil) require.Error(t, err) require.Contains(t, err.Error(), "tool failed") } @@ -212,8 +205,8 @@ func TestExecuteTool_ToolError(t *testing.T) { // --- AuthorizeTool --- func TestAuthorizeTool_UnknownTool(t *testing.T) { - rt := newTestRuntime(sdkruntime.AgentExecution{}) - _, err := rt.AuthorizeTool(context.Background(), noopLog(), "ghost", nil) + rt := newTestRuntime(sdkruntime.AgentConfig{}) + _, err := rt.AuthorizeTool(context.Background(), noopLog(), nil, "ghost", nil) require.Error(t, err) } @@ -222,30 +215,24 @@ func TestAuthorizeTool_NoAuthorizer_AllowedByDefault(t *testing.T) { tool := ifmocks.NewMockTool(ctrl) tool.EXPECT().Name().Return("plain").AnyTimes() - rt := newTestRuntime(sdkruntime.AgentExecution{ - Tools: sdkruntime.AgentTools{Tools: []interfaces.Tool{tool}}, - }) - result, err := rt.AuthorizeTool(context.Background(), noopLog(), "plain", nil) + rt := newTestRuntime(sdkruntime.AgentConfig{}) + result, err := rt.AuthorizeTool(context.Background(), noopLog(), []interfaces.Tool{tool}, "plain", nil) require.NoError(t, err) require.True(t, result.Allowed) } func TestAuthorizeTool_Allowed(t *testing.T) { tool := authorizerToolStub{name: "secure", allow: true} - rt := newTestRuntime(sdkruntime.AgentExecution{ - Tools: sdkruntime.AgentTools{Tools: []interfaces.Tool{tool}}, - }) - result, err := rt.AuthorizeTool(context.Background(), noopLog(), "secure", nil) + rt := newTestRuntime(sdkruntime.AgentConfig{}) + result, err := rt.AuthorizeTool(context.Background(), noopLog(), []interfaces.Tool{tool}, "secure", nil) require.NoError(t, err) require.True(t, result.Allowed) } func TestAuthorizeTool_Denied(t *testing.T) { tool := authorizerToolStub{name: "gated", allow: false, reason: "not allowed"} - rt := newTestRuntime(sdkruntime.AgentExecution{ - Tools: sdkruntime.AgentTools{Tools: []interfaces.Tool{tool}}, - }) - result, err := rt.AuthorizeTool(context.Background(), noopLog(), "gated", nil) + rt := newTestRuntime(sdkruntime.AgentConfig{}) + result, err := rt.AuthorizeTool(context.Background(), noopLog(), []interfaces.Tool{tool}, "gated", nil) require.NoError(t, err) require.False(t, result.Allowed) require.Equal(t, "not allowed", result.Reason) @@ -254,7 +241,7 @@ func TestAuthorizeTool_Denied(t *testing.T) { // --- ExecuteRetrievers --- func TestExecuteRetrievers_NoRetrievers(t *testing.T) { - rt := newTestRuntime(sdkruntime.AgentExecution{}) + rt := newTestRuntime(sdkruntime.AgentConfig{}) got, err := rt.ExecuteRetrievers(context.Background(), noopLog(), "query") require.NoError(t, err) require.Equal(t, "", got) @@ -266,7 +253,7 @@ func TestExecuteRetrievers_AllFail(t *testing.T) { r.EXPECT().Name().Return("r1").AnyTimes() r.EXPECT().Search(gomock.Any(), gomock.Any()).Return(nil, errors.New("down")) - rt := newTestRuntime(sdkruntime.AgentExecution{ + rt := newTestRuntime(sdkruntime.AgentConfig{ Retrievers: sdkruntime.AgentRetrievers{Retrievers: []interfaces.Retriever{r}}, }) _, err := rt.ExecuteRetrievers(context.Background(), noopLog(), "q") @@ -282,7 +269,7 @@ func TestExecuteRetrievers_Success(t *testing.T) { {Content: "doc content", Source: "src", Score: 0.95}, }, nil) - rt := newTestRuntime(sdkruntime.AgentExecution{ + rt := newTestRuntime(sdkruntime.AgentConfig{ Retrievers: sdkruntime.AgentRetrievers{Retrievers: []interfaces.Retriever{r}}, }) got, err := rt.ExecuteRetrievers(context.Background(), noopLog(), "my query") @@ -293,28 +280,28 @@ func TestExecuteRetrievers_Success(t *testing.T) { // --- ExecuteLLM --- func TestExecuteLLM_LLMError(t *testing.T) { - rt := newTestRuntime(sdkruntime.AgentExecution{ + rt := newTestRuntime(sdkruntime.AgentConfig{ LLM: sdkruntime.AgentLLM{Client: stubLLMClient{err: errors.New("llm unavailable")}}, }) - _, err := rt.ExecuteLLM(context.Background(), noopLog(), "agent", "msg-1", nil, false, "", nil) + _, err := rt.ExecuteLLM(context.Background(), noopLog(), "agent", "msg-1", nil, false, "", nil, nil) require.Error(t, err) require.Contains(t, err.Error(), "llm unavailable") } func TestExecuteLLM_Success_NoTools(t *testing.T) { - rt := newTestRuntime(sdkruntime.AgentExecution{ + rt := newTestRuntime(sdkruntime.AgentConfig{ LLM: sdkruntime.AgentLLM{Client: stubLLMClient{ resp: &interfaces.LLMResponse{Content: "hello world"}, }}, }) - result, err := rt.ExecuteLLM(context.Background(), noopLog(), "agent", "msg-1", nil, false, "", nil) + result, err := rt.ExecuteLLM(context.Background(), noopLog(), "agent", "msg-1", nil, false, "", nil, nil) require.NoError(t, err) require.Equal(t, "hello world", result.Content) require.Empty(t, result.ToolCalls) } func TestExecuteLLM_EmitsTextMessageEvents(t *testing.T) { - rt := newTestRuntime(sdkruntime.AgentExecution{ + rt := newTestRuntime(sdkruntime.AgentConfig{ LLM: sdkruntime.AgentLLM{Client: stubLLMClient{ resp: &interfaces.LLMResponse{Content: "response text"}, }}, @@ -325,7 +312,7 @@ func TestExecuteLLM_EmitsTextMessageEvents(t *testing.T) { emitted = append(emitted, ev.Type()) } - _, err := rt.ExecuteLLM(context.Background(), noopLog(), "agent", "msg-1", nil, false, "", emit) + _, err := rt.ExecuteLLM(context.Background(), noopLog(), "agent", "msg-1", nil, false, "", nil, emit) require.NoError(t, err) require.Equal(t, []events.AgentEventType{ events.AgentEventTypeTextMessageStart, @@ -335,18 +322,18 @@ func TestExecuteLLM_EmitsTextMessageEvents(t *testing.T) { } func TestExecuteLLM_NilEmitDoesNotPanic(t *testing.T) { - rt := newTestRuntime(sdkruntime.AgentExecution{ + rt := newTestRuntime(sdkruntime.AgentConfig{ LLM: sdkruntime.AgentLLM{Client: stubLLMClient{ resp: &interfaces.LLMResponse{Content: "ok"}, }}, }) require.NotPanics(t, func() { - _, _ = rt.ExecuteLLM(context.Background(), noopLog(), "a", "m", nil, false, "", nil) + _, _ = rt.ExecuteLLM(context.Background(), noopLog(), "a", "m", nil, false, "", nil, nil) }) } func TestExecuteLLM_UnknownToolCallReturnsError(t *testing.T) { - rt := newTestRuntime(sdkruntime.AgentExecution{ + rt := newTestRuntime(sdkruntime.AgentConfig{ LLM: sdkruntime.AgentLLM{Client: stubLLMClient{ resp: &interfaces.LLMResponse{ Content: "", @@ -356,13 +343,13 @@ func TestExecuteLLM_UnknownToolCallReturnsError(t *testing.T) { }, }}, }) - _, err := rt.ExecuteLLM(context.Background(), noopLog(), "a", "m", nil, false, "", nil) + _, err := rt.ExecuteLLM(context.Background(), noopLog(), "a", "m", nil, false, "", nil, nil) require.Error(t, err) require.Contains(t, err.Error(), "unknown tool") } func TestExecuteLLM_WithUsageMetrics(t *testing.T) { - rt := newTestRuntime(sdkruntime.AgentExecution{ + rt := newTestRuntime(sdkruntime.AgentConfig{ LLM: sdkruntime.AgentLLM{Client: stubLLMClient{ resp: &interfaces.LLMResponse{ Content: "ok", @@ -370,7 +357,7 @@ func TestExecuteLLM_WithUsageMetrics(t *testing.T) { }, }}, }) - result, err := rt.ExecuteLLM(context.Background(), noopLog(), "a", "m", nil, false, "", nil) + result, err := rt.ExecuteLLM(context.Background(), noopLog(), "a", "m", nil, false, "", nil, nil) require.NoError(t, err) require.NotNil(t, result.Usage) require.EqualValues(t, 10, result.Usage.PromptTokens) @@ -384,7 +371,7 @@ func TestExecuteLLM_ToolCallWithEmptyDisplayName(t *testing.T) { tool.EXPECT().Parameters().Return(nil).AnyTimes() tool.EXPECT().DisplayName().Return("").AnyTimes() // empty → falls back to tool name - rt := newTestRuntime(sdkruntime.AgentExecution{ + rt := newTestRuntime(sdkruntime.AgentConfig{ LLM: sdkruntime.AgentLLM{Client: stubLLMClient{ resp: &interfaces.LLMResponse{ ToolCalls: []*interfaces.ToolCall{ @@ -392,16 +379,15 @@ func TestExecuteLLM_ToolCallWithEmptyDisplayName(t *testing.T) { }, }, }}, - Tools: sdkruntime.AgentTools{Tools: []interfaces.Tool{tool}}, }) - result, err := rt.ExecuteLLM(context.Background(), noopLog(), "a", "m", nil, false, "", nil) + result, err := rt.ExecuteLLM(context.Background(), noopLog(), "a", "m", nil, false, "", []interfaces.Tool{tool}, nil) require.NoError(t, err) require.Len(t, result.ToolCalls, 1) require.Equal(t, "my-tool", result.ToolCalls[0].ToolDisplayName) } func TestExecuteLLM_NilToolCallInResponse(t *testing.T) { - rt := newTestRuntime(sdkruntime.AgentExecution{ + rt := newTestRuntime(sdkruntime.AgentConfig{ LLM: sdkruntime.AgentLLM{Client: stubLLMClient{ resp: &interfaces.LLMResponse{ Content: "answer", @@ -409,7 +395,7 @@ func TestExecuteLLM_NilToolCallInResponse(t *testing.T) { }, }}, }) - result, err := rt.ExecuteLLM(context.Background(), noopLog(), "a", "m", nil, false, "", nil) + result, err := rt.ExecuteLLM(context.Background(), noopLog(), "a", "m", nil, false, "", nil, nil) require.NoError(t, err) require.Empty(t, result.ToolCalls) } @@ -422,8 +408,8 @@ func TestRequiresApproval_PolicyOverrides(t *testing.T) { policy := ifmocks.NewMockAgentToolApprovalPolicy(ctrl) policy.EXPECT().RequiresApproval(tool).Return(true) - rt := newTestRuntime(sdkruntime.AgentExecution{ - Tools: sdkruntime.AgentTools{ApprovalPolicy: policy}, + rt := newTestRuntime(sdkruntime.AgentConfig{ + ToolApprovalPolicy: policy, }) require.True(t, rt.RequiresApproval(tool)) } @@ -432,10 +418,8 @@ func TestRequiresApproval_PolicyOverrides(t *testing.T) { func TestAuthorizeTool_AuthorizerError(t *testing.T) { tool := authorizerToolStub{name: "err-tool", err: errors.New("auth backend down")} - rt := newTestRuntime(sdkruntime.AgentExecution{ - Tools: sdkruntime.AgentTools{Tools: []interfaces.Tool{tool}}, - }) - _, err := rt.AuthorizeTool(context.Background(), noopLog(), "err-tool", nil) + rt := newTestRuntime(sdkruntime.AgentConfig{}) + _, err := rt.AuthorizeTool(context.Background(), noopLog(), []interfaces.Tool{tool}, "err-tool", nil) require.Error(t, err) require.Contains(t, err.Error(), "auth backend down") } @@ -454,7 +438,7 @@ func TestExecuteRetrievers_PartialFailure(t *testing.T) { bad.EXPECT().Name().Return("bad").AnyTimes() bad.EXPECT().Search(gomock.Any(), gomock.Any()).Return(nil, errors.New("timeout")) - rt := newTestRuntime(sdkruntime.AgentExecution{ + rt := newTestRuntime(sdkruntime.AgentConfig{ Retrievers: sdkruntime.AgentRetrievers{Retrievers: []interfaces.Retriever{good, bad}}, }) got, err := rt.ExecuteRetrievers(context.Background(), noopLog(), "q") @@ -502,27 +486,27 @@ func (s *fixedStream) Err() error { return s.err } func (s *fixedStream) GetResult() *interfaces.LLMResponse { return s.result } func TestExecuteLLMStream_FallbackGenerate_Success(t *testing.T) { - rt := newTestRuntime(sdkruntime.AgentExecution{ + rt := newTestRuntime(sdkruntime.AgentConfig{ LLM: sdkruntime.AgentLLM{Client: stubLLMClient{ resp: &interfaces.LLMResponse{Content: "fallback answer"}, }}, }) - result, err := rt.ExecuteLLMStream(context.Background(), noopLog(), "agent", "msg", nil, false, "", nil) + result, err := rt.ExecuteLLMStream(context.Background(), noopLog(), "agent", "msg", nil, false, "", nil, nil) require.NoError(t, err) require.Equal(t, "fallback answer", result.Content) } func TestExecuteLLMStream_FallbackGenerate_LLMError(t *testing.T) { - rt := newTestRuntime(sdkruntime.AgentExecution{ + rt := newTestRuntime(sdkruntime.AgentConfig{ LLM: sdkruntime.AgentLLM{Client: stubLLMClient{err: errors.New("llm down")}}, }) - _, err := rt.ExecuteLLMStream(context.Background(), noopLog(), "agent", "msg", nil, false, "", nil) + _, err := rt.ExecuteLLMStream(context.Background(), noopLog(), "agent", "msg", nil, false, "", nil, nil) require.Error(t, err) require.Contains(t, err.Error(), "llm down") } func TestExecuteLLMStream_FallbackGenerate_EmitsEvents(t *testing.T) { - rt := newTestRuntime(sdkruntime.AgentExecution{ + rt := newTestRuntime(sdkruntime.AgentConfig{ LLM: sdkruntime.AgentLLM{Client: stubLLMClient{ resp: &interfaces.LLMResponse{Content: "hi"}, }}, @@ -530,7 +514,7 @@ func TestExecuteLLMStream_FallbackGenerate_EmitsEvents(t *testing.T) { var emitted []events.AgentEventType emit := func(ev events.AgentEvent) { emitted = append(emitted, ev.Type()) } - _, err := rt.ExecuteLLMStream(context.Background(), noopLog(), "agent", "msg", nil, false, "", emit) + _, err := rt.ExecuteLLMStream(context.Background(), noopLog(), "agent", "msg", nil, false, "", nil, emit) require.NoError(t, err) require.Equal(t, []events.AgentEventType{ events.AgentEventTypeTextMessageStart, @@ -540,12 +524,12 @@ func TestExecuteLLMStream_FallbackGenerate_EmitsEvents(t *testing.T) { } func TestExecuteLLMStream_GenerateStreamError(t *testing.T) { - rt := newTestRuntime(sdkruntime.AgentExecution{ + rt := newTestRuntime(sdkruntime.AgentConfig{ LLM: sdkruntime.AgentLLM{Client: streamCapableLLMClient{ streamErr: errors.New("stream init failed"), }}, }) - _, err := rt.ExecuteLLMStream(context.Background(), noopLog(), "agent", "msg", nil, false, "", nil) + _, err := rt.ExecuteLLMStream(context.Background(), noopLog(), "agent", "msg", nil, false, "", nil, nil) require.Error(t, err) require.Contains(t, err.Error(), "stream init failed") } @@ -556,10 +540,10 @@ func TestExecuteLLMStream_StreamError_AfterChunks(t *testing.T) { }, nil) s.err = errors.New("connection reset") - rt := newTestRuntime(sdkruntime.AgentExecution{ + rt := newTestRuntime(sdkruntime.AgentConfig{ LLM: sdkruntime.AgentLLM{Client: streamCapableLLMClient{stream: s}}, }) - _, err := rt.ExecuteLLMStream(context.Background(), noopLog(), "agent", "msg", nil, false, "", nil) + _, err := rt.ExecuteLLMStream(context.Background(), noopLog(), "agent", "msg", nil, false, "", nil, nil) require.Error(t, err) require.Contains(t, err.Error(), "connection reset") } @@ -567,10 +551,10 @@ func TestExecuteLLMStream_StreamError_AfterChunks(t *testing.T) { func TestExecuteLLMStream_StreamNilResult(t *testing.T) { s := newFixedStream(nil, nil) // no chunks, GetResult() returns nil - rt := newTestRuntime(sdkruntime.AgentExecution{ + rt := newTestRuntime(sdkruntime.AgentConfig{ LLM: sdkruntime.AgentLLM{Client: streamCapableLLMClient{stream: s}}, }) - _, err := rt.ExecuteLLMStream(context.Background(), noopLog(), "agent", "msg", nil, false, "", nil) + _, err := rt.ExecuteLLMStream(context.Background(), noopLog(), "agent", "msg", nil, false, "", nil, nil) require.Error(t, err) require.Contains(t, err.Error(), "stream completed without result") } @@ -581,14 +565,14 @@ func TestExecuteLLMStream_TextChunks_EmitsCorrectEvents(t *testing.T) { {ContentDelta: " world"}, }, &interfaces.LLMResponse{Content: "hello world"}) - rt := newTestRuntime(sdkruntime.AgentExecution{ + rt := newTestRuntime(sdkruntime.AgentConfig{ LLM: sdkruntime.AgentLLM{Client: streamCapableLLMClient{stream: s}}, }) var emitted []events.AgentEventType emit := func(ev events.AgentEvent) { emitted = append(emitted, ev.Type()) } - result, err := rt.ExecuteLLMStream(context.Background(), noopLog(), "agent", "msg", nil, false, "", emit) + result, err := rt.ExecuteLLMStream(context.Background(), noopLog(), "agent", "msg", nil, false, "", nil, emit) require.NoError(t, err) require.Equal(t, "hello world", result.Content) require.Equal(t, events.AgentEventTypeTextMessageStart, emitted[0]) @@ -603,14 +587,14 @@ func TestExecuteLLMStream_ReasoningChunks_EmitsReasoningEvents(t *testing.T) { {ContentDelta: "answer"}, }, &interfaces.LLMResponse{Content: "answer"}) - rt := newTestRuntime(sdkruntime.AgentExecution{ + rt := newTestRuntime(sdkruntime.AgentConfig{ LLM: sdkruntime.AgentLLM{Client: streamCapableLLMClient{stream: s}}, }) var emitted []events.AgentEventType emit := func(ev events.AgentEvent) { emitted = append(emitted, ev.Type()) } - _, err := rt.ExecuteLLMStream(context.Background(), noopLog(), "agent", "msg", nil, false, "", emit) + _, err := rt.ExecuteLLMStream(context.Background(), noopLog(), "agent", "msg", nil, false, "", nil, emit) require.NoError(t, err) // Reasoning events must appear before text events @@ -635,15 +619,14 @@ func TestExecuteLLMStream_ToolOnlyResponse_EmitsEmptyAssistantTurn(t *testing.T) ToolCalls: []*interfaces.ToolCall{{ToolCallID: "1", ToolName: "search"}}, }) - rt := newTestRuntime(sdkruntime.AgentExecution{ - LLM: sdkruntime.AgentLLM{Client: streamCapableLLMClient{stream: s}}, - Tools: sdkruntime.AgentTools{Tools: []interfaces.Tool{tool}}, + rt := newTestRuntime(sdkruntime.AgentConfig{ + LLM: sdkruntime.AgentLLM{Client: streamCapableLLMClient{stream: s}}, }) var emitted []events.AgentEventType emit := func(ev events.AgentEvent) { emitted = append(emitted, ev.Type()) } - result, err := rt.ExecuteLLMStream(context.Background(), noopLog(), "agent", "msg", nil, false, "", emit) + result, err := rt.ExecuteLLMStream(context.Background(), noopLog(), "agent", "msg", nil, false, "", []interfaces.Tool{tool}, emit) require.NoError(t, err) require.Len(t, result.ToolCalls, 1) // finalizeAssistantText emits a start/content/end even when no text chunks arrived @@ -656,10 +639,10 @@ func TestExecuteLLMStream_WithUsageMetrics(t *testing.T) { Content: "done", Usage: &interfaces.LLMUsage{PromptTokens: 8, CompletionTokens: 4, TotalTokens: 12}, }) - rt := newTestRuntime(sdkruntime.AgentExecution{ + rt := newTestRuntime(sdkruntime.AgentConfig{ LLM: sdkruntime.AgentLLM{Client: streamCapableLLMClient{stream: s}}, }) - result, err := rt.ExecuteLLMStream(context.Background(), noopLog(), "agent", "msg", nil, false, "", nil) + result, err := rt.ExecuteLLMStream(context.Background(), noopLog(), "agent", "msg", nil, false, "", nil, nil) require.NoError(t, err) require.NotNil(t, result.Usage) require.EqualValues(t, 8, result.Usage.PromptTokens) @@ -671,16 +654,16 @@ func TestExecuteLLMStream_NilChunkSkipped(t *testing.T) { {ContentDelta: "text"}, }, &interfaces.LLMResponse{Content: "text"}) - rt := newTestRuntime(sdkruntime.AgentExecution{ + rt := newTestRuntime(sdkruntime.AgentConfig{ LLM: sdkruntime.AgentLLM{Client: streamCapableLLMClient{stream: s}}, }) - result, err := rt.ExecuteLLMStream(context.Background(), noopLog(), "agent", "msg", nil, false, "", nil) + result, err := rt.ExecuteLLMStream(context.Background(), noopLog(), "agent", "msg", nil, false, "", nil, nil) require.NoError(t, err) require.Equal(t, "text", result.Content) } func TestExecuteLLMStream_FallbackGenerate_WithUsage(t *testing.T) { - rt := newTestRuntime(sdkruntime.AgentExecution{ + rt := newTestRuntime(sdkruntime.AgentConfig{ LLM: sdkruntime.AgentLLM{Client: stubLLMClient{ resp: &interfaces.LLMResponse{ Content: "done", @@ -688,21 +671,21 @@ func TestExecuteLLMStream_FallbackGenerate_WithUsage(t *testing.T) { }, }}, }) - result, err := rt.ExecuteLLMStream(context.Background(), noopLog(), "agent", "msg", nil, false, "", nil) + result, err := rt.ExecuteLLMStream(context.Background(), noopLog(), "agent", "msg", nil, false, "", nil, nil) require.NoError(t, err) require.NotNil(t, result.Usage) require.EqualValues(t, 5, result.Usage.PromptTokens) } func TestExecuteLLMStream_FallbackGenerate_UnknownToolCallError(t *testing.T) { - rt := newTestRuntime(sdkruntime.AgentExecution{ + rt := newTestRuntime(sdkruntime.AgentConfig{ LLM: sdkruntime.AgentLLM{Client: stubLLMClient{ resp: &interfaces.LLMResponse{ ToolCalls: []*interfaces.ToolCall{{ToolCallID: "1", ToolName: "ghost"}}, }, }}, }) - _, err := rt.ExecuteLLMStream(context.Background(), noopLog(), "agent", "msg", nil, false, "", nil) + _, err := rt.ExecuteLLMStream(context.Background(), noopLog(), "agent", "msg", nil, false, "", nil, nil) require.Error(t, err) require.Contains(t, err.Error(), "unknown tool") } @@ -711,10 +694,10 @@ func TestExecuteLLMStream_Stream_UnknownToolCallError(t *testing.T) { s := newFixedStream(nil, &interfaces.LLMResponse{ ToolCalls: []*interfaces.ToolCall{{ToolCallID: "1", ToolName: "ghost"}}, }) - rt := newTestRuntime(sdkruntime.AgentExecution{ + rt := newTestRuntime(sdkruntime.AgentConfig{ LLM: sdkruntime.AgentLLM{Client: streamCapableLLMClient{stream: s}}, }) - _, err := rt.ExecuteLLMStream(context.Background(), noopLog(), "agent", "msg", nil, false, "", nil) + _, err := rt.ExecuteLLMStream(context.Background(), noopLog(), "agent", "msg", nil, false, "", nil, nil) require.Error(t, err) require.Contains(t, err.Error(), "unknown tool") } @@ -725,7 +708,7 @@ func TestExecuteRetrievers_EmptyDocsSkipped(t *testing.T) { r.EXPECT().Name().Return("empty-kb").AnyTimes() r.EXPECT().Search(gomock.Any(), gomock.Any()).Return([]interfaces.Document{}, nil) // no docs - rt := newTestRuntime(sdkruntime.AgentExecution{ + rt := newTestRuntime(sdkruntime.AgentConfig{ Retrievers: sdkruntime.AgentRetrievers{Retrievers: []interfaces.Retriever{r}}, }) got, err := rt.ExecuteRetrievers(context.Background(), noopLog(), "q") diff --git a/internal/runtime/local/agent_loop.go b/internal/runtime/local/agent_loop.go index 6ac4980..693e624 100644 --- a/internal/runtime/local/agent_loop.go +++ b/internal/runtime/local/agent_loop.go @@ -23,11 +23,13 @@ const ( // AgentLoopInput holds per-run execution inputs for one local agent run. // Mirrors AgentWorkflowInput (Temporal) for in-process execution — same fields, same semantics. -// Agent-level configuration (ToolExecutionMode, LLM, tools, limits) lives on the runtime itself. +// Static agent wiring lives on the runtime [base.Runtime.AgentConfig]; resolved tools are per-run on Tools. type AgentLoopInput struct { UserPrompt string ConversationID string StreamingEnabled bool + // Tools is the resolved tool list for this run. + Tools []interfaces.Tool // ChannelName is the eventbus channel events are published to during this run. // Sub-agents receive the parent's ChannelName so their events go directly to the parent stream. // Empty = no event fanout. @@ -77,9 +79,11 @@ func (rt *LocalRuntime) publishEventToChannel(ctx context.Context, channelName s func (rt *LocalRuntime) RunAgentLoop(ctx context.Context, input AgentLoopInput) (*AgentLoopResult, error) { log := rt.logger agentName := rt.AgentSpec.Name - model := rt.AgentExecution.LLM.Client.GetModel() + model := rt.AgentConfig.LLM.Client.GetModel() - maxIter := rt.AgentExecution.Limits.MaxIterations + tools := input.Tools + + maxIter := rt.AgentConfig.Limits.MaxIterations if maxIter <= 0 { maxIter = 10 } @@ -116,13 +120,13 @@ func (rt *LocalRuntime) RunAgentLoop(ctx context.Context, input AgentLoopInput) // Pre-fetch retriever context for prefetch/hybrid modes. retrieverContext := "" - retrieverMode := rt.AgentExecution.Retrievers.Mode + retrieverMode := rt.AgentConfig.Retrievers.Mode if (retrieverMode == types.RetrieverModePrefetch || retrieverMode == types.RetrieverModeHybrid) && - len(rt.AgentExecution.Retrievers.Retrievers) > 0 { + len(rt.AgentConfig.Retrievers.Retrievers) > 0 { log.Debug(ctx, "local: retriever prefetch started", slog.String("scope", "loop"), slog.String("mode", string(retrieverMode)), - slog.Int("retrieverCount", len(rt.AgentExecution.Retrievers.Retrievers))) + slog.Int("retrieverCount", len(rt.AgentConfig.Retrievers.Retrievers))) rc, err := rt.ExecuteRetrievers(ctx, log, input.UserPrompt) if err != nil { return nil, fmt.Errorf("retriever prefetch: %w", err) @@ -147,9 +151,9 @@ func (rt *LocalRuntime) RunAgentLoop(ctx context.Context, input AgentLoopInput) var llmResult *base.LLMResult var err error if input.StreamingEnabled { - llmResult, err = rt.ExecuteLLMStream(ctx, log, agentName, messageID, messages, false, retrieverContext, emit) + llmResult, err = rt.ExecuteLLMStream(ctx, log, agentName, messageID, messages, false, retrieverContext, tools, emit) } else { - llmResult, err = rt.ExecuteLLM(ctx, log, agentName, messageID, messages, false, retrieverContext, emit) + llmResult, err = rt.ExecuteLLM(ctx, log, agentName, messageID, messages, false, retrieverContext, tools, emit) } if err != nil { return nil, fmt.Errorf("llm call (iter %d): %w", iter, err) @@ -174,9 +178,9 @@ func (rt *LocalRuntime) RunAgentLoop(ctx context.Context, input AgentLoopInput) slog.Int("iteration", iter)) finalMessageID := uuid.New().String() if input.StreamingEnabled { - llmResult, err = rt.ExecuteLLMStream(ctx, log, agentName, finalMessageID, messages, true, retrieverContext, emit) + llmResult, err = rt.ExecuteLLMStream(ctx, log, agentName, finalMessageID, messages, true, retrieverContext, tools, emit) } else { - llmResult, err = rt.ExecuteLLM(ctx, log, agentName, finalMessageID, messages, true, retrieverContext, emit) + llmResult, err = rt.ExecuteLLM(ctx, log, agentName, finalMessageID, messages, true, retrieverContext, tools, emit) } if err != nil { return nil, fmt.Errorf("llm final call (iter %d): %w", iter, err) @@ -224,7 +228,7 @@ func (rt *LocalRuntime) RunAgentLoop(ctx context.Context, input AgentLoopInput) messages = append(messages, toolResults...) - if rt.conversationMemoryEnabled(input) && rt.AgentExecution.Session.ConversationSaveOnIteration && len(messages) > persistedMessageCount { + if rt.conversationMemoryEnabled(input) && rt.AgentConfig.Session.ConversationSaveOnIteration && len(messages) > persistedMessageCount { if err := persistConversationMessages(ctx, rt, input.ConversationID, messages[persistedMessageCount:]); err != nil { log.Warn(ctx, "local: persist conversation failed", slog.String("scope", "loop"), @@ -256,7 +260,7 @@ func (rt *LocalRuntime) RunAgentLoop(ctx context.Context, input AgentLoopInput) } func (rt *LocalRuntime) conversationMemoryEnabled(input AgentLoopInput) bool { - return input.ConversationID != "" && rt.AgentExecution.Session.Conversation != nil + return input.ConversationID != "" && rt.AgentConfig.Session.Conversation != nil } // executeToolsParallel runs all tool calls concurrently and collects results in submission order. @@ -336,6 +340,7 @@ func (rt *LocalRuntime) executeSingleTool( emit func(events.AgentEvent), ) (interfaces.Message, error) { log := rt.logger + tools := input.Tools emitToolEndThenResult := func(toolCallID, content string) { emit(events.NewAgentToolCallEndEvent(toolCallID)) @@ -354,7 +359,7 @@ func (rt *LocalRuntime) executeSingleTool( } // Authorization check. - authResult, err := rt.AuthorizeTool(ctx, log, tc.ToolName, tc.Args) + authResult, err := rt.AuthorizeTool(ctx, log, tools, tc.ToolName, tc.Args) if err != nil { return interfaces.Message{}, fmt.Errorf("tool authorization error for %q: %w", tc.ToolName, err) } @@ -483,6 +488,7 @@ func (rt *LocalRuntime) executeSingleTool( SubAgentRoutes: subAgentRoute.children, SubAgentDepth: input.SubAgentDepth + 1, MaxSubAgentDepth: input.MaxSubAgentDepth, + Tools: subAgentRoute.tools, }) emit(events.NewAgentStepFinishedEvent(stepName)) if execErr != nil { @@ -498,7 +504,7 @@ func (rt *LocalRuntime) executeSingleTool( slog.String("scope", "loop"), slog.String("tool", tc.ToolName), slog.String("toolCallID", tc.ToolCallID)) - result, execErr := rt.ExecuteTool(ctx, log, tc.ToolName, tc.Args) + result, execErr := rt.ExecuteTool(ctx, log, tools, tc.ToolName, tc.Args) if execErr != nil { content = "Tool execution failed: " + execErr.Error() } else { @@ -524,7 +530,7 @@ func (rt *LocalRuntime) executeSingleTool( // persistConversationMessages stores all accumulated messages from the run into the conversation store. func persistConversationMessages(ctx context.Context, rt *LocalRuntime, conversationID string, messages []interfaces.Message) error { - conv := rt.AgentExecution.Session.Conversation + conv := rt.AgentConfig.Session.Conversation if conv == nil { return nil } diff --git a/internal/runtime/local/agent_loop_test.go b/internal/runtime/local/agent_loop_test.go index 96b4d64..9559ee4 100644 --- a/internal/runtime/local/agent_loop_test.go +++ b/internal/runtime/local/agent_loop_test.go @@ -19,21 +19,30 @@ import ( "github.com/stretchr/testify/require" ) -// newLoopRT builds a LocalRuntime with the given LLM client and tools. -// Unlike newLocalRT it also accepts MaxIterations so loop tests can control iterations precisely. -func newLoopRT(t *testing.T, maxIter int, client interfaces.LLMClient, tools ...interfaces.Tool) *LocalRuntime { +// newLoopRT builds a LocalRuntime with the given LLM client and optional tools. +func newLoopRT(t *testing.T, maxIter int, client interfaces.LLMClient, tools ...interfaces.Tool) (*LocalRuntime, []interfaces.Tool) { t.Helper() rt, err := NewLocalRuntime( WithLogger(logger.NoopLogger()), WithAgentSpec(sdkruntime.AgentSpec{Name: "loop-agent", SystemPrompt: "sys"}), - WithAgentExecution(sdkruntime.AgentExecution{ + WithAgentConfig(sdkruntime.AgentConfig{ LLM: sdkruntime.AgentLLM{Client: client}, - Tools: sdkruntime.AgentTools{Tools: tools}, Limits: sdkruntime.AgentLimits{MaxIterations: maxIter, Timeout: 10 * time.Second}, }), ) require.NoError(t, err) - return rt + return rt, tools +} + +func runLoop(ctx context.Context, rt *LocalRuntime, tools []interfaces.Tool, in AgentLoopInput) (*AgentLoopResult, error) { + if len(in.Tools) == 0 { + in.Tools = tools + } + return rt.RunAgentLoop(ctx, in) +} + +func loopToolsInput(tools []interfaces.Tool) AgentLoopInput { + return AgentLoopInput{Tools: tools} } // noopEmit discards all events. @@ -53,18 +62,18 @@ func TestRunAgentLoop_SimpleTextResponse(t *testing.T) { client := &seqLLMClient{ responses: []*interfaces.LLMResponse{{Content: "hello world"}}, } - rt := newLoopRT(t, 5, client) + rt, _ := newLoopRT(t, 5, client) - result, err := rt.RunAgentLoop(context.Background(), AgentLoopInput{UserPrompt: "hi"}) + result, err := runLoop(context.Background(), rt, nil, AgentLoopInput{UserPrompt: "hi"}) require.NoError(t, err) require.Equal(t, "hello world", result.Content) } func TestRunAgentLoop_LLMError(t *testing.T) { client := &seqLLMClient{errs: []error{errors.New("llm fail")}} - rt := newLoopRT(t, 5, client) + rt, _ := newLoopRT(t, 5, client) - _, err := rt.RunAgentLoop(context.Background(), AgentLoopInput{UserPrompt: "hi"}) + _, err := runLoop(context.Background(), rt, nil, AgentLoopInput{UserPrompt: "hi"}) require.Error(t, err) require.Contains(t, err.Error(), "llm fail") } @@ -77,14 +86,14 @@ func TestRunAgentLoop_DefaultMaxIterations(t *testing.T) { } rt, err := NewLocalRuntime( WithLogger(logger.NoopLogger()), - WithAgentExecution(sdkruntime.AgentExecution{ + WithAgentConfig(sdkruntime.AgentConfig{ LLM: sdkruntime.AgentLLM{Client: client}, Limits: sdkruntime.AgentLimits{MaxIterations: 0, Timeout: 10 * time.Second}, }), ) require.NoError(t, err) - result, err := rt.RunAgentLoop(context.Background(), AgentLoopInput{UserPrompt: "hi"}) + result, err := runLoop(context.Background(), rt, nil, AgentLoopInput{UserPrompt: "hi"}) require.NoError(t, err) require.Equal(t, "early exit", result.Content) } @@ -97,9 +106,9 @@ func TestRunAgentLoop_ToolCallThenFinalAnswer(t *testing.T) { }, } tool := stubTool{name: "add", result: "7"} - rt := newLoopRT(t, 5, client, tool) + rt, tools := newLoopRT(t, 5, client, tool) - result, err := rt.RunAgentLoop(context.Background(), AgentLoopInput{UserPrompt: "add"}) + result, err := runLoop(context.Background(), rt, tools, AgentLoopInput{UserPrompt: "add"}) require.NoError(t, err) require.Equal(t, "sum is 7", result.Content) } @@ -114,9 +123,9 @@ func TestRunAgentLoop_MaxIterationsForcesFinalCall(t *testing.T) { }, } tool := stubTool{name: "add", result: "7"} - rt := newLoopRT(t, 1, client, tool) + rt, tools := newLoopRT(t, 1, client, tool) - result, err := rt.RunAgentLoop(context.Background(), AgentLoopInput{UserPrompt: "add"}) + result, err := runLoop(context.Background(), rt, tools, AgentLoopInput{UserPrompt: "add"}) require.NoError(t, err) require.Equal(t, "forced final answer", result.Content) } @@ -137,10 +146,10 @@ func TestRunAgentLoop_SequentialMode(t *testing.T) { } tool1 := stubTool{name: "t1", result: "r1"} tool2 := stubTool{name: "t2", result: "r2"} - rt := newLoopRT(t, 5, client, tool1, tool2) + rt, tools := newLoopRT(t, 5, client, tool1, tool2) rt.ToolExecutionMode = types.AgentToolExecutionModeSequential - result, err := rt.RunAgentLoop(context.Background(), AgentLoopInput{UserPrompt: "go"}) + result, err := runLoop(context.Background(), rt, tools, AgentLoopInput{UserPrompt: "go"}) require.NoError(t, err) require.Equal(t, "sequential done", result.Content) } @@ -152,10 +161,10 @@ func TestRunAgentLoop_InvalidToolMode(t *testing.T) { }, } tool := stubTool{name: "t1", result: "r"} - rt := newLoopRT(t, 5, client, tool) + rt, tools := newLoopRT(t, 5, client, tool) rt.ToolExecutionMode = "invalid-mode" - _, err := rt.RunAgentLoop(context.Background(), AgentLoopInput{UserPrompt: "go"}) + _, err := runLoop(context.Background(), rt, tools, AgentLoopInput{UserPrompt: "go"}) require.Error(t, err) require.Contains(t, err.Error(), "invalid tool execution mode") } @@ -178,7 +187,7 @@ func TestRunAgentLoop_WithConversationID(t *testing.T) { } rt, err := NewLocalRuntime( WithLogger(logger.NoopLogger()), - WithAgentExecution(sdkruntime.AgentExecution{ + WithAgentConfig(sdkruntime.AgentConfig{ LLM: sdkruntime.AgentLLM{Client: client}, Session: sdkruntime.AgentSession{Conversation: conv, ConversationSize: 10}, Limits: sdkruntime.AgentLimits{MaxIterations: 5, Timeout: 5 * time.Second}, @@ -186,7 +195,7 @@ func TestRunAgentLoop_WithConversationID(t *testing.T) { ) require.NoError(t, err) - result, err := rt.RunAgentLoop(context.Background(), AgentLoopInput{ + result, err := runLoop(context.Background(), rt, nil, AgentLoopInput{ UserPrompt: "new question", ConversationID: "conv-x", }) @@ -206,7 +215,7 @@ func TestRunAgentLoop_ConversationFetchErrorContinues(t *testing.T) { } rt, err := NewLocalRuntime( WithLogger(logger.NoopLogger()), - WithAgentExecution(sdkruntime.AgentExecution{ + WithAgentConfig(sdkruntime.AgentConfig{ LLM: sdkruntime.AgentLLM{Client: client}, Session: sdkruntime.AgentSession{Conversation: conv}, Limits: sdkruntime.AgentLimits{MaxIterations: 5, Timeout: 5 * time.Second}, @@ -214,7 +223,7 @@ func TestRunAgentLoop_ConversationFetchErrorContinues(t *testing.T) { ) require.NoError(t, err) - result, err := rt.RunAgentLoop(context.Background(), AgentLoopInput{ + result, err := runLoop(context.Background(), rt, nil, AgentLoopInput{ UserPrompt: "hi", ConversationID: "bad-conv", }) @@ -240,7 +249,7 @@ func TestRunAgentLoop_RetrieverPrefetch(t *testing.T) { } rt, err := NewLocalRuntime( WithLogger(logger.NoopLogger()), - WithAgentExecution(sdkruntime.AgentExecution{ + WithAgentConfig(sdkruntime.AgentConfig{ LLM: sdkruntime.AgentLLM{Client: client}, Retrievers: sdkruntime.AgentRetrievers{ Mode: types.RetrieverModePrefetch, @@ -251,7 +260,7 @@ func TestRunAgentLoop_RetrieverPrefetch(t *testing.T) { ) require.NoError(t, err) - result, err := rt.RunAgentLoop(context.Background(), AgentLoopInput{UserPrompt: "fetch me"}) + result, err := runLoop(context.Background(), rt, nil, AgentLoopInput{UserPrompt: "fetch me"}) require.NoError(t, err) require.Equal(t, "answer with context", result.Content) } @@ -265,7 +274,7 @@ func TestRunAgentLoop_RetrieverPrefetchError(t *testing.T) { client := &seqLLMClient{} rt, err := NewLocalRuntime( WithLogger(logger.NoopLogger()), - WithAgentExecution(sdkruntime.AgentExecution{ + WithAgentConfig(sdkruntime.AgentConfig{ LLM: sdkruntime.AgentLLM{Client: client}, Retrievers: sdkruntime.AgentRetrievers{ Mode: types.RetrieverModePrefetch, @@ -276,7 +285,7 @@ func TestRunAgentLoop_RetrieverPrefetchError(t *testing.T) { ) require.NoError(t, err) - _, err = rt.RunAgentLoop(context.Background(), AgentLoopInput{UserPrompt: "fetch"}) + _, err = runLoop(context.Background(), rt, nil, AgentLoopInput{UserPrompt: "fetch"}) require.Error(t, err) require.Contains(t, err.Error(), "retriever prefetch") } @@ -293,7 +302,7 @@ func TestRunAgentLoop_ToolEventsEmittedToChannel(t *testing.T) { }, } tool := stubTool{name: "calc", result: "99"} - rt := newLoopRT(t, 5, client, tool) + rt, tools := newLoopRT(t, 5, client, tool) ctx := context.Background() channel := "test-tool-events" @@ -307,7 +316,7 @@ func TestRunAgentLoop_ToolEventsEmittedToChannel(t *testing.T) { // Run the loop in a goroutine; close the subscription after it finishes so eventCh drains. go func() { - _, _ = rt.RunAgentLoop(ctx, AgentLoopInput{ + _, _ = runLoop(ctx, rt, tools, AgentLoopInput{ UserPrompt: "compute", ChannelName: channel, }) @@ -343,14 +352,14 @@ done: func TestExecuteToolsParallel_AllSucceed(t *testing.T) { t1 := stubTool{name: "t1", result: "r1"} t2 := stubTool{name: "t2", result: "r2"} - rt := newLoopRT(t, 5, &seqLLMClient{}, t1, t2) + rt, tools := newLoopRT(t, 5, &seqLLMClient{}, t1, t2) calls := []base.ToolCallRequest{ {ToolCallID: "c1", ToolName: "t1"}, {ToolCallID: "c2", ToolName: "t2"}, } - msgs, err := rt.executeToolsParallel(context.Background(), AgentLoopInput{}, "msg-1", calls, noopEmit) + msgs, err := rt.executeToolsParallel(context.Background(), loopToolsInput(tools), "msg-1", calls, noopEmit) require.NoError(t, err) require.Len(t, msgs, 2) // Order must match submission order (parallel but results are indexed). @@ -361,10 +370,10 @@ func TestExecuteToolsParallel_AllSucceed(t *testing.T) { func TestExecuteToolsParallel_ToolErrorInMessage(t *testing.T) { // Parallel: individual tool errors become synthetic error messages, not hard failures. failing := stubTool{name: "bad", execErr: errors.New("boom")} - rt := newLoopRT(t, 5, &seqLLMClient{}, failing) + rt, tools := newLoopRT(t, 5, &seqLLMClient{}, failing) calls := []base.ToolCallRequest{{ToolCallID: "c1", ToolName: "bad"}} - msgs, err := rt.executeToolsParallel(context.Background(), AgentLoopInput{}, "msg", calls, noopEmit) + msgs, err := rt.executeToolsParallel(context.Background(), loopToolsInput(tools), "msg", calls, noopEmit) require.NoError(t, err) // parallel swallows into message require.Len(t, msgs, 1) require.Contains(t, msgs[0].Content, "boom") @@ -372,19 +381,19 @@ func TestExecuteToolsParallel_ToolErrorInMessage(t *testing.T) { func TestExecuteToolsParallel_ResultsOrderPreserved(t *testing.T) { // Three tools; verify result order matches submission order despite concurrency. - tools := []interfaces.Tool{ + toolSet := []interfaces.Tool{ stubTool{name: "a", result: "A"}, stubTool{name: "b", result: "B"}, stubTool{name: "c", result: "C"}, } - rt := newLoopRT(t, 5, &seqLLMClient{}, tools...) + rt, tools := newLoopRT(t, 5, &seqLLMClient{}, toolSet...) calls := []base.ToolCallRequest{ {ToolCallID: "1", ToolName: "a"}, {ToolCallID: "2", ToolName: "b"}, {ToolCallID: "3", ToolName: "c"}, } - msgs, err := rt.executeToolsParallel(context.Background(), AgentLoopInput{}, "m", calls, noopEmit) + msgs, err := rt.executeToolsParallel(context.Background(), loopToolsInput(tools), "m", calls, noopEmit) require.NoError(t, err) require.Equal(t, []string{"A", "B", "C"}, []string{msgs[0].Content, msgs[1].Content, msgs[2].Content}) } @@ -396,13 +405,13 @@ func TestExecuteToolsParallel_ResultsOrderPreserved(t *testing.T) { func TestExecuteToolsSequential_AllSucceed(t *testing.T) { t1 := stubTool{name: "s1", result: "v1"} t2 := stubTool{name: "s2", result: "v2"} - rt := newLoopRT(t, 5, &seqLLMClient{}, t1, t2) + rt, tools := newLoopRT(t, 5, &seqLLMClient{}, t1, t2) calls := []base.ToolCallRequest{ {ToolCallID: "c1", ToolName: "s1"}, {ToolCallID: "c2", ToolName: "s2"}, } - msgs, err := rt.executeToolsSequential(context.Background(), AgentLoopInput{}, "msg", calls, noopEmit) + msgs, err := rt.executeToolsSequential(context.Background(), loopToolsInput(tools), "msg", calls, noopEmit) require.NoError(t, err) require.Len(t, msgs, 2) require.Equal(t, "v1", msgs[0].Content) @@ -412,7 +421,7 @@ func TestExecuteToolsSequential_AllSucceed(t *testing.T) { func TestExecuteToolsSequential_HardErrorOnContextCancel(t *testing.T) { // A tool that blocks until ctx is cancelled → executeSingleTool returns ctx.Err(). // Sequential should propagate that error. - rt := newLoopRT(t, 5, &seqLLMClient{}) + rt, _ := newLoopRT(t, 5, &seqLLMClient{}) // Add a fake tool that needs approval with no channel or handler → unavailable (not an error). // Instead: use a blocking LLM as a proxy — but we need a tool-level error. // We'll cancel the context before calling. @@ -431,10 +440,10 @@ func TestExecuteToolsSequential_HardErrorOnContextCancel(t *testing.T) { func TestExecuteSingleTool_Approved(t *testing.T) { tool := stubTool{name: "my-tool", result: "hello"} - rt := newLoopRT(t, 5, &seqLLMClient{}, tool) + rt, tools := newLoopRT(t, 5, &seqLLMClient{}, tool) emit, evs := captureEmit() - msg, err := rt.executeSingleTool(context.Background(), AgentLoopInput{}, "msg-1", + msg, err := rt.executeSingleTool(context.Background(), loopToolsInput(tools), "msg-1", base.ToolCallRequest{ToolCallID: "c1", ToolName: "my-tool"}, emit) require.NoError(t, err) @@ -450,16 +459,16 @@ func TestExecuteSingleTool_Approved(t *testing.T) { func TestExecuteSingleTool_ToolExecError(t *testing.T) { tool := stubTool{name: "boom", execErr: errors.New("exec failed")} - rt := newLoopRT(t, 5, &seqLLMClient{}, tool) + rt, tools := newLoopRT(t, 5, &seqLLMClient{}, tool) - msg, err := rt.executeSingleTool(context.Background(), AgentLoopInput{}, "msg", + msg, err := rt.executeSingleTool(context.Background(), loopToolsInput(tools), "msg", base.ToolCallRequest{ToolCallID: "c1", ToolName: "boom"}, noopEmit) require.NoError(t, err) // tool errors become a content message, not a hard error require.Contains(t, msg.Content, "exec failed") } func TestExecuteSingleTool_UnknownToolErrors(t *testing.T) { - rt := newLoopRT(t, 5, &seqLLMClient{}) // no tools registered + rt, _ := newLoopRT(t, 5, &seqLLMClient{}) // no tools registered _, err := rt.executeSingleTool(context.Background(), AgentLoopInput{}, "msg", base.ToolCallRequest{ToolCallID: "c1", ToolName: "ghost"}, noopEmit) @@ -480,9 +489,9 @@ func TestExecuteSingleTool_AuthorizationDenied(t *testing.T) { // Use an authorizerToolStub from the runtime_test helpers (same package). authTool := authorizerStubLocal{name: "restricted", allow: false, reason: "policy denied"} - rt := newLoopRT(t, 5, &seqLLMClient{}, authTool) + rt, tools := newLoopRT(t, 5, &seqLLMClient{}, authTool) - msg, err := rt.executeSingleTool(context.Background(), AgentLoopInput{}, "msg", + msg, err := rt.executeSingleTool(context.Background(), loopToolsInput(tools), "msg", base.ToolCallRequest{ToolCallID: "c1", ToolName: "restricted"}, noopEmit) require.NoError(t, err) require.Contains(t, msg.Content, msgToolUnauthorized) @@ -491,9 +500,9 @@ func TestExecuteSingleTool_AuthorizationDenied(t *testing.T) { func TestExecuteSingleTool_AuthorizationError(t *testing.T) { authTool := authorizerStubLocal{name: "err-tool", allow: false, authErr: errors.New("auth backend down")} - rt := newLoopRT(t, 5, &seqLLMClient{}, authTool) + rt, tools := newLoopRT(t, 5, &seqLLMClient{}, authTool) - _, err := rt.executeSingleTool(context.Background(), AgentLoopInput{}, "msg", + _, err := rt.executeSingleTool(context.Background(), loopToolsInput(tools), "msg", base.ToolCallRequest{ToolCallID: "c1", ToolName: "err-tool"}, noopEmit) require.Error(t, err) require.Contains(t, err.Error(), "auth backend down") @@ -502,10 +511,10 @@ func TestExecuteSingleTool_AuthorizationError(t *testing.T) { func TestExecuteSingleTool_ApprovalUnavailable(t *testing.T) { // No channel, no handler → approval status = unavailable, tool not run. tool := stubTool{name: "guarded", result: "secret", needsApproval: true} - rt := newLoopRT(t, 5, &seqLLMClient{}, tool) + rt, tools := newLoopRT(t, 5, &seqLLMClient{}, tool) msg, err := rt.executeSingleTool(context.Background(), - AgentLoopInput{ChannelName: "", ApprovalHandler: nil}, "msg", + AgentLoopInput{ChannelName: "", ApprovalHandler: nil, Tools: tools}, "msg", base.ToolCallRequest{ToolCallID: "c1", ToolName: "guarded", NeedsApproval: true}, noopEmit) require.NoError(t, err) require.Contains(t, msg.Content, msgToolApprovalUnavailable) @@ -513,14 +522,14 @@ func TestExecuteSingleTool_ApprovalUnavailable(t *testing.T) { func TestExecuteSingleTool_ApprovalHandlerApproves(t *testing.T) { tool := stubTool{name: "guarded", result: "ok", needsApproval: true} - rt := newLoopRT(t, 5, &seqLLMClient{}, tool) + rt, tools := newLoopRT(t, 5, &seqLLMClient{}, tool) handler := func(_ context.Context, req *types.ApprovalRequest) { _ = req.Respond(types.ApprovalStatusApproved) } msg, err := rt.executeSingleTool(context.Background(), - AgentLoopInput{ApprovalHandler: handler}, "msg", + AgentLoopInput{ApprovalHandler: handler, Tools: tools}, "msg", base.ToolCallRequest{ToolCallID: "c1", ToolName: "guarded", NeedsApproval: true}, noopEmit) require.NoError(t, err) require.Equal(t, "ok", msg.Content) @@ -528,14 +537,14 @@ func TestExecuteSingleTool_ApprovalHandlerApproves(t *testing.T) { func TestExecuteSingleTool_ApprovalHandlerRejects(t *testing.T) { tool := stubTool{name: "guarded", result: "secret", needsApproval: true} - rt := newLoopRT(t, 5, &seqLLMClient{}, tool) + rt, tools := newLoopRT(t, 5, &seqLLMClient{}, tool) handler := func(_ context.Context, req *types.ApprovalRequest) { _ = req.Respond(types.ApprovalStatusRejected) } msg, err := rt.executeSingleTool(context.Background(), - AgentLoopInput{ApprovalHandler: handler}, "msg", + AgentLoopInput{ApprovalHandler: handler, Tools: tools}, "msg", base.ToolCallRequest{ToolCallID: "c1", ToolName: "guarded", NeedsApproval: true}, noopEmit) require.NoError(t, err) require.Equal(t, msgToolRejected, msg.Content) @@ -545,7 +554,7 @@ func TestExecuteSingleTool_StreamingApproveUnblocks(t *testing.T) { // Streaming path: ChannelName set, no ApprovalHandler. // We call rt.Approve from a goroutine to unblock executeSingleTool. tool := stubTool{name: "guarded", result: "stream-ok", needsApproval: true} - rt := newLoopRT(t, 5, &seqLLMClient{}, tool) + rt, tools := newLoopRT(t, 5, &seqLLMClient{}, tool) // Capture the approval token from the emitted CUSTOM event. var capturedToken string @@ -583,7 +592,7 @@ func TestExecuteSingleTool_StreamingApproveUnblocks(t *testing.T) { defer close(done) resultMsg, resultErr = rt.executeSingleTool( context.Background(), - AgentLoopInput{ChannelName: "some-channel"}, // streaming path + AgentLoopInput{ChannelName: "some-channel", Tools: tools}, // streaming path "msg", base.ToolCallRequest{ToolCallID: "c1", ToolName: "guarded", NeedsApproval: true}, emit, @@ -609,7 +618,7 @@ func TestExecuteSingleTool_StreamingApproveUnblocks(t *testing.T) { func TestExecuteSingleTool_ApprovalContextCancel(t *testing.T) { tool := stubTool{name: "guarded", result: "should not run", needsApproval: true} - rt := newLoopRT(t, 5, &seqLLMClient{}, tool) + rt, tools := newLoopRT(t, 5, &seqLLMClient{}, tool) ctx, cancel := context.WithCancel(context.Background()) @@ -621,7 +630,7 @@ func TestExecuteSingleTool_ApprovalContextCancel(t *testing.T) { }() _, err := rt.executeSingleTool(ctx, - AgentLoopInput{ChannelName: "some-channel"}, "msg", + AgentLoopInput{ChannelName: "some-channel", Tools: tools}, "msg", base.ToolCallRequest{ToolCallID: "c1", ToolName: "guarded", NeedsApproval: true}, noopEmit) <-done @@ -634,14 +643,14 @@ func TestExecuteSingleTool_ApprovalContextCancel(t *testing.T) { // --------------------------------------------------------------------------- func TestPublishEventToChannel_NoOpWhenChannelEmpty(t *testing.T) { - rt := newLoopRT(t, 5, &seqLLMClient{}) + rt, _ := newLoopRT(t, 5, &seqLLMClient{}) require.NotPanics(t, func() { rt.publishEventToChannel(context.Background(), "", events.NewAgentRunErrorEvent("x")) }) } func TestPublishEventToChannel_NoOpWhenNilEvent(t *testing.T) { - rt := newLoopRT(t, 5, &seqLLMClient{}) + rt, _ := newLoopRT(t, 5, &seqLLMClient{}) require.NotPanics(t, func() { rt.publishEventToChannel(context.Background(), "ch", nil) }) @@ -665,7 +674,7 @@ func TestPublishEventToChannel_NoOpWhenNilEventbus(t *testing.T) { // --------------------------------------------------------------------------- func TestPersistConversationMessages_NilConversation(t *testing.T) { - rt := newLoopRT(t, 5, &seqLLMClient{}) + rt, _ := newLoopRT(t, 5, &seqLLMClient{}) // No conversation configured — must not panic or error. err := persistConversationMessages(context.Background(), rt, "c", []interfaces.Message{ {Role: interfaces.MessageRoleUser, Content: "hi"}, @@ -680,7 +689,7 @@ func TestPersistConversationMessages_StoresAllMessages(t *testing.T) { rt, err := NewLocalRuntime( WithLogger(logger.NoopLogger()), - WithAgentExecution(sdkruntime.AgentExecution{ + WithAgentConfig(sdkruntime.AgentConfig{ LLM: sdkruntime.AgentLLM{Client: &seqLLMClient{}}, Session: sdkruntime.AgentSession{Conversation: conv}, Limits: sdkruntime.AgentLimits{Timeout: 5 * time.Second}, @@ -704,7 +713,7 @@ func TestPersistConversationMessages_AddMessageErrorWarnsOnly(t *testing.T) { rt, err := NewLocalRuntime( WithLogger(logger.NoopLogger()), - WithAgentExecution(sdkruntime.AgentExecution{ + WithAgentConfig(sdkruntime.AgentConfig{ LLM: sdkruntime.AgentLLM{Client: &seqLLMClient{}}, Session: sdkruntime.AgentSession{Conversation: conv}, Limits: sdkruntime.AgentLimits{Timeout: 5 * time.Second}, diff --git a/internal/runtime/local/options.go b/internal/runtime/local/options.go index 74640a9..b9fa196 100644 --- a/internal/runtime/local/options.go +++ b/internal/runtime/local/options.go @@ -28,9 +28,9 @@ func WithAgentSpec(spec sdkruntime.AgentSpec) Option { } } -func WithAgentExecution(execution sdkruntime.AgentExecution) Option { +func WithAgentConfig(cfg sdkruntime.AgentConfig) Option { return func(r *LocalRuntime) { - r.AgentExecution = execution + r.AgentConfig = cfg } } @@ -58,7 +58,7 @@ func buildLocalRuntime(opts ...Option) (*LocalRuntime, error) { opt(r) } - if r.AgentExecution.LLM.Client == nil { + if r.AgentConfig.LLM.Client == nil { return nil, fmt.Errorf("llm client is required") } diff --git a/internal/runtime/local/runtime.go b/internal/runtime/local/runtime.go index 4fd7788..0b5d3dc 100644 --- a/internal/runtime/local/runtime.go +++ b/internal/runtime/local/runtime.go @@ -13,6 +13,7 @@ import ( sdkruntime "github.com/agenticenv/agent-sdk-go/internal/runtime" "github.com/agenticenv/agent-sdk-go/internal/runtime/base" "github.com/agenticenv/agent-sdk-go/internal/types" + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" "github.com/agenticenv/agent-sdk-go/pkg/logger" "github.com/google/uuid" ) @@ -100,7 +101,7 @@ func (rt *LocalRuntime) publishLifecycleEvent(channel string, ev events.AgentEve // Execute runs the agent loop synchronously and returns the final result. // Approval is handled inline via req.ApprovalHandler (no out-of-band tokens). func (rt *LocalRuntime) Execute(ctx context.Context, req *sdkruntime.ExecuteRequest) (*types.AgentRunResult, error) { - agentName := agentNameFromRequest(req) + agentName := agentNameFromRuntime(rt) rt.logger.Debug(ctx, "runtime execute", slog.String("scope", "runtime"), slog.String("agent", agentName), @@ -108,7 +109,7 @@ func (rt *LocalRuntime) Execute(ctx context.Context, req *sdkruntime.ExecuteRequ // Apply agent timeout when the caller has not set a deadline. runCtx := ctx - if d := rt.AgentExecution.Limits.Timeout; d > 0 { + if d := rt.AgentConfig.Limits.Timeout; d > 0 { if _, hasDeadline := ctx.Deadline(); !hasDeadline { var cancel context.CancelFunc runCtx, cancel = context.WithTimeout(ctx, d) @@ -119,6 +120,8 @@ func (rt *LocalRuntime) Execute(ctx context.Context, req *sdkruntime.ExecuteRequ conversationID := base.GetConversationID(req) runID := uuid.New().String() + tools := req.Tools + loopResult, err := rt.RunAgentLoop(runCtx, AgentLoopInput{ UserPrompt: req.UserPrompt, ConversationID: conversationID, @@ -128,6 +131,7 @@ func (rt *LocalRuntime) Execute(ctx context.Context, req *sdkruntime.ExecuteRequ SubAgentRoutes: buildSubAgentRoutes(req.SubAgents), SubAgentDepth: 0, MaxSubAgentDepth: req.MaxSubAgentDepth, + Tools: tools, }) if err != nil { return nil, err @@ -137,7 +141,7 @@ func (rt *LocalRuntime) Execute(ctx context.Context, req *sdkruntime.ExecuteRequ return &types.AgentRunResult{ Content: loopResult.Content, AgentName: strings.TrimSpace(agentName), - Model: rt.AgentExecution.LLM.Client.GetModel(), + Model: rt.AgentConfig.LLM.Client.GetModel(), Metadata: map[string]any{}, Usage: loopResult.Usage, }, nil @@ -146,7 +150,7 @@ func (rt *LocalRuntime) Execute(ctx context.Context, req *sdkruntime.ExecuteRequ // ExecuteStream starts the agent loop in a goroutine and returns a channel of AgentEvent. // RUN_STARTED is emitted before the loop begins; RUN_FINISHED or RUN_ERROR closes the channel. func (rt *LocalRuntime) ExecuteStream(ctx context.Context, req *sdkruntime.ExecuteRequest) (<-chan events.AgentEvent, error) { - agentName := agentNameFromRequest(req) + agentName := agentNameFromRuntime(rt) rt.logger.Debug(ctx, "runtime execute stream", slog.String("scope", "runtime"), slog.String("agent", agentName), @@ -164,7 +168,7 @@ func (rt *LocalRuntime) ExecuteStream(ctx context.Context, req *sdkruntime.Execu // Apply agent timeout. runCtx := ctx var runCancel context.CancelFunc - if d := rt.AgentExecution.Limits.Timeout; d > 0 { + if d := rt.AgentConfig.Limits.Timeout; d > 0 { if _, hasDeadline := ctx.Deadline(); !hasDeadline { runCtx, runCancel = context.WithTimeout(ctx, d) } @@ -196,13 +200,16 @@ func (rt *LocalRuntime) ExecuteStream(ctx context.Context, req *sdkruntime.Execu // Run the agent loop in a goroutine; emit lifecycle terminal event on completion. go func() { + var tools []interfaces.Tool + if req != nil { + tools = req.Tools + } defer func() { if runCancel != nil { runCancel() } _ = closeSub() }() - result, loopErr := rt.RunAgentLoop(runCtx, AgentLoopInput{ UserPrompt: req.UserPrompt, ConversationID: conversationID, @@ -212,6 +219,7 @@ func (rt *LocalRuntime) ExecuteStream(ctx context.Context, req *sdkruntime.Execu SubAgentRoutes: buildSubAgentRoutes(req.SubAgents), SubAgentDepth: 0, MaxSubAgentDepth: req.MaxSubAgentDepth, + Tools: tools, }) if loopErr != nil { @@ -226,7 +234,7 @@ func (rt *LocalRuntime) ExecuteStream(ctx context.Context, req *sdkruntime.Execu agentRunResult := &types.AgentRunResult{ Content: result.Content, AgentName: strings.TrimSpace(agentName), - Model: rt.AgentExecution.LLM.Client.GetModel(), + Model: rt.AgentConfig.LLM.Client.GetModel(), Metadata: map[string]any{}, Usage: result.Usage, } @@ -268,9 +276,9 @@ func (rt *LocalRuntime) SetEventBus(bus eventbus.EventBus) { rt.eventbus = bus } -func agentNameFromRequest(req *sdkruntime.ExecuteRequest) string { - if req == nil || req.AgentSpec == nil { +func agentNameFromRuntime(rt *LocalRuntime) string { + if rt == nil { return "" } - return req.AgentSpec.Name + return rt.AgentSpec.Name } diff --git a/internal/runtime/local/runtime_test.go b/internal/runtime/local/runtime_test.go index d181a71..b3a9b9a 100644 --- a/internal/runtime/local/runtime_test.go +++ b/internal/runtime/local/runtime_test.go @@ -76,9 +76,8 @@ func newLocalRT(t *testing.T, client interfaces.LLMClient, tools ...interfaces.T rt, err := NewLocalRuntime( WithLogger(logger.NoopLogger()), WithAgentSpec(sdkruntime.AgentSpec{Name: "test-agent", SystemPrompt: "you are helpful"}), - WithAgentExecution(sdkruntime.AgentExecution{ - LLM: sdkruntime.AgentLLM{Client: client}, - Tools: sdkruntime.AgentTools{Tools: tools}, + WithAgentConfig(sdkruntime.AgentConfig{ + LLM: sdkruntime.AgentLLM{Client: client}, Limits: sdkruntime.AgentLimits{ MaxIterations: 5, Timeout: 30 * time.Second, @@ -86,9 +85,14 @@ func newLocalRT(t *testing.T, client interfaces.LLMClient, tools ...interfaces.T }), ) require.NoError(t, err) + _ = tools // callers pass resolved tools on ExecuteRequest.Tools return rt } +func execReq(prompt string, tools ...interfaces.Tool) *sdkruntime.ExecuteRequest { + return &sdkruntime.ExecuteRequest{UserPrompt: prompt, Tools: tools} +} + // collectEvents drains an event channel until it is closed or timeout elapses, // returning all events received. func collectEvents(t *testing.T, ch <-chan events.AgentEvent, timeout time.Duration) []events.AgentEvent { @@ -127,7 +131,7 @@ func eventTypes(evs []events.AgentEvent) []events.AgentEventType { func TestNewLocalRuntime_MissingLLMClient(t *testing.T) { _, err := NewLocalRuntime( WithAgentSpec(sdkruntime.AgentSpec{Name: "agent"}), - WithAgentExecution(sdkruntime.AgentExecution{}), + WithAgentConfig(sdkruntime.AgentConfig{}), ) require.Error(t, err) require.Contains(t, err.Error(), "llm client is required") @@ -135,7 +139,7 @@ func TestNewLocalRuntime_MissingLLMClient(t *testing.T) { func TestNewLocalRuntime_DefaultNoopObservability(t *testing.T) { rt, err := NewLocalRuntime( - WithAgentExecution(sdkruntime.AgentExecution{ + WithAgentConfig(sdkruntime.AgentConfig{ LLM: sdkruntime.AgentLLM{Client: &seqLLMClient{}}, }), ) @@ -152,7 +156,7 @@ func TestNewLocalRuntime_WithAllOptions(t *testing.T) { rt, err := NewLocalRuntime( WithLogger(logger.NoopLogger()), WithAgentSpec(sdkruntime.AgentSpec{Name: "my-agent"}), - WithAgentExecution(sdkruntime.AgentExecution{ + WithAgentConfig(sdkruntime.AgentConfig{ LLM: sdkruntime.AgentLLM{Client: &seqLLMClient{}}, }), WithTracer(tracer), @@ -172,20 +176,16 @@ func TestNewLocalRuntime_EventBusInitialised(t *testing.T) { } // --------------------------------------------------------------------------- -// agentNameFromRequest +// agentNameFromRuntime // --------------------------------------------------------------------------- -func TestAgentNameFromRequest_NilRequest(t *testing.T) { - require.Equal(t, "", agentNameFromRequest(nil)) -} - -func TestAgentNameFromRequest_NilSpec(t *testing.T) { - require.Equal(t, "", agentNameFromRequest(&sdkruntime.ExecuteRequest{})) +func TestAgentNameFromRuntime_NilRuntime(t *testing.T) { + require.Equal(t, "", agentNameFromRuntime(nil)) } -func TestAgentNameFromRequest_WithName(t *testing.T) { - req := &sdkruntime.ExecuteRequest{AgentSpec: &sdkruntime.AgentSpec{Name: "hello"}} - require.Equal(t, "hello", agentNameFromRequest(req)) +func TestAgentNameFromRuntime_WithName(t *testing.T) { + rt := newLocalRT(t, &seqLLMClient{}) + require.Equal(t, "test-agent", agentNameFromRuntime(rt)) } // --------------------------------------------------------------------------- @@ -202,7 +202,6 @@ func TestExecute_SimpleTextResponse(t *testing.T) { result, err := rt.Execute(context.Background(), &sdkruntime.ExecuteRequest{ UserPrompt: "hi", - AgentSpec: &sdkruntime.AgentSpec{Name: "test-agent"}, }) require.NoError(t, err) @@ -227,7 +226,7 @@ func TestExecute_AppliesTimeoutWhenNoDeadline(t *testing.T) { blocking := &blockingLLMClient{block: make(chan struct{})} rt, err := NewLocalRuntime( WithLogger(logger.NoopLogger()), - WithAgentExecution(sdkruntime.AgentExecution{ + WithAgentConfig(sdkruntime.AgentConfig{ LLM: sdkruntime.AgentLLM{Client: blocking}, Limits: sdkruntime.AgentLimits{ MaxIterations: 1, @@ -285,6 +284,7 @@ func TestExecute_WithApprovalHandler(t *testing.T) { result, err := rt.Execute(context.Background(), &sdkruntime.ExecuteRequest{ UserPrompt: "run tool", + Tools: []interfaces.Tool{tool}, ApprovalHandler: handler, }) @@ -305,7 +305,6 @@ func TestExecuteStream_EmitsRunStartedAndFinished(t *testing.T) { ch, err := rt.ExecuteStream(context.Background(), &sdkruntime.ExecuteRequest{ UserPrompt: "hello", - AgentSpec: &sdkruntime.AgentSpec{Name: "test-agent"}, }) require.NoError(t, err) @@ -459,6 +458,7 @@ func TestApprove_StreamingEndToEnd(t *testing.T) { ch, err := rt.ExecuteStream(context.Background(), &sdkruntime.ExecuteRequest{ UserPrompt: "run guarded tool", + Tools: []interfaces.Tool{tool}, }) require.NoError(t, err) @@ -611,9 +611,7 @@ func TestExecute_ToolCallThenFinalAnswer(t *testing.T) { tool := stubTool{name: "calc", result: "42"} rt := newLocalRT(t, client, tool) - result, err := rt.Execute(context.Background(), &sdkruntime.ExecuteRequest{ - UserPrompt: "compute", - }) + result, err := rt.Execute(context.Background(), execReq("compute", tool)) require.NoError(t, err) require.Equal(t, "the answer is 42", result.Content) } @@ -637,7 +635,7 @@ func TestExecute_PersistsConversationMessages(t *testing.T) { rt, err := NewLocalRuntime( WithLogger(logger.NoopLogger()), WithAgentSpec(sdkruntime.AgentSpec{Name: "agent"}), - WithAgentExecution(sdkruntime.AgentExecution{ + WithAgentConfig(sdkruntime.AgentConfig{ LLM: sdkruntime.AgentLLM{Client: client}, Session: sdkruntime.AgentSession{Conversation: conv, ConversationSize: 20}, Limits: sdkruntime.AgentLimits{MaxIterations: 5, Timeout: 5 * time.Second}, diff --git a/internal/runtime/local/subagent.go b/internal/runtime/local/subagent.go index e57db4c..b972ea3 100644 --- a/internal/runtime/local/subagent.go +++ b/internal/runtime/local/subagent.go @@ -1,6 +1,9 @@ package local -import sdkruntime "github.com/agenticenv/agent-sdk-go/internal/runtime" +import ( + sdkruntime "github.com/agenticenv/agent-sdk-go/internal/runtime" + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" +) // subAgentRoute is the local runtime's internal representation of a delegatable sub-agent. // Built from ExecuteRequest.SubAgents by buildSubAgentRoutes; not shared with any other package. @@ -8,6 +11,7 @@ type subAgentRoute struct { name string runtime *LocalRuntime children map[string]subAgentRoute + tools []interfaces.Tool } // buildSubAgentRoutes converts the runtime-agnostic SubAgentSpec tree (from ExecuteRequest) @@ -30,6 +34,7 @@ func buildSubAgentRoutes(specs []*sdkruntime.SubAgentSpec) map[string]subAgentRo name: spec.Name, runtime: lr, children: buildSubAgentRoutes(spec.Children), + tools: spec.Tools, } } if len(out) == 0 { diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go index 6852492..920a9f6 100644 --- a/internal/runtime/runtime.go +++ b/internal/runtime/runtime.go @@ -26,7 +26,7 @@ type Runtime interface { // Execute runs one execution and returns the result. The agent package supplies approval via ExecuteRequest when needed. // Use WithTimeout or a context with deadline to avoid blocking. // When using conversation, pass the conversation ID on the request; agent and worker must use the same ID. - // Agent identity is on req.AgentSpec.Name when AgentSpec is set. + // Agent identity lives on the runtime [AgentSpec] configured at construction. Execute(ctx context.Context, req *ExecuteRequest) (*types.AgentRunResult, error) // ExecuteStream starts the run and returns a channel of AgentEvent. Streams RUN_* lifecycle, @@ -36,7 +36,6 @@ type Runtime interface { // For approvals (tool or delegation), receive CUSTOM (AgentEventTypeCustom) events and use the agent // package approval path (e.g. OnApproval with the token from the custom payload). // When using conversation, pass the conversation ID on the request. - // Agent identity is on req.AgentSpec.Name when AgentSpec is set. ExecuteStream(ctx context.Context, req *ExecuteRequest) (<-chan events.AgentEvent, error) // Approve completes a pending tool approval when the runtime uses out-of-band approval @@ -79,11 +78,11 @@ type SubAgentSpec struct { ToolName string // tool name used to invoke this sub-agent (key in runtime route maps) Runtime Runtime // the sub-agent's runtime instance Children []*SubAgentSpec + // Tools is the registry-resolved tool list for this sub-agent at request time. + Tools []interfaces.Tool `json:"-"` } -// AgentSpec describes agent identity and structured-output preferences for one run. -// It is attached to [ExecuteRequest.AgentSpec] so custom Runtime implementations can read name, prompts, -// and response format without importing pkg/agent. +// AgentSpec describes agent identity and structured-output preferences configured on the runtime. type AgentSpec struct { // Name is a human-readable label (may include spaces). Runtimes may sanitize it when embedding in workflow IDs. Name string @@ -92,15 +91,13 @@ type AgentSpec struct { ResponseFormat *interfaces.ResponseFormat } -// AgentExecution groups per-run execution inputs for custom Runtime implementations. Sub-structs -// stay stable so callers do not depend on a single flat blob that might be reshaped later. -// Temporal-backed runtimes typically use worker-local configuration for activities; this is a snapshot. -type AgentExecution struct { - LLM AgentLLM - Tools AgentTools - Retrievers AgentRetrievers - Session AgentSession - Limits AgentLimits +// AgentConfig is static agent wiring on the runtime at construction: LLM client, tool approval policy, session, limits, and retriever config. +type AgentConfig struct { + LLM AgentLLM + ToolApprovalPolicy interfaces.AgentToolApprovalPolicy + Retrievers AgentRetrievers + Session AgentSession + Limits AgentLimits } // AgentRetrievers holds the retriever instances and mode for prefetch and hybrid RAG. @@ -121,13 +118,6 @@ type AgentLLM struct { Sampling *LLMSampling } -// AgentTools is registered tools, optional registry, and approval policy for this run. -type AgentTools struct { - Tools []interfaces.Tool - Registry interfaces.ToolRegistry - ApprovalPolicy interfaces.AgentToolApprovalPolicy -} - // AgentSession is conversation storage and how many messages to include in LLM context. type AgentSession struct { Conversation interfaces.Conversation @@ -143,10 +133,6 @@ type AgentLimits struct { } // ExecuteRequest carries one execution request from Agent to Runtime. -// -// AgentSpec and AgentExecution are populated by pkg/agent from its configuration so implementations -// can read identity (including agent name on AgentSpec.Name), prompts, LLM, tools, and policies -// for this run. Implementations may ignore fields they do not use. type ExecuteRequest struct { UserPrompt string `json:"user_prompt"` // RunOptions is the per-call options forwarded from pkg/agent (e.g. conversation session). May be nil. @@ -159,10 +145,8 @@ type ExecuteRequest struct { SubAgents []*SubAgentSpec `json:"sub_agents,omitempty"` MaxSubAgentDepth int `json:"max_sub_agent_depth"` - ApprovalHandler types.ApprovalHandler `json:"approval_handler"` + // Tools is the registry-resolved tool list for this run. + Tools []interfaces.Tool `json:"-"` - // AgentSpec is identity and output-format metadata for this run (name, description, system prompt, response format). - AgentSpec *AgentSpec `json:"agent_spec"` - // AgentExecution is LLM, tools, conversation, sampling, and policy for this run. - AgentExecution *AgentExecution `json:"agent_execution"` + ApprovalHandler types.ApprovalHandler `json:"approval_handler"` } diff --git a/internal/runtime/temporal/agent_workflow.go b/internal/runtime/temporal/agent_workflow.go index 2f80d08..393a779 100644 --- a/internal/runtime/temporal/agent_workflow.go +++ b/internal/runtime/temporal/agent_workflow.go @@ -115,7 +115,7 @@ func (rt *TemporalRuntime) sendAgentEventWorkflowUpdate(ctx context.Context, eve // EventTaskQueue is the Temporal task queue for AgentEventWorkflow (e.g. main TaskQueue + "-events"); required // for UpdateWithStartWorkflow when EventWorkflowID is set. // EventTypes is set by the SDK; a single "*" element means emit all event kinds (used for Stream). -// AgentFingerprint is the SHA-256 hex digest of the worker-local agent config; activities reject on mismatch. +// AgentFingerprint is the per-run digest (config + resolved tools). Caller and worker compute it at resolve time. type AgentWorkflowInput struct { UserPrompt string `json:"user_prompt,omitempty"` EventWorkflowID string `json:"event_workflow_id,omitempty"` @@ -289,9 +289,9 @@ func (rt *TemporalRuntime) AgentWorkflow(ctx workflow.Context, input AgentWorkfl eventWorkflowID := input.EventWorkflowID eventTaskQueue := input.EventTaskQueue agentName := rt.AgentSpec.Name - model := rt.AgentExecution.LLM.Client.GetModel() + model := rt.AgentConfig.LLM.Client.GetModel() - maxIter := rt.AgentExecution.Limits.MaxIterations + maxIter := rt.AgentConfig.Limits.MaxIterations var activityIDSuffix string err := workflow.SideEffect(ctx, func(ctx workflow.Context) interface{} { @@ -379,7 +379,7 @@ func (rt *TemporalRuntime) AgentWorkflow(ctx workflow.Context, input AgentWorkfl return nil } - useStreaming := input.StreamingEnabled && rt.AgentExecution.LLM.Client.IsStreamSupported() + useStreaming := input.StreamingEnabled && rt.AgentConfig.LLM.Client.IsStreamSupported() // State restored after ContinueAsNew (iteration + conversation messages). if input.State == nil { @@ -395,10 +395,10 @@ func (rt *TemporalRuntime) AgentWorkflow(ctx workflow.Context, input AgentWorkfl // The resulting retrieverContext is forwarded to every AgentLLMInput in the run so the LLM always // sees the retrieved documents in its system prompt, regardless of the number of iterations. retrieverContext := "" - retrieverMode := rt.AgentExecution.Retrievers.Mode + retrieverMode := rt.AgentConfig.Retrievers.Mode if (retrieverMode == types.RetrieverModePrefetch || retrieverMode == types.RetrieverModeHybrid) && - len(rt.AgentExecution.Retrievers.Retrievers) > 0 { - logger.Debug("workflow: retriever prefetch started", "scope", "workflow", "retrieverMode", string(retrieverMode), "retrieverCount", len(rt.AgentExecution.Retrievers.Retrievers)) + len(rt.AgentConfig.Retrievers.Retrievers) > 0 { + logger.Debug("workflow: retriever prefetch started", "scope", "workflow", "retrieverMode", string(retrieverMode), "retrieverCount", len(rt.AgentConfig.Retrievers.Retrievers)) retrieverInput := AgentRetrieverInput{ AgentFingerprint: input.AgentFingerprint, UserPrompt: input.UserPrompt, @@ -624,7 +624,7 @@ func (rt *TemporalRuntime) AgentWorkflow(ctx workflow.Context, input AgentWorkfl messages = append(messages, toolResults...) - if rt.conversationMemoryEnabled(input.ConversationID) && rt.AgentExecution.Session.ConversationSaveOnIteration && len(messages) > 0 { + if rt.conversationMemoryEnabled(input.ConversationID) && rt.AgentConfig.Session.ConversationSaveOnIteration && len(messages) > 0 { if err := workflow.ExecuteActivity(convCtx, rt.AddConversationMessagesActivity, AddConversationMessagesInput{ ConversationID: input.ConversationID, Messages: messages, @@ -664,7 +664,7 @@ func (rt *TemporalRuntime) AgentWorkflow(ctx workflow.Context, input AgentWorkfl AgentFingerprint: input.AgentFingerprint, }).Get(convCtx, nil); err != nil { logger.Warn("workflow: persist conversation failed", "scope", "workflow", "conversationID", input.ConversationID, "messagesCount", len(messages), "error", err) - if !rt.AgentExecution.Session.ConversationSaveOnIteration { + if !rt.AgentConfig.Session.ConversationSaveOnIteration { return nil, err } } @@ -678,7 +678,7 @@ func (rt *TemporalRuntime) AgentWorkflow(ctx workflow.Context, input AgentWorkfl } func (rt *TemporalRuntime) conversationMemoryEnabled(conversationID string) bool { - return conversationID != "" && rt.AgentExecution.Session.Conversation != nil + return conversationID != "" && rt.AgentConfig.Session.Conversation != nil } // newAgentToolCallInput builds activity contexts for one tool-call branch. @@ -690,7 +690,7 @@ func (rt *TemporalRuntime) newAgentToolCallInput( emitAgentEvent func(workflow.Context, events.AgentEvent) error, parallelSlot string, ) agentToolCallInput { - approvalTaskTimeout := rt.AgentExecution.Limits.ApprovalTimeout + approvalTaskTimeout := rt.AgentConfig.Limits.ApprovalTimeout if approvalTaskTimeout == 0 { approvalTaskTimeout = types.MaxApprovalTimeout } @@ -921,7 +921,11 @@ func (rt *TemporalRuntime) publishAgentEventToStream(ctx context.Context, agentN // (REASONING_*), then TEXT_MESSAGE_START → TEXT_MESSAGE_CONTENT* → TEXT_MESSAGE_END. // When input.ConversationID is set, fetches messages from conversation and prepends to workflow messages. func (rt *TemporalRuntime) AgentLLMStreamActivity(ctx context.Context, input AgentLLMInput) (*AgentLLMResult, error) { - if err := rt.verifyAgentFingerprint(input.AgentFingerprint); err != nil { + tools, err := rt.fetchTools(ctx) + if err != nil { + return nil, err + } + if err := rt.verifyAgentFingerprint(ctx, input.AgentFingerprint, tools); err != nil { return nil, err } stopHB := startLongActivityHeartbeats(ctx) @@ -943,7 +947,7 @@ func (rt *TemporalRuntime) AgentLLMStreamActivity(ctx context.Context, input Age rt.publishAgentEventToStream(ctx, agentName, input.LocalChannelName, input.EventWorkflowID, input.EventTaskQueue, ev) } - result, err := rt.ExecuteLLMStream(ctx, actLog, agentName, input.MessageID, messages, input.SkipTools, input.RetrieverContext, emit) + result, err := rt.ExecuteLLMStream(ctx, actLog, agentName, input.MessageID, messages, input.SkipTools, input.RetrieverContext, tools, emit) if err != nil { return nil, err } @@ -956,7 +960,7 @@ func (rt *TemporalRuntime) AgentLLMStreamActivity(ctx context.Context, input Age // Partial failures (some retrievers fail) are logged and skipped; if all retrievers fail, the activity // returns an error so Temporal can retry per the retry policy. func (rt *TemporalRuntime) AgentRetrieverActivity(ctx context.Context, input AgentRetrieverInput) (*AgentRetrieverResult, error) { - if err := rt.verifyAgentFingerprint(input.AgentFingerprint); err != nil { + if err := rt.verifyAgentFingerprint(ctx, input.AgentFingerprint, nil); err != nil { return nil, err } actLog := newActivityLogger(activity.GetLogger(ctx)) @@ -970,7 +974,11 @@ func (rt *TemporalRuntime) AgentRetrieverActivity(ctx context.Context, input Age // AgentLLMActivity calls the LLM and returns content plus any tool calls. // When input.ConversationID is set, fetches from store and adds assistant message on completion. func (rt *TemporalRuntime) AgentLLMActivity(ctx context.Context, input AgentLLMInput) (*AgentLLMResult, error) { - if err := rt.verifyAgentFingerprint(input.AgentFingerprint); err != nil { + tools, err := rt.fetchTools(ctx) + if err != nil { + return nil, err + } + if err := rt.verifyAgentFingerprint(ctx, input.AgentFingerprint, tools); err != nil { return nil, err } actLog := newActivityLogger(activity.GetLogger(ctx)) @@ -989,7 +997,7 @@ func (rt *TemporalRuntime) AgentLLMActivity(ctx context.Context, input AgentLLMI rt.publishAgentEventToStream(ctx, agentName, input.LocalChannelName, input.EventWorkflowID, input.EventTaskQueue, ev) } - result, err := rt.ExecuteLLM(ctx, actLog, agentName, input.MessageID, messages, input.SkipTools, input.RetrieverContext, emit) + result, err := rt.ExecuteLLM(ctx, actLog, agentName, input.MessageID, messages, input.SkipTools, input.RetrieverContext, tools, emit) if err != nil { return nil, err } @@ -1001,7 +1009,7 @@ func (rt *TemporalRuntime) AgentLLMActivity(ctx context.Context, input AgentLLMI // When EventWorkflowID is set, UpdateWorkflow uses WorkflowUpdateStageCompleted and updateWorkflowApprovalRPCTimeout // so the event handler has returned before ErrResultPending; RPC timeout maps to ApprovalStatusUnavailable. func (rt *TemporalRuntime) AgentToolApprovalActivity(ctx context.Context, input AgentToolApprovalInput) (types.ApprovalStatus, error) { - if err := rt.verifyAgentFingerprint(input.AgentFingerprint); err != nil { + if err := rt.verifyAgentFingerprint(ctx, input.AgentFingerprint, nil); err != nil { return types.ApprovalStatusNone, err } logger := activity.GetLogger(ctx) @@ -1110,7 +1118,7 @@ func (rt *TemporalRuntime) SendAgentEventUpdateActivity(ctx context.Context, in // AddConversationMessagesActivity adds messages to the conversation memory. func (rt *TemporalRuntime) AddConversationMessagesActivity(ctx context.Context, input AddConversationMessagesInput) error { - if err := rt.verifyAgentFingerprint(input.AgentFingerprint); err != nil { + if err := rt.verifyAgentFingerprint(ctx, input.AgentFingerprint, nil); err != nil { return err } conversationID := input.ConversationID @@ -1121,7 +1129,7 @@ func (rt *TemporalRuntime) AddConversationMessagesActivity(ctx context.Context, logger.Debug("activity: add conversation messages started", "scope", "activity", "conversationID", conversationID, "messagesCount", msgCount) - if rt.AgentExecution.Session.Conversation == nil { + if rt.AgentConfig.Session.Conversation == nil { return fmt.Errorf("conversation is not configured") } @@ -1133,7 +1141,7 @@ func (rt *TemporalRuntime) AddConversationMessagesActivity(ctx context.Context, failCount := 0 for _, msg := range messages { - if err := rt.AgentExecution.Session.Conversation.AddMessage(ctx, conversationID, msg); err != nil { + if err := rt.AgentConfig.Session.Conversation.AddMessage(ctx, conversationID, msg); err != nil { failCount++ msgCount-- logger.Warn("activity: add conversation message failed", "scope", "activity", "conversationID", conversationID, "error", err) @@ -1149,22 +1157,30 @@ func (rt *TemporalRuntime) AddConversationMessagesActivity(ctx context.Context, // AgentToolExecuteActivity executes a tool by name and adds tool message to conversation when ConversationID is set. func (rt *TemporalRuntime) AgentToolExecuteActivity(ctx context.Context, input AgentToolExecuteInput) (string, error) { - if err := rt.verifyAgentFingerprint(input.AgentFingerprint); err != nil { + tools, err := rt.fetchTools(ctx) + if err != nil { + return "", err + } + if err := rt.verifyAgentFingerprint(ctx, input.AgentFingerprint, tools); err != nil { return "", err } stopHB := startLongActivityHeartbeats(ctx) defer stopHB() actLog := newActivityLogger(activity.GetLogger(ctx)) - return rt.ExecuteTool(ctx, actLog, input.ToolName, input.Args) + return rt.ExecuteTool(ctx, actLog, tools, input.ToolName, input.Args) } // AgentToolAuthorizeActivity checks optional programmatic authorization before approval/execute. func (rt *TemporalRuntime) AgentToolAuthorizeActivity(ctx context.Context, input AgentToolAuthorizeInput) (AgentToolAuthorizeResult, error) { - if err := rt.verifyAgentFingerprint(input.AgentFingerprint); err != nil { + tools, err := rt.fetchTools(ctx) + if err != nil { + return AgentToolAuthorizeResult{}, err + } + if err := rt.verifyAgentFingerprint(ctx, input.AgentFingerprint, tools); err != nil { return AgentToolAuthorizeResult{}, err } actLog := newActivityLogger(activity.GetLogger(ctx)) - authResult, err := rt.AuthorizeTool(ctx, actLog, input.ToolName, input.Args) + authResult, err := rt.AuthorizeTool(ctx, actLog, tools, input.ToolName, input.Args) if err != nil { return AgentToolAuthorizeResult{}, err } @@ -1273,7 +1289,7 @@ func (rt *TemporalRuntime) delegateToSubAgent(ctx workflow.Context, input AgentW // Uses the main agent worker's agent timeout (same package as delegateToSubAgent); sub-agent workers may define // their own limits separately, but this bounds the child execution from the main agent's perspective. func (rt *TemporalRuntime) subAgentChildWorkflowTimeout() time.Duration { - return rt.AgentExecution.Limits.Timeout + return rt.AgentConfig.Limits.Timeout } func retryPolicy(maxAttempts int32) *temporal.RetryPolicy { diff --git a/internal/runtime/temporal/agent_workflow_test.go b/internal/runtime/temporal/agent_workflow_test.go index 1582fe8..137fd4c 100644 --- a/internal/runtime/temporal/agent_workflow_test.go +++ b/internal/runtime/temporal/agent_workflow_test.go @@ -22,13 +22,12 @@ import ( func testRuntimeForWorkflow(t *testing.T) *TemporalRuntime { t.Helper() - return &TemporalRuntime{ + rt := &TemporalRuntime{ Runtime: base.Runtime{ AgentSpec: sdkruntime.AgentSpec{Name: "WorkflowTestAgent"}, - AgentExecution: sdkruntime.AgentExecution{ + AgentConfig: sdkruntime.AgentConfig{ LLM: sdkruntime.AgentLLM{Client: stubLLM{}}, Limits: sdkruntime.AgentLimits{MaxIterations: 5}, - Tools: sdkruntime.AgentTools{Tools: nil}, Session: sdkruntime.AgentSession{}, }, Tracer: observability.DefaultNoopTracer, @@ -36,6 +35,18 @@ func testRuntimeForWorkflow(t *testing.T) *TemporalRuntime { }, logger: logger.NoopLogger(), } + wireTestToolsResolver(rt, nil) + return rt +} + +// wireTestToolsResolver connects activity tests with a fixed resolved tool list. +func wireTestToolsResolver(rt *TemporalRuntime, tools []interfaces.Tool) { + if rt == nil { + return + } + rt.resolveToolsFn = func(ctx context.Context) ([]interfaces.Tool, error) { + return tools, nil + } } // newActivityTestEnv returns a [testsuite.TestActivityEnvironment] for isolated activity tests. @@ -80,7 +91,7 @@ func TestAgentWorkflow_StreamingPath_UsesStreamActivity(t *testing.T) { var suite testsuite.WorkflowTestSuite env := suite.NewTestWorkflowEnvironment() rt := testRuntimeForWorkflow(t) - rt.AgentExecution.LLM.Client = streamCapableStubLLM{} + rt.AgentConfig.LLM.Client = streamCapableStubLLM{} env.RegisterWorkflow(rt.AgentWorkflow) env.OnActivity(rt.AgentLLMStreamActivity, mock.Anything, mock.Anything).Return(func(ctx context.Context, in AgentLLMInput) (*AgentLLMResult, error) { @@ -192,7 +203,7 @@ func TestAgentLLMActivity_MockLLM_TextOnly(t *testing.T) { rt := &TemporalRuntime{ Runtime: base.Runtime{ AgentSpec: sdkruntime.AgentSpec{Name: "ActTest"}, - AgentExecution: sdkruntime.AgentExecution{ + AgentConfig: sdkruntime.AgentConfig{ LLM: sdkruntime.AgentLLM{Client: mockLLM}, }, Tracer: observability.DefaultNoopTracer, @@ -200,6 +211,7 @@ func TestAgentLLMActivity_MockLLM_TextOnly(t *testing.T) { }, logger: logger.NoopLogger(), } + wireTestToolsResolver(rt, nil) actEnv := newActivityTestEnv(t) actEnv.RegisterActivity(rt.AgentLLMActivity) @@ -242,15 +254,16 @@ func TestAgentLLMActivity_MockLLM_ToolCalls(t *testing.T) { rt := &TemporalRuntime{ Runtime: base.Runtime{ AgentSpec: sdkruntime.AgentSpec{Name: "ActTest"}, - AgentExecution: sdkruntime.AgentExecution{ - LLM: sdkruntime.AgentLLM{Client: mockLLM}, - Tools: sdkruntime.AgentTools{Tools: []interfaces.Tool{mockTool}, ApprovalPolicy: policy}, + AgentConfig: sdkruntime.AgentConfig{ + LLM: sdkruntime.AgentLLM{Client: mockLLM}, + ToolApprovalPolicy: policy, }, Tracer: observability.DefaultNoopTracer, Metrics: observability.DefaultNoopMetrics, }, logger: logger.NoopLogger(), } + wireTestToolsResolver(rt, []interfaces.Tool{mockTool}) actEnv := newActivityTestEnv(t) actEnv.RegisterActivity(rt.AgentLLMActivity) @@ -285,15 +298,15 @@ func TestAgentLLMActivity_MockLLM_UnknownToolError(t *testing.T) { rt := &TemporalRuntime{ Runtime: base.Runtime{ - AgentExecution: sdkruntime.AgentExecution{ - LLM: sdkruntime.AgentLLM{Client: mockLLM}, - Tools: sdkruntime.AgentTools{Tools: []interfaces.Tool{}}, + AgentConfig: sdkruntime.AgentConfig{ + LLM: sdkruntime.AgentLLM{Client: mockLLM}, }, Tracer: observability.DefaultNoopTracer, Metrics: observability.DefaultNoopMetrics, }, logger: logger.NoopLogger(), } + wireTestToolsResolver(rt, nil) actEnv := newActivityTestEnv(t) actEnv.RegisterActivity(rt.AgentLLMActivity) @@ -321,7 +334,7 @@ func TestAgentLLMActivity_MockConversationAndLLM(t *testing.T) { rt := &TemporalRuntime{ Runtime: base.Runtime{ - AgentExecution: sdkruntime.AgentExecution{ + AgentConfig: sdkruntime.AgentConfig{ LLM: sdkruntime.AgentLLM{Client: mockLLM}, Session: sdkruntime.AgentSession{ Conversation: mockConv, @@ -333,6 +346,7 @@ func TestAgentLLMActivity_MockConversationAndLLM(t *testing.T) { }, logger: logger.NoopLogger(), } + wireTestToolsResolver(rt, nil) actEnv := newActivityTestEnv(t) actEnv.RegisterActivity(rt.AgentLLMActivity) @@ -358,7 +372,7 @@ func TestAgentLLMActivity_ConversationNotConfigured(t *testing.T) { rt := &TemporalRuntime{ Runtime: base.Runtime{ - AgentExecution: sdkruntime.AgentExecution{ + AgentConfig: sdkruntime.AgentConfig{ LLM: sdkruntime.AgentLLM{Client: mockLLM}, Session: sdkruntime.AgentSession{Conversation: nil}, }, @@ -367,6 +381,7 @@ func TestAgentLLMActivity_ConversationNotConfigured(t *testing.T) { }, logger: logger.NoopLogger(), } + wireTestToolsResolver(rt, nil) actEnv := newActivityTestEnv(t) actEnv.RegisterActivity(rt.AgentLLMActivity) @@ -393,7 +408,7 @@ func TestAgentLLMStreamActivity_MockLLM_FallbackToGenerate(t *testing.T) { rt := &TemporalRuntime{ Runtime: base.Runtime{ AgentSpec: sdkruntime.AgentSpec{Name: "StreamAct"}, - AgentExecution: sdkruntime.AgentExecution{ + AgentConfig: sdkruntime.AgentConfig{ LLM: sdkruntime.AgentLLM{Client: mockLLM}, }, Tracer: observability.DefaultNoopTracer, @@ -401,6 +416,7 @@ func TestAgentLLMStreamActivity_MockLLM_FallbackToGenerate(t *testing.T) { }, logger: logger.NoopLogger(), } + wireTestToolsResolver(rt, nil) actEnv := newActivityTestEnv(t) actEnv.RegisterActivity(rt.AgentLLMStreamActivity) @@ -495,10 +511,10 @@ func makeRetrieverRuntime(t *testing.T, retrievers []interfaces.Retriever, mode mockLLM := mocks.NewMockLLMClient(gomock.NewController(t)) mockLLM.EXPECT().GetModel().Return("test-model").AnyTimes() mockLLM.EXPECT().GetProvider().Return(interfaces.LLMProviderOpenAI).AnyTimes() - return &TemporalRuntime{ + rt := &TemporalRuntime{ Runtime: base.Runtime{ AgentSpec: sdkruntime.AgentSpec{Name: "RetrieverTest"}, - AgentExecution: sdkruntime.AgentExecution{ + AgentConfig: sdkruntime.AgentConfig{ LLM: sdkruntime.AgentLLM{Client: mockLLM}, Retrievers: sdkruntime.AgentRetrievers{ Retrievers: retrievers, @@ -510,6 +526,8 @@ func makeRetrieverRuntime(t *testing.T, retrievers []interfaces.Retriever, mode }, logger: logger.NoopLogger(), } + wireTestToolsResolver(rt, nil) + return rt } func TestAgentRetrieverActivity_NoRetrievers(t *testing.T) { @@ -666,7 +684,7 @@ func TestAgentWorkflow_PrefetchMode_CallsRetrieverActivityFirst(t *testing.T) { rt := &TemporalRuntime{ Runtime: base.Runtime{ AgentSpec: sdkruntime.AgentSpec{Name: "PrefetchAgent", SystemPrompt: "base prompt"}, - AgentExecution: sdkruntime.AgentExecution{ + AgentConfig: sdkruntime.AgentConfig{ LLM: sdkruntime.AgentLLM{Client: stubLLM{}}, Limits: sdkruntime.AgentLimits{MaxIterations: 5}, Retrievers: sdkruntime.AgentRetrievers{ @@ -679,6 +697,7 @@ func TestAgentWorkflow_PrefetchMode_CallsRetrieverActivityFirst(t *testing.T) { }, logger: logger.NoopLogger(), } + wireTestToolsResolver(rt, nil) env.RegisterWorkflow(rt.AgentWorkflow) diff --git a/internal/runtime/temporal/config_test.go b/internal/runtime/temporal/config_test.go index 19a88d6..74e1dfb 100644 --- a/internal/runtime/temporal/config_test.go +++ b/internal/runtime/temporal/config_test.go @@ -90,7 +90,7 @@ func TestBuildTemporalRuntime_userProvidedTemporalClient_otelTracer_warns(t *tes WithLogger(log), WithTracer(newTestOTelTracer()), WithAgentSpec(sdkruntime.AgentSpec{Name: "x"}), - WithAgentExecution(sdkruntime.AgentExecution{LLM: sdkruntime.AgentLLM{Client: stubLLM{}}}), + WithAgentConfig(sdkruntime.AgentConfig{LLM: sdkruntime.AgentLLM{Client: stubLLM{}}}), ) if err != nil { t.Fatal(err) @@ -109,7 +109,7 @@ func TestBuildTemporalRuntime_userProvidedTemporalClient_defaultTracer_noManualI WithTemporalClient(tc, "tq"), WithLogger(log), WithAgentSpec(sdkruntime.AgentSpec{Name: "x"}), - WithAgentExecution(sdkruntime.AgentExecution{LLM: sdkruntime.AgentLLM{Client: stubLLM{}}}), + WithAgentConfig(sdkruntime.AgentConfig{LLM: sdkruntime.AgentLLM{Client: stubLLM{}}}), ) if err != nil { t.Fatal(err) @@ -129,7 +129,7 @@ func TestBuildTemporalRuntime_userProvidedTemporalClient_explicitNoopTracer_noMa WithLogger(log), WithTracer(observability.DefaultNoopTracer), WithAgentSpec(sdkruntime.AgentSpec{Name: "x"}), - WithAgentExecution(sdkruntime.AgentExecution{LLM: sdkruntime.AgentLLM{Client: stubLLM{}}}), + WithAgentConfig(sdkruntime.AgentConfig{LLM: sdkruntime.AgentLLM{Client: stubLLM{}}}), ) if err != nil { t.Fatal(err) @@ -149,7 +149,7 @@ func TestBuildTemporalRuntime_RequiresTemporalOrClient(t *testing.T) { WithPolicyFingerprint("test"), WithMCPFingerprint("test"), WithAgentSpec(sdkruntime.AgentSpec{Name: "test"}), - WithAgentExecution(sdkruntime.AgentExecution{ + WithAgentConfig(sdkruntime.AgentConfig{ LLM: sdkruntime.AgentLLM{Client: stubLLM{}}, }), } @@ -165,7 +165,7 @@ func TestBuildTemporalRuntime_RequiresLLMClient(t *testing.T) { WithTemporalClient(tc, "tq"), WithLogger(logger.NoopLogger()), WithAgentSpec(sdkruntime.AgentSpec{Name: "x"}), - WithAgentExecution(sdkruntime.AgentExecution{}), + WithAgentConfig(sdkruntime.AgentConfig{}), ) if err == nil || !strings.Contains(err.Error(), "llm client is required") { t.Fatalf("got %v", err) @@ -179,7 +179,7 @@ func TestBuildTemporalRuntime_InstanceIdSuffix(t *testing.T) { WithInstanceId("pod1"), WithLogger(logger.NoopLogger()), WithAgentSpec(sdkruntime.AgentSpec{Name: "x"}), - WithAgentExecution(sdkruntime.AgentExecution{LLM: sdkruntime.AgentLLM{Client: stubLLM{}}}), + WithAgentConfig(sdkruntime.AgentConfig{LLM: sdkruntime.AgentLLM{Client: stubLLM{}}}), ) if err != nil { t.Fatal(err) diff --git a/internal/runtime/temporal/fingerprint.go b/internal/runtime/temporal/fingerprint.go index 63b8eeb..e026dc9 100644 --- a/internal/runtime/temporal/fingerprint.go +++ b/internal/runtime/temporal/fingerprint.go @@ -180,14 +180,17 @@ func ToolNamesFromTools(tools []interfaces.Tool) []string { return names } -func computeAgentFingerprintFromRuntime(rt *TemporalRuntime) string { +func computeAgentFingerprintFromRuntime(rt *TemporalRuntime, tools []interfaces.Tool) string { + if rt == nil { + return "" + } mat := BuildAgentFingerprintPayload( rt.AgentSpec, - ToolNamesFromTools(rt.AgentExecution.Tools.Tools), + ToolNamesFromTools(tools), rt.policyFingerprint, - rt.AgentExecution.LLM.Sampling, - rt.AgentExecution.Session.ConversationSize, - rt.AgentExecution.Limits, + rt.AgentConfig.LLM.Sampling, + rt.AgentConfig.Session.ConversationSize, + rt.AgentConfig.Limits, rt.mcpFingerprint, rt.a2aFingerprint, rt.observabilityFingerprint, diff --git a/internal/runtime/temporal/fingerprint_test.go b/internal/runtime/temporal/fingerprint_test.go index 3eddb7b..a4e4c0d 100644 --- a/internal/runtime/temporal/fingerprint_test.go +++ b/internal/runtime/temporal/fingerprint_test.go @@ -2,9 +2,11 @@ package temporal import ( "context" + "strings" "testing" sdkruntime "github.com/agenticenv/agent-sdk-go/internal/runtime" + "github.com/agenticenv/agent-sdk-go/internal/runtime/base" "github.com/agenticenv/agent-sdk-go/pkg/interfaces" ) @@ -18,30 +20,9 @@ func (f fpTool) Execute(ctx context.Context, args map[string]any) (any, error) { return nil, nil } -func TestComputeAgentFingerprint_stableAndToolOrder(t *testing.T) { +func TestComputeAgentFingerprint_toolOrderStable(t *testing.T) { spec := sdkruntime.AgentSpec{Name: "a", SystemPrompt: "p"} lim := sdkruntime.AgentLimits{MaxIterations: 3} - - m := BuildAgentFingerprintPayload( - spec, - []string{"z", "a"}, - "auto", - nil, - 10, - lim, - "", - "", - "", - "", - "", - "", - ) - h1 := ComputeAgentFingerprint(m) - h2 := ComputeAgentFingerprint(m) - if len(h1) != 64 || h1 != h2 { - t.Fatalf("fingerprint len=%d h1=%q h2=%q", len(h1), h1, h2) - } - hA := ComputeAgentFingerprint(BuildAgentFingerprintPayload(spec, []string{"a", "b", "c"}, "auto", nil, 0, lim, "", "", "", "", "", "")) hB := ComputeAgentFingerprint(BuildAgentFingerprintPayload(spec, []string{"c", "a", "b"}, "auto", nil, 0, lim, "", "", "", "", "", "")) if hA != hB { @@ -49,7 +30,7 @@ func TestComputeAgentFingerprint_stableAndToolOrder(t *testing.T) { } } -func TestComputeAgentFingerprint_agentModeChangesDigest(t *testing.T) { +func TestComputeAgentFingerprint_stableWithoutTools(t *testing.T) { spec := sdkruntime.AgentSpec{Name: "a", SystemPrompt: "p"} lim := sdkruntime.AgentLimits{MaxIterations: 3} interactive := BuildAgentFingerprintPayload(spec, nil, "auto", nil, 0, lim, "", "", "", "", "", "") @@ -108,59 +89,37 @@ func TestComputeAgentFingerprint_observabilityFingerprintChangesDigest(t *testin } } -func newFingerprintRT(spec sdkruntime.AgentSpec, exec sdkruntime.AgentExecution, policyFP string, opts ...func(*TemporalRuntime)) *TemporalRuntime { - rt := &TemporalRuntime{} - rt.AgentSpec = spec - rt.AgentExecution = exec - rt.policyFingerprint = policyFP - for _, o := range opts { - o(rt) - } - rt.agentFingerprint = computeAgentFingerprintFromRuntime(rt) - return rt -} - func TestVerifyAgentFingerprint_mismatch(t *testing.T) { - rt := newFingerprintRT( - sdkruntime.AgentSpec{Name: "x"}, - sdkruntime.AgentExecution{}, - "require_all", - ) - err := rt.verifyAgentFingerprint("deadbeef") + rt := &TemporalRuntime{ + resolveToolsFn: func(context.Context) ([]interfaces.Tool, error) { + return nil, nil + }, + } + err := rt.verifyAgentFingerprint(context.Background(), "caller-fp", nil) if err == nil { t.Fatal("expected mismatch error") } } -func TestVerifyAgentFingerprint_bothEmptyOK(t *testing.T) { - rt := &TemporalRuntime{} - if err := rt.verifyAgentFingerprint(""); err != nil { - t.Fatal(err) +func TestVerifyAgentFingerprint_emptyCallerFingerprintSkipsCheck(t *testing.T) { + rt := &TemporalRuntime{ + resolveToolsFn: func(context.Context) ([]interfaces.Tool, error) { + return nil, nil + }, } -} - -func TestVerifyAgentFingerprint_emptyWantWhenWorkerHasFingerprint(t *testing.T) { - rt := newFingerprintRT( - sdkruntime.AgentSpec{Name: "x"}, - sdkruntime.AgentExecution{}, - "require_all", - ) - if err := rt.verifyAgentFingerprint(""); err == nil { - t.Fatal("expected mismatch when caller fingerprint is empty but worker has one") + if err := rt.verifyAgentFingerprint(context.Background(), "", nil); err != nil { + t.Fatal(err) } } func TestVerifyAgentFingerprint_disableCheckAllowsMismatch(t *testing.T) { - rt := newFingerprintRT( - sdkruntime.AgentSpec{Name: "x"}, - sdkruntime.AgentExecution{}, - "require_all", - func(rt *TemporalRuntime) { - rt.disableFingerprintCheck = true - rt.ToolExecutionMode = "sequential" + rt := &TemporalRuntime{ + disableFingerprintCheck: true, + resolveToolsFn: func(context.Context) ([]interfaces.Tool, error) { + return nil, nil }, - ) - if err := rt.verifyAgentFingerprint("definitely-different"); err != nil { + } + if err := rt.verifyAgentFingerprint(context.Background(), "caller-fp", nil); err != nil { t.Fatalf("expected bypass when skip is enabled, got: %v", err) } } @@ -210,3 +169,47 @@ func TestBuildAgentFingerprintPayload_responseFormatAndSampling(t *testing.T) { t.Fatalf("payload temperature should stay 0.2 after original changes: %+v", p.Sampling.Temperature) } } + +func TestFetchTools_requiresResolver(t *testing.T) { + rt := &TemporalRuntime{} + _, err := rt.fetchTools(context.Background()) + if err == nil || !strings.Contains(err.Error(), "tools resolver is not configured") { + t.Fatalf("fetchTools() = %v, want resolver not configured error", err) + } +} + +func TestFetchTools_delegatesToResolver(t *testing.T) { + want := []interfaces.Tool{fpTool{name: "t1"}} + rt := &TemporalRuntime{ + resolveToolsFn: func(context.Context) ([]interfaces.Tool, error) { + return want, nil + }, + } + got, err := rt.fetchTools(context.Background()) + if err != nil { + t.Fatal(err) + } + if len(got) != 1 || got[0].Name() != "t1" { + t.Fatalf("fetchTools() = %+v, want t1", got) + } +} + +func TestVerifyAgentFingerprint_usesPreFetchedTools(t *testing.T) { + tools := []interfaces.Tool{fpTool{name: "a"}} + rt := &TemporalRuntime{ + Runtime: base.Runtime{ + AgentSpec: sdkruntime.AgentSpec{Name: "a", SystemPrompt: "p"}, + AgentConfig: sdkruntime.AgentConfig{ + Limits: sdkruntime.AgentLimits{MaxIterations: 3}, + }, + }, + resolveToolsFn: func(context.Context) ([]interfaces.Tool, error) { + t.Fatal("fetchTools should not run when tools are pre-fetched") + return nil, nil + }, + } + fp := computeAgentFingerprintFromRuntime(rt, tools) + if err := rt.verifyAgentFingerprint(context.Background(), fp, tools); err != nil { + t.Fatal(err) + } +} diff --git a/internal/runtime/temporal/options.go b/internal/runtime/temporal/options.go index 78f1a31..5494039 100644 --- a/internal/runtime/temporal/options.go +++ b/internal/runtime/temporal/options.go @@ -70,10 +70,9 @@ func WithAgentSpec(spec sdkruntime.AgentSpec) Option { return func(rt *TemporalRuntime) { rt.AgentSpec = spec } } -// WithAgentExecution sets LLM, tools, session, and limits -// (same shape as [sdkruntime.ExecuteRequest.AgentExecution]). -func WithAgentExecution(exec sdkruntime.AgentExecution) Option { - return func(rt *TemporalRuntime) { rt.AgentExecution = exec } +// WithAgentConfig sets static LLM, session, limits, and tool approval policy on the worker runtime. +func WithAgentConfig(cfg sdkruntime.AgentConfig) Option { + return func(rt *TemporalRuntime) { rt.AgentConfig = cfg } } // WithPolicyFingerprint sets the opaque policy digest used with [ComputeAgentFingerprint]. @@ -119,6 +118,11 @@ func WithRetrieverFingerprint(fp string) Option { return func(rt *TemporalRuntime) { rt.retrieverFingerprint = fp } } +// WithToolsResolver sets the callback that resolves tools at activity time on the worker runtime. +func WithToolsResolver(fn ToolsResolver) Option { + return func(rt *TemporalRuntime) { rt.resolveToolsFn = fn } +} + // WithDisableLocalWorker mirrors pkg/agent DisableLocalWorker. When false, the client // embeds a worker and the runtime skips DescribeTaskQueue poller checks before starting // workflows. @@ -147,8 +151,7 @@ func WithMetrics(m interfaces.Metrics) Option { // buildTemporalRuntime applies options onto a fresh [TemporalRuntime], validates required // fields, and dials the Temporal client when [WithTemporalConfig] is used. The returned -// runtime is fully configured but does not yet have an agentFingerprint or eventbus — -// those are set by [NewTemporalRuntime]. +// runtime is fully configured but does not yet have an eventbus — that is set by [NewTemporalRuntime]. func buildTemporalRuntime(opts ...Option) (*TemporalRuntime, error) { rt := &TemporalRuntime{logger: logger.NoopLogger()} for _, opt := range opts { @@ -177,7 +180,7 @@ func buildTemporalRuntime(opts ...Option) (*TemporalRuntime, error) { rt.taskQueue = rt.taskQueue + "-" + rt.instanceId } - if rt.AgentExecution.LLM.Client == nil { + if rt.AgentConfig.LLM.Client == nil { return nil, fmt.Errorf("llm client is required") } @@ -193,15 +196,15 @@ func buildTemporalRuntime(opts ...Option) (*TemporalRuntime, error) { slog.String("agentName", rt.AgentSpec.Name), slog.String("taskQueue", rt.taskQueue), slog.String("instanceId", rt.instanceId), - slog.Int("maxIterations", rt.AgentExecution.Limits.MaxIterations), + slog.Int("maxIterations", rt.AgentConfig.Limits.MaxIterations), slog.Bool("remoteWorker", rt.remoteWorker), slog.String("agentMode", rt.agentMode), slog.String("toolExecutionMode", string(rt.ToolExecutionMode)), slog.Bool("enableRemoteWorkers", rt.enableRemoteWorkers), slog.Bool("disableFingerprintCheck", rt.disableFingerprintCheck), - slog.Duration("timeout", rt.AgentExecution.Limits.Timeout), - slog.Duration("approvalTimeout", rt.AgentExecution.Limits.ApprovalTimeout), - slog.Bool("hasConversation", rt.AgentExecution.Session.Conversation != nil), + slog.Duration("timeout", rt.AgentConfig.Limits.Timeout), + slog.Duration("approvalTimeout", rt.AgentConfig.Limits.ApprovalTimeout), + slog.Bool("hasConversation", rt.AgentConfig.Session.Conversation != nil), slog.Bool("hasTracer", rt.Tracer != nil), slog.Bool("hasMetrics", rt.Metrics != nil)) diff --git a/internal/runtime/temporal/runtime.go b/internal/runtime/temporal/runtime.go index 7f08454..2bca17b 100644 --- a/internal/runtime/temporal/runtime.go +++ b/internal/runtime/temporal/runtime.go @@ -16,6 +16,7 @@ import ( "github.com/agenticenv/agent-sdk-go/internal/runtime" "github.com/agenticenv/agent-sdk-go/internal/runtime/base" "github.com/agenticenv/agent-sdk-go/internal/types" + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" "github.com/agenticenv/agent-sdk-go/pkg/logger" "github.com/google/uuid" enumspb "go.temporal.io/api/enums/v1" @@ -43,15 +44,18 @@ const ( // ErrAgentAlreadyRunning is returned when Execute, ExecuteStream, or RunAsync is called while a run is already in progress. var ErrAgentAlreadyRunning = errors.New("agent already has an active run") -// ErrAgentFingerprintMismatch is returned when workflow input fingerprint does not match the worker. -var ErrAgentFingerprintMismatch = errors.New("temporal: agent fingerprint mismatch (caller vs worker); redeploy worker or align agent config") +// ErrAgentFingerprintMismatch is returned when the per-run agent fingerprint does not match the worker. +var ErrAgentFingerprintMismatch = errors.New("temporal: agent fingerprint mismatch (caller vs worker); redeploy worker or align config/registries or retry run") + +// ToolsResolver resolves per-run tools from registries at activity entry (worker runtime). +type ToolsResolver func(ctx context.Context) ([]interfaces.Tool, error) // TemporalRuntime implements [runtime.WorkerRuntime] and [runtime.EventBusRuntime] using // Temporal workflows and activities as the execution backend. -// It embeds [base.Runtime] for the common agent fields (AgentSpec, AgentExecution, Tracer, Metrics, +// It embeds [base.Runtime] for the common agent fields (AgentSpec, AgentConfig, Tracer, Metrics, // ToolExecutionMode) and holds all Temporal-specific connection and fingerprint state as flat fields. type TemporalRuntime struct { - base.Runtime // AgentSpec, AgentExecution, Tracer, Metrics, ToolExecutionMode + base.Runtime // AgentSpec, AgentConfig, Tracer, Metrics, ToolExecutionMode // Temporal connection temporalConfig *TemporalConfig @@ -66,7 +70,7 @@ type TemporalRuntime struct { logger logger.Logger - // Fingerprint inputs captured at construction and consumed by computeAgentFingerprintFromRuntime. + // Fingerprint inputs captured at construction; per-run digest from [computeAgentFingerprintFromRuntime]. policyFingerprint string mcpFingerprint string a2aFingerprint string @@ -83,8 +87,8 @@ type TemporalRuntime struct { // Break-glass only: keep false in production for rollout/config safety. disableFingerprintCheck bool - // agentFingerprint is ComputeAgentFingerprint(BuildAgentFingerprintPayload(...)) at NewTemporalRuntime; immutable. - agentFingerprint string + // resolveTools resolves tools from registries at activity time (worker runtime). + resolveToolsFn ToolsResolver eventbus eventbus.EventBus runMu sync.Mutex @@ -118,19 +122,34 @@ func NewTemporalRuntime(opts ...Option) (*TemporalRuntime, error) { slog.String("name", rt.AgentSpec.Name), slog.String("taskQueue", rt.taskQueue)) } - rt.agentFingerprint = computeAgentFingerprintFromRuntime(rt) rt.eventbus = eventbus.NewInmem(rt.logger) return rt, nil } -// verifyAgentFingerprint returns an error when want does not equal the runtime's agent fingerprint -// (computed at [NewTemporalRuntime]). -func (rt *TemporalRuntime) verifyAgentFingerprint(want string) error { - if rt.disableFingerprintCheck { +// fetchTools resolves tools from registries at activity time via [resolveToolsFn]. +func (rt *TemporalRuntime) fetchTools(ctx context.Context) ([]interfaces.Tool, error) { + if rt.resolveToolsFn == nil { + return nil, fmt.Errorf("temporal: tools resolver is not configured") + } + return rt.resolveToolsFn(ctx) +} + +// verifyAgentFingerprint compares caller vs worker config digest when fingerprint check is enabled. +// Pass nil tools to fetch via [fetchTools] internally; pass pre-fetched tools when the activity already resolved them. +func (rt *TemporalRuntime) verifyAgentFingerprint(ctx context.Context, callerFingerprint string, tools []interfaces.Tool) error { + if rt.disableFingerprintCheck || callerFingerprint == "" { return nil } - if rt.agentFingerprint != want { - return fmt.Errorf("%w: worker=%q caller=%q", ErrAgentFingerprintMismatch, rt.agentFingerprint, want) + if tools == nil { + var err error + tools, err = rt.fetchTools(ctx) + if err != nil { + return err + } + } + got := computeAgentFingerprintFromRuntime(rt, tools) + if got != callerFingerprint { + return fmt.Errorf("%w: worker=%q caller=%q", ErrAgentFingerprintMismatch, got, callerFingerprint) } return nil } @@ -261,18 +280,18 @@ func (rt *TemporalRuntime) Approve(ctx context.Context, approvalToken string, st return rt.temporalClient.CompleteActivity(ctx, taskToken, status, nil) } -func agentNameFromExecuteRequest(req *runtime.ExecuteRequest) string { - if req == nil || req.AgentSpec == nil { +func agentNameFromRuntime(rt *TemporalRuntime) string { + if rt == nil { return "" } - return req.AgentSpec.Name + return rt.AgentSpec.Name } func (rt *TemporalRuntime) Execute(ctx context.Context, req *runtime.ExecuteRequest) (*types.AgentRunResult, error) { - rt.logger.Debug(ctx, "runtime run dispatch", slog.String("scope", "runtime"), slog.String("agent", agentNameFromExecuteRequest(req)), slog.Int("inputLen", len(req.UserPrompt))) + rt.logger.Debug(ctx, "runtime run dispatch", slog.String("scope", "runtime"), slog.String("agent", agentNameFromRuntime(rt)), slog.Int("inputLen", len(req.UserPrompt))) runCtx := ctx - d := rt.AgentExecution.Limits.Timeout + d := rt.AgentConfig.Limits.Timeout if _, ok := ctx.Deadline(); !ok && d > 0 { var cancel context.CancelFunc runCtx, cancel = context.WithTimeout(ctx, d) @@ -289,7 +308,7 @@ func (rt *TemporalRuntime) Execute(ctx context.Context, req *runtime.ExecuteRequ threadID = runID } } - workflowID := rt.getWorkflowID(runID, agentNameFromExecuteRequest(req), false) + workflowID := rt.getWorkflowID(runID, agentNameFromRuntime(rt), false) rt.logger.Debug(runCtx, "runtime identifiers", slog.String("scope", "runtime"), slog.String("runID", runID), slog.String("threadID", threadID), slog.String("workflowID", workflowID)) @@ -305,7 +324,7 @@ func (rt *TemporalRuntime) Execute(ctx context.Context, req *runtime.ExecuteRequ EventWorkflowID: "", LocalChannelName: eventChannelName(workflowID), ConversationID: conversationID, - AgentFingerprint: rt.agentFingerprint, + AgentFingerprint: computeAgentFingerprintFromRuntime(rt, req.Tools), EventTypes: []events.AgentEventType{}, SubAgentDepth: 0, SubAgentRoutes: buildSubAgentRoutes(req.SubAgents), @@ -317,9 +336,9 @@ func (rt *TemporalRuntime) Execute(ctx context.Context, req *runtime.ExecuteRequ rt.logger.Error(runCtx, "runtime event worker creation failed", slog.String("scope", "runtime"), slog.String("taskQueue", rt.taskQueue), slog.Any("error", err)) return nil, err } - wfInput.EventWorkflowID, wfInput.EventTaskQueue, err = rt.resolveEventPipeline(runCtx, agentNameFromExecuteRequest(req)) + wfInput.EventWorkflowID, wfInput.EventTaskQueue, err = rt.resolveEventPipeline(runCtx, agentNameFromRuntime(rt)) if err != nil { - rt.logger.Error(runCtx, "runtime event pipeline resolution failed", slog.String("scope", "runtime"), slog.String("agent", agentNameFromExecuteRequest(req)), slog.Any("error", err)) + rt.logger.Error(runCtx, "runtime event pipeline resolution failed", slog.String("scope", "runtime"), slog.String("agent", agentNameFromRuntime(rt)), slog.Any("error", err)) return nil, err } } @@ -425,7 +444,7 @@ func (rt *TemporalRuntime) Execute(ctx context.Context, req *runtime.ExecuteRequ approvalResponseCh <- approvalResponse{approvalToken: token, status: status} return nil } - approvalCtx, cancel := context.WithTimeout(runCtx, rt.AgentExecution.Limits.ApprovalTimeout) + approvalCtx, cancel := context.WithTimeout(runCtx, rt.AgentConfig.Limits.ApprovalTimeout) req.ApprovalHandler(approvalCtx, apprReq) cancel() case resp := <-approvalResponseCh: @@ -438,7 +457,7 @@ func (rt *TemporalRuntime) Execute(ctx context.Context, req *runtime.ExecuteRequ } func (rt *TemporalRuntime) ExecuteStream(ctx context.Context, req *runtime.ExecuteRequest) (<-chan events.AgentEvent, error) { - rt.logger.Debug(ctx, "runtime stream run dispatch", slog.String("scope", "runtime"), slog.String("agent", agentNameFromExecuteRequest(req)), slog.Int("inputLen", len(req.UserPrompt))) + rt.logger.Debug(ctx, "runtime stream run dispatch", slog.String("scope", "runtime"), slog.String("agent", agentNameFromRuntime(rt)), slog.Int("inputLen", len(req.UserPrompt))) conversationID := base.GetConversationID(req) runID := uuid.New().String() @@ -450,7 +469,7 @@ func (rt *TemporalRuntime) ExecuteStream(ctx context.Context, req *runtime.Execu threadID = runID } } - workflowID := rt.getWorkflowID(runID, agentNameFromExecuteRequest(req), true) + workflowID := rt.getWorkflowID(runID, agentNameFromRuntime(rt), true) rt.logger.Debug(ctx, "runtime identifiers", slog.String("scope", "runtime"), slog.String("runID", runID), slog.String("threadID", threadID), slog.String("workflowID", workflowID)) @@ -471,9 +490,9 @@ func (rt *TemporalRuntime) ExecuteStream(ctx context.Context, req *runtime.Execu rt.logger.Error(ctx, "runtime event worker creation failed", slog.String("scope", "runtime"), slog.String("taskQueue", rt.taskQueue), slog.Any("error", err)) return nil, err } - eventWorkflowID, eventTaskQueue, err = rt.resolveEventPipeline(ctx, agentNameFromExecuteRequest(req)) + eventWorkflowID, eventTaskQueue, err = rt.resolveEventPipeline(ctx, agentNameFromRuntime(rt)) if err != nil { - rt.logger.Error(ctx, "runtime event pipeline resolution failed", slog.String("scope", "runtime"), slog.String("agent", agentNameFromExecuteRequest(req)), slog.Any("error", err)) + rt.logger.Error(ctx, "runtime event pipeline resolution failed", slog.String("scope", "runtime"), slog.String("agent", agentNameFromRuntime(rt)), slog.Any("error", err)) return nil, err } } @@ -489,7 +508,7 @@ func (rt *TemporalRuntime) ExecuteStream(ctx context.Context, req *runtime.Execu LocalChannelName: eventChannelName(workflowID), StreamingEnabled: req.StreamingEnabled, ConversationID: conversationID, - AgentFingerprint: rt.agentFingerprint, + AgentFingerprint: computeAgentFingerprintFromRuntime(rt, req.Tools), EventTypes: streamEventTypes, SubAgentDepth: 0, SubAgentRoutes: buildSubAgentRoutes(req.SubAgents), @@ -498,7 +517,7 @@ func (rt *TemporalRuntime) ExecuteStream(ctx context.Context, req *runtime.Execu runCtx := ctx var runCancel context.CancelFunc - d := rt.AgentExecution.Limits.Timeout + d := rt.AgentConfig.Limits.Timeout if _, ok := ctx.Deadline(); !ok && d > 0 { runCtx, runCancel = context.WithTimeout(ctx, d) } @@ -549,7 +568,7 @@ func (rt *TemporalRuntime) ExecuteStream(ctx context.Context, req *runtime.Execu wfErrCh := make(chan error, 1) workflowResultCh := make(chan *types.AgentRunResult, 1) localChannel := wfInput.LocalChannelName - rootName := agentNameFromExecuteRequest(req) + rootName := agentNameFromRuntime(rt) // eventCh → outCh only: all RUN_* and workflow events pass through the local bus (publish then forward). go func() { diff --git a/internal/runtime/temporal/runtime_test.go b/internal/runtime/temporal/runtime_test.go index 0567188..be46b02 100644 --- a/internal/runtime/temporal/runtime_test.go +++ b/internal/runtime/temporal/runtime_test.go @@ -54,16 +54,14 @@ func TestGetEventTaskQueue(t *testing.T) { } } -func TestAgentNameFromExecuteRequest(t *testing.T) { - if agentNameFromExecuteRequest(nil) != "" { - t.Fatal("nil req") +func TestAgentNameFromRuntime(t *testing.T) { + if agentNameFromRuntime(nil) != "" { + t.Fatal("nil rt") } - if agentNameFromExecuteRequest(&sdkruntime.ExecuteRequest{}) != "" { - t.Fatal("nil AgentSpec") + rt := &TemporalRuntime{ + Runtime: base.Runtime{AgentSpec: sdkruntime.AgentSpec{Name: "n"}}, } - if got := agentNameFromExecuteRequest(&sdkruntime.ExecuteRequest{ - AgentSpec: &sdkruntime.AgentSpec{Name: "n"}, - }); got != "n" { + if got := agentNameFromRuntime(rt); got != "n" { t.Fatalf("got %q", got) } } @@ -320,13 +318,13 @@ func TestTemporalRuntime_Run_Success(t *testing.T) { WithTemporalClient(tc, "tq"), WithDisableLocalWorker(true), WithAgentSpec(sdkruntime.AgentSpec{Name: "agent-a"}), - WithAgentExecution(sdkruntime.AgentExecution{LLM: sdkruntime.AgentLLM{Client: stubLLM{}}}), + WithAgentConfig(sdkruntime.AgentConfig{LLM: sdkruntime.AgentLLM{Client: stubLLM{}}}), ) if err != nil { t.Fatal(err) } - resp, err := rt.Execute(context.Background(), &sdkruntime.ExecuteRequest{UserPrompt: "hi", AgentSpec: &sdkruntime.AgentSpec{Name: "agent-a"}}) + resp, err := rt.Execute(context.Background(), &sdkruntime.ExecuteRequest{UserPrompt: "hi"}) if err != nil { t.Fatalf("Run: %v", err) } @@ -344,7 +342,7 @@ func TestTemporalRuntime_Run_NoWorkers(t *testing.T) { WithTemporalClient(tc, "tq"), WithDisableLocalWorker(true), WithAgentSpec(sdkruntime.AgentSpec{Name: "agent-a"}), - WithAgentExecution(sdkruntime.AgentExecution{LLM: sdkruntime.AgentLLM{Client: stubLLM{}}}), + WithAgentConfig(sdkruntime.AgentConfig{LLM: sdkruntime.AgentLLM{Client: stubLLM{}}}), ) if err != nil { t.Fatal(err) @@ -353,7 +351,7 @@ func TestTemporalRuntime_Run_NoWorkers(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) defer cancel() - _, err = rt.Execute(ctx, &sdkruntime.ExecuteRequest{UserPrompt: "hi", AgentSpec: &sdkruntime.AgentSpec{Name: "agent-a"}}) + _, err = rt.Execute(ctx, &sdkruntime.ExecuteRequest{UserPrompt: "hi"}) if err == nil { t.Fatal("expected error when no workers") } @@ -373,13 +371,13 @@ func TestTemporalRuntime_Run_ExecuteWorkflowError(t *testing.T) { WithTemporalClient(tc, "tq"), WithDisableLocalWorker(true), WithAgentSpec(sdkruntime.AgentSpec{Name: "agent-a"}), - WithAgentExecution(sdkruntime.AgentExecution{LLM: sdkruntime.AgentLLM{Client: stubLLM{}}}), + WithAgentConfig(sdkruntime.AgentConfig{LLM: sdkruntime.AgentLLM{Client: stubLLM{}}}), ) if err != nil { t.Fatal(err) } - _, err = rt.Execute(context.Background(), &sdkruntime.ExecuteRequest{UserPrompt: "hi", AgentSpec: &sdkruntime.AgentSpec{Name: "agent-a"}}) + _, err = rt.Execute(context.Background(), &sdkruntime.ExecuteRequest{UserPrompt: "hi"}) if err == nil || err.Error() != "start failed" { t.Fatalf("got %v, want start failed", err) } @@ -398,13 +396,13 @@ func TestTemporalRuntime_Run_WorkflowGetError(t *testing.T) { WithTemporalClient(tc, "tq"), WithDisableLocalWorker(true), WithAgentSpec(sdkruntime.AgentSpec{Name: "agent-a"}), - WithAgentExecution(sdkruntime.AgentExecution{LLM: sdkruntime.AgentLLM{Client: stubLLM{}}}), + WithAgentConfig(sdkruntime.AgentConfig{LLM: sdkruntime.AgentLLM{Client: stubLLM{}}}), ) if err != nil { t.Fatal(err) } - _, err = rt.Execute(context.Background(), &sdkruntime.ExecuteRequest{UserPrompt: "hi", AgentSpec: &sdkruntime.AgentSpec{Name: "agent-a"}}) + _, err = rt.Execute(context.Background(), &sdkruntime.ExecuteRequest{UserPrompt: "hi"}) if err == nil || err.Error() != "workflow failed" { t.Fatalf("got %v, want workflow failed", err) } @@ -434,14 +432,14 @@ func TestTemporalRuntime_ExecuteStream_Success(t *testing.T) { WithTemporalClient(tc, "tq"), WithDisableLocalWorker(true), WithAgentSpec(sdkruntime.AgentSpec{Name: "root"}), - WithAgentExecution(sdkruntime.AgentExecution{LLM: sdkruntime.AgentLLM{Client: stubLLM{}}}), + WithAgentConfig(sdkruntime.AgentConfig{LLM: sdkruntime.AgentLLM{Client: stubLLM{}}}), ) if err != nil { t.Fatal(err) } ctx := context.Background() - outCh, err := rt.ExecuteStream(ctx, &sdkruntime.ExecuteRequest{UserPrompt: "hi", AgentSpec: &sdkruntime.AgentSpec{Name: "root"}}) + outCh, err := rt.ExecuteStream(ctx, &sdkruntime.ExecuteRequest{UserPrompt: "hi"}) if err != nil { t.Fatalf("ExecuteStream: %v", err) } @@ -474,13 +472,13 @@ func TestTemporalRuntime_ExecuteStream_WorkflowGetError(t *testing.T) { WithTemporalClient(tc, "tq"), WithDisableLocalWorker(true), WithAgentSpec(sdkruntime.AgentSpec{Name: "root"}), - WithAgentExecution(sdkruntime.AgentExecution{LLM: sdkruntime.AgentLLM{Client: stubLLM{}}}), + WithAgentConfig(sdkruntime.AgentConfig{LLM: sdkruntime.AgentLLM{Client: stubLLM{}}}), ) if err != nil { t.Fatal(err) } - outCh, err := rt.ExecuteStream(context.Background(), &sdkruntime.ExecuteRequest{UserPrompt: "hi", AgentSpec: &sdkruntime.AgentSpec{Name: "root"}}) + outCh, err := rt.ExecuteStream(context.Background(), &sdkruntime.ExecuteRequest{UserPrompt: "hi"}) if err != nil { t.Fatalf("ExecuteStream: %v", err) } @@ -502,7 +500,7 @@ func TestTemporalRuntime_Start_Idempotent(t *testing.T) { rt, err := NewTemporalRuntime( WithTemporalClient(tc, "tq"), WithAgentSpec(sdkruntime.AgentSpec{Name: "agent-a"}), - WithAgentExecution(sdkruntime.AgentExecution{LLM: sdkruntime.AgentLLM{Client: stubLLM{}}}), + WithAgentConfig(sdkruntime.AgentConfig{LLM: sdkruntime.AgentLLM{Client: stubLLM{}}}), ) if err != nil { t.Fatal(err) @@ -526,7 +524,7 @@ func TestTemporalRuntime_Stop_RemoteOwnedClient(t *testing.T) { WithTemporalClient(tc, "tq"), WithRemoteWorker(true), WithAgentSpec(sdkruntime.AgentSpec{Name: "agent-a"}), - WithAgentExecution(sdkruntime.AgentExecution{LLM: sdkruntime.AgentLLM{Client: stubLLM{}}}), + WithAgentConfig(sdkruntime.AgentConfig{LLM: sdkruntime.AgentLLM{Client: stubLLM{}}}), ) if err != nil { t.Fatal(err) @@ -550,7 +548,7 @@ func TestTemporalRuntime_Stop_RemoteOwnedClientNoAgentWorker(t *testing.T) { WithTemporalClient(tc, "tq"), WithRemoteWorker(true), WithAgentSpec(sdkruntime.AgentSpec{Name: "agent-a"}), - WithAgentExecution(sdkruntime.AgentExecution{LLM: sdkruntime.AgentLLM{Client: stubLLM{}}}), + WithAgentConfig(sdkruntime.AgentConfig{LLM: sdkruntime.AgentLLM{Client: stubLLM{}}}), ) if err != nil { t.Fatal(err) @@ -568,7 +566,7 @@ func TestTemporalRuntime_Stop_LocalEmbed(t *testing.T) { WithTemporalClient(tc, "tq"), WithRemoteWorker(false), WithAgentSpec(sdkruntime.AgentSpec{Name: "agent-a"}), - WithAgentExecution(sdkruntime.AgentExecution{LLM: sdkruntime.AgentLLM{Client: stubLLM{}}}), + WithAgentConfig(sdkruntime.AgentConfig{LLM: sdkruntime.AgentLLM{Client: stubLLM{}}}), ) if err != nil { t.Fatal(err) @@ -581,7 +579,7 @@ func TestTemporalRuntime_Close_Minimal(t *testing.T) { rt, err := NewTemporalRuntime( WithTemporalClient(tc, "tq"), WithAgentSpec(sdkruntime.AgentSpec{Name: "agent-a"}), - WithAgentExecution(sdkruntime.AgentExecution{LLM: sdkruntime.AgentLLM{Client: stubLLM{}}}), + WithAgentConfig(sdkruntime.AgentConfig{LLM: sdkruntime.AgentLLM{Client: stubLLM{}}}), ) if err != nil { t.Fatal(err) @@ -596,7 +594,7 @@ func TestTemporalRuntime_Close_OwnsTemporalClient(t *testing.T) { rt, err := NewTemporalRuntime( WithTemporalClient(tc, "tq"), WithAgentSpec(sdkruntime.AgentSpec{Name: "agent-a"}), - WithAgentExecution(sdkruntime.AgentExecution{LLM: sdkruntime.AgentLLM{Client: stubLLM{}}}), + WithAgentConfig(sdkruntime.AgentConfig{LLM: sdkruntime.AgentLLM{Client: stubLLM{}}}), ) if err != nil { t.Fatal(err) @@ -611,7 +609,7 @@ func TestTemporalRuntime_Close_StopsWorkers(t *testing.T) { rt, err := NewTemporalRuntime( WithTemporalClient(tc, "tq"), WithAgentSpec(sdkruntime.AgentSpec{Name: "agent-a"}), - WithAgentExecution(sdkruntime.AgentExecution{LLM: sdkruntime.AgentLLM{Client: stubLLM{}}}), + WithAgentConfig(sdkruntime.AgentConfig{LLM: sdkruntime.AgentLLM{Client: stubLLM{}}}), ) if err != nil { t.Fatal(err) @@ -639,7 +637,7 @@ func TestTemporalRuntime_Close_ActiveWorkflows(t *testing.T) { rt, err := NewTemporalRuntime( WithTemporalClient(tc, "tq"), WithAgentSpec(sdkruntime.AgentSpec{Name: "agent-a"}), - WithAgentExecution(sdkruntime.AgentExecution{LLM: sdkruntime.AgentLLM{Client: stubLLM{}}}), + WithAgentConfig(sdkruntime.AgentConfig{LLM: sdkruntime.AgentLLM{Client: stubLLM{}}}), ) if err != nil { t.Fatal(err) diff --git a/internal/runtime/temporal/subagent.go b/internal/runtime/temporal/subagent.go index 48e97bf..9723f6e 100644 --- a/internal/runtime/temporal/subagent.go +++ b/internal/runtime/temporal/subagent.go @@ -13,7 +13,8 @@ type SubAgentRoute struct { // buildSubAgentRoutes converts the runtime-agnostic SubAgentSpec tree (from ExecuteRequest) // into a Temporal-specific SubAgentRoute map. Each spec's Runtime is type-asserted to -// *TemporalRuntime to extract the task queue and agent fingerprint. +// *TemporalRuntime to extract the task queue and per-run agent fingerprint (static runtime +// digests + resolved spec.Tools). func buildSubAgentRoutes(specs []*sdkruntime.SubAgentSpec) map[string]SubAgentRoute { if len(specs) == 0 { return nil @@ -26,7 +27,7 @@ func buildSubAgentRoutes(specs []*sdkruntime.SubAgentSpec) map[string]SubAgentRo route := SubAgentRoute{Name: spec.Name} if tr, ok := spec.Runtime.(*TemporalRuntime); ok { route.TaskQueue = tr.taskQueue - route.AgentFingerprint = tr.agentFingerprint + route.AgentFingerprint = computeAgentFingerprintFromRuntime(tr, spec.Tools) } route.ChildRoutes = buildSubAgentRoutes(spec.Children) out[spec.ToolName] = route diff --git a/internal/runtime/temporal/subagent_test.go b/internal/runtime/temporal/subagent_test.go new file mode 100644 index 0000000..a242bf3 --- /dev/null +++ b/internal/runtime/temporal/subagent_test.go @@ -0,0 +1,123 @@ +package temporal + +import ( + "testing" + + sdkruntime "github.com/agenticenv/agent-sdk-go/internal/runtime" + "github.com/agenticenv/agent-sdk-go/internal/runtime/base" + "github.com/agenticenv/agent-sdk-go/internal/types" + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" +) + +func testTemporalRuntime(name, taskQueue string) *TemporalRuntime { + return &TemporalRuntime{ + Runtime: base.Runtime{ + AgentSpec: sdkruntime.AgentSpec{Name: name, SystemPrompt: "p"}, + AgentConfig: sdkruntime.AgentConfig{ + Limits: sdkruntime.AgentLimits{MaxIterations: 3}, + }, + ToolExecutionMode: types.AgentToolExecutionModeParallel, + }, + taskQueue: taskQueue, + policyFingerprint: "policy", + } +} + +func TestBuildSubAgentRoutes_setsAgentFingerprint(t *testing.T) { + subRT := testTemporalRuntime("sub", "sub-queue") + subTools := []interfaces.Tool{fpTool{name: "sub_tool"}} + want := computeAgentFingerprintFromRuntime(subRT, subTools) + + routes := buildSubAgentRoutes([]*sdkruntime.SubAgentSpec{{ + Name: "Sub", + ToolName: "subagent_Sub", + Runtime: subRT, + Tools: subTools, + }}) + route, ok := routes["subagent_Sub"] + if !ok { + t.Fatal("missing route") + } + if route.TaskQueue != "sub-queue" { + t.Fatalf("task queue: got %q", route.TaskQueue) + } + if route.AgentFingerprint != want { + t.Fatalf("fingerprint: got %q want %q", route.AgentFingerprint, want) + } + if route.AgentFingerprint == "" { + t.Fatal("expected non-empty sub-agent fingerprint") + } +} + +func TestBuildSubAgentRoutes_nestedChildFingerprint(t *testing.T) { + childRT := testTemporalRuntime("child", "child-queue") + childTools := []interfaces.Tool{fpTool{name: "child_tool"}} + wantChild := computeAgentFingerprintFromRuntime(childRT, childTools) + + parentRT := testTemporalRuntime("parent", "parent-queue") + routes := buildSubAgentRoutes([]*sdkruntime.SubAgentSpec{{ + Name: "Parent", + ToolName: "subagent_Parent", + Runtime: parentRT, + Tools: []interfaces.Tool{fpTool{name: "parent_tool"}}, + Children: []*sdkruntime.SubAgentSpec{{ + Name: "Child", + ToolName: "subagent_Child", + Runtime: childRT, + Tools: childTools, + }}, + }}) + + parentRoute := routes["subagent_Parent"] + childRoute, ok := parentRoute.ChildRoutes["subagent_Child"] + if !ok { + t.Fatal("missing nested child route") + } + if childRoute.AgentFingerprint != wantChild { + t.Fatalf("child fingerprint: got %q want %q", childRoute.AgentFingerprint, wantChild) + } +} + +func TestBuildSubAgentRoutes_parentAndSubFingerprintsDiffer(t *testing.T) { + subRT := testTemporalRuntime("sub", "sub-queue") + subRT.AgentSpec.SystemPrompt = "sub prompt" + parentRT := testTemporalRuntime("parent", "parent-queue") + parentRT.AgentSpec.SystemPrompt = "parent prompt" + + routes := buildSubAgentRoutes([]*sdkruntime.SubAgentSpec{{ + Name: "Parent", + ToolName: "subagent_Parent", + Runtime: parentRT, + Tools: []interfaces.Tool{fpTool{name: "parent_tool"}}, + }, { + Name: "Sub", + ToolName: "subagent_Sub", + Runtime: subRT, + Tools: []interfaces.Tool{fpTool{name: "sub_tool"}}, + }}) + + parentFP := routes["subagent_Parent"].AgentFingerprint + subFP := routes["subagent_Sub"].AgentFingerprint + if parentFP == "" || subFP == "" { + t.Fatal("expected non-empty fingerprints") + } + if parentFP == subFP { + t.Fatalf("parent and sub fingerprints must differ: %q", parentFP) + } +} + +func TestBuildSubAgentRoutes_nonTemporalSkipsFingerprint(t *testing.T) { + routes := buildSubAgentRoutes([]*sdkruntime.SubAgentSpec{{ + Name: "Local", + ToolName: "subagent_Local", + Runtime: nil, + Tools: []interfaces.Tool{fpTool{name: "t"}}, + }}) + route := routes["subagent_Local"] + if route.AgentFingerprint != "" { + t.Fatalf("non-temporal route should not set fingerprint: %q", route.AgentFingerprint) + } + if route.TaskQueue != "" { + t.Fatalf("non-temporal route should not set task queue: %q", route.TaskQueue) + } +} diff --git a/pkg/agent/a2a.go b/pkg/agent/a2a.go index 03b7dc1..99f12dc 100644 --- a/pkg/agent/a2a.go +++ b/pkg/agent/a2a.go @@ -23,6 +23,7 @@ var ( const defaultA2AToolTimeout = types.DefaultA2ATimeout var _ interfaces.Tool = (*A2ATool)(nil) +var _ interfaces.ToolKindProvider = (*A2ATool)(nil) // NOTE: A2ATools for the same server share one A2AClient. The default pkg/a2a/client is safe // for concurrent use; custom A2AClient implementations should document concurrency behaviour. @@ -75,6 +76,9 @@ func NewA2ATool(serverName string, spec interfaces.ToolSpec, skillSpec interface return &A2ATool{ServerName: serverName, Spec: spec, SkillSpec: skillSpec, Client: client} } +// ToolKind implements [interfaces.ToolKindProvider]. +func (t *A2ATool) ToolKind() string { return "a2a" } + // Name implements [interfaces.Tool]. func (t *A2ATool) Name() string { if t == nil { diff --git a/pkg/agent/a2a_registry.go b/pkg/agent/a2a_registry.go new file mode 100644 index 0000000..fcb84a2 --- /dev/null +++ b/pkg/agent/a2a_registry.go @@ -0,0 +1,133 @@ +package agent + +import ( + "fmt" + "strings" + "sync" + + a2aclient "github.com/agenticenv/agent-sdk-go/pkg/a2a/client" + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" + "github.com/agenticenv/agent-sdk-go/pkg/logger" +) + +var _ A2ARegistry = (*a2aRegistryImpl)(nil) + +type a2aRegistryImpl struct { + mu sync.RWMutex + logger logger.Logger + clients map[string]interfaces.A2AClient + order []string +} + +// NewA2ARegistry returns an empty A2A client registry for use with [WithA2ARegistry]. +// logger is used when [Register] builds a client from [A2AConfig]. +func NewA2ARegistry(l logger.Logger) A2ARegistry { + if l == nil { + l = NoopLogger() + } + return &a2aRegistryImpl{ + logger: l, + clients: make(map[string]interfaces.A2AClient), + } +} + +func (r *a2aRegistryImpl) Register(name string, config A2AConfig) error { + cl, err := newA2AClient(name, config, r.logger) + if err != nil { + return err + } + r.mu.Lock() + defer r.mu.Unlock() + return r.registerClientLocked(cl) +} + +func (r *a2aRegistryImpl) RegisterClient(client interfaces.A2AClient) error { + if client == nil { + return ErrRegistryNilEntry + } + r.mu.Lock() + defer r.mu.Unlock() + return r.registerClientLocked(client) +} + +func (r *a2aRegistryImpl) registerClientLocked(client interfaces.A2AClient) error { + name := strings.TrimSpace(client.Name()) + if name == "" { + return ErrRegistryInvalidName + } + if _, exists := r.clients[name]; exists { + return ErrRegistryDuplicate + } + r.order = append(r.order, name) + r.clients[name] = client + return nil +} + +func (r *a2aRegistryImpl) Unregister(name string) error { + name = strings.TrimSpace(name) + if name == "" { + return ErrRegistryInvalidName + } + r.mu.Lock() + defer r.mu.Unlock() + if _, ok := r.clients[name]; !ok { + return ErrRegistryNotFound + } + delete(r.clients, name) + r.order = removeFromOrder(r.order, name) + return nil +} + +func (r *a2aRegistryImpl) Get(name string) (interfaces.A2AClient, error) { + name = strings.TrimSpace(name) + if name == "" { + return nil, ErrRegistryInvalidName + } + r.mu.RLock() + defer r.mu.RUnlock() + c, ok := r.clients[name] + if !ok { + return nil, ErrRegistryNotFound + } + return c, nil +} + +func (r *a2aRegistryImpl) List() []interfaces.A2AClient { + r.mu.RLock() + defer r.mu.RUnlock() + out := make([]interfaces.A2AClient, 0, len(r.order)) + for _, name := range r.order { + if c, ok := r.clients[name]; ok { + out = append(out, c) + } + } + return out +} + +func newA2AClient(name string, cfg A2AConfig, log logger.Logger) (interfaces.A2AClient, error) { + name = strings.TrimSpace(name) + if name == "" { + return nil, ErrRegistryInvalidName + } + if strings.TrimSpace(cfg.URL) == "" { + return nil, fmt.Errorf("a2a %q: URL is required", name) + } + if log == nil { + log = NoopLogger() + } + a2aOpts := []a2aclient.Option{ + a2aclient.WithLogger(log), + a2aclient.WithTimeout(cfg.Timeout), + a2aclient.WithToken(cfg.Token), + a2aclient.WithHeaders(cfg.Headers), + a2aclient.WithSkillFilter(cfg.SkillFilter), + } + if cfg.SkipTLSVerify { + a2aOpts = append(a2aOpts, a2aclient.WithSkipTLSVerify(true)) + } + cl, err := a2aclient.NewClient(name, cfg.URL, a2aOpts...) + if err != nil { + return nil, fmt.Errorf("a2a %q: new client: %w", name, err) + } + return cl, nil +} diff --git a/pkg/agent/a2a_registry_test.go b/pkg/agent/a2a_registry_test.go new file mode 100644 index 0000000..6abd8fb --- /dev/null +++ b/pkg/agent/a2a_registry_test.go @@ -0,0 +1,67 @@ +package agent + +import ( + "context" + "testing" + + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" +) + +type registryMockA2AClient struct { + name string +} + +func (c *registryMockA2AClient) Name() string { return c.name } +func (c *registryMockA2AClient) Close() error { return nil } +func (c *registryMockA2AClient) Ping(context.Context) error { + return nil +} +func (c *registryMockA2AClient) ResolveCard(context.Context) (interfaces.A2AAgentCard, error) { + return interfaces.A2AAgentCard{}, nil +} +func (c *registryMockA2AClient) ListSkills(context.Context) ([]interfaces.A2ASkillSpec, error) { + return nil, nil +} +func (c *registryMockA2AClient) SendMessage(context.Context, interfaces.A2ASendMessageRequest) (interfaces.A2ASendMessageResult, error) { + return interfaces.A2ASendMessageResult{}, nil +} + +func TestA2ARegistry_RegisterClient(t *testing.T) { + r := NewA2ARegistry(nil) + cl := ®istryMockA2AClient{name: "agent1"} + if err := r.RegisterClient(cl); err != nil { + t.Fatal(err) + } + if _, err := r.Get("agent1"); err != nil { + t.Fatalf("Get(agent1) err = %v", err) + } +} + +func TestA2ARegistry_RegisterMissingURL(t *testing.T) { + r := NewA2ARegistry(nil) + if err := r.Register("remote", A2AConfig{}); err == nil { + t.Fatal("expected error for missing URL") + } +} + +func TestA2ARegistry_UnregisterNotFound(t *testing.T) { + r := NewA2ARegistry(nil) + if err := r.Unregister("missing"); err != ErrRegistryNotFound { + t.Errorf("err = %v, want ErrRegistryNotFound", err) + } +} + +func TestNormalizeA2ARegistry_fromWithA2AClients(t *testing.T) { + cl := ®istryMockA2AClient{name: "agent1"} + c := &agentConfig{a2aClients: []interfaces.A2AClient{cl}} + if err := c.buildA2ARegistry(); err != nil { + t.Fatal(err) + } + if c.a2aRegistry == nil { + t.Fatal("expected a2aRegistry after buildA2ARegistry") + } + got, err := c.a2aRegistry.Get("agent1") + if err != nil || got != cl { + t.Fatalf("Get(agent1) = %v, %v", got, err) + } +} diff --git a/pkg/agent/a2a_server.go b/pkg/agent/a2a_server.go index ea20eba..b136189 100644 --- a/pkg/agent/a2a_server.go +++ b/pkg/agent/a2a_server.go @@ -357,7 +357,10 @@ func (a *Agent) buildSDKAgentCard() *a2a.AgentCard { // Examples from the remote agent card. All other tools are mapped generically using the // tool's Name, DisplayName, and Description. Tools with an empty Name are skipped. func (a *Agent) deriveSDKSkills() []a2a.AgentSkill { - tools := a.toolsList() + tools, err := a.resolveTools(context.Background()) + if err != nil { + return nil + } skills := make([]a2a.AgentSkill, 0, len(tools)) for _, t := range tools { if t == nil { diff --git a/pkg/agent/a2a_server_test.go b/pkg/agent/a2a_server_test.go index 67e79aa..17603a3 100644 --- a/pkg/agent/a2a_server_test.go +++ b/pkg/agent/a2a_server_test.go @@ -99,6 +99,7 @@ func TestBuildSDKAgentCard(t *testing.T) { a2aServerConfig: &A2AServerConfig{Hostname: "127.0.0.1", Port: 9000}, }, } + mustTestRegistries(t, &a.agentConfig) card := a.buildSDKAgentCard() if card.Name != "CardAgent" || card.Description != "desc" || card.Version != a2aServerVersion { t.Fatalf("card metadata: %+v", card) @@ -121,6 +122,7 @@ func TestBuildSDKAgentCard(t *testing.T) { a2aServerConfig: &A2AServerConfig{Hostname: "localhost", Port: 1}, }, } + mustTestRegistries(t, &a2.agentConfig) c2 := a2.buildSDKAgentCard() if !c2.Capabilities.Streaming { t.Fatal("Streaming should be true when stream enabled and LLM supports it") @@ -133,6 +135,7 @@ func TestBuildSDKAgentCard(t *testing.T) { a2aServerConfig: &A2AServerConfig{Hostname: "h", Port: 9, BearerTokens: []string{"secret"}}, }, } + mustTestRegistries(t, &a3.agentConfig) c3 := a3.buildSDKAgentCard() if len(c3.SecuritySchemes) == 0 || len(c3.SecurityRequirements) == 0 { t.Fatalf("expected security on card when BearerTokens set: schemes=%v reqs=%v", @@ -182,8 +185,6 @@ func TestDeriveSDKSkills(t *testing.T) { a := &Agent{ agentConfig: agentConfig{ tools: []interfaces.Tool{ - nil, - serverTestTool{name: "", display: "x", desc: "y"}, serverTestTool{name: "alpha", display: "Alpha", desc: "generic tool"}, NewA2ATool("remote", interfaces.ToolSpec{Name: "sk1", Description: "d"}, interfaces.A2ASkillSpec{ @@ -194,9 +195,12 @@ func TestDeriveSDKSkills(t *testing.T) { }, }, } + if err := a.buildRegistries(); err != nil { + t.Fatal(err) + } sk := a.deriveSDKSkills() if len(sk) != 2 { - t.Fatalf("want 2 skills (nil and empty name skipped), got %d: %+v", len(sk), sk) + t.Fatalf("want 2 skills, got %d: %+v", len(sk), sk) } if sk[0].ID != "alpha" || sk[0].Name != "Alpha" { t.Fatalf("generic skill: %+v", sk[0]) @@ -562,6 +566,7 @@ func TestRunA2A_ServesAgentCardAndSendMessage(t *testing.T) { }, runtime: mockRT, } + mustTestRegistries(t, &a.agentConfig) ctx, cancel := context.WithCancel(context.Background()) done := make(chan struct{}) @@ -683,6 +688,7 @@ func TestRunA2A_JSONRPCSendStreamingMessage_ReturnsSSE(t *testing.T) { }, runtime: mockRT, } + mustTestRegistries(t, &a.agentConfig) ctx, cancel := context.WithCancel(context.Background()) done := make(chan struct{}) @@ -787,6 +793,7 @@ func TestAgentCardProducer_CardJSON_InjectsURLField(t *testing.T) { a2aServerConfig: &A2AServerConfig{Hostname: "localhost", Port: 8080}, }, } + mustTestRegistries(t, &a.agentConfig) p := (*agentCardProducer)(a) b, err := p.CardJSON(context.Background()) if err != nil { @@ -812,6 +819,7 @@ func TestAgentCardProducer_Card_ReturnsTypedCard(t *testing.T) { a2aServerConfig: &A2AServerConfig{Hostname: "localhost", Port: 8080}, }, } + mustTestRegistries(t, &a.agentConfig) p := (*agentCardProducer)(a) card, err := p.Card(context.Background()) if err != nil { @@ -833,6 +841,7 @@ func TestAgentCardHandler_MethodNotAllowed(t *testing.T) { a2aServerConfig: &A2AServerConfig{Hostname: "localhost", Port: 8080}, }, } + mustTestRegistries(t, &a.agentConfig) srv := httptest.NewServer(a2asrv.NewAgentCardHandler((*agentCardProducer)(a))) defer srv.Close() diff --git a/pkg/agent/agent.go b/pkg/agent/agent.go index b69dacb..b1c5d15 100644 --- a/pkg/agent/agent.go +++ b/pkg/agent/agent.go @@ -71,38 +71,9 @@ func buildAgent(opts []Option) (*Agent, error) { a.localAgentWorker = &AgentWorker{agentConfig: *cfg, runtime: rt} } - // Sub-agents share the parent's in-memory pub/sub when the runtime implements [runtime.EventBusRuntime] - // (e.g. Temporal). A custom [runtime.Runtime] from a future WithRuntime need only implement [Runtime]; - // wiring is skipped when the assert to EventBusRuntime fails. - if ir, ok := a.runtime.(runtime.EventBusRuntime); ok { - bus := ir.GetEventBus() - for _, sub := range a.subAgents { - if sub != nil { - wireInMemoryEventChannelToSubAgents(bus, sub) - } - } - } - return a, nil } -func wireInMemoryEventChannelToSubAgents(bus eventbus.EventBus, agent *Agent) { - if agent == nil || bus == nil { - return - } - if ir, ok := agent.runtime.(runtime.EventBusRuntime); ok { - ir.SetEventBus(bus) - } - if agent.localAgentWorker != nil { - if ir, ok := agent.localAgentWorker.runtime.(runtime.EventBusRuntime); ok { - ir.SetEventBus(bus) - } - } - for _, child := range agent.subAgents { - wireInMemoryEventChannelToSubAgents(bus, child) - } -} - // NewAgent creates an Agent with the given options. // Background runtime workers (when used) start lazily when [Agent.Stream] runs or when approvals need them. func NewAgent(opts ...Option) (*Agent, error) { @@ -150,11 +121,19 @@ func (a *Agent) Close() { // When using [WithConversation], pass the conversation ID; agent and worker must use the same ID. func (a *Agent) Run(ctx context.Context, input string, opts *AgentRunOptions) (*AgentRunResult, error) { a.logger.Debug(ctx, "agent run started", slog.String("scope", "agent"), slog.String("name", a.Name), slog.Int("inputLen", len(input))) + return a.runInternal(ctx, input, opts, false) +} +func (a *Agent) runInternal(ctx context.Context, input string, opts *AgentRunOptions, runAsync bool) (*AgentRunResult, error) { conversationID := conversationIDFromOpts(opts) + spanName := "agent.run" + if runAsync { + spanName = "agent.run.async" + } + start := time.Now() - ctx, sp := a.tracer.StartSpan(ctx, "agent.run", + ctx, sp := a.tracer.StartSpan(ctx, spanName, interfaces.Attribute{Key: "agent.name", Value: a.Name}, interfaces.Attribute{Key: "conversation.id", Value: conversationID}, interfaces.Attribute{Key: "input.length", Value: len(input)}, @@ -169,7 +148,15 @@ func (a *Agent) Run(ctx context.Context, input string, opts *AgentRunOptions) (* return nil, err } - if a.hasApprovalTools() && a.approvalHandler == nil { + tools, err := a.resolveTools(ctx) + if err != nil { + sp.RecordError(err) + a.metrics.IncrementCounter(ctx, types.MetricRunFailed, interfaces.Attribute{Key: "error", Value: "tools_list_failed"}) + a.metrics.RecordHistogram(ctx, types.MetricRunDurationMs, float64(time.Since(start).Milliseconds())) + return nil, err + } + + if a.hasApprovalTools(tools) && a.approvalHandler == nil { err := fmt.Errorf("tools require approval but WithApprovalHandler was not set (required for Run)") sp.RecordError(err) a.metrics.IncrementCounter(ctx, types.MetricRunFailed, interfaces.Attribute{Key: "error", Value: "missing_approval_handler"}) @@ -177,7 +164,16 @@ func (a *Agent) Run(ctx context.Context, input string, opts *AgentRunOptions) (* return nil, err } - req := a.executeRequest(input, opts, false) + subAgents, err := a.resolveSubAgentSpecs(ctx) + if err != nil { + sp.RecordError(err) + a.metrics.IncrementCounter(ctx, types.MetricRunFailed, interfaces.Attribute{Key: "error", Value: "build_sub_agent_specs_failed"}) + a.metrics.RecordHistogram(ctx, types.MetricRunDurationMs, float64(time.Since(start).Milliseconds())) + return nil, err + } + a.shareEventBusWithSubAgents() + + req := a.executeRequest(input, opts, false, tools, subAgents) result, err := a.runtime.Execute(ctx, req) if err != nil { @@ -191,58 +187,23 @@ func (a *Agent) Run(ctx context.Context, input string, opts *AgentRunOptions) (* return result, nil } -// RunAsync starts the run in a goroutine and returns two channels: -// - resultCh: receives exactly one RunAsyncResult, then closes. -// - approvalCh: receives each pending tool approval; call req.Respond. Channel closes when the run ends. -// -// For each approval, call req.Respond(Approved|Rejected) exactly once. -// -// WithApprovalHandler is temporarily replaced for the duration of the run; restore happens when the run finishes. -// If tools do not require approval, approvalCh is still closed immediately with no values. -func (a *Agent) RunAsync(ctx context.Context, input string, opts *AgentRunOptions) (resultCh <-chan AgentRunAsyncResult, approvalCh <-chan *ApprovalRequest, err error) { +// RunAsync starts the run in a goroutine and returns a channel that receives exactly one +// [AgentRunAsyncResult], then closes. Use [WithApprovalHandler] when tools require approval +// (same as [Agent.Run]). +func (a *Agent) RunAsync(ctx context.Context, input string, opts *AgentRunOptions) (<-chan AgentRunAsyncResult, error) { a.logger.Debug(ctx, "agent run async started", slog.String("scope", "agent"), slog.String("name", a.Name), slog.Int("inputLen", len(input))) - conversationID := conversationIDFromOpts(opts) - - if err := a.validateConversationID(conversationID); err != nil { - return nil, nil, err - } - resCh := make(chan AgentRunAsyncResult, 1) - apprCh := make(chan *ApprovalRequest, 16) - go func() { - defer close(apprCh) defer close(resCh) - - var saved ApprovalHandler - if a.hasApprovalTools() { - saved = a.approvalHandler - a.approvalHandler = func(handlerCtx context.Context, req *ApprovalRequest) { - out := &ApprovalRequest{ - Name: req.Name, - Value: req.Value, - Respond: req.Respond, - } - select { - case apprCh <- out: - default: - // Avoid blocking Run's event loop if consumer is slow. - go func(p *ApprovalRequest) { apprCh <- p }(out) - } - } - defer func() { a.approvalHandler = saved }() - } - - resp, runErr := a.Run(ctx, input, opts) - if runErr != nil { - resCh <- AgentRunAsyncResult{Error: runErr} + resp, err := a.runInternal(ctx, input, opts, true) + if err != nil { + resCh <- AgentRunAsyncResult{Error: err} return } resCh <- AgentRunAsyncResult{Result: resp} }() - - return resCh, apprCh, nil + return resCh, nil } func copyApprovalArgs(src map[string]any) map[string]any { @@ -284,7 +245,23 @@ func (a *Agent) Stream(ctx context.Context, input string, opts *AgentRunOptions) return nil, err } - req := a.executeRequest(input, opts, true) + tools, err := a.resolveTools(ctx) + if err != nil { + sp.RecordError(err) + a.metrics.IncrementCounter(ctx, types.MetricStreamFailed, interfaces.Attribute{Key: "error", Value: "tools_list_failed"}) + a.metrics.RecordHistogram(ctx, types.MetricStreamDurationMs, float64(time.Since(start).Milliseconds())) + return nil, err + } + subAgents, err := a.resolveSubAgentSpecs(ctx) + if err != nil { + sp.RecordError(err) + a.metrics.IncrementCounter(ctx, types.MetricStreamFailed, interfaces.Attribute{Key: "error", Value: "build_sub_agent_specs_failed"}) + a.metrics.RecordHistogram(ctx, types.MetricStreamDurationMs, float64(time.Since(start).Milliseconds())) + return nil, err + } + a.shareEventBusWithSubAgents() + + req := a.executeRequest(input, opts, true, tools, subAgents) streamCh, err := a.runtime.ExecuteStream(ctx, req) if err != nil { @@ -315,39 +292,71 @@ func (a *Agent) validateConversationID(conversationID string) error { return nil } -// executeRequest builds [runtime.ExecuteRequest] with per-run fields plus AgentSpec and AgentExecution for custom Runtime implementations. -func (a *Agent) executeRequest(userPrompt string, opts *AgentRunOptions, streaming bool) *runtime.ExecuteRequest { +// executeRequest builds [runtime.ExecuteRequest] with per-run fields for Run, Stream, and RunAsync. +func (a *Agent) executeRequest(userPrompt string, opts *AgentRunOptions, streaming bool, tools []interfaces.Tool, subAgents []*runtime.SubAgentSpec) *runtime.ExecuteRequest { return &runtime.ExecuteRequest{ UserPrompt: userPrompt, RunOptions: opts, StreamingEnabled: streaming, - SubAgents: a.buildSubAgentSpecs(), + SubAgents: subAgents, MaxSubAgentDepth: a.maxSubAgentDepth, ApprovalHandler: a.approvalHandler, - AgentSpec: a.agentSpec(), - AgentExecution: a.agentExecution(), + Tools: tools, } } -func (a *Agent) agentSpec() *runtime.AgentSpec { - s := a.runtimeAgentSpec() - return &s +// Sub-agents share the parent's in-memory pub/sub when the runtime implements [runtime.EventBusRuntime] +// (e.g. Temporal). A custom [runtime.Runtime] from a future WithRuntime need only implement [Runtime]; +// wiring is skipped when the assert to EventBusRuntime fails. +func (a *Agent) shareEventBusWithSubAgents() { + if a == nil { + return + } + ir, ok := a.runtime.(runtime.EventBusRuntime) + if !ok || a.subAgentRegistry == nil { + return + } + bus := ir.GetEventBus() + for _, sub := range a.subAgentRegistry.List() { + if sub != nil { + shareEventBusWithSubAgent(bus, sub) + } + } } -func (a *Agent) agentExecution() *runtime.AgentExecution { - e := a.runtimeAgentExecution() - return &e +func shareEventBusWithSubAgent(bus eventbus.EventBus, agent *Agent) { + if agent == nil || bus == nil { + return + } + if ir, ok := agent.runtime.(runtime.EventBusRuntime); ok { + ir.SetEventBus(bus) + } + if agent.localAgentWorker != nil { + if ir, ok := agent.localAgentWorker.runtime.(runtime.EventBusRuntime); ok { + ir.SetEventBus(bus) + } + } + if agent.subAgentRegistry == nil { + return + } + for _, child := range agent.subAgentRegistry.List() { + shareEventBusWithSubAgent(bus, child) + } } -// buildSubAgentSpecs builds the runtime-agnostic sub-agent spec tree for this agent. +// resolveSubAgentSpecs builds the runtime-agnostic sub-agent spec tree for this agent. // Each runtime receives this tree via ExecuteRequest.SubAgents and constructs its own // internal routing structures (local: *LocalRuntime refs; temporal: task queue + fingerprint). -func (a *Agent) buildSubAgentSpecs() []*runtime.SubAgentSpec { - if a == nil || len(a.subAgents) == 0 { - return nil +func (a *Agent) resolveSubAgentSpecs(ctx context.Context) ([]*runtime.SubAgentSpec, error) { + if a == nil || a.subAgentRegistry == nil { + return nil, nil + } + subs := a.subAgentRegistry.List() + if len(subs) == 0 { + return nil, nil } - out := make([]*runtime.SubAgentSpec, 0, len(a.subAgents)) - for _, sub := range a.subAgents { + out := make([]*runtime.SubAgentSpec, 0, len(subs)) + for _, sub := range subs { if sub == nil { continue } @@ -355,15 +364,24 @@ func (a *Agent) buildSubAgentSpecs() []*runtime.SubAgentSpec { if err != nil || toolName == "" { continue } + tools, err := sub.resolveTools(ctx) + if err != nil { + return nil, err + } + children, err := sub.resolveSubAgentSpecs(ctx) + if err != nil { + return nil, err + } out = append(out, &runtime.SubAgentSpec{ Name: sub.Name, ToolName: toolName, Runtime: sub.runtime, - Children: sub.buildSubAgentSpecs(), + Children: children, + Tools: tools, }) } if len(out) == 0 { - return nil + return nil, nil } if a.logger != nil { names := make([]string, 0, len(out)) @@ -376,5 +394,37 @@ func (a *Agent) buildSubAgentSpecs() []*runtime.SubAgentSpec { slog.Any("subAgentToolNames", names), slog.Int("specCount", len(out))) } - return out + return out, nil +} + +// ToolRegistry returns the agent's tool registry. +func (a *Agent) ToolRegistry() ToolRegistry { + if a == nil { + return nil + } + return a.toolRegistry +} + +// MCPRegistry returns the agent's MCP client registry. +func (a *Agent) MCPRegistry() MCPRegistry { + if a == nil { + return nil + } + return a.mcpRegistry +} + +// A2ARegistry returns the agent's A2A client registry. +func (a *Agent) A2ARegistry() A2ARegistry { + if a == nil { + return nil + } + return a.a2aRegistry +} + +// SubAgentRegistry returns the agent's sub-agent registry. +func (a *Agent) SubAgentRegistry() SubAgentRegistry { + if a == nil { + return nil + } + return a.subAgentRegistry } diff --git a/pkg/agent/agent_test.go b/pkg/agent/agent_test.go index 0154cd5..6593a28 100644 --- a/pkg/agent/agent_test.go +++ b/pkg/agent/agent_test.go @@ -20,15 +20,26 @@ import ( ) func testAgentWithRuntime(rt runtime.Runtime) *Agent { + cfg := agentConfig{ + Name: "TestAgent", + logger: logger.DefaultLogger("error"), + maxSubAgentDepth: 2, + tracer: observability.DefaultNoopTracer, + metrics: observability.DefaultNoopMetrics, + } + if err := cfg.buildRegistries(); err != nil { + panic(err) + } return &Agent{ - agentConfig: agentConfig{ - Name: "TestAgent", - logger: logger.DefaultLogger("error"), - maxSubAgentDepth: 2, - tracer: observability.DefaultNoopTracer, - metrics: observability.DefaultNoopMetrics, - }, - runtime: rt, + agentConfig: cfg, + runtime: rt, + } +} + +func mustTestRegistries(t *testing.T, cfg *agentConfig) { + t.Helper() + if err := cfg.buildRegistries(); err != nil { + t.Fatal(err) } } @@ -43,10 +54,7 @@ func TestAgent_Run_ForwardsRequestAndReturnsResponse(t *testing.T) { if req.UserPrompt != "hello" { t.Errorf("UserPrompt = %q", req.UserPrompt) } - name := "" - if req.AgentSpec != nil { - name = req.AgentSpec.Name - } + name := "TestAgent" return &types.AgentRunResult{Content: "reply", AgentName: name, Model: "m1"}, nil }) @@ -68,11 +76,7 @@ func TestAgent_Stream_SetsStreamingEnabled(t *testing.T) { mockRT.EXPECT().ExecuteStream(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, req *runtime.ExecuteRequest) (<-chan events.AgentEvent, error) { streamReq = req ch := make(chan events.AgentEvent, 2) - evName := "" - if req.AgentSpec != nil { - evName = req.AgentSpec.Name - } - ch <- events.NewAgentRunFinishedEvent("", "", &types.AgentRunResult{AgentName: evName, Content: "done"}) + ch <- events.NewAgentRunFinishedEvent("", "", &types.AgentRunResult{AgentName: "TestAgent", Content: "done"}) close(ch) var recv <-chan events.AgentEvent = ch return recv, nil @@ -123,7 +127,7 @@ func TestAgent_RunAsync_DeliversResult(t *testing.T) { mockRT.EXPECT().Execute(gomock.Any(), gomock.Any()).Return(&types.AgentRunResult{Content: "mock", AgentName: "TestAgent", Model: "stub"}, nil) a := testAgentWithRuntime(mockRT) - resCh, apprCh, err := a.RunAsync(context.Background(), "async", nil) + resCh, err := a.RunAsync(context.Background(), "async", nil) if err != nil { t.Fatal(err) } @@ -138,9 +142,6 @@ func TestAgent_RunAsync_DeliversResult(t *testing.T) { case <-time.After(3 * time.Second): t.Fatal("timeout waiting for RunAsync result") } - for range apprCh { - t.Fatal("unexpected approval request") - } } func TestAgent_Stream_CustomStreamFn(t *testing.T) { @@ -296,9 +297,16 @@ func (s *stubRuntime) Close() func TestBuildSubAgentSpecs_flat(t *testing.T) { childRT := &stubRuntime{} child := &Agent{agentConfig: agentConfig{Name: "Child"}, runtime: childRT} - parent := &Agent{agentConfig: agentConfig{Name: "Parent", subAgents: []*Agent{child}}, runtime: &stubRuntime{}} + mustTestRegistries(t, &child.agentConfig) + parentReg := NewSubAgentRegistry() + _ = parentReg.Register(child) + parent := &Agent{agentConfig: agentConfig{Name: "Parent", subAgentRegistry: parentReg}, runtime: &stubRuntime{}} + mustTestRegistries(t, &parent.agentConfig) - got := parent.buildSubAgentSpecs() + got, err := parent.resolveSubAgentSpecs(context.Background()) + if err != nil { + t.Fatal(err) + } if len(got) != 1 { t.Fatalf("want 1 spec, got %d", len(got)) } @@ -324,11 +332,21 @@ func TestBuildSubAgentSpecs_flat(t *testing.T) { func TestBuildSubAgentSpecs_nested(t *testing.T) { leafRT := &stubRuntime{} leaf := &Agent{agentConfig: agentConfig{Name: "Leaf"}, runtime: leafRT} + mustTestRegistries(t, &leaf.agentConfig) midRT := &stubRuntime{} - mid := &Agent{agentConfig: agentConfig{Name: "Mid", subAgents: []*Agent{leaf}}, runtime: midRT} - root := &Agent{agentConfig: agentConfig{Name: "Root", subAgents: []*Agent{mid}}, runtime: &stubRuntime{}} - - got := root.buildSubAgentSpecs() + midReg := NewSubAgentRegistry() + _ = midReg.Register(leaf) + mid := &Agent{agentConfig: agentConfig{Name: "Mid", subAgentRegistry: midReg}, runtime: midRT} + mustTestRegistries(t, &mid.agentConfig) + rootReg := NewSubAgentRegistry() + _ = rootReg.Register(mid) + root := &Agent{agentConfig: agentConfig{Name: "Root", subAgentRegistry: rootReg}, runtime: &stubRuntime{}} + mustTestRegistries(t, &root.agentConfig) + + got, err := root.resolveSubAgentSpecs(context.Background()) + if err != nil { + t.Fatal(err) + } if len(got) != 1 { t.Fatalf("want 1 top-level spec, got %d", len(got)) } @@ -351,9 +369,16 @@ func TestBuildSubAgentSpecs_nested(t *testing.T) { func TestBuildSubAgentSpecs_noRuntimeStillBuilds(t *testing.T) { // Sub-agent with no runtime still gets a spec — runtime decides what to do with it. sub := &Agent{agentConfig: agentConfig{Name: "X"}} - parent := &Agent{agentConfig: agentConfig{subAgents: []*Agent{sub}}} + mustTestRegistries(t, &sub.agentConfig) + parentReg := NewSubAgentRegistry() + _ = parentReg.Register(sub) + parent := &Agent{agentConfig: agentConfig{subAgentRegistry: parentReg}} + mustTestRegistries(t, &parent.agentConfig) - got := parent.buildSubAgentSpecs() + got, err := parent.resolveSubAgentSpecs(context.Background()) + if err != nil { + t.Fatal(err) + } if len(got) != 1 { t.Fatalf("want 1 spec, got %v", got) } @@ -395,6 +420,9 @@ func TestAgent_Run_RequiresApprovalHandlerWhenToolsNeedApproval(t *testing.T) { }, runtime: mockRT, } + if err := a.buildToolRegistry(); err != nil { + t.Fatal(err) + } _, err := a.Run(context.Background(), "hi", nil) if err == nil || !strings.Contains(err.Error(), "WithApprovalHandler") { t.Fatalf("got %v", err) @@ -428,9 +456,176 @@ func TestWireInMemoryEventChannelToSubAgents(t *testing.T) { agentConfig: agentConfig{Name: "Child", taskQueue: "q-c"}, runtime: childRT, } + parentReg := NewSubAgentRegistry() + _ = parentReg.Register(child) parent := &Agent{ - agentConfig: agentConfig{Name: "Parent", taskQueue: "q-p", subAgents: []*Agent{child}}, + agentConfig: agentConfig{Name: "Parent", taskQueue: "q-p", subAgentRegistry: parentReg}, runtime: parentRT, } - wireInMemoryEventChannelToSubAgents(bus, parent) + shareEventBusWithSubAgent(bus, parent) +} + +func TestToolsList_picksUpRegistryChange(t *testing.T) { + child := &Agent{agentConfig: agentConfig{Name: "Child"}} + mustTestRegistries(t, &child.agentConfig) + parentReg := NewSubAgentRegistry() + parent := &Agent{ + agentConfig: agentConfig{ + Name: "Parent", + toolRegistry: NewToolRegistry(), + subAgentRegistry: parentReg, + logger: NoopLogger(), + }, + runtime: &stubRuntime{}, + } + if err := parent.buildMCPRegistry(); err != nil { + t.Fatal(err) + } + if err := parent.buildA2ARegistry(); err != nil { + t.Fatal(err) + } + _ = parent.toolRegistry.Register(mockTool{name: "echo"}) + + tools1, err := parent.resolveTools(context.Background()) + if err != nil { + t.Fatal(err) + } + if len(tools1) != 1 { + t.Fatalf("tools1 = %d, want 1", len(tools1)) + } + + _ = parent.subAgentRegistry.Register(child) + tools2, err := parent.resolveTools(context.Background()) + if err != nil { + t.Fatal(err) + } + if len(tools2) != 2 { + t.Fatalf("tools2 = %d, want 2 after sub-agent register", len(tools2)) + } + subAgents, err := parent.resolveSubAgentSpecs(context.Background()) + if err != nil { + t.Fatal(err) + } + if len(subAgents) != 1 { + t.Fatalf("subAgents = %d, want 1", len(subAgents)) + } + if agentConfigFingerprintTools(&parent.agentConfig, tools1) == agentConfigFingerprintTools(&parent.agentConfig, tools2) { + t.Fatal("run fingerprint should change when tools change") + } +} + +func TestToolsList_stableOrder(t *testing.T) { + c := &agentConfig{ + toolRegistry: NewToolRegistry(), + } + _ = c.toolRegistry.Register(mockTool{name: "b"}) + _ = c.toolRegistry.Register(mockTool{name: "a"}) + tools1, err := c.resolveTools(context.Background()) + if err != nil { + t.Fatal(err) + } + fp1 := agentConfigFingerprintTools(c, tools1) + tools2, err := c.resolveTools(context.Background()) + if err != nil { + t.Fatal(err) + } + fp2 := agentConfigFingerprintTools(c, tools2) + if fp1 != fp2 { + t.Fatalf("fingerprints differ: %q vs %q", fp1, fp2) + } +} + +func TestAgent_RegistryAccessors(t *testing.T) { + if (*Agent)(nil).ToolRegistry() != nil { + t.Fatal("nil agent ToolRegistry should be nil") + } + if (*Agent)(nil).MCPRegistry() != nil { + t.Fatal("nil agent MCPRegistry should be nil") + } + if (*Agent)(nil).A2ARegistry() != nil { + t.Fatal("nil agent A2ARegistry should be nil") + } + if (*Agent)(nil).SubAgentRegistry() != nil { + t.Fatal("nil agent SubAgentRegistry should be nil") + } + + toolReg := NewToolRegistry() + mcpReg := NewMCPRegistry(nil) + a2aReg := NewA2ARegistry(nil) + subReg := NewSubAgentRegistry() + child := &Agent{agentConfig: agentConfig{Name: "Child", taskQueue: "q-child"}} + mustTestRegistries(t, &child.agentConfig) + if err := subReg.Register(child); err != nil { + t.Fatal(err) + } + + a := &Agent{ + agentConfig: agentConfig{ + Name: "Parent", + toolRegistry: toolReg, + mcpRegistry: mcpReg, + a2aRegistry: a2aReg, + subAgentRegistry: subReg, + }, + } + if a.ToolRegistry() != toolReg { + t.Fatal("ToolRegistry accessor should return configured registry") + } + if a.MCPRegistry() != mcpReg { + t.Fatal("MCPRegistry accessor should return configured registry") + } + if a.A2ARegistry() != a2aReg { + t.Fatal("A2ARegistry accessor should return configured registry") + } + if a.SubAgentRegistry() != subReg { + t.Fatal("SubAgentRegistry accessor should return configured registry") + } +} + +func TestAgent_Run_resolvesToolsPerRun(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockRT := rtmocks.NewMockRuntime(ctrl) + + reg := NewToolRegistry() + if err := reg.Register(mockTool{name: "first"}); err != nil { + t.Fatal(err) + } + + cfg := agentConfig{ + Name: "TestAgent", + toolRegistry: reg, + logger: logger.DefaultLogger("error"), + maxSubAgentDepth: 2, + tracer: observability.DefaultNoopTracer, + metrics: observability.DefaultNoopMetrics, + toolApprovalPolicy: AutoToolApprovalPolicy(), + } + mustTestRegistries(t, &cfg) + a := &Agent{agentConfig: cfg, runtime: mockRT} + + var toolCounts []int + gomock.InOrder( + mockRT.EXPECT().Execute(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, req *runtime.ExecuteRequest) (*types.AgentRunResult, error) { + toolCounts = append(toolCounts, len(req.Tools)) + return &types.AgentRunResult{Content: "ok"}, nil + }), + mockRT.EXPECT().Execute(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, req *runtime.ExecuteRequest) (*types.AgentRunResult, error) { + toolCounts = append(toolCounts, len(req.Tools)) + return &types.AgentRunResult{Content: "ok"}, nil + }), + ) + + if _, err := a.Run(context.Background(), "one", nil); err != nil { + t.Fatal(err) + } + if err := a.ToolRegistry().Register(mockTool{name: "second"}); err != nil { + t.Fatal(err) + } + if _, err := a.Run(context.Background(), "two", nil); err != nil { + t.Fatal(err) + } + if len(toolCounts) != 2 || toolCounts[0] != 1 || toolCounts[1] != 2 { + t.Fatalf("tool counts per run = %v, want [1 2]", toolCounts) + } } diff --git a/pkg/agent/config.go b/pkg/agent/config.go index 9cf91f3..f25dd7a 100644 --- a/pkg/agent/config.go +++ b/pkg/agent/config.go @@ -16,10 +16,8 @@ import ( "github.com/agenticenv/agent-sdk-go/internal/runtime" "github.com/agenticenv/agent-sdk-go/internal/runtime/temporal" "github.com/agenticenv/agent-sdk-go/internal/types" - a2aclient "github.com/agenticenv/agent-sdk-go/pkg/a2a/client" "github.com/agenticenv/agent-sdk-go/pkg/interfaces" "github.com/agenticenv/agent-sdk-go/pkg/logger" - mcpclient "github.com/agenticenv/agent-sdk-go/pkg/mcp/client" "github.com/agenticenv/agent-sdk-go/pkg/observability" "github.com/google/uuid" "go.temporal.io/sdk/client" @@ -185,6 +183,7 @@ type ObservabilityConfig struct { // - AgentWorker only: (none; worker inherits options passed to NewAgentWorker) // - Both: WithName, WithDescription, WithSystemPrompt, WithTemporalConfig, WithTemporalClient, // WithInstanceId, WithLLMClient, WithToolApprovalPolicy, WithTools, WithToolRegistry, +// WithMCPRegistry, WithA2ARegistry, WithSubAgentRegistry, // WithMaxIterations, WithStream, WithLogger, WithLogLevel, WithConversation, WithConversationSize, EnableConversationSaveOnIteration, // WithResponseFormat, WithLLMSampling, WithSubAgents, WithMaxSubAgentDepth, // WithMCPConfig, WithMCPClients, WithA2AConfig, WithA2AClients, WithRetrievers, WithRetrieverMode, WithAgentMode, WithDisableFingerprintCheck, WithAgentToolExecutionMode, @@ -202,8 +201,11 @@ type agentConfig struct { instanceId string taskQueue string LLMClient interfaces.LLMClient - tools []interfaces.Tool - toolRegistry interfaces.ToolRegistry + tools []interfaces.Tool // staging for [WithTools]; consumed when the agent is created + toolRegistry ToolRegistry + mcpRegistry MCPRegistry + a2aRegistry A2ARegistry + subAgentRegistry SubAgentRegistry toolApprovalPolicy interfaces.AgentToolApprovalPolicy maxIterations int streamEnabled bool @@ -230,26 +232,21 @@ type agentConfig struct { // break-glass: disable caller-vs-worker fingerprint guard at activity entry. disableFingerprintCheck bool - // Sub-agents: direct children exposed to the LLM; subAgentTools is filled by buildSubAgentTools (graph + name checks), merged in toolsList with base and MCP tools. maxSubAgentDepth caps nesting from this agent (direct children = 1; default 2 when unset or <= 0). - subAgents []*Agent - subAgentTools []interfaces.Tool + // Sub-agents: direct children in [subAgentRegistry]. + subAgents []*Agent // staging for [WithSubAgents]; consumed when the agent is created maxSubAgentDepth int - // MCP: optional server configs and/or explicit clients; merged at build into mcpTools (see buildMCPTools). + // MCP: [WithMCPConfig] / [WithMCPClients] populate [mcpRegistry]; tools resolved on each run. mcpServers MCPServers mcpClients []interfaces.MCPClient - mcpTools []interfaces.Tool - // A2A: optional server configs and/or explicit clients; merged at build into a2aTools (see buildA2ATools). + // A2A: [WithA2AConfig] / [WithA2AClients] populate [a2aRegistry]; skills resolved on each run. a2aServers A2AServers a2aClients []interfaces.A2AClient - a2aTools []interfaces.Tool // Retrievers: optional vector/document backends (e.g. Weaviate) for RAG; validated at build. - // retrieverTools is filled by buildRetrieverTools for agentic/hybrid modes (see [RetrieverTool]). - retrievers []interfaces.Retriever - retrieverMode RetrieverMode - retrieverTools []interfaces.Tool + retrievers []interfaces.Retriever + retrieverMode RetrieverMode //A2A Server: optional server config; merged at build into a2aServer (see RunA2A). a2aServerConfig *A2AServerConfig @@ -349,16 +346,36 @@ func WithToolApprovalPolicy(policy interfaces.AgentToolApprovalPolicy) Option { return func(c *agentConfig) { c.toolApprovalPolicy = policy } } -// WithTools registers tools with the agent. Applies to Agent and AgentWorker. +// WithTools sets tools at agent creation. See [WithToolRegistry] to change tools later. +// Applies to Agent and AgentWorker. func WithTools(tools ...interfaces.Tool) Option { return func(c *agentConfig) { c.tools = tools } } -// WithToolRegistry sets a tool registry. Applies to Agent and AgentWorker. -func WithToolRegistry(reg interfaces.ToolRegistry) Option { +// WithToolRegistry sets the tool registry. Use Register and Unregister before Run, Stream, or RunAsync. +// Applies to Agent and AgentWorker. +func WithToolRegistry(reg ToolRegistry) Option { return func(c *agentConfig) { c.toolRegistry = reg } } +// WithMCPRegistry sets the MCP client registry. Use Register, RegisterClient, and Unregister before Run, Stream, or RunAsync. +// Applies to Agent and AgentWorker. +func WithMCPRegistry(reg MCPRegistry) Option { + return func(c *agentConfig) { c.mcpRegistry = reg } +} + +// WithA2ARegistry sets the A2A client registry. Use Register, RegisterClient, and Unregister before Run, Stream, or RunAsync. +// Applies to Agent and AgentWorker. +func WithA2ARegistry(reg A2ARegistry) Option { + return func(c *agentConfig) { c.a2aRegistry = reg } +} + +// WithSubAgentRegistry sets the sub-agent registry. Use Register and Unregister before Run, Stream, or RunAsync. +// Applies to Agent and AgentWorker. +func WithSubAgentRegistry(reg SubAgentRegistry) Option { + return func(c *agentConfig) { c.subAgentRegistry = reg } +} + // WithMaxIterations sets the max number of LLM rounds. Applies to Agent and AgentWorker. func WithMaxIterations(n int) Option { return func(c *agentConfig) { c.maxIterations = n } @@ -386,7 +403,7 @@ func WithLogLevel(level string) Option { return func(c *agentConfig) { c.logLevel = level } } -// WithApprovalHandler sets the approval callback for Run. Required when tools need approval. +// WithApprovalHandler sets the approval callback for Run and RunAsync. Required when tools need approval. // The callback receives req with req.Respond set; call req.Respond(Approved|Rejected). Agent only; Stream uses OnApproval on events. func WithApprovalHandler(fn types.ApprovalHandler) Option { return func(c *agentConfig) { c.approvalHandler = fn } @@ -468,9 +485,8 @@ func WithLLMSampling(s *LLMSampling) Option { return func(c *agentConfig) { c.llmSampling = s } } -// WithSubAgents registers sub-agents. Each is exposed to the parent LLM as a tool (AgentTool). -// Delegation runs through the execution runtime (child run), not Tool.Execute. -// The sub-agent graph is validated at agent build: no cycles, depth <= WithMaxSubAgentDepth (default 2). +// WithSubAgents sets sub-agents at agent creation. See [WithSubAgentRegistry] to change them later. +// Applies to Agent and AgentWorker. func WithSubAgents(subAgents ...*Agent) Option { return func(c *agentConfig) { c.subAgents = subAgents } } @@ -481,10 +497,8 @@ func WithMaxSubAgentDepth(depth int) Option { return func(c *agentConfig) { c.maxSubAgentDepth = depth } } -// WithMCPConfig registers MCP servers by stable key (used in tool names and default client naming). -// Each [MCPConfig] must set [MCPConfig.Transport] using transport types from [github.com/agenticenv/agent-sdk-go/pkg/mcp]; -// the agent wires a default MCP client internally (no modelcontextprotocol/go-sdk usage in application code). -// Tools are discovered and merged with [WithMCPClients]. Applies to Agent and AgentWorker. +// WithMCPConfig registers MCP servers by key. See [WithMCPRegistry] to change clients later. +// Applies to Agent and AgentWorker. func WithMCPConfig(servers MCPServers) Option { return func(c *agentConfig) { c.mcpServers = servers } } @@ -503,10 +517,8 @@ func WithMCPClients(clients ...interfaces.MCPClient) Option { } } -// WithA2AConfig registers A2A agent servers by stable key (used in tool names and default client naming). -// Each [A2AConfig] must set [A2AConfig.URL]; the agent wires a default A2A client internally -// (no a2aproject/a2a-go/v2 usage in application code). -// Skills are discovered and merged with [WithA2AClients]. Applies to Agent and AgentWorker. +// WithA2AConfig registers remote A2A agents by key. See [WithA2ARegistry] to change clients later. +// Applies to Agent and AgentWorker. func WithA2AConfig(servers A2AServers) Option { return func(c *agentConfig) { c.a2aServers = servers } } @@ -716,10 +728,7 @@ func buildAgentConfig(opts []Option) (*agentConfig, error) { if c.maxSubAgentDepth <= 0 { c.maxSubAgentDepth = defaultMaxSubAgentDepth } - if err := c.buildMCPTools(); err != nil { - return nil, err - } - if err := c.buildA2ATools(); err != nil { + if err := c.buildRegistries(); err != nil { return nil, err } if err := validateRetrievers(c.retrievers); err != nil { @@ -730,15 +739,11 @@ func buildAgentConfig(opts []Option) (*agentConfig, error) { return nil, err } c.retrieverMode = mode - if err := c.buildRetrieverTools(); err != nil { - return nil, err - } - if err := c.buildSubAgentTools(); err != nil { - return nil, err - } - if err := c.validateToolNames(); err != nil { + // Fail fast at NewAgent: merge registries, discover MCP/A2A tools, validate names (same path as each run). + if _, err := c.resolveTools(context.Background()); err != nil { return nil, err } + if c.timeout == 0 { switch c.agentMode { case AgentModeAutonomous: @@ -748,18 +753,14 @@ func buildAgentConfig(opts []Option) (*agentConfig, error) { } } - // Validate approvalTimeout when any tool requires approval (approvalTimeout must be < timeout) - if c.hasApprovalTools() { - c.logger.Debug(context.Background(), "tools require approval", slog.String("scope", "agent"), slog.String("name", c.Name)) - if c.approvalTimeout == 0 { - c.approvalTimeout = c.timeout - 30*time.Second - } - if c.approvalTimeout >= c.timeout { - return nil, fmt.Errorf("approvalTimeout (%v) must be less than agent timeout (%v)", c.approvalTimeout, c.timeout) - } - if c.approvalTimeout > types.MaxApprovalTimeout { - return nil, fmt.Errorf("approvalTimeout (%v) exceeds max (%v)", c.approvalTimeout, types.MaxApprovalTimeout) - } + if c.approvalTimeout == 0 { + c.approvalTimeout = c.timeout - 30*time.Second + } + if c.approvalTimeout >= c.timeout { + return nil, fmt.Errorf("approvalTimeout (%v) must be less than agent timeout (%v)", c.approvalTimeout, c.timeout) + } + if c.approvalTimeout > types.MaxApprovalTimeout { + return nil, fmt.Errorf("approvalTimeout (%v) exceeds max (%v)", c.approvalTimeout, types.MaxApprovalTimeout) } if c.maxIterations <= 0 { @@ -849,12 +850,11 @@ func buildAgentConfig(opts []Option) (*agentConfig, error) { slog.Duration("timeout", c.timeout), slog.Duration("approvalTimeout", c.approvalTimeout), slog.String("logLevel", c.logLevel), - slog.Int("toolCount", len(c.toolsList())), - slog.Int("subAgentToolCount", len(c.subAgentTools)), - slog.Int("mcpToolCount", len(c.mcpTools)), - slog.Int("a2aToolCount", len(c.a2aTools)), + slog.Int("toolRegistryCount", len(c.toolRegistry.List())), + slog.Int("mcpRegistryCount", len(c.mcpRegistry.List())), + slog.Int("a2aRegistryCount", len(c.a2aRegistry.List())), + slog.Int("subAgentRegistryCount", len(c.subAgentRegistry.List())), slog.Int("retrieverCount", len(c.retrievers)), - slog.Int("retrieverToolCount", len(c.retrieverTools)), slog.String("retrieverMode", string(c.retrieverMode)), slog.Bool("hasConversation", c.conversation != nil), slog.Bool("hasObservability", c.observabilityConfig != nil), @@ -890,67 +890,133 @@ func buildAgentConfig(opts []Option) (*agentConfig, error) { return c, nil } -// buildAgentRuntime constructs the execution backend from agentConfig. -// Defaults to the local in-process runtime when no Temporal backend is configured. -// Extend with additional branches when new [runtime.Runtime] implementations are added. -func (cfg *agentConfig) buildAgentRuntime(remoteWorker bool) (runtime.Runtime, error) { - if cfg.hasTemporalRuntime() { - return cfg.buildTemporalRuntime(remoteWorker) +// buildRegistries wires registries from agent options during [buildAgentConfig]. +func (c *agentConfig) buildRegistries() error { + if err := c.buildToolRegistry(); err != nil { + return err } - return cfg.buildLocalRuntime() + if err := c.buildMCPRegistry(); err != nil { + return err + } + if err := c.buildA2ARegistry(); err != nil { + return err + } + if err := c.buildSubAgentRegistry(); err != nil { + return err + } + return nil } -// toolsList returns WithTools or registry tools, merged MCP tools ([mcpTools]), A2A tools ([a2aTools]), -// retriever tools ([retrieverTools]), then [subAgentTools] from [buildSubAgentTools]. -func (c *agentConfig) toolsList() []interfaces.Tool { - var base []interfaces.Tool - if c.toolRegistry != nil { - base = c.toolRegistry.Tools() - } else { - base = c.tools - } - if len(c.mcpTools) > 0 { - merged := make([]interfaces.Tool, len(base)+len(c.mcpTools)) - copy(merged, base) - copy(merged[len(base):], c.mcpTools) - base = merged - } - if len(c.a2aTools) > 0 { - merged := make([]interfaces.Tool, len(base)+len(c.a2aTools)) - copy(merged, base) - copy(merged[len(base):], c.a2aTools) - base = merged - } - if len(c.retrieverTools) > 0 { - merged := make([]interfaces.Tool, len(base)+len(c.retrieverTools)) - copy(merged, base) - copy(merged[len(base):], c.retrieverTools) - base = merged - } - if len(c.subAgentTools) > 0 { - merged := make([]interfaces.Tool, len(base)+len(c.subAgentTools)) - copy(merged, base) - copy(merged[len(base):], c.subAgentTools) - base = merged - } - return base -} - -// buildSubAgentTools sets subAgentTools from [agentConfig.subAgents] using [NewSubAgentTool], -// and validates roots (non-nil, no duplicate agent pointer, no duplicate derived tool name) and the nested graph (cycles, max depth). -func (c *agentConfig) buildSubAgentTools() error { - if len(c.subAgents) == 0 { - c.subAgentTools = nil +func (c *agentConfig) buildToolRegistry() error { + reg := c.toolRegistry + if reg == nil { + reg = NewToolRegistry() + } + for _, tool := range c.tools { + if err := reg.Register(tool); err != nil { + return fmt.Errorf("WithTools: %w", err) + } + } + c.toolRegistry = reg + c.tools = nil + return nil +} + +func (c *agentConfig) buildMCPRegistry() error { + reg := c.mcpRegistry + if reg == nil { + reg = NewMCPRegistry(c.logger) + } + if len(c.mcpServers) > 0 { + keys := make([]string, 0, len(c.mcpServers)) + for k := range c.mcpServers { + keys = append(keys, k) + } + sort.Strings(keys) + for _, k := range keys { + if err := reg.Register(k, c.mcpServers[k]); err != nil { + return fmt.Errorf("WithMCPConfig: %w", err) + } + } + } + for _, cl := range c.mcpClients { + if cl == nil { + return fmt.Errorf("WithMCPClients: mcp client must not be nil") + } + if err := reg.RegisterClient(cl); err != nil { + return fmt.Errorf("WithMCPClients: %w", err) + } + } + if err := validateMCPClients(reg.List()); err != nil { + return err + } + c.mcpRegistry = reg + return nil +} + +func (c *agentConfig) buildA2ARegistry() error { + reg := c.a2aRegistry + if reg == nil { + reg = NewA2ARegistry(c.logger) + } + if len(c.a2aServers) > 0 { + keys := make([]string, 0, len(c.a2aServers)) + for k := range c.a2aServers { + keys = append(keys, k) + } + sort.Strings(keys) + for _, k := range keys { + if err := reg.Register(k, c.a2aServers[k]); err != nil { + return fmt.Errorf("WithA2AConfig: %w", err) + } + } + } + for _, cl := range c.a2aClients { + if cl == nil { + return fmt.Errorf("WithA2AClients: a2a client must not be nil") + } + if err := reg.RegisterClient(cl); err != nil { + return fmt.Errorf("WithA2AClients: %w", err) + } + } + if err := validateA2AClients(reg.List()); err != nil { + return err + } + c.a2aRegistry = reg + return nil +} + +func (c *agentConfig) buildSubAgentRegistry() error { + reg := c.subAgentRegistry + if reg == nil { + reg = NewSubAgentRegistry() + } + for _, sa := range c.subAgents { + if err := reg.Register(sa); err != nil { + return fmt.Errorf("WithSubAgents: %w", err) + } + } + if err := validateSubAgentRegistry(c, reg); err != nil { + return err + } + c.subAgentRegistry = reg + c.subAgents = nil + return nil +} + +// validateSubAgentRegistry checks roots and nested graph (cycles, max depth) for reg.List(). +func validateSubAgentRegistry(c *agentConfig, reg SubAgentRegistry) error { + agents := reg.List() + if len(agents) == 0 { return nil } maxDepth := c.maxSubAgentDepth if maxDepth <= 0 { maxDepth = defaultMaxSubAgentDepth } - seen := make(map[*Agent]struct{}, len(c.subAgents)) - seenNames := make(map[string]struct{}, len(c.subAgents)) - out := make([]interfaces.Tool, 0, len(c.subAgents)) - for _, sa := range c.subAgents { + seen := make(map[*Agent]struct{}, len(agents)) + seenNames := make(map[string]struct{}, len(agents)) + for _, sa := range agents { if sa == nil { return fmt.Errorf("sub-agent must not be nil") } @@ -958,20 +1024,16 @@ func (c *agentConfig) buildSubAgentTools() error { return fmt.Errorf("duplicate sub-agent %q in WithSubAgents", sa.Name) } seen[sa] = struct{}{} - n, err := subAgentToolName(sa.Name) + toolName, err := subAgentToolName(sa.Name) if err != nil { return fmt.Errorf("WithSubAgents: %w", err) } - if _, dup := seenNames[n]; dup { - return fmt.Errorf("duplicate sub-agent tool name %q", n) - } - seenNames[n] = struct{}{} - if st := NewSubAgentTool(sa); st != nil { - out = append(out, st) + if _, dup := seenNames[toolName]; dup { + return fmt.Errorf("duplicate sub-agent tool name %q", toolName) } + seenNames[toolName] = struct{}{} } - c.subAgentTools = out - for _, s := range c.subAgents { + for _, s := range agents { path := map[*Agent]struct{}{s: {}} if err := dfsSubAgentDepth(s, path, 1, maxDepth); err != nil { return err @@ -980,13 +1042,83 @@ func (c *agentConfig) buildSubAgentTools() error { return nil } +// buildAgentRuntime constructs the execution backend from agentConfig. +// Defaults to the local in-process runtime when no Temporal backend is configured. +// Extend with additional branches when new [runtime.Runtime] implementations are added. +func (cfg *agentConfig) buildAgentRuntime(remoteWorker bool) (runtime.Runtime, error) { + if cfg.hasTemporalRuntime() { + return cfg.buildTemporalRuntime(remoteWorker) + } + return cfg.buildLocalRuntime() +} + +// resolveTools builds the merged tool list for one run from registries and resolution. +func (c *agentConfig) resolveTools(ctx context.Context) ([]interfaces.Tool, error) { + tools := c.toolRegistry.List() + + mcpTools, err := c.resolveMCPTools(ctx) + if err != nil { + return nil, err + } + tools = append(tools, mcpTools...) + + a2aTools, err := c.resolveA2ATools(ctx) + if err != nil { + return nil, err + } + tools = append(tools, a2aTools...) + + subAgentTools, err := c.resolveSubAgentTools() + if err != nil { + return nil, err + } + tools = append(tools, subAgentTools...) + + retrieverTools, err := c.resolveRetrieverTools() + if err != nil { + return nil, err + } + tools = append(tools, retrieverTools...) + + if err := validateToolNames(tools); err != nil { + return nil, err + } + if c.subAgentRegistry != nil { + if err := validateSubAgentRegistry(c, c.subAgentRegistry); err != nil { + return nil, err + } + } + return tools, nil +} + +// resolveSubAgentTools returns sub-agent delegation tools from [subAgentRegistry]. +func (c *agentConfig) resolveSubAgentTools() ([]interfaces.Tool, error) { + if c.subAgentRegistry == nil { + return nil, nil + } + agents := c.subAgentRegistry.List() + if len(agents) == 0 { + return nil, nil + } + out := make([]interfaces.Tool, 0, len(agents)) + for _, sa := range agents { + if st := NewSubAgentTool(sa); st != nil { + out = append(out, st) + } + } + return out, nil +} + func dfsSubAgentDepth(a *Agent, path map[*Agent]struct{}, depth, maxDepth int) error { if depth > maxDepth { return fmt.Errorf("sub-agent depth exceeds max (%d): at %q", maxDepth, a.Name) } - for _, child := range a.subAgents { + if a == nil || a.subAgentRegistry == nil { + return nil + } + for _, child := range a.subAgentRegistry.List() { if child == nil { - return fmt.Errorf("sub-agent %q has a nil entry in WithSubAgents", a.Name) + return fmt.Errorf("sub-agent %q has a nil entry in sub-agent registry", a.Name) } if _, cycle := path[child]; cycle { return fmt.Errorf("sub-agent cycle detected involving %q and %q", a.Name, child.Name) @@ -1000,62 +1132,19 @@ func dfsSubAgentDepth(a *Agent, path map[*Agent]struct{}, depth, maxDepth int) e return nil } -// validateToolNames ensures tool names are unique across WithTools/registry, MCP tools, A2A tools, -// retriever tools, and [subAgentTools]. -func (c *agentConfig) validateToolNames() error { - var base []interfaces.Tool - if c.toolRegistry != nil { - base = c.toolRegistry.Tools() - } else { - base = c.tools - } - names := make(map[string]struct{}) - for _, t := range base { - n := t.Name() - if _, ok := names[n]; ok { - return fmt.Errorf("duplicate tool name %q in WithTools or registry", n) - } - names[n] = struct{}{} - } - for _, t := range c.mcpTools { - if t == nil { - return fmt.Errorf("mcp tool must not be nil") - } - n := t.Name() - if _, ok := names[n]; ok { - return fmt.Errorf("duplicate tool name %q: MCP tool conflicts with an existing tool", n) - } - names[n] = struct{}{} - } - for _, t := range c.a2aTools { - if t == nil { - return fmt.Errorf("a2a tool must not be nil") - } - n := t.Name() - if _, ok := names[n]; ok { - return fmt.Errorf("duplicate tool name %q: A2A tool conflicts with an existing tool", n) - } - names[n] = struct{}{} - } - for _, t := range c.retrieverTools { - if t == nil { - return fmt.Errorf("retriever tool must not be nil") - } - n := t.Name() - if _, ok := names[n]; ok { - return fmt.Errorf("duplicate tool name %q: retriever tool conflicts with an existing tool", n) - } - names[n] = struct{}{} - } - for _, t := range c.subAgentTools { - if t == nil { - return fmt.Errorf("sub-agent tool must not be nil") +// validateToolNames ensures tool names are unique across registry, MCP, A2A, retriever, and sub-agent tools. +func validateToolNames(tools []interfaces.Tool) error { + seen := make(map[string]string, len(tools)) + for _, tool := range tools { + if tool == nil { + return fmt.Errorf("tool must not be nil") } - n := t.Name() - if _, ok := names[n]; ok { - return fmt.Errorf("sub-agent tool name %q conflicts with an existing tool", n) + name := tool.Name() + kind := interfaces.KindOf(tool) + if prev, ok := seen[name]; ok { + return fmt.Errorf("duplicate tool name %q: %s tool conflicts with an existing %s tool", name, kind, prev) } - names[n] = struct{}{} + seen[name] = kind } return nil } @@ -1069,7 +1158,7 @@ func (c *agentConfig) responseFormatForLLM() *interfaces.ResponseFormat { return &interfaces.ResponseFormat{Type: interfaces.ResponseFormatText} } -// runtimeAgentSpec matches [runtime.ExecuteRequest.AgentSpec] / [temporal.TemporalRuntimeConfig.AgentSpec]. +// runtimeAgentSpec is static agent identity wired onto the runtime at construction. // ResponseFormat uses [agentConfig.responseFormatForLLM] so unset format defaults to text (same as LLM requests). func (c *agentConfig) runtimeAgentSpec() runtime.AgentSpec { return runtime.AgentSpec{ @@ -1080,17 +1169,13 @@ func (c *agentConfig) runtimeAgentSpec() runtime.AgentSpec { } } -// runtimeAgentExecution matches [runtime.ExecuteRequest.AgentExecution] / [temporal.TemporalRuntimeConfig.AgentExecution]. -func (c *agentConfig) runtimeAgentExecution() runtime.AgentExecution { - d := runtime.AgentExecution{ +// runtimeAgentConfig is static wiring copied onto the runtime at construction. +func (c *agentConfig) runtimeAgentConfig() runtime.AgentConfig { + d := runtime.AgentConfig{ LLM: runtime.AgentLLM{ Client: c.LLMClient, }, - Tools: runtime.AgentTools{ - Tools: c.toolsList(), - Registry: c.toolRegistry, - ApprovalPolicy: c.toolApprovalPolicy, - }, + ToolApprovalPolicy: c.toolApprovalPolicy, Retrievers: runtime.AgentRetrievers{ Retrievers: c.retrievers, Mode: c.retrieverMode, @@ -1230,21 +1315,21 @@ func (c *agentConfig) applySamplingToRequest(req *interfaces.LLMRequest) { } } -func (c *agentConfig) requiresApproval(t interfaces.Tool) bool { +func (c *agentConfig) requiresApproval(tool interfaces.Tool) bool { if c.toolApprovalPolicy == nil { // No policy: honor tool's ApprovalRequired - if ar, ok := t.(interfaces.ToolApproval); ok && ar.ApprovalRequired() { + if ar, ok := tool.(interfaces.ToolApproval); ok && ar.ApprovalRequired() { return true } return false } // Policy set: policy decides (can override tool default) - return c.toolApprovalPolicy.RequiresApproval(t) + return c.toolApprovalPolicy.RequiresApproval(tool) } -func (c *agentConfig) hasApprovalTools() bool { - for _, t := range c.toolsList() { - if c.requiresApproval(t) { +func (c *agentConfig) hasApprovalTools(tools []interfaces.Tool) bool { + for _, tool := range tools { + if c.requiresApproval(tool) { return true } } @@ -1258,70 +1343,43 @@ func validateMCPClients(clients []interfaces.MCPClient) error { if cl == nil { return fmt.Errorf("mcp client must not be nil") } - n := strings.TrimSpace(cl.Name()) - if n == "" { + name := strings.TrimSpace(cl.Name()) + if name == "" { return fmt.Errorf("mcp client name must not be empty") } - if _, dup := seen[n]; dup { - return fmt.Errorf("duplicate mcp client name %q", n) + if _, dup := seen[name]; dup { + return fmt.Errorf("duplicate mcp client name %q", name) } - seen[n] = struct{}{} + seen[name] = struct{}{} } return nil } -// buildMCPTools merges [agentConfig.mcpServers] (default SDK client per key) with [agentConfig.mcpClients], -// validates names, lists tools from each client (tool allow/block filtering runs inside [mcpclient.Client.ListTools] when configured), and appends [MCPTool] to [agentConfig.mcpTools]. -func (c *agentConfig) buildMCPTools() error { - c.mcpTools = []interfaces.Tool{} - if len(c.mcpServers) == 0 && len(c.mcpClients) == 0 { - return nil +// resolveMCPTools lists tools from [mcpRegistry] clients. +func (c *agentConfig) resolveMCPTools(ctx context.Context) ([]interfaces.Tool, error) { + if c.mcpRegistry == nil { + return nil, nil } - - keys := make([]string, 0, len(c.mcpServers)) - for k := range c.mcpServers { - keys = append(keys, k) - } - sort.Strings(keys) - - clients := make([]interfaces.MCPClient, 0, len(keys)+len(c.mcpClients)) - for _, k := range keys { - cfg := c.mcpServers[k] - if cfg.Transport == nil { - return fmt.Errorf("mcp %q: Transport is required", k) - } - mcpOpts := []mcpclient.Option{ - mcpclient.WithLogger(c.logger), - mcpclient.WithTimeout(cfg.Timeout), - mcpclient.WithRetryAttempts(cfg.RetryAttempts), - mcpclient.WithToolFilter(cfg.ToolFilter), - } - cl, err := mcpclient.NewClient(k, cfg.Transport, mcpOpts...) - if err != nil { - return fmt.Errorf("mcp %q: new client: %w", k, err) - } - clients = append(clients, cl) + clients := c.mcpRegistry.List() + if len(clients) == 0 { + return nil, nil } - clients = append(clients, c.mcpClients...) - if err := validateMCPClients(clients); err != nil { - return err + if ctx == nil { + ctx = context.Background() } - - ctx := context.Background() - var tools []interfaces.Tool + tools := make([]interfaces.Tool, 0) for _, cl := range clients { - sk := strings.TrimSpace(cl.Name()) + serverKey := strings.TrimSpace(cl.Name()) specs, err := cl.ListTools(ctx) if err != nil { - return fmt.Errorf("mcp %q: list tools: %w", sk, err) + return nil, fmt.Errorf("mcp %q: list tools: %w", serverKey, err) } - for _, sp := range specs { - tools = append(tools, NewMCPTool(sk, sp, cl)) + for _, spec := range specs { + tools = append(tools, NewMCPTool(serverKey, spec, cl)) } } - c.mcpTools = tools - return nil + return tools, nil } // validateRetrievers checks for nil entries in [WithRetrievers]. @@ -1349,36 +1407,34 @@ func validateRetrieverMode(mode RetrieverMode) (RetrieverMode, error) { } } -// buildRetrieverTools registers a [RetrieverTool] per [WithRetrievers] entry when mode is -// [RetrieverModeAgentic] or [RetrieverModeHybrid], and appends to [agentConfig.retrieverTools]. +// resolveRetrieverTools returns a [RetrieverTool] per [WithRetrievers] entry when mode is +// [RetrieverModeAgentic] or [RetrieverModeHybrid]. // [RetrieverModePrefetch] does not expose tools (context is injected before the first LLM call). -func (c *agentConfig) buildRetrieverTools() error { - c.retrieverTools = nil +func (c *agentConfig) resolveRetrieverTools() ([]interfaces.Tool, error) { if c.retrieverMode != RetrieverModeAgentic && c.retrieverMode != RetrieverModeHybrid { - return nil + return nil, nil } if len(c.retrievers) == 0 { - return nil + return nil, nil } seen := make(map[string]struct{}, len(c.retrievers)) tools := make([]interfaces.Tool, 0, len(c.retrievers)) - for _, r := range c.retrievers { - n := strings.TrimSpace(r.Name()) - if n == "" { - return fmt.Errorf("retriever name must not be empty") + for _, retriever := range c.retrievers { + name := strings.TrimSpace(retriever.Name()) + if name == "" { + return nil, fmt.Errorf("retriever name must not be empty") } - if _, dup := seen[n]; dup { - return fmt.Errorf("duplicate retriever name %q", n) + if _, dup := seen[name]; dup { + return nil, fmt.Errorf("duplicate retriever name %q", name) } - seen[n] = struct{}{} - tool := NewRetrieverTool(r) + seen[name] = struct{}{} + tool := NewRetrieverTool(retriever) if tool == nil { - return fmt.Errorf("retriever %q: failed to build tool", n) + return nil, fmt.Errorf("retriever %q: failed to build tool", name) } tools = append(tools, tool) } - c.retrieverTools = tools - return nil + return tools, nil } // validateA2AClients checks for nil clients, empty names, and duplicate [interfaces.A2AClient.Name] values. @@ -1388,76 +1444,44 @@ func validateA2AClients(clients []interfaces.A2AClient) error { if cl == nil { return fmt.Errorf("a2a client must not be nil") } - n := strings.TrimSpace(cl.Name()) - if n == "" { + name := strings.TrimSpace(cl.Name()) + if name == "" { return fmt.Errorf("a2a client name must not be empty") } - if _, dup := seen[n]; dup { - return fmt.Errorf("duplicate a2a client name %q", n) + if _, dup := seen[name]; dup { + return fmt.Errorf("duplicate a2a client name %q", name) } - seen[n] = struct{}{} + seen[name] = struct{}{} } return nil } -// buildA2ATools merges [agentConfig.a2aServers] (default SDK client per key) with [agentConfig.a2aClients], -// validates names, lists skills from each client (skill allow/block filtering runs inside -// [a2aclient.Client.ListSkills] when [WithSkillFilter] is configured), and appends [A2ATool] to [agentConfig.a2aTools]. -func (c *agentConfig) buildA2ATools() error { - c.a2aTools = []interfaces.Tool{} - if len(c.a2aServers) == 0 && len(c.a2aClients) == 0 { - return nil - } - - keys := make([]string, 0, len(c.a2aServers)) - for k := range c.a2aServers { - keys = append(keys, k) +// resolveA2ATools lists skills from [a2aRegistry] clients. +func (c *agentConfig) resolveA2ATools(ctx context.Context) ([]interfaces.Tool, error) { + if c.a2aRegistry == nil { + return nil, nil } - sort.Strings(keys) - - clients := make([]interfaces.A2AClient, 0, len(keys)+len(c.a2aClients)) - for _, k := range keys { - cfg := c.a2aServers[k] - if strings.TrimSpace(cfg.URL) == "" { - return fmt.Errorf("a2a %q: URL is required", k) - } - a2aOpts := []a2aclient.Option{ - a2aclient.WithLogger(c.logger), - a2aclient.WithTimeout(cfg.Timeout), - a2aclient.WithToken(cfg.Token), - a2aclient.WithHeaders(cfg.Headers), - a2aclient.WithSkillFilter(cfg.SkillFilter), - } - if cfg.SkipTLSVerify { - a2aOpts = append(a2aOpts, a2aclient.WithSkipTLSVerify(true)) - } - cl, err := a2aclient.NewClient(k, cfg.URL, a2aOpts...) - if err != nil { - return fmt.Errorf("a2a %q: new client: %w", k, err) - } - clients = append(clients, cl) + clients := c.a2aRegistry.List() + if len(clients) == 0 { + return nil, nil } - clients = append(clients, c.a2aClients...) - if err := validateA2AClients(clients); err != nil { - return err + if ctx == nil { + ctx = context.Background() } - - ctx := context.Background() - var tools []interfaces.Tool + tools := make([]interfaces.Tool, 0) for _, cl := range clients { - sk := strings.TrimSpace(cl.Name()) + serverKey := strings.TrimSpace(cl.Name()) skills, err := cl.ListSkills(ctx) if err != nil { - return fmt.Errorf("a2a %q: list skills: %w", sk, err) + return nil, fmt.Errorf("a2a %q: list skills: %w", serverKey, err) } - for _, sp := range skills { - tools = append(tools, NewA2ATool(sk, interfaces.ToolSpec{ - Name: sp.ID, - Description: sp.Description, - }, sp, cl)) + for _, skill := range skills { + tools = append(tools, NewA2ATool(serverKey, interfaces.ToolSpec{ + Name: skill.ID, + Description: skill.Description, + }, skill, cl)) } } - c.a2aTools = tools - return nil + return tools, nil } diff --git a/pkg/agent/config_test.go b/pkg/agent/config_test.go index 26374cc..1fdb58d 100644 --- a/pkg/agent/config_test.go +++ b/pkg/agent/config_test.go @@ -21,13 +21,19 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" ) -// agentConfigFingerprint is a test helper that mirrors the fingerprint computed by the temporal -// runtime for a given agent config. Lives here (not in production code) since it is only used -// to assert fingerprint stability in tests. +// agentConfigFingerprint is a test helper for Temporal per-run fingerprint payloads. func agentConfigFingerprint(c *agentConfig) string { - mat := temporal.BuildAgentFingerprintPayload( + tools, err := c.resolveTools(context.Background()) + if err != nil { + panic(err) + } + return agentConfigFingerprintTools(c, tools) +} + +func agentConfigFingerprintTools(c *agentConfig, tools []interfaces.Tool) string { + return temporal.ComputeAgentFingerprint(temporal.BuildAgentFingerprintPayload( c.runtimeAgentSpec(), - temporal.ToolNamesFromTools(c.toolsList()), + temporal.ToolNamesFromTools(tools), toolPolicyFingerprint(c.toolApprovalPolicy), llmSamplingRuntimeView(c.llmSampling), c.conversationSize, @@ -42,8 +48,7 @@ func agentConfigFingerprint(c *agentConfig) string { string(c.agentMode), c.agentToolExecutionMode, retrieverConfigFingerprint(c.retrieverMode, c.retrievers), - ) - return temporal.ComputeAgentFingerprint(mat) + )) } func TestBuildAgentConfig_NeitherTemporalConfigNorClient_UsesLocalRuntime(t *testing.T) { @@ -142,7 +147,7 @@ func TestBuildAgentConfig_WithMCP(t *testing.T) { } defer func() { _ = srvSess.Close() }() - cfg, err := buildAgentConfig([]Option{ + _, err = buildAgentConfig([]Option{ WithName("test"), WithTemporalConfig(&TemporalConfig{TaskQueue: "q"}), WithLLMClient(stubLLM{}), @@ -154,9 +159,7 @@ func TestBuildAgentConfig_WithMCP(t *testing.T) { if err != nil { t.Fatal(err) } - if len(cfg.mcpTools) != 1 || cfg.mcpTools[0].Name() != "mcp_srv_keep" { - t.Fatalf("mcpTools = %v", cfg.mcpTools) - } + // buildAgentConfig calls resolveTools; success means MCP discovery + filter produced valid tools. } func TestBuildAgentConfig_MCPClients_toolFilter(t *testing.T) { @@ -180,7 +183,7 @@ func TestBuildAgentConfig_MCPClients_toolFilter(t *testing.T) { if err != nil { t.Fatal(err) } - cfg, err := buildAgentConfig([]Option{ + _, err = buildAgentConfig([]Option{ WithName("test"), WithTemporalConfig(&TemporalConfig{TaskQueue: "q"}), WithLLMClient(stubLLM{}), @@ -189,9 +192,7 @@ func TestBuildAgentConfig_MCPClients_toolFilter(t *testing.T) { if err != nil { t.Fatal(err) } - if len(cfg.mcpTools) != 1 || cfg.mcpTools[0].Name() != "mcp_s_keep" { - t.Fatalf("mcpTools = %v", cfg.mcpTools) - } + // buildAgentConfig calls resolveTools; success means MCP discovery + filter produced valid tools. } func TestBuildAgentConfig_MCP_duplicateClientName(t *testing.T) { @@ -208,7 +209,7 @@ func TestBuildAgentConfig_MCP_duplicateClientName(t *testing.T) { }}), WithMCPClients(cl), }) - if err == nil || !strings.Contains(err.Error(), "duplicate mcp client name") { + if err == nil || !strings.Contains(err.Error(), "duplicate mcp client name") && !strings.Contains(err.Error(), "already exists") { t.Fatalf("got %v", err) } } @@ -216,14 +217,23 @@ func TestBuildAgentConfig_MCP_duplicateClientName(t *testing.T) { func TestAgentConfig_ToolsList(t *testing.T) { tool := mockTool{name: "t1"} c := &agentConfig{tools: []interfaces.Tool{tool}} - list := c.toolsList() + if err := c.buildToolRegistry(); err != nil { + t.Fatal(err) + } + list, err := c.resolveTools(context.Background()) + if err != nil { + t.Fatal(err) + } if len(list) != 1 || list[0].Name() != "t1" { t.Errorf("toolsList = %v, want [t1]", list) } reg := &mockRegistry{tools: []interfaces.Tool{tool, mockTool{name: "t2"}}} c2 := &agentConfig{toolRegistry: reg} - list2 := c2.toolsList() + list2, err := c2.resolveTools(context.Background()) + if err != nil { + t.Fatal(err) + } if len(list2) != 2 { t.Errorf("toolsList with registry = %v, want 2 tools", list2) } @@ -288,81 +298,81 @@ func TestAgentConfig_RequiresApproval(t *testing.T) { } } -func TestAgentConfig_buildSubAgentTools_duplicateRootSubs(t *testing.T) { +func TestAgentConfig_resolveSubAgentTools_duplicateRootSubs(t *testing.T) { s := &Agent{agentConfig: agentConfig{Name: "Same"}} c := &agentConfig{subAgents: []*Agent{s, s}, maxSubAgentDepth: 3} - err := c.buildSubAgentTools() - if err == nil || !strings.Contains(err.Error(), "duplicate") { + err := c.buildSubAgentRegistry() + if err == nil || (!strings.Contains(err.Error(), "duplicate") && !strings.Contains(err.Error(), "already exists")) { t.Fatalf("want duplicate error, got %v", err) } } -func TestAgentConfig_buildSubAgentTools_duplicateDerivedToolName(t *testing.T) { +func TestAgentConfig_resolveSubAgentTools_duplicateDerivedToolName(t *testing.T) { a := &Agent{agentConfig: agentConfig{Name: "Dup"}} b := &Agent{agentConfig: agentConfig{Name: "Dup"}} c := &agentConfig{subAgents: []*Agent{a, b}, maxSubAgentDepth: 3} - err := c.buildSubAgentTools() - if err == nil || !strings.Contains(err.Error(), "duplicate sub-agent tool name") { + err := c.buildSubAgentRegistry() + if err == nil || (!strings.Contains(err.Error(), "duplicate sub-agent tool name") && !strings.Contains(err.Error(), "already exists")) { t.Fatalf("want duplicate sub-agent tool name error, got %v", err) } } -func TestAgentConfig_buildSubAgentTools_nilSubAgent(t *testing.T) { +func TestAgentConfig_resolveSubAgentTools_nilSubAgent(t *testing.T) { c := &agentConfig{subAgents: []*Agent{nil}, maxSubAgentDepth: 3} - err := c.buildSubAgentTools() + err := c.buildSubAgentRegistry() if err == nil || !strings.Contains(err.Error(), "nil") { t.Fatalf("want nil sub-agent error, got %v", err) } } -func TestAgentConfig_buildSubAgentTools_invalidSubAgentName(t *testing.T) { +func TestAgentConfig_resolveSubAgentTools_invalidSubAgentName(t *testing.T) { emptyName := &Agent{agentConfig: agentConfig{Name: "", ID: "id-only"}} c := &agentConfig{subAgents: []*Agent{emptyName}, maxSubAgentDepth: 3} - if err := c.buildSubAgentTools(); err == nil { + if err := c.buildSubAgentRegistry(); err == nil { t.Fatal("expected error for empty sub-agent name") } symbolsOnly := &Agent{agentConfig: agentConfig{Name: "@@@"}} c2 := &agentConfig{subAgents: []*Agent{symbolsOnly}, maxSubAgentDepth: 3} - if err := c2.buildSubAgentTools(); err == nil { + if err := c2.buildSubAgentRegistry(); err == nil { t.Fatal("expected error for sub-agent name with no alphanumeric characters") } } -func TestAgentConfig_buildSubAgentTools_cycleAB(t *testing.T) { - a := &Agent{agentConfig: agentConfig{Name: "A"}} - b := &Agent{agentConfig: agentConfig{Name: "B"}} - a.subAgents = []*Agent{b} - b.subAgents = []*Agent{a} +func TestAgentConfig_resolveSubAgentTools_cycleAB(t *testing.T) { + a := &Agent{agentConfig: agentConfig{Name: "A", subAgentRegistry: NewSubAgentRegistry()}} + b := &Agent{agentConfig: agentConfig{Name: "B", subAgentRegistry: NewSubAgentRegistry()}} + _ = a.subAgentRegistry.Register(b) + _ = b.subAgentRegistry.Register(a) c := &agentConfig{subAgents: []*Agent{a}, maxSubAgentDepth: 5} - err := c.buildSubAgentTools() + err := c.buildSubAgentRegistry() if err == nil || !strings.Contains(err.Error(), "cycle") { t.Fatalf("want cycle error, got %v", err) } } -func TestAgentConfig_buildSubAgentTools_depthExceeded(t *testing.T) { - d1 := &Agent{agentConfig: agentConfig{Name: "d1"}} - d2 := &Agent{agentConfig: agentConfig{Name: "d2"}} - d3 := &Agent{agentConfig: agentConfig{Name: "d3"}} - d4 := &Agent{agentConfig: agentConfig{Name: "d4"}} - d1.subAgents = []*Agent{d2} - d2.subAgents = []*Agent{d3} - d3.subAgents = []*Agent{d4} +func TestAgentConfig_resolveSubAgentTools_depthExceeded(t *testing.T) { + d4 := &Agent{agentConfig: agentConfig{Name: "d4", subAgentRegistry: NewSubAgentRegistry()}} + d3 := &Agent{agentConfig: agentConfig{Name: "d3", subAgentRegistry: NewSubAgentRegistry()}} + d2 := &Agent{agentConfig: agentConfig{Name: "d2", subAgentRegistry: NewSubAgentRegistry()}} + d1 := &Agent{agentConfig: agentConfig{Name: "d1", subAgentRegistry: NewSubAgentRegistry()}} + _ = d3.subAgentRegistry.Register(d4) + _ = d2.subAgentRegistry.Register(d3) + _ = d1.subAgentRegistry.Register(d2) c := &agentConfig{subAgents: []*Agent{d1}, maxSubAgentDepth: 3} - err := c.buildSubAgentTools() + err := c.buildSubAgentRegistry() if err == nil || !strings.Contains(err.Error(), "depth") { t.Fatalf("want depth error, got %v", err) } } -func TestAgentConfig_buildSubAgentTools_okWithinDepth(t *testing.T) { - d1 := &Agent{agentConfig: agentConfig{Name: "d1"}} - d2 := &Agent{agentConfig: agentConfig{Name: "d2"}} - d3 := &Agent{agentConfig: agentConfig{Name: "d3"}} - d1.subAgents = []*Agent{d2} - d2.subAgents = []*Agent{d3} +func TestAgentConfig_resolveSubAgentTools_okWithinDepth(t *testing.T) { + d3 := &Agent{agentConfig: agentConfig{Name: "d3", subAgentRegistry: NewSubAgentRegistry()}} + d2 := &Agent{agentConfig: agentConfig{Name: "d2", subAgentRegistry: NewSubAgentRegistry()}} + d1 := &Agent{agentConfig: agentConfig{Name: "d1", subAgentRegistry: NewSubAgentRegistry()}} + _ = d2.subAgentRegistry.Register(d3) + _ = d1.subAgentRegistry.Register(d2) c := &agentConfig{subAgents: []*Agent{d1}, maxSubAgentDepth: 3} - if err := c.buildSubAgentTools(); err != nil { + if err := c.buildSubAgentRegistry(); err != nil { t.Fatal(err) } } @@ -373,10 +383,15 @@ func TestAgentConfig_validateToolNames_conflict(t *testing.T) { tools: []interfaces.Tool{mockTool{name: "subagent_Math"}}, subAgents: []*Agent{sub}, } - if err := c.buildSubAgentTools(); err != nil { + if err := c.buildRegistries(); err != nil { + t.Fatal(err) + } + subs, err := c.resolveSubAgentTools() + if err != nil { t.Fatal(err) } - err := c.validateToolNames() + tools := append(c.toolRegistry.List(), subs...) + err = validateToolNames(tools) if err == nil || (!strings.Contains(err.Error(), "duplicate tool name") && !strings.Contains(err.Error(), "conflicts")) { t.Fatalf("want duplicate / conflict error, got %v", err) } @@ -388,10 +403,13 @@ func TestAgentConfig_toolsList_includesSubAgents(t *testing.T) { tools: []interfaces.Tool{mockTool{name: "echo"}}, subAgents: []*Agent{sub}, } - if err := c.buildSubAgentTools(); err != nil { + if err := c.buildRegistries(); err != nil { + t.Fatal(err) + } + list, err := c.resolveTools(context.Background()) + if err != nil { t.Fatal(err) } - list := c.toolsList() if len(list) != 2 { t.Fatalf("toolsList len = %d, want 2", len(list)) } @@ -412,7 +430,10 @@ func TestAgentConfig_HasApprovalTools(t *testing.T) { tools: []interfaces.Tool{mockToolWithApproval{mockTool: mockTool{name: "x"}, needApproval: true}}, toolApprovalPolicy: RequireAllToolApprovalPolicy{}, } - if !c.hasApprovalTools() { + if err := c.buildToolRegistry(); err != nil { + t.Fatal(err) + } + if !c.hasApprovalTools(c.toolRegistry.List()) { t.Error("hasApprovalTools should be true when tools require approval") } @@ -420,25 +441,56 @@ func TestAgentConfig_HasApprovalTools(t *testing.T) { tools: []interfaces.Tool{mockToolWithApproval{mockTool: mockTool{name: "x"}, needApproval: false}}, toolApprovalPolicy: AutoToolApprovalPolicy(), } - if c2.hasApprovalTools() { + if err := c2.buildToolRegistry(); err != nil { + t.Fatal(err) + } + if c2.hasApprovalTools(c2.toolRegistry.List()) { t.Error("hasApprovalTools should be false when no tool requires approval") } } +func TestBuildAgentConfig_approvalTimeoutValidatedWithoutApprovalTools(t *testing.T) { + _, err := buildAgentConfig([]Option{ + WithName("test"), + WithTemporalConfig(&TemporalConfig{TaskQueue: "q"}), + WithLLMClient(stubLLM{}), + WithTimeout(5 * time.Minute), + WithApprovalTimeout(6 * time.Minute), + }) + if err == nil || !strings.Contains(err.Error(), "approvalTimeout") { + t.Fatalf("got %v", err) + } + + cfg, err := buildAgentConfig([]Option{ + WithName("test"), + WithTemporalConfig(&TemporalConfig{TaskQueue: "q"}), + WithLLMClient(stubLLM{}), + WithTimeout(5 * time.Minute), + WithApprovalTimeout(2 * time.Minute), + }) + if err != nil { + t.Fatal(err) + } + if cfg.approvalTimeout != 2*time.Minute { + t.Fatalf("approvalTimeout = %v", cfg.approvalTimeout) + } +} + type mockRegistry struct { tools []interfaces.Tool } -func (m *mockRegistry) Register(interfaces.Tool) {} -func (m *mockRegistry) Get(name string) (interfaces.Tool, bool) { +func (m *mockRegistry) Register(interfaces.Tool) error { return nil } +func (m *mockRegistry) Unregister(string) error { return ErrRegistryNotFound } +func (m *mockRegistry) Get(name string) (interfaces.Tool, error) { for _, t := range m.tools { if t.Name() == name { - return t, true + return t, nil } } - return nil, false + return nil, ErrRegistryNotFound } -func (m *mockRegistry) Tools() []interfaces.Tool { return m.tools } +func (m *mockRegistry) List() []interfaces.Tool { return m.tools } type mockToolWithApproval struct { mockTool @@ -565,11 +617,12 @@ func TestBuildRetrieverTools(t *testing.T) { retrieverMode: RetrieverModeAgentic, retrievers: []interfaces.Retriever{namedStubRetriever("kb")}, } - if err := c.buildRetrieverTools(); err != nil { + tools, err := c.resolveRetrieverTools() + if err != nil { t.Fatal(err) } - if len(c.retrieverTools) != 1 || c.retrieverTools[0].Name() != "retriever_kb" { - t.Fatalf("retrieverTools = %v", c.retrieverTools) + if len(tools) != 1 || tools[0].Name() != "retriever_kb" { + t.Fatalf("retrieverTools = %v", tools) } }) t.Run("hybrid_builds_tools", func(t *testing.T) { @@ -577,11 +630,12 @@ func TestBuildRetrieverTools(t *testing.T) { retrieverMode: RetrieverModeHybrid, retrievers: []interfaces.Retriever{stubRetriever{}}, } - if err := c.buildRetrieverTools(); err != nil { + tools, err := c.resolveRetrieverTools() + if err != nil { t.Fatal(err) } - if len(c.retrieverTools) != 1 { - t.Fatalf("len = %d", len(c.retrieverTools)) + if len(tools) != 1 { + t.Fatalf("len = %d", len(tools)) } }) t.Run("prefetch_skips_tools", func(t *testing.T) { @@ -589,20 +643,22 @@ func TestBuildRetrieverTools(t *testing.T) { retrieverMode: RetrieverModePrefetch, retrievers: []interfaces.Retriever{stubRetriever{}}, } - if err := c.buildRetrieverTools(); err != nil { + tools, err := c.resolveRetrieverTools() + if err != nil { t.Fatal(err) } - if c.retrieverTools != nil { - t.Fatalf("retrieverTools = %v, want nil", c.retrieverTools) + if len(tools) != 0 { + t.Fatalf("retrieverTools = %v, want none", tools) } }) t.Run("no_retrievers", func(t *testing.T) { c := &agentConfig{retrieverMode: RetrieverModeAgentic} - if err := c.buildRetrieverTools(); err != nil { + tools, err := c.resolveRetrieverTools() + if err != nil { t.Fatal(err) } - if c.retrieverTools != nil { - t.Fatalf("retrieverTools = %v, want nil", c.retrieverTools) + if len(tools) != 0 { + t.Fatalf("retrieverTools = %v, want none", tools) } }) t.Run("duplicate_name", func(t *testing.T) { @@ -610,7 +666,7 @@ func TestBuildRetrieverTools(t *testing.T) { retrieverMode: RetrieverModeAgentic, retrievers: []interfaces.Retriever{namedStubRetriever("x"), namedStubRetriever("x")}, } - err := c.buildRetrieverTools() + _, err := c.resolveRetrieverTools() if err == nil || !strings.Contains(err.Error(), "duplicate retriever name") { t.Fatalf("got %v", err) } @@ -620,7 +676,7 @@ func TestBuildRetrieverTools(t *testing.T) { retrieverMode: RetrieverModeAgentic, retrievers: []interfaces.Retriever{namedStubRetriever(" ")}, } - err := c.buildRetrieverTools() + _, err := c.resolveRetrieverTools() if err == nil || !strings.Contains(err.Error(), "must not be empty") { t.Fatalf("got %v", err) } @@ -641,8 +697,12 @@ func TestBuildAgentConfig_WithRetrievers(t *testing.T) { if len(cfg.retrievers) != 2 { t.Fatalf("retrievers len = %d", len(cfg.retrievers)) } - if len(cfg.retrieverTools) != 2 { - t.Fatalf("retrieverTools len = %d, want 2 (default agentic mode)", len(cfg.retrieverTools)) + tools, err := cfg.resolveTools(context.Background()) + if err != nil { + t.Fatal(err) + } + if len(tools) != 2 { + t.Fatalf("resolved tools len = %d, want 2 (default agentic mode)", len(tools)) } } @@ -657,8 +717,14 @@ func TestBuildAgentConfig_RetrieverMode_prefetchNoTools(t *testing.T) { if err != nil { t.Fatal(err) } - if len(cfg.retrieverTools) != 0 { - t.Fatalf("retrieverTools len = %d, want 0 for prefetch", len(cfg.retrieverTools)) + tools, err := cfg.resolveTools(context.Background()) + if err != nil { + t.Fatal(err) + } + for _, tool := range tools { + if tool != nil && strings.HasPrefix(tool.Name(), "retriever_") { + t.Fatalf("prefetch mode should not expose retriever tools, got %q", tool.Name()) + } } } @@ -673,9 +739,23 @@ func TestBuildAgentConfig_RetrieverMode_agenticBuildsTools(t *testing.T) { if err != nil { t.Fatal(err) } - if len(cfg.retrieverTools) != 1 || cfg.retrieverTools[0].Name() != "retriever_stub" { - t.Fatalf("retrieverTools = %v", cfg.retrieverTools) + tools, err := cfg.resolveTools(context.Background()) + if err != nil { + t.Fatal(err) + } + if len(tools) != 1 || tools[0].Name() != "retriever_stub" { + t.Fatalf("resolved tools = %v", toolNames(tools)) + } +} + +func toolNames(tools []interfaces.Tool) []string { + out := make([]string, 0, len(tools)) + for _, t := range tools { + if t != nil { + out = append(out, t.Name()) + } } + return out } func TestBuildAgentConfig_AgenticNoRetrievers_NoTools(t *testing.T) { @@ -688,8 +768,12 @@ func TestBuildAgentConfig_AgenticNoRetrievers_NoTools(t *testing.T) { if err != nil { t.Fatal(err) } - if len(cfg.retrieverTools) != 0 { - t.Fatalf("retrieverTools len = %d, want 0", len(cfg.retrieverTools)) + tools, err := cfg.resolveTools(context.Background()) + if err != nil { + t.Fatal(err) + } + if len(tools) != 0 { + t.Fatalf("resolved tools = %v, want none", toolNames(tools)) } } @@ -704,8 +788,12 @@ func TestBuildAgentConfig_RetrieverMode_hybridBuildsTools(t *testing.T) { if err != nil { t.Fatal(err) } - if len(cfg.retrieverTools) != 1 || cfg.retrieverTools[0].Name() != "retriever_stub" { - t.Fatalf("retrieverTools = %v", cfg.retrieverTools) + tools, err := cfg.resolveTools(context.Background()) + if err != nil { + t.Fatal(err) + } + if len(tools) != 1 || tools[0].Name() != "retriever_stub" { + t.Fatalf("resolved tools = %v", toolNames(tools)) } } @@ -732,7 +820,10 @@ func TestBuildAgentConfig_toolsList_includesRetrieverTools(t *testing.T) { if err != nil { t.Fatal(err) } - list := cfg.toolsList() + list, err := cfg.resolveTools(context.Background()) + if err != nil { + t.Fatal(err) + } if len(list) != 2 { t.Fatalf("toolsList len = %d, want 2", len(list)) } @@ -743,21 +834,27 @@ func TestBuildAgentConfig_toolsList_includesRetrieverTools(t *testing.T) { func TestBuildAgentConfig_validateToolNames_RetrieverConflict(t *testing.T) { c := &agentConfig{ - tools: []interfaces.Tool{mockTool{name: "retriever_stub"}}, - retrieverTools: []interfaces.Tool{ - NewRetrieverTool(stubRetriever{}), - }, + tools: []interfaces.Tool{mockTool{name: "retriever_stub"}}, + retrievers: []interfaces.Retriever{stubRetriever{}}, + retrieverMode: RetrieverModeAgentic, + } + if err := c.buildToolRegistry(); err != nil { + t.Fatal(err) } - err := c.validateToolNames() - if err == nil || !strings.Contains(err.Error(), "retriever tool conflicts") { + retr, err := c.resolveRetrieverTools() + if err != nil { + t.Fatal(err) + } + tools := append(c.toolRegistry.List(), retr...) + err = validateToolNames(tools) + if err == nil || !strings.Contains(err.Error(), "conflicts") { t.Fatalf("got %v", err) } } func TestBuildAgentConfig_validateToolNames_nilRetrieverTool(t *testing.T) { - c := &agentConfig{retrieverTools: []interfaces.Tool{nil}} - err := c.validateToolNames() - if err == nil || !strings.Contains(err.Error(), "retriever tool must not be nil") { + err := validateToolNames([]interfaces.Tool{nil}) + if err == nil || !strings.Contains(err.Error(), "tool must not be nil") { t.Fatalf("got %v", err) } } @@ -788,8 +885,12 @@ func TestBuildAgentConfig_WithRetrievers_emptyClears(t *testing.T) { if cfg.retrievers != nil { t.Fatalf("retrievers = %v, want nil", cfg.retrievers) } - if len(cfg.retrieverTools) != 0 { - t.Fatalf("retrieverTools len = %d, want 0", len(cfg.retrieverTools)) + tools, err := cfg.resolveTools(context.Background()) + if err != nil { + t.Fatal(err) + } + if len(tools) != 0 { + t.Fatalf("resolved tools = %v, want none", toolNames(tools)) } } @@ -871,7 +972,10 @@ func TestBuildAgentConfig_toolsList_includesRetrieverTools_hybrid(t *testing.T) if err != nil { t.Fatal(err) } - list := cfg.toolsList() + list, err := cfg.resolveTools(context.Background()) + if err != nil { + t.Fatal(err) + } if len(list) != 2 { t.Fatalf("toolsList len = %d, want 2 (base tool + retriever tool)", len(list)) } @@ -953,14 +1057,18 @@ func TestBuildAgentConfig_WithA2AConfig(t *testing.T) { if err != nil { t.Fatal(err) } - if len(cfg.a2aTools) != 2 { - t.Fatalf("a2aTools len = %d, want 2", len(cfg.a2aTools)) + tools, err := cfg.resolveTools(context.Background()) + if err != nil { + t.Fatal(err) + } + if len(tools) != 2 { + t.Fatalf("tools len = %d, want 2", len(tools)) } - if cfg.a2aTools[0].Name() != "a2a_agent_search" { - t.Errorf("tool[0].Name = %q, want a2a_agent_search", cfg.a2aTools[0].Name()) + if tools[0].Name() != "a2a_agent_search" { + t.Errorf("tool[0].Name = %q, want a2a_agent_search", tools[0].Name()) } - if cfg.a2aTools[1].Name() != "a2a_agent_summarize" { - t.Errorf("tool[1].Name = %q, want a2a_agent_summarize", cfg.a2aTools[1].Name()) + if tools[1].Name() != "a2a_agent_summarize" { + t.Errorf("tool[1].Name = %q, want a2a_agent_summarize", tools[1].Name()) } } @@ -981,8 +1089,12 @@ func TestBuildAgentConfig_WithA2AConfig_SkillFilter(t *testing.T) { if err != nil { t.Fatal(err) } - if len(cfg.a2aTools) != 1 || cfg.a2aTools[0].Name() != "a2a_agent_keep" { - t.Fatalf("a2aTools = %v, want [a2a_agent_keep]", cfg.a2aTools) + tools, err := cfg.resolveTools(context.Background()) + if err != nil { + t.Fatal(err) + } + if len(tools) != 1 || tools[0].Name() != "a2a_agent_keep" { + t.Fatalf("tools = %v, want [a2a_agent_keep]", tools) } } @@ -1000,8 +1112,12 @@ func TestBuildAgentConfig_WithA2AClients(t *testing.T) { if err != nil { t.Fatal(err) } - if len(cfg.a2aTools) != 1 || cfg.a2aTools[0].Name() != "a2a_agent1_echo" { - t.Fatalf("a2aTools = %v, want [a2a_agent1_echo]", cfg.a2aTools) + tools, err := cfg.resolveTools(context.Background()) + if err != nil { + t.Fatal(err) + } + if len(tools) != 1 || tools[0].Name() != "a2a_agent1_echo" { + t.Fatalf("tools = %v, want [a2a_agent1_echo]", tools) } } @@ -1167,7 +1283,7 @@ func TestBuildAgentConfig_A2A_duplicateClientName(t *testing.T) { WithA2AConfig(A2AServers{"dup": A2AConfig{URL: "http://127.0.0.1:1"}}), WithA2AClients(cl), }) - if err == nil || !strings.Contains(err.Error(), "duplicate a2a client name") { + if err == nil || !strings.Contains(err.Error(), "duplicate a2a client name") && !strings.Contains(err.Error(), "already exists") { t.Fatalf("got %v", err) } } @@ -1176,10 +1292,12 @@ func TestAgentConfig_toolsList_includesA2ATools(t *testing.T) { echo := mockTool{name: "echo"} a2aTool := NewA2ATool("agent1", interfaces.ToolSpec{Name: "search", Description: "d"}, interfaces.A2ASkillSpec{}, nil) c := &agentConfig{ - tools: []interfaces.Tool{echo}, - a2aTools: []interfaces.Tool{a2aTool}, + tools: []interfaces.Tool{echo}, + } + if err := c.buildToolRegistry(); err != nil { + t.Fatal(err) } - list := c.toolsList() + list := append(c.toolRegistry.List(), a2aTool) if len(list) != 2 { t.Fatalf("toolsList len = %d, want 2", len(list)) } @@ -1194,10 +1312,13 @@ func TestAgentConfig_toolsList_includesA2ATools(t *testing.T) { func TestAgentConfig_validateToolNames_A2AConflict(t *testing.T) { a2aTool := NewA2ATool("srv", interfaces.ToolSpec{Name: "s", Description: "d"}, interfaces.A2ASkillSpec{}, nil) c := &agentConfig{ - tools: []interfaces.Tool{mockTool{name: a2aTool.Name()}}, - a2aTools: []interfaces.Tool{a2aTool}, + tools: []interfaces.Tool{mockTool{name: a2aTool.Name()}}, } - err := c.validateToolNames() + if err := c.buildToolRegistry(); err != nil { + t.Fatal(err) + } + tools := append(c.toolRegistry.List(), a2aTool) + err := validateToolNames(tools) if err == nil || (!strings.Contains(err.Error(), "duplicate tool name") && !strings.Contains(err.Error(), "conflicts")) { t.Fatalf("want duplicate/conflict error, got %v", err) } @@ -1805,3 +1926,64 @@ func TestBuildAgentConfig_WithObservabilityConfig_customLogger_warnsAboutLogs(t t.Fatalf("expected warning about custom WithLogger and OTLP logs; buf=%q", out) } } + +func TestBuildAgentConfig_WithExplicitRegistryOptions(t *testing.T) { + toolReg := NewToolRegistry() + if err := toolReg.Register(mockTool{name: "native"}); err != nil { + t.Fatal(err) + } + mcpReg := NewMCPRegistry(nil) + if err := mcpReg.RegisterClient(®istryMockMCPClient{name: "mcp-srv"}); err != nil { + t.Fatal(err) + } + a2aReg := NewA2ARegistry(nil) + if err := a2aReg.RegisterClient(®istryMockA2AClient{name: "a2a-srv"}); err != nil { + t.Fatal(err) + } + subReg := NewSubAgentRegistry() + child := &Agent{agentConfig: agentConfig{Name: "Child", taskQueue: "q-child"}} + if err := child.buildRegistries(); err != nil { + t.Fatal(err) + } + if err := subReg.Register(child); err != nil { + t.Fatal(err) + } + + cfg, err := buildAgentConfig([]Option{ + WithName("parent"), + WithTemporalConfig(&TemporalConfig{TaskQueue: "q"}), + WithLLMClient(stubLLM{}), + WithToolRegistry(toolReg), + WithMCPRegistry(mcpReg), + WithA2ARegistry(a2aReg), + WithSubAgentRegistry(subReg), + WithToolApprovalPolicy(AutoToolApprovalPolicy()), + }) + if err != nil { + t.Fatal(err) + } + if cfg.toolRegistry != toolReg { + t.Fatal("WithToolRegistry should preserve user registry") + } + if cfg.mcpRegistry != mcpReg { + t.Fatal("WithMCPRegistry should preserve user registry") + } + if cfg.a2aRegistry != a2aReg { + t.Fatal("WithA2ARegistry should preserve user registry") + } + if cfg.subAgentRegistry != subReg { + t.Fatal("WithSubAgentRegistry should preserve user registry") + } + + a := &Agent{agentConfig: *cfg} + if a.ToolRegistry() != toolReg || a.MCPRegistry() != mcpReg || a.A2ARegistry() != a2aReg || a.SubAgentRegistry() != subReg { + t.Fatal("registry accessors should return configured registries") + } + tools, err := cfg.resolveTools(context.Background()) + if err != nil { + t.Fatal(err) + } + if len(tools) < 2 { + t.Fatalf("resolveTools len = %d, want at least native + sub-agent tools", len(tools)) + } +} diff --git a/pkg/agent/mcp.go b/pkg/agent/mcp.go index 23c239a..8db4daa 100644 --- a/pkg/agent/mcp.go +++ b/pkg/agent/mcp.go @@ -19,6 +19,7 @@ var ( ) var _ interfaces.Tool = (*MCPTool)(nil) +var _ interfaces.ToolKindProvider = (*MCPTool)(nil) // NOTE: MCPTools for the same server share one MCPClient. The default pkg/mcp/client serializes // RPCs on that client with a mutex; custom MCPClient implementations should document concurrency behavior. @@ -62,6 +63,9 @@ func NewMCPTool(serverName string, spec interfaces.ToolSpec, client interfaces.M } } +// ToolKind implements [interfaces.ToolKindProvider]. +func (t *MCPTool) ToolKind() string { return "mcp" } + // Name implements interfaces.Tool. func (t *MCPTool) Name() string { if t == nil { diff --git a/pkg/agent/mcp_registry.go b/pkg/agent/mcp_registry.go new file mode 100644 index 0000000..10cc3e4 --- /dev/null +++ b/pkg/agent/mcp_registry.go @@ -0,0 +1,133 @@ +package agent + +import ( + "fmt" + "strings" + "sync" + + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" + "github.com/agenticenv/agent-sdk-go/pkg/logger" + mcpclient "github.com/agenticenv/agent-sdk-go/pkg/mcp/client" +) + +var _ MCPRegistry = (*mcpRegistryImpl)(nil) + +type mcpRegistryImpl struct { + mu sync.RWMutex + logger logger.Logger + clients map[string]interfaces.MCPClient + order []string +} + +// NewMCPRegistry returns an empty MCP client registry for use with [WithMCPRegistry]. +// logger is used when [Register] builds a client from [MCPConfig]. +func NewMCPRegistry(l logger.Logger) MCPRegistry { + if l == nil { + l = NoopLogger() + } + return &mcpRegistryImpl{ + logger: l, + clients: make(map[string]interfaces.MCPClient), + } +} + +func (r *mcpRegistryImpl) Register(name string, config MCPConfig) error { + cl, err := newMCPClient(name, config, r.logger) + if err != nil { + return err + } + r.mu.Lock() + defer r.mu.Unlock() + return r.registerClientLocked(cl) +} + +func (r *mcpRegistryImpl) RegisterClient(client interfaces.MCPClient) error { + if client == nil { + return ErrRegistryNilEntry + } + name := strings.TrimSpace(client.Name()) + if name == "" { + return ErrRegistryInvalidName + } + r.mu.Lock() + defer r.mu.Unlock() + return r.registerClientLocked(client) +} + +func (r *mcpRegistryImpl) registerClientLocked(client interfaces.MCPClient) error { + name := strings.TrimSpace(client.Name()) + if name == "" { + return ErrRegistryInvalidName + } + if _, exists := r.clients[name]; exists { + return ErrRegistryDuplicate + } + r.order = append(r.order, name) + r.clients[name] = client + return nil +} + +func (r *mcpRegistryImpl) Unregister(name string) error { + name = strings.TrimSpace(name) + if name == "" { + return ErrRegistryInvalidName + } + r.mu.Lock() + defer r.mu.Unlock() + if _, ok := r.clients[name]; !ok { + return ErrRegistryNotFound + } + delete(r.clients, name) + r.order = removeFromOrder(r.order, name) + return nil +} + +func (r *mcpRegistryImpl) Get(name string) (interfaces.MCPClient, error) { + name = strings.TrimSpace(name) + if name == "" { + return nil, ErrRegistryInvalidName + } + r.mu.RLock() + defer r.mu.RUnlock() + c, ok := r.clients[name] + if !ok { + return nil, ErrRegistryNotFound + } + return c, nil +} + +func (r *mcpRegistryImpl) List() []interfaces.MCPClient { + r.mu.RLock() + defer r.mu.RUnlock() + out := make([]interfaces.MCPClient, 0, len(r.order)) + for _, name := range r.order { + if c, ok := r.clients[name]; ok { + out = append(out, c) + } + } + return out +} + +func newMCPClient(name string, cfg MCPConfig, log logger.Logger) (interfaces.MCPClient, error) { + name = strings.TrimSpace(name) + if name == "" { + return nil, ErrRegistryInvalidName + } + if cfg.Transport == nil { + return nil, fmt.Errorf("mcp %q: Transport is required", name) + } + if log == nil { + log = NoopLogger() + } + mcpOpts := []mcpclient.Option{ + mcpclient.WithLogger(log), + mcpclient.WithTimeout(cfg.Timeout), + mcpclient.WithRetryAttempts(cfg.RetryAttempts), + mcpclient.WithToolFilter(cfg.ToolFilter), + } + cl, err := mcpclient.NewClient(name, cfg.Transport, mcpOpts...) + if err != nil { + return nil, fmt.Errorf("mcp %q: new client: %w", name, err) + } + return cl, nil +} diff --git a/pkg/agent/mcp_registry_test.go b/pkg/agent/mcp_registry_test.go new file mode 100644 index 0000000..b521388 --- /dev/null +++ b/pkg/agent/mcp_registry_test.go @@ -0,0 +1,77 @@ +package agent + +import ( + "context" + "encoding/json" + "testing" + + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" +) + +type registryMockMCPClient struct { + name string +} + +func (c *registryMockMCPClient) Name() string { return c.name } +func (c *registryMockMCPClient) Ping(context.Context) error { + return nil +} +func (c *registryMockMCPClient) ListTools(context.Context) ([]interfaces.ToolSpec, error) { + return nil, nil +} +func (c *registryMockMCPClient) CallTool(context.Context, string, json.RawMessage) (json.RawMessage, error) { + return nil, nil +} +func (c *registryMockMCPClient) Close() error { return nil } + +func TestMCPRegistry_RegisterClient(t *testing.T) { + r := NewMCPRegistry(nil) + cl := ®istryMockMCPClient{name: "srv"} + if err := r.RegisterClient(cl); err != nil { + t.Fatal(err) + } + got, err := r.Get("srv") + if err != nil || got != cl { + t.Fatalf("Get(srv) = %v, %v", got, err) + } + if len(r.List()) != 1 { + t.Fatalf("List len = %d, want 1", len(r.List())) + } + if err := r.Unregister("srv"); err != nil { + t.Fatal(err) + } +} + +func TestMCPRegistry_RegisterConfigMissingTransport(t *testing.T) { + r := NewMCPRegistry(nil) + err := r.Register("bad", MCPConfig{}) + if err == nil { + t.Fatal("expected error for missing transport") + } +} + +func TestMCPRegistry_RegisterDuplicate(t *testing.T) { + r := NewMCPRegistry(nil) + cl := ®istryMockMCPClient{name: "srv"} + if err := r.RegisterClient(cl); err != nil { + t.Fatal(err) + } + if err := r.RegisterClient(®istryMockMCPClient{name: "srv"}); err != ErrRegistryDuplicate { + t.Errorf("duplicate RegisterClient err = %v, want ErrRegistryDuplicate", err) + } +} + +func TestNormalizeMCPRegistry_fromWithMCPClients(t *testing.T) { + cl := ®istryMockMCPClient{name: "srv"} + c := &agentConfig{mcpClients: []interfaces.MCPClient{cl}} + if err := c.buildMCPRegistry(); err != nil { + t.Fatal(err) + } + if c.mcpRegistry == nil { + t.Fatal("expected mcpRegistry after buildMCPRegistry") + } + got, err := c.mcpRegistry.Get("srv") + if err != nil || got != cl { + t.Fatalf("Get(srv) = %v, %v", got, err) + } +} diff --git a/pkg/agent/mocks/mock_registry.go b/pkg/agent/mocks/mock_registry.go new file mode 100644 index 0000000..92d4ec3 --- /dev/null +++ b/pkg/agent/mocks/mock_registry.go @@ -0,0 +1,361 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/agenticenv/agent-sdk-go/pkg/agent (interfaces: ToolRegistry,MCPRegistry,A2ARegistry,SubAgentRegistry) + +// Package mocks is a generated GoMock package. +package mocks + +import ( + reflect "reflect" + + agent "github.com/agenticenv/agent-sdk-go/pkg/agent" + interfaces "github.com/agenticenv/agent-sdk-go/pkg/interfaces" + gomock "github.com/golang/mock/gomock" +) + +// MockToolRegistry is a mock of ToolRegistry interface. +type MockToolRegistry struct { + ctrl *gomock.Controller + recorder *MockToolRegistryMockRecorder +} + +// MockToolRegistryMockRecorder is the mock recorder for MockToolRegistry. +type MockToolRegistryMockRecorder struct { + mock *MockToolRegistry +} + +// NewMockToolRegistry creates a new mock instance. +func NewMockToolRegistry(ctrl *gomock.Controller) *MockToolRegistry { + mock := &MockToolRegistry{ctrl: ctrl} + mock.recorder = &MockToolRegistryMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockToolRegistry) EXPECT() *MockToolRegistryMockRecorder { + return m.recorder +} + +// Get mocks base method. +func (m *MockToolRegistry) Get(arg0 string) (interfaces.Tool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", arg0) + ret0, _ := ret[0].(interfaces.Tool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get. +func (mr *MockToolRegistryMockRecorder) Get(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockToolRegistry)(nil).Get), arg0) +} + +// List mocks base method. +func (m *MockToolRegistry) List() []interfaces.Tool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "List") + ret0, _ := ret[0].([]interfaces.Tool) + return ret0 +} + +// List indicates an expected call of List. +func (mr *MockToolRegistryMockRecorder) List() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockToolRegistry)(nil).List)) +} + +// Register mocks base method. +func (m *MockToolRegistry) Register(arg0 interfaces.Tool) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Register", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Register indicates an expected call of Register. +func (mr *MockToolRegistryMockRecorder) Register(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Register", reflect.TypeOf((*MockToolRegistry)(nil).Register), arg0) +} + +// Unregister mocks base method. +func (m *MockToolRegistry) Unregister(arg0 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Unregister", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Unregister indicates an expected call of Unregister. +func (mr *MockToolRegistryMockRecorder) Unregister(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unregister", reflect.TypeOf((*MockToolRegistry)(nil).Unregister), arg0) +} + +// MockMCPRegistry is a mock of MCPRegistry interface. +type MockMCPRegistry struct { + ctrl *gomock.Controller + recorder *MockMCPRegistryMockRecorder +} + +// MockMCPRegistryMockRecorder is the mock recorder for MockMCPRegistry. +type MockMCPRegistryMockRecorder struct { + mock *MockMCPRegistry +} + +// NewMockMCPRegistry creates a new mock instance. +func NewMockMCPRegistry(ctrl *gomock.Controller) *MockMCPRegistry { + mock := &MockMCPRegistry{ctrl: ctrl} + mock.recorder = &MockMCPRegistryMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockMCPRegistry) EXPECT() *MockMCPRegistryMockRecorder { + return m.recorder +} + +// Get mocks base method. +func (m *MockMCPRegistry) Get(arg0 string) (interfaces.MCPClient, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", arg0) + ret0, _ := ret[0].(interfaces.MCPClient) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get. +func (mr *MockMCPRegistryMockRecorder) Get(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockMCPRegistry)(nil).Get), arg0) +} + +// List mocks base method. +func (m *MockMCPRegistry) List() []interfaces.MCPClient { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "List") + ret0, _ := ret[0].([]interfaces.MCPClient) + return ret0 +} + +// List indicates an expected call of List. +func (mr *MockMCPRegistryMockRecorder) List() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockMCPRegistry)(nil).List)) +} + +// Register mocks base method. +func (m *MockMCPRegistry) Register(arg0 string, arg1 agent.MCPConfig) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Register", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// Register indicates an expected call of Register. +func (mr *MockMCPRegistryMockRecorder) Register(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Register", reflect.TypeOf((*MockMCPRegistry)(nil).Register), arg0, arg1) +} + +// RegisterClient mocks base method. +func (m *MockMCPRegistry) RegisterClient(arg0 interfaces.MCPClient) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RegisterClient", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// RegisterClient indicates an expected call of RegisterClient. +func (mr *MockMCPRegistryMockRecorder) RegisterClient(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterClient", reflect.TypeOf((*MockMCPRegistry)(nil).RegisterClient), arg0) +} + +// Unregister mocks base method. +func (m *MockMCPRegistry) Unregister(arg0 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Unregister", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Unregister indicates an expected call of Unregister. +func (mr *MockMCPRegistryMockRecorder) Unregister(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unregister", reflect.TypeOf((*MockMCPRegistry)(nil).Unregister), arg0) +} + +// MockA2ARegistry is a mock of A2ARegistry interface. +type MockA2ARegistry struct { + ctrl *gomock.Controller + recorder *MockA2ARegistryMockRecorder +} + +// MockA2ARegistryMockRecorder is the mock recorder for MockA2ARegistry. +type MockA2ARegistryMockRecorder struct { + mock *MockA2ARegistry +} + +// NewMockA2ARegistry creates a new mock instance. +func NewMockA2ARegistry(ctrl *gomock.Controller) *MockA2ARegistry { + mock := &MockA2ARegistry{ctrl: ctrl} + mock.recorder = &MockA2ARegistryMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockA2ARegistry) EXPECT() *MockA2ARegistryMockRecorder { + return m.recorder +} + +// Get mocks base method. +func (m *MockA2ARegistry) Get(arg0 string) (interfaces.A2AClient, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", arg0) + ret0, _ := ret[0].(interfaces.A2AClient) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get. +func (mr *MockA2ARegistryMockRecorder) Get(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockA2ARegistry)(nil).Get), arg0) +} + +// List mocks base method. +func (m *MockA2ARegistry) List() []interfaces.A2AClient { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "List") + ret0, _ := ret[0].([]interfaces.A2AClient) + return ret0 +} + +// List indicates an expected call of List. +func (mr *MockA2ARegistryMockRecorder) List() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockA2ARegistry)(nil).List)) +} + +// Register mocks base method. +func (m *MockA2ARegistry) Register(arg0 string, arg1 agent.A2AConfig) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Register", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// Register indicates an expected call of Register. +func (mr *MockA2ARegistryMockRecorder) Register(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Register", reflect.TypeOf((*MockA2ARegistry)(nil).Register), arg0, arg1) +} + +// RegisterClient mocks base method. +func (m *MockA2ARegistry) RegisterClient(arg0 interfaces.A2AClient) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RegisterClient", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// RegisterClient indicates an expected call of RegisterClient. +func (mr *MockA2ARegistryMockRecorder) RegisterClient(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterClient", reflect.TypeOf((*MockA2ARegistry)(nil).RegisterClient), arg0) +} + +// Unregister mocks base method. +func (m *MockA2ARegistry) Unregister(arg0 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Unregister", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Unregister indicates an expected call of Unregister. +func (mr *MockA2ARegistryMockRecorder) Unregister(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unregister", reflect.TypeOf((*MockA2ARegistry)(nil).Unregister), arg0) +} + +// MockSubAgentRegistry is a mock of SubAgentRegistry interface. +type MockSubAgentRegistry struct { + ctrl *gomock.Controller + recorder *MockSubAgentRegistryMockRecorder +} + +// MockSubAgentRegistryMockRecorder is the mock recorder for MockSubAgentRegistry. +type MockSubAgentRegistryMockRecorder struct { + mock *MockSubAgentRegistry +} + +// NewMockSubAgentRegistry creates a new mock instance. +func NewMockSubAgentRegistry(ctrl *gomock.Controller) *MockSubAgentRegistry { + mock := &MockSubAgentRegistry{ctrl: ctrl} + mock.recorder = &MockSubAgentRegistryMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockSubAgentRegistry) EXPECT() *MockSubAgentRegistryMockRecorder { + return m.recorder +} + +// Get mocks base method. +func (m *MockSubAgentRegistry) Get(arg0 string) (*agent.Agent, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", arg0) + ret0, _ := ret[0].(*agent.Agent) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get. +func (mr *MockSubAgentRegistryMockRecorder) Get(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockSubAgentRegistry)(nil).Get), arg0) +} + +// List mocks base method. +func (m *MockSubAgentRegistry) List() []*agent.Agent { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "List") + ret0, _ := ret[0].([]*agent.Agent) + return ret0 +} + +// List indicates an expected call of List. +func (mr *MockSubAgentRegistryMockRecorder) List() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockSubAgentRegistry)(nil).List)) +} + +// Register mocks base method. +func (m *MockSubAgentRegistry) Register(arg0 *agent.Agent) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Register", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Register indicates an expected call of Register. +func (mr *MockSubAgentRegistryMockRecorder) Register(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Register", reflect.TypeOf((*MockSubAgentRegistry)(nil).Register), arg0) +} + +// Unregister mocks base method. +func (m *MockSubAgentRegistry) Unregister(arg0 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Unregister", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Unregister indicates an expected call of Unregister. +func (mr *MockSubAgentRegistryMockRecorder) Unregister(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unregister", reflect.TypeOf((*MockSubAgentRegistry)(nil).Unregister), arg0) +} diff --git a/pkg/agent/registry.go b/pkg/agent/registry.go new file mode 100644 index 0000000..cfa3169 --- /dev/null +++ b/pkg/agent/registry.go @@ -0,0 +1,54 @@ +package agent + +import ( + "errors" + + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" +) + +//go:generate mockgen -destination=./mocks/mock_registry.go -package=mocks github.com/agenticenv/agent-sdk-go/pkg/agent ToolRegistry,MCPRegistry,A2ARegistry,SubAgentRegistry + +var ( + // ErrRegistryNotFound is returned when Get or Unregister cannot find the name. + ErrRegistryNotFound = errors.New("agent: registry entry not found") + // ErrRegistryDuplicate is returned when Register would overwrite an existing name. + ErrRegistryDuplicate = errors.New("agent: registry entry already exists") + // ErrRegistryInvalidName is returned when name is empty after trim. + ErrRegistryInvalidName = errors.New("agent: registry name must not be empty") + // ErrRegistryNilEntry is returned when Register is called with a nil tool, client, or sub-agent. + ErrRegistryNilEntry = errors.New("agent: registry entry must not be nil") +) + +// ToolRegistry stores tools for an agent. +type ToolRegistry interface { + Register(tool interfaces.Tool) error + Get(name string) (interfaces.Tool, error) + List() []interfaces.Tool + Unregister(name string) error +} + +// MCPRegistry stores MCP clients for an agent. +type MCPRegistry interface { + Register(name string, config MCPConfig) error + RegisterClient(client interfaces.MCPClient) error + Get(name string) (interfaces.MCPClient, error) + List() []interfaces.MCPClient + Unregister(name string) error +} + +// A2ARegistry stores A2A clients for an agent. +type A2ARegistry interface { + Register(name string, config A2AConfig) error + RegisterClient(client interfaces.A2AClient) error + Get(name string) (interfaces.A2AClient, error) + List() []interfaces.A2AClient + Unregister(name string) error +} + +// SubAgentRegistry stores sub-agents for a parent agent. +type SubAgentRegistry interface { + Register(sub *Agent) error + Get(name string) (*Agent, error) + List() []*Agent + Unregister(name string) error +} diff --git a/pkg/agent/retriever.go b/pkg/agent/retriever.go index f8e1901..057642c 100644 --- a/pkg/agent/retriever.go +++ b/pkg/agent/retriever.go @@ -20,6 +20,7 @@ var ( ) var _ interfaces.Tool = (*RetrieverTool)(nil) +var _ interfaces.ToolKindProvider = (*RetrieverTool)(nil) // RetrieverTool implements [interfaces.Tool] for [RetrieverModeAgentic] and [RetrieverModeHybrid]. type RetrieverTool struct { @@ -60,6 +61,9 @@ func NewRetrieverTool(retriever interfaces.Retriever) interfaces.Tool { return &RetrieverTool{RetrieverName: rn, Retriever: retriever} } +// ToolKind implements [interfaces.ToolKindProvider]. +func (t *RetrieverTool) ToolKind() string { return "retriever" } + // Name implements [interfaces.Tool]. func (t *RetrieverTool) Name() string { if t == nil { diff --git a/pkg/agent/runtime_factory.go b/pkg/agent/runtime_factory.go index 6c24c9f..6be46d0 100644 --- a/pkg/agent/runtime_factory.go +++ b/pkg/agent/runtime_factory.go @@ -14,7 +14,7 @@ func (cfg *agentConfig) buildTemporalRuntime(remoteWorker bool) (*temporal.Tempo options := []temporal.Option{ temporal.WithLogger(cfg.logger), temporal.WithAgentSpec(cfg.runtimeAgentSpec()), - temporal.WithAgentExecution(cfg.runtimeAgentExecution()), + temporal.WithAgentConfig(cfg.runtimeAgentConfig()), temporal.WithPolicyFingerprint(toolPolicyFingerprint(cfg.toolApprovalPolicy)), temporal.WithMCPFingerprint(mcpConfigFingerprint(cfg.mcpServers, mcpExtraClientNames(cfg.mcpClients))), temporal.WithA2AFingerprint(a2aConfigFingerprint(cfg.a2aServers, a2aExtraClientNames(cfg.a2aClients))), @@ -28,6 +28,7 @@ func (cfg *agentConfig) buildTemporalRuntime(remoteWorker bool) (*temporal.Tempo // Never allow fingerprint bypass on remote worker runtime. temporal.WithDisableFingerprintCheck(cfg.disableFingerprintCheck && !remoteWorker), temporal.WithRemoteWorker(remoteWorker), + temporal.WithToolsResolver(cfg.resolveTools), } if cfg.temporalConfig != nil { options = append(options, temporal.WithTemporalConfig(cfg.temporalConfig)) @@ -37,7 +38,6 @@ func (cfg *agentConfig) buildTemporalRuntime(remoteWorker bool) (*temporal.Tempo if cfg.instanceId != "" { options = append(options, temporal.WithInstanceId(cfg.instanceId)) } - // Event pipeline runs only on the client runtime; always set so worker runtimes get false explicitly. enableRemote := !remoteWorker && cfg.enableRemoteWorkers options = append(options, temporal.WithEnableRemoteWorkers(enableRemote)) return temporal.NewTemporalRuntime(options...) @@ -48,7 +48,7 @@ func (cfg *agentConfig) buildLocalRuntime() (*local.LocalRuntime, error) { local.WithLogger(cfg.logger), local.WithToolExecutionMode(cfg.agentToolExecutionMode), local.WithAgentSpec(cfg.runtimeAgentSpec()), - local.WithAgentExecution(cfg.runtimeAgentExecution()), + local.WithAgentConfig(cfg.runtimeAgentConfig()), local.WithTracer(cfg.tracer), local.WithMetrics(cfg.metrics), } diff --git a/pkg/agent/subagent.go b/pkg/agent/subagent.go index 781502c..6be6237 100644 --- a/pkg/agent/subagent.go +++ b/pkg/agent/subagent.go @@ -14,6 +14,7 @@ import ( var _ AgentTool = (*subAgentTool)(nil) var _ interfaces.Tool = (*subAgentTool)(nil) +var _ interfaces.ToolKindProvider = (*subAgentTool)(nil) // Sub-agent tool names must be identifier-like for LLM tool APIs; normalize display names accordingly. var subAgentToolNameNonIdent = regexp.MustCompile(`[^a-zA-Z0-9]+`) @@ -32,9 +33,8 @@ var ErrSubAgentNameInvalid = errors.New("sub-agent name invalid for delegation t // AgentTool marks a tool that represents sub-agent delegation (child AgentWorkflow), not normal Tool.Execute. // -// AgentWorkflow chooses delegation vs AgentToolExecuteActivity using SubAgentRoutes keyed by tool name, not by -// asserting AgentTool in workflow code. AgentTool is still used elsewhere (e.g. toolApprovalMetadata walks -// toolsList and asserts AgentTool to set delegation fields on approval events). +// The runtime routes tool calls to sub-agents using sub-agent routes keyed by tool name. +// AgentTool is also used when building approval metadata for delegation tools. type AgentTool interface { interfaces.Tool // SubAgent returns the sub-agent this tool delegates to. @@ -118,3 +118,6 @@ func (t *subAgentTool) Execute(_ context.Context, _ map[string]any) (any, error) } func (t *subAgentTool) SubAgent() *Agent { return t.agent } + +// ToolKind implements [interfaces.ToolKindProvider]. +func (t *subAgentTool) ToolKind() string { return "sub-agent" } diff --git a/pkg/agent/subagent_registry.go b/pkg/agent/subagent_registry.go new file mode 100644 index 0000000..038c721 --- /dev/null +++ b/pkg/agent/subagent_registry.go @@ -0,0 +1,80 @@ +package agent + +import ( + "strings" + "sync" +) + +var _ SubAgentRegistry = (*subAgentRegistryImpl)(nil) + +type subAgentRegistryImpl struct { + mu sync.RWMutex + agents map[string]*Agent + order []string +} + +// NewSubAgentRegistry returns an empty sub-agent registry for use with [WithSubAgentRegistry]. +func NewSubAgentRegistry() SubAgentRegistry { + return &subAgentRegistryImpl{ + agents: make(map[string]*Agent), + } +} + +func (r *subAgentRegistryImpl) Register(sub *Agent) error { + if sub == nil { + return ErrRegistryNilEntry + } + name := strings.TrimSpace(sub.Name) + if name == "" { + return ErrRegistryInvalidName + } + r.mu.Lock() + defer r.mu.Unlock() + if _, exists := r.agents[name]; exists { + return ErrRegistryDuplicate + } + r.order = append(r.order, name) + r.agents[name] = sub + return nil +} + +func (r *subAgentRegistryImpl) Unregister(name string) error { + name = strings.TrimSpace(name) + if name == "" { + return ErrRegistryInvalidName + } + r.mu.Lock() + defer r.mu.Unlock() + if _, ok := r.agents[name]; !ok { + return ErrRegistryNotFound + } + delete(r.agents, name) + r.order = removeFromOrder(r.order, name) + return nil +} + +func (r *subAgentRegistryImpl) Get(name string) (*Agent, error) { + name = strings.TrimSpace(name) + if name == "" { + return nil, ErrRegistryInvalidName + } + r.mu.RLock() + defer r.mu.RUnlock() + a, ok := r.agents[name] + if !ok { + return nil, ErrRegistryNotFound + } + return a, nil +} + +func (r *subAgentRegistryImpl) List() []*Agent { + r.mu.RLock() + defer r.mu.RUnlock() + out := make([]*Agent, 0, len(r.order)) + for _, name := range r.order { + if a, ok := r.agents[name]; ok { + out = append(out, a) + } + } + return out +} diff --git a/pkg/agent/subagent_registry_test.go b/pkg/agent/subagent_registry_test.go new file mode 100644 index 0000000..1250d3b --- /dev/null +++ b/pkg/agent/subagent_registry_test.go @@ -0,0 +1,49 @@ +package agent + +import "testing" + +func TestSubAgentRegistry_RegisterGet(t *testing.T) { + r := NewSubAgentRegistry() + sub := &Agent{agentConfig: agentConfig{Name: "Math"}} + if err := r.Register(sub); err != nil { + t.Fatal(err) + } + + got, err := r.Get("Math") + if err != nil || got != sub { + t.Fatalf("Get(Math) = %v, %v", got, err) + } + if len(r.List()) != 1 { + t.Fatalf("List len = %d, want 1", len(r.List())) + } + if err := r.Unregister("Math"); err != nil { + t.Fatal(err) + } +} + +func TestSubAgentRegistry_RegisterNilAndEmptyName(t *testing.T) { + r := NewSubAgentRegistry() + if err := r.Register(nil); err != ErrRegistryNilEntry { + t.Errorf("Register(nil) err = %v", err) + } + if err := r.Register(&Agent{agentConfig: agentConfig{Name: " "}}); err != ErrRegistryInvalidName { + t.Errorf("empty name err = %v", err) + } + if len(r.List()) != 0 { + t.Fatalf("List len = %d, want 0", len(r.List())) + } +} + +func TestNormalizeSubAgentRegistry_fromWithSubAgents(t *testing.T) { + sub := &Agent{agentConfig: agentConfig{Name: "Helper"}} + c := &agentConfig{subAgents: []*Agent{sub}, maxSubAgentDepth: 3} + if err := c.buildSubAgentRegistry(); err != nil { + t.Fatal(err) + } + if len(c.subAgents) != 0 { + t.Fatal("WithSubAgents entries should be cleared after buildSubAgentRegistry") + } + if len(c.subAgentRegistry.List()) != 1 { + t.Fatalf("registry len = %d, want 1", len(c.subAgentRegistry.List())) + } +} diff --git a/pkg/agent/tool_registry.go b/pkg/agent/tool_registry.go new file mode 100644 index 0000000..6d55e57 --- /dev/null +++ b/pkg/agent/tool_registry.go @@ -0,0 +1,102 @@ +package agent + +import ( + "strings" + "sync" + + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" +) + +var _ ToolRegistry = (*toolRegistryImpl)(nil) + +type toolRegistryImpl struct { + mu sync.RWMutex + tools map[string]interfaces.Tool + order []string +} + +// NewToolRegistry returns an empty in-process tool registry for use with [WithToolRegistry]. +func NewToolRegistry() ToolRegistry { + return &toolRegistryImpl{ + tools: make(map[string]interfaces.Tool), + } +} + +func (r *toolRegistryImpl) Register(tool interfaces.Tool) error { + if tool == nil { + return ErrRegistryNilEntry + } + name := strings.TrimSpace(tool.Name()) + if name == "" { + return ErrRegistryInvalidName + } + r.mu.Lock() + defer r.mu.Unlock() + if _, exists := r.tools[name]; exists { + return ErrRegistryDuplicate + } + r.order = append(r.order, name) + r.tools[name] = tool + return nil +} + +func (r *toolRegistryImpl) Unregister(name string) error { + name = strings.TrimSpace(name) + if name == "" { + return ErrRegistryInvalidName + } + r.mu.Lock() + defer r.mu.Unlock() + if _, ok := r.tools[name]; !ok { + return ErrRegistryNotFound + } + delete(r.tools, name) + r.order = removeFromOrder(r.order, name) + return nil +} + +func (r *toolRegistryImpl) Get(name string) (interfaces.Tool, error) { + name = strings.TrimSpace(name) + if name == "" { + return nil, ErrRegistryInvalidName + } + r.mu.RLock() + defer r.mu.RUnlock() + t, ok := r.tools[name] + if !ok { + return nil, ErrRegistryNotFound + } + return t, nil +} + +func (r *toolRegistryImpl) List() []interfaces.Tool { + r.mu.RLock() + defer r.mu.RUnlock() + result := make([]interfaces.Tool, 0, len(r.order)) + for _, name := range r.order { + if t, ok := r.tools[name]; ok { + result = append(result, t) + } + } + return result +} + +// removeFromOrder removes the first occurrence of name from a string slice and returns the result. +func removeFromOrder(order []string, name string) []string { + for i, n := range order { + if n == name { + return append(order[:i], order[i+1:]...) + } + } + return order +} + +// RegisterTools registers each tool on reg, returning the first error. +func RegisterTools(reg ToolRegistry, tools ...interfaces.Tool) error { + for _, t := range tools { + if err := reg.Register(t); err != nil { + return err + } + } + return nil +} diff --git a/pkg/agent/tool_registry_test.go b/pkg/agent/tool_registry_test.go new file mode 100644 index 0000000..4646a1c --- /dev/null +++ b/pkg/agent/tool_registry_test.go @@ -0,0 +1,135 @@ +package agent + +import ( + "context" + "testing" + + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" +) + +type registryMockTool struct { + name string +} + +func (m registryMockTool) Name() string { return m.name } +func (m registryMockTool) DisplayName() string { return "Mock" } +func (m registryMockTool) Description() string { return "mock" } +func (m registryMockTool) Parameters() interfaces.JSONSchema { return interfaces.JSONSchema{} } +func (m registryMockTool) Execute(ctx context.Context, args map[string]any) (any, error) { + return nil, nil +} + +func TestNewToolRegistry(t *testing.T) { + r := NewToolRegistry() + if r == nil { + t.Fatal("NewToolRegistry should not return nil") + } + if len(r.List()) != 0 { + t.Errorf("new registry should have 0 tools, got %d", len(r.List())) + } +} + +func TestToolRegistry_RegisterGet(t *testing.T) { + r := NewToolRegistry() + if err := r.Register(registryMockTool{name: "mock1"}); err != nil { + t.Fatal(err) + } + + tool, err := r.Get("mock1") + if err != nil || tool == nil { + t.Fatalf("Get(mock1) = %v, %v", tool, err) + } + if tool.Name() != "mock1" { + t.Errorf("tool.Name() = %q, want mock1", tool.Name()) + } + + if _, err := r.Get("nonexistent"); err != ErrRegistryNotFound { + t.Errorf("Get(nonexistent) err = %v, want ErrRegistryNotFound", err) + } +} + +func TestToolRegistry_RegisterDuplicate(t *testing.T) { + r := NewToolRegistry() + if err := r.Register(registryMockTool{name: "mock1"}); err != nil { + t.Fatal(err) + } + if err := r.Register(registryMockTool{name: "mock1"}); err != ErrRegistryDuplicate { + t.Errorf("second Register err = %v, want ErrRegistryDuplicate", err) + } +} + +func TestToolRegistry_RegisterNil(t *testing.T) { + r := NewToolRegistry() + if err := r.Register(nil); err != ErrRegistryNilEntry { + t.Errorf("Register(nil) err = %v, want ErrRegistryNilEntry", err) + } +} + +func TestToolRegistry_Unregister(t *testing.T) { + r := NewToolRegistry() + if err := r.Register(registryMockTool{name: "a"}); err != nil { + t.Fatal(err) + } + if err := r.Unregister("a"); err != nil { + t.Fatalf("Unregister(a) err = %v", err) + } + if _, err := r.Get("a"); err != ErrRegistryNotFound { + t.Error("Get(a) after Unregister should be ErrRegistryNotFound") + } + if err := r.Unregister("missing"); err != ErrRegistryNotFound { + t.Errorf("Unregister(missing) err = %v, want ErrRegistryNotFound", err) + } +} + +func TestToolRegistry_ListOrder(t *testing.T) { + r := NewToolRegistry() + _ = r.Register(registryMockTool{name: "a"}) + _ = r.Register(registryMockTool{name: "b"}) + + tools := r.List() + if len(tools) != 2 { + t.Fatalf("got %d tools, want 2", len(tools)) + } + if tools[0].Name() != "a" || tools[1].Name() != "b" { + t.Errorf("List order = %q, %q; want a, b", tools[0].Name(), tools[1].Name()) + } +} + +func TestRegisterTools(t *testing.T) { + r := NewToolRegistry() + if err := RegisterTools(r, registryMockTool{name: "a"}, registryMockTool{name: "b"}); err != nil { + t.Fatal(err) + } + if len(r.List()) != 2 { + t.Fatalf("List len = %d, want 2", len(r.List())) + } +} + +func TestNormalizeToolRegistry_fromWithTools(t *testing.T) { + c := &agentConfig{tools: []interfaces.Tool{registryMockTool{name: "a"}}} + if err := c.buildToolRegistry(); err != nil { + t.Fatal(err) + } + if len(c.tools) != 0 { + t.Fatal("WithTools entries should be cleared after buildToolRegistry") + } + if len(c.toolRegistry.List()) != 1 { + t.Fatalf("registry len = %d, want 1", len(c.toolRegistry.List())) + } +} + +func TestNormalizeToolRegistry_userRegistryWins(t *testing.T) { + reg := NewToolRegistry() + _ = reg.Register(registryMockTool{name: "existing"}) + c := &agentConfig{ + toolRegistry: reg, + tools: []interfaces.Tool{registryMockTool{name: "from_with_tools"}}, + } + if err := c.buildToolRegistry(); err != nil { + t.Fatal(err) + } + tools := c.toolRegistry.List() + if len(tools) != 2 { + t.Fatalf("registry len = %d, want 2", len(tools)) + } +} diff --git a/pkg/interfaces/mocks/mock_tool.go b/pkg/interfaces/mocks/mock_tool.go index 1fa7263..068c279 100644 --- a/pkg/interfaces/mocks/mock_tool.go +++ b/pkg/interfaces/mocks/mock_tool.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/agenticenv/agent-sdk-go/pkg/interfaces (interfaces: Tool,ToolRegistry,ToolApproval,ToolAuthorizer) +// Source: github.com/agenticenv/agent-sdk-go/pkg/interfaces (interfaces: Tool,ToolApproval,ToolAuthorizer,ToolKindProvider) // Package mocks is a generated GoMock package. package mocks @@ -107,70 +107,6 @@ func (mr *MockToolMockRecorder) Parameters() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Parameters", reflect.TypeOf((*MockTool)(nil).Parameters)) } -// MockToolRegistry is a mock of ToolRegistry interface. -type MockToolRegistry struct { - ctrl *gomock.Controller - recorder *MockToolRegistryMockRecorder -} - -// MockToolRegistryMockRecorder is the mock recorder for MockToolRegistry. -type MockToolRegistryMockRecorder struct { - mock *MockToolRegistry -} - -// NewMockToolRegistry creates a new mock instance. -func NewMockToolRegistry(ctrl *gomock.Controller) *MockToolRegistry { - mock := &MockToolRegistry{ctrl: ctrl} - mock.recorder = &MockToolRegistryMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockToolRegistry) EXPECT() *MockToolRegistryMockRecorder { - return m.recorder -} - -// Get mocks base method. -func (m *MockToolRegistry) Get(arg0 string) (interfaces.Tool, bool) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Get", arg0) - ret0, _ := ret[0].(interfaces.Tool) - ret1, _ := ret[1].(bool) - return ret0, ret1 -} - -// Get indicates an expected call of Get. -func (mr *MockToolRegistryMockRecorder) Get(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockToolRegistry)(nil).Get), arg0) -} - -// Register mocks base method. -func (m *MockToolRegistry) Register(arg0 interfaces.Tool) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Register", arg0) -} - -// Register indicates an expected call of Register. -func (mr *MockToolRegistryMockRecorder) Register(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Register", reflect.TypeOf((*MockToolRegistry)(nil).Register), arg0) -} - -// Tools mocks base method. -func (m *MockToolRegistry) Tools() []interfaces.Tool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Tools") - ret0, _ := ret[0].([]interfaces.Tool) - return ret0 -} - -// Tools indicates an expected call of Tools. -func (mr *MockToolRegistryMockRecorder) Tools() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Tools", reflect.TypeOf((*MockToolRegistry)(nil).Tools)) -} - // MockToolApproval is a mock of ToolApproval interface. type MockToolApproval struct { ctrl *gomock.Controller @@ -245,3 +181,40 @@ func (mr *MockToolAuthorizerMockRecorder) Authorize(arg0, arg1 interface{}) *gom mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Authorize", reflect.TypeOf((*MockToolAuthorizer)(nil).Authorize), arg0, arg1) } + +// MockToolKindProvider is a mock of ToolKindProvider interface. +type MockToolKindProvider struct { + ctrl *gomock.Controller + recorder *MockToolKindProviderMockRecorder +} + +// MockToolKindProviderMockRecorder is the mock recorder for MockToolKindProvider. +type MockToolKindProviderMockRecorder struct { + mock *MockToolKindProvider +} + +// NewMockToolKindProvider creates a new mock instance. +func NewMockToolKindProvider(ctrl *gomock.Controller) *MockToolKindProvider { + mock := &MockToolKindProvider{ctrl: ctrl} + mock.recorder = &MockToolKindProviderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockToolKindProvider) EXPECT() *MockToolKindProviderMockRecorder { + return m.recorder +} + +// ToolKind mocks base method. +func (m *MockToolKindProvider) ToolKind() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ToolKind") + ret0, _ := ret[0].(string) + return ret0 +} + +// ToolKind indicates an expected call of ToolKind. +func (mr *MockToolKindProviderMockRecorder) ToolKind() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ToolKind", reflect.TypeOf((*MockToolKindProvider)(nil).ToolKind)) +} diff --git a/pkg/interfaces/tool.go b/pkg/interfaces/tool.go index 0127ed0..b8aa00d 100644 --- a/pkg/interfaces/tool.go +++ b/pkg/interfaces/tool.go @@ -6,7 +6,7 @@ import ( "github.com/agenticenv/agent-sdk-go/internal/types" ) -//go:generate mockgen -destination=./mocks/mock_tool.go -package=mocks github.com/agenticenv/agent-sdk-go/pkg/interfaces Tool,ToolRegistry,ToolApproval,ToolAuthorizer +//go:generate mockgen -destination=./mocks/mock_tool.go -package=mocks github.com/agenticenv/agent-sdk-go/pkg/interfaces Tool,ToolApproval,ToolAuthorizer,ToolKindProvider // ToolApproval is an optional interface for tools that require interactive human approval before execution. // When implemented, the agent honors ApprovalRequired() when no agent-level approval policy is set. @@ -74,14 +74,17 @@ func ToolsToSpecs(tools []Tool) []ToolSpec { return specs } -// ToolRegistry manages a collection of tools. Use for registering and looking up tools by name. -type ToolRegistry interface { - // Register adds a tool. Overwrites if a tool with the same name exists. - Register(tool Tool) - - // Get returns the tool by name, or (nil, false) if not found. - Get(name string) (Tool, bool) +// ToolKindProvider is an optional interface for tools that report their origin. +type ToolKindProvider interface { + ToolKind() string +} - // Tools returns all registered tools in registration order. - Tools() []Tool +// KindOf returns ToolKind() from t when implemented, or "native". +func KindOf(t Tool) string { + if k, ok := t.(ToolKindProvider); ok { + if s := k.ToolKind(); s != "" { + return s + } + } + return "native" } diff --git a/pkg/interfaces/tool_test.go b/pkg/interfaces/tool_test.go new file mode 100644 index 0000000..73eda79 --- /dev/null +++ b/pkg/interfaces/tool_test.go @@ -0,0 +1,40 @@ +package interfaces + +import ( + "context" + "testing" +) + +type stubKindTool struct{ kind string } + +func (s stubKindTool) ToolKind() string { return s.kind } +func (stubKindTool) Name() string { return "x" } +func (stubKindTool) DisplayName() string { return "x" } +func (stubKindTool) Description() string { return "" } +func (stubKindTool) Parameters() JSONSchema { return JSONSchema{"type": "object"} } +func (stubKindTool) Execute(_ context.Context, _ map[string]any) (any, error) { return nil, nil } + +type stubNativeTool struct{} + +func (stubNativeTool) Name() string { return "n" } +func (stubNativeTool) DisplayName() string { return "n" } +func (stubNativeTool) Description() string { return "" } +func (stubNativeTool) Parameters() JSONSchema { return JSONSchema{"type": "object"} } +func (stubNativeTool) Execute(_ context.Context, _ map[string]any) (any, error) { + return nil, nil +} + +func TestKindOf(t *testing.T) { + if KindOf(nil) != "native" { + t.Fatalf("nil = %q", KindOf(nil)) + } + if KindOf(stubNativeTool{}) != "native" { + t.Fatal("native tool without provider") + } + if KindOf(stubKindTool{kind: "mcp"}) != "mcp" { + t.Fatal("mcp kind") + } + if KindOf(stubKindTool{kind: ""}) != "native" { + t.Fatal("empty kind falls back to native") + } +} diff --git a/pkg/tools/registry.go b/pkg/tools/registry.go deleted file mode 100644 index 12c9042..0000000 --- a/pkg/tools/registry.go +++ /dev/null @@ -1,59 +0,0 @@ -package tools - -import ( - "sync" - - "github.com/agenticenv/agent-sdk-go/pkg/interfaces" -) - -var _ interfaces.ToolRegistry = (*Registry)(nil) - -// Registry is an in-memory ToolRegistry implementation. -type Registry struct { - mu sync.RWMutex - tools map[string]interfaces.Tool - order []string // preserve registration order for Tools() -} - -// NewRegistry returns a new empty ToolRegistry. -func NewRegistry() *Registry { - return &Registry{ - tools: make(map[string]interfaces.Tool), - order: nil, - } -} - -// Register adds a tool. Overwrites if a tool with the same name exists. -func (r *Registry) Register(tool interfaces.Tool) { - if tool == nil { - return - } - r.mu.Lock() - defer r.mu.Unlock() - name := tool.Name() - if _, exists := r.tools[name]; !exists { - r.order = append(r.order, name) - } - r.tools[name] = tool -} - -// Get returns the tool by name, or (nil, false) if not found. -func (r *Registry) Get(name string) (interfaces.Tool, bool) { - r.mu.RLock() - defer r.mu.RUnlock() - t, ok := r.tools[name] - return t, ok -} - -// Tools returns all registered tools in registration order. -func (r *Registry) Tools() []interfaces.Tool { - r.mu.RLock() - defer r.mu.RUnlock() - result := make([]interfaces.Tool, 0, len(r.order)) - for _, name := range r.order { - if t, ok := r.tools[name]; ok { - result = append(result, t) - } - } - return result -} diff --git a/pkg/tools/registry_test.go b/pkg/tools/registry_test.go deleted file mode 100644 index 6c72c98..0000000 --- a/pkg/tools/registry_test.go +++ /dev/null @@ -1,88 +0,0 @@ -package tools - -import ( - "context" - "testing" - - "github.com/agenticenv/agent-sdk-go/pkg/interfaces" -) - -// mockTool for registry tests (avoids import cycle with calculator/echo). -type mockTool struct { - name string -} - -func (m mockTool) Name() string { return m.name } -func (m mockTool) DisplayName() string { return "Mock" } -func (m mockTool) Description() string { return "mock" } -func (m mockTool) Parameters() interfaces.JSONSchema { return interfaces.JSONSchema{} } -func (m mockTool) Execute(ctx context.Context, args map[string]any) (any, error) { - return nil, nil -} - -func TestNewRegistry(t *testing.T) { - r := NewRegistry() - if r == nil { - t.Fatal("NewRegistry should not return nil") - } - tools := r.Tools() - if len(tools) != 0 { - t.Errorf("new registry should have 0 tools, got %d", len(tools)) - } -} - -func TestRegistry_RegisterGet(t *testing.T) { - r := NewRegistry() - mt := mockTool{name: "mock1"} - r.Register(mt) - - tool, ok := r.Get("mock1") - if !ok || tool == nil { - t.Fatal("Get(mock1) should return tool") - } - if tool.Name() != "mock1" { - t.Errorf("tool.Name() = %q, want mock1", tool.Name()) - } - - _, ok = r.Get("nonexistent") - if ok { - t.Error("Get(nonexistent) should return false") - } -} - -func TestRegistry_RegisterOverwrite(t *testing.T) { - r := NewRegistry() - mt := mockTool{name: "mock1"} - r.Register(mt) - r.Register(mt) // same tool again - - tools := r.Tools() - if len(tools) != 1 { - t.Errorf("overwrite should keep 1 tool, got %d", len(tools)) - } -} - -func TestRegistry_RegisterNil(t *testing.T) { - r := NewRegistry() - r.Register(nil) - - tools := r.Tools() - if len(tools) != 0 { - t.Error("Register(nil) should be ignored") - } -} - -func TestRegistry_ToolsOrder(t *testing.T) { - r := NewRegistry() - r.Register(mockTool{name: "a"}) - r.Register(mockTool{name: "b"}) - r.Register(mockTool{name: "a"}) // overwrite a - - tools := r.Tools() - if len(tools) != 2 { - t.Fatalf("got %d tools, want 2", len(tools)) - } - if tools[0].Name() != "a" || tools[1].Name() != "b" { - t.Errorf("Tools order = %q, %q; want a, b", tools[0].Name(), tools[1].Name()) - } -} diff --git a/pkg/tools/schema.go b/pkg/tools/schema.go index 4374682..028a7af 100644 --- a/pkg/tools/schema.go +++ b/pkg/tools/schema.go @@ -1,4 +1,4 @@ -// Package tools provides ToolRegistry implementation and schema helpers for building +// Package tools provides schema helpers for building tool parameter JSON schemas. // JSON Schema parameter definitions in a type-safe way. // // Example: diff --git a/taskfiles/examples.yml b/taskfiles/examples.yml index 8d60104..d8839d5 100644 --- a/taskfiles/examples.yml +++ b/taskfiles/examples.yml @@ -193,6 +193,7 @@ tasks: - simple_agent - agent_with_tools/basic - agent_with_tools/custom + - agent_with_tools/dynamic_registry - agent_with_tools/authorizer - agent_with_json_response - agent_with_stream