diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8f0972b..97792d7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -49,3 +49,31 @@ jobs: - name: Build run: make build + + # examples:local — off until if: true and EXAMPLES_* repo secrets (examples/.env.defaults). + examples: + if: false + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version-file: go.mod + + - name: Install Task + run: go install github.com/go-task/task/v3/cmd/task@latest + + - name: Configure examples/.env + run: | + cat > examples/.env < **Most agent frameworks live and die inside a single process: if your server restarts, the run is lost.** Here, every agent run is a Temporal workflow end to end. Runs survive crashes and deploys, respect timeouts and retries, and are observable as real service operations. There is no execution path outside Temporal. +Every agent run on the **Temporal runtime** is a durable workflow: it survives process crashes and deploys, supports horizontal scaling, and is observable as a real service operation. This is the recommended runtime for production workloads where run durability matters. A **running Temporal server** is required. -`pkg/agent` exposes three entry points — `Run`, `Stream`, and `RunAsync` — each mapped directly to a Temporal workflow. Connect via `WithTemporalConfig` or `WithTemporalClient` to your cluster. See [Getting Started](#getting-started) to set up, or [Temporal Runtime](#temporal-runtime) for deeper detail on workers, queues, and streaming. +The **in-process runtime** runs the agent loop directly in your process with no external dependencies — ideal for development, testing, and deployments where Temporal is not available. It has full feature parity: tools, MCP, A2A, sub-agents, streaming, AG-UI, approvals, conversation, and observability. -## Capabilities +`pkg/agent` exposes three entry points — `Run`, `Stream`, and `RunAsync`. Add `WithTemporalConfig` or `WithTemporalClient` for the Temporal runtime; omit it for in-process. See [Getting Started](#getting-started) or [Runtimes](#runtimes). -> Every agent run is a Temporal workflow: durable, replay-safe, and observable. No in-memory execution path. +## Capabilities - **LLM providers** — OpenAI, Anthropic, and Gemini out of the box; bring your own via `interfaces.LLMClient`. - **Tools** — Register built-in or custom tools via `interfaces.Tool`; optional **parallel vs sequential** execution for multiple tool calls in one LLM round (`WithAgentToolExecutionMode`). - **Human-in-the-loop** — Approval gates on tool calls and delegation across `Run`, `RunAsync`, and `Stream`. -- **Conversation** — Persist multi-turn message history across runs via `WithConversation`; built-in in-memory and Redis stores, or bring your own. -- **Sub-agents** — Delegate to specialist agents via `WithSubAgents`. +- **Conversation** — Persist multi-turn message history via `WithConversation`; in-memory store for in-process; in-memory and Redis for Temporal. Bring your own via `interfaces.Conversation`. +- **Sub-agents** — Delegate to specialist agents via `WithSubAgents`; recursive delegation with depth limiting; all sub-agent events fan in to the parent stream on both runtimes. - **MCP** — Extend agent capabilities by connecting any MCP server as a tool source via `WithMCPConfig` or `WithMCPClients`. - **A2A** — Connect remote [Agent-to-Agent](https://github.com/a2aproject/A2A) agents as tool providers via `WithA2AConfig` or `WithA2AClients`; or expose the agent itself as an A2A server via `WithA2ADefaultServer` / `WithA2AServer` and `RunA2A`. - **Retrieval (RAG)** — Ground agent responses in external knowledge bases via a pluggable `Retriever` interface with built-in Weaviate and pgvector support; extend with your own implementation. @@ -79,8 +79,11 @@ - **AG-UI** — Stream events conform to the [AG-UI protocol](https://docs.ag-ui.com); agents work out of the box with any AG-UI compatible frontend such as [CopilotKit](https://copilotkit.ai). - **Reasoning** — Extended thinking / chain-of-thought where supported (Anthropic, Gemini). - **Token usage** — Track input, output, and reasoning token counts per run. -- **Scale** — Add Temporal workers to scale agent execution horizontally. -- **Observability** — OpenTelemetry traces, metrics, and structured logs across all agent execution paths; export to any OTLP-compatible backend. +- **Observability** — OpenTelemetry traces, metrics, and structured logs; export to any OTLP-compatible backend. +- **Durable execution** ★ — Runs survive process crashes and restarts; Temporal workflow history ensures no step is lost. +- **Scale** ★ — Add Temporal workers to scale agent execution horizontally; split agent and worker across separate processes. + +> ★ Temporal runtime only. ## Reference apps @@ -88,7 +91,9 @@ Demo applications that use **agent-sdk-go** end-to-end: - **[Agent Chat](https://github.com/agenticenv/agent-chat)** — Web chat demo with durable conversations; a good reference for wiring the SDK into an HTTP-backed app. -## Temporal Runtime +## Runtimes + +### Temporal **Temporal** powers agents through three moving parts: a **Temporal client** that launches agent workflows, **workers** (typically `NewAgentWorker`) that poll task queues and execute workflow and activity code, and **workflow history** that makes each run durable. Workers are stateless — they replay and advance history, not hold state themselves. @@ -130,13 +135,29 @@ Stream events and approval events cross two boundaries: **Temporal** (durable wo - **Your responsibility.** Keep worker processes supervised and restarting on crash, maintain a stable connection to your Temporal cluster, and ensure stream subscribers can reconnect. - **Client reconnection and UX.** For interactive apps, if the process serving `Stream` crashes, the workflow continues in Temporal but your client loses the connection. Once a stream is lost, reconnecting to that specific run is not supported — the recommended approach is to block the user from sending a new prompt until the current one completes, then fetch the final response and display it. This keeps conversation turns sequential and avoids out-of-order state. For autonomous agents, this is a non-issue since the caller waits for completion and the workflow finishes regardless. +### In-Process + +Runs the agent loop directly in your process — zero setup, zero infrastructure. + +**When to use:** +- Development, testing, and prototyping +- Deployments where Temporal is not needed +- Short-lived runs where crash recovery is not required + +All capabilities listed above apply except ★ items. Conversation uses the in-memory store only. If the process crashes, the run is lost — no replay, no remote workers. + +**Switching to Temporal:** add `WithTemporalConfig` (or `WithTemporalClient`) to any `NewAgent` call — no other code changes required. + ## Getting Started How to **use** the SDK—agents, LLMs, Temporal connection, examples. ### Prerequisites -**agent-sdk-go** runs agents on the **[Temporal](https://temporal.io)** runtime (durable workflows and activities), so a **running Temporal server** is required. See **[Temporal setup](temporal-setup.md)**. Also **Go 1.26+** (see `go.mod`) and credentials for your LLM provider. +**Go 1.26+** (see `go.mod`) and credentials for your LLM provider are always required. + +- **In-process runtime** (default): no additional setup — just `go get` and an LLM API key. +- **Temporal runtime**: a running Temporal server is required. See **[Temporal setup](temporal-setup.md)**. **Module:** `github.com/agenticenv/agent-sdk-go` @@ -146,6 +167,8 @@ go get github.com/agenticenv/agent-sdk-go@latest ### Create an agent and run +**Local runtime** (no Temporal required): + ```go import ( "github.com/agenticenv/agent-sdk-go/pkg/agent" @@ -158,6 +181,19 @@ llmClient, _ := openai.NewClient( llm.WithModel("gpt-4o"), ) +a, _ := agent.NewAgent( + agent.WithSystemPrompt("You are a helpful assistant."), + agent.WithLLMClient(llmClient), +) +defer a.Close() + +result, err := a.Run(ctx, "Hello", "") +// result.Content, result.AgentName, result.Model +``` + +**Temporal runtime** (durable execution): + +```go a, _ := agent.NewAgent( agent.WithTemporalConfig(&agent.TemporalConfig{ Host: "localhost", Port: 7233, @@ -169,7 +205,6 @@ a, _ := agent.NewAgent( defer a.Close() result, err := a.Run(ctx, "Hello", "") -// result.Content, result.AgentName, result.Model ``` [examples/simple_agent](examples/simple_agent) @@ -426,7 +461,7 @@ defer a.Close() You may use **Option 1** for some servers and **Option 2** for others on the same agent; keep server names unique across both. -[examples/agent_with_mcp_config](examples/agent_with_mcp_config) and [examples/agent_with_mcp_client](examples/agent_with_mcp_client) show MCP from env (`stdio` or streamable HTTP, URL-only OK, optional bearer/OAuth). Variables: [examples/env.sample](examples/env.sample). Running examples from `examples/`: [examples/README.md](examples/README.md). **MCP transports and testing against real servers:** [examples/agent_with_mcp_config/README.md](examples/agent_with_mcp_config/README.md). +[examples/agent_with_mcp_config](examples/agent_with_mcp_config) and [examples/agent_with_mcp_client](examples/agent_with_mcp_client) show MCP from env (`stdio` or streamable HTTP, URL-only OK, optional bearer/OAuth). Variables: [examples/.env.defaults](examples/.env.defaults). Running examples from `examples/`: [examples/README.md](examples/README.md). **MCP transports and testing against real servers:** [examples/agent_with_mcp_config/README.md](examples/agent_with_mcp_config/README.md). ### A2A (Agent-to-Agent) @@ -474,7 +509,7 @@ if err := a.RunA2A(ctx); err != nil { } ``` -[examples/agent_with_a2a_server](examples/agent_with_a2a_server) shows a full server example with env-based config (`A2A_SERVER_HOST`, `A2A_SERVER_PORT`, `A2A_SERVER_BEARER_TOKENS`). Variables: [examples/env.sample](examples/env.sample). Running examples from `examples/`: [examples/README.md](examples/README.md). **Inbound server — curl, `a2a` CLI, bearer, cross-test with `agent_with_a2a_config`:** [examples/agent_with_a2a_server/README.md](examples/agent_with_a2a_server/README.md). +[examples/agent_with_a2a_server](examples/agent_with_a2a_server) shows a full server example with env-based config (`A2A_SERVER_HOST`, `A2A_SERVER_PORT`, `A2A_SERVER_BEARER_TOKENS`). Variables: [examples/.env.defaults](examples/.env.defaults). Running examples from `examples/`: [examples/README.md](examples/README.md). **Inbound server — curl, `a2a` CLI, bearer, cross-test with `agent_with_a2a_config`:** [examples/agent_with_a2a_server/README.md](examples/agent_with_a2a_server/README.md). #### A2A Client @@ -555,7 +590,7 @@ defer a.Close() You may use **Option 1** for some remote agents and **Option 2** for others on the same agent; keep connection names unique across both. -[examples/agent_with_a2a_config](examples/agent_with_a2a_config) and [examples/agent_with_a2a_client](examples/agent_with_a2a_client) show A2A from env (`A2A_URL`, optional bearer/headers/filter). Variables: [examples/env.sample](examples/env.sample). Running examples from `examples/`: [examples/README.md](examples/README.md). **Remote agent setup (e.g. `a2a-samples` helloworld), curl checks:** [examples/agent_with_a2a_config/README.md](examples/agent_with_a2a_config/README.md). +[examples/agent_with_a2a_config](examples/agent_with_a2a_config) and [examples/agent_with_a2a_client](examples/agent_with_a2a_client) show A2A from env (`A2A_URL`, optional bearer/headers/filter). Variables: [examples/.env.defaults](examples/.env.defaults). Running examples from `examples/`: [examples/README.md](examples/README.md). **Remote agent setup (e.g. `a2a-samples` helloworld), curl checks:** [examples/agent_with_a2a_config/README.md](examples/agent_with_a2a_config/README.md). ### Retrieval (RAG) @@ -760,7 +795,7 @@ if res.Err != nil { /* handle */ } For **Run** / **RunAsync**, use `req.Respond` only. For **Stream**, use `**OnApproval`** as in the snippet above—the activity token string is `**ApprovalToken**` from `**ParseCustomEventApproval**` / `**ParseCustomEventDelegation**` (not a field on the `**AgentEvent**` interface). -[examples/agent_with_tools_approval](examples/agent_with_tools_approval) +[examples/agent_with_tools/approval](examples/agent_with_tools/approval) [examples/agent_with_run_async](examples/agent_with_run_async) @@ -802,7 +837,7 @@ result, err := a.Run(context.Background(), "Hello", "") Implement `interfaces.Tool`: `Name()`, `Description()`, `Parameters()`, `Execute()`. Register with `agent.WithTools(tool1, tool2)`. -[examples/agent_with_custom_tools](examples/agent_with_custom_tools) +[examples/agent_with_tools/custom](examples/agent_with_tools/custom) ### Response format @@ -993,7 +1028,7 @@ Agent stream events follow the [AG-UI open protocol](https://docs.ag-ui.com), ma Events like `RUN_STARTED`, `TEXT_MESSAGE_CONTENT`, `TOOL_CALL_START`, and `REASONING_MESSAGE_CONTENT` are emitted in the correct AG-UI sequence during every `Stream()` call. Serialize any event with `event.ToJSON()` and forward it over SSE, WebSocket, or Redis to a TypeScript/React frontend using the AG-UI client SDK. -For a complete server + UI reference, see `[examples/agent_copilotkit](examples/agent_copilotkit)` (Go SSE server in `server/main.go`, Next.js + CopilotKit bridge in `ui/app/api/copilotkit/route.ts`). +For a complete server + UI reference, see [examples/agent_with_agui](examples/agent_with_agui) (Go SSE server in `server/main.go`, Next.js + CopilotKit bridge in `ui/app/api/copilotkit/route.ts`). ```go ch, err := a.Stream(ctx, prompt, conversationID) @@ -1131,7 +1166,7 @@ agent.WithLogLevel("debug") // show all SDK internal log lines ## Configuration -A Temporal connection is **required** — one of `WithTemporalConfig` or `WithTemporalClient` must be set; the agent does not run with LLM-only config. +A Temporal connection (`WithTemporalConfig` or `WithTemporalClient`) is **optional** — omit it to use the **local runtime** (in-process, no Temporal). Set it to use the **Temporal runtime** (durable execution). All other options work the same on both runtimes. - **WithTemporalConfig**: Temporal connection (Host, Port, Namespace, TaskQueue). Use for simple setups. See [Temporal connection](#temporal-connection). - **WithTemporalClient**: Pre-configured Temporal client. Use for TLS, API key auth, Temporal Cloud. Requires `WithTaskQueue`. Agent does not close the client. @@ -1183,11 +1218,13 @@ Coverage reports (PR and default branch) are on **[Codecov](https://app.codecov. ```bash git clone cd agent-sdk-go -cp examples/env.sample examples/.env -# Edit examples/.env: set LLM_APIKEY, LLM_MODEL +export LLM_APIKEY=your-key +export LLM_PROVIDER=your-provider +export LLM_MODEL=your-model +# LLM_PROVIDER: openai, anthropic, or gemini. Optional overrides in examples/.env ``` -See **[examples/README.md](examples/README.md)** for how to run examples, env vars ([examples/env.sample](examples/env.sample)), and optional **README.md** files inside specific example directories. +See **[examples/README.md](examples/README.md)** for how to run examples, env vars ([examples/.env.defaults](examples/.env.defaults)), and optional **README.md** files inside specific example directories. ### CLI configuration diff --git a/Taskfile.yml b/Taskfile.yml new file mode 100644 index 0000000..f5b9bdd --- /dev/null +++ b/Taskfile.yml @@ -0,0 +1,19 @@ +version: '3' + +dotenv: ['examples/.env'] + +includes: + examples: + taskfile: taskfiles/examples.yml + flatten: true + reports: + taskfile: taskfiles/reports.yml + flatten: true + +tasks: + # ── Default ─────────────────────────────────────────────── + + default: + desc: Show available tasks + cmds: + - task --list \ No newline at end of file diff --git a/cmd/config.go b/cmd/config.go index c82ea00..2e1c1f9 100644 --- a/cmd/config.go +++ b/cmd/config.go @@ -21,12 +21,40 @@ import ( ) type Config struct { + // Runtime selects the agent execution backend: "local" (default) or "temporal". + // Override with AGENT_RUNTIME env var or set runtime: temporal in config.yaml. + Runtime string `mapstructure:"runtime"` Temporal *TemporalConfig `mapstructure:"temporal"` LLM *LLMConfig `mapstructure:"llm"` Logger *LoggerConfig `mapstructure:"logger"` MCP *MCPRootConfig `mapstructure:"mcp"` } +// UseTemporalRuntime reports whether the temporal backend is selected. +func (c *Config) UseTemporalRuntime() bool { + return c != nil && strings.ToLower(strings.TrimSpace(c.Runtime)) == "temporal" +} + +// RuntimeOption returns [agent.WithTemporalConfig] when runtime is "temporal", or nil for +// the local runtime. Spread into the options slice: +// +// opts = append(opts, RuntimeOption(cfg)...) +// +// To hard-code a runtime regardless of config, call [agent.WithTemporalConfig] directly. +func RuntimeOption(cfg *Config) []agent.Option { + if !cfg.UseTemporalRuntime() || cfg.Temporal == nil { + return nil + } + return []agent.Option{ + agent.WithTemporalConfig(&agent.TemporalConfig{ + Host: cfg.Temporal.Host, + Port: cfg.Temporal.Port, + Namespace: cfg.Temporal.Namespace, + TaskQueue: cfg.Temporal.TaskQueue, + }), + } +} + // MCPRootConfig holds optional MCP server definitions for the CLI (see config.sample.yaml). type MCPRootConfig struct { Servers []MCPServerYAML `mapstructure:"servers"` @@ -202,6 +230,7 @@ func LoadConfig(path string) (*Config, error) { v.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) // Explicit BindEnv so AGENT_* env vars reliably override (AutomaticEnv can be inconsistent with nested keys) + _ = v.BindEnv("runtime", "AGENT_RUNTIME") _ = v.BindEnv("temporal.host", "AGENT_TEMPORAL_HOST") _ = v.BindEnv("temporal.port", "AGENT_TEMPORAL_PORT") _ = v.BindEnv("temporal.namespace", "AGENT_TEMPORAL_NAMESPACE") @@ -217,6 +246,7 @@ func LoadConfig(path string) (*Config, error) { _ = v.BindEnv("logger.tee_stderr", "AGENT_LOGGER_TEE_STDERR") // Set defaults so env can override even when file is missing or key absent + v.SetDefault("runtime", "local") v.SetDefault("temporal.host", "localhost") v.SetDefault("temporal.port", 7233) v.SetDefault("temporal.namespace", "default") diff --git a/cmd/config.sample.yaml b/cmd/config.sample.yaml index 493fc06..5288572 100644 --- a/cmd/config.sample.yaml +++ b/cmd/config.sample.yaml @@ -9,6 +9,12 @@ # Environment variables with AGENT_ prefix override file values. # Example: AGENT_LLM_APIKEY=sk-xxx AGENT_LLM_PROVIDER=openai go run ./cmd +# Runtime: local (default) or temporal. +# local — runs entirely in-process; no Temporal server required. +# temporal — durable execution via Temporal; requires a running Temporal server. +# Override per-run: AGENT_RUNTIME=temporal go run ./cmd +runtime: local + # Temporal connection (for durable agent execution) temporal: host: localhost # Temporal server host; use AGENT_TEMPORAL_HOST to override diff --git a/cmd/main.go b/cmd/main.go index a2d5205..4fc40c3 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -87,12 +87,6 @@ func main() { opts := []agent.Option{ agent.WithName("agentctl"), agent.WithSystemPrompt("You are a helpful assistant."), - agent.WithTemporalConfig(&agent.TemporalConfig{ - Host: cfg.Temporal.Host, - Port: cfg.Temporal.Port, - Namespace: cfg.Temporal.Namespace, - TaskQueue: cfg.Temporal.TaskQueue, - }), agent.WithLLMClient(llmClient), agent.WithStream(true), agent.WithToolRegistry(reg), @@ -100,6 +94,7 @@ func main() { agent.WithConversationSize(20), agent.WithLogger(lgr), } + opts = append(opts, RuntimeOption(cfg)...) if len(mcpServers) > 0 { opts = append(opts, agent.WithMCPConfig(mcpServers), diff --git a/examples/.env.defaults b/examples/.env.defaults new file mode 100644 index 0000000..125dcbf --- /dev/null +++ b/examples/.env.defaults @@ -0,0 +1,78 @@ +# Committed defaults for examples. Loaded automatically by examples/config.go. +# Optional overrides: create examples/.env (gitignored) for secrets and local changes. +# Load order: .env.defaults → examples/.env → process environment (export / Task / CI). +# Put LLM_APIKEY and EMBEDDING_OPENAI_APIKEY in examples/.env (never commit keys). + +# --- Logging --- +# error | warn | info | debug +LOG_LEVEL=error + +# --- Agent runtime --- +# local = in-process (default). temporal = requires Temporal server (see temporal-setup.md). +AGENT_RUNTIME=local + +# --- Temporal (when AGENT_RUNTIME=temporal) --- +# TEMPORAL_TASKQUEUE is a base prefix; each example appends its suffix (e.g. agent-sdk-go-simple_agent). +TEMPORAL_HOST=localhost +TEMPORAL_PORT=7233 +TEMPORAL_NAMESPACE=default +TEMPORAL_TASKQUEUE=agent-sdk-go + +# --- LLM (all examples) --- +# Provider: openai | anthropic | gemini. Set LLM_APIKEY in examples/.env. +LLM_PROVIDER=openai +LLM_APIKEY= +LLM_MODEL=gpt-4o +# Used for OpenAI client and as embedding base URL fallback when EMBEDDING_OPENAI_BASEURL is unset. +LLM_BASEURL=https://api.openai.com/v1 + +# --- MCP (agent_with_mcp_config, agent_with_mcp_client) --- +# Default: stdio + local filesystem sandbox (matches Task infra). Override in .env for streamable_http: +# MCP_TRANSPORT=streamable_http +# MCP_STREAMABLE_HTTP_URL=https://your-mcp-host/mcp +MCP_TRANSPORT=stdio +MCP_STDIO_COMMAND=npx +MCP_STDIO_ARGS=["-y","@modelcontextprotocol/server-filesystem","./mcp-filesystem-sandbox"] + +# --- A2A outbound (agent_with_a2a_config, agent_with_a2a_client) --- +# Base URL of remote agent (no path). Default matches agent_with_a2a_server from task infra:a2a:up. +A2A_URL=http://localhost:9999 + +# --- A2A inbound server (agent_with_a2a_server) --- +A2A_SERVER_HOST=localhost +A2A_SERVER_PORT=9999 + +# --- Retriever mode (weaviate and pgvector examples) --- +# agentic | prefetch | hybrid +RETRIEVER_MODE=agentic + +# --- Weaviate (agent_with_retriever/weaviate) --- +# Run task infra:weaviate:up from examples/ first. +WEAVIATE_HOST=localhost:8080 +WEAVIATE_SCHEME=http +WEAVIATE_CLASS=Document +WEAVIATE_RETRIEVER_NAME=weaviate-kb +WEAVIATE_CONTENT_FIELD=content +WEAVIATE_SOURCE_FIELD=source + +# --- pgvector (agent_with_retriever/pgvector) --- +# Run task infra:pgvector:up from examples/ first. +PGVECTOR_DSN=postgres://postgres:secret@localhost:5432/vectordb?sslmode=disable +PGVECTOR_TABLE=documents +PGVECTOR_RETRIEVER_NAME=pgvector-kb +PGVECTOR_CONTENT_COL=content +PGVECTOR_SOURCE_COL=source +PGVECTOR_EMBEDDING_COL=embedding +PGVECTOR_MIN_SCORE=0.35 + +# --- OpenAI-compatible embeddings (pgvector client-side; Weaviate text2vec-openai in Docker) --- +# Set EMBEDDING_OPENAI_APIKEY in examples/.env (separate from chat LLM_APIKEY for Anthropic/Gemini). +EMBEDDING_OPENAI_APIKEY= +EMBEDDING_OPENAI_MODEL=text-embedding-3-small +EMBEDDING_OPENAI_BASEURL=https://api.openai.com/v1 + +# --- OpenTelemetry (agent_with_observability; LGTM from task infra:lgtm:up) --- +# Host:port without scheme. grpc uses port 4317; http often 4318. +OTEL_EXPORTER_OTLP_ENDPOINT=localhost:4317 +OTLP_PROTOCOL=grpc +OTLP_INSECURE=true diff --git a/examples/README.md b/examples/README.md index 68e16a0..3ca4ffc 100644 --- a/examples/README.md +++ b/examples/README.md @@ -1,49 +1,76 @@ # Examples -These programs exercise **agent-sdk-go** (`github.com/agenticenv/agent-sdk-go`). Agents run as Temporal workflows, so a running Temporal service is mandatory for every example below. +These programs exercise **agent-sdk-go** (`github.com/agenticenv/agent-sdk-go`). By default examples run on the **local** runtime (in-process, no external services). Set `AGENT_RUNTIME=temporal` in `.env` for durable Temporal execution. -**Prerequisite:** These examples use **agent-sdk-go** on the **Temporal** runtime. A **running Temporal server** is required before you run them. See **[Temporal setup](../temporal-setup.md)** for Docker, Temporal CLI, ports, Cloud, and self-hosted options. +## Runtime -## Default connection +| Mode | How to enable | Requirement | +|------|--------------|-------------| +| `local` (default) | `AGENT_RUNTIME=local` (or unset) | Nothing — runs in-process | +| `temporal` | `AGENT_RUNTIME=temporal` | **`task infra:temporal:up`** + **`infra:temporal:wait`** from `examples/`, or **[Temporal setup](../temporal-setup.md)** | -The examples use `TEMPORAL_HOST`, `TEMPORAL_PORT`, and `TEMPORAL_NAMESPACE` from `.env` (default: localhost, 7233, default). Adjust if your Temporal runs elsewhere. +When using Temporal the examples read `TEMPORAL_HOST`, `TEMPORAL_PORT`, and `TEMPORAL_NAMESPACE` from `.env` (default: localhost, 7233, default). ## Examples overview -| Example | What it demonstrates | -|---------|---------------------| -| `simple_agent` | Minimal agent, no tools — Temporal config, system prompt, LLM client, single `Run()`; prints `AgentResponse.Usage` (token counts) when the provider reports them | -| `agent_with_temporal_client` | Caller-owned Temporal client — `WithTemporalClient` + `WithTaskQueue`; create and close client yourself (TLS, API key, Cloud) | -| `agent_with_conversation` | In-memory conversation with `WithConversation` — multi-turn context, same `conversationID` for `Run` | -| `agent_with_tools` | Built-in tools (echo, calculator, weather, wikipedia, search) with auto-approval | -| `agent_with_stream` | Streaming with `Stream` — **`TEXT_MESSAGE_*`**, **`TOOL_CALL_*`**, **`RUN_FINISHED`**; prints token usage from **`RUN_FINISHED`** result when present | -| `agent_copilotkit` | Go **`POST /agui` SSE** + **Next.js + CopilotKit** ([`agent_copilotkit/README.md`](agent_copilotkit/README.md)) — two processes: agent server, then `ui/` dev server | -| `agent_with_stream_conversation` | Stream + conversation; avoid printing the same text twice (**`TEXT_MESSAGE_CONTENT`** deltas vs **`RUN_FINISHED`** body) | -| `agent_with_tools_approval` | Tools + `WithApprovalHandler` — user approves or rejects each tool run (Run only) | -| `agent_with_run_async` | `RunAsync` — `resultCh` + `approvalCh`; use `req.Respond` (no `WithApprovalHandler`) | -| `agent_with_custom_tools` | Custom tools via `WithTools` — implementing `interfaces.Tool` | -| `agent_with_tool_authorizer` | Custom tool authorization via `interfaces.ToolAuthorizer` — denied calls surface as `tool_result` with `denied` status | -| `multiple_agents` | Multiple agents with `WithInstanceId` — sequential or concurrent | -| `agent_with_subagents` | Main agent + math specialist — `WithSubAgents`, separate task queues; prints **`STEP_STARTED` / `STEP_FINISHED`** (sub-agent name) around each child run when using `Stream` | -| `agent_with_json_response` | Structured LLM output — `WithResponseFormat` + `interfaces.JSONSchema` (JSON with schema; no tools) | -| `agent_with_reasoning` | Generic `interfaces.LLMReasoning` via `WithLLMSampling` — `Stream` to observe `thinking_delta` (e.g. Anthropic) | -| `agent_with_worker` | Agent and worker in separate processes — `DisableLocalWorker` + `NewAgentWorker`; agent uses **`Stream`** | -| `durable_agent` | Same split-process layout — agent uses **`Stream`** (`WithStream`); durability scenarios: [`durable_agent/README.md`](durable_agent/README.md) | -| `agent_with_mcp_config` | MCP via `WithMCPConfig` — transport from env; see **`env.sample`** — **[README](agent_with_mcp_config/README.md)** (testing & sample servers) | -| `agent_with_mcp_client` | Same as above via `mcpclient.NewClient` + `WithMCPClients` — **[README](agent_with_mcp_client/README.md)** | -| `agent_with_a2a_config` | Outbound A2A via `WithA2AConfig` — **`A2A_URL`** etc.; **[README](agent_with_a2a_config/README.md)** | -| `agent_with_a2a_client` | Same env, explicit **`pkg/a2a/client`** — **[README](agent_with_a2a_client/README.md)** | -| `agent_with_a2a_server` | **Inbound** A2A server — **`A2A_SERVER_*`**; **[README](agent_with_a2a_server/README.md)** (curl, **`a2a` CLI**, client example) | -| `agent_with_observability` | OpenTelemetry OTLP exports — two runnable programs: **`config/`** ([`WithObservabilityConfig`](../pkg/agent/config.go)) vs **`objects/`** (pre-built [`pkg/observability`](../pkg/observability/) tracer/metrics + [`WithTracer`](../pkg/agent/config.go) / [`WithMetrics`](../pkg/agent/config.go)); shared **`setup/`** helper package — **[README](agent_with_observability/README.md)** (collector endpoint, ports **`4317`**/**`4318`**) | -| `agent_with_retriever` | Vector retrievers — **`weaviate/`** or **`pgvector/`** backends; shared **`common/`**; modes **`agentic`**, **`prefetch`**, **`hybrid`** via **`RETRIEVER_MODE`** — **[README](agent_with_retriever/README.md)** (Weaviate / Postgres setup in subfolder READMEs) | +From **`examples/`**, use [Task](https://taskfile.dev) and [`Taskfile.yml`](Taskfile.yml) for infra in the table below, then **`go run ./`**. Install **`task`**, infra targets, and batch runs: [Setup](#setup). Third-party MCP/A2A servers stay manual — see each example’s README. + +### Works with both runtimes + +These examples run with `AGENT_RUNTIME=local` (default) or `AGENT_RUNTIME=temporal`. + +**Temporal runtime:** set `AGENT_RUNTIME=temporal` in `.env`, then run **`task infra:temporal:up`** and **`task infra:temporal:wait`** before `go run` (for every row below, in addition to the infra in the third column). + +| Example | What it demonstrates | Infra (Task, from `examples/`) | +|---------|---------------------|--------------------------------| +| `simple_agent` | Minimal agent, no tools — system prompt, LLM client, single `Run()`; prints `AgentResponse.Usage` (token counts) when the provider reports them | — | +| `agent_with_conversation` | In-memory conversation with `WithConversation` — multi-turn context, same `conversationID` for `Run` | — | +| `agent_with_tools/basic` | Built-in tools (echo, calculator, weather, wikipedia, search) with auto-approval | — | +| `agent_with_tools/approval` | Tools + `WithApprovalHandler` — user approves or rejects each tool run (`Run` only) | — | +| `agent_with_tools/authorizer` | Custom tool authorization via `interfaces.ToolAuthorizer` — denied calls surface as `tool_result` with `denied` status | — | +| `agent_with_tools/custom` | Custom tools via `WithTools` — implementing `interfaces.Tool` | — | +| `agent_with_stream` | Streaming with `Stream` — **`TEXT_MESSAGE_*`**, **`TOOL_CALL_*`**, **`RUN_FINISHED`**; prints token usage from **`RUN_FINISHED`** result when present | — | +| `agent_with_agui` | Go **`POST /agui` SSE** + **Next.js + CopilotKit** ([`agent_with_agui/README.md`](agent_with_agui/README.md)) — agent server, then `ui/` dev server | UI manual (`npm run dev` in `ui/`) | +| `agent_with_stream_conversation` | Stream + conversation; avoid printing the same text twice (**`TEXT_MESSAGE_CONTENT`** deltas vs **`RUN_FINISHED`** body) | — | +| `agent_with_run_async` | `RunAsync` — `resultCh` + `approvalCh`; use `req.Respond` (no `WithApprovalHandler`) | — | +| `multiple_agents` | Multiple agents with `WithInstanceId` — sequential or concurrent | — | +| `agent_with_subagents` | Main agent + math specialist — `WithSubAgents`; prints **`STEP_STARTED` / `STEP_FINISHED`** (sub-agent name) around each child run when using `Stream` | — | +| `agent_with_json_response` | Structured LLM output — `WithResponseFormat` + `interfaces.JSONSchema` (JSON with schema; no tools) | — | +| `agent_with_reasoning` | Generic `interfaces.LLMReasoning` via `WithLLMSampling` — `Stream` to observe `thinking_delta` (e.g. Anthropic) | — | +| `agent_with_mcp_config` | MCP via `WithMCPConfig` — transport from env; **[README](agent_with_mcp_config/README.md)** | stdio: — (`.env.defaults`); remote MCP: manual | +| `agent_with_mcp_client` | Same via `mcpclient.NewClient` + `WithMCPClients` — **[README](agent_with_mcp_client/README.md)** | same as `mcp_config` | +| `agent_with_a2a_config` | Outbound A2A via `WithA2AConfig` — **`A2A_URL`**; **[README](agent_with_a2a_config/README.md)** | `infra:a2a:up` or external A2A (manual) | +| `agent_with_a2a_client` | Same env, explicit **`pkg/a2a/client`** | same as `a2a_config` | +| `agent_with_a2a_server` | **Inbound** A2A server — **`A2A_SERVER_*`**; **[README](agent_with_a2a_server/README.md)** | `go run` or `infra:a2a:up` | +| `agent_with_observability` | OTLP — **`config/`** vs **`objects/`**; **[README](agent_with_observability/README.md)** | `infra:lgtm:up` (or manual collector) | +| `agent_with_retriever` | **`weaviate/`** or **`pgvector/`**; **`RETRIEVER_MODE`** — **[README](agent_with_retriever/README.md)** | `infra:weaviate:up` or `infra:pgvector:up` | + +### Temporal only + +Set **`AGENT_RUNTIME=temporal`**. Start **`task infra:temporal:up`** and **`task infra:temporal:wait`** before `go run`. + +| Example | What it demonstrates | Infra (Task, from `examples/`) | +|---------|---------------------|--------------------------------| +| `agent_with_temporal_client` | Caller-owned Temporal client — `WithTemporalClient` + `WithTaskQueue`; TLS, API key, Cloud | `infra:temporal:up`, `infra:temporal:wait` | +| `agent_with_worker` | Agent and worker in **separate processes** — `DisableLocalWorker` + `NewAgentWorker`; **`Stream`** | `infra:temporal:up`, `infra:temporal:wait` | +| `durable_agent` | Split-process durability scenarios — **[README](durable_agent/README.md)** | `infra:temporal:up`, `infra:temporal:wait` | ## Setup +**`.env.defaults`** is loaded automatically: valid values for local Task infra (stdio MCP, A2A on `:9999`, Weaviate/pgvector ports, OTLP to LGTM). Create optional **`examples/.env`** (gitignored) for secrets and overrides: + ```bash -cp env.sample .env -# Edit .env: set LLM_APIKEY, LLM_MODEL (see LLM_PROVIDER: openai, anthropic, or gemini) +# From examples/ — at minimum set keys; override anything else as needed +cat >> .env <<'EOF' +LLM_APIKEY=your-key +EMBEDDING_OPENAI_APIKEY=your-openai-embeddings-key +EOF ``` +Override **`LLM_PROVIDER`** / **`LLM_MODEL`**, **`MCP_TRANSPORT=streamable_http`** + **`MCP_STREAMABLE_HTTP_URL`**, a remote **`A2A_URL`**, or retriever vars when not using the default local stack. Process environment (export / root **`Taskfile.yml`** `dotenv`) wins over both files. See [env vars](#env-vars) and **`examples/.env.defaults`**. + +**Task** — not installed by default; install via **[Task installation](https://taskfile.dev/installation/)** (platform-specific). Not needed for **`go run ./`** when the overview table has no infra. Compose infra also needs **Docker**. From **`examples/`**: **`task infra:status`**, **`infra:deps:up`** / **`down`**, **`infra:*:up`** / **`down`**. From **repo root**: **`task examples:local`**, **`task examples:temporal`**. **`task --dry`** only prints commands (no report file). To preview the report layout without running examples or infra, use **`task examples:local:plan`**, **`task examples:temporal:plan`**, or **`task examples:all:plan`**. + ## Run examples ### Minimal agent (no tools) @@ -52,14 +79,6 @@ cp env.sample .env go run ./simple_agent "Hello, what can you do?" ``` -### Agent with caller-owned Temporal client - -Uses `WithTemporalClient` and `WithTaskQueue`. The example creates the Temporal client, passes it to the agent, and closes it when done. Use this pattern for TLS, Temporal Cloud API keys, or other connection options. - -```bash -go run ./agent_with_temporal_client "Hello, what can you do?" -``` - ### Agent with conversation (multi-turn) Uses in-memory conversation. Run **interactive mode** (no args) for multi-turn in one process—history is shared across turns. With args, runs a single turn (useful for testing). @@ -75,7 +94,10 @@ go run ./agent_with_conversation "Hello, remember I'm Alice" ### Agent with tools ```bash -go run ./agent_with_tools "What's the weather in Tokyo?" +go run ./agent_with_tools/basic "What's the weather in Tokyo?" +go run ./agent_with_tools/approval "What is 15 + 27?" +go run ./agent_with_tools/authorizer "Get the protected note for roadmap." +go run ./agent_with_tools/custom "Reverse 'hello world'" ``` ### Streaming (partial content as tokens arrive) @@ -84,9 +106,21 @@ go run ./agent_with_tools "What's the weather in Tokyo?" go run ./agent_with_stream "What's the current time and what's 17 * 23?" ``` -### Structured JSON response (`WithResponseFormat`) +### AG-UI / CopilotKit (`agent_with_agui`) + +Go SSE server + Next.js frontend. Two processes: -Uses `agent.WithResponseFormat` with `interfaces.ResponseFormatJSON`, `Name`, and `interfaces.JSONSchema`. No tools—keeps the run in structured-output mode. Prints validated, indented JSON. +```bash +# Terminal 1: Go agent server (listens on :8787) +go run ./agent_with_agui/server + +# Terminal 2: Next.js UI +cd agent_with_agui/ui && npm install && npm run dev +``` + +See **[agent_with_agui/README.md](agent_with_agui/README.md)** for curl testing and UI setup. + +### Structured JSON response (`WithResponseFormat`) ```bash go run ./agent_with_json_response @@ -95,8 +129,6 @@ go run ./agent_with_json_response "What is the capital of Japan?" ### Reasoning / thinking (`WithLLMSampling` + `LLMReasoning`) -Sets `WithLLMSampling` with `Reasoning: &interfaces.LLMReasoning{Enabled, Effort, BudgetTokens}` and uses **`Stream`** so you can see **`thinking_delta`** events when the provider emits them (e.g. Anthropic extended thinking). Pick a model that supports reasoning/thinking for your `LLM_PROVIDER`. - ```bash go run ./agent_with_reasoning go run ./agent_with_reasoning "Why is the sky blue? One short paragraph." @@ -104,8 +136,6 @@ go run ./agent_with_reasoning "Why is the sky blue? One short paragraph." ### Streaming + conversation (event handling pattern) -Interactive multi-turn with `Stream`. Uses **`AgentEventTypeTextMessageContent`** (deltas) and **`AgentEventTypeRunFinished`** (final body) so the same answer is not printed twice. - ```bash go run ./agent_with_stream_conversation go run ./agent_with_stream_conversation "What is 5 * 8?" @@ -113,36 +143,21 @@ go run ./agent_with_stream_conversation "What is 5 * 8?" ### Sub-agents (main agent + specialist) -Two agents in one process: main agent with a math specialist registered via `WithSubAgents`. Requires workers on **both** task queues (each `NewAgent` starts its own embedded worker). Main agent uses default tool approval (**RequireAll**): delegating to the specialist prompts on **stdin** (`y` / `n`). Specialist uses **AutoToolApprovalPolicy** so calculator does not prompt. Same stdin pattern as [agent_with_tools_approval](#tools--approval-custom-tools-multiple-agents-worker-split). - ```bash go run ./agent_with_subagents "What is 987 times 654?" ``` -### Tools + approval, custom tools, multiple agents, worker split +### RunAsync + multiple agents ```bash -go run ./agent_with_tools_approval "What is 15 + 27?" go run ./agent_with_run_async "What is 15 + 27?" -go run ./agent_with_custom_tools "Reverse 'hello world'" -go run ./agent_with_tool_authorizer "Get the protected note for roadmap." go run ./multiple_agents "What is 7 times 8?" go run ./multiple_agents concurrent "What is 7 times 8?" - -# Agent and worker in separate processes: two terminals — worker in terminal 1, -# agent in terminal 2 (after the worker is up). -go run ./agent_with_worker/worker # terminal 1: worker -go run ./agent_with_worker/agent "Hello from remote agent!" # terminal 2: agent - -# durable_agent: same two-terminal flow; streaming REPL. Scenarios: -# durable_agent/README.md -go run ./durable_agent/worker # terminal 1 -go run ./durable_agent/agent "Hello from remote agent!" # terminal 2 ``` ### MCP (`agent_with_mcp_config`, `agent_with_mcp_client`) -Same **`MCP_*`** env (see **`env.sample`**); differs only in **`WithMCPConfig`** vs **`mcpclient.NewClient`** + **`WithMCPClients`**. +Same **`MCP_*`** env (see **`.env.defaults`**); differs only in **`WithMCPConfig`** vs **`mcpclient.NewClient`** + **`WithMCPClients`**. ```bash go run ./agent_with_mcp_config @@ -151,11 +166,11 @@ go run ./agent_with_mcp_client go run ./agent_with_mcp_client "List tools you can call." ``` -**Configure transports, test against real MCP servers (streamable HTTP walkthrough, stdio, links):** **[agent_with_mcp_config/README.md](agent_with_mcp_config/README.md)**. +**Configure transports, test against real MCP servers:** **[agent_with_mcp_config/README.md](agent_with_mcp_config/README.md)**. ### A2A client (`agent_with_a2a_config`, `agent_with_a2a_client`) -Outbound A2A tools — set **`A2A_URL`** (and optional **`A2A_*`** in **`env.sample`**). +Outbound A2A tools — set **`A2A_URL`** (and optional **`A2A_*`** in **`.env.defaults`**). ```bash go run ./agent_with_a2a_config @@ -164,7 +179,7 @@ go run ./agent_with_a2a_client go run ./agent_with_a2a_client "What tools do you have available?" ``` -**Run a sample remote agent (e.g. `a2a-samples` helloworld), curl checks:** **[agent_with_a2a_config/README.md](agent_with_a2a_config/README.md)**. +**Run a sample remote agent, curl checks:** **[agent_with_a2a_config/README.md](agent_with_a2a_config/README.md)**. ### A2A server (`agent_with_a2a_server`) @@ -178,33 +193,62 @@ go run ./agent_with_a2a_server ### Observability OTLP (`agent_with_observability`) -Requires a reachable OTLP **collector** (**`OTEL_EXPORTER_OTLP_ENDPOINT`**, typically **`localhost:4317`** for gRPC or **`localhost:4318`** for HTTP — see **`env.sample`**). Same Temporal + LLM setup as other examples. +Requires a reachable OTLP **collector** (**`OTEL_EXPORTER_OTLP_ENDPOINT`**, typically **`localhost:4317`** for gRPC or **`localhost:4318`** for HTTP). ```bash go run ./agent_with_observability/config/ go run ./agent_with_observability/objects/ - -go run ./agent_with_observability/config/ "Say hello in one sentence" -go run ./agent_with_observability/objects/ "Say hello in one sentence" ``` -Details, env semantics, and collector notes: **[agent_with_observability/README.md](agent_with_observability/README.md)**. +Details and collector notes: **[agent_with_observability/README.md](agent_with_observability/README.md)**. ### Vector retriever (`agent_with_retriever`) -Requires a running vector store (Weaviate **or** Postgres with pgvector) plus Temporal and LLM env. Set backend-specific vars in **`env.sample`** (`WEAVIATE_*` or **`PGVECTOR_DSN`**). +Requires a running vector store (Weaviate **or** Postgres with pgvector). Set backend-specific vars in **`.env.defaults`**. ```bash -# Weaviate (run ./agent_with_retriever/weaviate/setup.sh; ./cleanup.sh when done) +# Weaviate (task infra:weaviate:up first; task infra:weaviate:down when done) go run ./agent_with_retriever/weaviate "What is the return policy?" -# pgvector (run ./agent_with_retriever/pgvector/setup.sh; ./cleanup.sh when done) +# pgvector (task infra:pgvector:up first; task infra:pgvector:down when done) go run ./agent_with_retriever/pgvector "What is the return policy?" RETRIEVER_MODE=prefetch go run ./agent_with_retriever/weaviate "What are the return and shipping rules?" ``` -Setup guides: **[agent_with_retriever/README.md](agent_with_retriever/README.md)**, **[weaviate/README.md](agent_with_retriever/weaviate/README.md)**, **[pgvector/README.md](agent_with_retriever/pgvector/README.md)**. +Setup guides: **[agent_with_retriever/README.md](agent_with_retriever/README.md)**. + +--- + +### Temporal-only examples + +> These require `AGENT_RUNTIME=temporal` and a running Temporal server. + +#### Caller-owned Temporal client + +Creates and manages the Temporal client directly — for TLS, Temporal Cloud API keys, or custom connection options. + +```bash +AGENT_RUNTIME=temporal go run ./agent_with_temporal_client "Hello, what can you do?" +``` + +#### Agent + worker in separate processes (`agent_with_worker`) + +```bash +AGENT_RUNTIME=temporal go run ./agent_with_worker/worker # terminal 1: worker +AGENT_RUNTIME=temporal go run ./agent_with_worker/agent "Hello from remote agent!" # terminal 2: agent +``` + +#### Durable agent — workflow replay and failure scenarios (`durable_agent`) + +```bash +AGENT_RUNTIME=temporal go run ./durable_agent/worker # terminal 1 +AGENT_RUNTIME=temporal go run ./durable_agent/agent "Hello from remote agent!" # terminal 2 +``` + +See **[durable_agent/README.md](durable_agent/README.md)** for durability and failure scenarios. + +--- ## Logging @@ -227,8 +271,9 @@ Examples send conversation (user prompt, assistant response) to **stdout** and i | Env var | Description | |---------|-------------| -| `TEMPORAL_HOST`, `TEMPORAL_PORT`, `TEMPORAL_NAMESPACE`, `TEMPORAL_TASKQUEUE` | Temporal connection | -| `LLM_PROVIDER` | `openai`, `anthropic`, or `gemini` (see `env.sample`) | +| `AGENT_RUNTIME` | `local` (default) or `temporal` — selects the execution backend | +| `TEMPORAL_HOST`, `TEMPORAL_PORT`, `TEMPORAL_NAMESPACE`, `TEMPORAL_TASKQUEUE` | Temporal connection (used when `AGENT_RUNTIME=temporal`) | +| `LLM_PROVIDER` | `openai`, `anthropic`, or `gemini` (see `.env.defaults`) | | `LLM_APIKEY` | API key | | `LLM_MODEL` | e.g. `gpt-4o`, `claude-3-5-sonnet-20241022` | | `LLM_BASEURL` | Optional (custom/proxy endpoints) | @@ -256,9 +301,9 @@ Examples send conversation (user prompt, assistant response) to **stdout** and i | `A2A_SERVER_HOST` | Optional bind hostname for **`agent_with_a2a_server`** (empty → default **localhost**) | | `A2A_SERVER_PORT` | Optional TCP port for **`agent_with_a2a_server`** (0 → default **9999**) | | `A2A_SERVER_BEARER_TOKENS` | Optional comma-separated bearer secrets for inbound JSON-RPC on **`agent_with_a2a_server`** | -| `OTEL_EXPORTER_OTLP_ENDPOINT` | **Required** for **`agent_with_observability`** examples: OTLP collector **`host:port`** only (no `http://` scheme), e.g. **`localhost:4317`** (gRPC) or **`localhost:4318`** (HTTP) | -| `OTLP_PROTOCOL` | Optional for **`agent_with_observability`**: **`grpc`** (default) or **`http`** — must match how the collector listens | -| `OTLP_INSECURE` | Optional: set to **`true`** for plaintext export (typical for local collectors without TLS) | +| `OTEL_EXPORTER_OTLP_ENDPOINT` | **Required** for **`agent_with_observability`**: OTLP collector **`host:port`** (no `http://` scheme), e.g. **`localhost:4317`** (gRPC) or **`localhost:4318`** (HTTP) | +| `OTLP_PROTOCOL` | Optional: **`grpc`** (default) or **`http`** — must match how the collector listens | +| `OTLP_INSECURE` | Optional: **`true`** for plaintext export (typical for local collectors without TLS) | | `RETRIEVER_MODE` | For **`agent_with_retriever`**: **`agentic`** (default), **`prefetch`**, or **`hybrid`** | -| `WEAVIATE_HOST`, `WEAVIATE_SCHEME`, `WEAVIATE_CLASS`, … | Weaviate backend — see **`env.sample`** and **[agent_with_retriever/weaviate/README.md](agent_with_retriever/weaviate/README.md)** | -| `PGVECTOR_DSN`, `PGVECTOR_TABLE`, `EMBEDDING_MODEL`, … | pgvector backend — **`PGVECTOR_DSN` required**; see **[agent_with_retriever/pgvector/README.md](agent_with_retriever/pgvector/README.md)** | +| `WEAVIATE_HOST`, `WEAVIATE_SCHEME`, `WEAVIATE_CLASS`, … | Weaviate backend — **`.env.defaults`** and **[agent_with_retriever/README.md#weaviate](agent_with_retriever/README.md#weaviate)** | +| `PGVECTOR_DSN`, `PGVECTOR_TABLE`, `EMBEDDING_OPENAI_MODEL`, … | pgvector backend — **`PGVECTOR_DSN` required**; **[agent_with_retriever/README.md#pgvector](agent_with_retriever/README.md#pgvector)** | diff --git a/examples/Taskfile.yml b/examples/Taskfile.yml new file mode 100644 index 0000000..d494c62 --- /dev/null +++ b/examples/Taskfile.yml @@ -0,0 +1,209 @@ +version: '3' + +# Load secrets for compose interpolation and seed scripts (paths relative to this file). +dotenv: ['.env'] + +vars: + A2A_HEALTH_URL: '{{.A2A_HEALTH_URL | default "http://localhost:9999/.well-known/agent-card.json"}}' + A2A_PID_FILE: '{{.TASKFILE_DIR}}/.a2a-server.pid' + MCP_TRANSPORT: stdio + MCP_STDIO_COMMAND: node + MCP_STDIO_ARGS: '["-y","@modelcontextprotocol/server-filesystem","{{.TASKFILE_DIR}}/mcp-filesystem-sandbox"]' + +tasks: + # ── Prerequisites ────────────────────────────────────────── + + prereq:check: + internal: true + silent: true + desc: Check tools required for examples infra and integration deps + cmds: + - | + for c in docker go node curl jq; do + command -v "$c" >/dev/null || { echo "❌ $c not installed"; exit 1; } + done + + # ── Infra ────────────────────────────────────────────────── + + infra:up: + desc: Start all infrastructure + deps: [prereq:check] + cmds: + - task infra:temporal:up + - task infra:lgtm:up + + infra:down: + desc: Stop all infrastructure + deps: [prereq:check] + cmds: + - task infra:temporal:down + - task infra:lgtm:down + + infra:status: + desc: Show status of example infra + silent: true + cmds: + - | + set -euo pipefail + COMPOSE="docker/docker-compose.yml" + compose_svc() { + local svc="$1" + if [ -n "$(docker compose -f "$COMPOSE" ps -q "$svc" 2>/dev/null)" ]; then + echo "✅ $svc (compose)" + else + echo "❌ $svc (compose)" + fi + } + + echo "Examples infra" + echo "──────────────" + A2A_URL="{{.A2A_HEALTH_URL}}" + + compose_svc temporal + compose_svc otel-lgtm + compose_svc weaviate + compose_svc pgvector + if curl -sf "$A2A_URL" >/dev/null 2>&1; then + echo "✅ A2A server" + else + echo "❌ A2A server" + fi + + infra:temporal:up: + desc: Start Temporal + deps: [prereq:check] + cmds: + - docker compose -f docker/docker-compose.yml up -d temporal + + infra:temporal:down: + desc: Stop Temporal + deps: [prereq:check] + cmds: + - docker compose -f docker/docker-compose.yml down temporal + + infra:temporal:wait: + desc: Wait for Temporal to be ready + deps: [prereq:check] + cmds: + - echo "⏳ waiting for Temporal..." + - | + for i in $(seq 1 30); do + if docker compose -f docker/docker-compose.yml exec temporal temporal operator cluster health 2>/dev/null; then + echo "✅ Temporal ready" + exit 0 + fi + sleep 2 + done + echo "❌ Temporal not ready after 60s" + exit 1 + + infra:lgtm:up: + desc: Start LGTM (OpenTelemetry collector) + deps: [prereq:check] + cmds: + - docker compose -f docker/docker-compose.yml up -d otel-lgtm + + infra:lgtm:down: + desc: Stop LGTM + deps: [prereq:check] + cmds: + - docker compose -f docker/docker-compose.yml down otel-lgtm + + # ── Shared deps (local + temporal examples) ───────────────── + + infra:deps:up: + desc: Start shared example deps (LGTM, Weaviate, pgvector, A2A server) + deps: [prereq:check] + cmds: + - task: infra:lgtm:up + - task: infra:weaviate:up + - task: infra:pgvector:up + - task: infra:a2a:up + + infra:deps:down: + desc: Stop shared example deps + deps: [prereq:check] + cmds: + - task: infra:a2a:down + - task: infra:pgvector:down + - task: infra:weaviate:down + - task: infra:lgtm:down + + infra:weaviate:up: + desc: Start Weaviate (compose) and seed sample data + deps: [prereq:check] + preconditions: + - sh: test -n "$EMBEDDING_OPENAI_APIKEY" + msg: "EMBEDDING_OPENAI_APIKEY is required in examples/.env (or your environment) before starting Weaviate." + cmds: + - docker compose -f docker/docker-compose.yml up -d --wait weaviate + - docker/weaviate/seed.sh + + infra:weaviate:down: + desc: Stop Weaviate (compose) + deps: [prereq:check] + cmds: + - docker compose -f docker/docker-compose.yml down weaviate + + infra:pgvector:up: + desc: Start pgvector Postgres (compose) and seed sample data + deps: [prereq:check] + preconditions: + - sh: test -n "$EMBEDDING_OPENAI_APIKEY" + msg: "EMBEDDING_OPENAI_APIKEY is required in examples/.env (or your environment) before seeding pgvector." + cmds: + - docker compose -f docker/docker-compose.yml up -d --wait pgvector + - docker/pgvector/seed.sh + + infra:pgvector:down: + desc: Stop pgvector (compose) + deps: [prereq:check] + cmds: + - docker compose -f docker/docker-compose.yml down pgvector + + infra:a2a:up: + desc: Start agent_with_a2a_server in background + deps: [prereq:check] + cmds: + - | + if curl -sf "{{.A2A_HEALTH_URL}}" >/dev/null 2>&1; then + echo "✅ A2A server already running ({{.A2A_HEALTH_URL}})" + exit 0 + fi + mkdir -p "{{.TASKFILE_DIR}}/logs" + echo "🚀 Starting A2A server..." + go run ./agent_with_a2a_server >> "{{.TASKFILE_DIR}}/logs/a2a-server.log" 2>&1 & + printf '%s\n' "$!" > "{{.A2A_PID_FILE}}" + - task: infra:a2a:wait + + infra:a2a:wait: + desc: Wait for A2A server health endpoint + deps: [prereq:check] + cmds: + - echo "⏳ waiting for A2A server..." + - | + for i in $(seq 1 30); do + if curl -sf "{{.A2A_HEALTH_URL}}" >/dev/null; then + echo "✅ A2A server ready ({{.A2A_HEALTH_URL}})" + exit 0 + fi + sleep 2 + done + echo "❌ A2A server not ready after 60s" + exit 1 + + infra:a2a:down: + desc: Stop agent_with_a2a_server started by infra:a2a:up + cmds: + - | + if [ -f "{{.A2A_PID_FILE}}" ]; then + PID="$(cat "{{.A2A_PID_FILE}}" | tr -d '[:space:]')" + if [ -n "$PID" ] && kill -0 "$PID" 2>/dev/null; then + kill "$PID" 2>/dev/null || true + sleep 1 + kill -9 "$PID" 2>/dev/null || true + fi + rm -f "{{.A2A_PID_FILE}}" + fi + pkill -f 'go run ./agent_with_a2a_server' 2>/dev/null || true + pkill -f 'agent_with_a2a_server' 2>/dev/null || true diff --git a/examples/agent_with_a2a_client/README.md b/examples/agent_with_a2a_client/README.md deleted file mode 100644 index 7c20108..0000000 --- a/examples/agent_with_a2a_client/README.md +++ /dev/null @@ -1,12 +0,0 @@ -# `agent_with_a2a_client` - -Same **`A2A_*`** settings as **`agent_with_a2a_config`**, but registers the client with **`a2aclient.NewClient`** + **`WithA2AClients`**. - -**How to run a remote A2A server for testing, env vars, and curl checks:** see **[../agent_with_a2a_config/README.md](../agent_with_a2a_config/README.md)**. - -From **`examples/`**: - -```bash -go run ./agent_with_a2a_client -go run ./agent_with_a2a_client "What tools do you have available?" -``` diff --git a/examples/agent_with_a2a_client/main.go b/examples/agent_with_a2a_client/main.go index 441336d..f8782d8 100644 --- a/examples/agent_with_a2a_client/main.go +++ b/examples/agent_with_a2a_client/main.go @@ -53,18 +53,13 @@ func main() { agent.WithName("agent-with-a2a-client"), agent.WithDescription("Agent with A2A from env (WithA2AClients)"), agent.WithSystemPrompt("You are a helpful assistant. Use A2A tools from your tool list when they help answer the user."), - agent.WithTemporalConfig(&agent.TemporalConfig{ - Host: cfg.Host, - Port: cfg.Port, - Namespace: cfg.Namespace, - TaskQueue: cfg.TaskQueue, - }), agent.WithLLMClient(llmClient), agent.WithA2AClients(cl), agent.WithToolApprovalPolicy(agent.AutoToolApprovalPolicy()), agent.WithLogger(logCfg), agent.WithLogLevel(cfg.LogLevel), } + opts = append(opts, config.RuntimeOption(cfg)...) a, err := agent.NewAgent(opts...) if err != nil { diff --git a/examples/agent_with_a2a_config/README.md b/examples/agent_with_a2a_config/README.md index a1d4ed0..96432fa 100644 --- a/examples/agent_with_a2a_config/README.md +++ b/examples/agent_with_a2a_config/README.md @@ -3,11 +3,11 @@ Outbound A2A: your agent calls **remote** A2A agents as tools. - **`agent_with_a2a_config`** — `agent.WithA2AConfig(agent.A2AServers{: cfg})`; SDK builds the default **[pkg/a2a/client](https://pkg.go.dev/github.com/agenticenv/agent-sdk-go/pkg/a2a/client)** per server. -- **`agent_with_a2a_client`** — `a2aclient.NewClient(...)` + **`WithA2AClients`**; same env and testing flow — see **[../agent_with_a2a_client/README.md](../agent_with_a2a_client/README.md)**. +- **`agent_with_a2a_client`** — `a2aclient.NewClient(...)` + **`WithA2AClients`**; same env and testing flow as above. ## Prerequisites -- **Temporal** + **LLM** in **`examples/.env`** (see **`../env.sample`**). +- **LLM** in **`examples/.env`** (see **`../.env.defaults`**). **Temporal** only when **`AGENT_RUNTIME=temporal`**. **Required for A2A:** **`A2A_URL`** — base URL of the remote agent (scheme + host + port, **no path**). Optional: **`A2A_SERVER_NAME`**, **`A2A_TOKEN`**, **`A2A_HEADERS`** (JSON), **`A2A_TIMEOUT_SECONDS`**, **`A2A_SKIP_TLS_VERIFY`** (dev), **`A2A_ALLOW_SKILLS`** / **`A2A_BLOCK_SKILLS`** (comma-separated; mutually exclusive). @@ -65,4 +65,4 @@ curl -sS "http://localhost:9999/.well-known/agent-card.json" | head ## Env vars (A2A client) -See **`A2A_*`** rows in **[examples/README.md](../README.md#env-vars)** and **`../env.sample`**. +See **`A2A_*`** rows in **[examples/README.md](../README.md#env-vars)** and **`../.env.defaults`**. diff --git a/examples/agent_with_a2a_config/main.go b/examples/agent_with_a2a_config/main.go index 135a50c..e4a3865 100644 --- a/examples/agent_with_a2a_config/main.go +++ b/examples/agent_with_a2a_config/main.go @@ -28,18 +28,13 @@ func main() { agent.WithName("agent-with-a2a-config"), agent.WithDescription("Agent with A2A from env (WithA2AConfig)"), agent.WithSystemPrompt("You are a helpful assistant. Use A2A tools from your tool list when they help answer the user."), - agent.WithTemporalConfig(&agent.TemporalConfig{ - Host: cfg.Host, - Port: cfg.Port, - Namespace: cfg.Namespace, - TaskQueue: cfg.TaskQueue, - }), agent.WithLLMClient(llmClient), agent.WithA2AConfig(agent.A2AServers{serverName: a2aCfg}), agent.WithToolApprovalPolicy(agent.AutoToolApprovalPolicy()), agent.WithLogger(config.NewLoggerFromLogConfig(cfg)), agent.WithLogLevel(cfg.LogLevel), } + opts = append(opts, config.RuntimeOption(cfg)...) a, err := agent.NewAgent(opts...) if err != nil { diff --git a/examples/agent_with_a2a_server/README.md b/examples/agent_with_a2a_server/README.md index e70c84d..8147420 100644 --- a/examples/agent_with_a2a_server/README.md +++ b/examples/agent_with_a2a_server/README.md @@ -4,7 +4,7 @@ Runs **your** agent as an **inbound** A2A HTTP server (dynamic agent card + **JS ## Prerequisites -- **Temporal** + **LLM** in **`examples/.env`** (see **`../env.sample`**). +- **LLM** in **`examples/.env`** (see **`../.env.defaults`**). **Temporal** only when **`AGENT_RUNTIME=temporal`**. - Optional **`A2A_SERVER_HOST`**, **`A2A_SERVER_PORT`** (defaults **localhost:9999**). - Optional **`A2A_SERVER_BEARER_TOKENS`** — comma-separated secrets; JSON-RPC calls must send **`Authorization: Bearer `** (agent card GET stays unauthenticated). @@ -52,4 +52,4 @@ Full flags: **`a2a help`**; reference: [**a2a-go** `cmd/README.md`](https://gith ## Env vars (inbound server) -See **`A2A_SERVER_*`** rows in **[examples/README.md](../README.md#env-vars)** and **`../env.sample`**. +See **`A2A_SERVER_*`** rows in **[examples/README.md](../README.md#env-vars)** and **`../.env.defaults`**. diff --git a/examples/agent_with_a2a_server/main.go b/examples/agent_with_a2a_server/main.go index cc65603..fb9bad5 100644 --- a/examples/agent_with_a2a_server/main.go +++ b/examples/agent_with_a2a_server/main.go @@ -31,12 +31,6 @@ func main() { agent.WithName("agent-with-a2a-server"), agent.WithDescription("Example agent exposed as an A2A HTTP server (agent card + JSON-RPC)."), agent.WithSystemPrompt("You are a helpful assistant. You have an echo tool when the user asks to repeat text."), - agent.WithTemporalConfig(&agent.TemporalConfig{ - Host: cfg.Host, - Port: cfg.Port, - Namespace: cfg.Namespace, - TaskQueue: cfg.TaskQueue, - }), agent.WithLLMClient(llmClient), agent.WithToolRegistry(reg), agent.WithToolApprovalPolicy(agent.AutoToolApprovalPolicy()), @@ -45,6 +39,7 @@ func main() { agent.WithLogger(config.NewLoggerFromLogConfig(cfg)), agent.WithLogLevel(cfg.LogLevel), } + opts = append(opts, config.RuntimeOption(cfg)...) a, err := agent.NewAgent(opts...) if err != nil { diff --git a/examples/agent_copilotkit/README.md b/examples/agent_with_agui/README.md similarity index 89% rename from examples/agent_copilotkit/README.md rename to examples/agent_with_agui/README.md index ea9a40f..66b9748 100644 --- a/examples/agent_copilotkit/README.md +++ b/examples/agent_with_agui/README.md @@ -1,4 +1,4 @@ -# CopilotKit + agent streaming (minimal) +# AG-UI / CopilotKit agent streaming (minimal) A tiny **Go HTTP server** streams agent events as **SSE** (`POST /agui`). A **Next.js** UI uses **CopilotKit** with the usual **`/api/copilotkit`** runtime route; that route forwards requests to the Go agent via **`@ag-ui/client` `HttpAgent`**. @@ -13,7 +13,7 @@ CopilotKit expects its **runtime** handler, not a raw Go URL, in `runtimeUrl`— From the **agent-sdk-go** repo root (or this example’s parent with `go` module in path): ```bash -go run ./examples/agent_copilotkit/server +go run ./examples/agent_with_agui/server ``` Listens on **`:8787`** by default (override with `PORT=` — avoids conflicting with apps on 8080). Health: `GET http://localhost:8787/health`. @@ -22,7 +22,7 @@ Stream: `POST http://localhost:8787/agui` with JSON `{"prompt":"Hello"}` or `{"m ## 2) Start the UI ```bash -cd examples/agent_copilotkit/ui +cd examples/agent_with_agui/ui npm install npm run dev ``` @@ -49,5 +49,5 @@ You should see `data: {...}` lines (AG-UI-style JSON from `event.ToJSON()`). ## Notes - **`ui/node_modules`** and **`ui/.next`** are listed in the repo root `.gitignore` — run `npm install` in `ui/` after clone; do not commit those directories. -- Keep **Temporal** and the **Go server** running before using the chat UI. +- Keep the **Go server** running before using the chat UI. Temporal is only needed when `AGENT_RUNTIME=temporal`. - CopilotKit / `@ag-ui/client` versions may need to stay compatible; if the UI errors, check [CopilotKit AG-UI docs](https://docs.copilotkit.ai) and align package versions. diff --git a/examples/agent_copilotkit/server/main.go b/examples/agent_with_agui/server/main.go similarity index 94% rename from examples/agent_copilotkit/server/main.go rename to examples/agent_with_agui/server/main.go index 10287f1..ee2255c 100644 --- a/examples/agent_copilotkit/server/main.go +++ b/examples/agent_with_agui/server/main.go @@ -49,16 +49,10 @@ func main() { reg.Register(echo.New()) reg.Register(calculator.New()) - a, err := agent.NewAgent( - agent.WithName("copilotkit-demo-agent"), - agent.WithDescription("Streaming demo for CopilotKit / AG-UI"), + agentOpts := []agent.Option{ + agent.WithName("agui-demo-agent"), + agent.WithDescription("Streaming demo for AG-UI / CopilotKit"), agent.WithSystemPrompt("You are a helpful assistant. Be concise."), - agent.WithTemporalConfig(&agent.TemporalConfig{ - Host: cfg.Host, - Port: cfg.Port, - Namespace: cfg.Namespace, - TaskQueue: cfg.TaskQueue, - }), agent.WithLLMClient(llmClient), agent.WithStream(true), agent.WithLLMSampling(&agent.LLMSampling{ @@ -71,7 +65,9 @@ func main() { agent.WithToolRegistry(reg), agent.WithToolApprovalPolicy(agent.AutoToolApprovalPolicy()), agent.WithLogger(config.NewLoggerFromLogConfig(cfg)), - ) + } + agentOpts = append(agentOpts, config.RuntimeOption(cfg)...) + a, err := agent.NewAgent(agentOpts...) if err != nil { log.Fatal(config.FormatNewAgentError("agent", err)) } diff --git a/examples/agent_copilotkit/server/main_test.go b/examples/agent_with_agui/server/main_test.go similarity index 100% rename from examples/agent_copilotkit/server/main_test.go rename to examples/agent_with_agui/server/main_test.go diff --git a/examples/agent_copilotkit/ui/app/api/copilotkit/route.ts b/examples/agent_with_agui/ui/app/api/copilotkit/route.ts similarity index 100% rename from examples/agent_copilotkit/ui/app/api/copilotkit/route.ts rename to examples/agent_with_agui/ui/app/api/copilotkit/route.ts diff --git a/examples/agent_copilotkit/ui/app/layout.tsx b/examples/agent_with_agui/ui/app/layout.tsx similarity index 100% rename from examples/agent_copilotkit/ui/app/layout.tsx rename to examples/agent_with_agui/ui/app/layout.tsx diff --git a/examples/agent_copilotkit/ui/app/page.tsx b/examples/agent_with_agui/ui/app/page.tsx similarity index 100% rename from examples/agent_copilotkit/ui/app/page.tsx rename to examples/agent_with_agui/ui/app/page.tsx diff --git a/examples/agent_copilotkit/ui/next-env.d.ts b/examples/agent_with_agui/ui/next-env.d.ts similarity index 100% rename from examples/agent_copilotkit/ui/next-env.d.ts rename to examples/agent_with_agui/ui/next-env.d.ts diff --git a/examples/agent_copilotkit/ui/next.config.ts b/examples/agent_with_agui/ui/next.config.ts similarity index 100% rename from examples/agent_copilotkit/ui/next.config.ts rename to examples/agent_with_agui/ui/next.config.ts diff --git a/examples/agent_copilotkit/ui/package-lock.json b/examples/agent_with_agui/ui/package-lock.json similarity index 100% rename from examples/agent_copilotkit/ui/package-lock.json rename to examples/agent_with_agui/ui/package-lock.json diff --git a/examples/agent_copilotkit/ui/package.json b/examples/agent_with_agui/ui/package.json similarity index 100% rename from examples/agent_copilotkit/ui/package.json rename to examples/agent_with_agui/ui/package.json diff --git a/examples/agent_copilotkit/ui/tsconfig.json b/examples/agent_with_agui/ui/tsconfig.json similarity index 100% rename from examples/agent_copilotkit/ui/tsconfig.json rename to examples/agent_with_agui/ui/tsconfig.json diff --git a/examples/agent_with_conversation/main.go b/examples/agent_with_conversation/main.go index 48a6017..384a6cb 100644 --- a/examples/agent_with_conversation/main.go +++ b/examples/agent_with_conversation/main.go @@ -35,12 +35,6 @@ func main() { agent.WithName("agent-with-conversation"), agent.WithDescription("Agent with in-memory conversation and tools for multi-turn context"), agent.WithSystemPrompt("You are a helpful assistant. Remember the conversation context. Use tools when helpful: echo for repeating, calculator for math."), - agent.WithTemporalConfig(&agent.TemporalConfig{ - Host: cfg.Host, - Port: cfg.Port, - Namespace: cfg.Namespace, - TaskQueue: cfg.TaskQueue, - }), agent.WithLLMClient(llmClient), agent.WithToolRegistry(reg), agent.WithToolApprovalPolicy(agent.AutoToolApprovalPolicy()), @@ -48,6 +42,7 @@ func main() { agent.WithConversationSize(20), agent.WithLogger(config.NewLoggerFromLogConfig(cfg)), } + opts = append(opts, config.RuntimeOption(cfg)...) a, err := agent.NewAgent(opts...) if err != nil { diff --git a/examples/agent_with_json_response/main.go b/examples/agent_with_json_response/main.go index 2249142..58e41ea 100644 --- a/examples/agent_with_json_response/main.go +++ b/examples/agent_with_json_response/main.go @@ -47,16 +47,11 @@ func main() { agent.WithName("agent-json-response"), agent.WithDescription("Example agent constrained to JSON output via ResponseFormat / JSONSchema"), agent.WithSystemPrompt("You are a precise assistant. Respond only with JSON that matches the configured schema. No markdown fences or extra text."), - agent.WithTemporalConfig(&agent.TemporalConfig{ - Host: cfg.Host, - Port: cfg.Port, - Namespace: cfg.Namespace, - TaskQueue: cfg.TaskQueue, - }), agent.WithLLMClient(llmClient), agent.WithResponseFormat(rf), agent.WithLogger(config.NewLoggerFromLogConfig(cfg)), } + opts = append(opts, config.RuntimeOption(cfg)...) a, err := agent.NewAgent(opts...) if err != nil { @@ -66,7 +61,7 @@ func main() { prompt := strings.Join(os.Args[1:], " ") if prompt == "" { - prompt = "What is the capital of France?" + prompt = "Hi" } fmt.Println("user:", prompt) diff --git a/examples/agent_with_mcp_client/main.go b/examples/agent_with_mcp_client/main.go index d80e7f0..6f4f48d 100644 --- a/examples/agent_with_mcp_client/main.go +++ b/examples/agent_with_mcp_client/main.go @@ -48,18 +48,13 @@ func main() { agent.WithName("agent-with-mcp-client"), agent.WithDescription("Agent with MCP from env: stdio or streamable HTTP (WithMCPClients)"), agent.WithSystemPrompt("You are a helpful assistant. Use MCP or other tools from your tool list when they help answer the user."), - agent.WithTemporalConfig(&agent.TemporalConfig{ - Host: cfg.Host, - Port: cfg.Port, - Namespace: cfg.Namespace, - TaskQueue: cfg.TaskQueue, - }), agent.WithLLMClient(llmClient), agent.WithMCPClients(mcpClient), agent.WithToolApprovalPolicy(agent.AutoToolApprovalPolicy()), agent.WithLogger(config.NewLoggerFromLogConfig(cfg)), agent.WithLogLevel(cfg.LogLevel), } + opts = append(opts, config.RuntimeOption(cfg)...) a, err := agent.NewAgent(opts...) if err != nil { diff --git a/examples/agent_with_mcp_config/README.md b/examples/agent_with_mcp_config/README.md index 1e9d656..25b6abf 100644 --- a/examples/agent_with_mcp_config/README.md +++ b/examples/agent_with_mcp_config/README.md @@ -7,19 +7,19 @@ These two programs use the **same env-driven MCP transport** but wire the agent ## Prerequisites -- A **running Temporal** server (see repo root **[temporal-setup.md](../../temporal-setup.md)**). -- **`examples/.env`** copied from **`examples/env.sample`** with **`LLM_*`** set (same as other examples). +- **`examples/.env`** with **`LLM_*`** set (see **`../.env.defaults`**; defaults load automatically). +- **`AGENT_RUNTIME=temporal`** only if you want durable workflows — then Temporal per **[temporal-setup.md](../../temporal-setup.md)** or `task -t examples/Taskfile.yml infra:temporal:up`. ## Configure MCP -**Transport** must be set with **`MCP_TRANSPORT`**: `stdio` or `streamable_http` (aliases in **`env.sample`**). +**Transport** must be set with **`MCP_TRANSPORT`**: `stdio` or `streamable_http` (aliases in **`.env.defaults`**). - **Remote — `streamable_http`:** set **`MCP_STREAMABLE_HTTP_URL`**. Auth optional: **`MCP_BEARER_TOKEN`**, or OAuth trio **`MCP_CLIENT_ID`** + **`MCP_CLIENT_SECRET`** + **`MCP_TOKEN_URL`** (OAuth wins over bearer when all three are set). **`MCP_SKIP_TLS_VERIFY=true`** for dev TLS only. - **Local — `stdio`:** set **`MCP_STDIO_COMMAND`** and optional **`MCP_STDIO_ARGS`** (JSON string array) and **`MCP_STDIO_ENV`** (JSON string→string object). Shared optional knobs: **`MCP_SERVER_NAME`**, **`MCP_TIMEOUT_SECONDS`**, **`MCP_RETRY_ATTEMPTS`**, **`MCP_ALLOW_TOOLS`** / **`MCP_BLOCK_TOOLS`** (comma-separated; only one list type). -See **`../env.sample`** for every variable. +See **`../.env.defaults`** for every variable. ## Run @@ -35,7 +35,7 @@ go run ./agent_with_mcp_client "List tools you can call." ## Testing against real MCP servers -This repo does **not** start an MCP server—you point **`examples/.env`** at **your** server(s). Pick **stdio** or **streamable_http**, set **`MCP_TRANSPORT`**, then fill in **`env.sample`** under **MCP**. +This repo does **not** start an MCP server—you point **`examples/.env`** at **your** server(s). Pick **stdio** or **streamable_http**, set **`MCP_TRANSPORT`**, then fill in **`.env.defaults`** under **MCP**. ### Worked example — TypeScript streamable HTTP (`mcp-streamable-http`) @@ -77,15 +77,15 @@ curl -sS -o /dev/null -w "%{http_code}\n" "http://localhost:8123/mcp" | **[invariantlabs-ai/mcp-streamable-http](https://github.com/invariantlabs-ai/mcp-streamable-http)** | Reference **streamable HTTP** server (TypeScript example above). Default **8123**, path **`/mcp`**; set **`MCP_STREAMABLE_HTTP_URL`** to the full endpoint URL. | | **[modelcontextprotocol/servers](https://github.com/modelcontextprotocol/servers)** | Reference implementations (filesystem, git, fetch, etc.). Often **npx** / **uvx** / **docker** per each server’s README; map into **`MCP_STDIO_COMMAND`** + **`MCP_STDIO_ARGS`**. | | **[Model Context Protocol](https://modelcontextprotocol.io)** | Protocol docs; third-party hosts list streamable-HTTP endpoints you can point **`MCP_STREAMABLE_HTTP_URL`** at. | -| **Your own MCP server** | Any compliant implementation—the examples need **`stdio`** or **`streamable_http`** as wired in **`env.sample`**. | +| **Your own MCP server** | Any compliant implementation—the examples need **`stdio`** or **`streamable_http`** as wired in **`.env.defaults`**. | ### Quick checks before running - **`streamable_http`:** Confirm the URL is reachable from the machine running the example. Example: `curl -sS -o /dev/null -w "%{http_code}\n" "$MCP_STREAMABLE_HTTP_URL"` — status depends on the implementation. - **`stdio`:** Run the same command line as **`MCP_STDIO_COMMAND`** / **`MCP_STDIO_ARGS`** in a terminal once to ensure the binary starts. -You still need **Temporal** and **LLM** credentials in **`examples/.env`**. +You still need **LLM** credentials in **`examples/.env`** (and Temporal when using **`AGENT_RUNTIME=temporal`**). ## Env vars (MCP) -See the **MCP_*** rows in **[examples/README.md](../README.md#env-vars)** and **`../env.sample`**. +See the **MCP_*** rows in **[examples/README.md](../README.md#env-vars)** and **`../.env.defaults`**. diff --git a/examples/agent_with_mcp_config/main.go b/examples/agent_with_mcp_config/main.go index 559123b..2d540e9 100644 --- a/examples/agent_with_mcp_config/main.go +++ b/examples/agent_with_mcp_config/main.go @@ -44,18 +44,13 @@ func main() { agent.WithName("agent-with-mcp-config"), agent.WithDescription("Agent with MCP from env: stdio or streamable HTTP (WithMCPConfig)"), agent.WithSystemPrompt("You are a helpful assistant. Use MCP or other tools from your tool list when they help answer the user."), - agent.WithTemporalConfig(&agent.TemporalConfig{ - Host: cfg.Host, - Port: cfg.Port, - Namespace: cfg.Namespace, - TaskQueue: cfg.TaskQueue, - }), agent.WithLLMClient(llmClient), agent.WithMCPConfig(agent.MCPServers{serverName: mcpCfg}), agent.WithToolApprovalPolicy(agent.AutoToolApprovalPolicy()), agent.WithLogger(config.NewLoggerFromLogConfig(cfg)), agent.WithLogLevel(cfg.LogLevel), } + opts = append(opts, config.RuntimeOption(cfg)...) a, err := agent.NewAgent(opts...) if err != nil { diff --git a/examples/agent_with_observability/README.md b/examples/agent_with_observability/README.md index ee2bc15..1fe3b3d 100644 --- a/examples/agent_with_observability/README.md +++ b/examples/agent_with_observability/README.md @@ -13,9 +13,9 @@ Do not combine **`WithObservabilityConfig`** with injected **`WithTracer` / `Wit ## Prerequisites -1. **Temporal** — same variables as other examples (`TEMPORAL_HOST`, `TEMPORAL_PORT`, `TEMPORAL_NAMESPACE`, task queue). See [`temporal-setup.md`](../../temporal-setup.md) at the repository root. +1. **LLM** — `LLM_PROVIDER`, `LLM_APIKEY`, `LLM_MODEL`, optional `LLM_BASEURL` per [`../.env.defaults`](../.env.defaults). Add secrets in **`examples/.env`**. -2. **LLM** — `LLM_PROVIDER`, `LLM_APIKEY`, `LLM_MODEL`, optional `LLM_BASEURL` per [`../env.sample`](../env.sample). Optional: copy to `examples/.env`. +2. **Temporal** — only when **`AGENT_RUNTIME=temporal`** (`TEMPORAL_HOST`, `TEMPORAL_PORT`, `TEMPORAL_NAMESPACE`, task queue). See [`temporal-setup.md`](../../temporal-setup.md). 3. **OTLP collector** — accepts OpenTelemetry Protocol on **gRPC** (default) or **HTTP/protobuf**. Use **host:port only** (no `http://` scheme). See the table below. diff --git a/examples/agent_with_observability/setup/setup.go b/examples/agent_with_observability/setup/setup.go index 1abc016..428a70e 100644 --- a/examples/agent_with_observability/setup/setup.go +++ b/examples/agent_with_observability/setup/setup.go @@ -52,22 +52,17 @@ func BaseAgentOptions(cfg *excfg.Config, llm interfaces.LLMClient) []agent.Optio reg := tools.NewRegistry() reg.Register(calculator.New()) - return []agent.Option{ + opts := []agent.Option{ agent.WithName("observability-example-agent"), agent.WithDescription("Agent demonstrating OTLP wiring (see examples/agent_with_observability)."), agent.WithSystemPrompt("You are a concise assistant."), - agent.WithTemporalConfig(&agent.TemporalConfig{ - Host: cfg.Host, - Port: cfg.Port, - Namespace: cfg.Namespace, - TaskQueue: cfg.TaskQueue, - }), agent.WithLLMClient(llm), agent.WithToolRegistry(reg), agent.WithToolApprovalPolicy(agent.AutoToolApprovalPolicy()), agent.WithLogLevel(cfg.LogLevel), //agent.WithLogger(excfg.NewLoggerFromLogConfig(cfg)), } + return append(opts, excfg.RuntimeOption(cfg)...) } // UserPrompt returns command-line text after the program name, or a default line if empty. diff --git a/examples/agent_with_reasoning/main.go b/examples/agent_with_reasoning/main.go index 232ebe8..dc4cfff 100644 --- a/examples/agent_with_reasoning/main.go +++ b/examples/agent_with_reasoning/main.go @@ -39,12 +39,6 @@ func main() { agent.WithName("agent-with-reasoning"), agent.WithDescription("Example: WithLLMSampling + generic LLMReasoning"), agent.WithSystemPrompt("You are a helpful assistant. Be concise."), - agent.WithTemporalConfig(&agent.TemporalConfig{ - Host: cfg.Host, - Port: cfg.Port, - Namespace: cfg.Namespace, - TaskQueue: cfg.TaskQueue, - }), agent.WithLLMClient(llmClient), agent.WithStream(true), agent.WithLLMSampling(&agent.LLMSampling{ @@ -56,6 +50,7 @@ func main() { }), agent.WithLogger(config.NewLoggerFromLogConfig(cfg)), } + opts = append(opts, config.RuntimeOption(cfg)...) a, err := agent.NewAgent(opts...) if err != nil { diff --git a/examples/agent_with_retriever/README.md b/examples/agent_with_retriever/README.md index 6d4e70b..0eaaafa 100644 --- a/examples/agent_with_retriever/README.md +++ b/examples/agent_with_retriever/README.md @@ -2,37 +2,25 @@ Examples that wire a **vector retriever** into **agent-sdk-go**. Pick **one backend** per run. -| Backend | Directory | Guide | -|---------|-----------|--------| -| Weaviate | [`weaviate/`](weaviate/) | [`weaviate/README.md`](weaviate/README.md) | -| PostgreSQL + pgvector | [`pgvector/`](pgvector/) | [`pgvector/README.md`](pgvector/README.md) | +| Backend | Package | Example entrypoint | +|---------|---------|-------------------| +| Weaviate | [`pkg/retriever/weaviate`](../../pkg/retriever/weaviate) | `go run ./agent_with_retriever/weaviate` | +| PostgreSQL + pgvector | [`pkg/retriever/pgvector`](../../pkg/retriever/pgvector) | `go run ./agent_with_retriever/pgvector` | -Shared sample data: [`common/sample-documents.json`](common/sample-documents.json). +Sample KB JSON (edit per backend, then re-seed): [`docker/weaviate/sample-documents.json`](../docker/weaviate/sample-documents.json), [`docker/pgvector/sample-documents.json`](../docker/pgvector/sample-documents.json). Infra: [`../docker/`](../docker/) (compose + seed scripts). ## Prerequisites -- **Temporal** — [`temporal-setup.md`](../../temporal-setup.md) -- **LLM** — `LLM_APIKEY`, `LLM_MODEL` in `examples/.env` ([`env.sample`](../env.sample)) -- **Vector store** — set up via `./setup.sh` in the backend folder you choose +- **Runtime** — **`AGENT_RUNTIME=local`** (default): in-process, no Temporal. Optional **`AGENT_RUNTIME=temporal`**: from `examples/`, run `task infra:temporal:up` (and `task infra:temporal:wait` if the example fails to connect). That starts the compose dev server on `localhost:7233`. For Temporal CLI, Cloud, or other hosts, see [`temporal-setup.md`](../../temporal-setup.md). +- **`examples/.env`** — `LLM_APIKEY`, `LLM_MODEL`, and **`EMBEDDING_OPENAI_APIKEY`** (see **`.env.defaults`**) +- **Task** (`go-task`) and **Docker** for the vector store you use (`task infra:weaviate:up` or `task infra:pgvector:up`) -## Quick start +From `examples/`: ```bash -cd examples -cp env.sample .env -# Edit .env: LLM keys and backend vars (see env.sample) - -# Weaviate -cd agent_with_retriever/weaviate && ./setup.sh && cd ../.. -go run ./agent_with_retriever/weaviate "What is the return policy?" - -# pgvector -cd agent_with_retriever/pgvector && ./setup.sh && cd ../.. -go run ./agent_with_retriever/pgvector "What is the return policy?" +task infra:status # see what is up ``` -Cleanup: `./cleanup.sh` in the backend folder when done. - ## Retriever modes Set `RETRIEVER_MODE` in `.env` (default `agentic`): @@ -43,30 +31,109 @@ Set `RETRIEVER_MODE` in `.env` (default `agentic`): | `prefetch` | Search runs once before the first LLM call; context injected into system prompt | | `hybrid` | Prefetch and retriever tools | +Prefetch/hybrid embed your **exact user message** — use concrete questions aligned with the sample KB (returns, shipping, warranty, etc.). + +--- + +## Weaviate + +Weaviate embeds queries via **nearText** (`text2vec-openai` in Docker). Chat LLM can differ from the embedding provider. + +### Setup + ```bash +cd examples +task infra:weaviate:up +task infra:weaviate:down # when finished +``` + +Compose: [`docker/docker-compose.yml`](../docker/docker-compose.yml). Seed: [`docker/weaviate/seed.sh`](../docker/weaviate/seed.sh). + +`EMBEDDING_OPENAI_APIKEY` must be set in `examples/.env` **before** `up` (baked into the container). After a key change: `task infra:weaviate:down && task infra:weaviate:up`. + +### Environment + +```bash +WEAVIATE_HOST=localhost:8080 +WEAVIATE_SCHEME=http +WEAVIATE_CLASS=Document +WEAVIATE_RETRIEVER_NAME=weaviate-kb +RETRIEVER_MODE=agentic +# WEAVIATE_MIN_SCORE=0.5 # optional; SDK default 0.75 +``` + +### Run + +```bash +go run ./agent_with_retriever/weaviate "What is the return policy?" RETRIEVER_MODE=prefetch go run ./agent_with_retriever/weaviate "What is the return policy?" ``` -## Troubleshooting +### Weaviate troubleshooting -| Issue | Where to look | -|-------|----------------| -| Weaviate setup, search, vectorizer | [`weaviate/README.md`](weaviate/README.md#troubleshooting) | -| pgvector setup, embeddings, `minScore` | [`pgvector/README.md`](pgvector/README.md#troubleshooting) | +| Symptom | What to do | +|---------|------------| +| Compose / API key errors | Set `EMBEDDING_OPENAI_APIKEY`, then `task infra:weaviate:down && task infra:weaviate:up` | +| Connection refused `:8080` | `task infra:status`, `curl -s http://localhost:8080/v1/.well-known/ready`, `docker logs weaviate` | +| Empty search / no relevant docs | Re-seed with `task infra:weaviate:up`; check `WEAVIATE_CLASS=Document`; list objects: `curl -s "http://localhost:8080/v1/objects?class=Document&limit=5"`; try `RETRIEVER_MODE=prefetch` | +| Port 8080 / 50051 in use | `task infra:weaviate:down`; set `WEAVIATE_HTTP_PORT` / `WEAVIATE_GRPC_PORT` before `up` | +| LLM ignores KB (agentic) | Confirm objects exist; use prefetch mode | -**Common checks (all examples):** +```bash +LOG_LEVEL=debug go run ./agent_with_retriever/weaviate "What is the return policy?" +``` + +--- + +## pgvector + +Client-side **OpenAI-compatible** embeddings, then cosine search in Postgres ([pgvector](https://github.com/pgvector/pgvector)). + +### Setup + +```bash +cd examples +task infra:pgvector:up +task infra:pgvector:down # when finished +``` + +Schema: [`docker/pgvector/setup.sql`](../docker/pgvector/setup.sql). Seed: [`docker/pgvector/seed.sh`](../docker/pgvector/seed.sh). -- **Temporal** running — see [`temporal-setup.md`](../../temporal-setup.md) -- **`examples/.env`** — `LLM_APIKEY`, `LLM_MODEL`, and backend vars from [`env.sample`](../env.sample) -- **Vector store up** — `./setup.sh` in `weaviate/` or `pgvector/` before `go run` -- **Retriever mode** — `RETRIEVER_MODE=agentic|prefetch|hybrid` in `.env` -- **Debug** — `LOG_LEVEL=debug go run ./agent_with_retriever/ "..."` +Default DSN (in **`.env.defaults`**): `postgres://postgres:secret@localhost:5432/vectordb?sslmode=disable` -**pgvector + Anthropic/Gemini chat:** set `EMBEDDING_APIKEY` (OpenAI) in `.env`; chat `LLM_APIKEY` is not used for embeddings. +### Environment -**Clean restart a backend:** +```bash +PGVECTOR_DSN=postgres://postgres:secret@localhost:5432/vectordb?sslmode=disable +PGVECTOR_TABLE=documents +PGVECTOR_RETRIEVER_NAME=pgvector-kb +EMBEDDING_OPENAI_MODEL=text-embedding-3-small +EMBEDDING_OPENAI_APIKEY=sk-... +PGVECTOR_MIN_SCORE=0.35 +RETRIEVER_MODE=agentic +``` + +With **Anthropic/Gemini** chat, `EMBEDDING_OPENAI_APIKEY` is still required for search (not `LLM_APIKEY`). + +### Run ```bash -cd agent_with_retriever/weaviate # or pgvector -./cleanup.sh && ./setup.sh +go run ./agent_with_retriever/pgvector "What is the return policy?" +RETRIEVER_MODE=prefetch go run ./agent_with_retriever/pgvector "What is the return policy?" ``` + +### pgvector troubleshooting + +| Symptom | What to do | +|---------|------------| +| `no relevant documents found` | `task infra:status`; row count: `docker exec pgvector psql -U postgres -d vectordb -t -c "SELECT COUNT(*) FROM documents;"`; re-seed or lower `PGVECTOR_MIN_SCORE` | +| `embedding config` / Anthropic chat | Set `EMBEDDING_OPENAI_APIKEY`; re-seed: `task infra:pgvector:down && task infra:pgvector:up` | +| `PGVECTOR_DSN is required` | Use default DSN or match compose `PGVECTOR_*` vars | +| Dimension / SQL errors | Model must match `vector(1536)` in `setup.sql`; re-seed after model change | +| Port 5432 in use | `task infra:pgvector:down`; set `PGVECTOR_PORT` and update `PGVECTOR_DSN` | + +```bash +LOG_LEVEL=debug go run ./agent_with_retriever/pgvector "What is the return policy?" +``` + +Look for `pgvector search done` with `docs=0` vs embedding errors. diff --git a/examples/agent_with_retriever/common/config.go b/examples/agent_with_retriever/common/config.go index 2ad3f96..19ade30 100644 --- a/examples/agent_with_retriever/common/config.go +++ b/examples/agent_with_retriever/common/config.go @@ -91,16 +91,13 @@ func LoadSettings() (*Settings, error) { PGTopK: getEnvInt("PGVECTOR_TOP_K", 0), // Example default 0.35 — sample KB often scores 0.3–0.6 per topic; 0.5 drops secondary docs on combined queries. PGMinScore: getEnvFloat("PGVECTOR_MIN_SCORE", 0.35), - EmbeddingModel: getEnv("EMBEDDING_MODEL", "text-embedding-3-small"), - EmbeddingBaseURL: strings.TrimSpace(getEnv("EMBEDDING_BASEURL", "")), - EmbeddingAPIKey: strings.TrimSpace(getEnv("EMBEDDING_APIKEY", "")), + EmbeddingModel: getEnv("EMBEDDING_OPENAI_MODEL", "text-embedding-3-small"), + EmbeddingBaseURL: strings.TrimSpace(getEnv("EMBEDDING_OPENAI_BASEURL", "")), + EmbeddingAPIKey: strings.TrimSpace(getEnv("EMBEDDING_OPENAI_APIKEY", "")), } if s.EmbeddingBaseURL == "" { s.EmbeddingBaseURL = strings.TrimSpace(getEnv("LLM_BASEURL", "https://api.openai.com/v1")) } - if s.EmbeddingAPIKey == "" { - s.EmbeddingAPIKey = strings.TrimSpace(getEnv("LLM_APIKEY", "")) - } return s, nil } diff --git a/examples/agent_with_retriever/common/embed_openai.go b/examples/agent_with_retriever/common/embed_openai.go index 90b453f..e98a0ce 100644 --- a/examples/agent_with_retriever/common/embed_openai.go +++ b/examples/agent_with_retriever/common/embed_openai.go @@ -19,11 +19,11 @@ func OpenAIEmbedFunc(settings *Settings) (pgretriever.EmbedFunc, error) { return nil, fmt.Errorf("embed: settings is nil") } if settings.EmbeddingAPIKey == "" { - return nil, fmt.Errorf("embed: EMBEDDING_APIKEY or LLM_APIKEY is required for pgvector") + return nil, fmt.Errorf("embed: EMBEDDING_OPENAI_APIKEY is required for pgvector") } model := strings.TrimSpace(settings.EmbeddingModel) if model == "" { - return nil, fmt.Errorf("embed: EMBEDDING_MODEL is required") + return nil, fmt.Errorf("embed: EMBEDDING_OPENAI_MODEL is required") } base := strings.TrimRight(strings.TrimSpace(settings.EmbeddingBaseURL), "/") client := &http.Client{Timeout: 60 * time.Second} diff --git a/examples/agent_with_retriever/common/embedding.go b/examples/agent_with_retriever/common/embedding.go index 2fd038d..fdea99e 100644 --- a/examples/agent_with_retriever/common/embedding.go +++ b/examples/agent_with_retriever/common/embedding.go @@ -1,35 +1,15 @@ package common -import ( - "fmt" - "os" - "strings" - - "github.com/agenticenv/agent-sdk-go/pkg/interfaces" -) +import "fmt" // ValidateEmbeddingConfig ensures pgvector can call an OpenAI-compatible embeddings API. -// When LLM_PROVIDER is not openai, EMBEDDING_APIKEY (or OPENAI_APIKEY) must be set explicitly. -func ValidateEmbeddingConfig(provider interfaces.LLMProvider, settings *Settings) error { +// Requires EMBEDDING_OPENAI_APIKEY (separate from chat LLM_APIKEY). +func ValidateEmbeddingConfig(settings *Settings) error { if settings == nil { return fmt.Errorf("settings is nil") } if settings.EmbeddingAPIKey == "" { - return fmt.Errorf("EMBEDDING_APIKEY or LLM_APIKEY is required for pgvector embeddings") - } - explicit := strings.TrimSpace(os.Getenv("EMBEDDING_APIKEY")) != "" || - strings.TrimSpace(os.Getenv("OPENAI_APIKEY")) != "" - if explicit { - return nil - } - switch provider { - case interfaces.LLMProviderOpenAI, "": - return nil - default: - return fmt.Errorf( - "pgvector embeddings need an OpenAI-compatible API key in EMBEDDING_APIKEY (or OPENAI_APIKEY); "+ - "LLM_PROVIDER=%s cannot use LLM_APIKEY for /embeddings", - provider, - ) + return fmt.Errorf("EMBEDDING_OPENAI_APIKEY is required for pgvector embeddings (OpenAI-compatible; separate from LLM_APIKEY)") } + return nil } diff --git a/examples/agent_with_retriever/common/opts.go b/examples/agent_with_retriever/common/opts.go index 5290c6c..e91f44b 100644 --- a/examples/agent_with_retriever/common/opts.go +++ b/examples/agent_with_retriever/common/opts.go @@ -3,17 +3,16 @@ package common import ( "fmt" + excfg "github.com/agenticenv/agent-sdk-go/examples" "github.com/agenticenv/agent-sdk-go/pkg/agent" "github.com/agenticenv/agent-sdk-go/pkg/interfaces" "github.com/agenticenv/agent-sdk-go/pkg/logger" ) -// AgentOptions builds shared agent options: Temporal, LLM, retriever mode, and system prompt. +// AgentOptions builds shared agent options: runtime, LLM, retriever mode, and system prompt. // backendLabel is shown in the agent name/description (e.g. "weaviate" or "pgvector"). func AgentOptions( - host string, - port int, - namespace, taskQueue string, + cfg *excfg.Config, llmClient interfaces.LLMClient, log logger.Logger, settings *Settings, @@ -28,19 +27,14 @@ func AgentOptions( backendLabel, mode, ) - return []agent.Option{ + opts := []agent.Option{ agent.WithName(fmt.Sprintf("agent-with-retriever-%s", backendLabel)), agent.WithDescription(fmt.Sprintf("Agent with %s retriever (%s)", backendLabel, mode)), agent.WithSystemPrompt(prompt), - agent.WithTemporalConfig(&agent.TemporalConfig{ - Host: host, - Port: port, - Namespace: namespace, - TaskQueue: taskQueue, - }), agent.WithLLMClient(llmClient), agent.WithLogger(log), agent.WithRetrieverMode(mode), agent.WithToolApprovalPolicy(agent.AutoToolApprovalPolicy()), } + return append(opts, excfg.RuntimeOption(cfg)...) } diff --git a/examples/agent_with_retriever/pgvector/README.md b/examples/agent_with_retriever/pgvector/README.md deleted file mode 100644 index c33c84a..0000000 --- a/examples/agent_with_retriever/pgvector/README.md +++ /dev/null @@ -1,136 +0,0 @@ -# pgvector retriever example - -This program uses [`pkg/retriever/pgvector`](../../../pkg/retriever/pgvector): queries are embedded with an **OpenAI-compatible API**, then searched in PostgreSQL with [**pgvector**](https://github.com/pgvector/pgvector). - -Parent overview: [`../README.md`](../README.md). - -## Quick setup - -```bash -cd examples/agent_with_retriever/pgvector -chmod +x setup.sh cleanup.sh verify.sh -./setup.sh -``` - -Requires **Docker**, **curl**, **jq**, and an OpenAI-compatible key for embeddings (`EMBEDDING_APIKEY`, `OPENAI_APIKEY`, or `LLM_APIKEY` in `examples/.env`). - -**[`setup.sh`](setup.sh)** starts Postgres, applies [`setup.sql`](setup.sql), embeds [`../common/sample-documents.json`](../common/sample-documents.json), and prints `PGVECTOR_DSN` for `.env`. - -```bash -./cleanup.sh # when finished -``` - -## Configure `.env` - -From `examples/` (after `./setup.sh`): - -```bash -# Temporal + LLM (required) -LLM_APIKEY=sk-... -LLM_MODEL=gpt-4o - -# Postgres -PGVECTOR_DSN=postgres://postgres:secret@localhost:5432/vectordb?sslmode=disable -PGVECTOR_TABLE=documents -PGVECTOR_RETRIEVER_NAME=pgvector-kb - -# Embeddings (must match ./setup.sh) -EMBEDDING_MODEL=text-embedding-3-small -EMBEDDING_APIKEY=sk-... # required when LLM_PROVIDER is not openai -# PGVECTOR_MIN_SCORE=0.35 # example default; see env.sample - -# Optional: agentic | prefetch | hybrid -RETRIEVER_MODE=agentic -``` - -Embeddings use **OpenAI** (or `EMBEDDING_*`). Chat can use another provider (e.g. Anthropic). - -## Run the example - -```bash -cd examples -go run ./agent_with_retriever/pgvector "What is the return policy?" -go run ./agent_with_retriever/pgvector "How long does standard shipping take in the US?" - -RETRIEVER_MODE=prefetch go run ./agent_with_retriever/pgvector "What is the return policy?" - -RETRIEVER_MODE=hybrid go run ./agent_with_retriever/pgvector "What are Pro and Enterprise support hours?" -``` - -Sample prompts match the customer-support articles in [`../common/sample-documents.json`](../common/sample-documents.json) (returns, shipping, warranty, support hours, etc.). - -## Verify search (optional) - -```bash -./verify.sh "What is the return policy?" -``` - -Shows row count and similarity scores without running the agent. - -## Troubleshooting - -### `no relevant documents found` - -The retriever ran but no rows passed the similarity filter. - -1. Check data and scores: - ```bash - ./verify.sh "What is the return policy?" - ``` - - **`COUNT` is 0** → run `./setup.sh` again, or fix `PGVECTOR_DSN` in `examples/.env`. - - **Rows exist but low `score`** → lower the threshold in `examples/.env`: - ```bash - PGVECTOR_MIN_SCORE=0.35 - ``` - Re-run the example (startup line shows `minScore: 0.35`). - -2. **Embeddings key** — search uses OpenAI `/embeddings`, not your chat LLM. If `LLM_PROVIDER=anthropic` or `gemini`, set: - ```bash - EMBEDDING_APIKEY=sk-... - EMBEDDING_BASEURL=https://api.openai.com/v1 - ``` - Re-run `./setup.sh` so stored vectors use the same model as queries. - -3. **Prefetch / hybrid** — your full user message is embedded as the search query. Use a concrete question (e.g. *“What is the return policy?”*), not *“Summarize the knowledge base”*. - -### `embedding config: ... LLM_PROVIDER=anthropic` - -Set `EMBEDDING_APIKEY` (OpenAI-compatible) in `examples/.env`. `LLM_APIKEY` alone is not enough when chat uses Anthropic/Gemini. - -### `PGVECTOR_DSN is required` - -Copy `PGVECTOR_DSN` from `./setup.sh` output into `examples/.env`. - -### `dimension mismatch` or SQL errors - -`EMBEDDING_MODEL` must match `vector(1536)` in `setup.sql` (default `text-embedding-3-small`). After changing the model, `./cleanup.sh`, `./setup.sh`, and update `.env`. - -### Connection / port errors - -```bash -./cleanup.sh -./setup.sh -docker logs pgvector -docker ps -``` - -Port **5432** already in use → stop other Postgres or set `PGVECTOR_PORT` and update `PGVECTOR_DSN`. - -### Weak or incomplete answers (prefetch) - -Only documents above `PGVECTOR_MIN_SCORE` are injected. Run `./verify.sh` with your exact prompt; lower `PGVECTOR_MIN_SCORE` if needed docs are below the threshold. - -### Debug logs - -```bash -LOG_LEVEL=debug go run ./agent_with_retriever/pgvector "What is the return policy?" -``` - -Look for `pgvector search done` with `docs=0` vs embed/query errors. - -### Clean reset - -```bash -./cleanup.sh && ./setup.sh -./verify.sh "What is the return policy?" -``` diff --git a/examples/agent_with_retriever/pgvector/cleanup.sh b/examples/agent_with_retriever/pgvector/cleanup.sh deleted file mode 100755 index 79f6feb..0000000 --- a/examples/agent_with_retriever/pgvector/cleanup.sh +++ /dev/null @@ -1,24 +0,0 @@ -#!/usr/bin/env bash -# Stop and remove the local pgvector Postgres Docker container for this example. -# -# Usage (from this directory): -# ./cleanup.sh -# -# Environment: -# PGVECTOR_CONTAINER_NAME default pgvector -set -euo pipefail - -CONTAINER_NAME="${PGVECTOR_CONTAINER_NAME:-pgvector}" - -if ! command -v docker >/dev/null 2>&1; then - echo "error: docker is required but not installed" >&2 - exit 1 -fi - -if docker ps -a --format '{{.Names}}' | grep -qx "$CONTAINER_NAME"; then - echo "Stopping and removing '${CONTAINER_NAME}'..." - docker rm -f "$CONTAINER_NAME" >/dev/null - echo "Done." -else - echo "No container named '${CONTAINER_NAME}' found." -fi diff --git a/examples/agent_with_retriever/pgvector/main.go b/examples/agent_with_retriever/pgvector/main.go index cad1b05..86d39d6 100644 --- a/examples/agent_with_retriever/pgvector/main.go +++ b/examples/agent_with_retriever/pgvector/main.go @@ -4,7 +4,7 @@ // // go run ./examples/agent_with_retriever/pgvector "What do you know about our docs?" // -// See ../README.md and ./README.md for Postgres/pgvector setup and env vars. +// See ../README.md for setup and env vars. package main import ( @@ -27,9 +27,9 @@ func main() { log.Fatalf("retriever config: %v", err) } if retrieverCfg.PGDSN == "" { - log.Fatal("PGVECTOR_DSN is required for the pgvector example; see ./README.md") + log.Fatal("PGVECTOR_DSN is required for the pgvector example; see ../README.md") } - if err := common.ValidateEmbeddingConfig(cfg.Provider, retrieverCfg); err != nil { + if err := common.ValidateEmbeddingConfig(retrieverCfg); err != nil { log.Fatalf("embedding config: %v", err) } @@ -62,10 +62,7 @@ func main() { log.Fatalf("pgvector retriever: %v", err) } - opts := common.AgentOptions( - cfg.Host, cfg.Port, cfg.Namespace, cfg.TaskQueue, - llmClient, logr, retrieverCfg, "pgvector", - ) + opts := common.AgentOptions(cfg, llmClient, logr, retrieverCfg, "pgvector") opts = append(opts, agent.WithRetrievers(retriever)) a, err := agent.NewAgent(opts...) diff --git a/examples/agent_with_retriever/pgvector/setup.sh b/examples/agent_with_retriever/pgvector/setup.sh deleted file mode 100755 index 68711e9..0000000 --- a/examples/agent_with_retriever/pgvector/setup.sh +++ /dev/null @@ -1,223 +0,0 @@ -#!/usr/bin/env bash -# One-shot pgvector setup for the agent_with_retriever/pgvector example: -# - starts PostgreSQL with pgvector in Docker (or reuses existing container) -# - waits until Postgres is ready -# - applies setup.sql (extension, table, index) -# - embeds sample-documents.json via OpenAI-compatible API and inserts rows -# -# Usage (from this directory): -# ./setup.sh -# -# Teardown: ./cleanup.sh -# -# Environment: -# OPENAI_APIKEY / LLM_APIKEY from env or examples/.env (required for embeddings) -# EMBEDDING_MODEL default text-embedding-3-small (1536 dimensions) -# EMBEDDING_BASEURL default LLM_BASEURL from .env or https://api.openai.com/v1 -# PGVECTOR_CONTAINER_NAME default pgvector -# PGVECTOR_PORT default 5432 -set -euo pipefail - -CONTAINER_NAME="${PGVECTOR_CONTAINER_NAME:-pgvector}" -PG_IMAGE="${PGVECTOR_IMAGE:-pgvector/pgvector:pg16}" -PG_PORT="${PGVECTOR_PORT:-5432}" -PG_USER="${PGVECTOR_USER:-postgres}" -PG_PASSWORD="${PGVECTOR_PASSWORD:-secret}" -PG_DB="${PGVECTOR_DB:-vectordb}" -PG_TABLE="${PGVECTOR_TABLE:-documents}" -EMBEDDING_MODEL="${EMBEDDING_MODEL:-text-embedding-3-small}" -READY_TIMEOUT_SEC="${PGVECTOR_READY_TIMEOUT_SEC:-120}" - -PG_DSN="postgres://${PG_USER}:${PG_PASSWORD}@localhost:${PG_PORT}/${PG_DB}?sslmode=disable" - -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -ENV_FILE="${SCRIPT_DIR}/../../.env" -DOCS_FILE="${SCRIPT_DIR}/../common/sample-documents.json" -SQL_FILE="${SCRIPT_DIR}/setup.sql" - -read_env_value() { - local key="$1" file="$2" - [[ -f "$file" ]] || return 1 - local line - line="$(grep -E "^${key}=" "$file" | tail -1 || true)" - [[ -n "$line" ]] || return 1 - line="${line#${key}=}" - line="${line%$'\r'}" - line="${line#\"}"; line="${line%\"}" - line="${line#\'}"; line="${line%\'}" - printf '%s' "$line" -} - -require_cmd() { - if ! command -v "$1" >/dev/null 2>&1; then - echo "error: '$1' is required but not installed" >&2 - exit 1 - fi -} - -sql_escape() { - printf '%s' "$1" | sed "s/'/''/g" -} - -resolve_openai_api_key() { - if [[ -n "${OPENAI_APIKEY:-}" ]]; then - return 0 - fi - if key="$(read_env_value OPENAI_APIKEY "$ENV_FILE" 2>/dev/null)" && [[ -n "$key" ]]; then - export OPENAI_APIKEY="$key" - echo "Using OPENAI_APIKEY from ${ENV_FILE}" - return 0 - fi - if key="$(read_env_value EMBEDDING_APIKEY "$ENV_FILE" 2>/dev/null)" && [[ -n "$key" ]]; then - export OPENAI_APIKEY="$key" - echo "Using EMBEDDING_APIKEY from ${ENV_FILE}" - return 0 - fi - if key="$(read_env_value LLM_APIKEY "$ENV_FILE" 2>/dev/null)" && [[ -n "$key" ]]; then - export OPENAI_APIKEY="$key" - echo "Using LLM_APIKEY from ${ENV_FILE} for embeddings" - return 0 - fi - echo "error: set OPENAI_APIKEY / EMBEDDING_APIKEY / LLM_APIKEY for embedding seed data" >&2 - exit 1 -} - -resolve_embedding_base_url() { - if [[ -n "${EMBEDDING_BASEURL:-}" ]]; then - return 0 - fi - if url="$(read_env_value EMBEDDING_BASEURL "$ENV_FILE" 2>/dev/null)" && [[ -n "$url" ]]; then - export EMBEDDING_BASEURL="$url" - return 0 - fi - if url="$(read_env_value LLM_BASEURL "$ENV_FILE" 2>/dev/null)" && [[ -n "$url" ]]; then - export EMBEDDING_BASEURL="$url" - return 0 - fi - export EMBEDDING_BASEURL="https://api.openai.com/v1" -} - -container_running() { - docker ps --format '{{.Names}}' | grep -qx "$CONTAINER_NAME" -} - -container_exists() { - docker ps -a --format '{{.Names}}' | grep -qx "$CONTAINER_NAME" -} - -wait_for_postgres() { - local deadline=$((SECONDS + READY_TIMEOUT_SEC)) - echo "Waiting for Postgres in '${CONTAINER_NAME}' (timeout ${READY_TIMEOUT_SEC}s)..." - while (( SECONDS < deadline )); do - if docker exec "$CONTAINER_NAME" pg_isready -U "$PG_USER" -d "$PG_DB" >/dev/null 2>&1; then - echo "Postgres is ready." - return 0 - fi - sleep 2 - done - echo "error: Postgres did not become ready within ${READY_TIMEOUT_SEC}s" >&2 - echo "Check logs: docker logs ${CONTAINER_NAME}" >&2 - exit 1 -} - -start_postgres() { - if container_running; then - echo "Container '${CONTAINER_NAME}' is already running." - return 0 - fi - if container_exists; then - echo "Starting existing container '${CONTAINER_NAME}'..." - docker start "$CONTAINER_NAME" >/dev/null - return 0 - fi - - echo "Creating and starting '${CONTAINER_NAME}' (${PG_IMAGE})..." - docker run -d --name "$CONTAINER_NAME" \ - -e POSTGRES_PASSWORD="$PG_PASSWORD" \ - -e POSTGRES_DB="$PG_DB" \ - -p "${PG_PORT}:5432" \ - "$PG_IMAGE" >/dev/null -} - -apply_schema() { - if [[ ! -f "$SQL_FILE" ]]; then - echo "error: missing ${SQL_FILE}" >&2 - exit 1 - fi - echo "Applying schema from setup.sql..." - docker exec -i "$CONTAINER_NAME" psql -U "$PG_USER" -d "$PG_DB" -v ON_ERROR_STOP=1 < "$SQL_FILE" -} - -embed_text() { - local text="$1" - local body response - body=$(jq -n --arg input "$text" --arg model "$EMBEDDING_MODEL" '{input: $input, model: $model}') - response=$(curl -sf "${EMBEDDING_BASEURL%/}/embeddings" \ - -H "Authorization: Bearer ${OPENAI_APIKEY}" \ - -H "Content-Type: application/json" \ - -d "$body") - echo "$response" | jq -c '.data[0].embedding' -} - -seed_documents() { - if [[ ! -f "$DOCS_FILE" ]]; then - echo "error: missing ${DOCS_FILE}" >&2 - exit 1 - fi - - echo "Clearing existing rows in ${PG_TABLE}..." - docker exec "$CONTAINER_NAME" psql -U "$PG_USER" -d "$PG_DB" -v ON_ERROR_STOP=1 \ - -c "TRUNCATE ${PG_TABLE} RESTART IDENTITY;" >/dev/null - - local count=0 row content source vec content_sql source_sql - while IFS= read -r row; do - content="$(echo "$row" | jq -r '.content')" - source="$(echo "$row" | jq -r '.source')" - echo "Embedding document $((count + 1)): ${source}" - vec="$(embed_text "$content")" - if [[ -z "$vec" || "$vec" == "null" ]]; then - echo "error: empty embedding for ${source}" >&2 - exit 1 - fi - content_sql="$(sql_escape "$content")" - source_sql="$(sql_escape "$source")" - docker exec "$CONTAINER_NAME" psql -U "$PG_USER" -d "$PG_DB" -v ON_ERROR_STOP=1 \ - -c "INSERT INTO ${PG_TABLE} (content, source, embedding) VALUES ('${content_sql}', '${source_sql}', '${vec}'::vector);" >/dev/null - count=$((count + 1)) - done < <(jq -c '.[]' "$DOCS_FILE") - - echo "Inserted ${count} documents from sample-documents.json" -} - -require_cmd docker -require_cmd curl -require_cmd jq -resolve_openai_api_key -resolve_embedding_base_url -start_postgres -wait_for_postgres -apply_schema -seed_documents - -cat </dev/null 2>&1 || { echo "error: need $1" >&2; exit 1; } -} - -resolve_openai_api_key() { - [[ -n "${OPENAI_APIKEY:-}" ]] && return 0 - if key="$(read_env_value OPENAI_APIKEY "$ENV_FILE" 2>/dev/null)" && [[ -n "$key" ]]; then - export OPENAI_APIKEY="$key"; return 0 - fi - if key="$(read_env_value EMBEDDING_APIKEY "$ENV_FILE" 2>/dev/null)" && [[ -n "$key" ]]; then - export OPENAI_APIKEY="$key"; return 0 - fi - if key="$(read_env_value LLM_APIKEY "$ENV_FILE" 2>/dev/null)" && [[ -n "$key" ]]; then - export OPENAI_APIKEY="$key"; return 0 - fi - echo "error: set OPENAI_APIKEY or EMBEDDING_APIKEY" >&2 - exit 1 -} - -resolve_embedding_base_url() { - [[ -n "${EMBEDDING_BASEURL:-}" ]] && return 0 - if url="$(read_env_value EMBEDDING_BASEURL "$ENV_FILE" 2>/dev/null)" && [[ -n "$url" ]]; then - export EMBEDDING_BASEURL="$url"; return 0 - fi - if url="$(read_env_value LLM_BASEURL "$ENV_FILE" 2>/dev/null)" && [[ -n "$url" ]]; then - export EMBEDDING_BASEURL="$url"; return 0 - fi - export EMBEDDING_BASEURL="https://api.openai.com/v1" -} - -embed_text() { - local text="$1" body response - body=$(jq -n --arg input "$text" --arg model "${EMBEDDING_MODEL:-text-embedding-3-small}" \ - '{input: $input, model: $model}') - response=$(curl -sf "${EMBEDDING_BASEURL%/}/embeddings" \ - -H "Authorization: Bearer ${OPENAI_APIKEY}" \ - -H "Content-Type: application/json" \ - -d "$body") - echo "$response" | jq -c '.data[0].embedding' -} - -require_cmd docker -require_cmd curl -require_cmd jq -resolve_openai_api_key -resolve_embedding_base_url - -echo "=== row count ===" -docker exec "$CONTAINER_NAME" psql -U "$PG_USER" -d "$PG_DB" -t -c "SELECT COUNT(*) FROM ${PG_TABLE};" - -echo "=== top matches (no min_score filter) for: ${QUERY} ===" -vec="$(embed_text "$QUERY")" -docker exec "$CONTAINER_NAME" psql -U "$PG_USER" -d "$PG_DB" -c \ - "SELECT source, LEFT(content, 60) AS preview, - ROUND((1 - (embedding <=> '${vec}'::vector))::numeric, 4) AS score - FROM ${PG_TABLE} - ORDER BY embedding <=> '${vec}'::vector - LIMIT 5;" - -echo "" -echo "If expected docs are missing, lower PGVECTOR_MIN_SCORE in examples/.env (example default 0.35)" -echo "If COUNT is 0, re-run ./setup.sh" diff --git a/examples/agent_with_retriever/weaviate/README.md b/examples/agent_with_retriever/weaviate/README.md deleted file mode 100644 index 0deb56a..0000000 --- a/examples/agent_with_retriever/weaviate/README.md +++ /dev/null @@ -1,136 +0,0 @@ -# Weaviate retriever example - -This program uses [`pkg/retriever/weaviate`](../../../pkg/retriever/weaviate): Weaviate embeds queries via **nearText** (no client-side embedding). - -Parent overview: [`../README.md`](../README.md). - -## Quick setup - -```bash -cd examples/agent_with_retriever/weaviate -chmod +x setup.sh cleanup.sh -export OPENAI_APIKEY=sk-your-key # or set in examples/.env -./setup.sh -``` - -Requires **Docker**, **curl**, **jq**, and an OpenAI API key for Weaviate’s `text2vec-openai` module. - -**[`setup.sh`](setup.sh)** starts Weaviate, creates the schema, and loads [`../common/sample-documents.json`](../common/sample-documents.json). - -```bash -./cleanup.sh # when finished -``` - -## Configure `.env` - -From `examples/`: - -```bash -# Temporal + LLM (required) -LLM_APIKEY=sk-... -LLM_MODEL=gpt-4o - -# Weaviate (defaults shown) -WEAVIATE_HOST=localhost:8080 -WEAVIATE_SCHEME=http -WEAVIATE_CLASS=Document -WEAVIATE_RETRIEVER_NAME=weaviate-kb - -# Optional: agentic | prefetch | hybrid -RETRIEVER_MODE=agentic -``` - -Weaviate uses **OpenAI** inside Docker for vectors. Chat can use another provider (e.g. Anthropic). - -## Run the example - -```bash -cd examples -go run ./agent_with_retriever/weaviate "What is the return policy?" -go run ./agent_with_retriever/weaviate "How long does standard shipping take in the US?" - -RETRIEVER_MODE=prefetch go run ./agent_with_retriever/weaviate "What is the return policy?" - -RETRIEVER_MODE=hybrid go run ./agent_with_retriever/weaviate "What are Pro and Enterprise support hours?" -``` - -Prompts match articles in [`../common/sample-documents.json`](../common/sample-documents.json). - -## Troubleshooting - -### `OPENAI_APIKEY` error from setup.sh - -Weaviate’s `text2vec-openai` module needs an OpenAI key in the container: - -```bash -export OPENAI_APIKEY=sk-your-key -./setup.sh -``` - -Or add `OPENAI_APIKEY` / `LLM_APIKEY` to `examples/.env` before running `./setup.sh`. - -### Connection refused on `:8080` - -Weaviate is not running or `WEAVIATE_HOST` is wrong. - -```bash -docker ps -./setup.sh -curl -s http://localhost:8080/v1/.well-known/ready -``` - -### Empty search or `no relevant documents found` - -1. Re-seed the sample KB: `./setup.sh` -2. Confirm class name matches `.env`: `WEAVIATE_CLASS=Document` -3. Optional: lower certainty — `WEAVIATE_MIN_SCORE=0.5` in `.env` (SDK default is **0.75**) -4. List objects: - ```bash - curl -s "http://localhost:8080/v1/objects?class=Document&limit=5" - ``` - -### Vectorizer / OpenAI errors in logs - -`OPENAI_APIKEY` must be set when the container starts. Fix and recreate: - -```bash -./cleanup.sh -export OPENAI_APIKEY=sk-your-key -./setup.sh -docker logs weaviate -``` - -### Port already in use (`8080` or `50051`) - -Another process or old container is using the port: - -```bash -./cleanup.sh -./setup.sh -``` - -### Answers ignore the knowledge base - -- Run `./setup.sh` and confirm objects exist (curl above). -- **Agentic mode** — LLM must call `retriever_weaviate-kb`; try **prefetch** to force retrieval: - ```bash - RETRIEVER_MODE=prefetch go run ./agent_with_retriever/weaviate "What is the return policy?" - ``` -- Check `WEAVIATE_HOST`, `WEAVIATE_CLASS`, and `WEAVIATE_SCHEME` in `.env`. - -### Prefetch / hybrid returns little context - -Prefetch searches with your **exact user message**. Use topic questions aligned with the sample KB (returns, shipping, warranty, etc.). - -### Debug logs - -```bash -LOG_LEVEL=debug go run ./agent_with_retriever/weaviate "What is the return policy?" -docker logs weaviate -``` - -### Clean reset - -```bash -./cleanup.sh && ./setup.sh -``` diff --git a/examples/agent_with_retriever/weaviate/cleanup.sh b/examples/agent_with_retriever/weaviate/cleanup.sh deleted file mode 100755 index 53d5e18..0000000 --- a/examples/agent_with_retriever/weaviate/cleanup.sh +++ /dev/null @@ -1,24 +0,0 @@ -#!/usr/bin/env bash -# Stop and remove the local Weaviate Docker container for this example. -# -# Usage (from this directory): -# ./cleanup.sh -# -# Environment: -# WEAVIATE_CONTAINER_NAME default weaviate -set -euo pipefail - -CONTAINER_NAME="${WEAVIATE_CONTAINER_NAME:-weaviate}" - -if ! command -v docker >/dev/null 2>&1; then - echo "error: docker is required but not installed" >&2 - exit 1 -fi - -if docker ps -a --format '{{.Names}}' | grep -qx "$CONTAINER_NAME"; then - echo "Stopping and removing '${CONTAINER_NAME}'..." - docker rm -f "$CONTAINER_NAME" >/dev/null - echo "Done." -else - echo "No container named '${CONTAINER_NAME}' found." -fi diff --git a/examples/agent_with_retriever/weaviate/main.go b/examples/agent_with_retriever/weaviate/main.go index 5c948c2..614766c 100644 --- a/examples/agent_with_retriever/weaviate/main.go +++ b/examples/agent_with_retriever/weaviate/main.go @@ -4,7 +4,7 @@ // // go run ./examples/agent_with_retriever/weaviate "What do you know about our docs?" // -// See ../README.md and ./README.md for Weaviate setup and env vars. +// See ../README.md for setup and env vars. package main import ( @@ -53,10 +53,7 @@ func main() { log.Fatalf("weaviate retriever: %v", err) } - opts := common.AgentOptions( - cfg.Host, cfg.Port, cfg.Namespace, cfg.TaskQueue, - llmClient, logr, retrieverCfg, "weaviate", - ) + opts := common.AgentOptions(cfg, llmClient, logr, retrieverCfg, "weaviate") opts = append(opts, agent.WithRetrievers(retriever)) a, err := agent.NewAgent(opts...) diff --git a/examples/agent_with_retriever/weaviate/setup.sh b/examples/agent_with_retriever/weaviate/setup.sh deleted file mode 100755 index ff91d84..0000000 --- a/examples/agent_with_retriever/weaviate/setup.sh +++ /dev/null @@ -1,164 +0,0 @@ -#!/usr/bin/env bash -# One-shot Weaviate setup for the agent_with_retriever/weaviate example: -# - starts Docker (or reuses an existing weaviate container) -# - waits until the API is ready -# - creates schema and loads sample-documents.json -# -# Usage (from this directory): -# ./setup.sh -# -# Teardown: ./cleanup.sh -# -# Environment: -# OPENAI_APIKEY required for text2vec-openai (falls back to LLM_APIKEY from examples/.env) -# WEAVIATE_URL default http://localhost:8080 -# WEAVIATE_CLASS default Document -set -euo pipefail - -CONTAINER_NAME="${WEAVIATE_CONTAINER_NAME:-weaviate}" -WEAVIATE_IMAGE="${WEAVIATE_IMAGE:-cr.weaviate.io/semitechnologies/weaviate:1.27.0}" -WEAVIATE_HTTP_PORT="${WEAVIATE_HTTP_PORT:-8080}" -WEAVIATE_GRPC_PORT="${WEAVIATE_GRPC_PORT:-50051}" -WEAVIATE_URL="${WEAVIATE_URL:-http://localhost:${WEAVIATE_HTTP_PORT}}" -WEAVIATE_CLASS="${WEAVIATE_CLASS:-Document}" -READY_TIMEOUT_SEC="${WEAVIATE_READY_TIMEOUT_SEC:-120}" - -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -ENV_FILE="${SCRIPT_DIR}/../../.env" -DOCS_FILE="${SCRIPT_DIR}/../common/sample-documents.json" - -read_env_value() { - local key="$1" file="$2" - [[ -f "$file" ]] || return 1 - local line - line="$(grep -E "^${key}=" "$file" | tail -1 || true)" - [[ -n "$line" ]] || return 1 - line="${line#${key}=}" - line="${line%$'\r'}" - line="${line#\"}"; line="${line%\"}" - line="${line#\'}"; line="${line%\'}" - printf '%s' "$line" -} - -require_cmd() { - if ! command -v "$1" >/dev/null 2>&1; then - echo "error: '$1' is required but not installed" >&2 - exit 1 - fi -} - -resolve_openai_api_key() { - if [[ -n "${OPENAI_APIKEY:-}" ]]; then - return 0 - fi - if key="$(read_env_value OPENAI_APIKEY "$ENV_FILE" 2>/dev/null)" && [[ -n "$key" ]]; then - export OPENAI_APIKEY="$key" - echo "Using OPENAI_APIKEY from ${ENV_FILE}" - return 0 - fi - if key="$(read_env_value LLM_APIKEY "$ENV_FILE" 2>/dev/null)" && [[ -n "$key" ]]; then - export OPENAI_APIKEY="$key" - echo "Using LLM_APIKEY from ${ENV_FILE} for Weaviate text2vec-openai" - return 0 - fi - echo "error: set OPENAI_APIKEY (Weaviate vectorizer) or add OPENAI_APIKEY / LLM_APIKEY to ${ENV_FILE}" >&2 - exit 1 -} - -container_running() { - docker ps --format '{{.Names}}' | grep -qx "$CONTAINER_NAME" -} - -container_exists() { - docker ps -a --format '{{.Names}}' | grep -qx "$CONTAINER_NAME" -} - -wait_for_ready() { - local deadline=$((SECONDS + READY_TIMEOUT_SEC)) - echo "Waiting for Weaviate at ${WEAVIATE_URL} (timeout ${READY_TIMEOUT_SEC}s)..." - while (( SECONDS < deadline )); do - if curl -sf "${WEAVIATE_URL}/v1/.well-known/ready" >/dev/null 2>&1; then - echo "Weaviate is ready." - return 0 - fi - sleep 2 - done - echo "error: Weaviate did not become ready within ${READY_TIMEOUT_SEC}s" >&2 - echo "Check logs: docker logs ${CONTAINER_NAME}" >&2 - exit 1 -} - -start_weaviate() { - if container_running; then - echo "Container '${CONTAINER_NAME}' is already running." - return 0 - fi - if container_exists; then - echo "Starting existing container '${CONTAINER_NAME}'..." - docker start "$CONTAINER_NAME" >/dev/null - return 0 - fi - - echo "Creating and starting '${CONTAINER_NAME}' (${WEAVIATE_IMAGE})..." - docker run -d --name "$CONTAINER_NAME" \ - -p "${WEAVIATE_HTTP_PORT}:8080" \ - -p "${WEAVIATE_GRPC_PORT}:50051" \ - -e AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED=true \ - -e DEFAULT_VECTORIZER_MODULE=text2vec-openai \ - -e ENABLE_MODULES=text2vec-openai \ - -e OPENAI_APIKEY="${OPENAI_APIKEY}" \ - "$WEAVIATE_IMAGE" >/dev/null -} - -seed_documents() { - if [[ ! -f "$DOCS_FILE" ]]; then - echo "error: missing ${DOCS_FILE}" >&2 - exit 1 - fi - - echo "Creating class ${WEAVIATE_CLASS} at ${WEAVIATE_URL} (ignored if it already exists)..." - curl -sf -X POST "${WEAVIATE_URL}/v1/schema" \ - -H 'Content-Type: application/json' \ - -d "{ - \"class\": \"${WEAVIATE_CLASS}\", - \"vectorizer\": \"text2vec-openai\", - \"properties\": [ - {\"name\": \"content\", \"dataType\": [\"text\"]}, - {\"name\": \"source\", \"dataType\": [\"text\"]} - ] - }" >/dev/null || true - - local count=0 row payload - while IFS= read -r row; do - payload=$(jq -n \ - --arg class "$WEAVIATE_CLASS" \ - --arg content "$(echo "$row" | jq -r '.content')" \ - --arg source "$(echo "$row" | jq -r '.source')" \ - '{class: $class, properties: {content: $content, source: $source}}') - curl -sf -X POST "${WEAVIATE_URL}/v1/objects" \ - -H 'Content-Type: application/json' \ - -d "$payload" >/dev/null - count=$((count + 1)) - done < <(jq -c '.[]' "$DOCS_FILE") - - echo "Inserted ${count} documents from sample-documents.json" -} - -require_cmd docker -require_cmd curl -require_cmd jq -resolve_openai_api_key -start_weaviate -wait_for_ready -seed_documents - -cat </dev/null 2>&1 || exit 1"] + interval: 2s + timeout: 5s + retries: 30 + start_period: 10s + + pgvector: + container_name: pgvector + image: pgvector/pgvector:pg16 + ports: + - "${PGVECTOR_PORT:-5432}:5432" + environment: + POSTGRES_USER: ${PGVECTOR_USER:-postgres} + POSTGRES_PASSWORD: ${PGVECTOR_PASSWORD:-secret} + POSTGRES_DB: ${PGVECTOR_DB:-vectordb} + healthcheck: + test: ["CMD-SHELL", "pg_isready -U ${PGVECTOR_USER:-postgres} -d ${PGVECTOR_DB:-vectordb}"] + interval: 2s + timeout: 5s + retries: 30 + start_period: 5s diff --git a/examples/agent_with_retriever/common/sample-documents.json b/examples/docker/pgvector/sample-documents.json similarity index 100% rename from examples/agent_with_retriever/common/sample-documents.json rename to examples/docker/pgvector/sample-documents.json diff --git a/examples/docker/pgvector/seed.sh b/examples/docker/pgvector/seed.sh new file mode 100755 index 0000000..48f0aeb --- /dev/null +++ b/examples/docker/pgvector/seed.sh @@ -0,0 +1,126 @@ +#!/usr/bin/env bash +# Seed pgvector after compose service is up (schema + sample documents). +set -euo pipefail + +CONTAINER_NAME="${PGVECTOR_CONTAINER_NAME:-pgvector}" +PG_USER="${PGVECTOR_USER:-postgres}" +PG_PASSWORD="${PGVECTOR_PASSWORD:-secret}" +PG_DB="${PGVECTOR_DB:-vectordb}" +PG_TABLE="${PGVECTOR_TABLE:-documents}" +EMBEDDING_OPENAI_MODEL="${EMBEDDING_OPENAI_MODEL:-text-embedding-3-small}" +READY_TIMEOUT_SEC="${PGVECTOR_READY_TIMEOUT_SEC:-120}" + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +EXAMPLES_DIR="${SCRIPT_DIR}/../.." +ENV_FILE="${EXAMPLES_DIR}/.env" +ROOT_ENV_FILE="${EXAMPLES_DIR}/../.env" +DOCS_FILE="${SCRIPT_DIR}/sample-documents.json" +SQL_FILE="${SCRIPT_DIR}/setup.sql" + +read_env_value() { + local key="$1" file="$2" + [[ -f "$file" ]] || return 1 + local line + line="$(grep -E "^${key}=" "$file" | tail -1 || true)" + [[ -n "$line" ]] || return 1 + line="${line#${key}=}" + line="${line%$'\r'}" + line="${line#\"}"; line="${line%\"}" + line="${line#\'}"; line="${line%\'}" + printf '%s' "$line" +} + +require_cmd() { + command -v "$1" >/dev/null || { echo "error: '$1' is required" >&2; exit 1; } +} + +sql_escape() { + printf '%s' "$1" | sed "s/'/''/g" +} + +resolve_embedding_api_key() { + local f key + if [[ -n "${EMBEDDING_OPENAI_APIKEY:-}" ]]; then + return 0 + fi + for f in "$ENV_FILE" "$ROOT_ENV_FILE"; do + if key="$(read_env_value EMBEDDING_OPENAI_APIKEY "$f" 2>/dev/null)" && [[ -n "$key" ]]; then + export EMBEDDING_OPENAI_APIKEY="$key" + return 0 + fi + done + echo "error: EMBEDDING_OPENAI_APIKEY is required (examples/.env or environment)" >&2 + exit 1 +} + +resolve_embedding_base_url() { + if [[ -n "${EMBEDDING_OPENAI_BASEURL:-}" ]]; then + return 0 + fi + if url="$(read_env_value EMBEDDING_OPENAI_BASEURL "$ENV_FILE" 2>/dev/null)" && [[ -n "$url" ]]; then + export EMBEDDING_OPENAI_BASEURL="$url" + return 0 + fi + if url="$(read_env_value LLM_BASEURL "$ENV_FILE" 2>/dev/null)" && [[ -n "$url" ]]; then + export EMBEDDING_OPENAI_BASEURL="$url" + return 0 + fi + export EMBEDDING_OPENAI_BASEURL="https://api.openai.com/v1" +} + +wait_for_postgres() { + local deadline=$((SECONDS + READY_TIMEOUT_SEC)) + echo "Waiting for Postgres in '${CONTAINER_NAME}'..." + while (( SECONDS < deadline )); do + if docker exec "$CONTAINER_NAME" pg_isready -U "$PG_USER" -d "$PG_DB" >/dev/null 2>&1; then + echo "Postgres is ready." + return 0 + fi + sleep 2 + done + echo "error: Postgres not ready within ${READY_TIMEOUT_SEC}s (docker logs pgvector)" >&2 + exit 1 +} + +embed_text() { + local text="$1" body response + body=$(jq -n --arg input "$text" --arg model "$EMBEDDING_OPENAI_MODEL" '{input: $input, model: $model}') + response=$(curl -sf "${EMBEDDING_OPENAI_BASEURL%/}/embeddings" \ + -H "Authorization: Bearer ${EMBEDDING_OPENAI_APIKEY}" \ + -H "Content-Type: application/json" \ + -d "$body") + echo "$response" | jq -c '.data[0].embedding' +} + +require_cmd docker +require_cmd curl +require_cmd jq +resolve_embedding_api_key +resolve_embedding_base_url +wait_for_postgres + +echo "Applying schema from setup.sql..." +docker exec -i "$CONTAINER_NAME" psql -U "$PG_USER" -d "$PG_DB" -v ON_ERROR_STOP=1 < "$SQL_FILE" + +echo "Clearing existing rows in ${PG_TABLE}..." +docker exec "$CONTAINER_NAME" psql -U "$PG_USER" -d "$PG_DB" -v ON_ERROR_STOP=1 \ + -c "TRUNCATE ${PG_TABLE} RESTART IDENTITY;" >/dev/null + +count=0 +while IFS= read -r row; do + content="$(echo "$row" | jq -r '.content')" + source="$(echo "$row" | jq -r '.source')" + echo "Embedding document $((count + 1)): ${source}" + vec="$(embed_text "$content")" + if [[ -z "$vec" || "$vec" == "null" ]]; then + echo "error: empty embedding for ${source}" >&2 + exit 1 + fi + content_sql="$(sql_escape "$content")" + source_sql="$(sql_escape "$source")" + docker exec "$CONTAINER_NAME" psql -U "$PG_USER" -d "$PG_DB" -v ON_ERROR_STOP=1 \ + -c "INSERT INTO ${PG_TABLE} (content, source, embedding) VALUES ('${content_sql}', '${source_sql}', '${vec}'::vector);" >/dev/null + count=$((count + 1)) +done < <(jq -c '.[]' "$DOCS_FILE") + +echo "Inserted ${count} documents from sample-documents.json" diff --git a/examples/agent_with_retriever/pgvector/setup.sql b/examples/docker/pgvector/setup.sql similarity index 65% rename from examples/agent_with_retriever/pgvector/setup.sql rename to examples/docker/pgvector/setup.sql index 4335b23..4dac62e 100644 --- a/examples/agent_with_retriever/pgvector/setup.sql +++ b/examples/docker/pgvector/setup.sql @@ -1,5 +1,5 @@ --- Optional DDL for the pgvector example (run against PGVECTOR_DSN database). --- Embedding dimension must match EMBEDDING_MODEL (text-embedding-3-small → 1536). +-- Schema for the pgvector example (applied by docker/pgvector/seed.sh). +-- Embedding dimension must match EMBEDDING_OPENAI_MODEL (text-embedding-3-small → 1536). CREATE EXTENSION IF NOT EXISTS vector; diff --git a/examples/docker/weaviate/sample-documents.json b/examples/docker/weaviate/sample-documents.json new file mode 100644 index 0000000..6d6cb92 --- /dev/null +++ b/examples/docker/weaviate/sample-documents.json @@ -0,0 +1,26 @@ +[ + { + "content": "Standard shipping within the continental United States takes 3–5 business days after the order ships. Express shipping (1–2 business days) is available at checkout for an additional fee. Orders placed after 2 p.m. local time ship the next business day.", + "source": "kb/shipping-and-delivery" + }, + { + "content": "Most items can be returned within 30 days of delivery if they are unused and in original packaging. Refunds are issued to the original payment method within 5–7 business days after we receive and inspect the return. Clearance items are final sale unless defective.", + "source": "kb/returns-and-refunds" + }, + { + "content": "To reset your account password, open the sign-in page and choose Forgot password. Enter the email on your account; you will receive a link that expires in 24 hours. If you do not see the email, check spam or contact support with your order number.", + "source": "kb/account/password-reset" + }, + { + "content": "Hardware products include a one-year limited warranty covering manufacturing defects. The warranty does not cover accidental damage, water damage, or normal wear. To start a claim, open a support ticket with your serial number and a short description of the issue.", + "source": "kb/warranty/hardware" + }, + { + "content": "Pro and Enterprise plans include priority email support with a target first response within one business day. Business hours are Monday–Friday, 9 a.m.–6 p.m. Eastern Time, excluding U.S. federal holidays. Phone support is available on Enterprise plans only.", + "source": "kb/support/hours-and-sla" + }, + { + "content": "Invoices for subscription plans are emailed on the first of each month. You can download past invoices from Billing → Invoice history in the customer portal. Tax is calculated based on your billing address and local regulations.", + "source": "kb/billing/invoices" + } +] diff --git a/examples/docker/weaviate/seed.sh b/examples/docker/weaviate/seed.sh new file mode 100755 index 0000000..c6db35b --- /dev/null +++ b/examples/docker/weaviate/seed.sh @@ -0,0 +1,94 @@ +#!/usr/bin/env bash +# Seed Weaviate after compose service is up (schema + sample documents). +set -euo pipefail + +WEAVIATE_URL="${WEAVIATE_URL:-http://localhost:8080}" +WEAVIATE_CLASS="${WEAVIATE_CLASS:-Document}" +READY_TIMEOUT_SEC="${WEAVIATE_READY_TIMEOUT_SEC:-120}" + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +EXAMPLES_DIR="${SCRIPT_DIR}/../.." +ENV_FILE="${EXAMPLES_DIR}/.env" +ROOT_ENV_FILE="${EXAMPLES_DIR}/../.env" +DOCS_FILE="${SCRIPT_DIR}/sample-documents.json" + +read_env_value() { + local key="$1" file="$2" + [[ -f "$file" ]] || return 1 + local line + line="$(grep -E "^${key}=" "$file" | tail -1 || true)" + [[ -n "$line" ]] || return 1 + line="${line#${key}=}" + line="${line%$'\r'}" + line="${line#\"}"; line="${line%\"}" + line="${line#\'}"; line="${line%\'}" + printf '%s' "$line" +} + +require_cmd() { + command -v "$1" >/dev/null || { echo "error: '$1' is required" >&2; exit 1; } +} + +resolve_embedding_api_key() { + local f key + if [[ -n "${EMBEDDING_OPENAI_APIKEY:-}" ]]; then + return 0 + fi + for f in "$ENV_FILE" "$ROOT_ENV_FILE"; do + if key="$(read_env_value EMBEDDING_OPENAI_APIKEY "$f" 2>/dev/null)" && [[ -n "$key" ]]; then + export EMBEDDING_OPENAI_APIKEY="$key" + return 0 + fi + done + echo "error: EMBEDDING_OPENAI_APIKEY is required (examples/.env or environment)" >&2 + exit 1 +} + +wait_for_ready() { + local deadline=$((SECONDS + READY_TIMEOUT_SEC)) + echo "Waiting for Weaviate at ${WEAVIATE_URL}..." + while (( SECONDS < deadline )); do + if curl -sf "${WEAVIATE_URL}/v1/.well-known/ready" >/dev/null 2>&1; then + echo "Weaviate is ready." + return 0 + fi + sleep 2 + done + echo "error: Weaviate not ready within ${READY_TIMEOUT_SEC}s (docker logs weaviate)" >&2 + exit 1 +} + +require_cmd curl +require_cmd jq +resolve_embedding_api_key +wait_for_ready + +echo "Creating class ${WEAVIATE_CLASS}..." +curl -sf -X POST "${WEAVIATE_URL}/v1/schema" \ + -H 'Content-Type: application/json' \ + -d "{ + \"class\": \"${WEAVIATE_CLASS}\", + \"vectorizer\": \"text2vec-openai\", + \"properties\": [ + {\"name\": \"content\", \"dataType\": [\"text\"]}, + {\"name\": \"source\", \"dataType\": [\"text\"]} + ] + }" >/dev/null || true + +count=0 +while IFS= read -r row; do + payload=$(jq -n \ + --arg class "$WEAVIATE_CLASS" \ + --arg content "$(echo "$row" | jq -r '.content')" \ + --arg source "$(echo "$row" | jq -r '.source')" \ + '{class: $class, properties: {content: $content, source: $source}}') + if ! curl -sf -X POST "${WEAVIATE_URL}/v1/objects" \ + -H 'Content-Type: application/json' \ + -d "$payload" >/dev/null; then + echo "error: failed to insert into Weaviate (recreate: task infra:weaviate:down && task infra:weaviate:up if API key changed)" >&2 + exit 1 + fi + count=$((count + 1)) +done < <(jq -c '.[]' "$DOCS_FILE") + +echo "Inserted ${count} documents from sample-documents.json" diff --git a/examples/durable_agent/README.md b/examples/durable_agent/README.md index f216303..6a0c1f0 100644 --- a/examples/durable_agent/README.md +++ b/examples/durable_agent/README.md @@ -3,7 +3,7 @@ Separate **agent** and **worker** processes: `DisableLocalWorker` / `EnableRemoteWorkers` on the agent, `NewAgentWorker` on the worker, shared options in [`opts/opts.go`](opts/opts.go). The agent uses **`Stream`** with `WithStream(true)`; the sample prints streaming event types (`AgentEventTypeTextMessageContent`, tool events, **`AgentEventTypeCustom`** for approvals, **`AgentEventTypeRunFinished`**, etc.—see [`agent/main.go`](agent/main.go)). ```bash -# From examples/ (after cp env.sample .env and Temporal is up) +# From examples/ (LLM_* in .env; task -t Taskfile.yml infra:temporal:up or Temporal running) go run ./durable_agent/worker # Run on terminal 1 go run ./durable_agent/agent "Hello from remote agent!" # Run on terminal 2 ``` diff --git a/examples/env.sample b/examples/env.sample deleted file mode 100644 index 79f0a12..0000000 --- a/examples/env.sample +++ /dev/null @@ -1,141 +0,0 @@ -# Copy to .env for examples (from examples/ dir: cp env.sample .env) - -# Logging: LOG_LEVEL=error|warn|info|debug (default: error = minimal output) -# LOG_LEVEL=error - -# Temporal (TaskQueue optional; defaults to per-example queue like agent-sdk-go-simple_agent. -# If TEMPORAL_TASKQUEUE is set, it is treated as a base prefix and the example suffix is appended -# (e.g. TEMPORAL_TASKQUEUE=agent-sdk-go -> agent-sdk-go-agent_with_tools_approval). -TEMPORAL_HOST=localhost -TEMPORAL_PORT=7233 -TEMPORAL_NAMESPACE=default -TEMPORAL_TASKQUEUE=agent-sdk-go - -# LLM (openai | anthropic | gemini) -LLM_PROVIDER=openai -LLM_APIKEY= -LLM_MODEL=gpt-4o -LLM_BASEURL=https://api.openai.com/v1 - -# For Gemini: LLM_PROVIDER=gemini, LLM_MODEL=gemini-2.5-flash, API key from ai.google.dev - -# Optional: Serper API (for search tool - 2,500 free queries at serper.dev) -# SERPER_API_KEY= - - -# --- MCP (examples/agent_with_mcp_config, agent_with_mcp_client) --- -# Required: stdio (local subprocess) or streamable_http (remote URL). Aliases: local→stdio; http|remote|streamable→streamable_http -MCP_TRANSPORT=streamable_http - -# Optional stable server id (defaults: local for stdio, remote for HTTP). Used when wiring the MCP client. -# MCP_SERVER_NAME=remote - -# Optional: cap connect+RPC time (seconds). Zero = SDK default (see pkg/mcp/client BuildConfig). -# MCP_TIMEOUT_SECONDS= - -# Optional: max connect+RPC attempts per operation. Zero = SDK default. -# MCP_RETRY_ATTEMPTS= - -# Optional tool filter (comma-separated server tool names). Set only one of allow or block. -# MCP_ALLOW_TOOLS= -# MCP_BLOCK_TOOLS= - -# --- MCP streamable HTTP (remote) --- -# Base URL for the MCP server (required for streamable_http). Auth is optional. -MCP_STREAMABLE_HTTP_URL= - -# Optional static bearer (Authorization: Bearer ...). Ignored when OAuth trio below is all set. -# MCP_BEARER_TOKEN= - -# Optional OAuth2 client credentials (set all three together for MCP HTTP + token endpoint). -# MCP_CLIENT_ID= -# MCP_CLIENT_SECRET= -# MCP_TOKEN_URL= - -# Optional: skip TLS verify for MCP and token HTTP (dev only). -# MCP_SKIP_TLS_VERIFY=true - -# --- MCP stdio (local subprocess) --- -# Executable to run (required when using stdio transport). -# MCP_STDIO_COMMAND=node - -# Optional: JSON array of argv strings after the command, e.g.: -# MCP_STDIO_ARGS=["-y","@modelcontextprotocol/server-filesystem","/path/to/examples/mcp-filesystem-sandbox"] - -# Optional: JSON object of extra environment variables for the subprocess, e.g.: -# MCP_STDIO_ENV={"NODE_ENV":"production"} - - -# --- A2A (examples/agent_with_a2a_config, agent_with_a2a_client) --- -# Required: base URL of the remote A2A agent (agent card + JSON-RPC endpoint as advertised by the card). -A2A_URL= - -# Optional stable connection id (default: remote). Used as the server key in tool names (a2a__). -# A2A_SERVER_NAME=remote - -# Optional: per-operation HTTP timeout in seconds. Zero = SDK default (see pkg/a2a/client). -# A2A_TIMEOUT_SECONDS= - -# Optional bearer token (Authorization: Bearer ...). -# A2A_TOKEN= - -# Optional extra headers as JSON object, e.g. {"X-Api-Key":"..."}. -# A2A_HEADERS= - -# Optional; set to true to skip TLS verification (development only). -# A2A_SKIP_TLS_VERIFY=true - -# Optional skill filter (comma-separated skill IDs). Set only one of allow or block. -# A2A_ALLOW_SKILLS= -# A2A_BLOCK_SKILLS= - - -# --- A2A inbound server (examples/agent_with_a2a_server) --- -# Built-in HTTP server: agent card + JSON-RPC. Defaults: localhost:9999 when unset. -# A2A_SERVER_HOST=localhost -# A2A_SERVER_PORT=9999 - -# Optional comma-separated bearer tokens for JSON-RPC (agent card URL stays public). -# A2A_SERVER_BEARER_TOKENS=my-dev-token,other-token - - -# --- Retriever (examples/agent_with_retriever/weaviate or .../pgvector) --- -# Mode: agentic (default) | prefetch | hybrid -# RETRIEVER_MODE=agentic - -# --- Weaviate (examples/agent_with_retriever/weaviate) --- -# WEAVIATE_HOST=localhost:8080 -# WEAVIATE_SCHEME=http -# WEAVIATE_CLASS=Document -# WEAVIATE_RETRIEVER_NAME=weaviate-kb -# WEAVIATE_CONTENT_FIELD=content -# WEAVIATE_SOURCE_FIELD=source -# WEAVIATE_TOP_K= -# WEAVIATE_MIN_SCORE= - -# --- pgvector (examples/agent_with_retriever/pgvector) --- -# Required for pgvector example: -# PGVECTOR_DSN=postgres://postgres:secret@localhost:5432/vectordb?sslmode=disable -# PGVECTOR_TABLE=documents -# PGVECTOR_RETRIEVER_NAME=pgvector-kb -# PGVECTOR_CONTENT_COL=content -# PGVECTOR_SOURCE_COL=source -# PGVECTOR_EMBEDDING_COL=embedding -# PGVECTOR_TOP_K= -# Cosine similarity threshold (pgvector example default 0.35; SDK default 0.75 if unset in app code) -PGVECTOR_MIN_SCORE=0.35 -# EMBEDDING_MODEL=text-embedding-3-small -# Required when LLM_PROVIDER is anthropic or gemini (OpenAI-compatible /embeddings API): -# EMBEDDING_APIKEY=sk-... -# EMBEDDING_BASEURL=https://api.openai.com/v1 - - -# --- OTLP (examples/agent_with_observability/config or .../objects) --- -# Collector host:port without scheme (gRPC default port 4317; HTTP often 4318). -# OTEL_EXPORTER_OTLP_ENDPOINT=localhost:4317 -# -# Optional: grpc (default) or http. -# OTLP_PROTOCOL=grpc -# -# Optional: dev collectors without TLS. -# OTLP_INSECURE=true diff --git a/examples/multiple_agents/main.go b/examples/multiple_agents/main.go index 1e672fc..b7588ae 100644 --- a/examples/multiple_agents/main.go +++ b/examples/multiple_agents/main.go @@ -22,34 +22,31 @@ func main() { } // TaskQueue must be unique per agent. Use WithInstanceId when running multiple agents in same process. - temporalCfg := &agent.TemporalConfig{ - Host: cfg.Host, - Port: cfg.Port, - Namespace: cfg.Namespace, - TaskQueue: cfg.TaskQueue, - } + temporalOpts := config.RuntimeOption(cfg) - agent1, err := agent.NewAgent( + agent1Opts := []agent.Option{ agent.WithName("agent-1"), agent.WithSystemPrompt("You are a helpful math assistant. Keep answers brief."), - agent.WithTemporalConfig(temporalCfg), agent.WithInstanceId("agent-1"), agent.WithLLMClient(llmClient), agent.WithLogger(config.NewLoggerFromLogConfig(cfg)), - ) + } + agent1Opts = append(agent1Opts, temporalOpts...) + agent1, err := agent.NewAgent(agent1Opts...) if err != nil { log.Fatal(config.FormatNewAgentError("failed to create agent 1", err)) } defer agent1.Close() - agent2, err := agent.NewAgent( + agent2Opts := []agent.Option{ agent.WithName("agent-2"), agent.WithSystemPrompt("You are a creative writing assistant. Be expressive."), - agent.WithTemporalConfig(temporalCfg), agent.WithInstanceId("agent-2"), agent.WithLLMClient(llmClient), agent.WithLogger(config.NewLoggerFromLogConfig(cfg)), - ) + } + agent2Opts = append(agent2Opts, temporalOpts...) + agent2, err := agent.NewAgent(agent2Opts...) if err != nil { log.Fatal(config.FormatNewAgentError("failed to create agent 2", err)) } diff --git a/examples/simple_agent/main.go b/examples/simple_agent/main.go index a6dd777..a7aa74e 100644 --- a/examples/simple_agent/main.go +++ b/examples/simple_agent/main.go @@ -23,15 +23,10 @@ func main() { agent.WithName("simple-agent"), agent.WithDescription("Simple agent with built-in worker"), agent.WithSystemPrompt("You are a helpful assistant that can generate text."), - agent.WithTemporalConfig(&agent.TemporalConfig{ - Host: cfg.Host, - Port: cfg.Port, - Namespace: cfg.Namespace, - TaskQueue: cfg.TaskQueue, - }), agent.WithLLMClient(llmClient), agent.WithLogger(config.NewLoggerFromLogConfig(cfg)), } + opts = append(opts, config.RuntimeOption(cfg)...) a, err := agent.NewAgent(opts...) if err != nil { @@ -41,7 +36,7 @@ func main() { prompt := strings.Join(os.Args[1:], " ") if prompt == "" { - prompt = "Hello, what can you do?" + prompt = "Hi" } fmt.Println("user:", prompt) result, err := a.Run(context.Background(), prompt, "") diff --git a/internal/runtime/base/runtime.go b/internal/runtime/base/runtime.go new file mode 100644 index 0000000..9c0c0ae --- /dev/null +++ b/internal/runtime/base/runtime.go @@ -0,0 +1,529 @@ +// Package base provides the shared runtime struct and core execution methods used by +// both the local and temporal runtime backends. It has no dependency on any backend-specific +// SDK (no Temporal, no workflow/activity imports). +package base + +import ( + "context" + "fmt" + "log/slog" + "strings" + "sync" + "time" + + "github.com/agenticenv/agent-sdk-go/internal/events" + "github.com/agenticenv/agent-sdk-go/internal/runtime" + "github.com/agenticenv/agent-sdk-go/internal/types" + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" + "github.com/agenticenv/agent-sdk-go/pkg/logger" + "github.com/google/uuid" +) + +// Runtime holds the execution inputs shared by all runtime backends. +// Local and Temporal runtimes embed this struct and call its methods directly. +type Runtime struct { + AgentSpec runtime.AgentSpec + AgentExecution runtime.AgentExecution + Tracer interfaces.Tracer + Metrics interfaces.Metrics + // ToolExecutionMode controls whether tool calls in one LLM round are executed + // in parallel or sequentially. Defaults to parallel when empty. + ToolExecutionMode types.AgentToolExecutionMode +} + +// BuildLLMRequest constructs an LLMRequest from the given messages and options. +// When retrieverContext is non-empty it is appended to the system prompt (prefetch/hybrid mode). +// Returns the request and the resolved tools slice for later use in response parsing. +func (rt *Runtime) BuildLLMRequest(messages []interfaces.Message, skipTools bool, retrieverContext string) (*interfaces.LLMRequest, []interfaces.Tool) { + tools := rt.AgentExecution.Tools.Tools + systemMessage := rt.AgentSpec.SystemPrompt + if retrieverContext != "" { + systemMessage = fmt.Sprintf("%s\n\nRelevant Context:\n%s", rt.AgentSpec.SystemPrompt, retrieverContext) + } + req := &interfaces.LLMRequest{ + SystemMessage: systemMessage, + ResponseFormat: rt.AgentSpec.ResponseFormat, + Messages: messages, + } + ApplyLLMSampling(rt.AgentExecution.LLM.Sampling, req) + if skipTools { + req.Tools = []interfaces.ToolSpec{} + } else { + req.Tools = interfaces.ToolsToSpecs(tools) + } + return req, tools +} + +// RequiresApproval reports whether t requires human approval before execution. +// When no approval policy is configured the tool's own ApprovalRequired flag is used. +func (rt *Runtime) RequiresApproval(t interfaces.Tool) bool { + if rt.AgentExecution.Tools.ApprovalPolicy == nil { + if ar, ok := t.(interfaces.ToolApproval); ok && ar.ApprovalRequired() { + return true + } + return false + } + return rt.AgentExecution.Tools.ApprovalPolicy.RequiresApproval(t) +} + +// FetchConversationMessages loads prior messages from the conversation store. +// Returns an error when no conversation is configured or the store call fails. +func (rt *Runtime) FetchConversationMessages(ctx context.Context, log logger.Logger, conversationID string) ([]interfaces.Message, error) { + log.Debug(ctx, "runtime: loading conversation history", slog.String("scope", "runtime"), slog.String("conversationID", conversationID)) + + if rt.AgentExecution.Session.Conversation == nil { + return nil, fmt.Errorf("conversation is not configured") + } + + limit := rt.AgentExecution.Session.ConversationSize + if limit <= 0 { + limit = 20 + } + + ctx, sp := rt.Tracer.StartSpan(ctx, "conversation.get_messages", + interfaces.Attribute{Key: "conversation.id", Value: conversationID}, + interfaces.Attribute{Key: "limit", Value: limit}, + ) + defer sp.End() + + messages, err := rt.AgentExecution.Session.Conversation.ListMessages(ctx, conversationID, interfaces.WithLimit(limit)) + if err != nil { + sp.RecordError(err) + return nil, fmt.Errorf("failed to list conversation messages: %w", err) + } + + sp.SetAttribute("message.count", len(messages)) + log.Debug(ctx, "runtime: conversation history loaded", slog.String("scope", "runtime"), slog.Int("messageCount", len(messages))) + return messages, nil +} + +// llmResponseToResult converts an LLMResponse into an LLMResult, resolving tool metadata +// (display name, approval flag) from the registered tools list. +func (rt *Runtime) llmResponseToResult(resp *interfaces.LLMResponse, tools []interfaces.Tool) (*LLMResult, error) { + result := &LLMResult{Content: resp.Content, Usage: CloneLLMUsage(resp.Usage)} + for _, tc := range resp.ToolCalls { + if tc == nil { + continue + } + tool, ok := FindToolByName(tools, tc.ToolName) + if !ok { + return nil, fmt.Errorf("unknown tool: %s", tc.ToolName) + } + displayName := tool.DisplayName() + if displayName == "" { + displayName = tc.ToolName + } + result.ToolCalls = append(result.ToolCalls, ToolCallRequest{ + ToolCallID: tc.ToolCallID, + ToolName: tc.ToolName, + ToolDisplayName: displayName, + Args: tc.Args, + NeedsApproval: rt.RequiresApproval(tool), + }) + } + return result, nil +} + +// emitEvent calls fn safely; a nil fn is a no-op. +func emitEvent(fn func(events.AgentEvent), ev events.AgentEvent) { + if fn != nil { + fn(ev) + } +} + +// ExecuteLLM calls the LLM in non-streaming mode, records metrics and traces, emits +// TEXT_MESSAGE_START / TEXT_MESSAGE_CONTENT / TEXT_MESSAGE_END events, and returns LLMResult. +// messageID and agentName are used only for event construction; emit may be nil. +func (rt *Runtime) ExecuteLLM( + ctx context.Context, + log logger.Logger, + agentName, messageID string, + messages []interfaces.Message, + skipTools bool, + retrieverContext string, + emit func(events.AgentEvent), +) (*LLMResult, error) { + req, tools := rt.BuildLLMRequest(messages, skipTools, retrieverContext) + + llmClient := rt.AgentExecution.LLM.Client + model := llmClient.GetModel() + provider := string(llmClient.GetProvider()) + modelAttr := interfaces.Attribute{Key: types.MetricAttrModel, Value: model} + providerAttr := interfaces.Attribute{Key: types.MetricAttrProvider, Value: provider} + + log.Debug(ctx, "runtime: LLM generate started", slog.String("scope", "runtime"), slog.Int("messageCount", len(messages))) + + rt.Metrics.IncrementCounter(ctx, types.MetricLLMCallStarted, modelAttr, providerAttr) + llmStart := time.Now() + + ctx, sp := rt.Tracer.StartSpan(ctx, "llm.generate", + interfaces.Attribute{Key: "agent.name", Value: strings.TrimSpace(agentName)}, + interfaces.Attribute{Key: "message.count", Value: len(messages)}, + modelAttr, + providerAttr, + ) + resp, err := llmClient.Generate(ctx, req) + llmLatency := float64(time.Since(llmStart).Milliseconds()) + if err != nil { + sp.RecordError(err) + sp.End() + rt.Metrics.IncrementCounter(ctx, types.MetricLLMCallFailed, modelAttr, providerAttr) + rt.Metrics.RecordHistogram(ctx, types.MetricLLMLatencyMs, llmLatency, modelAttr, providerAttr) + return nil, err + } + sp.End() + + rt.Metrics.RecordHistogram(ctx, types.MetricLLMLatencyMs, llmLatency, modelAttr, providerAttr) + rt.Metrics.IncrementCounter(ctx, types.MetricLLMCallCompleted, modelAttr, providerAttr) + if resp.Usage != nil { + rt.Metrics.RecordHistogram(ctx, types.MetricLLMTokensInput, float64(resp.Usage.PromptTokens), modelAttr, providerAttr) + rt.Metrics.RecordHistogram(ctx, types.MetricLLMTokensOutput, float64(resp.Usage.CompletionTokens), modelAttr, providerAttr) + } + + log.Debug(ctx, "runtime: LLM generate completed", slog.String("scope", "runtime"), slog.Int("messageCount", len(messages))) + + result, err := rt.llmResponseToResult(resp, tools) + if err != nil { + return nil, err + } + + emitEvent(emit, events.NewAgentTextMessageStartEvent(messageID, string(interfaces.MessageRoleAssistant))) + emitEvent(emit, events.NewAgentTextMessageContentEvent(messageID, result.Content)) + emitEvent(emit, events.NewAgentTextMessageEndEvent(messageID)) + return result, nil +} + +// ExecuteLLMStream calls the LLM in streaming mode. When the LLM client does not support streaming +// it falls back to Generate automatically. Delta events (text content, reasoning) are emitted via +// emit as chunks arrive; a final TEXT_MESSAGE_START/CONTENT/END triple is emitted for non-streaming +// fallback. emit may be nil. +func (rt *Runtime) ExecuteLLMStream( + ctx context.Context, + log logger.Logger, + agentName, messageID string, + messages []interfaces.Message, + skipTools bool, + retrieverContext string, + emit func(events.AgentEvent), +) (*LLMResult, error) { + req, tools := rt.BuildLLMRequest(messages, skipTools, retrieverContext) + + llmClient := rt.AgentExecution.LLM.Client + model := llmClient.GetModel() + provider := string(llmClient.GetProvider()) + modelAttr := interfaces.Attribute{Key: types.MetricAttrModel, Value: model} + providerAttr := interfaces.Attribute{Key: types.MetricAttrProvider, Value: provider} + isStreamSupported := llmClient.IsStreamSupported() + + rt.Metrics.IncrementCounter(ctx, types.MetricLLMCallStarted, modelAttr, providerAttr) + llmStart := time.Now() + + ctx, sp := rt.Tracer.StartSpan(ctx, "llm.stream", + interfaces.Attribute{Key: "agent.name", Value: strings.TrimSpace(agentName)}, + interfaces.Attribute{Key: "message.count", Value: len(messages)}, + interfaces.Attribute{Key: "streaming", Value: isStreamSupported}, + modelAttr, + providerAttr, + ) + defer sp.End() + + // Helpers to track open/close state for text message and reasoning events. + textMsgOpen := false + openTextMsg := func() { + if textMsgOpen { + return + } + emitEvent(emit, events.NewAgentTextMessageStartEvent(messageID, string(interfaces.MessageRoleAssistant))) + textMsgOpen = true + } + closeTextMsg := func() { + if !textMsgOpen { + return + } + emitEvent(emit, events.NewAgentTextMessageEndEvent(messageID)) + textMsgOpen = false + } + // If the model never sent text chunks still emit one assistant turn (empty for tool-only). + finalizeAssistantText := func(result *LLMResult) { + if textMsgOpen { + closeTextMsg() + return + } + openTextMsg() + emitEvent(emit, events.NewAgentTextMessageContentEvent(messageID, result.Content)) + closeTextMsg() + } + + // Non-streaming fallback: use Generate and emit a complete text message. + if !isStreamSupported { + log.Debug(ctx, "runtime: LLM stream unsupported, using generate", slog.String("scope", "runtime")) + resp, err := llmClient.Generate(ctx, req) + llmLatency := float64(time.Since(llmStart).Milliseconds()) + if err != nil { + sp.RecordError(err) + rt.Metrics.IncrementCounter(ctx, types.MetricLLMCallFailed, modelAttr, providerAttr) + rt.Metrics.RecordHistogram(ctx, types.MetricLLMLatencyMs, llmLatency, modelAttr, providerAttr) + return nil, err + } + result, err := rt.llmResponseToResult(resp, tools) + if err != nil { + sp.RecordError(err) + rt.Metrics.IncrementCounter(ctx, types.MetricLLMCallFailed, modelAttr, providerAttr) + rt.Metrics.RecordHistogram(ctx, types.MetricLLMLatencyMs, llmLatency, modelAttr, providerAttr) + return nil, err + } + rt.Metrics.RecordHistogram(ctx, types.MetricLLMLatencyMs, llmLatency, modelAttr, providerAttr) + rt.Metrics.IncrementCounter(ctx, types.MetricLLMCallCompleted, modelAttr, providerAttr) + if resp.Usage != nil { + rt.Metrics.RecordHistogram(ctx, types.MetricLLMTokensInput, float64(resp.Usage.PromptTokens), modelAttr, providerAttr) + rt.Metrics.RecordHistogram(ctx, types.MetricLLMTokensOutput, float64(resp.Usage.CompletionTokens), modelAttr, providerAttr) + } + finalizeAssistantText(result) + return result, nil + } + + stream, err := llmClient.GenerateStream(ctx, req) + if err != nil { + sp.RecordError(err) + rt.Metrics.IncrementCounter(ctx, types.MetricLLMCallFailed, modelAttr, providerAttr) + rt.Metrics.RecordHistogram(ctx, types.MetricLLMLatencyMs, float64(time.Since(llmStart).Milliseconds()), modelAttr, providerAttr) + return nil, err + } + + // Reasoning AG-UI order: REASONING_START → REASONING_MESSAGE_START → REASONING_MESSAGE_CONTENT* → + // REASONING_MESSAGE_END → REASONING_END (flushed before the first assistant text delta, or at stream end). + var reasoningMID string + reasoningPhaseOpen := false + reasoningMsgOpen := false + flushReasoning := func() { + if reasoningMsgOpen { + emitEvent(emit, events.NewAgentReasoningMessageEndEvent(reasoningMID)) + reasoningMsgOpen = false + } + if reasoningPhaseOpen { + emitEvent(emit, events.NewAgentReasoningEndEvent(reasoningMID)) + reasoningPhaseOpen = false + } + } + openReasoning := func() { + if reasoningPhaseOpen { + return + } + reasoningMID = uuid.New().String() + emitEvent(emit, events.NewAgentReasoningStartEvent(reasoningMID)) + reasoningPhaseOpen = true + emitEvent(emit, events.NewAgentReasoningMessageStartEvent(reasoningMID, string(interfaces.MessageRoleReasoning))) + reasoningMsgOpen = true + } + + for stream.Next() { + chunk := stream.Current() + if chunk == nil { + continue + } + if chunk.ContentDelta != "" { + flushReasoning() + openTextMsg() + emitEvent(emit, events.NewAgentTextMessageContentEvent(messageID, chunk.ContentDelta)) + } + if chunk.ThinkingDelta != "" { + openReasoning() + emitEvent(emit, events.NewAgentReasoningMessageContentEvent(reasoningMID, chunk.ThinkingDelta)) + } + } + flushReasoning() + + llmLatency := float64(time.Since(llmStart).Milliseconds()) + if err := stream.Err(); err != nil { + sp.RecordError(err) + rt.Metrics.IncrementCounter(ctx, types.MetricLLMCallFailed, modelAttr, providerAttr) + rt.Metrics.RecordHistogram(ctx, types.MetricLLMLatencyMs, llmLatency, modelAttr, providerAttr) + return nil, err + } + + resp := stream.GetResult() + if resp == nil { + err := fmt.Errorf("stream completed without result") + sp.RecordError(err) + rt.Metrics.IncrementCounter(ctx, types.MetricLLMCallFailed, modelAttr, providerAttr) + rt.Metrics.RecordHistogram(ctx, types.MetricLLMLatencyMs, llmLatency, modelAttr, providerAttr) + return nil, err + } + + result, err := rt.llmResponseToResult(resp, tools) + if err != nil { + sp.RecordError(err) + rt.Metrics.IncrementCounter(ctx, types.MetricLLMCallFailed, modelAttr, providerAttr) + rt.Metrics.RecordHistogram(ctx, types.MetricLLMLatencyMs, llmLatency, modelAttr, providerAttr) + return nil, err + } + + rt.Metrics.RecordHistogram(ctx, types.MetricLLMLatencyMs, llmLatency, modelAttr, providerAttr) + rt.Metrics.IncrementCounter(ctx, types.MetricLLMCallCompleted, modelAttr, providerAttr) + if resp.Usage != nil { + rt.Metrics.RecordHistogram(ctx, types.MetricLLMTokensInput, float64(resp.Usage.PromptTokens), modelAttr, providerAttr) + rt.Metrics.RecordHistogram(ctx, types.MetricLLMTokensOutput, float64(resp.Usage.CompletionTokens), modelAttr, providerAttr) + } + + log.Debug(ctx, "runtime: LLM stream completed", slog.String("scope", "runtime")) + finalizeAssistantText(result) + return result, nil +} + +// ExecuteTool finds the named tool and executes it, recording tracing and metrics. +// Returns the string representation of the tool result. +func (rt *Runtime) ExecuteTool(ctx context.Context, log logger.Logger, toolName string, args map[string]any) (string, error) { + log.Debug(ctx, "runtime: tool execute started", slog.String("scope", "runtime"), slog.String("tool", toolName), slog.Int("argCount", len(args))) + + tool, ok := FindToolByName(rt.AgentExecution.Tools.Tools, toolName) + if !ok { + log.Warn(ctx, "runtime: unknown tool", slog.String("scope", "runtime"), slog.String("tool", toolName)) + return "", fmt.Errorf("unknown tool: %s", toolName) + } + + toolAttr := interfaces.Attribute{Key: types.MetricAttrTool, Value: toolName} + rt.Metrics.IncrementCounter(ctx, types.MetricToolCallStarted, toolAttr) + toolStart := time.Now() + + ctx, sp := rt.Tracer.StartSpan(ctx, "tool.execute", + interfaces.Attribute{Key: "tool.name", Value: toolName}, + interfaces.Attribute{Key: "arg.count", Value: len(args)}, + ) + defer sp.End() + + result, err := tool.Execute(ctx, args) + toolLatency := float64(time.Since(toolStart).Milliseconds()) + if err != nil { + sp.RecordError(err) + rt.Metrics.IncrementCounter(ctx, types.MetricToolCallFailed, toolAttr) + rt.Metrics.RecordHistogram(ctx, types.MetricToolLatencyMs, toolLatency, toolAttr) + return "", err + } + + rt.Metrics.RecordHistogram(ctx, types.MetricToolLatencyMs, toolLatency, toolAttr) + rt.Metrics.IncrementCounter(ctx, types.MetricToolCallCompleted, toolAttr) + log.Debug(ctx, "runtime: tool execute completed", slog.String("scope", "runtime"), slog.String("tool", toolName)) + return fmt.Sprintf("%v", result), nil +} + +// AuthorizeTool checks programmatic authorization for a tool before approval/execution. +// Tools that do not implement interfaces.ToolAuthorizer are allowed by default. +func (rt *Runtime) AuthorizeTool(ctx context.Context, log logger.Logger, toolName string, args map[string]any) (AuthorizeResult, error) { + log.Debug(ctx, "runtime: tool authorize started", slog.String("scope", "runtime"), slog.String("tool", toolName), slog.Int("argCount", len(args))) + + tool, ok := FindToolByName(rt.AgentExecution.Tools.Tools, toolName) + if !ok { + log.Warn(ctx, "runtime: unknown tool in authorization", slog.String("scope", "runtime"), slog.String("tool", toolName)) + return AuthorizeResult{}, fmt.Errorf("unknown tool: %s", toolName) + } + + authorizer, ok := tool.(interfaces.ToolAuthorizer) + if !ok { + log.Debug(ctx, "runtime: tool has no authorizer; allow by default", slog.String("scope", "runtime"), slog.String("tool", toolName)) + return AuthorizeResult{Allowed: true}, nil + } + + ctx, sp := rt.Tracer.StartSpan(ctx, "tool.authorize", + interfaces.Attribute{Key: "tool.name", Value: toolName}, + interfaces.Attribute{Key: "arg.count", Value: len(args)}, + ) + defer sp.End() + + decision, err := authorizer.Authorize(ctx, args) + if err != nil { + sp.RecordError(err) + log.Warn(ctx, "runtime: tool authorization failed", slog.String("scope", "runtime"), slog.String("tool", toolName), slog.Any("error", err)) + return AuthorizeResult{}, err + } + + if decision.Allow { + sp.SetAttribute("decision", "allowed") + log.Debug(ctx, "runtime: tool authorization allowed", slog.String("scope", "runtime"), slog.String("tool", toolName)) + return AuthorizeResult{Allowed: true}, nil + } + + reason := strings.TrimSpace(decision.Reason) + sp.SetAttribute("decision", "denied") + sp.SetAttribute("deny.reason", reason) + log.Info(ctx, "runtime: tool authorization denied", slog.String("scope", "runtime"), slog.String("tool", toolName), slog.String("reason", reason)) + return AuthorizeResult{Allowed: false, Reason: reason}, nil +} + +// ExecuteRetrievers runs all configured retrievers in parallel for the given query and +// returns a combined document context string for injection into the LLM system prompt. +// Partial failures are logged and skipped; all retrievers failing returns an error. +func (rt *Runtime) ExecuteRetrievers(ctx context.Context, log logger.Logger, query string) (string, error) { + retrievers := rt.AgentExecution.Retrievers.Retrievers + if len(retrievers) == 0 { + return "", nil + } + + log.Debug(ctx, "runtime: retriever prefetch started", slog.String("scope", "runtime"), slog.Int("retrieverCount", len(retrievers)), slog.String("query", query)) + + type retrieverResult struct { + name string + docs []interfaces.Document + err error + } + + results := make([]retrieverResult, len(retrievers)) + var wg sync.WaitGroup + for i, r := range retrievers { + wg.Add(1) + go func(idx int, ret interfaces.Retriever) { + defer wg.Done() + name := ret.Name() + retrieverAttr := interfaces.Attribute{Key: types.MetricAttrRetriever, Value: name} + rt.Metrics.IncrementCounter(ctx, types.MetricRetrieverCallStarted, retrieverAttr) + start := time.Now() + + searchCtx, sp := rt.Tracer.StartSpan(ctx, "retriever.search", + interfaces.Attribute{Key: "retriever.name", Value: name}, + interfaces.Attribute{Key: "query", Value: query}, + ) + docs, err := ret.Search(searchCtx, query) + latency := float64(time.Since(start).Milliseconds()) + if err != nil { + sp.RecordError(err) + sp.End() + rt.Metrics.IncrementCounter(ctx, types.MetricRetrieverCallFailed, retrieverAttr) + rt.Metrics.RecordHistogram(ctx, types.MetricRetrieverLatencyMs, latency, retrieverAttr) + } else { + sp.End() + rt.Metrics.IncrementCounter(ctx, types.MetricRetrieverCallCompleted, retrieverAttr) + rt.Metrics.RecordHistogram(ctx, types.MetricRetrieverLatencyMs, latency, retrieverAttr) + } + results[idx] = retrieverResult{name: name, docs: docs, err: err} + }(i, r) + } + wg.Wait() + + multipleRetrievers := len(retrievers) > 1 + var sb strings.Builder + failedCount := 0 + for _, res := range results { + if res.err != nil { + failedCount++ + log.Error(ctx, "runtime: retriever search failed, skipping", slog.String("scope", "runtime"), slog.String("retriever", res.name), slog.Any("error", res.err)) + continue + } + if len(res.docs) == 0 { + continue + } + if multipleRetrievers { + fmt.Fprintf(&sb, "## %s\n", res.name) + } + sb.WriteString(FormatRetrieverDocs(res.docs)) + } + + if failedCount == len(retrievers) { + return "", fmt.Errorf("retriever prefetch: all %d retriever(s) failed", len(retrievers)) + } + if failedCount > 0 { + log.Warn(ctx, "runtime: some retrievers failed, continuing with partial context", slog.String("scope", "runtime"), slog.Int("failed", failedCount), slog.Int("total", len(retrievers))) + } + + retrieverContext := strings.TrimSpace(sb.String()) + log.Debug(ctx, "runtime: retriever prefetch completed", slog.String("scope", "runtime"), slog.Int("retrieverCount", len(retrievers)), slog.Bool("hasContext", retrieverContext != "")) + return retrieverContext, nil +} diff --git a/internal/runtime/base/runtime_test.go b/internal/runtime/base/runtime_test.go new file mode 100644 index 0000000..f59e296 --- /dev/null +++ b/internal/runtime/base/runtime_test.go @@ -0,0 +1,746 @@ +package base + +import ( + "context" + "errors" + "testing" + + "github.com/agenticenv/agent-sdk-go/internal/events" + sdkruntime "github.com/agenticenv/agent-sdk-go/internal/runtime" + "github.com/agenticenv/agent-sdk-go/internal/types" + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" + ifmocks "github.com/agenticenv/agent-sdk-go/pkg/interfaces/mocks" + "github.com/agenticenv/agent-sdk-go/pkg/logger" + "github.com/agenticenv/agent-sdk-go/pkg/observability" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/require" +) + +// newTestRuntime returns a Runtime wired with noop tracer/metrics and the provided execution. +func newTestRuntime(exec sdkruntime.AgentExecution) *Runtime { + return &Runtime{ + AgentSpec: sdkruntime.AgentSpec{ + Name: "test-agent", + SystemPrompt: "you are helpful", + }, + AgentExecution: exec, + Tracer: observability.DefaultNoopTracer, + Metrics: observability.DefaultNoopMetrics, + } +} + +func noopLog() logger.Logger { return logger.NoopLogger() } + +// stubLLMClient is a minimal LLMClient that returns a fixed response. +type stubLLMClient struct { + resp *interfaces.LLMResponse + err error +} + +func (s stubLLMClient) Generate(_ context.Context, _ *interfaces.LLMRequest) (*interfaces.LLMResponse, error) { + return s.resp, s.err +} +func (stubLLMClient) GenerateStream(_ context.Context, _ *interfaces.LLMRequest) (interfaces.LLMStream, error) { + return nil, errors.New("stream not implemented in stub") +} +func (stubLLMClient) GetModel() string { return "stub" } +func (stubLLMClient) GetProvider() interfaces.LLMProvider { return interfaces.LLMProviderOpenAI } +func (stubLLMClient) IsStreamSupported() bool { return false } + +// --- BuildLLMRequest --- + +func TestBuildLLMRequest_Basic(t *testing.T) { + rt := newTestRuntime(sdkruntime.AgentExecution{ + LLM: sdkruntime.AgentLLM{Client: stubLLMClient{}}, + }) + msgs := []interfaces.Message{{Role: interfaces.MessageRoleUser, Content: "hello"}} + req, tools := rt.BuildLLMRequest(msgs, false, "") + + require.Equal(t, "you are helpful", req.SystemMessage) + require.Equal(t, msgs, req.Messages) + require.Empty(t, tools) +} + +func TestBuildLLMRequest_WithRetrieverContext(t *testing.T) { + rt := newTestRuntime(sdkruntime.AgentExecution{ + LLM: sdkruntime.AgentLLM{Client: stubLLMClient{}}, + }) + req, _ := rt.BuildLLMRequest(nil, false, "extra context") + require.Contains(t, req.SystemMessage, "you are helpful") + require.Contains(t, req.SystemMessage, "extra context") +} + +func TestBuildLLMRequest_SkipTools(t *testing.T) { + ctrl := gomock.NewController(t) + tool := ifmocks.NewMockTool(ctrl) + tool.EXPECT().Name().Return("t").AnyTimes() + tool.EXPECT().Description().Return("").AnyTimes() + tool.EXPECT().Parameters().Return(nil).AnyTimes() + + rt := newTestRuntime(sdkruntime.AgentExecution{ + LLM: sdkruntime.AgentLLM{Client: stubLLMClient{}}, + Tools: sdkruntime.AgentTools{Tools: []interfaces.Tool{tool}}, + }) + req, _ := rt.BuildLLMRequest(nil, true, "") + require.Empty(t, req.Tools) +} + +// approvalToolStub is a Tool that also implements ToolApproval. +type approvalToolStub struct { + name string + approvalRequired bool +} + +func (a approvalToolStub) Name() string { return a.name } +func (a approvalToolStub) DisplayName() string { return a.name } +func (a approvalToolStub) Description() string { return "" } +func (a approvalToolStub) Parameters() interfaces.JSONSchema { return nil } +func (a approvalToolStub) Execute(_ context.Context, _ map[string]any) (any, error) { return nil, nil } +func (a approvalToolStub) ApprovalRequired() bool { return a.approvalRequired } + +// authorizerToolStub is a Tool that also implements ToolAuthorizer. +type authorizerToolStub struct { + name string + allow bool + reason string + err error +} + +func (a authorizerToolStub) Name() string { return a.name } +func (a authorizerToolStub) DisplayName() string { return a.name } +func (a authorizerToolStub) Description() string { return "" } +func (a authorizerToolStub) Parameters() interfaces.JSONSchema { return nil } +func (a authorizerToolStub) Execute(_ context.Context, _ map[string]any) (any, error) { + return nil, nil +} +func (a authorizerToolStub) Authorize(_ context.Context, _ map[string]any) (interfaces.ToolAuthorizationDecision, error) { + return interfaces.ToolAuthorizationDecision{Allow: a.allow, Reason: a.reason}, a.err +} + +// --- RequiresApproval --- + +func TestRequiresApproval_NoPolicyToolHasApproval(t *testing.T) { + rt := newTestRuntime(sdkruntime.AgentExecution{}) + tool := approvalToolStub{name: "t", approvalRequired: true} + require.True(t, rt.RequiresApproval(tool)) +} + +func TestRequiresApproval_NoPolicyToolNoApproval(t *testing.T) { + ctrl := gomock.NewController(t) + tool := ifmocks.NewMockTool(ctrl) + rt := newTestRuntime(sdkruntime.AgentExecution{}) + require.False(t, rt.RequiresApproval(tool)) +} + +// --- FetchConversationMessages --- + +func TestFetchConversationMessages_NoConversation(t *testing.T) { + rt := newTestRuntime(sdkruntime.AgentExecution{ + Session: sdkruntime.AgentSession{Conversation: nil}, + }) + _, err := rt.FetchConversationMessages(context.Background(), noopLog(), "conv-1") + require.Error(t, err) + require.Contains(t, err.Error(), "conversation is not configured") +} + +func TestFetchConversationMessages_Success(t *testing.T) { + ctrl := gomock.NewController(t) + conv := ifmocks.NewMockConversation(ctrl) + msgs := []interfaces.Message{{Role: interfaces.MessageRoleUser, Content: "hi"}} + conv.EXPECT().ListMessages(gomock.Any(), "conv-1", gomock.Any()).Return(msgs, nil) + + rt := newTestRuntime(sdkruntime.AgentExecution{ + Session: sdkruntime.AgentSession{Conversation: conv, ConversationSize: 10}, + }) + got, err := rt.FetchConversationMessages(context.Background(), noopLog(), "conv-1") + require.NoError(t, err) + require.Equal(t, msgs, got) +} + +func TestFetchConversationMessages_Error(t *testing.T) { + ctrl := gomock.NewController(t) + conv := ifmocks.NewMockConversation(ctrl) + conv.EXPECT().ListMessages(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("store down")) + + rt := newTestRuntime(sdkruntime.AgentExecution{ + Session: sdkruntime.AgentSession{Conversation: conv}, + }) + _, err := rt.FetchConversationMessages(context.Background(), noopLog(), "c") + require.Error(t, err) + require.Contains(t, err.Error(), "store down") +} + +// --- ExecuteTool --- + +func TestExecuteTool_UnknownTool(t *testing.T) { + rt := newTestRuntime(sdkruntime.AgentExecution{ + Tools: sdkruntime.AgentTools{Tools: nil}, + }) + _, err := rt.ExecuteTool(context.Background(), noopLog(), "missing", nil) + require.Error(t, err) + require.Contains(t, err.Error(), "unknown tool") +} + +func TestExecuteTool_Success(t *testing.T) { + ctrl := gomock.NewController(t) + tool := ifmocks.NewMockTool(ctrl) + tool.EXPECT().Name().Return("calc").AnyTimes() + tool.EXPECT().Execute(gomock.Any(), gomock.Any()).Return("42", nil) + + rt := newTestRuntime(sdkruntime.AgentExecution{ + Tools: sdkruntime.AgentTools{Tools: []interfaces.Tool{tool}}, + }) + result, err := rt.ExecuteTool(context.Background(), noopLog(), "calc", map[string]any{"x": 1}) + require.NoError(t, err) + require.Equal(t, "42", result) +} + +func TestExecuteTool_ToolError(t *testing.T) { + ctrl := gomock.NewController(t) + tool := ifmocks.NewMockTool(ctrl) + tool.EXPECT().Name().Return("fail-tool").AnyTimes() + tool.EXPECT().Execute(gomock.Any(), gomock.Any()).Return(nil, errors.New("tool failed")) + + rt := newTestRuntime(sdkruntime.AgentExecution{ + Tools: sdkruntime.AgentTools{Tools: []interfaces.Tool{tool}}, + }) + _, err := rt.ExecuteTool(context.Background(), noopLog(), "fail-tool", nil) + require.Error(t, err) + require.Contains(t, err.Error(), "tool failed") +} + +// --- AuthorizeTool --- + +func TestAuthorizeTool_UnknownTool(t *testing.T) { + rt := newTestRuntime(sdkruntime.AgentExecution{}) + _, err := rt.AuthorizeTool(context.Background(), noopLog(), "ghost", nil) + require.Error(t, err) +} + +func TestAuthorizeTool_NoAuthorizer_AllowedByDefault(t *testing.T) { + ctrl := gomock.NewController(t) + tool := ifmocks.NewMockTool(ctrl) + tool.EXPECT().Name().Return("plain").AnyTimes() + + rt := newTestRuntime(sdkruntime.AgentExecution{ + Tools: sdkruntime.AgentTools{Tools: []interfaces.Tool{tool}}, + }) + result, err := rt.AuthorizeTool(context.Background(), noopLog(), "plain", nil) + require.NoError(t, err) + require.True(t, result.Allowed) +} + +func TestAuthorizeTool_Allowed(t *testing.T) { + tool := authorizerToolStub{name: "secure", allow: true} + rt := newTestRuntime(sdkruntime.AgentExecution{ + Tools: sdkruntime.AgentTools{Tools: []interfaces.Tool{tool}}, + }) + result, err := rt.AuthorizeTool(context.Background(), noopLog(), "secure", nil) + require.NoError(t, err) + require.True(t, result.Allowed) +} + +func TestAuthorizeTool_Denied(t *testing.T) { + tool := authorizerToolStub{name: "gated", allow: false, reason: "not allowed"} + rt := newTestRuntime(sdkruntime.AgentExecution{ + Tools: sdkruntime.AgentTools{Tools: []interfaces.Tool{tool}}, + }) + result, err := rt.AuthorizeTool(context.Background(), noopLog(), "gated", nil) + require.NoError(t, err) + require.False(t, result.Allowed) + require.Equal(t, "not allowed", result.Reason) +} + +// --- ExecuteRetrievers --- + +func TestExecuteRetrievers_NoRetrievers(t *testing.T) { + rt := newTestRuntime(sdkruntime.AgentExecution{}) + got, err := rt.ExecuteRetrievers(context.Background(), noopLog(), "query") + require.NoError(t, err) + require.Equal(t, "", got) +} + +func TestExecuteRetrievers_AllFail(t *testing.T) { + ctrl := gomock.NewController(t) + r := ifmocks.NewMockRetriever(ctrl) + r.EXPECT().Name().Return("r1").AnyTimes() + r.EXPECT().Search(gomock.Any(), gomock.Any()).Return(nil, errors.New("down")) + + rt := newTestRuntime(sdkruntime.AgentExecution{ + Retrievers: sdkruntime.AgentRetrievers{Retrievers: []interfaces.Retriever{r}}, + }) + _, err := rt.ExecuteRetrievers(context.Background(), noopLog(), "q") + require.Error(t, err) + require.Contains(t, err.Error(), "all") +} + +func TestExecuteRetrievers_Success(t *testing.T) { + ctrl := gomock.NewController(t) + r := ifmocks.NewMockRetriever(ctrl) + r.EXPECT().Name().Return("kb").AnyTimes() + r.EXPECT().Search(gomock.Any(), "my query").Return([]interfaces.Document{ + {Content: "doc content", Source: "src", Score: 0.95}, + }, nil) + + rt := newTestRuntime(sdkruntime.AgentExecution{ + Retrievers: sdkruntime.AgentRetrievers{Retrievers: []interfaces.Retriever{r}}, + }) + got, err := rt.ExecuteRetrievers(context.Background(), noopLog(), "my query") + require.NoError(t, err) + require.Contains(t, got, "doc content") +} + +// --- ExecuteLLM --- + +func TestExecuteLLM_LLMError(t *testing.T) { + rt := newTestRuntime(sdkruntime.AgentExecution{ + LLM: sdkruntime.AgentLLM{Client: stubLLMClient{err: errors.New("llm unavailable")}}, + }) + _, err := rt.ExecuteLLM(context.Background(), noopLog(), "agent", "msg-1", nil, false, "", nil) + require.Error(t, err) + require.Contains(t, err.Error(), "llm unavailable") +} + +func TestExecuteLLM_Success_NoTools(t *testing.T) { + rt := newTestRuntime(sdkruntime.AgentExecution{ + LLM: sdkruntime.AgentLLM{Client: stubLLMClient{ + resp: &interfaces.LLMResponse{Content: "hello world"}, + }}, + }) + result, err := rt.ExecuteLLM(context.Background(), noopLog(), "agent", "msg-1", nil, false, "", nil) + require.NoError(t, err) + require.Equal(t, "hello world", result.Content) + require.Empty(t, result.ToolCalls) +} + +func TestExecuteLLM_EmitsTextMessageEvents(t *testing.T) { + rt := newTestRuntime(sdkruntime.AgentExecution{ + LLM: sdkruntime.AgentLLM{Client: stubLLMClient{ + resp: &interfaces.LLMResponse{Content: "response text"}, + }}, + }) + + var emitted []events.AgentEventType + emit := func(ev events.AgentEvent) { + emitted = append(emitted, ev.Type()) + } + + _, err := rt.ExecuteLLM(context.Background(), noopLog(), "agent", "msg-1", nil, false, "", emit) + require.NoError(t, err) + require.Equal(t, []events.AgentEventType{ + events.AgentEventTypeTextMessageStart, + events.AgentEventTypeTextMessageContent, + events.AgentEventTypeTextMessageEnd, + }, emitted) +} + +func TestExecuteLLM_NilEmitDoesNotPanic(t *testing.T) { + rt := newTestRuntime(sdkruntime.AgentExecution{ + LLM: sdkruntime.AgentLLM{Client: stubLLMClient{ + resp: &interfaces.LLMResponse{Content: "ok"}, + }}, + }) + require.NotPanics(t, func() { + _, _ = rt.ExecuteLLM(context.Background(), noopLog(), "a", "m", nil, false, "", nil) + }) +} + +func TestExecuteLLM_UnknownToolCallReturnsError(t *testing.T) { + rt := newTestRuntime(sdkruntime.AgentExecution{ + LLM: sdkruntime.AgentLLM{Client: stubLLMClient{ + resp: &interfaces.LLMResponse{ + Content: "", + ToolCalls: []*interfaces.ToolCall{ + {ToolCallID: "1", ToolName: "nonexistent", Args: nil}, + }, + }, + }}, + }) + _, err := rt.ExecuteLLM(context.Background(), noopLog(), "a", "m", nil, false, "", nil) + require.Error(t, err) + require.Contains(t, err.Error(), "unknown tool") +} + +func TestExecuteLLM_WithUsageMetrics(t *testing.T) { + rt := newTestRuntime(sdkruntime.AgentExecution{ + LLM: sdkruntime.AgentLLM{Client: stubLLMClient{ + resp: &interfaces.LLMResponse{ + Content: "ok", + Usage: &interfaces.LLMUsage{PromptTokens: 10, CompletionTokens: 5, TotalTokens: 15}, + }, + }}, + }) + result, err := rt.ExecuteLLM(context.Background(), noopLog(), "a", "m", nil, false, "", nil) + require.NoError(t, err) + require.NotNil(t, result.Usage) + require.EqualValues(t, 10, result.Usage.PromptTokens) +} + +func TestExecuteLLM_ToolCallWithEmptyDisplayName(t *testing.T) { + ctrl := gomock.NewController(t) + tool := ifmocks.NewMockTool(ctrl) + tool.EXPECT().Name().Return("my-tool").AnyTimes() + tool.EXPECT().Description().Return("").AnyTimes() + tool.EXPECT().Parameters().Return(nil).AnyTimes() + tool.EXPECT().DisplayName().Return("").AnyTimes() // empty → falls back to tool name + + rt := newTestRuntime(sdkruntime.AgentExecution{ + LLM: sdkruntime.AgentLLM{Client: stubLLMClient{ + resp: &interfaces.LLMResponse{ + ToolCalls: []*interfaces.ToolCall{ + {ToolCallID: "tc1", ToolName: "my-tool"}, + }, + }, + }}, + Tools: sdkruntime.AgentTools{Tools: []interfaces.Tool{tool}}, + }) + result, err := rt.ExecuteLLM(context.Background(), noopLog(), "a", "m", nil, false, "", nil) + require.NoError(t, err) + require.Len(t, result.ToolCalls, 1) + require.Equal(t, "my-tool", result.ToolCalls[0].ToolDisplayName) +} + +func TestExecuteLLM_NilToolCallInResponse(t *testing.T) { + rt := newTestRuntime(sdkruntime.AgentExecution{ + LLM: sdkruntime.AgentLLM{Client: stubLLMClient{ + resp: &interfaces.LLMResponse{ + Content: "answer", + ToolCalls: []*interfaces.ToolCall{nil}, // nil entry must be skipped + }, + }}, + }) + result, err := rt.ExecuteLLM(context.Background(), noopLog(), "a", "m", nil, false, "", nil) + require.NoError(t, err) + require.Empty(t, result.ToolCalls) +} + +// --- RequiresApproval with policy --- + +func TestRequiresApproval_PolicyOverrides(t *testing.T) { + ctrl := gomock.NewController(t) + tool := ifmocks.NewMockTool(ctrl) + policy := ifmocks.NewMockAgentToolApprovalPolicy(ctrl) + policy.EXPECT().RequiresApproval(tool).Return(true) + + rt := newTestRuntime(sdkruntime.AgentExecution{ + Tools: sdkruntime.AgentTools{ApprovalPolicy: policy}, + }) + require.True(t, rt.RequiresApproval(tool)) +} + +// --- AuthorizeTool error path --- + +func TestAuthorizeTool_AuthorizerError(t *testing.T) { + tool := authorizerToolStub{name: "err-tool", err: errors.New("auth backend down")} + rt := newTestRuntime(sdkruntime.AgentExecution{ + Tools: sdkruntime.AgentTools{Tools: []interfaces.Tool{tool}}, + }) + _, err := rt.AuthorizeTool(context.Background(), noopLog(), "err-tool", nil) + require.Error(t, err) + require.Contains(t, err.Error(), "auth backend down") +} + +// --- ExecuteRetrievers partial failure --- + +func TestExecuteRetrievers_PartialFailure(t *testing.T) { + ctrl := gomock.NewController(t) + good := ifmocks.NewMockRetriever(ctrl) + good.EXPECT().Name().Return("good").AnyTimes() + good.EXPECT().Search(gomock.Any(), gomock.Any()).Return([]interfaces.Document{ + {Content: "useful", Source: "s", Score: 0.8}, + }, nil) + + bad := ifmocks.NewMockRetriever(ctrl) + bad.EXPECT().Name().Return("bad").AnyTimes() + bad.EXPECT().Search(gomock.Any(), gomock.Any()).Return(nil, errors.New("timeout")) + + rt := newTestRuntime(sdkruntime.AgentExecution{ + Retrievers: sdkruntime.AgentRetrievers{Retrievers: []interfaces.Retriever{good, bad}}, + }) + got, err := rt.ExecuteRetrievers(context.Background(), noopLog(), "q") + require.NoError(t, err) // partial is ok + require.Contains(t, got, "useful") +} + +// --- ExecuteLLMStream --- + +// streamCapableLLMClient wraps a stubLLMClient and sets IsStreamSupported=true. +type streamCapableLLMClient struct { + stubLLMClient + stream interfaces.LLMStream + streamErr error +} + +func (s streamCapableLLMClient) IsStreamSupported() bool { return true } +func (s streamCapableLLMClient) GenerateStream(_ context.Context, _ *interfaces.LLMRequest) (interfaces.LLMStream, error) { + return s.stream, s.streamErr +} + +// fixedStream is a simple in-memory LLMStream backed by a slice of chunks. +type fixedStream struct { + chunks []*interfaces.LLMStreamChunk + pos int + result *interfaces.LLMResponse + err error +} + +func newFixedStream(chunks []*interfaces.LLMStreamChunk, result *interfaces.LLMResponse) *fixedStream { + return &fixedStream{chunks: chunks, result: result} +} + +func (s *fixedStream) Next() bool { + s.pos++ + return s.pos <= len(s.chunks) +} +func (s *fixedStream) Current() *interfaces.LLMStreamChunk { + if s.pos < 1 || s.pos > len(s.chunks) { + return nil + } + return s.chunks[s.pos-1] +} +func (s *fixedStream) Err() error { return s.err } +func (s *fixedStream) GetResult() *interfaces.LLMResponse { return s.result } + +func TestExecuteLLMStream_FallbackGenerate_Success(t *testing.T) { + rt := newTestRuntime(sdkruntime.AgentExecution{ + LLM: sdkruntime.AgentLLM{Client: stubLLMClient{ + resp: &interfaces.LLMResponse{Content: "fallback answer"}, + }}, + }) + result, err := rt.ExecuteLLMStream(context.Background(), noopLog(), "agent", "msg", nil, false, "", nil) + require.NoError(t, err) + require.Equal(t, "fallback answer", result.Content) +} + +func TestExecuteLLMStream_FallbackGenerate_LLMError(t *testing.T) { + rt := newTestRuntime(sdkruntime.AgentExecution{ + LLM: sdkruntime.AgentLLM{Client: stubLLMClient{err: errors.New("llm down")}}, + }) + _, err := rt.ExecuteLLMStream(context.Background(), noopLog(), "agent", "msg", nil, false, "", nil) + require.Error(t, err) + require.Contains(t, err.Error(), "llm down") +} + +func TestExecuteLLMStream_FallbackGenerate_EmitsEvents(t *testing.T) { + rt := newTestRuntime(sdkruntime.AgentExecution{ + LLM: sdkruntime.AgentLLM{Client: stubLLMClient{ + resp: &interfaces.LLMResponse{Content: "hi"}, + }}, + }) + var emitted []events.AgentEventType + emit := func(ev events.AgentEvent) { emitted = append(emitted, ev.Type()) } + + _, err := rt.ExecuteLLMStream(context.Background(), noopLog(), "agent", "msg", nil, false, "", emit) + require.NoError(t, err) + require.Equal(t, []events.AgentEventType{ + events.AgentEventTypeTextMessageStart, + events.AgentEventTypeTextMessageContent, + events.AgentEventTypeTextMessageEnd, + }, emitted) +} + +func TestExecuteLLMStream_GenerateStreamError(t *testing.T) { + rt := newTestRuntime(sdkruntime.AgentExecution{ + LLM: sdkruntime.AgentLLM{Client: streamCapableLLMClient{ + streamErr: errors.New("stream init failed"), + }}, + }) + _, err := rt.ExecuteLLMStream(context.Background(), noopLog(), "agent", "msg", nil, false, "", nil) + require.Error(t, err) + require.Contains(t, err.Error(), "stream init failed") +} + +func TestExecuteLLMStream_StreamError_AfterChunks(t *testing.T) { + s := newFixedStream([]*interfaces.LLMStreamChunk{ + {ContentDelta: "partial"}, + }, nil) + s.err = errors.New("connection reset") + + rt := newTestRuntime(sdkruntime.AgentExecution{ + LLM: sdkruntime.AgentLLM{Client: streamCapableLLMClient{stream: s}}, + }) + _, err := rt.ExecuteLLMStream(context.Background(), noopLog(), "agent", "msg", nil, false, "", nil) + require.Error(t, err) + require.Contains(t, err.Error(), "connection reset") +} + +func TestExecuteLLMStream_StreamNilResult(t *testing.T) { + s := newFixedStream(nil, nil) // no chunks, GetResult() returns nil + + rt := newTestRuntime(sdkruntime.AgentExecution{ + LLM: sdkruntime.AgentLLM{Client: streamCapableLLMClient{stream: s}}, + }) + _, err := rt.ExecuteLLMStream(context.Background(), noopLog(), "agent", "msg", nil, false, "", nil) + require.Error(t, err) + require.Contains(t, err.Error(), "stream completed without result") +} + +func TestExecuteLLMStream_TextChunks_EmitsCorrectEvents(t *testing.T) { + s := newFixedStream([]*interfaces.LLMStreamChunk{ + {ContentDelta: "hello"}, + {ContentDelta: " world"}, + }, &interfaces.LLMResponse{Content: "hello world"}) + + rt := newTestRuntime(sdkruntime.AgentExecution{ + LLM: sdkruntime.AgentLLM{Client: streamCapableLLMClient{stream: s}}, + }) + + var emitted []events.AgentEventType + emit := func(ev events.AgentEvent) { emitted = append(emitted, ev.Type()) } + + result, err := rt.ExecuteLLMStream(context.Background(), noopLog(), "agent", "msg", nil, false, "", emit) + require.NoError(t, err) + require.Equal(t, "hello world", result.Content) + require.Equal(t, events.AgentEventTypeTextMessageStart, emitted[0]) + require.Equal(t, events.AgentEventTypeTextMessageContent, emitted[1]) + require.Equal(t, events.AgentEventTypeTextMessageContent, emitted[2]) + require.Equal(t, events.AgentEventTypeTextMessageEnd, emitted[3]) +} + +func TestExecuteLLMStream_ReasoningChunks_EmitsReasoningEvents(t *testing.T) { + s := newFixedStream([]*interfaces.LLMStreamChunk{ + {ThinkingDelta: "let me think"}, + {ContentDelta: "answer"}, + }, &interfaces.LLMResponse{Content: "answer"}) + + rt := newTestRuntime(sdkruntime.AgentExecution{ + LLM: sdkruntime.AgentLLM{Client: streamCapableLLMClient{stream: s}}, + }) + + var emitted []events.AgentEventType + emit := func(ev events.AgentEvent) { emitted = append(emitted, ev.Type()) } + + _, err := rt.ExecuteLLMStream(context.Background(), noopLog(), "agent", "msg", nil, false, "", emit) + require.NoError(t, err) + + // Reasoning events must appear before text events + require.Equal(t, events.AgentEventTypeReasoningStart, emitted[0]) + require.Equal(t, events.AgentEventTypeReasoningMessageStart, emitted[1]) + require.Equal(t, events.AgentEventTypeReasoningMessageContent, emitted[2]) + // flush reasoning before text + require.Equal(t, events.AgentEventTypeReasoningMessageEnd, emitted[3]) + require.Equal(t, events.AgentEventTypeReasoningEnd, emitted[4]) + require.Equal(t, events.AgentEventTypeTextMessageStart, emitted[5]) +} + +func TestExecuteLLMStream_ToolOnlyResponse_EmitsEmptyAssistantTurn(t *testing.T) { + ctrl := gomock.NewController(t) + tool := ifmocks.NewMockTool(ctrl) + tool.EXPECT().Name().Return("search").AnyTimes() + tool.EXPECT().Description().Return("").AnyTimes() + tool.EXPECT().Parameters().Return(nil).AnyTimes() + tool.EXPECT().DisplayName().Return("Search").AnyTimes() + + s := newFixedStream(nil, &interfaces.LLMResponse{ + ToolCalls: []*interfaces.ToolCall{{ToolCallID: "1", ToolName: "search"}}, + }) + + rt := newTestRuntime(sdkruntime.AgentExecution{ + LLM: sdkruntime.AgentLLM{Client: streamCapableLLMClient{stream: s}}, + Tools: sdkruntime.AgentTools{Tools: []interfaces.Tool{tool}}, + }) + + var emitted []events.AgentEventType + emit := func(ev events.AgentEvent) { emitted = append(emitted, ev.Type()) } + + result, err := rt.ExecuteLLMStream(context.Background(), noopLog(), "agent", "msg", nil, false, "", emit) + require.NoError(t, err) + require.Len(t, result.ToolCalls, 1) + // finalizeAssistantText emits a start/content/end even when no text chunks arrived + require.Contains(t, emitted, events.AgentEventTypeTextMessageStart) + require.Contains(t, emitted, events.AgentEventTypeTextMessageEnd) +} + +func TestExecuteLLMStream_WithUsageMetrics(t *testing.T) { + s := newFixedStream(nil, &interfaces.LLMResponse{ + Content: "done", + Usage: &interfaces.LLMUsage{PromptTokens: 8, CompletionTokens: 4, TotalTokens: 12}, + }) + rt := newTestRuntime(sdkruntime.AgentExecution{ + LLM: sdkruntime.AgentLLM{Client: streamCapableLLMClient{stream: s}}, + }) + result, err := rt.ExecuteLLMStream(context.Background(), noopLog(), "agent", "msg", nil, false, "", nil) + require.NoError(t, err) + require.NotNil(t, result.Usage) + require.EqualValues(t, 8, result.Usage.PromptTokens) +} + +func TestExecuteLLMStream_NilChunkSkipped(t *testing.T) { + s := newFixedStream([]*interfaces.LLMStreamChunk{ + nil, + {ContentDelta: "text"}, + }, &interfaces.LLMResponse{Content: "text"}) + + rt := newTestRuntime(sdkruntime.AgentExecution{ + LLM: sdkruntime.AgentLLM{Client: streamCapableLLMClient{stream: s}}, + }) + result, err := rt.ExecuteLLMStream(context.Background(), noopLog(), "agent", "msg", nil, false, "", nil) + require.NoError(t, err) + require.Equal(t, "text", result.Content) +} + +func TestExecuteLLMStream_FallbackGenerate_WithUsage(t *testing.T) { + rt := newTestRuntime(sdkruntime.AgentExecution{ + LLM: sdkruntime.AgentLLM{Client: stubLLMClient{ + resp: &interfaces.LLMResponse{ + Content: "done", + Usage: &interfaces.LLMUsage{PromptTokens: 5, CompletionTokens: 3, TotalTokens: 8}, + }, + }}, + }) + result, err := rt.ExecuteLLMStream(context.Background(), noopLog(), "agent", "msg", nil, false, "", nil) + require.NoError(t, err) + require.NotNil(t, result.Usage) + require.EqualValues(t, 5, result.Usage.PromptTokens) +} + +func TestExecuteLLMStream_FallbackGenerate_UnknownToolCallError(t *testing.T) { + rt := newTestRuntime(sdkruntime.AgentExecution{ + LLM: sdkruntime.AgentLLM{Client: stubLLMClient{ + resp: &interfaces.LLMResponse{ + ToolCalls: []*interfaces.ToolCall{{ToolCallID: "1", ToolName: "ghost"}}, + }, + }}, + }) + _, err := rt.ExecuteLLMStream(context.Background(), noopLog(), "agent", "msg", nil, false, "", nil) + require.Error(t, err) + require.Contains(t, err.Error(), "unknown tool") +} + +func TestExecuteLLMStream_Stream_UnknownToolCallError(t *testing.T) { + s := newFixedStream(nil, &interfaces.LLMResponse{ + ToolCalls: []*interfaces.ToolCall{{ToolCallID: "1", ToolName: "ghost"}}, + }) + rt := newTestRuntime(sdkruntime.AgentExecution{ + LLM: sdkruntime.AgentLLM{Client: streamCapableLLMClient{stream: s}}, + }) + _, err := rt.ExecuteLLMStream(context.Background(), noopLog(), "agent", "msg", nil, false, "", nil) + require.Error(t, err) + require.Contains(t, err.Error(), "unknown tool") +} + +func TestExecuteRetrievers_EmptyDocsSkipped(t *testing.T) { + ctrl := gomock.NewController(t) + r := ifmocks.NewMockRetriever(ctrl) + r.EXPECT().Name().Return("empty-kb").AnyTimes() + r.EXPECT().Search(gomock.Any(), gomock.Any()).Return([]interfaces.Document{}, nil) // no docs + + rt := newTestRuntime(sdkruntime.AgentExecution{ + Retrievers: sdkruntime.AgentRetrievers{Retrievers: []interfaces.Retriever{r}}, + }) + got, err := rt.ExecuteRetrievers(context.Background(), noopLog(), "q") + require.NoError(t, err) + require.Equal(t, "", got) +} + +// --- ApplyLLMSampling Reasoning field --- + +func TestApplyLLMSampling_Reasoning(t *testing.T) { + req := &interfaces.LLMRequest{} + ApplyLLMSampling(&types.LLMSampling{ + Reasoning: &types.LLMReasoning{Effort: "medium", Enabled: true}, + }, req) + require.NotNil(t, req.Reasoning) + require.Equal(t, "medium", req.Reasoning.Effort) + require.True(t, req.Reasoning.Enabled) +} diff --git a/internal/runtime/base/types.go b/internal/runtime/base/types.go new file mode 100644 index 0000000..c631f69 --- /dev/null +++ b/internal/runtime/base/types.go @@ -0,0 +1,30 @@ +package base + +import "github.com/agenticenv/agent-sdk-go/pkg/interfaces" + +// LLMResult is the result of a successful LLM call. +// Content holds the assistant text; ToolCalls holds any tool invocations resolved against +// the registered tools list (NeedsApproval pre-computed from the approval policy). +type LLMResult struct { + Content string + ToolCalls []ToolCallRequest + Usage *interfaces.LLMUsage +} + +// ToolCallRequest describes one tool call returned by the LLM. +// NeedsApproval is pre-computed from the tool approval policy so orchestration loops +// (local agent loop, temporal workflow) do not need to re-evaluate the policy. +type ToolCallRequest struct { + ToolCallID string + ToolName string + ToolDisplayName string + Args map[string]any + NeedsApproval bool +} + +// AuthorizeResult is the outcome of a programmatic tool authorization check. +// When Allowed is false, Reason carries the denial message for logging/events. +type AuthorizeResult struct { + Allowed bool + Reason string +} diff --git a/internal/runtime/base/utils.go b/internal/runtime/base/utils.go new file mode 100644 index 0000000..86a2abe --- /dev/null +++ b/internal/runtime/base/utils.go @@ -0,0 +1,93 @@ +package base + +import ( + "fmt" + "strings" + + "github.com/agenticenv/agent-sdk-go/internal/runtime" + "github.com/agenticenv/agent-sdk-go/internal/types" + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" +) + +// SubAgentQuery extracts the query string from a sub-agent tool call's args map. +func SubAgentQuery(args map[string]any) string { + if args == nil { + return "" + } + q, _ := args[runtime.SubAgentToolParamQuery].(string) + return q +} + +// FindToolByName returns the first tool whose Name() matches toolName. +func FindToolByName(tools []interfaces.Tool, toolName string) (interfaces.Tool, bool) { + for _, t := range tools { + if t.Name() == toolName { + return t, true + } + } + return nil, false +} + +// FormatRetrieverDocs formats a list of documents for injection into the LLM system prompt. +// Each entry is rendered as "[N] content\n(source: s, score: 0.XX)\n\n". +func FormatRetrieverDocs(docs []interfaces.Document) string { + if len(docs) == 0 { + return "" + } + var sb strings.Builder + for i, doc := range docs { + fmt.Fprintf(&sb, types.RetrieverDocFormat, i+1, doc.Content, doc.Source, doc.Score) + } + return sb.String() +} + +// MergeLLMUsage accumulates add into acc and returns the result. +// Either argument may be nil; when both are nil, nil is returned. +func MergeLLMUsage(acc, add *interfaces.LLMUsage) *interfaces.LLMUsage { + if add == nil { + return acc + } + if acc == nil { + return CloneLLMUsage(add) + } + return &interfaces.LLMUsage{ + PromptTokens: acc.PromptTokens + add.PromptTokens, + CompletionTokens: acc.CompletionTokens + add.CompletionTokens, + TotalTokens: acc.TotalTokens + add.TotalTokens, + CachedPromptTokens: acc.CachedPromptTokens + add.CachedPromptTokens, + ReasoningTokens: acc.ReasoningTokens + add.ReasoningTokens, + } +} + +// CloneLLMUsage returns a shallow copy of u, or nil when u is nil. +func CloneLLMUsage(u *interfaces.LLMUsage) *interfaces.LLMUsage { + if u == nil { + return nil + } + c := *u + return &c +} + +// ApplyLLMSampling copies non-zero sampling fields from s onto req. +// A nil sampling value is a no-op. +func ApplyLLMSampling(s *types.LLMSampling, req *interfaces.LLMRequest) { + if s == nil { + return + } + if s.Temperature != nil { + req.Temperature = s.Temperature + } + if s.MaxTokens > 0 { + req.MaxTokens = s.MaxTokens + } + if s.TopP != nil { + req.TopP = s.TopP + } + if s.TopK != nil { + req.TopK = s.TopK + } + if s.Reasoning != nil { + r := *s.Reasoning + req.Reasoning = &r + } +} diff --git a/internal/runtime/base/utils_test.go b/internal/runtime/base/utils_test.go new file mode 100644 index 0000000..f3eced1 --- /dev/null +++ b/internal/runtime/base/utils_test.go @@ -0,0 +1,147 @@ +package base + +import ( + "testing" + + "github.com/agenticenv/agent-sdk-go/internal/types" + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/require" + + ifmocks "github.com/agenticenv/agent-sdk-go/pkg/interfaces/mocks" +) + +// --- FindToolByName --- + +func TestFindToolByName_Found(t *testing.T) { + ctrl := gomock.NewController(t) + tool := ifmocks.NewMockTool(ctrl) + tool.EXPECT().Name().Return("my-tool").AnyTimes() + + got, ok := FindToolByName([]interfaces.Tool{tool}, "my-tool") + require.True(t, ok) + require.Equal(t, tool, got) +} + +func TestFindToolByName_NotFound(t *testing.T) { + ctrl := gomock.NewController(t) + tool := ifmocks.NewMockTool(ctrl) + tool.EXPECT().Name().Return("other").AnyTimes() + + _, ok := FindToolByName([]interfaces.Tool{tool}, "missing") + require.False(t, ok) +} + +func TestFindToolByName_EmptyList(t *testing.T) { + _, ok := FindToolByName(nil, "any") + require.False(t, ok) +} + +// --- FormatRetrieverDocs --- + +func TestFormatRetrieverDocs_Empty(t *testing.T) { + require.Equal(t, "", FormatRetrieverDocs(nil)) + require.Equal(t, "", FormatRetrieverDocs([]interfaces.Document{})) +} + +func TestFormatRetrieverDocs_SingleDoc(t *testing.T) { + docs := []interfaces.Document{{Content: "hello", Source: "src", Score: 0.9}} + got := FormatRetrieverDocs(docs) + require.Contains(t, got, "[1]") + require.Contains(t, got, "hello") + require.Contains(t, got, "src") + require.Contains(t, got, "0.90") +} + +func TestFormatRetrieverDocs_MultipleDocs(t *testing.T) { + docs := []interfaces.Document{ + {Content: "first", Source: "s1", Score: 0.8}, + {Content: "second", Source: "s2", Score: 0.6}, + } + got := FormatRetrieverDocs(docs) + require.Contains(t, got, "[1]") + require.Contains(t, got, "[2]") + require.Contains(t, got, "first") + require.Contains(t, got, "second") +} + +// --- MergeLLMUsage --- + +func TestMergeLLMUsage_BothNil(t *testing.T) { + require.Nil(t, MergeLLMUsage(nil, nil)) +} + +func TestMergeLLMUsage_AddNil(t *testing.T) { + acc := &interfaces.LLMUsage{PromptTokens: 5} + got := MergeLLMUsage(acc, nil) + require.Equal(t, acc, got) +} + +func TestMergeLLMUsage_AccNil(t *testing.T) { + add := &interfaces.LLMUsage{PromptTokens: 3, TotalTokens: 3} + got := MergeLLMUsage(nil, add) + require.EqualValues(t, 3, got.PromptTokens) + require.NotSame(t, add, got) // must be a copy +} + +func TestMergeLLMUsage_BothNonNil(t *testing.T) { + acc := &interfaces.LLMUsage{PromptTokens: 10, CompletionTokens: 5, TotalTokens: 15} + add := &interfaces.LLMUsage{PromptTokens: 2, CompletionTokens: 3, TotalTokens: 5} + got := MergeLLMUsage(acc, add) + require.EqualValues(t, 12, got.PromptTokens) + require.EqualValues(t, 8, got.CompletionTokens) + require.EqualValues(t, 20, got.TotalTokens) +} + +// --- CloneLLMUsage --- + +func TestCloneLLMUsage_Nil(t *testing.T) { + require.Nil(t, CloneLLMUsage(nil)) +} + +func TestCloneLLMUsage_MutationIsolation(t *testing.T) { + orig := &interfaces.LLMUsage{PromptTokens: 7} + clone := CloneLLMUsage(orig) + require.Equal(t, orig.PromptTokens, clone.PromptTokens) + clone.PromptTokens = 99 + require.EqualValues(t, 7, orig.PromptTokens) // original unchanged +} + +// --- ApplyLLMSampling --- + +func TestApplyLLMSampling_Nil(t *testing.T) { + req := &interfaces.LLMRequest{} + ApplyLLMSampling(nil, req) // must not panic + require.Nil(t, req.Temperature) +} + +func TestApplyLLMSampling_Temperature(t *testing.T) { + temp := 0.7 + req := &interfaces.LLMRequest{} + ApplyLLMSampling(&types.LLMSampling{Temperature: &temp}, req) + require.NotNil(t, req.Temperature) + require.InDelta(t, 0.7, *req.Temperature, 0.001) +} + +func TestApplyLLMSampling_AllFields(t *testing.T) { + temp := 0.5 + topP := 0.9 + topK := 40 + req := &interfaces.LLMRequest{} + ApplyLLMSampling(&types.LLMSampling{ + Temperature: &temp, + MaxTokens: 512, + TopP: &topP, + TopK: &topK, + }, req) + require.InDelta(t, 0.5, *req.Temperature, 0.001) + require.Equal(t, 512, req.MaxTokens) + require.InDelta(t, 0.9, *req.TopP, 0.001) + require.Equal(t, 40, *req.TopK) +} + +func TestApplyLLMSampling_MaxTokensZeroNotApplied(t *testing.T) { + req := &interfaces.LLMRequest{MaxTokens: 100} + ApplyLLMSampling(&types.LLMSampling{MaxTokens: 0}, req) + require.Equal(t, 100, req.MaxTokens) // unchanged when zero +} diff --git a/internal/runtime/local/agent_loop.go b/internal/runtime/local/agent_loop.go new file mode 100644 index 0000000..5f86a9d --- /dev/null +++ b/internal/runtime/local/agent_loop.go @@ -0,0 +1,523 @@ +package local + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "strings" + "sync" + + "github.com/agenticenv/agent-sdk-go/internal/events" + "github.com/agenticenv/agent-sdk-go/internal/runtime/base" + "github.com/agenticenv/agent-sdk-go/internal/types" + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" + "github.com/google/uuid" +) + +const ( + msgToolRejected = "Tool execution was rejected by the user." + msgToolApprovalUnavailable = "Tool approval could not be completed because no approval handler is configured; continuing without running the tool." + msgToolUnauthorized = "Tool execution was denied by authorization policy." +) + +// AgentLoopInput holds per-run execution inputs for one local agent run. +// Mirrors AgentWorkflowInput (Temporal) for in-process execution — same fields, same semantics. +// Agent-level configuration (ToolExecutionMode, LLM, tools, limits) lives on the runtime itself. +type AgentLoopInput struct { + UserPrompt string + ConversationID string + StreamingEnabled bool + // ChannelName is the eventbus channel events are published to during this run. + // Sub-agents receive the parent's ChannelName so their events go directly to the parent stream. + // Empty = no event fanout. + ChannelName string + // ApprovalHandler is called when a tool requires human approval. May be nil (approval → unavailable). + ApprovalHandler types.ApprovalHandler + // SubAgentRoutes maps sub-agent tool name → local route. Built by the local runtime from + // ExecuteRequest.SubAgents before RunAgentLoop is called. Mirrors AgentWorkflowInput.SubAgentRoutes. + SubAgentRoutes map[string]subAgentRoute + // SubAgentDepth is the current nesting depth (0 = top-level, 1 = direct sub-agent, etc.). + SubAgentDepth int + // MaxSubAgentDepth caps recursive delegation. Mirrors AgentWorkflowInput.MaxSubAgentDepth. + MaxSubAgentDepth int +} + +// AgentLoopResult is the outcome of a completed local agent run. +type AgentLoopResult struct { + Content string + Usage *interfaces.LLMUsage +} + +// publishEventToChannel marshals ev and publishes it on channelName via the runtime eventbus. +// A nil eventbus, empty channel, or nil event is a no-op. +func (rt *LocalRuntime) publishEventToChannel(ctx context.Context, channelName string, ev events.AgentEvent) { + if ev == nil || strings.TrimSpace(channelName) == "" || rt.eventbus == nil { + return + } + data, err := json.Marshal(ev) + if err != nil { + rt.logger.Warn(ctx, "local: failed to marshal agent event", + slog.String("scope", "loop"), + slog.Any("error", err)) + return + } + if err := rt.eventbus.Publish(ctx, channelName, data); err != nil { + rt.logger.Warn(ctx, "local: failed to publish agent event", + slog.String("scope", "loop"), + slog.String("channel", channelName), + slog.Any("error", err)) + } +} + +// RunAgentLoop executes the full agent loop in-process using base.Runtime core methods. +// It mirrors the orchestration logic of AgentWorkflow but calls base methods directly +// instead of dispatching to Temporal activities. +// Events are published to rt.eventbus on input.ChannelName; callers subscribe to that channel. +func (rt *LocalRuntime) RunAgentLoop(ctx context.Context, input AgentLoopInput) (*AgentLoopResult, error) { + log := rt.logger + agentName := rt.AgentSpec.Name + model := rt.AgentExecution.LLM.Client.GetModel() + + maxIter := rt.AgentExecution.Limits.MaxIterations + if maxIter <= 0 { + maxIter = 10 + } + + toolExecMode := rt.ToolExecutionMode + if toolExecMode == "" { + toolExecMode = types.AgentToolExecutionModeParallel + } + + // Internal emit: publish events to the eventbus channel for this run. + emit := func(ev events.AgentEvent) { + rt.publishEventToChannel(ctx, input.ChannelName, ev) + } + + // Build initial message list from user prompt. + messages := []interfaces.Message{ + {Role: interfaces.MessageRoleUser, Content: input.UserPrompt}, + } + + // Prepend conversation history when a conversation ID is provided. + if input.ConversationID != "" { + convMsgs, err := rt.FetchConversationMessages(ctx, log, input.ConversationID) + if err != nil { + log.Warn(ctx, "local: failed to load conversation history, continuing without it", + slog.String("scope", "loop"), + slog.String("conversationID", input.ConversationID), + slog.Any("error", err)) + } else { + messages = append(convMsgs, messages...) + } + } + + // Pre-fetch retriever context for prefetch/hybrid modes. + retrieverContext := "" + retrieverMode := rt.AgentExecution.Retrievers.Mode + if (retrieverMode == types.RetrieverModePrefetch || retrieverMode == types.RetrieverModeHybrid) && + len(rt.AgentExecution.Retrievers.Retrievers) > 0 { + log.Debug(ctx, "local: retriever prefetch started", + slog.String("scope", "loop"), + slog.String("mode", string(retrieverMode)), + slog.Int("retrieverCount", len(rt.AgentExecution.Retrievers.Retrievers))) + rc, err := rt.ExecuteRetrievers(ctx, log, input.UserPrompt) + if err != nil { + return nil, fmt.Errorf("retriever prefetch: %w", err) + } + retrieverContext = rc + log.Debug(ctx, "local: retriever prefetch done", + slog.String("scope", "loop"), + slog.Bool("hasContext", retrieverContext != "")) + } + + var runUsage *interfaces.LLMUsage + lastContent := "" + + for iter := 0; iter < maxIter; iter++ { + messageID := uuid.New().String() + + log.Debug(ctx, "local: LLM call started", + slog.String("scope", "loop"), + slog.Int("iteration", iter), + slog.Int("messageCount", len(messages))) + + var llmResult *base.LLMResult + var err error + if input.StreamingEnabled { + llmResult, err = rt.ExecuteLLMStream(ctx, log, agentName, messageID, messages, false, retrieverContext, emit) + } else { + llmResult, err = rt.ExecuteLLM(ctx, log, agentName, messageID, messages, false, retrieverContext, emit) + } + if err != nil { + return nil, fmt.Errorf("llm call (iter %d): %w", iter, err) + } + + runUsage = base.MergeLLMUsage(runUsage, llmResult.Usage) + + // Final response: no tool calls → done. + if len(llmResult.ToolCalls) == 0 { + messages = append(messages, interfaces.Message{ + Role: interfaces.MessageRoleAssistant, + Content: llmResult.Content, + }) + lastContent = llmResult.Content + break + } + + // Max iterations: re-run without tools for a final answer. + if iter == maxIter-1 { + log.Info(ctx, "local: max iterations reached, forcing final LLM call without tools", + slog.String("scope", "loop"), + slog.Int("iteration", iter)) + finalMessageID := uuid.New().String() + if input.StreamingEnabled { + llmResult, err = rt.ExecuteLLMStream(ctx, log, agentName, finalMessageID, messages, true, retrieverContext, emit) + } else { + llmResult, err = rt.ExecuteLLM(ctx, log, agentName, finalMessageID, messages, true, retrieverContext, emit) + } + if err != nil { + return nil, fmt.Errorf("llm final call (iter %d): %w", iter, err) + } + runUsage = base.MergeLLMUsage(runUsage, llmResult.Usage) + messages = append(messages, interfaces.Message{ + Role: interfaces.MessageRoleAssistant, + Content: llmResult.Content, + }) + lastContent = llmResult.Content + break + } + + // Append assistant message with tool call metadata for next iteration. + assistantMsg := interfaces.Message{ + Role: interfaces.MessageRoleAssistant, + Content: llmResult.Content, + ToolCalls: make([]*interfaces.ToolCall, len(llmResult.ToolCalls)), + } + for i, tc := range llmResult.ToolCalls { + assistantMsg.ToolCalls[i] = &interfaces.ToolCall{ + ToolCallID: tc.ToolCallID, + ToolName: tc.ToolName, + Args: tc.Args, + } + } + messages = append(messages, assistantMsg) + + // Execute tools according to the requested execution mode. + var toolResults []interfaces.Message + switch toolExecMode { + case types.AgentToolExecutionModeParallel: + toolResults, err = rt.executeToolsParallel(ctx, input, messageID, llmResult.ToolCalls, emit) + case types.AgentToolExecutionModeSequential: + toolResults, err = rt.executeToolsSequential(ctx, input, messageID, llmResult.ToolCalls, emit) + default: + return nil, fmt.Errorf("invalid tool execution mode %q: use %q or %q", + toolExecMode, + types.AgentToolExecutionModeParallel, + types.AgentToolExecutionModeSequential) + } + if err != nil { + return nil, err + } + + messages = append(messages, toolResults...) + } + + // Persist all accumulated messages to conversation when a conversation ID is set. + if input.ConversationID != "" && rt.AgentExecution.Session.Conversation != nil { + if err := persistConversationMessages(ctx, rt, input.ConversationID, messages); err != nil { + log.Warn(ctx, "local: persist conversation failed", + slog.String("scope", "loop"), + slog.String("conversationID", input.ConversationID), + slog.Any("error", err)) + } + } + + log.Info(ctx, "local: agent run completed", + slog.String("scope", "loop"), + slog.String("agentName", agentName), + slog.String("model", model), + slog.Int("contentLen", len(lastContent))) + + return &AgentLoopResult{Content: lastContent, Usage: runUsage}, nil +} + +// executeToolsParallel runs all tool calls concurrently and collects results in submission order. +// Errors from individual tools are returned as synthetic tool messages so the LLM can handle +// partial failures gracefully (same behaviour as the Temporal parallel branch). +func (rt *LocalRuntime) executeToolsParallel( + ctx context.Context, + input AgentLoopInput, + messageID string, + toolCalls []base.ToolCallRequest, + emit func(events.AgentEvent), +) ([]interfaces.Message, error) { + rt.logger.Info(ctx, "local: tool execution (parallel)", + slog.String("scope", "loop"), + slog.Int("toolCount", len(toolCalls))) + + results := make([]interfaces.Message, len(toolCalls)) + var wg sync.WaitGroup + + for i := range toolCalls { + wg.Add(1) + go func(idx int, tc base.ToolCallRequest) { + defer wg.Done() + msg, err := rt.executeSingleTool(ctx, input, messageID, tc, emit) + if err != nil { + results[idx] = interfaces.Message{ + Role: interfaces.MessageRoleTool, + Content: "Tool execution failed: " + err.Error(), + ToolName: tc.ToolName, + ToolCallID: tc.ToolCallID, + } + return + } + results[idx] = msg + }(i, toolCalls[i]) + } + wg.Wait() + + return results, nil +} + +// executeToolsSequential runs tool calls one at a time and returns on the first hard error. +func (rt *LocalRuntime) executeToolsSequential( + ctx context.Context, + input AgentLoopInput, + messageID string, + toolCalls []base.ToolCallRequest, + emit func(events.AgentEvent), +) ([]interfaces.Message, error) { + rt.logger.Info(ctx, "local: tool execution (sequential)", + slog.String("scope", "loop"), + slog.Int("toolCount", len(toolCalls))) + + results := make([]interfaces.Message, 0, len(toolCalls)) + for i, tc := range toolCalls { + msg, err := rt.executeSingleTool(ctx, input, messageID, tc, emit) + if err != nil { + rt.logger.Info(ctx, "local: sequential tool failed", + slog.String("scope", "loop"), + slog.Int("toolIndex", i), + slog.String("toolName", tc.ToolName), + slog.Any("error", err)) + return nil, err + } + results = append(results, msg) + } + return results, nil +} + +// executeSingleTool runs the full lifecycle for one tool call: +// authorize → approval (if needed) → execute, emitting TOOL_CALL_* events throughout. +func (rt *LocalRuntime) executeSingleTool( + ctx context.Context, + input AgentLoopInput, + messageID string, + tc base.ToolCallRequest, + emit func(events.AgentEvent), +) (interfaces.Message, error) { + log := rt.logger + + emitToolEndThenResult := func(toolCallID, content string) { + emit(events.NewAgentToolCallEndEvent(toolCallID)) + emit(events.NewAgentToolCallResultEvent(messageID, toolCallID, content, string(interfaces.MessageRoleTool))) + } + + // TOOL_CALL_START + emit(events.NewAgentToolCallStartEvent(tc.ToolCallID, tc.ToolName, messageID)) + + // TOOL_CALL_ARGS (only when non-trivial args) + if argsJSON, err := json.Marshal(tc.Args); err == nil { + s := string(argsJSON) + if s != "" && s != "null" && s != "{}" { + emit(events.NewAgentToolCallArgsEvent(tc.ToolCallID, s)) + } + } + + // Authorization check. + authResult, err := rt.AuthorizeTool(ctx, log, tc.ToolName, tc.Args) + if err != nil { + return interfaces.Message{}, fmt.Errorf("tool authorization error for %q: %w", tc.ToolName, err) + } + if !authResult.Allowed { + content := msgToolUnauthorized + if authResult.Reason != "" { + content = fmt.Sprintf("%s Reason: %s", content, authResult.Reason) + } + log.Info(ctx, "local: tool authorization denied", + slog.String("scope", "loop"), + slog.String("toolName", tc.ToolName), + slog.String("reason", authResult.Reason)) + emitToolEndThenResult(tc.ToolCallID, content) + return interfaces.Message{ + Role: interfaces.MessageRoleTool, + Content: content, + ToolName: tc.ToolName, + ToolCallID: tc.ToolCallID, + }, nil + } + + // Determine whether this tool call is a sub-agent delegation. + subAgentRoute, isSubAgent := input.SubAgentRoutes[tc.ToolName] + + // Approval gate when required. + approvalStatus := types.ApprovalStatusApproved + if tc.NeedsApproval { + // No channel (non-streaming Execute) and no handler: skip approval. + if input.ChannelName == "" && input.ApprovalHandler == nil { + approvalStatus = types.ApprovalStatusUnavailable + } else { + // Generate a token and register a resolve channel so either path can unblock: + // - Streaming: caller receives CUSTOM event with token, calls rt.Approve(token, status) + // - Non-streaming with handler: handler calls approvalReq.Respond(status) directly + token := uuid.New().String() + resultCh := make(chan types.ApprovalStatus, 1) + rt.pendingApprovals.Store(token, resultCh) + defer rt.pendingApprovals.Delete(token) + + if isSubAgent { + approvalReq := &types.ApprovalRequest{ + Name: types.ApprovalRequestNameSubAgent, + Value: types.SubAgentDelegationApprovalRequestValue{ + AgentName: rt.AgentSpec.Name, + SubAgentName: tc.ToolDisplayName, + Args: tc.Args, + ApprovalToken: token, + }, + Respond: func(status types.ApprovalStatus) error { resultCh <- status; return nil }, + } + emit(events.NewAgentCustomEvent(string(events.AgentCustomEventNameSubAgentDelegation), + events.AgentCustomEventDelegationValue{ + AgentName: rt.AgentSpec.Name, + SubAgentName: tc.ToolDisplayName, + Args: tc.Args, + ApprovalToken: token, + })) + if input.ApprovalHandler != nil { + input.ApprovalHandler(ctx, approvalReq) + } + } else { + approvalReq := &types.ApprovalRequest{ + Name: types.ApprovalRequestNameTool, + Value: types.ToolApprovalRequestValue{ + AgentName: rt.AgentSpec.Name, + ToolCallID: tc.ToolCallID, + ToolName: tc.ToolName, + ToolDisplayName: tc.ToolDisplayName, + Args: tc.Args, + ApprovalToken: token, + }, + Respond: func(status types.ApprovalStatus) error { resultCh <- status; return nil }, + } + emit(events.NewAgentCustomEvent(string(events.AgentCustomEventNameToolApproval), + events.AgentCustomEventApprovalValue{ + AgentName: rt.AgentSpec.Name, + ToolCallID: tc.ToolCallID, + ToolName: tc.ToolName, + ToolDisplayName: tc.ToolDisplayName, + Args: tc.Args, + ApprovalToken: token, + })) + if input.ApprovalHandler != nil { + input.ApprovalHandler(ctx, approvalReq) + } + } + // Streaming path: handler is nil; caller calls rt.Approve(token, status) → resultCh. + + select { + case status := <-resultCh: + approvalStatus = status + case <-ctx.Done(): + return interfaces.Message{}, ctx.Err() + } + } + } + + var content string + switch approvalStatus { + case types.ApprovalStatusApproved: + if isSubAgent { + stepName := strings.TrimSpace(subAgentRoute.name) + if stepName == "" { + stepName = tc.ToolName + } + if input.SubAgentDepth >= input.MaxSubAgentDepth { + log.Warn(ctx, "local: sub-agent delegation refused (max depth)", + slog.String("scope", "loop"), + slog.Int("depth", input.SubAgentDepth), + slog.Int("maxDepth", input.MaxSubAgentDepth), + slog.String("toolName", tc.ToolName)) + content = fmt.Sprintf("Sub-agent delegation refused: maximum nesting depth (%d) reached.", input.MaxSubAgentDepth) + } else if subAgentRoute.runtime != nil { + query := base.SubAgentQuery(tc.Args) + log.Info(ctx, "local: delegating to sub-agent", + slog.String("scope", "loop"), + slog.String("toolName", tc.ToolName), + slog.String("stepName", stepName), + slog.Int("depth", input.SubAgentDepth+1)) + emit(events.NewAgentStepStartedEvent(stepName)) + subResult, execErr := subAgentRoute.runtime.RunAgentLoop(ctx, AgentLoopInput{ + UserPrompt: query, + StreamingEnabled: input.StreamingEnabled, + ChannelName: input.ChannelName, + ApprovalHandler: input.ApprovalHandler, + SubAgentRoutes: subAgentRoute.children, + SubAgentDepth: input.SubAgentDepth + 1, + MaxSubAgentDepth: input.MaxSubAgentDepth, + }) + emit(events.NewAgentStepFinishedEvent(stepName)) + if execErr != nil { + content = "Sub-agent execution failed: " + execErr.Error() + } else { + content = subResult.Content + } + } else { + content = "Sub-agent delegation not available for this runtime." + } + } else { + log.Info(ctx, "local: executing tool", + slog.String("scope", "loop"), + slog.String("tool", tc.ToolName), + slog.String("toolCallID", tc.ToolCallID)) + result, execErr := rt.ExecuteTool(ctx, log, tc.ToolName, tc.Args) + if execErr != nil { + content = "Tool execution failed: " + execErr.Error() + } else { + content = result + } + } + case types.ApprovalStatusRejected: + content = msgToolRejected + case types.ApprovalStatusUnavailable: + content = msgToolApprovalUnavailable + default: + return interfaces.Message{}, fmt.Errorf("unexpected approval status %q for tool %q", approvalStatus, tc.ToolName) + } + + emitToolEndThenResult(tc.ToolCallID, content) + return interfaces.Message{ + Role: interfaces.MessageRoleTool, + Content: content, + ToolName: tc.ToolName, + ToolCallID: tc.ToolCallID, + }, nil +} + +// persistConversationMessages stores all accumulated messages from the run into the conversation store. +func persistConversationMessages(ctx context.Context, rt *LocalRuntime, conversationID string, messages []interfaces.Message) error { + conv := rt.AgentExecution.Session.Conversation + if conv == nil { + return nil + } + for _, msg := range messages { + if err := conv.AddMessage(ctx, conversationID, msg); err != nil { + rt.logger.Warn(ctx, "local: add conversation message failed", + slog.String("scope", "loop"), + slog.String("conversationID", conversationID), + slog.Any("error", err)) + } + } + return nil +} diff --git a/internal/runtime/local/agent_loop_test.go b/internal/runtime/local/agent_loop_test.go new file mode 100644 index 0000000..96b4d64 --- /dev/null +++ b/internal/runtime/local/agent_loop_test.go @@ -0,0 +1,742 @@ +package local + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "github.com/agenticenv/agent-sdk-go/internal/events" + sdkruntime "github.com/agenticenv/agent-sdk-go/internal/runtime" + "github.com/agenticenv/agent-sdk-go/internal/runtime/base" + "github.com/agenticenv/agent-sdk-go/internal/types" + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" + ifmocks "github.com/agenticenv/agent-sdk-go/pkg/interfaces/mocks" + "github.com/agenticenv/agent-sdk-go/pkg/logger" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// newLoopRT builds a LocalRuntime with the given LLM client and tools. +// Unlike newLocalRT it also accepts MaxIterations so loop tests can control iterations precisely. +func newLoopRT(t *testing.T, maxIter int, client interfaces.LLMClient, tools ...interfaces.Tool) *LocalRuntime { + t.Helper() + rt, err := NewLocalRuntime( + WithLogger(logger.NoopLogger()), + WithAgentSpec(sdkruntime.AgentSpec{Name: "loop-agent", SystemPrompt: "sys"}), + WithAgentExecution(sdkruntime.AgentExecution{ + LLM: sdkruntime.AgentLLM{Client: client}, + Tools: sdkruntime.AgentTools{Tools: tools}, + Limits: sdkruntime.AgentLimits{MaxIterations: maxIter, Timeout: 10 * time.Second}, + }), + ) + require.NoError(t, err) + return rt +} + +// noopEmit discards all events. +func noopEmit(_ events.AgentEvent) {} + +// captureEmit returns an emit function and a pointer to the captured events slice. +func captureEmit() (func(events.AgentEvent), *[]events.AgentEvent) { + var evs []events.AgentEvent + return func(ev events.AgentEvent) { evs = append(evs, ev) }, &evs +} + +// --------------------------------------------------------------------------- +// RunAgentLoop — basic paths +// --------------------------------------------------------------------------- + +func TestRunAgentLoop_SimpleTextResponse(t *testing.T) { + client := &seqLLMClient{ + responses: []*interfaces.LLMResponse{{Content: "hello world"}}, + } + rt := newLoopRT(t, 5, client) + + result, err := rt.RunAgentLoop(context.Background(), AgentLoopInput{UserPrompt: "hi"}) + require.NoError(t, err) + require.Equal(t, "hello world", result.Content) +} + +func TestRunAgentLoop_LLMError(t *testing.T) { + client := &seqLLMClient{errs: []error{errors.New("llm fail")}} + rt := newLoopRT(t, 5, client) + + _, err := rt.RunAgentLoop(context.Background(), AgentLoopInput{UserPrompt: "hi"}) + require.Error(t, err) + require.Contains(t, err.Error(), "llm fail") +} + +func TestRunAgentLoop_DefaultMaxIterations(t *testing.T) { + // When MaxIterations = 0 the loop defaults to 10. + // The client returns a text response on the first call so it exits immediately. + client := &seqLLMClient{ + responses: []*interfaces.LLMResponse{{Content: "early exit"}}, + } + rt, err := NewLocalRuntime( + WithLogger(logger.NoopLogger()), + WithAgentExecution(sdkruntime.AgentExecution{ + LLM: sdkruntime.AgentLLM{Client: client}, + Limits: sdkruntime.AgentLimits{MaxIterations: 0, Timeout: 10 * time.Second}, + }), + ) + require.NoError(t, err) + + result, err := rt.RunAgentLoop(context.Background(), AgentLoopInput{UserPrompt: "hi"}) + require.NoError(t, err) + require.Equal(t, "early exit", result.Content) +} + +func TestRunAgentLoop_ToolCallThenFinalAnswer(t *testing.T) { + client := &seqLLMClient{ + responses: []*interfaces.LLMResponse{ + {ToolCalls: []*interfaces.ToolCall{{ToolCallID: "c1", ToolName: "add"}}}, + {Content: "sum is 7"}, + }, + } + tool := stubTool{name: "add", result: "7"} + rt := newLoopRT(t, 5, client, tool) + + result, err := rt.RunAgentLoop(context.Background(), AgentLoopInput{UserPrompt: "add"}) + require.NoError(t, err) + require.Equal(t, "sum is 7", result.Content) +} + +func TestRunAgentLoop_MaxIterationsForcesFinalCall(t *testing.T) { + // With maxIter=1 and the only LLM response returning a tool call, the loop + // must fire a second "forced final" LLM call (skipTools=true) and return its content. + client := &seqLLMClient{ + responses: []*interfaces.LLMResponse{ + {ToolCalls: []*interfaces.ToolCall{{ToolCallID: "c1", ToolName: "add"}}}, + {Content: "forced final answer"}, + }, + } + tool := stubTool{name: "add", result: "7"} + rt := newLoopRT(t, 1, client, tool) + + result, err := rt.RunAgentLoop(context.Background(), AgentLoopInput{UserPrompt: "add"}) + require.NoError(t, err) + require.Equal(t, "forced final answer", result.Content) +} + +// --------------------------------------------------------------------------- +// RunAgentLoop — tool execution modes +// --------------------------------------------------------------------------- + +func TestRunAgentLoop_SequentialMode(t *testing.T) { + client := &seqLLMClient{ + responses: []*interfaces.LLMResponse{ + {ToolCalls: []*interfaces.ToolCall{ + {ToolCallID: "c1", ToolName: "t1"}, + {ToolCallID: "c2", ToolName: "t2"}, + }}, + {Content: "sequential done"}, + }, + } + tool1 := stubTool{name: "t1", result: "r1"} + tool2 := stubTool{name: "t2", result: "r2"} + rt := newLoopRT(t, 5, client, tool1, tool2) + rt.ToolExecutionMode = types.AgentToolExecutionModeSequential + + result, err := rt.RunAgentLoop(context.Background(), AgentLoopInput{UserPrompt: "go"}) + require.NoError(t, err) + require.Equal(t, "sequential done", result.Content) +} + +func TestRunAgentLoop_InvalidToolMode(t *testing.T) { + client := &seqLLMClient{ + responses: []*interfaces.LLMResponse{ + {ToolCalls: []*interfaces.ToolCall{{ToolCallID: "c1", ToolName: "t1"}}}, + }, + } + tool := stubTool{name: "t1", result: "r"} + rt := newLoopRT(t, 5, client, tool) + rt.ToolExecutionMode = "invalid-mode" + + _, err := rt.RunAgentLoop(context.Background(), AgentLoopInput{UserPrompt: "go"}) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid tool execution mode") +} + +// --------------------------------------------------------------------------- +// RunAgentLoop — conversation +// --------------------------------------------------------------------------- + +func TestRunAgentLoop_WithConversationID(t *testing.T) { + ctrl := gomock.NewController(t) + conv := ifmocks.NewMockConversation(ctrl) + + history := []interfaces.Message{{Role: interfaces.MessageRoleUser, Content: "old message"}} + conv.EXPECT().ListMessages(gomock.Any(), "conv-x", gomock.Any()).Return(history, nil) + // user + assistant = 2 messages persisted (history messages re-saved too). + conv.EXPECT().AddMessage(gomock.Any(), "conv-x", gomock.Any()).Return(nil).AnyTimes() + + client := &seqLLMClient{ + responses: []*interfaces.LLMResponse{{Content: "with history"}}, + } + rt, err := NewLocalRuntime( + WithLogger(logger.NoopLogger()), + WithAgentExecution(sdkruntime.AgentExecution{ + LLM: sdkruntime.AgentLLM{Client: client}, + Session: sdkruntime.AgentSession{Conversation: conv, ConversationSize: 10}, + Limits: sdkruntime.AgentLimits{MaxIterations: 5, Timeout: 5 * time.Second}, + }), + ) + require.NoError(t, err) + + result, err := rt.RunAgentLoop(context.Background(), AgentLoopInput{ + UserPrompt: "new question", + ConversationID: "conv-x", + }) + require.NoError(t, err) + require.Equal(t, "with history", result.Content) +} + +func TestRunAgentLoop_ConversationFetchErrorContinues(t *testing.T) { + ctrl := gomock.NewController(t) + conv := ifmocks.NewMockConversation(ctrl) + conv.EXPECT().ListMessages(gomock.Any(), "bad-conv", gomock.Any()).Return(nil, errors.New("store down")) + // No AddMessage expected since conversation fetch failed (but we still try to persist). + conv.EXPECT().AddMessage(gomock.Any(), "bad-conv", gomock.Any()).Return(nil).AnyTimes() + + client := &seqLLMClient{ + responses: []*interfaces.LLMResponse{{Content: "continued without history"}}, + } + rt, err := NewLocalRuntime( + WithLogger(logger.NoopLogger()), + WithAgentExecution(sdkruntime.AgentExecution{ + LLM: sdkruntime.AgentLLM{Client: client}, + Session: sdkruntime.AgentSession{Conversation: conv}, + Limits: sdkruntime.AgentLimits{MaxIterations: 5, Timeout: 5 * time.Second}, + }), + ) + require.NoError(t, err) + + result, err := rt.RunAgentLoop(context.Background(), AgentLoopInput{ + UserPrompt: "hi", + ConversationID: "bad-conv", + }) + // Must not fail — just warns and continues. + require.NoError(t, err) + require.Equal(t, "continued without history", result.Content) +} + +// --------------------------------------------------------------------------- +// RunAgentLoop — retrievers +// --------------------------------------------------------------------------- + +func TestRunAgentLoop_RetrieverPrefetch(t *testing.T) { + ctrl := gomock.NewController(t) + ret := ifmocks.NewMockRetriever(ctrl) + ret.EXPECT().Name().Return("kb").AnyTimes() + ret.EXPECT().Search(gomock.Any(), "fetch me").Return([]interfaces.Document{ + {Content: "relevant doc", Source: "kb", Score: 0.9}, + }, nil) + + client := &seqLLMClient{ + responses: []*interfaces.LLMResponse{{Content: "answer with context"}}, + } + rt, err := NewLocalRuntime( + WithLogger(logger.NoopLogger()), + WithAgentExecution(sdkruntime.AgentExecution{ + LLM: sdkruntime.AgentLLM{Client: client}, + Retrievers: sdkruntime.AgentRetrievers{ + Mode: types.RetrieverModePrefetch, + Retrievers: []interfaces.Retriever{ret}, + }, + Limits: sdkruntime.AgentLimits{MaxIterations: 5, Timeout: 5 * time.Second}, + }), + ) + require.NoError(t, err) + + result, err := rt.RunAgentLoop(context.Background(), AgentLoopInput{UserPrompt: "fetch me"}) + require.NoError(t, err) + require.Equal(t, "answer with context", result.Content) +} + +func TestRunAgentLoop_RetrieverPrefetchError(t *testing.T) { + ctrl := gomock.NewController(t) + ret := ifmocks.NewMockRetriever(ctrl) + ret.EXPECT().Name().Return("kb").AnyTimes() + ret.EXPECT().Search(gomock.Any(), gomock.Any()).Return(nil, errors.New("kb down")) + + client := &seqLLMClient{} + rt, err := NewLocalRuntime( + WithLogger(logger.NoopLogger()), + WithAgentExecution(sdkruntime.AgentExecution{ + LLM: sdkruntime.AgentLLM{Client: client}, + Retrievers: sdkruntime.AgentRetrievers{ + Mode: types.RetrieverModePrefetch, + Retrievers: []interfaces.Retriever{ret}, + }, + Limits: sdkruntime.AgentLimits{MaxIterations: 5, Timeout: 5 * time.Second}, + }), + ) + require.NoError(t, err) + + _, err = rt.RunAgentLoop(context.Background(), AgentLoopInput{UserPrompt: "fetch"}) + require.Error(t, err) + require.Contains(t, err.Error(), "retriever prefetch") +} + +// --------------------------------------------------------------------------- +// RunAgentLoop — event emission +// --------------------------------------------------------------------------- + +func TestRunAgentLoop_ToolEventsEmittedToChannel(t *testing.T) { + client := &seqLLMClient{ + responses: []*interfaces.LLMResponse{ + {ToolCalls: []*interfaces.ToolCall{{ToolCallID: "c1", ToolName: "calc"}}}, + {Content: "done"}, + }, + } + tool := stubTool{name: "calc", result: "99"} + rt := newLoopRT(t, 5, client, tool) + + ctx := context.Background() + channel := "test-tool-events" + eventCh, closeFn, err := rt.subscribeToAgentEvents(ctx, channel) + require.NoError(t, err) + + // close only once + var closeOnce sync.Once + safeClose := func() { closeOnce.Do(func() { _ = closeFn() }) } + defer safeClose() + + // Run the loop in a goroutine; close the subscription after it finishes so eventCh drains. + go func() { + _, _ = rt.RunAgentLoop(ctx, AgentLoopInput{ + UserPrompt: "compute", + ChannelName: channel, + }) + safeClose() + }() + + var collected []events.AgentEvent + timeout := time.After(5 * time.Second) + for { + select { + case ev, ok := <-eventCh: + if !ok { + goto done + } + if ev != nil { + collected = append(collected, ev) + } + case <-timeout: + t.Fatal("timed out waiting for events") + } + } +done: + etypes := eventTypes(collected) + assert.Contains(t, etypes, events.AgentEventTypeToolCallStart) + assert.Contains(t, etypes, events.AgentEventTypeToolCallEnd) + assert.Contains(t, etypes, events.AgentEventTypeToolCallResult) +} + +// --------------------------------------------------------------------------- +// executeToolsParallel +// --------------------------------------------------------------------------- + +func TestExecuteToolsParallel_AllSucceed(t *testing.T) { + t1 := stubTool{name: "t1", result: "r1"} + t2 := stubTool{name: "t2", result: "r2"} + rt := newLoopRT(t, 5, &seqLLMClient{}, t1, t2) + + calls := []base.ToolCallRequest{ + {ToolCallID: "c1", ToolName: "t1"}, + {ToolCallID: "c2", ToolName: "t2"}, + } + + msgs, err := rt.executeToolsParallel(context.Background(), AgentLoopInput{}, "msg-1", calls, noopEmit) + require.NoError(t, err) + require.Len(t, msgs, 2) + // Order must match submission order (parallel but results are indexed). + require.Equal(t, "r1", msgs[0].Content) + require.Equal(t, "r2", msgs[1].Content) +} + +func TestExecuteToolsParallel_ToolErrorInMessage(t *testing.T) { + // Parallel: individual tool errors become synthetic error messages, not hard failures. + failing := stubTool{name: "bad", execErr: errors.New("boom")} + rt := newLoopRT(t, 5, &seqLLMClient{}, failing) + + calls := []base.ToolCallRequest{{ToolCallID: "c1", ToolName: "bad"}} + msgs, err := rt.executeToolsParallel(context.Background(), AgentLoopInput{}, "msg", calls, noopEmit) + require.NoError(t, err) // parallel swallows into message + require.Len(t, msgs, 1) + require.Contains(t, msgs[0].Content, "boom") +} + +func TestExecuteToolsParallel_ResultsOrderPreserved(t *testing.T) { + // Three tools; verify result order matches submission order despite concurrency. + tools := []interfaces.Tool{ + stubTool{name: "a", result: "A"}, + stubTool{name: "b", result: "B"}, + stubTool{name: "c", result: "C"}, + } + rt := newLoopRT(t, 5, &seqLLMClient{}, tools...) + + calls := []base.ToolCallRequest{ + {ToolCallID: "1", ToolName: "a"}, + {ToolCallID: "2", ToolName: "b"}, + {ToolCallID: "3", ToolName: "c"}, + } + msgs, err := rt.executeToolsParallel(context.Background(), AgentLoopInput{}, "m", calls, noopEmit) + require.NoError(t, err) + require.Equal(t, []string{"A", "B", "C"}, []string{msgs[0].Content, msgs[1].Content, msgs[2].Content}) +} + +// --------------------------------------------------------------------------- +// executeToolsSequential +// --------------------------------------------------------------------------- + +func TestExecuteToolsSequential_AllSucceed(t *testing.T) { + t1 := stubTool{name: "s1", result: "v1"} + t2 := stubTool{name: "s2", result: "v2"} + rt := newLoopRT(t, 5, &seqLLMClient{}, t1, t2) + + calls := []base.ToolCallRequest{ + {ToolCallID: "c1", ToolName: "s1"}, + {ToolCallID: "c2", ToolName: "s2"}, + } + msgs, err := rt.executeToolsSequential(context.Background(), AgentLoopInput{}, "msg", calls, noopEmit) + require.NoError(t, err) + require.Len(t, msgs, 2) + require.Equal(t, "v1", msgs[0].Content) + require.Equal(t, "v2", msgs[1].Content) +} + +func TestExecuteToolsSequential_HardErrorOnContextCancel(t *testing.T) { + // A tool that blocks until ctx is cancelled → executeSingleTool returns ctx.Err(). + // Sequential should propagate that error. + rt := newLoopRT(t, 5, &seqLLMClient{}) + // Add a fake tool that needs approval with no channel or handler → unavailable (not an error). + // Instead: use a blocking LLM as a proxy — but we need a tool-level error. + // We'll cancel the context before calling. + ctx, cancel := context.WithCancel(context.Background()) + cancel() // pre-cancelled + + calls := []base.ToolCallRequest{{ToolCallID: "c1", ToolName: "missing-tool"}} + _, err := rt.executeToolsSequential(ctx, AgentLoopInput{}, "msg", calls, noopEmit) + // AuthorizeTool returns error for unknown tool. + require.Error(t, err) +} + +// --------------------------------------------------------------------------- +// executeSingleTool +// --------------------------------------------------------------------------- + +func TestExecuteSingleTool_Approved(t *testing.T) { + tool := stubTool{name: "my-tool", result: "hello"} + rt := newLoopRT(t, 5, &seqLLMClient{}, tool) + + emit, evs := captureEmit() + msg, err := rt.executeSingleTool(context.Background(), AgentLoopInput{}, "msg-1", + base.ToolCallRequest{ToolCallID: "c1", ToolName: "my-tool"}, emit) + + require.NoError(t, err) + require.Equal(t, "hello", msg.Content) + require.Equal(t, interfaces.MessageRoleTool, msg.Role) + require.Equal(t, "my-tool", msg.ToolName) + + etypes := eventTypes(*evs) + require.Contains(t, etypes, events.AgentEventTypeToolCallStart) + require.Contains(t, etypes, events.AgentEventTypeToolCallEnd) + require.Contains(t, etypes, events.AgentEventTypeToolCallResult) +} + +func TestExecuteSingleTool_ToolExecError(t *testing.T) { + tool := stubTool{name: "boom", execErr: errors.New("exec failed")} + rt := newLoopRT(t, 5, &seqLLMClient{}, tool) + + msg, err := rt.executeSingleTool(context.Background(), AgentLoopInput{}, "msg", + base.ToolCallRequest{ToolCallID: "c1", ToolName: "boom"}, noopEmit) + require.NoError(t, err) // tool errors become a content message, not a hard error + require.Contains(t, msg.Content, "exec failed") +} + +func TestExecuteSingleTool_UnknownToolErrors(t *testing.T) { + rt := newLoopRT(t, 5, &seqLLMClient{}) // no tools registered + + _, err := rt.executeSingleTool(context.Background(), AgentLoopInput{}, "msg", + base.ToolCallRequest{ToolCallID: "c1", ToolName: "ghost"}, noopEmit) + require.Error(t, err) + require.Contains(t, err.Error(), "ghost") +} + +func TestExecuteSingleTool_AuthorizationDenied(t *testing.T) { + tool := struct { + stubTool + allow bool + reason string + }{ + stubTool: stubTool{name: "restricted"}, + allow: false, + reason: "policy denied", + } + + // Use an authorizerToolStub from the runtime_test helpers (same package). + authTool := authorizerStubLocal{name: "restricted", allow: false, reason: "policy denied"} + rt := newLoopRT(t, 5, &seqLLMClient{}, authTool) + + msg, err := rt.executeSingleTool(context.Background(), AgentLoopInput{}, "msg", + base.ToolCallRequest{ToolCallID: "c1", ToolName: "restricted"}, noopEmit) + require.NoError(t, err) + require.Contains(t, msg.Content, msgToolUnauthorized) + _ = tool +} + +func TestExecuteSingleTool_AuthorizationError(t *testing.T) { + authTool := authorizerStubLocal{name: "err-tool", allow: false, authErr: errors.New("auth backend down")} + rt := newLoopRT(t, 5, &seqLLMClient{}, authTool) + + _, err := rt.executeSingleTool(context.Background(), AgentLoopInput{}, "msg", + base.ToolCallRequest{ToolCallID: "c1", ToolName: "err-tool"}, noopEmit) + require.Error(t, err) + require.Contains(t, err.Error(), "auth backend down") +} + +func TestExecuteSingleTool_ApprovalUnavailable(t *testing.T) { + // No channel, no handler → approval status = unavailable, tool not run. + tool := stubTool{name: "guarded", result: "secret", needsApproval: true} + rt := newLoopRT(t, 5, &seqLLMClient{}, tool) + + msg, err := rt.executeSingleTool(context.Background(), + AgentLoopInput{ChannelName: "", ApprovalHandler: nil}, "msg", + base.ToolCallRequest{ToolCallID: "c1", ToolName: "guarded", NeedsApproval: true}, noopEmit) + require.NoError(t, err) + require.Contains(t, msg.Content, msgToolApprovalUnavailable) +} + +func TestExecuteSingleTool_ApprovalHandlerApproves(t *testing.T) { + tool := stubTool{name: "guarded", result: "ok", needsApproval: true} + rt := newLoopRT(t, 5, &seqLLMClient{}, tool) + + handler := func(_ context.Context, req *types.ApprovalRequest) { + _ = req.Respond(types.ApprovalStatusApproved) + } + + msg, err := rt.executeSingleTool(context.Background(), + AgentLoopInput{ApprovalHandler: handler}, "msg", + base.ToolCallRequest{ToolCallID: "c1", ToolName: "guarded", NeedsApproval: true}, noopEmit) + require.NoError(t, err) + require.Equal(t, "ok", msg.Content) +} + +func TestExecuteSingleTool_ApprovalHandlerRejects(t *testing.T) { + tool := stubTool{name: "guarded", result: "secret", needsApproval: true} + rt := newLoopRT(t, 5, &seqLLMClient{}, tool) + + handler := func(_ context.Context, req *types.ApprovalRequest) { + _ = req.Respond(types.ApprovalStatusRejected) + } + + msg, err := rt.executeSingleTool(context.Background(), + AgentLoopInput{ApprovalHandler: handler}, "msg", + base.ToolCallRequest{ToolCallID: "c1", ToolName: "guarded", NeedsApproval: true}, noopEmit) + require.NoError(t, err) + require.Equal(t, msgToolRejected, msg.Content) +} + +func TestExecuteSingleTool_StreamingApproveUnblocks(t *testing.T) { + // Streaming path: ChannelName set, no ApprovalHandler. + // We call rt.Approve from a goroutine to unblock executeSingleTool. + tool := stubTool{name: "guarded", result: "stream-ok", needsApproval: true} + rt := newLoopRT(t, 5, &seqLLMClient{}, tool) + + // Capture the approval token from the emitted CUSTOM event. + var capturedToken string + var mu sync.Mutex + tokenSet := make(chan struct{}) + + emit := func(ev events.AgentEvent) { + if ev == nil || ev.Type() != events.AgentEventTypeCustom { + return + } + customEv, ok := ev.(*events.AgentCustomEvent) + if !ok { + return + } + val, err := events.ParseCustomEventApproval(customEv) + if err != nil || val.ApprovalToken == "" { + return + } + mu.Lock() + capturedToken = val.ApprovalToken + mu.Unlock() + select { + case <-tokenSet: + default: + close(tokenSet) + } + } + + done := make(chan struct{}) + var ( + resultMsg interfaces.Message + resultErr error + ) + go func() { + defer close(done) + resultMsg, resultErr = rt.executeSingleTool( + context.Background(), + AgentLoopInput{ChannelName: "some-channel"}, // streaming path + "msg", + base.ToolCallRequest{ToolCallID: "c1", ToolName: "guarded", NeedsApproval: true}, + emit, + ) + }() + + // Wait for the token, then approve. + select { + case <-tokenSet: + case <-time.After(3 * time.Second): + t.Fatal("timed out waiting for approval token") + } + mu.Lock() + tok := capturedToken + mu.Unlock() + + require.NoError(t, rt.Approve(context.Background(), tok, types.ApprovalStatusApproved)) + + <-done + require.NoError(t, resultErr) + require.Equal(t, "stream-ok", resultMsg.Content) +} + +func TestExecuteSingleTool_ApprovalContextCancel(t *testing.T) { + tool := stubTool{name: "guarded", result: "should not run", needsApproval: true} + rt := newLoopRT(t, 5, &seqLLMClient{}, tool) + + ctx, cancel := context.WithCancel(context.Background()) + + done := make(chan struct{}) + go func() { + defer close(done) + time.Sleep(20 * time.Millisecond) + cancel() + }() + + _, err := rt.executeSingleTool(ctx, + AgentLoopInput{ChannelName: "some-channel"}, "msg", + base.ToolCallRequest{ToolCallID: "c1", ToolName: "guarded", NeedsApproval: true}, noopEmit) + + <-done + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) +} + +// --------------------------------------------------------------------------- +// publishEventToChannel +// --------------------------------------------------------------------------- + +func TestPublishEventToChannel_NoOpWhenChannelEmpty(t *testing.T) { + rt := newLoopRT(t, 5, &seqLLMClient{}) + require.NotPanics(t, func() { + rt.publishEventToChannel(context.Background(), "", events.NewAgentRunErrorEvent("x")) + }) +} + +func TestPublishEventToChannel_NoOpWhenNilEvent(t *testing.T) { + rt := newLoopRT(t, 5, &seqLLMClient{}) + require.NotPanics(t, func() { + rt.publishEventToChannel(context.Background(), "ch", nil) + }) +} + +func TestPublishEventToChannel_NoOpWhenNilEventbus(t *testing.T) { + rt := &LocalRuntime{ + Runtime: base.Runtime{ + AgentSpec: sdkruntime.AgentSpec{Name: "a"}, + }, + logger: logger.NoopLogger(), + // eventbus is nil + } + require.NotPanics(t, func() { + rt.publishEventToChannel(context.Background(), "ch", events.NewAgentRunErrorEvent("x")) + }) +} + +// --------------------------------------------------------------------------- +// persistConversationMessages +// --------------------------------------------------------------------------- + +func TestPersistConversationMessages_NilConversation(t *testing.T) { + rt := newLoopRT(t, 5, &seqLLMClient{}) + // No conversation configured — must not panic or error. + err := persistConversationMessages(context.Background(), rt, "c", []interfaces.Message{ + {Role: interfaces.MessageRoleUser, Content: "hi"}, + }) + require.NoError(t, err) +} + +func TestPersistConversationMessages_StoresAllMessages(t *testing.T) { + ctrl := gomock.NewController(t) + conv := ifmocks.NewMockConversation(ctrl) + conv.EXPECT().AddMessage(gomock.Any(), "conv-1", gomock.Any()).Return(nil).Times(3) + + rt, err := NewLocalRuntime( + WithLogger(logger.NoopLogger()), + WithAgentExecution(sdkruntime.AgentExecution{ + LLM: sdkruntime.AgentLLM{Client: &seqLLMClient{}}, + Session: sdkruntime.AgentSession{Conversation: conv}, + Limits: sdkruntime.AgentLimits{Timeout: 5 * time.Second}, + }), + ) + require.NoError(t, err) + + msgs := []interfaces.Message{ + {Role: interfaces.MessageRoleUser, Content: "1"}, + {Role: interfaces.MessageRoleAssistant, Content: "2"}, + {Role: interfaces.MessageRoleTool, Content: "3"}, + } + err = persistConversationMessages(context.Background(), rt, "conv-1", msgs) + require.NoError(t, err) +} + +func TestPersistConversationMessages_AddMessageErrorWarnsOnly(t *testing.T) { + ctrl := gomock.NewController(t) + conv := ifmocks.NewMockConversation(ctrl) + conv.EXPECT().AddMessage(gomock.Any(), "c", gomock.Any()).Return(errors.New("store err")).AnyTimes() + + rt, err := NewLocalRuntime( + WithLogger(logger.NoopLogger()), + WithAgentExecution(sdkruntime.AgentExecution{ + LLM: sdkruntime.AgentLLM{Client: &seqLLMClient{}}, + Session: sdkruntime.AgentSession{Conversation: conv}, + Limits: sdkruntime.AgentLimits{Timeout: 5 * time.Second}, + }), + ) + require.NoError(t, err) + + // persistConversationMessages returns nil even when AddMessage fails (warns only). + err = persistConversationMessages(context.Background(), rt, "c", []interfaces.Message{ + {Role: interfaces.MessageRoleUser, Content: "hi"}, + }) + require.NoError(t, err) +} + +// --------------------------------------------------------------------------- +// authorizerStubLocal — tool with configurable authorization for loop tests +// --------------------------------------------------------------------------- + +type authorizerStubLocal struct { + name string + allow bool + reason string + authErr error +} + +func (a authorizerStubLocal) Name() string { return a.name } +func (a authorizerStubLocal) DisplayName() string { return a.name } +func (a authorizerStubLocal) Description() string { return "" } +func (a authorizerStubLocal) Parameters() interfaces.JSONSchema { return nil } +func (a authorizerStubLocal) Execute(_ context.Context, _ map[string]any) (any, error) { + return "auth-result", nil +} +func (a authorizerStubLocal) Authorize(_ context.Context, _ map[string]any) (interfaces.ToolAuthorizationDecision, error) { + return interfaces.ToolAuthorizationDecision{Allow: a.allow, Reason: a.reason}, a.authErr +} diff --git a/internal/runtime/local/options.go b/internal/runtime/local/options.go new file mode 100644 index 0000000..74640a9 --- /dev/null +++ b/internal/runtime/local/options.go @@ -0,0 +1,79 @@ +package local + +import ( + "context" + "fmt" + "log/slog" + + sdkruntime "github.com/agenticenv/agent-sdk-go/internal/runtime" + "github.com/agenticenv/agent-sdk-go/internal/types" + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" + "github.com/agenticenv/agent-sdk-go/pkg/logger" + "github.com/agenticenv/agent-sdk-go/pkg/observability" +) + +type Option func(*LocalRuntime) + +func WithLogger(l logger.Logger) Option { + return func(r *LocalRuntime) { + if l != nil { + r.logger = l + } + } +} + +func WithAgentSpec(spec sdkruntime.AgentSpec) Option { + return func(r *LocalRuntime) { + r.AgentSpec = spec + } +} + +func WithAgentExecution(execution sdkruntime.AgentExecution) Option { + return func(r *LocalRuntime) { + r.AgentExecution = execution + } +} + +func WithTracer(tracer interfaces.Tracer) Option { + return func(r *LocalRuntime) { + r.Tracer = tracer + } +} + +func WithMetrics(metrics interfaces.Metrics) Option { + return func(r *LocalRuntime) { + r.Metrics = metrics + } +} + +func WithToolExecutionMode(mode types.AgentToolExecutionMode) Option { + return func(r *LocalRuntime) { + r.ToolExecutionMode = mode + } +} + +func buildLocalRuntime(opts ...Option) (*LocalRuntime, error) { + r := &LocalRuntime{logger: logger.NoopLogger()} + for _, opt := range opts { + opt(r) + } + + if r.AgentExecution.LLM.Client == nil { + return nil, fmt.Errorf("llm client is required") + } + + if r.Tracer == nil { + r.Tracer = observability.DefaultNoopTracer + } + if r.Metrics == nil { + r.Metrics = observability.DefaultNoopMetrics + } + + r.logger.Debug(context.Background(), "runtime config resolved", + slog.String("scope", "runtime"), + slog.String("agentName", r.AgentSpec.Name), + slog.Bool("hasTracer", r.Tracer != nil), + slog.Bool("hasMetrics", r.Metrics != nil), + ) + return r, nil +} diff --git a/internal/runtime/local/runtime.go b/internal/runtime/local/runtime.go new file mode 100644 index 0000000..444d5c8 --- /dev/null +++ b/internal/runtime/local/runtime.go @@ -0,0 +1,272 @@ +package local + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "strings" + "sync" + + "github.com/agenticenv/agent-sdk-go/internal/eventbus" + "github.com/agenticenv/agent-sdk-go/internal/events" + sdkruntime "github.com/agenticenv/agent-sdk-go/internal/runtime" + "github.com/agenticenv/agent-sdk-go/internal/runtime/base" + "github.com/agenticenv/agent-sdk-go/internal/types" + "github.com/agenticenv/agent-sdk-go/pkg/logger" + "github.com/google/uuid" +) + +var _ sdkruntime.Runtime = (*LocalRuntime)(nil) +var _ sdkruntime.EventBusRuntime = (*LocalRuntime)(nil) + +// LocalRuntime executes the agent loop in-process, embedding base.Runtime for shared +// core methods and holding local-specific fields (logger, eventbus). +type LocalRuntime struct { + base.Runtime + + logger logger.Logger + eventbus eventbus.EventBus + + // pendingApprovals holds token → resolve channel for tools awaiting human approval. + // Used by Approve() to unblock executeSingleTool when the caller responds via OnApproval + // (streaming path). Thread-safe: parallel tool calls each register their own token. + pendingApprovals sync.Map // key: string token, value: chan types.ApprovalStatus +} + +// NewLocalRuntime constructs a LocalRuntime from functional options. +func NewLocalRuntime(opts ...Option) (*LocalRuntime, error) { + r, err := buildLocalRuntime(opts...) + if err != nil { + return nil, err + } + r.logger.Info(context.Background(), "runtime created", + slog.String("scope", "runtime"), + slog.String("name", r.AgentSpec.Name)) + r.eventbus = eventbus.NewInmem(r.logger) + return r, nil +} + +// localChannelName returns the eventbus channel name for one run. +func localChannelName(runID string) string { + return "agent-event-" + runID +} + +// subscribeToAgentEvents subscribes to the run channel and returns a typed event channel +// plus a close function. Events are decoded from the raw JSON published by publishEventToChannel. +func (rt *LocalRuntime) subscribeToAgentEvents(ctx context.Context, channel string) (<-chan events.AgentEvent, func() error, error) { + rawCh, closeFn, err := rt.eventbus.Subscribe(ctx, channel) + if err != nil { + return nil, nil, fmt.Errorf("local: subscribe to channel %q: %w", channel, err) + } + outCh := make(chan events.AgentEvent, 64) + go func() { + defer close(outCh) + for data := range rawCh { + ev, err := events.EventFromJSON(data) + if err != nil { + rt.logger.Warn(ctx, "local: failed to decode agent event", + slog.String("scope", "runtime"), + slog.Any("error", err)) + continue + } + if ev != nil { + outCh <- ev + } + } + }() + return outCh, closeFn, nil +} + +// publishLifecycleEvent publishes a lifecycle event (RUN_STARTED, RUN_FINISHED, RUN_ERROR) to the +// run channel. Uses context.Background so a cancelled runCtx never drops the terminal event. +func (rt *LocalRuntime) publishLifecycleEvent(channel string, ev events.AgentEvent) { + if rt.eventbus == nil || channel == "" || ev == nil { + return + } + data, err := json.Marshal(ev) + if err != nil { + return + } + if err := rt.eventbus.Publish(context.Background(), channel, data); err != nil { + rt.logger.Warn(context.Background(), "local: lifecycle event publish failed", + slog.String("scope", "runtime"), + slog.String("channel", channel), + slog.String("type", string(ev.Type())), + slog.Any("error", err)) + } +} + +// Execute runs the agent loop synchronously and returns the final result. +// Approval is handled inline via req.ApprovalHandler (no out-of-band tokens). +func (rt *LocalRuntime) Execute(ctx context.Context, req *sdkruntime.ExecuteRequest) (*types.AgentRunResult, error) { + agentName := agentNameFromRequest(req) + rt.logger.Debug(ctx, "runtime execute", + slog.String("scope", "runtime"), + slog.String("agent", agentName), + slog.Int("inputLen", len(req.UserPrompt))) + + // Apply agent timeout when the caller has not set a deadline. + runCtx := ctx + if d := rt.AgentExecution.Limits.Timeout; d > 0 { + if _, hasDeadline := ctx.Deadline(); !hasDeadline { + var cancel context.CancelFunc + runCtx, cancel = context.WithTimeout(ctx, d) + defer cancel() + } + } + + runID := uuid.New().String() + loopResult, err := rt.RunAgentLoop(runCtx, AgentLoopInput{ + UserPrompt: req.UserPrompt, + ConversationID: req.ConversationID, + StreamingEnabled: false, + ChannelName: "", + ApprovalHandler: req.ApprovalHandler, + SubAgentRoutes: buildSubAgentRoutes(req.SubAgents), + SubAgentDepth: 0, + MaxSubAgentDepth: req.MaxSubAgentDepth, + }) + if err != nil { + return nil, err + } + + _ = runID + return &types.AgentRunResult{ + Content: loopResult.Content, + AgentName: strings.TrimSpace(agentName), + Model: rt.AgentExecution.LLM.Client.GetModel(), + Metadata: map[string]any{}, + Usage: loopResult.Usage, + }, nil +} + +// ExecuteStream starts the agent loop in a goroutine and returns a channel of AgentEvent. +// RUN_STARTED is emitted before the loop begins; RUN_FINISHED or RUN_ERROR closes the channel. +func (rt *LocalRuntime) ExecuteStream(ctx context.Context, req *sdkruntime.ExecuteRequest) (<-chan events.AgentEvent, error) { + agentName := agentNameFromRequest(req) + rt.logger.Debug(ctx, "runtime execute stream", + slog.String("scope", "runtime"), + slog.String("agent", agentName), + slog.Int("inputLen", len(req.UserPrompt))) + + runID := uuid.New().String() + threadID := req.ConversationID + if threadID == "" { + threadID = runID + } + channel := localChannelName(runID) + + // Apply agent timeout. + runCtx := ctx + var runCancel context.CancelFunc + if d := rt.AgentExecution.Limits.Timeout; d > 0 { + if _, hasDeadline := ctx.Deadline(); !hasDeadline { + runCtx, runCancel = context.WithTimeout(ctx, d) + } + } + + // Subscribe before launching the loop so no events are lost. + eventCh, closeSub, err := rt.subscribeToAgentEvents(runCtx, channel) + if err != nil { + if runCancel != nil { + runCancel() + } + return nil, err + } + + outCh := make(chan events.AgentEvent, 64) + + // Forward subscription events to the caller's channel. + go func() { + defer close(outCh) + for ev := range eventCh { + if ev != nil { + outCh <- ev + } + } + }() + + // Emit RUN_STARTED before the loop so callers always see the lifecycle preamble. + rt.publishLifecycleEvent(channel, events.NewAgentRunStartedEvent(threadID, runID)) + + // Run the agent loop in a goroutine; emit lifecycle terminal event on completion. + go func() { + defer func() { + if runCancel != nil { + runCancel() + } + _ = closeSub() + }() + + result, loopErr := rt.RunAgentLoop(runCtx, AgentLoopInput{ + UserPrompt: req.UserPrompt, + ConversationID: req.ConversationID, + StreamingEnabled: req.StreamingEnabled, + ChannelName: channel, + ApprovalHandler: req.ApprovalHandler, + SubAgentRoutes: buildSubAgentRoutes(req.SubAgents), + SubAgentDepth: 0, + MaxSubAgentDepth: req.MaxSubAgentDepth, + }) + + if loopErr != nil { + rt.logger.Error(runCtx, "runtime stream run failed", + slog.String("scope", "runtime"), + slog.String("runID", runID), + slog.Any("error", loopErr)) + rt.publishLifecycleEvent(channel, events.NewAgentRunErrorEvent(loopErr.Error())) + return + } + + agentRunResult := &types.AgentRunResult{ + Content: result.Content, + AgentName: strings.TrimSpace(agentName), + Model: rt.AgentExecution.LLM.Client.GetModel(), + Metadata: map[string]any{}, + Usage: result.Usage, + } + rt.publishLifecycleEvent(channel, events.NewAgentRunFinishedEvent(threadID, runID, agentRunResult)) + }() + + return outCh, nil +} + +// Approve resolves a pending tool approval registered during a streaming run. +// When a tool requires approval, executeSingleTool registers a token and blocks; the +// caller receives a CUSTOM event on the stream with that token and calls Approve to unblock. +func (rt *LocalRuntime) Approve(_ context.Context, approvalToken string, status types.ApprovalStatus) error { + val, ok := rt.pendingApprovals.LoadAndDelete(approvalToken) + if !ok { + return fmt.Errorf("local: no pending approval for token %q", approvalToken) + } + ch := val.(chan types.ApprovalStatus) + ch <- status + return nil +} + +// Close releases runtime resources. +func (rt *LocalRuntime) Close() { + rt.logger.Info(context.Background(), "runtime closed", + slog.String("scope", "runtime"), + slog.String("name", rt.AgentSpec.Name)) +} + +// GetEventBus returns the runtime's in-process event bus so pkg/agent can wire sub-agents +// to the same bus for streaming fan-in and delegation events. +func (rt *LocalRuntime) GetEventBus() eventbus.EventBus { + return rt.eventbus +} + +// SetEventBus replaces the runtime's event bus. Called by pkg/agent when wiring a sub-agent +// tree so all agents in the tree share the parent's bus. +func (rt *LocalRuntime) SetEventBus(bus eventbus.EventBus) { + rt.eventbus = bus +} + +func agentNameFromRequest(req *sdkruntime.ExecuteRequest) string { + if req == nil || req.AgentSpec == nil { + return "" + } + return req.AgentSpec.Name +} diff --git a/internal/runtime/local/runtime_test.go b/internal/runtime/local/runtime_test.go new file mode 100644 index 0000000..f8cadd4 --- /dev/null +++ b/internal/runtime/local/runtime_test.go @@ -0,0 +1,653 @@ +package local + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "github.com/agenticenv/agent-sdk-go/internal/events" + sdkruntime "github.com/agenticenv/agent-sdk-go/internal/runtime" + "github.com/agenticenv/agent-sdk-go/internal/runtime/base" + "github.com/agenticenv/agent-sdk-go/internal/types" + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" + ifmocks "github.com/agenticenv/agent-sdk-go/pkg/interfaces/mocks" + "github.com/agenticenv/agent-sdk-go/pkg/logger" + "github.com/agenticenv/agent-sdk-go/pkg/observability" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// Shared test stubs +// --------------------------------------------------------------------------- + +// seqLLMClient returns LLM responses from a pre-loaded sequence. +// Once the sequence is exhausted it returns a plain "done" response. +type seqLLMClient struct { + mu sync.Mutex + responses []*interfaces.LLMResponse + errs []error + call int +} + +func (s *seqLLMClient) Generate(_ context.Context, _ *interfaces.LLMRequest) (*interfaces.LLMResponse, error) { + s.mu.Lock() + defer s.mu.Unlock() + i := s.call + s.call++ + if i < len(s.errs) && s.errs[i] != nil { + return nil, s.errs[i] + } + if i < len(s.responses) { + return s.responses[i], nil + } + return &interfaces.LLMResponse{Content: "done"}, nil +} +func (s *seqLLMClient) GenerateStream(_ context.Context, _ *interfaces.LLMRequest) (interfaces.LLMStream, error) { + return nil, errors.New("stream not implemented in seqLLMClient") +} +func (s *seqLLMClient) GetModel() string { return "test-model" } +func (s *seqLLMClient) GetProvider() interfaces.LLMProvider { return interfaces.LLMProviderOpenAI } +func (s *seqLLMClient) IsStreamSupported() bool { return false } + +// stubTool is a minimal Tool with configurable execute result and optional approval. +type stubTool struct { + name string + result string + execErr error + needsApproval bool +} + +func (t stubTool) Name() string { return t.name } +func (t stubTool) DisplayName() string { return t.name } +func (t stubTool) Description() string { return "" } +func (t stubTool) Parameters() interfaces.JSONSchema { return nil } +func (t stubTool) Execute(_ context.Context, _ map[string]any) (any, error) { + return t.result, t.execErr +} +func (t stubTool) ApprovalRequired() bool { return t.needsApproval } + +// newLocalRT constructs a LocalRuntime suitable for tests. +func newLocalRT(t *testing.T, client interfaces.LLMClient, tools ...interfaces.Tool) *LocalRuntime { + t.Helper() + rt, err := NewLocalRuntime( + WithLogger(logger.NoopLogger()), + WithAgentSpec(sdkruntime.AgentSpec{Name: "test-agent", SystemPrompt: "you are helpful"}), + WithAgentExecution(sdkruntime.AgentExecution{ + LLM: sdkruntime.AgentLLM{Client: client}, + Tools: sdkruntime.AgentTools{Tools: tools}, + Limits: sdkruntime.AgentLimits{ + MaxIterations: 5, + Timeout: 30 * time.Second, + }, + }), + ) + require.NoError(t, err) + return rt +} + +// collectEvents drains an event channel until it is closed or timeout elapses, +// returning all events received. +func collectEvents(t *testing.T, ch <-chan events.AgentEvent, timeout time.Duration) []events.AgentEvent { + t.Helper() + var collected []events.AgentEvent + deadline := time.After(timeout) + for { + select { + case ev, ok := <-ch: + if !ok { + return collected + } + if ev != nil { + collected = append(collected, ev) + } + case <-deadline: + t.Fatalf("collectEvents: timed out after %s waiting for channel to close", timeout) + return collected + } + } +} + +// eventTypes extracts the AgentEventType from each collected event. +func eventTypes(evs []events.AgentEvent) []events.AgentEventType { + out := make([]events.AgentEventType, len(evs)) + for i, ev := range evs { + out[i] = ev.Type() + } + return out +} + +// --------------------------------------------------------------------------- +// NewLocalRuntime +// --------------------------------------------------------------------------- + +func TestNewLocalRuntime_MissingLLMClient(t *testing.T) { + _, err := NewLocalRuntime( + WithAgentSpec(sdkruntime.AgentSpec{Name: "agent"}), + WithAgentExecution(sdkruntime.AgentExecution{}), + ) + require.Error(t, err) + require.Contains(t, err.Error(), "llm client is required") +} + +func TestNewLocalRuntime_DefaultNoopObservability(t *testing.T) { + rt, err := NewLocalRuntime( + WithAgentExecution(sdkruntime.AgentExecution{ + LLM: sdkruntime.AgentLLM{Client: &seqLLMClient{}}, + }), + ) + require.NoError(t, err) + require.Equal(t, observability.DefaultNoopTracer, rt.Tracer) + require.Equal(t, observability.DefaultNoopMetrics, rt.Metrics) +} + +func TestNewLocalRuntime_WithAllOptions(t *testing.T) { + ctrl := gomock.NewController(t) + tracer := ifmocks.NewMockTracer(ctrl) + metrics := ifmocks.NewMockMetrics(ctrl) + + rt, err := NewLocalRuntime( + WithLogger(logger.NoopLogger()), + WithAgentSpec(sdkruntime.AgentSpec{Name: "my-agent"}), + WithAgentExecution(sdkruntime.AgentExecution{ + LLM: sdkruntime.AgentLLM{Client: &seqLLMClient{}}, + }), + WithTracer(tracer), + WithMetrics(metrics), + WithToolExecutionMode(types.AgentToolExecutionModeSequential), + ) + require.NoError(t, err) + require.Equal(t, "my-agent", rt.AgentSpec.Name) + require.Equal(t, tracer, rt.Tracer) + require.Equal(t, metrics, rt.Metrics) + require.Equal(t, types.AgentToolExecutionModeSequential, rt.ToolExecutionMode) +} + +func TestNewLocalRuntime_EventBusInitialised(t *testing.T) { + rt := newLocalRT(t, &seqLLMClient{}) + require.NotNil(t, rt.eventbus, "eventbus should be initialised by NewLocalRuntime") +} + +// --------------------------------------------------------------------------- +// agentNameFromRequest +// --------------------------------------------------------------------------- + +func TestAgentNameFromRequest_NilRequest(t *testing.T) { + require.Equal(t, "", agentNameFromRequest(nil)) +} + +func TestAgentNameFromRequest_NilSpec(t *testing.T) { + require.Equal(t, "", agentNameFromRequest(&sdkruntime.ExecuteRequest{})) +} + +func TestAgentNameFromRequest_WithName(t *testing.T) { + req := &sdkruntime.ExecuteRequest{AgentSpec: &sdkruntime.AgentSpec{Name: "hello"}} + require.Equal(t, "hello", agentNameFromRequest(req)) +} + +// --------------------------------------------------------------------------- +// Execute +// --------------------------------------------------------------------------- + +func TestExecute_SimpleTextResponse(t *testing.T) { + client := &seqLLMClient{ + responses: []*interfaces.LLMResponse{ + {Content: "Hello from the agent"}, + }, + } + rt := newLocalRT(t, client) + + result, err := rt.Execute(context.Background(), &sdkruntime.ExecuteRequest{ + UserPrompt: "hi", + AgentSpec: &sdkruntime.AgentSpec{Name: "test-agent"}, + }) + + require.NoError(t, err) + require.Equal(t, "Hello from the agent", result.Content) + require.Equal(t, "test-agent", result.AgentName) + require.Equal(t, "test-model", result.Model) +} + +func TestExecute_PropagatesLLMError(t *testing.T) { + client := &seqLLMClient{ + errs: []error{errors.New("llm unavailable")}, + } + rt := newLocalRT(t, client) + + _, err := rt.Execute(context.Background(), &sdkruntime.ExecuteRequest{UserPrompt: "hi"}) + require.Error(t, err) + require.Contains(t, err.Error(), "llm unavailable") +} + +func TestExecute_AppliesTimeoutWhenNoDeadline(t *testing.T) { + // Build a runtime with a very short timeout. + blocking := &blockingLLMClient{block: make(chan struct{})} + rt, err := NewLocalRuntime( + WithLogger(logger.NoopLogger()), + WithAgentExecution(sdkruntime.AgentExecution{ + LLM: sdkruntime.AgentLLM{Client: blocking}, + Limits: sdkruntime.AgentLimits{ + MaxIterations: 1, + Timeout: 50 * time.Millisecond, + }, + }), + ) + require.NoError(t, err) + + // Pass context.Background() — no deadline — so Execute applies the runtime timeout. + start := time.Now() + _, err = rt.Execute(context.Background(), &sdkruntime.ExecuteRequest{UserPrompt: "hi"}) + elapsed := time.Since(start) + + // Should have been cancelled by the 50ms runtime timeout. + require.Error(t, err) + assert.Less(t, elapsed, 2*time.Second, "runtime timeout should fire well before 2s") +} + +// blockingLLMClient blocks until its context is cancelled. +type blockingLLMClient struct { + block chan struct{} +} + +func (b *blockingLLMClient) Generate(ctx context.Context, _ *interfaces.LLMRequest) (*interfaces.LLMResponse, error) { + <-ctx.Done() + return nil, ctx.Err() +} +func (b *blockingLLMClient) GenerateStream(_ context.Context, _ *interfaces.LLMRequest) (interfaces.LLMStream, error) { + return nil, errors.New("not supported") +} +func (b *blockingLLMClient) GetModel() string { return "blocking" } +func (b *blockingLLMClient) GetProvider() interfaces.LLMProvider { return interfaces.LLMProviderOpenAI } +func (b *blockingLLMClient) IsStreamSupported() bool { return false } + +func TestExecute_WithApprovalHandler(t *testing.T) { + client := &seqLLMClient{ + responses: []*interfaces.LLMResponse{ + { + ToolCalls: []*interfaces.ToolCall{ + {ToolCallID: "c1", ToolName: "approve-tool"}, + }, + }, + {Content: "tool done"}, + }, + } + tool := stubTool{name: "approve-tool", result: "executed", needsApproval: true} + rt := newLocalRT(t, client, tool) + + handlerCalled := false + handler := func(_ context.Context, req *types.ApprovalRequest) { + handlerCalled = true + _ = req.Respond(types.ApprovalStatusApproved) + } + + result, err := rt.Execute(context.Background(), &sdkruntime.ExecuteRequest{ + UserPrompt: "run tool", + ApprovalHandler: handler, + }) + + require.NoError(t, err) + require.True(t, handlerCalled, "approval handler must be called") + require.Equal(t, "tool done", result.Content) +} + +// --------------------------------------------------------------------------- +// ExecuteStream +// --------------------------------------------------------------------------- + +func TestExecuteStream_EmitsRunStartedAndFinished(t *testing.T) { + client := &seqLLMClient{ + responses: []*interfaces.LLMResponse{{Content: "stream answer"}}, + } + rt := newLocalRT(t, client) + + ch, err := rt.ExecuteStream(context.Background(), &sdkruntime.ExecuteRequest{ + UserPrompt: "hello", + AgentSpec: &sdkruntime.AgentSpec{Name: "test-agent"}, + }) + require.NoError(t, err) + + evs := collectEvents(t, ch, 5*time.Second) + types := eventTypes(evs) + + require.Contains(t, types, events.AgentEventTypeRunStarted) + require.Contains(t, types, events.AgentEventTypeRunFinished) + + // RUN_STARTED must come first, RUN_FINISHED last. + first := types[0] + last := types[len(types)-1] + require.Equal(t, events.AgentEventTypeRunStarted, first) + require.Equal(t, events.AgentEventTypeRunFinished, last) +} + +func TestExecuteStream_EmitsRunError(t *testing.T) { + client := &seqLLMClient{ + errs: []error{errors.New("llm down")}, + } + rt := newLocalRT(t, client) + + ch, err := rt.ExecuteStream(context.Background(), &sdkruntime.ExecuteRequest{ + UserPrompt: "hi", + }) + require.NoError(t, err) // subscribe succeeds synchronously + + evs := collectEvents(t, ch, 5*time.Second) + types := eventTypes(evs) + + require.Contains(t, types, events.AgentEventTypeRunStarted) + require.Contains(t, types, events.AgentEventTypeRunError) +} + +func TestExecuteStream_ChannelClosedAfterTerminalEvent(t *testing.T) { + client := &seqLLMClient{ + responses: []*interfaces.LLMResponse{{Content: "done"}}, + } + rt := newLocalRT(t, client) + + ch, err := rt.ExecuteStream(context.Background(), &sdkruntime.ExecuteRequest{UserPrompt: "hi"}) + require.NoError(t, err) + + // Channel must close eventually. + timeout := time.After(5 * time.Second) + for { + select { + case _, ok := <-ch: + if !ok { + return // channel closed — success + } + case <-timeout: + t.Fatal("channel never closed") + } + } +} + +func TestExecuteStream_ContextCancelledAborts(t *testing.T) { + blocking := &blockingLLMClient{block: make(chan struct{})} + rt := newLocalRT(t, blocking) + + ctx, cancel := context.WithCancel(context.Background()) + + ch, err := rt.ExecuteStream(ctx, &sdkruntime.ExecuteRequest{UserPrompt: "hi"}) + require.NoError(t, err) + + // Give the goroutine a moment to start, then cancel. + time.Sleep(20 * time.Millisecond) + cancel() + + evs := collectEvents(t, ch, 3*time.Second) + types := eventTypes(evs) + require.Contains(t, types, events.AgentEventTypeRunStarted) + // Channel must close (error or finished). + // Verifying closure is enough; collectEvents blocks until close. + _ = types +} + +// --------------------------------------------------------------------------- +// Approve +// --------------------------------------------------------------------------- + +func TestApprove_UnknownToken(t *testing.T) { + rt := newLocalRT(t, &seqLLMClient{}) + err := rt.Approve(context.Background(), "nonexistent-token", types.ApprovalStatusApproved) + require.Error(t, err) + require.Contains(t, err.Error(), "no pending approval for token") +} + +func TestApprove_ResolvesRegisteredChannel(t *testing.T) { + rt := newLocalRT(t, &seqLLMClient{}) + + const token = "test-token-123" + resultCh := make(chan types.ApprovalStatus, 1) + rt.pendingApprovals.Store(token, resultCh) + + err := rt.Approve(context.Background(), token, types.ApprovalStatusApproved) + require.NoError(t, err) + + select { + case status := <-resultCh: + require.Equal(t, types.ApprovalStatusApproved, status) + case <-time.After(time.Second): + t.Fatal("expected status on channel, got timeout") + } + + // Token should have been removed by LoadAndDelete. + _, loaded := rt.pendingApprovals.Load(token) + require.False(t, loaded, "token must be removed after Approve") +} + +func TestApprove_RejectsViaSameToken(t *testing.T) { + rt := newLocalRT(t, &seqLLMClient{}) + + const token = "reject-token" + resultCh := make(chan types.ApprovalStatus, 1) + rt.pendingApprovals.Store(token, resultCh) + + err := rt.Approve(context.Background(), token, types.ApprovalStatusRejected) + require.NoError(t, err) + + status := <-resultCh + require.Equal(t, types.ApprovalStatusRejected, status) +} + +func TestApprove_DoubleApproveSecondErrors(t *testing.T) { + rt := newLocalRT(t, &seqLLMClient{}) + + const token = "double-token" + resultCh := make(chan types.ApprovalStatus, 1) + rt.pendingApprovals.Store(token, resultCh) + + require.NoError(t, rt.Approve(context.Background(), token, types.ApprovalStatusApproved)) + // Second call: token already removed by LoadAndDelete. + err := rt.Approve(context.Background(), token, types.ApprovalStatusApproved) + require.Error(t, err) +} + +func TestApprove_StreamingEndToEnd(t *testing.T) { + // LLM: first call returns a tool call needing approval, second returns final text. + client := &seqLLMClient{ + responses: []*interfaces.LLMResponse{ + {ToolCalls: []*interfaces.ToolCall{ + {ToolCallID: "c1", ToolName: "guarded-tool"}, + }}, + {Content: "approved result"}, + }, + } + tool := stubTool{name: "guarded-tool", result: "ran!", needsApproval: true} + rt := newLocalRT(t, client, tool) + + ch, err := rt.ExecuteStream(context.Background(), &sdkruntime.ExecuteRequest{ + UserPrompt: "run guarded tool", + }) + require.NoError(t, err) + + // Collect events until we see a CUSTOM approval event, then approve. + var approvalToken string + var allEvents []events.AgentEvent + + timeout := time.After(5 * time.Second) +outer: + for { + select { + case ev, ok := <-ch: + if !ok { + break outer + } + if ev == nil { + continue + } + allEvents = append(allEvents, ev) + if ev.Type() == events.AgentEventTypeCustom { + val, parseErr := events.ParseCustomEventApproval(ev.(*events.AgentCustomEvent)) + if parseErr == nil && val.ApprovalToken != "" { + approvalToken = val.ApprovalToken + // Approve in a separate goroutine to unblock the loop. + go func(tok string) { + _ = rt.Approve(context.Background(), tok, types.ApprovalStatusApproved) + }(approvalToken) + } + } + case <-timeout: + t.Fatal("timed out waiting for streaming events") + } + } + + types := eventTypes(allEvents) + require.NotEmpty(t, approvalToken, "expected an approval token in CUSTOM event") + require.Contains(t, types, events.AgentEventTypeRunFinished) +} + +// --------------------------------------------------------------------------- +// Close +// --------------------------------------------------------------------------- + +func TestClose_NoError(t *testing.T) { + rt := newLocalRT(t, &seqLLMClient{}) + require.NotPanics(t, rt.Close) +} + +// --------------------------------------------------------------------------- +// publishLifecycleEvent +// --------------------------------------------------------------------------- + +func TestPublishLifecycleEvent_NilEventbus(t *testing.T) { + rt := &LocalRuntime{ + Runtime: base.Runtime{ + AgentSpec: sdkruntime.AgentSpec{Name: "a"}, + }, + logger: logger.NoopLogger(), + } + // eventbus is nil — must not panic. + require.NotPanics(t, func() { + rt.publishLifecycleEvent("some-channel", events.NewAgentRunErrorEvent("oops")) + }) +} + +func TestPublishLifecycleEvent_EmptyChannel(t *testing.T) { + rt := newLocalRT(t, &seqLLMClient{}) + // empty channel — must not panic. + require.NotPanics(t, func() { + rt.publishLifecycleEvent("", events.NewAgentRunErrorEvent("oops")) + }) +} + +func TestPublishLifecycleEvent_NilEvent(t *testing.T) { + rt := newLocalRT(t, &seqLLMClient{}) + require.NotPanics(t, func() { + rt.publishLifecycleEvent("ch", nil) + }) +} + +// --------------------------------------------------------------------------- +// EventBusRuntime interface +// --------------------------------------------------------------------------- + +func TestGetEventBus_ReturnsInitialisedBus(t *testing.T) { + rt := newLocalRT(t, &seqLLMClient{}) + require.NotNil(t, rt.GetEventBus(), "GetEventBus must return the bus initialised by NewLocalRuntime") +} + +func TestSetEventBus_ReplacesBus(t *testing.T) { + rt := newLocalRT(t, &seqLLMClient{}) + original := rt.GetEventBus() + + // Build a second runtime and swap its bus into the first. + rt2 := newLocalRT(t, &seqLLMClient{}) + newBus := rt2.GetEventBus() + + rt.SetEventBus(newBus) + require.Same(t, newBus, rt.GetEventBus(), "GetEventBus should return the new bus") + require.NotSame(t, original, rt.GetEventBus(), "bus should have changed after SetEventBus") +} + +// --------------------------------------------------------------------------- +// localChannelName +// --------------------------------------------------------------------------- + +func TestLocalChannelName(t *testing.T) { + name := localChannelName("run-42") + require.Equal(t, "agent-event-run-42", name) +} + +// --------------------------------------------------------------------------- +// subscribeToAgentEvents +// --------------------------------------------------------------------------- + +func TestSubscribeToAgentEvents_DecodesEvents(t *testing.T) { + rt := newLocalRT(t, &seqLLMClient{}) + ctx := context.Background() + ch, closeFn, err := rt.subscribeToAgentEvents(ctx, "test-channel") + require.NoError(t, err) + defer func() { _ = closeFn() }() + + // Publish a raw lifecycle event. + ev := events.NewAgentRunStartedEvent("thread-1", "run-1") + rt.publishLifecycleEvent("test-channel", ev) + + select { + case received := <-ch: + require.Equal(t, events.AgentEventTypeRunStarted, received.Type()) + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for event") + } +} + +// --------------------------------------------------------------------------- +// Execute with tool call (two-turn) +// --------------------------------------------------------------------------- + +func TestExecute_ToolCallThenFinalAnswer(t *testing.T) { + client := &seqLLMClient{ + responses: []*interfaces.LLMResponse{ + { + ToolCalls: []*interfaces.ToolCall{ + {ToolCallID: "c1", ToolName: "calc"}, + }, + }, + {Content: "the answer is 42"}, + }, + } + tool := stubTool{name: "calc", result: "42"} + rt := newLocalRT(t, client, tool) + + result, err := rt.Execute(context.Background(), &sdkruntime.ExecuteRequest{ + UserPrompt: "compute", + }) + require.NoError(t, err) + require.Equal(t, "the answer is 42", result.Content) +} + +// --------------------------------------------------------------------------- +// Execute — conversation persistence +// --------------------------------------------------------------------------- + +func TestExecute_PersistsConversationMessages(t *testing.T) { + ctrl := gomock.NewController(t) + conv := ifmocks.NewMockConversation(ctrl) + + // ListMessages returns empty history for "conv-1". + conv.EXPECT().ListMessages(gomock.Any(), "conv-1", gomock.Any()).Return(nil, nil) + // AddMessage is called for each message (user + assistant = 2). + conv.EXPECT().AddMessage(gomock.Any(), "conv-1", gomock.Any()).Return(nil).Times(2) + + client := &seqLLMClient{ + responses: []*interfaces.LLMResponse{{Content: "persisted"}}, + } + rt, err := NewLocalRuntime( + WithLogger(logger.NoopLogger()), + WithAgentSpec(sdkruntime.AgentSpec{Name: "agent"}), + WithAgentExecution(sdkruntime.AgentExecution{ + LLM: sdkruntime.AgentLLM{Client: client}, + Session: sdkruntime.AgentSession{Conversation: conv, ConversationSize: 20}, + Limits: sdkruntime.AgentLimits{MaxIterations: 5, Timeout: 5 * time.Second}, + }), + ) + require.NoError(t, err) + + _, err = rt.Execute(context.Background(), &sdkruntime.ExecuteRequest{ + UserPrompt: "remember this", + ConversationID: "conv-1", + }) + require.NoError(t, err) +} diff --git a/internal/runtime/local/subagent.go b/internal/runtime/local/subagent.go new file mode 100644 index 0000000..e57db4c --- /dev/null +++ b/internal/runtime/local/subagent.go @@ -0,0 +1,39 @@ +package local + +import sdkruntime "github.com/agenticenv/agent-sdk-go/internal/runtime" + +// subAgentRoute is the local runtime's internal representation of a delegatable sub-agent. +// Built from ExecuteRequest.SubAgents by buildSubAgentRoutes; not shared with any other package. +type subAgentRoute struct { + name string + runtime *LocalRuntime + children map[string]subAgentRoute +} + +// buildSubAgentRoutes converts the runtime-agnostic SubAgentSpec tree (from ExecuteRequest) +// into local-specific routes. Each spec's Runtime is type-asserted to *LocalRuntime; +// specs with a non-local runtime are skipped (mixed-runtime delegation not supported). +func buildSubAgentRoutes(specs []*sdkruntime.SubAgentSpec) map[string]subAgentRoute { + if len(specs) == 0 { + return nil + } + out := make(map[string]subAgentRoute, len(specs)) + for _, spec := range specs { + if spec == nil { + continue + } + lr, ok := spec.Runtime.(*LocalRuntime) + if !ok { + continue + } + out[spec.ToolName] = subAgentRoute{ + name: spec.Name, + runtime: lr, + children: buildSubAgentRoutes(spec.Children), + } + } + if len(out) == 0 { + return nil + } + return out +} diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go index e1280ea..000487a 100644 --- a/internal/runtime/runtime.go +++ b/internal/runtime/runtime.go @@ -68,6 +68,19 @@ type EventBusRuntime interface { GetEventBus() eventbus.EventBus } +// SubAgentToolParamQuery is the tool/JSON parameter name for the query sent to a sub-agent. +const SubAgentToolParamQuery = "query" + +// SubAgentSpec describes one sub-agent in the delegation tree passed from pkg/agent to a runtime. +// The runtime builds its own internal routing structures from this tree; no runtime-specific fields +// are present here. ToolName is the sanitised tool name derived from Name and used as the map key. +type SubAgentSpec struct { + Name string // human-readable agent name + ToolName string // tool name used to invoke this sub-agent (key in runtime route maps) + Runtime Runtime // the sub-agent's runtime instance + Children []*SubAgentSpec +} + // AgentSpec describes agent identity and structured-output preferences for one run. // It is attached to [ExecuteRequest.AgentSpec] so custom Runtime implementations can read name, prompts, // and response format without importing pkg/agent. @@ -139,7 +152,7 @@ type ExecuteRequest struct { StreamingEnabled bool // EventTypes filters streamed events; empty means default (implementation-defined, often all types). EventTypes []events.AgentEventType - SubAgentRoutes map[string]types.SubAgentRoute + SubAgents []*SubAgentSpec MaxSubAgentDepth int ApprovalHandler types.ApprovalHandler diff --git a/internal/runtime/temporal/agent_workflow.go b/internal/runtime/temporal/agent_workflow.go index bfa1e34..a9dd3f9 100644 --- a/internal/runtime/temporal/agent_workflow.go +++ b/internal/runtime/temporal/agent_workflow.go @@ -11,6 +11,7 @@ import ( "time" "github.com/agenticenv/agent-sdk-go/internal/events" + "github.com/agenticenv/agent-sdk-go/internal/runtime/base" "github.com/agenticenv/agent-sdk-go/internal/types" "github.com/agenticenv/agent-sdk-go/pkg/interfaces" "github.com/google/uuid" @@ -116,18 +117,18 @@ func (rt *TemporalRuntime) sendAgentEventWorkflowUpdate(ctx context.Context, eve // EventTypes is set by the SDK; a single "*" element means emit all event kinds (used for Stream). // AgentFingerprint is the SHA-256 hex digest of the worker-local agent config; activities reject on mismatch. type AgentWorkflowInput struct { - UserPrompt string `json:"user_prompt,omitempty"` - EventWorkflowID string `json:"event_workflow_id,omitempty"` - EventTaskQueue string `json:"event_task_queue,omitempty"` - LocalChannelName string `json:"local_channel_name,omitempty"` - StreamingEnabled bool `json:"streaming_enabled,omitempty"` - ConversationID string `json:"conversation_id,omitempty"` - AgentFingerprint string `json:"agent_fingerprint,omitempty"` - EventTypes []events.AgentEventType `json:"event_types,omitempty"` - SubAgentDepth int `json:"sub_agent_depth,omitempty"` - SubAgentRoutes map[string]types.SubAgentRoute `json:"sub_agent_routes,omitempty"` - MaxSubAgentDepth int `json:"max_sub_agent_depth,omitempty"` - State *AgentWorkflowState `json:"state,omitempty"` + UserPrompt string `json:"user_prompt,omitempty"` + EventWorkflowID string `json:"event_workflow_id,omitempty"` + EventTaskQueue string `json:"event_task_queue,omitempty"` + LocalChannelName string `json:"local_channel_name,omitempty"` + StreamingEnabled bool `json:"streaming_enabled,omitempty"` + ConversationID string `json:"conversation_id,omitempty"` + AgentFingerprint string `json:"agent_fingerprint,omitempty"` + EventTypes []events.AgentEventType `json:"event_types,omitempty"` + SubAgentDepth int `json:"sub_agent_depth,omitempty"` + SubAgentRoutes map[string]SubAgentRoute `json:"sub_agent_routes,omitempty"` + MaxSubAgentDepth int `json:"max_sub_agent_depth,omitempty"` + State *AgentWorkflowState `json:"state,omitempty"` } // AgentWorkflowState is the state of the agent workflow. @@ -174,6 +175,26 @@ type AgentLLMResult struct { Usage *interfaces.LLMUsage `json:"usage,omitempty"` } +// baseLLMResultToActivity converts a [base.LLMResult] (no JSON tags) to an [AgentLLMResult] +// (with JSON tags required for Temporal serialization). ToolCallRequests are copied field by field +// so the two types stay independent (temporal adds JSON tags, base does not). +func baseLLMResultToActivity(r *base.LLMResult) *AgentLLMResult { + out := &AgentLLMResult{ + Content: r.Content, + Usage: base.CloneLLMUsage(r.Usage), + } + for _, tc := range r.ToolCalls { + out.ToolCalls = append(out.ToolCalls, ToolCallRequest{ + ToolCallID: tc.ToolCallID, + ToolName: tc.ToolName, + ToolDisplayName: tc.ToolDisplayName, + Args: tc.Args, + NeedsApproval: tc.NeedsApproval, + }) + } + return out +} + // ToolCallRequest is a tool invocation with approval flag. NeedsApproval is set by AgentLLMActivity. type ToolCallRequest struct { ToolCallID string `json:"tool_call_id"` // from LLM; used to match tool results @@ -424,7 +445,7 @@ func (rt *TemporalRuntime) AgentWorkflow(ctx workflow.Context, input AgentWorkfl return nil, err } - runUsage = mergeLLMUsage(runUsage, llmResult.Usage) + runUsage = base.MergeLLMUsage(runUsage, llmResult.Usage) if len(llmResult.ToolCalls) == 0 { // Final response: accumulate assistant message for conversation @@ -448,7 +469,7 @@ func (rt *TemporalRuntime) AgentWorkflow(ctx workflow.Context, input AgentWorkfl } return nil, err } - runUsage = mergeLLMUsage(runUsage, llmResult.Usage) + runUsage = base.MergeLLMUsage(runUsage, llmResult.Usage) messages = append(messages, interfaces.Message{Role: interfaces.MessageRoleAssistant, Content: llmResult.Content}) lastContent = llmResult.Content break @@ -471,7 +492,7 @@ func (rt *TemporalRuntime) AgentWorkflow(ctx workflow.Context, input AgentWorkfl var toolResults []interfaces.Message - toolExecMode := rt.AgentToolExecutionMode + toolExecMode := rt.ToolExecutionMode if toolExecMode == "" { toolExecMode = types.AgentToolExecutionModeParallel } @@ -891,266 +912,28 @@ func (rt *TemporalRuntime) AgentLLMStreamActivity(ctx context.Context, input Age } stopHB := startLongActivityHeartbeats(ctx) defer stopHB() - logger := activity.GetLogger(ctx) - info := activity.GetInfo(ctx) - agentWorkflowID := info.WorkflowExecution.ID + + actLog := newActivityLogger(activity.GetLogger(ctx)) agentName := strings.TrimSpace(input.AgentName) messages := input.Messages if input.ConversationID != "" { - convMessages, err := rt.fetchConversationMessages(ctx, input.ConversationID) + convMessages, err := rt.FetchConversationMessages(ctx, actLog, input.ConversationID) if err != nil { return nil, err } messages = append(convMessages, messages...) } - logger.Debug("activity: LLM stream started", "scope", "activity", "runID", agentWorkflowID, "messageCount", len(messages)) - - req, tools := rt.buildLLMRequest(messages, input.SkipTools, input.RetrieverContext) - - emitDelta := func(ev events.AgentEvent) { + emit := func(ev events.AgentEvent) { rt.publishAgentEventToStream(ctx, agentName, input.LocalChannelName, input.EventWorkflowID, input.EventTaskQueue, ev) } - textMsgOpen := false - openTextMsg := func() { - if textMsgOpen { - return - } - emitDelta(events.NewAgentTextMessageStartEvent(input.MessageID, string(interfaces.MessageRoleAssistant))) - textMsgOpen = true - } - closeTextMsg := func() { - if !textMsgOpen { - return - } - emitDelta(events.NewAgentTextMessageEndEvent(input.MessageID)) - textMsgOpen = false - } - // If the model never sent text chunks, still emit one text message (empty for tool-only) to match one activity = one assistant turn. - finalizeAssistantTextMessage := func(result *AgentLLMResult) { - if textMsgOpen { - closeTextMsg() - return - } - openTextMsg() - emitDelta(events.NewAgentTextMessageContentEvent(input.MessageID, result.Content)) - closeTextMsg() - } - - llmClient := rt.AgentExecution.LLM.Client - model := llmClient.GetModel() - provider := string(llmClient.GetProvider()) - modelAttr := interfaces.Attribute{Key: types.MetricAttrModel, Value: model} - providerAttr := interfaces.Attribute{Key: types.MetricAttrProvider, Value: provider} - - isLLMStreamSupported := llmClient.IsStreamSupported() - - rt.Metrics.IncrementCounter(ctx, types.MetricLLMCallStarted, modelAttr, providerAttr) - llmStart := time.Now() - - ctx, sp := rt.Tracer.StartSpan(ctx, "llm.stream", - interfaces.Attribute{Key: "agent.name", Value: agentName}, - interfaces.Attribute{Key: "message.count", Value: len(messages)}, - interfaces.Attribute{Key: "streaming", Value: isLLMStreamSupported}, - modelAttr, - providerAttr, - ) - defer sp.End() - - if !isLLMStreamSupported { - logger.Debug("activity: LLM stream unsupported, using generate", "scope", "activity") - resp, err := llmClient.Generate(ctx, req) - llmLatency := float64(time.Since(llmStart).Milliseconds()) - if err != nil { - sp.RecordError(err) - rt.Metrics.IncrementCounter(ctx, types.MetricLLMCallFailed, modelAttr, providerAttr) - rt.Metrics.RecordHistogram(ctx, types.MetricLLMLatencyMs, llmLatency, modelAttr, providerAttr) - return nil, err - } - result, err := rt.llmResponseToResult(resp, tools) - if err != nil { - sp.RecordError(err) - rt.Metrics.IncrementCounter(ctx, types.MetricLLMCallFailed, modelAttr, providerAttr) - rt.Metrics.RecordHistogram(ctx, types.MetricLLMLatencyMs, llmLatency, modelAttr, providerAttr) - return nil, err - } - rt.Metrics.RecordHistogram(ctx, types.MetricLLMLatencyMs, llmLatency, modelAttr, providerAttr) - rt.Metrics.IncrementCounter(ctx, types.MetricLLMCallCompleted, modelAttr, providerAttr) - if resp.Usage != nil { - rt.Metrics.RecordHistogram(ctx, types.MetricLLMTokensInput, float64(resp.Usage.PromptTokens), modelAttr, providerAttr) - rt.Metrics.RecordHistogram(ctx, types.MetricLLMTokensOutput, float64(resp.Usage.CompletionTokens), modelAttr, providerAttr) - } - finalizeAssistantTextMessage(result) - return result, nil - } - - stream, err := llmClient.GenerateStream(ctx, req) - if err != nil { - sp.RecordError(err) - rt.Metrics.IncrementCounter(ctx, types.MetricLLMCallFailed, modelAttr, providerAttr) - rt.Metrics.RecordHistogram(ctx, types.MetricLLMLatencyMs, float64(time.Since(llmStart).Milliseconds()), modelAttr, providerAttr) - return nil, err - } - - // Reasoning AG-UI order: REASONING_START → REASONING_MESSAGE_START → REASONING_MESSAGE_CONTENT* → - // REASONING_MESSAGE_END → REASONING_END (flushed before the first assistant text delta, or at stream end). - // reasoningMID is a new UUID per reasoning phase (regenerated after a prior phase is flushed). - var reasoningMID string - reasoningPhaseOpen := false - reasoningMsgOpen := false - flushReasoning := func() { - if reasoningMsgOpen { - emitDelta(events.NewAgentReasoningMessageEndEvent(reasoningMID)) - reasoningMsgOpen = false - } - if reasoningPhaseOpen { - emitDelta(events.NewAgentReasoningEndEvent(reasoningMID)) - reasoningPhaseOpen = false - } - } - openReasoning := func() { - if reasoningPhaseOpen { - return - } - reasoningMID = uuid.New().String() - emitDelta(events.NewAgentReasoningStartEvent(reasoningMID)) - reasoningPhaseOpen = true - emitDelta(events.NewAgentReasoningMessageStartEvent(reasoningMID, string(interfaces.MessageRoleReasoning))) - reasoningMsgOpen = true - } - - for stream.Next() { - chunk := stream.Current() - if chunk == nil { - continue - } - if chunk.ContentDelta != "" { - flushReasoning() - openTextMsg() - //TEXT_MESSAGE_CONTENT - emitDelta(events.NewAgentTextMessageContentEvent(input.MessageID, chunk.ContentDelta)) - } - if chunk.ThinkingDelta != "" { - openReasoning() - //REASONING_MESSAGE_CONTENT - emitDelta(events.NewAgentReasoningMessageContentEvent(reasoningMID, chunk.ThinkingDelta)) - } - } - flushReasoning() - llmLatency := float64(time.Since(llmStart).Milliseconds()) - if err := stream.Err(); err != nil { - sp.RecordError(err) - rt.Metrics.IncrementCounter(ctx, types.MetricLLMCallFailed, modelAttr, providerAttr) - rt.Metrics.RecordHistogram(ctx, types.MetricLLMLatencyMs, llmLatency, modelAttr, providerAttr) - return nil, err - } - - resp := stream.GetResult() - if resp == nil { - err := fmt.Errorf("stream completed without result") - sp.RecordError(err) - rt.Metrics.IncrementCounter(ctx, types.MetricLLMCallFailed, modelAttr, providerAttr) - rt.Metrics.RecordHistogram(ctx, types.MetricLLMLatencyMs, llmLatency, modelAttr, providerAttr) - return nil, err - } - logger.Debug("activity: LLM stream completed", "scope", "activity", "runID", agentWorkflowID) - result, err := rt.llmResponseToResult(resp, tools) + result, err := rt.ExecuteLLMStream(ctx, actLog, agentName, input.MessageID, messages, input.SkipTools, input.RetrieverContext, emit) if err != nil { - sp.RecordError(err) - rt.Metrics.IncrementCounter(ctx, types.MetricLLMCallFailed, modelAttr, providerAttr) - rt.Metrics.RecordHistogram(ctx, types.MetricLLMLatencyMs, llmLatency, modelAttr, providerAttr) return nil, err } - rt.Metrics.RecordHistogram(ctx, types.MetricLLMLatencyMs, llmLatency, modelAttr, providerAttr) - rt.Metrics.IncrementCounter(ctx, types.MetricLLMCallCompleted, modelAttr, providerAttr) - if resp.Usage != nil { - rt.Metrics.RecordHistogram(ctx, types.MetricLLMTokensInput, float64(resp.Usage.PromptTokens), modelAttr, providerAttr) - rt.Metrics.RecordHistogram(ctx, types.MetricLLMTokensOutput, float64(resp.Usage.CompletionTokens), modelAttr, providerAttr) - } - finalizeAssistantTextMessage(result) - return result, nil -} - -// buildLLMRequest builds an LLMRequest from messages, skipTools, and optional retrieverContext. -// When retrieverContext is non-empty (prefetch / hybrid mode) it is appended to the system prompt so the -// LLM sees pre-fetched documents on every call in the run. Returns the request and tools list. -func (rt *TemporalRuntime) buildLLMRequest(messages []interfaces.Message, skipTools bool, retrieverContext string) (*interfaces.LLMRequest, []interfaces.Tool) { - tools := rt.AgentExecution.Tools.Tools - systemMessage := rt.AgentSpec.SystemPrompt - if retrieverContext != "" { - systemMessage = fmt.Sprintf("%s\n\nRelevant Context:\n%s", rt.AgentSpec.SystemPrompt, retrieverContext) - } - req := &interfaces.LLMRequest{ - SystemMessage: systemMessage, - ResponseFormat: rt.AgentSpec.ResponseFormat, - Messages: messages, - } - applyLLMSampling(rt.AgentExecution.LLM.Sampling, req) - if skipTools { - req.Tools = []interfaces.ToolSpec{} - } else { - req.Tools = interfaces.ToolsToSpecs(tools) - } - return req, tools -} - -// fetchConversationMessages fetches messages for the LLM: fetches from conversation when ConversationID is set, -func (rt *TemporalRuntime) fetchConversationMessages(ctx context.Context, conversationID string) ([]interfaces.Message, error) { - logger := activity.GetLogger(ctx) - logger.Debug("activity: loading conversation history", "scope", "activity", "conversationID", conversationID) - - if rt.AgentExecution.Session.Conversation == nil { - return nil, fmt.Errorf("conversation is not configured") - } - - limit := rt.AgentExecution.Session.ConversationSize - if limit <= 0 { - limit = 20 - } - - ctx, sp := rt.Tracer.StartSpan(ctx, "conversation.get_messages", - interfaces.Attribute{Key: "conversation.id", Value: conversationID}, - interfaces.Attribute{Key: "limit", Value: limit}, - ) - defer sp.End() - - messages, err := rt.AgentExecution.Session.Conversation.ListMessages(ctx, conversationID, interfaces.WithLimit(limit)) - if err != nil { - sp.RecordError(err) - return nil, fmt.Errorf("failed to list conversation messages: %w", err) - } - - sp.SetAttribute("message.count", len(messages)) - logger.Debug("activity: conversation history loaded", "scope", "activity", "messageCount", len(messages)) - return messages, nil -} - -func (rt *TemporalRuntime) llmResponseToResult(resp *interfaces.LLMResponse, tools []interfaces.Tool) (*AgentLLMResult, error) { - result := &AgentLLMResult{Content: resp.Content, Usage: cloneLLMUsagePtr(resp.Usage)} - for _, tc := range resp.ToolCalls { - if tc == nil { - continue - } - tool, ok := findToolByName(tools, tc.ToolName) - if !ok { - return nil, fmt.Errorf("unknown tool: %s", tc.ToolName) - } - needsApproval := rt.requiresApproval(tool) - displayName := tool.DisplayName() - if displayName == "" { - displayName = tc.ToolName - } - result.ToolCalls = append(result.ToolCalls, ToolCallRequest{ - ToolCallID: tc.ToolCallID, - ToolName: tc.ToolName, - ToolDisplayName: displayName, - Args: tc.Args, - NeedsApproval: needsApproval, - }) - } - return result, nil + return baseLLMResultToActivity(result), nil } // AgentRetrieverActivity runs all configured retrievers in parallel using input.UserPrompt as the query, @@ -1162,162 +945,41 @@ func (rt *TemporalRuntime) AgentRetrieverActivity(ctx context.Context, input Age if err := rt.verifyAgentFingerprint(input.AgentFingerprint); err != nil { return nil, err } - - retrievers := rt.AgentExecution.Retrievers.Retrievers - if len(retrievers) == 0 { - return &AgentRetrieverResult{}, nil - } - - logger := activity.GetLogger(ctx) - logger.Debug("activity: retriever prefetch started", "scope", "activity", "retrieverCount", len(retrievers), "query", input.UserPrompt) - - type retrieverResult struct { - name string - docs []interfaces.Document - err error - } - - results := make([]retrieverResult, len(retrievers)) - var wg sync.WaitGroup - for i, r := range retrievers { - wg.Add(1) - go func(idx int, ret interfaces.Retriever) { - defer wg.Done() - name := ret.Name() - retrieverAttr := interfaces.Attribute{Key: types.MetricAttrRetriever, Value: name} - rt.Metrics.IncrementCounter(ctx, types.MetricRetrieverCallStarted, retrieverAttr) - start := time.Now() - - searchCtx, sp := rt.Tracer.StartSpan(ctx, "retriever.search", - interfaces.Attribute{Key: "retriever.name", Value: name}, - interfaces.Attribute{Key: "query", Value: input.UserPrompt}, - ) - docs, err := ret.Search(searchCtx, input.UserPrompt) - latency := float64(time.Since(start).Milliseconds()) - if err != nil { - sp.RecordError(err) - sp.End() - rt.Metrics.IncrementCounter(ctx, types.MetricRetrieverCallFailed, retrieverAttr) - rt.Metrics.RecordHistogram(ctx, types.MetricRetrieverLatencyMs, latency, retrieverAttr) - } else { - sp.End() - rt.Metrics.IncrementCounter(ctx, types.MetricRetrieverCallCompleted, retrieverAttr) - rt.Metrics.RecordHistogram(ctx, types.MetricRetrieverLatencyMs, latency, retrieverAttr) - } - results[idx] = retrieverResult{name: name, docs: docs, err: err} - }(i, r) - } - wg.Wait() - - multipleRetrievers := len(retrievers) > 1 - var sb strings.Builder - failedCount := 0 - for _, res := range results { - if res.err != nil { - failedCount++ - logger.Error("activity: retriever search failed, skipping", "scope", "activity", "retriever", res.name, "error", res.err) - continue - } - if len(res.docs) == 0 { - continue - } - if multipleRetrievers { - fmt.Fprintf(&sb, "## %s\n", res.name) - } - sb.WriteString(formatRetrieverDocs(res.docs)) - } - - if failedCount == len(retrievers) { - return nil, fmt.Errorf("retriever prefetch: all %d retriever(s) failed", len(retrievers)) - } - if failedCount > 0 { - logger.Warn("activity: some retrievers failed, continuing with partial context", "scope", "activity", "failed", failedCount, "total", len(retrievers)) + actLog := newActivityLogger(activity.GetLogger(ctx)) + retrieverContext, err := rt.ExecuteRetrievers(ctx, actLog, input.UserPrompt) + if err != nil { + return nil, err } - - retrieverContext := strings.TrimSpace(sb.String()) - logger.Debug("activity: retriever prefetch completed", "scope", "activity", "retrieverCount", len(retrievers), "hasContext", retrieverContext != "") return &AgentRetrieverResult{RetrieverContext: retrieverContext}, nil } -// formatRetrieverDocs formats a list of documents for injection into the LLM system prompt. -// Format: "[N] content\n(source: s, score: 0.XX)\n\n" for each document. -func formatRetrieverDocs(docs []interfaces.Document) string { - if len(docs) == 0 { - return "" - } - var sb strings.Builder - for i, doc := range docs { - fmt.Fprintf(&sb, types.RetrieverDocFormat, i+1, doc.Content, doc.Source, doc.Score) - } - return sb.String() -} - // AgentLLMActivity calls the LLM and returns content plus any tool calls. // When input.ConversationID is set, fetches from store and adds assistant message on completion. func (rt *TemporalRuntime) AgentLLMActivity(ctx context.Context, input AgentLLMInput) (*AgentLLMResult, error) { if err := rt.verifyAgentFingerprint(input.AgentFingerprint); err != nil { return nil, err } - logger := activity.GetLogger(ctx) + actLog := newActivityLogger(activity.GetLogger(ctx)) + agentName := strings.TrimSpace(input.AgentName) messages := input.Messages if input.ConversationID != "" { - convMessages, err := rt.fetchConversationMessages(ctx, input.ConversationID) + convMessages, err := rt.FetchConversationMessages(ctx, actLog, input.ConversationID) if err != nil { return nil, err } messages = append(convMessages, messages...) } - logger.Debug("activity: LLM generate started", "scope", "activity", "messageCount", len(messages)) - req, tools := rt.buildLLMRequest(messages, input.SkipTools, input.RetrieverContext) - - llmClient := rt.AgentExecution.LLM.Client - model := llmClient.GetModel() - provider := string(llmClient.GetProvider()) - modelAttr := interfaces.Attribute{Key: types.MetricAttrModel, Value: model} - providerAttr := interfaces.Attribute{Key: types.MetricAttrProvider, Value: provider} - - rt.Metrics.IncrementCounter(ctx, types.MetricLLMCallStarted, modelAttr, providerAttr) - llmStart := time.Now() - - ctx, sp := rt.Tracer.StartSpan(ctx, "llm.generate", - interfaces.Attribute{Key: "agent.name", Value: strings.TrimSpace(input.AgentName)}, - interfaces.Attribute{Key: "message.count", Value: len(messages)}, - modelAttr, - providerAttr, - ) - resp, err := llmClient.Generate(ctx, req) - llmLatency := float64(time.Since(llmStart).Milliseconds()) - if err != nil { - sp.RecordError(err) - sp.End() - rt.Metrics.IncrementCounter(ctx, types.MetricLLMCallFailed, modelAttr, providerAttr) - rt.Metrics.RecordHistogram(ctx, types.MetricLLMLatencyMs, llmLatency, modelAttr, providerAttr) - return nil, err - } - sp.End() - - rt.Metrics.RecordHistogram(ctx, types.MetricLLMLatencyMs, llmLatency, modelAttr, providerAttr) - rt.Metrics.IncrementCounter(ctx, types.MetricLLMCallCompleted, modelAttr, providerAttr) - if resp.Usage != nil { - rt.Metrics.RecordHistogram(ctx, types.MetricLLMTokensInput, float64(resp.Usage.PromptTokens), modelAttr, providerAttr) - rt.Metrics.RecordHistogram(ctx, types.MetricLLMTokensOutput, float64(resp.Usage.CompletionTokens), modelAttr, providerAttr) + emit := func(ev events.AgentEvent) { + rt.publishAgentEventToStream(ctx, agentName, input.LocalChannelName, input.EventWorkflowID, input.EventTaskQueue, ev) } - logger.Debug("activity: LLM generate completed", "scope", "activity", "messageCount", len(messages)) - result, err := rt.llmResponseToResult(resp, tools) + result, err := rt.ExecuteLLM(ctx, actLog, agentName, input.MessageID, messages, input.SkipTools, input.RetrieverContext, emit) if err != nil { return nil, err } - agentNameTrim := strings.TrimSpace(input.AgentName) - publish := func(ev events.AgentEvent) { - rt.publishAgentEventToStream(ctx, agentNameTrim, input.LocalChannelName, input.EventWorkflowID, input.EventTaskQueue, ev) - } - publish(events.NewAgentTextMessageStartEvent(input.MessageID, string(interfaces.MessageRoleAssistant))) - publish(events.NewAgentTextMessageContentEvent(input.MessageID, result.Content)) - publish(events.NewAgentTextMessageEndEvent(input.MessageID)) - return result, nil + return baseLLMResultToActivity(result), nil } // AgentToolApprovalActivity blocks until the driver completes it via CompleteActivity. @@ -1478,40 +1140,8 @@ func (rt *TemporalRuntime) AgentToolExecuteActivity(ctx context.Context, input A } stopHB := startLongActivityHeartbeats(ctx) defer stopHB() - toolName := input.ToolName - args := input.Args - logger := activity.GetLogger(ctx) - logger.Debug("activity: tool execute started", "scope", "activity", "tool", toolName, "argCount", len(args)) - tools := rt.AgentExecution.Tools.Tools - tool, ok := findToolByName(tools, toolName) - if !ok { - logger.Warn("activity: unknown tool", "scope", "activity", "tool", toolName) - return "", fmt.Errorf("unknown tool: %s", toolName) - } - - toolAttr := interfaces.Attribute{Key: types.MetricAttrTool, Value: toolName} - rt.Metrics.IncrementCounter(ctx, types.MetricToolCallStarted, toolAttr) - toolStart := time.Now() - - ctx, sp := rt.Tracer.StartSpan(ctx, "tool.execute", - interfaces.Attribute{Key: "tool.name", Value: toolName}, - interfaces.Attribute{Key: "arg.count", Value: len(args)}, - ) - defer sp.End() - - result, err := tool.Execute(ctx, args) - toolLatency := float64(time.Since(toolStart).Milliseconds()) - if err != nil { - sp.RecordError(err) - rt.Metrics.IncrementCounter(ctx, types.MetricToolCallFailed, toolAttr) - rt.Metrics.RecordHistogram(ctx, types.MetricToolLatencyMs, toolLatency, toolAttr) - return "", err - } - rt.Metrics.RecordHistogram(ctx, types.MetricToolLatencyMs, toolLatency, toolAttr) - rt.Metrics.IncrementCounter(ctx, types.MetricToolCallCompleted, toolAttr) - content := fmt.Sprintf("%v", result) - logger.Debug("activity: tool execute completed", "scope", "activity", "tool", toolName) - return content, nil + actLog := newActivityLogger(activity.GetLogger(ctx)) + return rt.ExecuteTool(ctx, actLog, input.ToolName, input.Args) } // AgentToolAuthorizeActivity checks optional programmatic authorization before approval/execute. @@ -1519,59 +1149,15 @@ func (rt *TemporalRuntime) AgentToolAuthorizeActivity(ctx context.Context, input if err := rt.verifyAgentFingerprint(input.AgentFingerprint); err != nil { return AgentToolAuthorizeResult{}, err } - toolName := input.ToolName - args := input.Args - logger := activity.GetLogger(ctx) - logger.Debug("activity: tool authorize started", "scope", "activity", "tool", toolName, "argCount", len(args)) - tools := rt.AgentExecution.Tools.Tools - tool, ok := findToolByName(tools, toolName) - if !ok { - logger.Warn("activity: unknown tool in authorization", "scope", "activity", "tool", toolName) - return AgentToolAuthorizeResult{}, fmt.Errorf("unknown tool: %s", toolName) - } - authorizer, ok := tool.(interfaces.ToolAuthorizer) - if !ok { - logger.Debug("activity: tool has no authorizer; allow by default", "scope", "activity", "tool", toolName) - return AgentToolAuthorizeResult{Allowed: true}, nil - } - - ctx, sp := rt.Tracer.StartSpan(ctx, "tool.authorize", - interfaces.Attribute{Key: "tool.name", Value: toolName}, - interfaces.Attribute{Key: "arg.count", Value: len(args)}, - ) - defer sp.End() - - decision, err := authorizer.Authorize(ctx, args) + actLog := newActivityLogger(activity.GetLogger(ctx)) + authResult, err := rt.AuthorizeTool(ctx, actLog, input.ToolName, input.Args) if err != nil { - sp.RecordError(err) - logger.Warn("activity: tool authorization failed", "scope", "activity", "tool", toolName, "error", err) return AgentToolAuthorizeResult{}, err } - if decision.Allow { - sp.SetAttribute("decision", "allowed") - logger.Debug("activity: tool authorization allowed", "scope", "activity", "tool", toolName) - return AgentToolAuthorizeResult{Allowed: true}, nil - } - reason := strings.TrimSpace(decision.Reason) - sp.SetAttribute("decision", "denied") - sp.SetAttribute("deny.reason", reason) - logger.Info("activity: tool authorization denied", "scope", "activity", "tool", toolName, "reason", reason) - return AgentToolAuthorizeResult{ - Allowed: false, - Reason: reason, - }, nil + return AgentToolAuthorizeResult{Allowed: authResult.Allowed, Reason: authResult.Reason}, nil } -func findToolByName(tools []interfaces.Tool, toolName string) (interfaces.Tool, bool) { - for _, t := range tools { - if t.Name() == toolName { - return t, true - } - } - return nil, false -} - -func (rt *TemporalRuntime) delegateToSubAgent(ctx workflow.Context, input AgentWorkflowInput, tc ToolCallRequest, route types.SubAgentRoute, emitEvent func(events.AgentEvent) error) (string, error) { +func (rt *TemporalRuntime) delegateToSubAgent(ctx workflow.Context, input AgentWorkflowInput, tc ToolCallRequest, route SubAgentRoute, emitEvent func(events.AgentEvent) error) (string, error) { logger := workflow.GetLogger(ctx) if strings.TrimSpace(route.TaskQueue) == "" { logger.Warn("workflow: sub-agent delegation skipped (empty task queue)", @@ -1591,7 +1177,7 @@ func (rt *TemporalRuntime) delegateToSubAgent(ctx workflow.Context, input AgentW return fmt.Sprintf("Sub-agent delegation refused: maximum nesting depth (%d) reached for this agent.", maxDepth), nil } - query := subAgentQueryFromArgs(tc.Args) + query := base.SubAgentQuery(tc.Args) childInput := AgentWorkflowInput{ UserPrompt: query, EventWorkflowID: input.EventWorkflowID, @@ -1669,26 +1255,6 @@ func (rt *TemporalRuntime) delegateToSubAgent(ctx workflow.Context, input AgentW return childResult.Content, nil } -func (rt *TemporalRuntime) requiresApproval(t interfaces.Tool) bool { - if rt.AgentExecution.Tools.ApprovalPolicy == nil { - // No policy: honor tool's ApprovalRequired - if ar, ok := t.(interfaces.ToolApproval); ok && ar.ApprovalRequired() { - return true - } - return false - } - // Policy set: policy decides (can override tool default) - return rt.AgentExecution.Tools.ApprovalPolicy.RequiresApproval(t) -} - -func subAgentQueryFromArgs(args map[string]any) string { - if args == nil { - return "" - } - q, _ := args[types.SubAgentToolParamQuery].(string) - return q -} - // subAgentChildWorkflowTimeout caps how long the main agent waits on a delegated sub-agent run. // Uses the main agent worker's agent timeout (same package as delegateToSubAgent); sub-agent workers may define // their own limits separately, but this bounds the child execution from the main agent's perspective. @@ -1696,30 +1262,6 @@ func (rt *TemporalRuntime) subAgentChildWorkflowTimeout() time.Duration { return rt.AgentExecution.Limits.Timeout } -func mergeLLMUsage(acc *interfaces.LLMUsage, add *interfaces.LLMUsage) *interfaces.LLMUsage { - if add == nil { - return acc - } - if acc == nil { - return cloneLLMUsagePtr(add) - } - return &interfaces.LLMUsage{ - PromptTokens: acc.PromptTokens + add.PromptTokens, - CompletionTokens: acc.CompletionTokens + add.CompletionTokens, - TotalTokens: acc.TotalTokens + add.TotalTokens, - CachedPromptTokens: acc.CachedPromptTokens + add.CachedPromptTokens, - ReasoningTokens: acc.ReasoningTokens + add.ReasoningTokens, - } -} - -func cloneLLMUsagePtr(u *interfaces.LLMUsage) *interfaces.LLMUsage { - if u == nil { - return nil - } - c := *u - return &c -} - func retryPolicy(maxAttempts int32) *temporal.RetryPolicy { return &temporal.RetryPolicy{ InitialInterval: time.Second, @@ -1728,26 +1270,3 @@ func retryPolicy(maxAttempts int32) *temporal.RetryPolicy { MaximumAttempts: maxAttempts, } } - -func applyLLMSampling(sampling *types.LLMSampling, req *interfaces.LLMRequest) { - if sampling == nil { - return - } - s := sampling - if s.Temperature != nil { - req.Temperature = s.Temperature - } - if s.MaxTokens > 0 { - req.MaxTokens = s.MaxTokens - } - if s.TopP != nil { - req.TopP = s.TopP - } - if s.TopK != nil { - req.TopK = s.TopK - } - if s.Reasoning != nil { - r := *s.Reasoning - req.Reasoning = &r - } -} diff --git a/internal/runtime/temporal/agent_workflow_test.go b/internal/runtime/temporal/agent_workflow_test.go index 9f6f302..397724e 100644 --- a/internal/runtime/temporal/agent_workflow_test.go +++ b/internal/runtime/temporal/agent_workflow_test.go @@ -12,6 +12,7 @@ import ( "go.temporal.io/sdk/workflow" sdkruntime "github.com/agenticenv/agent-sdk-go/internal/runtime" + "github.com/agenticenv/agent-sdk-go/internal/runtime/base" "github.com/agenticenv/agent-sdk-go/internal/types" "github.com/agenticenv/agent-sdk-go/pkg/interfaces" "github.com/agenticenv/agent-sdk-go/pkg/interfaces/mocks" @@ -22,7 +23,7 @@ import ( func testRuntimeForWorkflow(t *testing.T) *TemporalRuntime { t.Helper() return &TemporalRuntime{ - TemporalRuntimeConfig: TemporalRuntimeConfig{ + Runtime: base.Runtime{ AgentSpec: sdkruntime.AgentSpec{Name: "WorkflowTestAgent"}, AgentExecution: sdkruntime.AgentExecution{ LLM: sdkruntime.AgentLLM{Client: stubLLM{}}, @@ -30,10 +31,10 @@ func testRuntimeForWorkflow(t *testing.T) *TemporalRuntime { Tools: sdkruntime.AgentTools{Tools: nil}, Session: sdkruntime.AgentSession{}, }, - logger: logger.NoopLogger(), Tracer: observability.DefaultNoopTracer, Metrics: observability.DefaultNoopMetrics, }, + logger: logger.NoopLogger(), } } @@ -189,15 +190,15 @@ func TestAgentLLMActivity_MockLLM_TextOnly(t *testing.T) { }, nil) rt := &TemporalRuntime{ - TemporalRuntimeConfig: TemporalRuntimeConfig{ + Runtime: base.Runtime{ AgentSpec: sdkruntime.AgentSpec{Name: "ActTest"}, AgentExecution: sdkruntime.AgentExecution{ LLM: sdkruntime.AgentLLM{Client: mockLLM}, }, - logger: logger.NoopLogger(), Tracer: observability.DefaultNoopTracer, Metrics: observability.DefaultNoopMetrics, }, + logger: logger.NoopLogger(), } actEnv := newActivityTestEnv(t) @@ -239,16 +240,16 @@ func TestAgentLLMActivity_MockLLM_ToolCalls(t *testing.T) { }, nil) rt := &TemporalRuntime{ - TemporalRuntimeConfig: TemporalRuntimeConfig{ + Runtime: base.Runtime{ AgentSpec: sdkruntime.AgentSpec{Name: "ActTest"}, AgentExecution: sdkruntime.AgentExecution{ LLM: sdkruntime.AgentLLM{Client: mockLLM}, Tools: sdkruntime.AgentTools{Tools: []interfaces.Tool{mockTool}, ApprovalPolicy: policy}, }, - logger: logger.NoopLogger(), Tracer: observability.DefaultNoopTracer, Metrics: observability.DefaultNoopMetrics, }, + logger: logger.NoopLogger(), } actEnv := newActivityTestEnv(t) @@ -283,15 +284,15 @@ func TestAgentLLMActivity_MockLLM_UnknownToolError(t *testing.T) { }, nil) rt := &TemporalRuntime{ - TemporalRuntimeConfig: TemporalRuntimeConfig{ + Runtime: base.Runtime{ AgentExecution: sdkruntime.AgentExecution{ LLM: sdkruntime.AgentLLM{Client: mockLLM}, Tools: sdkruntime.AgentTools{Tools: []interfaces.Tool{}}, }, - logger: logger.NoopLogger(), Tracer: observability.DefaultNoopTracer, Metrics: observability.DefaultNoopMetrics, }, + logger: logger.NoopLogger(), } actEnv := newActivityTestEnv(t) @@ -319,7 +320,7 @@ func TestAgentLLMActivity_MockConversationAndLLM(t *testing.T) { mockLLM.EXPECT().Generate(gomock.Any(), gomock.Any()).Return(&interfaces.LLMResponse{Content: "answer"}, nil) rt := &TemporalRuntime{ - TemporalRuntimeConfig: TemporalRuntimeConfig{ + Runtime: base.Runtime{ AgentExecution: sdkruntime.AgentExecution{ LLM: sdkruntime.AgentLLM{Client: mockLLM}, Session: sdkruntime.AgentSession{ @@ -327,10 +328,10 @@ func TestAgentLLMActivity_MockConversationAndLLM(t *testing.T) { ConversationSize: 10, }, }, - logger: logger.NoopLogger(), Tracer: observability.DefaultNoopTracer, Metrics: observability.DefaultNoopMetrics, }, + logger: logger.NoopLogger(), } actEnv := newActivityTestEnv(t) @@ -353,15 +354,15 @@ func TestAgentLLMActivity_ConversationNotConfigured(t *testing.T) { mockLLM := mocks.NewMockLLMClient(ctrl) rt := &TemporalRuntime{ - TemporalRuntimeConfig: TemporalRuntimeConfig{ + Runtime: base.Runtime{ AgentExecution: sdkruntime.AgentExecution{ LLM: sdkruntime.AgentLLM{Client: mockLLM}, Session: sdkruntime.AgentSession{Conversation: nil}, }, - logger: logger.NoopLogger(), Tracer: observability.DefaultNoopTracer, Metrics: observability.DefaultNoopMetrics, }, + logger: logger.NoopLogger(), } actEnv := newActivityTestEnv(t) @@ -385,15 +386,15 @@ func TestAgentLLMStreamActivity_MockLLM_FallbackToGenerate(t *testing.T) { mockLLM.EXPECT().Generate(gomock.Any(), gomock.Any()).Return(&interfaces.LLMResponse{Content: "gen"}, nil) rt := &TemporalRuntime{ - TemporalRuntimeConfig: TemporalRuntimeConfig{ + Runtime: base.Runtime{ AgentSpec: sdkruntime.AgentSpec{Name: "StreamAct"}, AgentExecution: sdkruntime.AgentExecution{ LLM: sdkruntime.AgentLLM{Client: mockLLM}, }, - logger: logger.NoopLogger(), Tracer: observability.DefaultNoopTracer, Metrics: observability.DefaultNoopMetrics, }, + logger: logger.NoopLogger(), } actEnv := newActivityTestEnv(t) @@ -490,7 +491,7 @@ func makeRetrieverRuntime(t *testing.T, retrievers []interfaces.Retriever, mode mockLLM.EXPECT().GetModel().Return("test-model").AnyTimes() mockLLM.EXPECT().GetProvider().Return(interfaces.LLMProviderOpenAI).AnyTimes() return &TemporalRuntime{ - TemporalRuntimeConfig: TemporalRuntimeConfig{ + Runtime: base.Runtime{ AgentSpec: sdkruntime.AgentSpec{Name: "RetrieverTest"}, AgentExecution: sdkruntime.AgentExecution{ LLM: sdkruntime.AgentLLM{Client: mockLLM}, @@ -499,10 +500,10 @@ func makeRetrieverRuntime(t *testing.T, retrievers []interfaces.Retriever, mode Mode: mode, }, }, - logger: logger.NoopLogger(), Tracer: observability.DefaultNoopTracer, Metrics: observability.DefaultNoopMetrics, }, + logger: logger.NoopLogger(), } } @@ -644,30 +645,6 @@ func TestAgentRetrieverActivity_EmptyDocs_EmptyContext(t *testing.T) { // buildLLMRequest RAG context tests // --------------------------------------------------------------------------- -func TestBuildLLMRequest_WithRagContext_AugmentsSystemPrompt(t *testing.T) { - rt := &TemporalRuntime{ - TemporalRuntimeConfig: TemporalRuntimeConfig{ - AgentSpec: sdkruntime.AgentSpec{Name: "Test", SystemPrompt: "You are helpful."}, - AgentExecution: sdkruntime.AgentExecution{LLM: sdkruntime.AgentLLM{Client: stubLLM{}}}, - }, - } - req, _ := rt.buildLLMRequest(nil, false, "doc context") - require.Contains(t, req.SystemMessage, "You are helpful.") - require.Contains(t, req.SystemMessage, "Relevant Context:") - require.Contains(t, req.SystemMessage, "doc context") -} - -func TestBuildLLMRequest_NoRagContext_UnchangedSystemPrompt(t *testing.T) { - rt := &TemporalRuntime{ - TemporalRuntimeConfig: TemporalRuntimeConfig{ - AgentSpec: sdkruntime.AgentSpec{Name: "Test", SystemPrompt: "You are helpful."}, - AgentExecution: sdkruntime.AgentExecution{LLM: sdkruntime.AgentLLM{Client: stubLLM{}}}, - }, - } - req, _ := rt.buildLLMRequest(nil, false, "") - require.Equal(t, "You are helpful.", req.SystemMessage) -} - // --------------------------------------------------------------------------- // AgentWorkflow + prefetch mode integration // --------------------------------------------------------------------------- @@ -682,7 +659,7 @@ func TestAgentWorkflow_PrefetchMode_CallsRetrieverActivityFirst(t *testing.T) { var suite testsuite.WorkflowTestSuite env := suite.NewTestWorkflowEnvironment() rt := &TemporalRuntime{ - TemporalRuntimeConfig: TemporalRuntimeConfig{ + Runtime: base.Runtime{ AgentSpec: sdkruntime.AgentSpec{Name: "PrefetchAgent", SystemPrompt: "base prompt"}, AgentExecution: sdkruntime.AgentExecution{ LLM: sdkruntime.AgentLLM{Client: stubLLM{}}, @@ -692,10 +669,10 @@ func TestAgentWorkflow_PrefetchMode_CallsRetrieverActivityFirst(t *testing.T) { Mode: types.RetrieverModePrefetch, }, }, - logger: logger.NoopLogger(), Tracer: observability.DefaultNoopTracer, Metrics: observability.DefaultNoopMetrics, }, + logger: logger.NoopLogger(), } env.RegisterWorkflow(rt.AgentWorkflow) @@ -745,23 +722,3 @@ func TestAgentWorkflow_AgenticMode_SkipsRetrieverActivity(t *testing.T) { require.True(t, env.IsWorkflowCompleted()) require.NoError(t, env.GetWorkflowError()) } - -func TestMergeLLMUsage(t *testing.T) { - a := &interfaces.LLMUsage{PromptTokens: 10, CompletionTokens: 5, TotalTokens: 15} - b := &interfaces.LLMUsage{PromptTokens: 3, CompletionTokens: 7, TotalTokens: 10, CachedPromptTokens: 2, ReasoningTokens: 1} - - got := mergeLLMUsage(a, b) - if got.PromptTokens != 13 || got.CompletionTokens != 12 || got.TotalTokens != 25 { - t.Fatalf("mergeLLMUsage: got %+v", got) - } - if got.CachedPromptTokens != 2 || got.ReasoningTokens != 1 { - t.Fatalf("mergeLLMUsage optional fields: got %+v", got) - } - - if mergeLLMUsage(nil, nil) != nil { - t.Fatal("nil + nil should be nil") - } - if x := mergeLLMUsage(nil, b); x.PromptTokens != b.PromptTokens { - t.Fatal("nil + b should copy b") - } -} diff --git a/internal/runtime/temporal/config.go b/internal/runtime/temporal/config.go index 34f1198..44ee974 100644 --- a/internal/runtime/temporal/config.go +++ b/internal/runtime/temporal/config.go @@ -7,16 +7,16 @@ import ( "strconv" "time" - sdkruntime "github.com/agenticenv/agent-sdk-go/internal/runtime" "github.com/agenticenv/agent-sdk-go/internal/types" "github.com/agenticenv/agent-sdk-go/pkg/interfaces" "github.com/agenticenv/agent-sdk-go/pkg/logger" - "github.com/agenticenv/agent-sdk-go/pkg/observability" "go.temporal.io/sdk/client" "go.temporal.io/sdk/contrib/opentelemetry" "go.temporal.io/sdk/interceptor" ) +// TemporalConfig holds the Temporal server connection parameters. +// Pass it to [WithTemporalConfig] when the runtime should dial its own client. type TemporalConfig struct { Host string Port int @@ -24,255 +24,6 @@ type TemporalConfig struct { TaskQueue string } -// TemporalRuntimeConfig holds connection settings plus the same [sdkruntime.AgentSpec] / -// [sdkruntime.AgentExecution] shape as [sdkruntime.ExecuteRequest], so workers and pkg/agent share one layout. -type TemporalRuntimeConfig struct { - temporalConfig *TemporalConfig - temporalClient client.Client - taskQueue string - instanceId string - ownsTemporalClient bool - // enableRemoteWorkers: start event worker + event workflow in Execute/ExecuteStream (client agent runtime). - enableRemoteWorkers bool - // remoteWorker: true for NewAgentWorker (polls activities); false for client Agent runtime. - remoteWorker bool - - logger logger.Logger - - AgentSpec sdkruntime.AgentSpec - AgentExecution sdkruntime.AgentExecution - PolicyFingerprint string // from pkg/agent toolPolicyFingerprint; must match caller temporal.ComputeAgentFingerprint inputs - MCPFingerprint string // from pkg/agent mcpConfigFingerprint; must match caller temporal.ComputeAgentFingerprint inputs - A2AFingerprint string // from pkg/agent a2aConfigFingerprint; must match caller temporal.ComputeAgentFingerprint inputs - // ObservabilityFingerprint is from pkg/agent observabilityConfigFingerprint; must match caller temporal.ComputeAgentFingerprint inputs. - ObservabilityFingerprint string - // AgentMode is the string form of [types.AgentMode] (e.g. "interactive", "autonomous"); must match pkg/agent WithAgentMode. - AgentMode string - // AgentToolExecutionMode is the [types.AgentToolExecutionMode] (e.g. "sequential", "parallel"); must match pkg/agent WithAgentToolExecutionMode. - AgentToolExecutionMode types.AgentToolExecutionMode - // RetrieverFingerprint is from pkg/agent retrieverConfigFingerprint; must match caller temporal.ComputeAgentFingerprint inputs. - RetrieverFingerprint string - // DisableLocalWorker mirrors pkg/agent [DisableLocalWorker]: when false, the client embeds a worker - // so Execute/ExecuteStream skip DescribeTaskQueue poller checks. ([NewAgentWorker] never calls those methods.) - DisableLocalWorker bool - // DisableFingerprintCheck disables caller-vs-worker agent fingerprint verification at activity entry. - // Break-glass only: keep false in production for rollout/config safety. - DisableFingerprintCheck bool - - // Tracer and Metrics are optional clients from pkg/agent (WithObservabilityConfig / WithTracer / WithMetrics). - // When the runtime owns the Temporal client ([WithTemporalConfig]), [interfaces.OTelTracer] is used to attach - // the Temporal OpenTelemetry client interceptor. Workers use the same tracer for worker interceptors. - Tracer interfaces.Tracer - Metrics interfaces.Metrics -} - -// Option configures a TemporalRuntime. -type Option func(*TemporalRuntimeConfig) - -// WithTemporalConfig sets the Temporal config. -func WithTemporalConfig(config *TemporalConfig) Option { - return func(c *TemporalRuntimeConfig) { - c.temporalConfig = config - c.taskQueue = config.TaskQueue - c.ownsTemporalClient = true - } -} - -// WithTemporalClient sets the Temporal client. -func WithTemporalClient(client client.Client, taskQueue string) Option { - return func(c *TemporalRuntimeConfig) { - c.temporalClient = client - c.taskQueue = taskQueue - c.ownsTemporalClient = false - } -} - -func WithInstanceId(instanceId string) Option { - return func(c *TemporalRuntimeConfig) { - c.instanceId = instanceId - } -} - -func WithEnableRemoteWorkers(enableRemoteWorkers bool) Option { - return func(c *TemporalRuntimeConfig) { - c.enableRemoteWorkers = enableRemoteWorkers - } -} - -func WithRemoteWorker(remoteWorker bool) Option { - return func(c *TemporalRuntimeConfig) { - c.remoteWorker = remoteWorker - } -} - -func WithLogger(logger logger.Logger) Option { - return func(c *TemporalRuntimeConfig) { - c.logger = logger - } -} - -// WithAgentSpec sets identity and response format (same as [sdkruntime.ExecuteRequest.AgentSpec]). -func WithAgentSpec(spec sdkruntime.AgentSpec) Option { - return func(c *TemporalRuntimeConfig) { - c.AgentSpec = spec - } -} - -// WithAgentExecution sets LLM, tools, session, and limits (same as [sdkruntime.ExecuteRequest.AgentExecution]). -func WithAgentExecution(exec sdkruntime.AgentExecution) Option { - return func(c *TemporalRuntimeConfig) { - c.AgentExecution = exec - } -} - -// WithPolicyFingerprint sets the opaque policy digest used with [ComputeAgentFingerprint]. -// Must match pkg/agent's toolPolicyFingerprint for the same agent options. -func WithPolicyFingerprint(fp string) Option { - return func(c *TemporalRuntimeConfig) { - c.PolicyFingerprint = fp - } -} - -// WithMCPFingerprint sets the MCP wiring digest used with [ComputeAgentFingerprint]. -// Must match pkg/agent's mcpConfigFingerprint for the same WithMCPConfig / WithMCPClients wiring. -func WithMCPFingerprint(fp string) Option { - return func(c *TemporalRuntimeConfig) { - c.MCPFingerprint = fp - } -} - -// WithA2AFingerprint sets the A2A wiring digest used with [ComputeAgentFingerprint]. -// Must match pkg/agent's a2aConfigFingerprint for the same WithA2AConfig / WithA2AClients wiring. -func WithA2AFingerprint(fp string) Option { - return func(c *TemporalRuntimeConfig) { - c.A2AFingerprint = fp - } -} - -// WithObservabilityFingerprint sets the OTLP observability digest used with [ComputeAgentFingerprint]. -// Must match pkg/agent observabilityConfigFingerprint for the same WithObservabilityConfig wiring. -func WithObservabilityFingerprint(fp string) Option { - return func(c *TemporalRuntimeConfig) { - c.ObservabilityFingerprint = fp - } -} - -// WithAgentMode sets the agent mode string used with [ComputeAgentFingerprint]. -// Must match pkg/agent [WithAgentMode] for the same agent (caller process and worker process). -func WithAgentMode(mode string) Option { - return func(c *TemporalRuntimeConfig) { - c.AgentMode = mode - } -} - -// WithAgentToolExecutionMode sets the agent tool execution mode string used with [ComputeAgentFingerprint]. -// Must match pkg/agent [WithAgentToolExecutionMode] for the same agent (caller process and worker process). -func WithAgentToolExecutionMode(mode types.AgentToolExecutionMode) Option { - return func(c *TemporalRuntimeConfig) { - c.AgentToolExecutionMode = mode - } -} - -// WithRetrieverFingerprint sets the retriever wiring digest (mode + retriever names). -// Must match pkg/agent [retrieverConfigFingerprint] for the same agent. -func WithRetrieverFingerprint(fp string) Option { - return func(c *TemporalRuntimeConfig) { - c.RetrieverFingerprint = fp - } -} - -// WithDisableLocalWorker mirrors pkg/agent [DisableLocalWorker]. When false, the client embeds a worker -// and the runtime skips DescribeTaskQueue poller checks before starting workflows. -func WithDisableLocalWorker(disable bool) Option { - return func(c *TemporalRuntimeConfig) { - c.DisableLocalWorker = disable - } -} - -// WithDisableFingerprintCheck disables activity-time caller-vs-worker fingerprint verification. -// Break-glass only: use temporarily during rollout incidents; default is strict verification. -func WithDisableFingerprintCheck(disable bool) Option { - return func(c *TemporalRuntimeConfig) { - c.DisableFingerprintCheck = disable - } -} - -// WithTracer sets the optional [interfaces.Tracer] for this runtime (from pkg/agent build). -// When the runtime dials its own Temporal client ([WithTemporalConfig]) and the tracer implements -// [interfaces.OTelTracer], a Temporal OpenTelemetry client interceptor is attached. -// For [WithTemporalClient], the SDK cannot modify the client; if the tracer implements [interfaces.OTelTracer], -// a warning is logged so callers can register the interceptor on their client. -func WithTracer(t interfaces.Tracer) Option { - return func(c *TemporalRuntimeConfig) { - c.Tracer = t - } -} - -// WithMetrics sets the optional [interfaces.Metrics] for this runtime (from pkg/agent build). -// The Temporal runtime stores it for consistency with agent config; worker/client paths do not emit metrics yet. -func WithMetrics(m interfaces.Metrics) Option { - return func(c *TemporalRuntimeConfig) { - c.Metrics = m - } -} - -func buildTemporalRuntimeConfig(opts ...Option) (*TemporalRuntimeConfig, error) { - c := &TemporalRuntimeConfig{logger: logger.NoopLogger()} - for _, opt := range opts { - opt(c) - } - - if c.temporalConfig == nil && c.temporalClient == nil { - return nil, fmt.Errorf("temporal config or client is required") - } - - if c.temporalConfig != nil { - tc, err := newTemporalClient(c.temporalConfig, c.logger, c.Tracer) - if err != nil { - return nil, err - } - c.temporalClient = tc - } else { // user provided Temporal client - if _, ok := c.Tracer.(interfaces.OTelTracer); ok { - c.logger.Warn(context.Background(), "user provided Temporal client — add OTel interceptor manually for tracing", slog.String("scope", "runtime")) - } - } - - if c.instanceId != "" { - c.taskQueue = c.taskQueue + "-" + c.instanceId - } - - if c.AgentExecution.LLM.Client == nil { - return nil, fmt.Errorf("llm client is required") - } - - if c.Tracer == nil { - c.Tracer = observability.DefaultNoopTracer - } - if c.Metrics == nil { - c.Metrics = observability.DefaultNoopMetrics - } - - c.logger.Debug(context.Background(), "runtime config resolved", - slog.String("scope", "runtime"), - slog.String("agentName", c.AgentSpec.Name), - slog.String("taskQueue", c.taskQueue), - slog.String("instanceId", c.instanceId), - slog.Int("maxIterations", c.AgentExecution.Limits.MaxIterations), - slog.Bool("remoteWorker", c.remoteWorker), - slog.String("agentMode", c.AgentMode), - slog.String("agentToolExecutionMode", string(c.AgentToolExecutionMode)), - slog.Bool("enableRemoteWorkers", c.enableRemoteWorkers), - slog.Bool("disableFingerprintCheck", c.DisableFingerprintCheck), - slog.Duration("timeout", c.AgentExecution.Limits.Timeout), - slog.Duration("approvalTimeout", c.AgentExecution.Limits.ApprovalTimeout), - slog.Bool("hasConversation", c.AgentExecution.Session.Conversation != nil), - slog.Bool("hasTracer", c.Tracer != nil), - slog.Bool("hasMetrics", c.Metrics != nil)) - - return c, nil -} - func newTemporalClient(config *TemporalConfig, sdkLog logger.Logger, tracer interfaces.Tracer) (client.Client, error) { ctx := context.Background() sdkLog.Info(ctx, "runtime connecting to temporal server", slog.String("scope", "runtime"), slog.String("host", config.Host), slog.Int("port", config.Port)) @@ -296,7 +47,6 @@ func newTemporalClient(config *TemporalConfig, sdkLog logger.Logger, tracer inte defer ticker.Stop() connectionTimeout := 10 * time.Second - timeoutExceeded := time.After(connectionTimeout) var c client.Client diff --git a/internal/runtime/temporal/config_test.go b/internal/runtime/temporal/config_test.go index cc805b4..19a88d6 100644 --- a/internal/runtime/temporal/config_test.go +++ b/internal/runtime/temporal/config_test.go @@ -80,12 +80,12 @@ func TestNewTemporalTracingInterceptor_otelTracer_nonNil(t *testing.T) { } } -func TestBuildTemporalRuntimeConfig_userProvidedTemporalClient_otelTracer_warns(t *testing.T) { +func TestBuildTemporalRuntime_userProvidedTemporalClient_otelTracer_warns(t *testing.T) { var buf bytes.Buffer log := logger.NewWriterLogger(&buf, "warn", "text", false) tc := temporalmocks.NewClient(t) - _, err := buildTemporalRuntimeConfig( + _, err := buildTemporalRuntime( WithTemporalClient(tc, "tq"), WithLogger(log), WithTracer(newTestOTelTracer()), @@ -100,12 +100,12 @@ func TestBuildTemporalRuntimeConfig_userProvidedTemporalClient_otelTracer_warns( } } -func TestBuildTemporalRuntimeConfig_userProvidedTemporalClient_defaultTracer_noManualInterceptorWarn(t *testing.T) { +func TestBuildTemporalRuntime_userProvidedTemporalClient_defaultTracer_noManualInterceptorWarn(t *testing.T) { var buf bytes.Buffer log := logger.NewWriterLogger(&buf, "warn", "text", false) tc := temporalmocks.NewClient(t) - _, err := buildTemporalRuntimeConfig( + _, err := buildTemporalRuntime( WithTemporalClient(tc, "tq"), WithLogger(log), WithAgentSpec(sdkruntime.AgentSpec{Name: "x"}), @@ -119,12 +119,12 @@ func TestBuildTemporalRuntimeConfig_userProvidedTemporalClient_defaultTracer_noM } } -func TestBuildTemporalRuntimeConfig_userProvidedTemporalClient_explicitNoopTracer_noManualInterceptorWarn(t *testing.T) { +func TestBuildTemporalRuntime_userProvidedTemporalClient_explicitNoopTracer_noManualInterceptorWarn(t *testing.T) { var buf bytes.Buffer log := logger.NewWriterLogger(&buf, "warn", "text", false) tc := temporalmocks.NewClient(t) - _, err := buildTemporalRuntimeConfig( + _, err := buildTemporalRuntime( WithTemporalClient(tc, "tq"), WithLogger(log), WithTracer(observability.DefaultNoopTracer), @@ -139,7 +139,7 @@ func TestBuildTemporalRuntimeConfig_userProvidedTemporalClient_explicitNoopTrace } } -func TestBuildTemporalRuntimeConfig_RequiresTemporalOrClient(t *testing.T) { +func TestBuildTemporalRuntime_RequiresTemporalOrClient(t *testing.T) { // Neither WithTemporalConfig nor WithTemporalClient: must fail fast without dialing a server. options := []Option{ WithLogger(logger.NoopLogger()), @@ -153,15 +153,15 @@ func TestBuildTemporalRuntimeConfig_RequiresTemporalOrClient(t *testing.T) { LLM: sdkruntime.AgentLLM{Client: stubLLM{}}, }), } - _, err := buildTemporalRuntimeConfig(options...) + _, err := buildTemporalRuntime(options...) if err == nil || !strings.Contains(err.Error(), "temporal config or client is required") { t.Fatalf("got %v", err) } } -func TestBuildTemporalRuntimeConfig_RequiresLLMClient(t *testing.T) { +func TestBuildTemporalRuntime_RequiresLLMClient(t *testing.T) { tc := temporalmocks.NewClient(t) - _, err := buildTemporalRuntimeConfig( + _, err := buildTemporalRuntime( WithTemporalClient(tc, "tq"), WithLogger(logger.NoopLogger()), WithAgentSpec(sdkruntime.AgentSpec{Name: "x"}), @@ -172,9 +172,9 @@ func TestBuildTemporalRuntimeConfig_RequiresLLMClient(t *testing.T) { } } -func TestBuildTemporalRuntimeConfig_InstanceIdSuffix(t *testing.T) { +func TestBuildTemporalRuntime_InstanceIdSuffix(t *testing.T) { tc := temporalmocks.NewClient(t) - cfg, err := buildTemporalRuntimeConfig( + rt, err := buildTemporalRuntime( WithTemporalClient(tc, "myq"), WithInstanceId("pod1"), WithLogger(logger.NoopLogger()), @@ -184,7 +184,7 @@ func TestBuildTemporalRuntimeConfig_InstanceIdSuffix(t *testing.T) { if err != nil { t.Fatal(err) } - if cfg.taskQueue != "myq-pod1" { - t.Fatalf("taskQueue = %q, want myq-pod1", cfg.taskQueue) + if rt.taskQueue != "myq-pod1" { + t.Fatalf("taskQueue = %q, want myq-pod1", rt.taskQueue) } } diff --git a/internal/runtime/temporal/event_workflow_test.go b/internal/runtime/temporal/event_workflow_test.go index 2b822f9..8557bca 100644 --- a/internal/runtime/temporal/event_workflow_test.go +++ b/internal/runtime/temporal/event_workflow_test.go @@ -19,8 +19,8 @@ func TestEventPublishActivity_PublishesToEventBus(t *testing.T) { l := logger.NoopLogger() bus := eventbus.NewInmem(l) rt := &TemporalRuntime{ - TemporalRuntimeConfig: TemporalRuntimeConfig{logger: l}, - eventbus: bus, + logger: l, + eventbus: bus, } ctx := context.Background() chName := "agent_event_unit_test" @@ -65,8 +65,8 @@ func TestEventPublishActivity_PublishesToEventBus(t *testing.T) { func TestEventPublishActivity_NilEventErrors(t *testing.T) { l := logger.NoopLogger() rt := &TemporalRuntime{ - TemporalRuntimeConfig: TemporalRuntimeConfig{logger: l}, - eventbus: eventbus.NewInmem(l), + logger: l, + eventbus: eventbus.NewInmem(l), } actEnv := newActivityTestEnv(t) actEnv.RegisterActivity(rt.EventPublishActivity) @@ -80,8 +80,8 @@ func TestSubscribeToAgentEvents_RoundTrip(t *testing.T) { l := logger.NoopLogger() bus := eventbus.NewInmem(l) rt := &TemporalRuntime{ - TemporalRuntimeConfig: TemporalRuntimeConfig{logger: l}, - eventbus: bus, + logger: l, + eventbus: bus, } ctx := context.Background() chName := "agent_event_sub_test" diff --git a/internal/runtime/temporal/fingerprint.go b/internal/runtime/temporal/fingerprint.go index d929c62..63b8eeb 100644 --- a/internal/runtime/temporal/fingerprint.go +++ b/internal/runtime/temporal/fingerprint.go @@ -180,20 +180,20 @@ func ToolNamesFromTools(tools []interfaces.Tool) []string { return names } -func computeAgentFingerprintFromRuntimeConfig(c *TemporalRuntimeConfig) string { +func computeAgentFingerprintFromRuntime(rt *TemporalRuntime) string { mat := BuildAgentFingerprintPayload( - c.AgentSpec, - ToolNamesFromTools(c.AgentExecution.Tools.Tools), - c.PolicyFingerprint, - c.AgentExecution.LLM.Sampling, - c.AgentExecution.Session.ConversationSize, - c.AgentExecution.Limits, - c.MCPFingerprint, - c.A2AFingerprint, - c.ObservabilityFingerprint, - c.AgentMode, - c.AgentToolExecutionMode, - c.RetrieverFingerprint, + rt.AgentSpec, + ToolNamesFromTools(rt.AgentExecution.Tools.Tools), + rt.policyFingerprint, + rt.AgentExecution.LLM.Sampling, + rt.AgentExecution.Session.ConversationSize, + rt.AgentExecution.Limits, + rt.mcpFingerprint, + rt.a2aFingerprint, + rt.observabilityFingerprint, + rt.agentMode, + rt.ToolExecutionMode, + rt.retrieverFingerprint, ) return ComputeAgentFingerprint(mat) } diff --git a/internal/runtime/temporal/fingerprint_test.go b/internal/runtime/temporal/fingerprint_test.go index 9ec144a..3eddb7b 100644 --- a/internal/runtime/temporal/fingerprint_test.go +++ b/internal/runtime/temporal/fingerprint_test.go @@ -108,21 +108,24 @@ func TestComputeAgentFingerprint_observabilityFingerprintChangesDigest(t *testin } } +func newFingerprintRT(spec sdkruntime.AgentSpec, exec sdkruntime.AgentExecution, policyFP string, opts ...func(*TemporalRuntime)) *TemporalRuntime { + rt := &TemporalRuntime{} + rt.AgentSpec = spec + rt.AgentExecution = exec + rt.policyFingerprint = policyFP + for _, o := range opts { + o(rt) + } + rt.agentFingerprint = computeAgentFingerprintFromRuntime(rt) + return rt +} + func TestVerifyAgentFingerprint_mismatch(t *testing.T) { - cfg := &TemporalRuntimeConfig{ - AgentSpec: sdkruntime.AgentSpec{Name: "x"}, - AgentExecution: sdkruntime.AgentExecution{ - LLM: sdkruntime.AgentLLM{}, - Tools: sdkruntime.AgentTools{Tools: nil}, - Session: sdkruntime.AgentSession{}, - Limits: sdkruntime.AgentLimits{}, - }, - PolicyFingerprint: "require_all", - } - rt := &TemporalRuntime{ - TemporalRuntimeConfig: *cfg, - agentFingerprint: computeAgentFingerprintFromRuntimeConfig(cfg), - } + rt := newFingerprintRT( + sdkruntime.AgentSpec{Name: "x"}, + sdkruntime.AgentExecution{}, + "require_all", + ) err := rt.verifyAgentFingerprint("deadbeef") if err == nil { t.Fatal("expected mismatch error") @@ -137,42 +140,26 @@ func TestVerifyAgentFingerprint_bothEmptyOK(t *testing.T) { } func TestVerifyAgentFingerprint_emptyWantWhenWorkerHasFingerprint(t *testing.T) { - cfg := &TemporalRuntimeConfig{ - AgentSpec: sdkruntime.AgentSpec{Name: "x"}, - AgentExecution: sdkruntime.AgentExecution{ - LLM: sdkruntime.AgentLLM{}, - Tools: sdkruntime.AgentTools{Tools: nil}, - Session: sdkruntime.AgentSession{}, - Limits: sdkruntime.AgentLimits{}, - }, - PolicyFingerprint: "require_all", - } - rt := &TemporalRuntime{ - TemporalRuntimeConfig: *cfg, - agentFingerprint: computeAgentFingerprintFromRuntimeConfig(cfg), - } + rt := newFingerprintRT( + sdkruntime.AgentSpec{Name: "x"}, + sdkruntime.AgentExecution{}, + "require_all", + ) if err := rt.verifyAgentFingerprint(""); err == nil { t.Fatal("expected mismatch when caller fingerprint is empty but worker has one") } } func TestVerifyAgentFingerprint_disableCheckAllowsMismatch(t *testing.T) { - cfg := &TemporalRuntimeConfig{ - AgentSpec: sdkruntime.AgentSpec{Name: "x"}, - DisableFingerprintCheck: true, - PolicyFingerprint: "require_all", - AgentExecution: sdkruntime.AgentExecution{ - LLM: sdkruntime.AgentLLM{}, - Tools: sdkruntime.AgentTools{Tools: nil}, - Session: sdkruntime.AgentSession{}, - Limits: sdkruntime.AgentLimits{}, + rt := newFingerprintRT( + sdkruntime.AgentSpec{Name: "x"}, + sdkruntime.AgentExecution{}, + "require_all", + func(rt *TemporalRuntime) { + rt.disableFingerprintCheck = true + rt.ToolExecutionMode = "sequential" }, - AgentToolExecutionMode: "sequential", - } - rt := &TemporalRuntime{ - TemporalRuntimeConfig: *cfg, - agentFingerprint: computeAgentFingerprintFromRuntimeConfig(cfg), - } + ) if err := rt.verifyAgentFingerprint("definitely-different"); err != nil { t.Fatalf("expected bypass when skip is enabled, got: %v", err) } diff --git a/internal/runtime/temporal/logger.go b/internal/runtime/temporal/logger.go index e8c4f62..642cdc6 100644 --- a/internal/runtime/temporal/logger.go +++ b/internal/runtime/temporal/logger.go @@ -51,3 +51,28 @@ func (a logAdapter) Error(msg string, keyvals ...interface{}) { } var _ tlog.Logger = logAdapter{} + +// activityLogAdapter wraps a Temporal activity/workflow logger (tlog.Logger) so it satisfies +// pkg/logger.Logger. Activities obtain their logger via activity.GetLogger(ctx), which attaches +// workflow and activity metadata automatically; this bridge lets base-package methods reuse it. +type activityLogAdapter struct{ l tlog.Logger } + +// newActivityLogger returns a logger.Logger backed by the Temporal tlog.Logger l. +func newActivityLogger(l tlog.Logger) logger.Logger { + return activityLogAdapter{l: l} +} + +func (a activityLogAdapter) Debug(ctx context.Context, msg string, args ...any) { + a.l.Debug(msg, args...) +} +func (a activityLogAdapter) Info(ctx context.Context, msg string, args ...any) { + a.l.Info(msg, args...) +} +func (a activityLogAdapter) Warn(ctx context.Context, msg string, args ...any) { + a.l.Warn(msg, args...) +} +func (a activityLogAdapter) Error(ctx context.Context, msg string, args ...any) { + a.l.Error(msg, args...) +} + +var _ logger.Logger = activityLogAdapter{} diff --git a/internal/runtime/temporal/options.go b/internal/runtime/temporal/options.go new file mode 100644 index 0000000..78f1a31 --- /dev/null +++ b/internal/runtime/temporal/options.go @@ -0,0 +1,209 @@ +package temporal + +import ( + "context" + "fmt" + "log/slog" + + sdkruntime "github.com/agenticenv/agent-sdk-go/internal/runtime" + "github.com/agenticenv/agent-sdk-go/internal/types" + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" + "github.com/agenticenv/agent-sdk-go/pkg/logger" + "github.com/agenticenv/agent-sdk-go/pkg/observability" + "go.temporal.io/sdk/client" +) + +// Option configures a [TemporalRuntime]. +type Option func(*TemporalRuntime) + +// WithTemporalConfig dials a new Temporal client from the supplied connection parameters. +// The runtime owns the resulting client and closes it on [TemporalRuntime.Close]. +func WithTemporalConfig(config *TemporalConfig) Option { + return func(rt *TemporalRuntime) { + rt.temporalConfig = config + rt.taskQueue = config.TaskQueue + rt.ownsTemporalClient = true + } +} + +// WithTemporalClient injects a caller-managed Temporal client and task queue. +// The runtime does NOT close this client on [TemporalRuntime.Close]. +func WithTemporalClient(tc client.Client, taskQueue string) Option { + return func(rt *TemporalRuntime) { + rt.temporalClient = tc + rt.taskQueue = taskQueue + rt.ownsTemporalClient = false + } +} + +// WithInstanceId appends a suffix to the task queue (e.g. "myq-pod1") so multiple +// instances of the same agent can run on isolated queues. +func WithInstanceId(instanceId string) Option { + return func(rt *TemporalRuntime) { rt.instanceId = instanceId } +} + +// WithEnableRemoteWorkers starts the event worker and event workflow inside +// Execute/ExecuteStream (client agent runtime path). +func WithEnableRemoteWorkers(enable bool) Option { + return func(rt *TemporalRuntime) { rt.enableRemoteWorkers = enable } +} + +// WithRemoteWorker marks the runtime as a remote worker (true for [NewAgentWorker], +// false for client [Agent] runtimes). +func WithRemoteWorker(remote bool) Option { + return func(rt *TemporalRuntime) { rt.remoteWorker = remote } +} + +// WithLogger sets the [logger.Logger] used by the runtime. A nil value is silently +// ignored so the safe [logger.NoopLogger] default is preserved. +func WithLogger(l logger.Logger) Option { + return func(rt *TemporalRuntime) { + if l != nil { + rt.logger = l + } + } +} + +// WithAgentSpec sets identity and response format +// (same shape as [sdkruntime.ExecuteRequest.AgentSpec]). +func WithAgentSpec(spec sdkruntime.AgentSpec) Option { + return func(rt *TemporalRuntime) { rt.AgentSpec = spec } +} + +// WithAgentExecution sets LLM, tools, session, and limits +// (same shape as [sdkruntime.ExecuteRequest.AgentExecution]). +func WithAgentExecution(exec sdkruntime.AgentExecution) Option { + return func(rt *TemporalRuntime) { rt.AgentExecution = exec } +} + +// WithPolicyFingerprint sets the opaque policy digest used with [ComputeAgentFingerprint]. +// Must match pkg/agent's toolPolicyFingerprint for the same agent options. +func WithPolicyFingerprint(fp string) Option { + return func(rt *TemporalRuntime) { rt.policyFingerprint = fp } +} + +// WithMCPFingerprint sets the MCP wiring digest used with [ComputeAgentFingerprint]. +// Must match pkg/agent's mcpConfigFingerprint for the same WithMCPConfig / WithMCPClients wiring. +func WithMCPFingerprint(fp string) Option { + return func(rt *TemporalRuntime) { rt.mcpFingerprint = fp } +} + +// WithA2AFingerprint sets the A2A wiring digest used with [ComputeAgentFingerprint]. +// Must match pkg/agent's a2aConfigFingerprint for the same WithA2AConfig / WithA2AClients wiring. +func WithA2AFingerprint(fp string) Option { + return func(rt *TemporalRuntime) { rt.a2aFingerprint = fp } +} + +// WithObservabilityFingerprint sets the OTLP observability digest used with [ComputeAgentFingerprint]. +// Must match pkg/agent observabilityConfigFingerprint for the same WithObservabilityConfig wiring. +func WithObservabilityFingerprint(fp string) Option { + return func(rt *TemporalRuntime) { rt.observabilityFingerprint = fp } +} + +// WithAgentMode sets the agent mode string (e.g. "interactive", "autonomous") used with +// [ComputeAgentFingerprint]. Must match pkg/agent [WithAgentMode] on both caller and worker. +func WithAgentMode(mode string) Option { + return func(rt *TemporalRuntime) { rt.agentMode = mode } +} + +// WithAgentToolExecutionMode sets the tool execution mode. The value is stored in +// the embedded [base.Runtime.ToolExecutionMode] so it drives both fingerprinting and +// workflow-level tool execution. +func WithAgentToolExecutionMode(mode types.AgentToolExecutionMode) Option { + return func(rt *TemporalRuntime) { rt.ToolExecutionMode = mode } +} + +// WithRetrieverFingerprint sets the retriever wiring digest (mode + retriever names). +// Must match pkg/agent [retrieverConfigFingerprint] for the same agent. +func WithRetrieverFingerprint(fp string) Option { + return func(rt *TemporalRuntime) { rt.retrieverFingerprint = fp } +} + +// WithDisableLocalWorker mirrors pkg/agent DisableLocalWorker. When false, the client +// embeds a worker and the runtime skips DescribeTaskQueue poller checks before starting +// workflows. +func WithDisableLocalWorker(disable bool) Option { + return func(rt *TemporalRuntime) { rt.disableLocalWorker = disable } +} + +// WithDisableFingerprintCheck disables activity-time caller-vs-worker fingerprint +// verification. Break-glass only: use temporarily during rollout incidents; keep false +// in production for safety. +func WithDisableFingerprintCheck(disable bool) Option { + return func(rt *TemporalRuntime) { rt.disableFingerprintCheck = disable } +} + +// WithTracer sets the optional [interfaces.Tracer]. When the runtime dials its own +// Temporal client ([WithTemporalConfig]) and the tracer implements [interfaces.OTelTracer], +// a Temporal OpenTelemetry client interceptor is attached automatically. +func WithTracer(t interfaces.Tracer) Option { + return func(rt *TemporalRuntime) { rt.Tracer = t } +} + +// WithMetrics sets the optional [interfaces.Metrics] for this runtime. +func WithMetrics(m interfaces.Metrics) Option { + return func(rt *TemporalRuntime) { rt.Metrics = m } +} + +// buildTemporalRuntime applies options onto a fresh [TemporalRuntime], validates required +// fields, and dials the Temporal client when [WithTemporalConfig] is used. The returned +// runtime is fully configured but does not yet have an agentFingerprint or eventbus — +// those are set by [NewTemporalRuntime]. +func buildTemporalRuntime(opts ...Option) (*TemporalRuntime, error) { + rt := &TemporalRuntime{logger: logger.NoopLogger()} + for _, opt := range opts { + opt(rt) + } + + if rt.temporalConfig == nil && rt.temporalClient == nil { + return nil, fmt.Errorf("temporal config or client is required") + } + + if rt.temporalConfig != nil { + tc, err := newTemporalClient(rt.temporalConfig, rt.logger, rt.Tracer) + if err != nil { + return nil, err + } + rt.temporalClient = tc + } else { // user-provided Temporal client + if _, ok := rt.Tracer.(interfaces.OTelTracer); ok { + rt.logger.Warn(context.Background(), + "user provided Temporal client — add OTel interceptor manually for tracing", + slog.String("scope", "runtime")) + } + } + + if rt.instanceId != "" { + rt.taskQueue = rt.taskQueue + "-" + rt.instanceId + } + + if rt.AgentExecution.LLM.Client == nil { + return nil, fmt.Errorf("llm client is required") + } + + if rt.Tracer == nil { + rt.Tracer = observability.DefaultNoopTracer + } + if rt.Metrics == nil { + rt.Metrics = observability.DefaultNoopMetrics + } + + rt.logger.Debug(context.Background(), "runtime config resolved", + slog.String("scope", "runtime"), + slog.String("agentName", rt.AgentSpec.Name), + slog.String("taskQueue", rt.taskQueue), + slog.String("instanceId", rt.instanceId), + slog.Int("maxIterations", rt.AgentExecution.Limits.MaxIterations), + slog.Bool("remoteWorker", rt.remoteWorker), + slog.String("agentMode", rt.agentMode), + slog.String("toolExecutionMode", string(rt.ToolExecutionMode)), + slog.Bool("enableRemoteWorkers", rt.enableRemoteWorkers), + slog.Bool("disableFingerprintCheck", rt.disableFingerprintCheck), + slog.Duration("timeout", rt.AgentExecution.Limits.Timeout), + slog.Duration("approvalTimeout", rt.AgentExecution.Limits.ApprovalTimeout), + slog.Bool("hasConversation", rt.AgentExecution.Session.Conversation != nil), + slog.Bool("hasTracer", rt.Tracer != nil), + slog.Bool("hasMetrics", rt.Metrics != nil)) + + return rt, nil +} diff --git a/internal/runtime/temporal/runtime.go b/internal/runtime/temporal/runtime.go index 8196fe2..a250cd3 100644 --- a/internal/runtime/temporal/runtime.go +++ b/internal/runtime/temporal/runtime.go @@ -14,7 +14,9 @@ import ( "github.com/agenticenv/agent-sdk-go/internal/eventbus" "github.com/agenticenv/agent-sdk-go/internal/events" "github.com/agenticenv/agent-sdk-go/internal/runtime" + "github.com/agenticenv/agent-sdk-go/internal/runtime/base" "github.com/agenticenv/agent-sdk-go/internal/types" + "github.com/agenticenv/agent-sdk-go/pkg/logger" "github.com/google/uuid" enumspb "go.temporal.io/api/enums/v1" "go.temporal.io/sdk/activity" @@ -44,10 +46,44 @@ var ErrAgentAlreadyRunning = errors.New("agent already has an active run") // ErrAgentFingerprintMismatch is returned when workflow input fingerprint does not match the worker. var ErrAgentFingerprintMismatch = errors.New("temporal: agent fingerprint mismatch (caller vs worker); redeploy worker or align agent config") +// TemporalRuntime implements [runtime.WorkerRuntime] and [runtime.EventBusRuntime] using +// Temporal workflows and activities as the execution backend. +// It embeds [base.Runtime] for the common agent fields (AgentSpec, AgentExecution, Tracer, Metrics, +// ToolExecutionMode) and holds all Temporal-specific connection and fingerprint state as flat fields. type TemporalRuntime struct { - TemporalRuntimeConfig - - // agentFingerprint is ComputeAgentFingerprint(BuildAgentFingerprintPayload(...)) at NewTemporalRuntime; immutable for this runtime. + base.Runtime // AgentSpec, AgentExecution, Tracer, Metrics, ToolExecutionMode + + // Temporal connection + temporalConfig *TemporalConfig + temporalClient client.Client + taskQueue string + instanceId string + ownsTemporalClient bool + // enableRemoteWorkers: start event worker + event workflow in Execute/ExecuteStream (client agent runtime). + enableRemoteWorkers bool + // remoteWorker: true for NewAgentWorker (polls activities); false for client Agent runtime. + remoteWorker bool + + logger logger.Logger + + // Fingerprint inputs captured at construction and consumed by computeAgentFingerprintFromRuntime. + policyFingerprint string + mcpFingerprint string + a2aFingerprint string + observabilityFingerprint string + // agentMode is the string form of [types.AgentMode] (e.g. "interactive", "autonomous"). + agentMode string + retrieverFingerprint string + + // Temporal-specific flags + // disableLocalWorker mirrors pkg/agent DisableLocalWorker: when false, the client embeds a worker + // so Execute/ExecuteStream skip DescribeTaskQueue poller checks. + disableLocalWorker bool + // disableFingerprintCheck disables activity-time caller-vs-worker fingerprint verification. + // Break-glass only: keep false in production for rollout/config safety. + disableFingerprintCheck bool + + // agentFingerprint is ComputeAgentFingerprint(BuildAgentFingerprintPayload(...)) at NewTemporalRuntime; immutable. agentFingerprint string eventbus eventbus.EventBus @@ -67,30 +103,30 @@ type TemporalRuntime struct { } func NewTemporalRuntime(opts ...Option) (*TemporalRuntime, error) { - cfg, err := buildTemporalRuntimeConfig(opts...) + rt, err := buildTemporalRuntime(opts...) if err != nil { return nil, err } - cfg.logger.Info(context.Background(), "runtime created", slog.String("scope", "runtime"), slog.String("name", cfg.AgentSpec.Name), slog.String("taskQueue", cfg.taskQueue)) - if cfg.DisableFingerprintCheck { - cfg.logger.Warn(context.Background(), + rt.logger.Info(context.Background(), "runtime created", + slog.String("scope", "runtime"), + slog.String("name", rt.AgentSpec.Name), + slog.String("taskQueue", rt.taskQueue)) + if rt.disableFingerprintCheck { + rt.logger.Warn(context.Background(), "fingerprint verification is disabled (break-glass mode)", slog.String("scope", "runtime"), - slog.String("name", cfg.AgentSpec.Name), - slog.String("taskQueue", cfg.taskQueue)) - } - fp := computeAgentFingerprintFromRuntimeConfig(cfg) - return &TemporalRuntime{ - TemporalRuntimeConfig: *cfg, - agentFingerprint: fp, - eventbus: eventbus.NewInmem(cfg.logger), - }, nil + slog.String("name", rt.AgentSpec.Name), + slog.String("taskQueue", rt.taskQueue)) + } + rt.agentFingerprint = computeAgentFingerprintFromRuntime(rt) + rt.eventbus = eventbus.NewInmem(rt.logger) + return rt, nil } // verifyAgentFingerprint returns an error when want does not equal the runtime's agent fingerprint // (computed at [NewTemporalRuntime]). func (rt *TemporalRuntime) verifyAgentFingerprint(want string) error { - if rt.DisableFingerprintCheck { + if rt.disableFingerprintCheck { return nil } if rt.agentFingerprint != want { @@ -270,7 +306,7 @@ func (rt *TemporalRuntime) Execute(ctx context.Context, req *runtime.ExecuteRequ AgentFingerprint: rt.agentFingerprint, EventTypes: []events.AgentEventType{}, SubAgentDepth: 0, - SubAgentRoutes: req.SubAgentRoutes, + SubAgentRoutes: buildSubAgentRoutes(req.SubAgents), MaxSubAgentDepth: req.MaxSubAgentDepth, } @@ -452,7 +488,7 @@ func (rt *TemporalRuntime) ExecuteStream(ctx context.Context, req *runtime.Execu AgentFingerprint: rt.agentFingerprint, EventTypes: streamEventTypes, SubAgentDepth: 0, - SubAgentRoutes: req.SubAgentRoutes, + SubAgentRoutes: buildSubAgentRoutes(req.SubAgents), MaxSubAgentDepth: req.MaxSubAgentDepth, } @@ -617,20 +653,20 @@ func (rt *TemporalRuntime) resolveEventPipeline(ctx context.Context, agentName s // before starting the workflow. Only these paths call Execute/ExecuteStream (client [Agent]; remoteWorker is always false). // Skip when mode is autonomous, or when an embedded worker polls in-process ([DisableLocalWorker] false). func (rt *TemporalRuntime) skipHasWorkersPrecheck() bool { - if rt.AgentMode == string(types.AgentModeAutonomous) { + if rt.agentMode == string(types.AgentModeAutonomous) { return true } - if !rt.DisableLocalWorker { + if !rt.disableLocalWorker { return true } return false } func (rt *TemporalRuntime) hasWorkersPrecheckSkipReason() string { - if rt.AgentMode == string(types.AgentModeAutonomous) { + if rt.agentMode == string(types.AgentModeAutonomous) { return "autonomous_mode" } - if !rt.DisableLocalWorker { + if !rt.disableLocalWorker { return "embedded_local_worker" } return "" diff --git a/internal/runtime/temporal/runtime_test.go b/internal/runtime/temporal/runtime_test.go index 8b0e81c..0567188 100644 --- a/internal/runtime/temporal/runtime_test.go +++ b/internal/runtime/temporal/runtime_test.go @@ -12,8 +12,8 @@ import ( "github.com/agenticenv/agent-sdk-go/internal/eventbus" "github.com/agenticenv/agent-sdk-go/internal/events" sdkruntime "github.com/agenticenv/agent-sdk-go/internal/runtime" + "github.com/agenticenv/agent-sdk-go/internal/runtime/base" "github.com/agenticenv/agent-sdk-go/internal/types" - "github.com/agenticenv/agent-sdk-go/pkg/interfaces" "github.com/agenticenv/agent-sdk-go/pkg/logger" "github.com/nexus-rpc/sdk-go/nexus" "github.com/stretchr/testify/mock" @@ -110,10 +110,8 @@ func TestSyntheticStreamCompleteEvent(t *testing.T) { func TestResolveEventPipeline(t *testing.T) { l := logger.NoopLogger() rt := &TemporalRuntime{ - TemporalRuntimeConfig: TemporalRuntimeConfig{ - logger: l, - taskQueue: "tq", - }, + logger: l, + taskQueue: "tq", } ewf, etq, err := rt.resolveEventPipeline(context.Background(), "My Agent") if err != nil { @@ -137,21 +135,21 @@ func TestResolveEventPipeline(t *testing.T) { } } -func TestSubAgentQueryFromArgs(t *testing.T) { - if subAgentQueryFromArgs(nil) != "" { +func TestSubAgentQuery(t *testing.T) { + if base.SubAgentQuery(nil) != "" { t.Error("nil args") } - if subAgentQueryFromArgs(map[string]any{}) != "" { + if base.SubAgentQuery(map[string]any{}) != "" { t.Error("empty map") } - if got := subAgentQueryFromArgs(map[string]any{"query": "hello"}); got != "hello" { + if got := base.SubAgentQuery(map[string]any{"query": "hello"}); got != "hello" { t.Errorf("got %q", got) } } func TestAgent_BeginRunEndRun(t *testing.T) { l := logger.DefaultLogger("error") - a := &TemporalRuntime{TemporalRuntimeConfig: TemporalRuntimeConfig{logger: l}} + a := &TemporalRuntime{logger: l} cleanup, err := a.beginRun("wf1") if err != nil { @@ -184,58 +182,6 @@ func TestRetryPolicy(t *testing.T) { } } -func TestApplyLLMSampling(t *testing.T) { - req := &interfaces.LLMRequest{} - applyLLMSampling(nil, req) - if req.Temperature != nil || req.MaxTokens != 0 { - t.Error("nil sampling should not modify request") - } - temp := 0.3 - topP := 0.9 - topK := 5 - applyLLMSampling(&types.LLMSampling{ - Temperature: &temp, - MaxTokens: 42, - TopP: &topP, - TopK: &topK, - }, req) - if req.Temperature == nil || *req.Temperature != 0.3 { - t.Errorf("Temperature = %v", req.Temperature) - } - if req.MaxTokens != 42 { - t.Errorf("MaxTokens = %d", req.MaxTokens) - } - if req.TopP == nil || *req.TopP != 0.9 { - t.Errorf("TopP = %v", req.TopP) - } - if req.TopK == nil || *req.TopK != 5 { - t.Errorf("TopK = %v", req.TopK) - } -} - -func TestApplyLLMSampling_reasoning(t *testing.T) { - req := &interfaces.LLMRequest{} - applyLLMSampling(&types.LLMSampling{ - Reasoning: &interfaces.LLMReasoning{ - Enabled: true, - Effort: "medium", - BudgetTokens: 2048, - }, - }, req) - if req.Reasoning == nil { - t.Fatal("expected Reasoning") - } - if req.Reasoning.Effort != "medium" { - t.Errorf("Effort = %q", req.Reasoning.Effort) - } - if req.Reasoning.BudgetTokens != 2048 { - t.Errorf("BudgetTokens = %d", req.Reasoning.BudgetTokens) - } - if !req.Reasoning.Enabled { - t.Error("expected Enabled") - } -} - func TestKeyvalsToAny(t *testing.T) { kv := []interface{}{"k", 1} out := keyvalsToAny(kv) @@ -246,7 +192,7 @@ func TestKeyvalsToAny(t *testing.T) { func TestTemporalRuntime_SetEventBus_GetEventBus(t *testing.T) { l := logger.NoopLogger() - rt := &TemporalRuntime{TemporalRuntimeConfig: TemporalRuntimeConfig{logger: l}} + rt := &TemporalRuntime{logger: l} if rt.GetEventBus() != nil { t.Fatal("zero-value runtime should have nil event bus until set") } diff --git a/internal/runtime/temporal/subagent.go b/internal/runtime/temporal/subagent.go new file mode 100644 index 0000000..48e97bf --- /dev/null +++ b/internal/runtime/temporal/subagent.go @@ -0,0 +1,38 @@ +package temporal + +import sdkruntime "github.com/agenticenv/agent-sdk-go/internal/runtime" + +// SubAgentRoute tells the Temporal runtime how to delegate to a sub-agent child workflow. +// It is serialised into AgentWorkflowInput and propagated to child workflows unchanged. +type SubAgentRoute struct { + Name string `json:"name"` + TaskQueue string `json:"task_queue,omitempty"` + ChildRoutes map[string]SubAgentRoute `json:"child_routes,omitempty"` + AgentFingerprint string `json:"agent_fingerprint,omitempty"` +} + +// buildSubAgentRoutes converts the runtime-agnostic SubAgentSpec tree (from ExecuteRequest) +// into a Temporal-specific SubAgentRoute map. Each spec's Runtime is type-asserted to +// *TemporalRuntime to extract the task queue and agent fingerprint. +func buildSubAgentRoutes(specs []*sdkruntime.SubAgentSpec) map[string]SubAgentRoute { + if len(specs) == 0 { + return nil + } + out := make(map[string]SubAgentRoute, len(specs)) + for _, spec := range specs { + if spec == nil { + continue + } + route := SubAgentRoute{Name: spec.Name} + if tr, ok := spec.Runtime.(*TemporalRuntime); ok { + route.TaskQueue = tr.taskQueue + route.AgentFingerprint = tr.agentFingerprint + } + route.ChildRoutes = buildSubAgentRoutes(spec.Children) + out[spec.ToolName] = route + } + if len(out) == 0 { + return nil + } + return out +} diff --git a/internal/types/subagent.go b/internal/types/subagent.go deleted file mode 100644 index cbc2909..0000000 --- a/internal/types/subagent.go +++ /dev/null @@ -1,15 +0,0 @@ -package types - -// SubAgentToolParamQuery is the tool/JSON parameter name for the query sent to a sub-agent. -const SubAgentToolParamQuery = "query" - -// SubAgentRoute tells the runtime how to delegate to a sub-agent (child run on TaskQueue), -// with nested routes for that sub-agent's sub-agents (frozen at parent run start). -// AgentFingerprint is the agent config digest for that sub-agent (pkg/agent + temporal.ComputeAgentFingerprint) -// so the child worker can reject runs when its deployed config does not match the caller. -type SubAgentRoute struct { - Name string `json:"name"` - TaskQueue string `json:"task_queue"` - ChildRoutes map[string]SubAgentRoute `json:"child_routes,omitempty"` - AgentFingerprint string `json:"agent_fingerprint,omitempty"` -} diff --git a/pkg/agent/agent.go b/pkg/agent/agent.go index 24a6dee..76d5c02 100644 --- a/pkg/agent/agent.go +++ b/pkg/agent/agent.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "sort" - "strings" "time" "log/slog" @@ -46,7 +45,10 @@ func buildAgent(opts []Option) (*Agent, error) { agentConfig: *cfg, } - if a.disableLocalWorker && a.streamEnabled && !a.enableRemoteWorkers { + // This guard is Temporal-specific: streaming on Temporal requires a local worker unless + // remote workers are enabled. LocalRuntime streams in-process via ExecuteStream and needs + // no background worker poll loop, so we skip the guard for the local backend. + if cfg.hasTemporalRuntime() && a.disableLocalWorker && a.streamEnabled && !a.enableRemoteWorkers { return nil, fmt.Errorf("DisableLocalWorker with streaming requires EnableRemoteWorkers()") } @@ -56,7 +58,10 @@ func buildAgent(opts []Option) (*Agent, error) { } a.runtime = rt - if !a.disableLocalWorker { + // Worker poll loop is only needed for backends that implement WorkerRuntime (e.g. Temporal). + // LocalRuntime executes in-process via Execute/ExecuteStream; creating a worker for it would + // log a spurious error because LocalRuntime does not implement WorkerRuntime. + if !a.disableLocalWorker && cfg.hasTemporalRuntime() { a.localAgentWorker = &AgentWorker{agentConfig: *cfg, runtime: rt} } @@ -297,7 +302,7 @@ func (a *Agent) executeRequest(userPrompt, conversationID string, streaming bool UserPrompt: userPrompt, ConversationID: conversationID, StreamingEnabled: streaming, - SubAgentRoutes: a.buildSubAgentRoutes(), + SubAgents: a.buildSubAgentSpecs(), MaxSubAgentDepth: a.maxSubAgentDepth, ApprovalHandler: a.approvalHandler, AgentSpec: a.agentSpec(), @@ -315,44 +320,42 @@ func (a *Agent) agentExecution() *runtime.AgentExecution { return &e } -// buildSubAgentRoutes snapshots sub-agent tool names, task queues, and nested routes for runtime delegation (internal). -func (a *Agent) buildSubAgentRoutes() map[string]types.SubAgentRoute { +// buildSubAgentSpecs builds the runtime-agnostic sub-agent spec tree for this agent. +// Each runtime receives this tree via ExecuteRequest.SubAgents and constructs its own +// internal routing structures (local: *LocalRuntime refs; temporal: task queue + fingerprint). +func (a *Agent) buildSubAgentSpecs() []*runtime.SubAgentSpec { if a == nil || len(a.subAgents) == 0 { return nil } - out := make(map[string]types.SubAgentRoute, len(a.subAgents)) + out := make([]*runtime.SubAgentSpec, 0, len(a.subAgents)) for _, sub := range a.subAgents { if sub == nil { continue } - tq := strings.TrimSpace(sub.taskQueue) - if tq == "" { + toolName, err := subAgentToolName(sub.Name) + if err != nil || toolName == "" { continue } - name, err := subAgentToolName(sub.Name) - if err != nil || name == "" { - continue - } - out[name] = types.SubAgentRoute{ - Name: sub.Name, - TaskQueue: tq, - ChildRoutes: sub.buildSubAgentRoutes(), - AgentFingerprint: sub.agentConfigFingerprint(), - } + out = append(out, &runtime.SubAgentSpec{ + Name: sub.Name, + ToolName: toolName, + Runtime: sub.runtime, + Children: sub.buildSubAgentSpecs(), + }) } if len(out) == 0 { return nil } if a.logger != nil { names := make([]string, 0, len(out)) - for k := range out { - names = append(names, k) + for _, s := range out { + names = append(names, s.ToolName) } sort.Strings(names) - a.logger.Debug(context.Background(), "built sub-agent routes for runtime delegation", + a.logger.Debug(context.Background(), "built sub-agent specs for runtime delegation", slog.String("scope", "agent"), slog.Any("subAgentToolNames", names), - slog.Int("routeCount", len(out))) + slog.Int("specCount", len(out))) } return out } diff --git a/pkg/agent/agent_test.go b/pkg/agent/agent_test.go index a25f286..0a7cc59 100644 --- a/pkg/agent/agent_test.go +++ b/pkg/agent/agent_test.go @@ -235,60 +235,87 @@ func (m *mockConversation) ListMessages(ctx context.Context, id string, opts ... func (m *mockConversation) Clear(ctx context.Context, id string) error { return nil } func (m *mockConversation) IsDistributed() bool { return false } -func TestBuildWorkflowSubAgentRoutes_flat(t *testing.T) { - child := &Agent{agentConfig: agentConfig{Name: "Child", taskQueue: "q-child"}} - parent := &Agent{agentConfig: agentConfig{Name: "Parent", taskQueue: "q-parent", subAgents: []*Agent{child}}} - got := parent.buildSubAgentRoutes() - if got == nil { - t.Fatal("expected routes") +// stubRuntime is a minimal Runtime implementation for tests. +type stubRuntime struct{} + +func (s *stubRuntime) Execute(_ context.Context, _ *runtime.ExecuteRequest) (*types.AgentRunResult, error) { + return nil, nil +} +func (s *stubRuntime) ExecuteStream(_ context.Context, _ *runtime.ExecuteRequest) (<-chan events.AgentEvent, error) { + return nil, nil +} +func (s *stubRuntime) Approve(_ context.Context, _ string, _ types.ApprovalStatus) error { return nil } +func (s *stubRuntime) Close() {} + +func TestBuildSubAgentSpecs_flat(t *testing.T) { + childRT := &stubRuntime{} + child := &Agent{agentConfig: agentConfig{Name: "Child"}, runtime: childRT} + parent := &Agent{agentConfig: agentConfig{Name: "Parent", subAgents: []*Agent{child}}, runtime: &stubRuntime{}} + + got := parent.buildSubAgentSpecs() + if len(got) != 1 { + t.Fatalf("want 1 spec, got %d", len(got)) } key, err := subAgentToolName(child.Name) if err != nil { t.Fatal(err) } - r, ok := got[key] - if !ok { - t.Fatalf("missing %q in %v", key, got) + spec := got[0] + if spec.ToolName != key { + t.Fatalf("ToolName = %q, want %q", spec.ToolName, key) + } + if spec.Name != child.Name { + t.Fatalf("Name = %q, want %q", spec.Name, child.Name) + } + if spec.Runtime != childRT { + t.Fatal("Runtime mismatch") } - if r.TaskQueue != "q-child" || r.ChildRoutes != nil { - t.Fatalf("route = %+v", r) + if spec.Children != nil { + t.Fatalf("expected no children, got %v", spec.Children) } } -func TestBuildWorkflowSubAgentRoutes_nested(t *testing.T) { - leaf := &Agent{agentConfig: agentConfig{Name: "Leaf", taskQueue: "q-leaf"}} - mid := &Agent{agentConfig: agentConfig{Name: "Mid", taskQueue: "q-mid", subAgents: []*Agent{leaf}}} - root := &Agent{agentConfig: agentConfig{Name: "Root", taskQueue: "q-root", subAgents: []*Agent{mid}}} - got := root.buildSubAgentRoutes() - midKey, err := subAgentToolName(mid.Name) - if err != nil { - t.Fatal(err) +func TestBuildSubAgentSpecs_nested(t *testing.T) { + leafRT := &stubRuntime{} + leaf := &Agent{agentConfig: agentConfig{Name: "Leaf"}, runtime: leafRT} + midRT := &stubRuntime{} + mid := &Agent{agentConfig: agentConfig{Name: "Mid", subAgents: []*Agent{leaf}}, runtime: midRT} + root := &Agent{agentConfig: agentConfig{Name: "Root", subAgents: []*Agent{mid}}, runtime: &stubRuntime{}} + + got := root.buildSubAgentSpecs() + if len(got) != 1 { + t.Fatalf("want 1 top-level spec, got %d", len(got)) } - rMid, ok := got[midKey] - if !ok { - t.Fatalf("missing mid %q", midKey) + midSpec := got[0] + if midSpec.Runtime != midRT { + t.Fatal("mid Runtime mismatch") } - if rMid.ChildRoutes == nil { - t.Fatal("expected nested child routes") + if len(midSpec.Children) != 1 { + t.Fatalf("want 1 child spec, got %d", len(midSpec.Children)) } - leafKey, err := subAgentToolName(leaf.Name) - if err != nil { - t.Fatal(err) + leafSpec := midSpec.Children[0] + if leafSpec.Runtime != leafRT { + t.Fatal("leaf Runtime mismatch") } - rLeaf, ok := rMid.ChildRoutes[leafKey] - if !ok { - t.Fatalf("missing leaf %q", leafKey) - } - if rLeaf.TaskQueue != "q-leaf" || len(rLeaf.ChildRoutes) != 0 { - t.Fatalf("leaf route = %+v", rLeaf) + if len(leafSpec.Children) != 0 { + t.Fatalf("leaf should have no children, got %d", len(leafSpec.Children)) } } -func TestBuildWorkflowSubAgentRoutes_skipsEmptyTaskQueue(t *testing.T) { - skip := &Agent{agentConfig: agentConfig{Name: "X", taskQueue: " "}} - parent := &Agent{agentConfig: agentConfig{subAgents: []*Agent{skip}}} - if got := parent.buildSubAgentRoutes(); len(got) != 0 { - t.Fatalf("want empty, got %v", got) +func TestBuildSubAgentSpecs_noRuntimeStillBuilds(t *testing.T) { + // Sub-agent with no runtime still gets a spec — runtime decides what to do with it. + sub := &Agent{agentConfig: agentConfig{Name: "X"}} + parent := &Agent{agentConfig: agentConfig{subAgents: []*Agent{sub}}} + + got := parent.buildSubAgentSpecs() + if len(got) != 1 { + t.Fatalf("want 1 spec, got %v", got) + } + if got[0].ToolName != "subagent_X" { + t.Fatalf("ToolName = %q", got[0].ToolName) + } + if got[0].Runtime != nil { + t.Fatalf("expected nil runtime, got %v", got[0].Runtime) } } diff --git a/pkg/agent/config.go b/pkg/agent/config.go index 551021d..053ce60 100644 --- a/pkg/agent/config.go +++ b/pkg/agent/config.go @@ -631,7 +631,7 @@ func otlpLogsClientConfigured(logs interfaces.Logs) bool { } // buildAgentConfig applies options, validates, and sets defaults (logger, timeouts, iterations). -// WithTemporalConfig lets the runtime create a Temporal client from host settings; WithTemporalClient supplies a caller-owned client. +// When neither WithTemporalConfig nor WithTemporalClient is set, the local in-process runtime is used. // remoteWorker is false for Agent; NewAgentWorker sets it to true for worker-side activities. func buildAgentConfig(opts []Option) (*agentConfig, error) { c := &agentConfig{remoteWorker: false, ID: uuid.New().String()} @@ -666,17 +666,13 @@ func buildAgentConfig(opts []Option) (*agentConfig, error) { if c.toolApprovalPolicy == nil { c.toolApprovalPolicy = RequireAllToolApprovalPolicy{} } - // Either TemporalConfig or TemporalClient is required, not both. + // Temporal-specific validation: only enforced when the caller explicitly opts in to the + // Temporal backend. When neither is set the local runtime is used as the default backend. if c.temporalConfig != nil && c.temporalClient != nil { return nil, errors.New("provide either WithTemporalConfig or WithTemporalClient, not both") } - if c.temporalConfig == nil && c.temporalClient == nil { - return nil, errors.New("temporal connection is required: use WithTemporalConfig or WithTemporalClient") - } - if c.temporalConfig != nil { - if c.temporalConfig.TaskQueue == "" { - return nil, errors.New("TaskQueue is required in TemporalConfig: provide a unique name per agent") - } + if c.temporalConfig != nil && c.temporalConfig.TaskQueue == "" { + return nil, errors.New("TaskQueue is required in TemporalConfig: provide a unique name per agent") } if c.temporalClient != nil && c.taskQueue == "" { return nil, errors.New("taskQueue is required when using WithTemporalClient") @@ -816,20 +812,26 @@ func buildAgentConfig(opts []Option) (*agentConfig, error) { } } + runtimeName := "local" + if c.hasTemporalRuntime() { + runtimeName = "temporal" + } + ctx := context.Background() - c.logger.Info(ctx, "agent config built", slog.String("scope", "agent"), slog.String("name", c.Name), slog.String("taskQueue", c.taskQueue)) - // Debug: full config summary for troubleshooting (no sensitive: systemPrompt, API keys) - c.logger.Info(ctx, "agent config detail", + c.logger.Info(ctx, "agent config built", + slog.String("scope", "agent"), + slog.String("name", c.Name), + slog.String("runtime", runtimeName), + ) + + // Full config summary for troubleshooting (no sensitive values: systemPrompt, API keys). + // Fields are split by relevance: common fields always logged; Temporal-specific fields only + // when the Temporal backend is selected. + commonAttrs := []any{ slog.String("scope", "agent"), slog.String("name", c.Name), - slog.String("taskQueue", c.taskQueue), - slog.String("instanceId", c.instanceId), + slog.String("runtime", runtimeName), slog.Int("maxIterations", c.maxIterations), - slog.Bool("streamEnabled", c.streamEnabled), - slog.Bool("disableLocalWorker", c.disableLocalWorker), - slog.Bool("enableRemoteWorkers", c.enableRemoteWorkers), - slog.Bool("remoteWorker", c.remoteWorker), - slog.Bool("disableFingerprintCheck", c.disableFingerprintCheck), slog.String("agentMode", string(c.agentMode)), slog.String("agentToolExecutionMode", string(c.agentToolExecutionMode)), slog.Bool("hasApprovalHandler", c.approvalHandler != nil), @@ -849,7 +851,20 @@ func buildAgentConfig(opts []Option) (*agentConfig, error) { slog.Bool("enabledMetrics", c.metrics != nil), slog.Bool("otlpSdkLogsExporter", otlpLogsClientConfigured(c.logs)), slog.Bool("otelLoggerWired", otelLoggerWired), - ) + } + if c.hasTemporalRuntime() { + c.logger.Info(ctx, "agent config detail", append(commonAttrs, + slog.String("taskQueue", c.taskQueue), + slog.String("instanceId", c.instanceId), + slog.Bool("streamEnabled", c.streamEnabled), + slog.Bool("disableLocalWorker", c.disableLocalWorker), + slog.Bool("enableRemoteWorkers", c.enableRemoteWorkers), + slog.Bool("remoteWorker", c.remoteWorker), + slog.Bool("disableFingerprintCheck", c.disableFingerprintCheck), + )...) + } else { + c.logger.Info(ctx, "agent config detail", commonAttrs...) + } if c.tracer == nil { c.tracer = observability.DefaultNoopTracer @@ -864,6 +879,16 @@ func buildAgentConfig(opts []Option) (*agentConfig, error) { return c, nil } +// buildAgentRuntime constructs the execution backend from agentConfig. +// Defaults to the local in-process runtime when no Temporal backend is configured. +// Extend with additional branches when new [runtime.Runtime] implementations are added. +func (cfg *agentConfig) buildAgentRuntime(remoteWorker bool) (runtime.Runtime, error) { + if cfg.hasTemporalRuntime() { + return cfg.buildTemporalRuntime(remoteWorker) + } + return cfg.buildLocalRuntime() +} + // toolsList returns WithTools or registry tools, merged MCP tools ([mcpTools]), A2A tools ([a2aTools]), // retriever tools ([retrieverTools]), then [subAgentTools] from [buildSubAgentTools]. func (c *agentConfig) toolsList() []interfaces.Tool { @@ -1147,39 +1172,6 @@ func observabilityOptions(c *agentConfig) []observability.Option { return opts } -// agentConfigFingerprint hashes identity, prompts, tools, sampling, limits, approval policy, -// MCP wiring digest (transports, timeouts, filters, extra MCP client names), A2A wiring digest -// for outbound clients only ([WithA2AConfig] / [WithA2AClients] via [a2aConfigFingerprint]), -// observability OTLP wiring ([observabilityConfigFingerprint] from [WithObservabilityConfig]), -// [WithRetrieverMode], and retriever names ([retrieverConfigFingerprint]). -// Same inputs as temporal.NewTemporalRuntime agent fingerprint. -// -// Inbound [A2AServerConfig] from [WithA2AServer] / [WithA2ADefaultServer] (listen address, -// [A2AServerConfig.BearerTokens], [A2AServerConfig.AgentCard] overrides for RunA2A) is omitted: -// it does not change worker-side tool wiring or activity semantics and is typically deployment-specific. -// Injected [WithTracer] / [WithMetrics] alone (without [WithObservabilityConfig]) are not hashed. -func (c *agentConfig) agentConfigFingerprint() string { - mat := temporal.BuildAgentFingerprintPayload( - c.runtimeAgentSpec(), - temporal.ToolNamesFromTools(c.toolsList()), - toolPolicyFingerprint(c.toolApprovalPolicy), - llmSamplingRuntimeView(c.llmSampling), - c.conversationSize, - runtime.AgentLimits{ - MaxIterations: c.maxIterations, - Timeout: c.timeout, - ApprovalTimeout: c.approvalTimeout, - }, - mcpConfigFingerprint(c.mcpServers, mcpExtraClientNames(c.mcpClients)), - a2aConfigFingerprint(c.a2aServers, a2aExtraClientNames(c.a2aClients)), - observabilityConfigFingerprint(c.observabilityConfig), - string(c.agentMode), - c.agentToolExecutionMode, - retrieverConfigFingerprint(c.retrieverMode, c.retrievers), - ) - return temporal.ComputeAgentFingerprint(mat) -} - func llmSamplingRuntimeView(s *LLMSampling) *runtime.LLMSampling { if s == nil { return nil diff --git a/pkg/agent/config_test.go b/pkg/agent/config_test.go index 69de6e9..26374cc 100644 --- a/pkg/agent/config_test.go +++ b/pkg/agent/config_test.go @@ -11,6 +11,8 @@ import ( "github.com/a2aproject/a2a-go/v2/a2a" "github.com/a2aproject/a2a-go/v2/a2asrv" + "github.com/agenticenv/agent-sdk-go/internal/runtime" + "github.com/agenticenv/agent-sdk-go/internal/runtime/temporal" "github.com/agenticenv/agent-sdk-go/internal/types" "github.com/agenticenv/agent-sdk-go/pkg/interfaces" "github.com/agenticenv/agent-sdk-go/pkg/logger" @@ -19,13 +21,42 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" ) -func TestBuildAgentConfig_NeitherTemporalConfigNorClient(t *testing.T) { - _, err := buildAgentConfig([]Option{ +// agentConfigFingerprint is a test helper that mirrors the fingerprint computed by the temporal +// runtime for a given agent config. Lives here (not in production code) since it is only used +// to assert fingerprint stability in tests. +func agentConfigFingerprint(c *agentConfig) string { + mat := temporal.BuildAgentFingerprintPayload( + c.runtimeAgentSpec(), + temporal.ToolNamesFromTools(c.toolsList()), + toolPolicyFingerprint(c.toolApprovalPolicy), + llmSamplingRuntimeView(c.llmSampling), + c.conversationSize, + runtime.AgentLimits{ + MaxIterations: c.maxIterations, + Timeout: c.timeout, + ApprovalTimeout: c.approvalTimeout, + }, + mcpConfigFingerprint(c.mcpServers, mcpExtraClientNames(c.mcpClients)), + a2aConfigFingerprint(c.a2aServers, a2aExtraClientNames(c.a2aClients)), + observabilityConfigFingerprint(c.observabilityConfig), + string(c.agentMode), + c.agentToolExecutionMode, + retrieverConfigFingerprint(c.retrieverMode, c.retrievers), + ) + return temporal.ComputeAgentFingerprint(mat) +} + +func TestBuildAgentConfig_NeitherTemporalConfigNorClient_UsesLocalRuntime(t *testing.T) { + // No Temporal config is valid — the local runtime is the default backend. + cfg, err := buildAgentConfig([]Option{ WithName("test"), WithLLMClient(stubLLM{}), }) - if err == nil || !strings.Contains(err.Error(), "temporal connection is required") { - t.Fatalf("got %v", err) + if err != nil { + t.Fatalf("expected success with local backend, got: %v", err) + } + if cfg.hasTemporalRuntime() { + t.Fatal("expected local backend (hasTemporalRuntime should be false)") } } @@ -812,7 +843,7 @@ func TestAgentConfigFingerprint_RetrieverModeChangesDigest(t *testing.T) { if err != nil { t.Fatal(err) } - return cfg.agentConfigFingerprint() + return agentConfigFingerprint(cfg) } fpAgentic := build(RetrieverModeAgentic) fpPrefetch := build(RetrieverModePrefetch) @@ -864,7 +895,7 @@ func TestAgentConfigFingerprint_AgenticRetrieverNamesChangesDigest(t *testing.T) if err != nil { t.Fatal(err) } - if cfgNoR.agentConfigFingerprint() == cfgWithR.agentConfigFingerprint() { + if agentConfigFingerprint(cfgNoR) == agentConfigFingerprint(cfgWithR) { t.Fatal("expected different fingerprints for agentic mode with vs without retriever names") } } @@ -1107,9 +1138,9 @@ func TestAgentConfigFingerprint_InboundA2AServerIgnored(t *testing.T) { if err != nil { t.Fatal(err) } - if cfgNoInbound.agentConfigFingerprint() != cfgWithInbound.agentConfigFingerprint() { + if agentConfigFingerprint(cfgNoInbound) != agentConfigFingerprint(cfgWithInbound) { t.Fatalf("inbound A2AServerConfig should not change agent fingerprint: %q vs %q", - cfgNoInbound.agentConfigFingerprint(), cfgWithInbound.agentConfigFingerprint()) + agentConfigFingerprint(cfgNoInbound), agentConfigFingerprint(cfgWithInbound)) } } diff --git a/pkg/agent/retriever_test.go b/pkg/agent/retriever_test.go index 7d3428c..c89b265 100644 --- a/pkg/agent/retriever_test.go +++ b/pkg/agent/retriever_test.go @@ -317,7 +317,7 @@ func TestAgentConfigFingerprint_RetrieverNamesChangesDigest(t *testing.T) { if err != nil { t.Fatal(err) } - if cfgNoR.agentConfigFingerprint() == cfgWithR.agentConfigFingerprint() { + if agentConfigFingerprint(cfgNoR) == agentConfigFingerprint(cfgWithR) { t.Fatal("expected different fingerprints when retriever names are registered") } } diff --git a/pkg/agent/runtime_factory.go b/pkg/agent/runtime_factory.go index b8f446a..6c24c9f 100644 --- a/pkg/agent/runtime_factory.go +++ b/pkg/agent/runtime_factory.go @@ -1,9 +1,7 @@ package agent import ( - "fmt" - - "github.com/agenticenv/agent-sdk-go/internal/runtime" + "github.com/agenticenv/agent-sdk-go/internal/runtime/local" "github.com/agenticenv/agent-sdk-go/internal/runtime/temporal" ) @@ -12,16 +10,7 @@ func (cfg *agentConfig) hasTemporalRuntime() bool { return cfg.temporalConfig != nil || cfg.temporalClient != nil } -// buildAgentRuntime constructs the execution backend from agentConfig. -// Extend with additional branches when new [runtime.Runtime] implementations are added. -func (cfg *agentConfig) buildAgentRuntime(remoteWorker bool) (runtime.Runtime, error) { - if cfg.hasTemporalRuntime() { - return cfg.buildTemporalRuntime(remoteWorker) - } - return nil, fmt.Errorf("no runtime configured: use WithTemporalConfig or WithTemporalClient") -} - -func (cfg *agentConfig) buildTemporalRuntime(remoteWorker bool) (runtime.Runtime, error) { +func (cfg *agentConfig) buildTemporalRuntime(remoteWorker bool) (*temporal.TemporalRuntime, error) { options := []temporal.Option{ temporal.WithLogger(cfg.logger), temporal.WithAgentSpec(cfg.runtimeAgentSpec()), @@ -53,3 +42,15 @@ func (cfg *agentConfig) buildTemporalRuntime(remoteWorker bool) (runtime.Runtime options = append(options, temporal.WithEnableRemoteWorkers(enableRemote)) return temporal.NewTemporalRuntime(options...) } + +func (cfg *agentConfig) buildLocalRuntime() (*local.LocalRuntime, error) { + options := []local.Option{ + local.WithLogger(cfg.logger), + local.WithToolExecutionMode(cfg.agentToolExecutionMode), + local.WithAgentSpec(cfg.runtimeAgentSpec()), + local.WithAgentExecution(cfg.runtimeAgentExecution()), + local.WithTracer(cfg.tracer), + local.WithMetrics(cfg.metrics), + } + return local.NewLocalRuntime(options...) +} diff --git a/pkg/agent/runtime_factory_test.go b/pkg/agent/runtime_factory_test.go index 2798e1c..452a55c 100644 --- a/pkg/agent/runtime_factory_test.go +++ b/pkg/agent/runtime_factory_test.go @@ -16,10 +16,23 @@ func TestHasTemporalRuntime(t *testing.T) { } } -func TestBuildAgentRuntime_NoTemporalBackend(t *testing.T) { +func TestBuildAgentRuntime_NoTemporalBackend_BuildsLocalRuntime(t *testing.T) { + // When no Temporal config is set, buildAgentRuntime falls back to LocalRuntime. cfg := &agentConfig{Name: "n", LLMClient: stubLLM{}} + rt, err := cfg.buildAgentRuntime(false) + if err != nil { + t.Fatalf("expected local runtime to be built, got error: %v", err) + } + if rt == nil { + t.Fatal("expected non-nil runtime") + } +} + +func TestBuildAgentRuntime_NoTemporalBackend_MissingLLMErrors(t *testing.T) { + // Without an LLM client the local runtime builder must return an error. + cfg := &agentConfig{Name: "n"} _, err := cfg.buildAgentRuntime(false) - if err == nil || !strings.Contains(err.Error(), "no runtime configured") { - t.Fatalf("got %v", err) + if err == nil || !strings.Contains(err.Error(), "llm client is required") { + t.Fatalf("expected 'llm client is required', got %v", err) } } diff --git a/pkg/agent/subagent.go b/pkg/agent/subagent.go index 18d3fe9..781502c 100644 --- a/pkg/agent/subagent.go +++ b/pkg/agent/subagent.go @@ -7,7 +7,7 @@ import ( "regexp" "strings" - "github.com/agenticenv/agent-sdk-go/internal/types" + "github.com/agenticenv/agent-sdk-go/internal/runtime" "github.com/agenticenv/agent-sdk-go/pkg/interfaces" "github.com/agenticenv/agent-sdk-go/pkg/tools" ) @@ -21,9 +21,10 @@ var subAgentToolNameNonIdent = regexp.MustCompile(`[^a-zA-Z0-9]+`) // defaultMaxSubAgentDepth is the maximum number of sub-agent hops from this agent when unset. const defaultMaxSubAgentDepth = 2 -// ErrSubAgentToolNotExecutable is returned by SubAgentTool.Execute. -// Normal runs delegate via AgentWorkflow child workflows; Execute() is only for misconfigured or direct activity calls. -var ErrSubAgentToolNotExecutable = errors.New("sub-agent tool must be executed via workflow, not Execute()") +// ErrSubAgentToolNotExecutable is returned by SubAgentTool.Execute when called outside a managed runtime. +// Local runtime delegates via SubAgentRunners (in-process); Temporal delegates via child workflows. +// This error only surfaces if the runtime bypasses the normal SubAgentRunners/workflow delegation path. +var ErrSubAgentToolNotExecutable = errors.New("sub-agent tool must be delegated via runtime (local: in-process runner, temporal: child workflow)") // ErrSubAgentNameInvalid is returned when computing a sub-agent delegation tool name from a display name fails // for a delegation tool (empty name, or name contains no letters or digits after normalization). @@ -105,8 +106,8 @@ func (t *subAgentTool) Description() string { func (t *subAgentTool) Parameters() interfaces.JSONSchema { return tools.Params(map[string]interfaces.JSONSchema{ - types.SubAgentToolParamQuery: tools.ParamString("Task or question to send to the sub-agent."), - }, types.SubAgentToolParamQuery) + runtime.SubAgentToolParamQuery: tools.ParamString("Task or question to send to the sub-agent."), + }, runtime.SubAgentToolParamQuery) } func (t *subAgentTool) Execute(_ context.Context, _ map[string]any) (any, error) { diff --git a/pkg/agent/worker.go b/pkg/agent/worker.go index 91c788a..b102436 100644 --- a/pkg/agent/worker.go +++ b/pkg/agent/worker.go @@ -18,11 +18,15 @@ type AgentWorker struct { // NewAgentWorker creates an AgentWorker that polls and executes runs for the configured backend. // Same options as [NewAgent]. Use when the agent is created with [DisableLocalWorker]. +// AgentWorker requires a Temporal backend (WithTemporalConfig or WithTemporalClient). func NewAgentWorker(opts ...Option) (*AgentWorker, error) { cfg, err := buildAgentConfig(opts) if err != nil { return nil, err } + if !cfg.hasTemporalRuntime() { + return nil, fmt.Errorf("AgentWorker requires a Temporal backend: use WithTemporalConfig or WithTemporalClient") + } cfg.remoteWorker = true if cfg.disableFingerprintCheck { return nil, fmt.Errorf("WithDisableFingerprintCheck is not allowed on AgentWorker (remote worker process)") diff --git a/pkg/agent/worker_test.go b/pkg/agent/worker_test.go index 26ab7eb..36b9c0f 100644 --- a/pkg/agent/worker_test.go +++ b/pkg/agent/worker_test.go @@ -13,7 +13,7 @@ import ( func TestNewAgentWorker_requiresTemporal(t *testing.T) { _, err := NewAgentWorker(WithName("w"), WithLLMClient(stubLLM{})) - if err == nil || !strings.Contains(err.Error(), "temporal connection is required") { + if err == nil || !strings.Contains(err.Error(), "AgentWorker requires a Temporal backend") { t.Fatalf("got %v", err) } } diff --git a/taskfiles/examples.yml b/taskfiles/examples.yml new file mode 100644 index 0000000..8d60104 --- /dev/null +++ b/taskfiles/examples.yml @@ -0,0 +1,494 @@ +version: '3' + +tasks: + # ── Internal ─────────────────────────────────────────────── + + require:llm: + internal: true + desc: Require LLM_APIKEY, LLM_PROVIDER, and LLM_MODEL in the environment + preconditions: + - sh: test -n "$LLM_APIKEY" && test -n "$LLM_PROVIDER" && test -n "$LLM_MODEL" + msg: "LLM_APIKEY, LLM_PROVIDER, and LLM_MODEL are required. Export all three before running example tasks (provider openai, anthropic, or gemini — with a matching model)." + + # Shell wrappers — infra tasks live in examples/Taskfile.yml; structured task+taskfile + # resolves against this flattened file and fails with "does not exist". + _infra:deps:up: + internal: true + cmds: + - task -t "{{.ROOT_DIR}}/examples/Taskfile.yml" infra:deps:up + _infra:deps:down: + internal: true + cmds: + - task -t "{{.ROOT_DIR}}/examples/Taskfile.yml" infra:deps:down + _infra:temporal:up: + internal: true + cmds: + - task -t "{{.ROOT_DIR}}/examples/Taskfile.yml" infra:temporal:up + _infra:temporal:wait: + internal: true + cmds: + - task -t "{{.ROOT_DIR}}/examples/Taskfile.yml" infra:temporal:wait + _infra:temporal:down: + internal: true + cmds: + - task -t "{{.ROOT_DIR}}/examples/Taskfile.yml" infra:temporal:down + + report:header: + internal: true + desc: Write formatted report header (standalone local/temporal or examples:all) + cmds: + - | + { + echo "========================================" + echo "Agent SDK — Examples Run" + echo "========================================" + echo "Date: $(date)" + echo "Runtime: {{.RUNTIME}}" + echo "Logs: {{.LOG_REPORT}}" + {{if eq .PLAN_MODE "true"}}echo "Mode: plan (no go run, no infra)"{{end}} + echo "========================================" + echo "" + } >> "{{.REPORT_FILE}}" + + report:section: + internal: true + desc: Write runtime section divider (examples:all append mode) + cmds: + - | + { + echo "----------------------------------------" + echo "Runtime: {{.RUNTIME}}" + echo "----------------------------------------" + echo "" + } >> "{{.REPORT_FILE}}" + + report:summary: + internal: true + desc: Write formatted SUMMARY block from COUNT_FILE + vars: + FAIL_ON_ERROR: '{{.FAIL_ON_ERROR | default "true"}}' + PLAN_MODE: '{{.PLAN_MODE | default "false"}}' + cmds: + - | + read -r PASS FAIL < "{{.COUNT_FILE}}" + TOTAL=$((PASS + FAIL)) + if [ "{{.PLAN_MODE}}" = "true" ]; then + STATUS_LINE=" Status: ○ PLAN ONLY (not executed)" + elif [ "${FAIL}" -gt 0 ]; then + STATUS_LINE=" Status: ❌ FAILURES (${FAIL})" + else + STATUS_LINE=" Status: ✅ ALL PASSED" + fi + { + echo "" + echo "========================================" + echo "SUMMARY" + echo " Runtime: {{.RUNTIME}}" + echo " Total: ${TOTAL}" + } >> "{{.REPORT_FILE}}" + if [ "{{.PLAN_MODE}}" = "true" ]; then + { + echo " Planned: ${PASS}" + echo " Pass: 0" + echo " Fail: 0" + } >> "{{.REPORT_FILE}}" + else + { + echo " Pass: ${PASS}" + echo " Fail: ${FAIL}" + } >> "{{.REPORT_FILE}}" + fi + { + echo "${STATUS_LINE}" + echo "========================================" + } >> "{{.REPORT_FILE}}" + if [ "{{.PLAN_MODE}}" = "true" ]; then + echo "📊 runtime={{.RUNTIME}} plan=${PASS} (not executed)" + else + echo "📊 runtime={{.RUNTIME}} total=${TOTAL} pass=${PASS} fail=${FAIL}" + fi + rm -f "{{.COUNT_FILE}}" + if [ "${FAIL}" -gt 0 ] && [ "{{.FAIL_ON_ERROR}}" = "true" ]; then + echo "❌ Examples failed: ${FAIL}" + exit 1 + fi + + exec:example: + internal: true + desc: Run single example and capture output + dir: '{{.ROOT_DIR}}/examples' + vars: + RUNTIME: '{{.RUNTIME | default "local"}}' + NAME: '{{.NAME}}' + PROMPT: '{{.PROMPT | default ""}}' + SKIP_RUN: '{{.SKIP_RUN | default "false"}}' + cmds: + - | + if [ "{{.SKIP_RUN}}" = "true" ]; then + RESULT="○ PLAN" + if [ -n "{{.PROMPT}}" ]; then + echo "○ Plan {{.NAME}} ({{.RUNTIME}}) — go run ./{{.NAME}} \"{{.PROMPT}}\"" + else + echo "○ Plan {{.NAME}} ({{.RUNTIME}}) — go run ./{{.NAME}}" + fi + if [ -n "{{.REPORT_FILE}}" ]; then + echo "$RESULT {{.NAME}}" >> {{.REPORT_FILE}} + fi + if [ -n "{{.COUNT_FILE}}" ]; then + read -r PASS FAIL < "{{.COUNT_FILE}}" 2>/dev/null || PASS=0 FAIL=0 + PASS=$((PASS + 1)) + printf '%s %s\n' "$PASS" "$FAIL" > "{{.COUNT_FILE}}" + fi + exit 0 + fi + echo "🚀 Running {{.NAME}} with {{.RUNTIME}} runtime..." + set +e + if [ -n "{{.LOG_FILE}}" ]; then + if [ -n "{{.PROMPT}}" ]; then + go run ./{{.NAME}} "{{.PROMPT}}" >> {{.LOG_FILE}} 2>&1 + else + go run ./{{.NAME}} >> {{.LOG_FILE}} 2>&1 + fi + else + if [ -n "{{.PROMPT}}" ]; then + go run ./{{.NAME}} "{{.PROMPT}}" + else + go run ./{{.NAME}} + fi + fi + STATUS=$? + set -e + if [ $STATUS -eq 0 ]; then + RESULT="✅ PASS" + else + RESULT="❌ FAIL" + fi + echo "$RESULT {{.NAME}} runtime={{.RUNTIME}}" + if [ -n "{{.REPORT_FILE}}" ]; then + echo "$RESULT {{.NAME}}" >> {{.REPORT_FILE}} + fi + if [ -n "{{.COUNT_FILE}}" ]; then + read -r PASS FAIL < "{{.COUNT_FILE}}" 2>/dev/null || PASS=0 FAIL=0 + if [ "$STATUS" -eq 0 ]; then + PASS=$((PASS + 1)) + else + FAIL=$((FAIL + 1)) + fi + printf '%s %s\n' "$PASS" "$FAIL" > "{{.COUNT_FILE}}" + fi + env: + AGENT_RUNTIME: '{{.RUNTIME}}' + # Batch only — not in .env.defaults; manual go run leaves unset (interactive y/n). + EXAMPLES_AUTO_APPROVE: '{{.EXAMPLES_AUTO_APPROVE | default "true"}}' + + exec:examples: + internal: true + vars: + RUNTIME: '{{.RUNTIME | default "local"}}' + SKIP_RUN: '{{.SKIP_RUN | default "false"}}' + PLAN_MODE: '{{.SKIP_RUN | default "false"}}' + WRITE_SUMMARY: '{{.WRITE_SUMMARY | default "true"}}' + FAIL_ON_ERROR: '{{.FAIL_ON_ERROR | default "true"}}' + EXAMPLES: + - simple_agent + - agent_with_tools/basic + - agent_with_tools/custom + - agent_with_tools/authorizer + - agent_with_json_response + - agent_with_stream + - agent_with_reasoning + - multiple_agents + - agent_with_mcp_config + - agent_with_mcp_client + - agent_with_a2a_config + - agent_with_a2a_client + - agent_with_retriever/weaviate + - agent_with_retriever/pgvector + - agent_with_observability/config + - agent_with_observability/objects + - agent_with_subagents + - agent_with_tools/approval + - agent_with_run_async + EXAMPLES_WITH_PROMPTS: + - agent_with_conversation + - agent_with_stream_conversation + EXAMPLES_TEMPORAL: + - agent_with_temporal_client + # TODO: enable — split worker + agent REPL + # - durable_agent/worker + # - durable_agent/agent + # - agent_with_worker/worker + # - agent_with_worker/agent + cmds: + - printf '0 0\n' > {{.COUNT_FILE}} + - for: + var: EXAMPLES + ignore_error: true + task: exec:example + vars: + NAME: '{{.ITEM}}' + RUNTIME: '{{.RUNTIME}}' + REPORT_FILE: '{{.REPORT_FILE}}' + LOG_FILE: '{{.LOG_FILE}}' + COUNT_FILE: '{{.COUNT_FILE}}' + SKIP_RUN: '{{.SKIP_RUN}}' + - for: + var: EXAMPLES_WITH_PROMPTS + ignore_error: true + task: exec:example + vars: + NAME: '{{.ITEM}}' + PROMPT: "Hi" + RUNTIME: '{{.RUNTIME}}' + REPORT_FILE: '{{.REPORT_FILE}}' + LOG_FILE: '{{.LOG_FILE}}' + COUNT_FILE: '{{.COUNT_FILE}}' + SKIP_RUN: '{{.SKIP_RUN}}' + - for: + var: EXAMPLES_TEMPORAL + if: '{{eq .RUNTIME "temporal"}}' + ignore_error: true + task: exec:example + vars: + NAME: '{{.ITEM}}' + RUNTIME: '{{.RUNTIME}}' + REPORT_FILE: '{{.REPORT_FILE}}' + LOG_FILE: '{{.LOG_FILE}}' + COUNT_FILE: '{{.COUNT_FILE}}' + SKIP_RUN: '{{.SKIP_RUN}}' + - task: report:summary + if: '{{eq .WRITE_SUMMARY "true"}}' + vars: + REPORT_FILE: '{{.REPORT_FILE}}' + COUNT_FILE: '{{.COUNT_FILE}}' + RUNTIME: '{{.RUNTIME}}' + FAIL_ON_ERROR: '{{.FAIL_ON_ERROR}}' + PLAN_MODE: '{{.PLAN_MODE}}' + + # ── Examples ─────────────────────────────────────────────── + + examples:local: + desc: Run all examples with local runtime + deps: + - task: require:llm + if: '{{ne .SKIP_RUN "true"}}' + vars: + TIMESTAMP: '{{now | date "2006-01-02_15-04-05"}}' + REPORT_FILE: '{{.REPORT_FILE | default (printf "%s/reports/examples_local_%s.txt" .ROOT_DIR .TIMESTAMP)}}' + LOG_FILE: '{{.LOG_FILE | default (printf "%s/reports/examples_local_%s.log" .ROOT_DIR .TIMESTAMP)}}' + LOG_REPORT: '{{.LOG_REPORT | default (printf "reports/examples_local_%s.log" .TIMESTAMP)}}' + COUNT_FILE: '{{.COUNT_FILE | default (printf "%s/reports/examples_local_%s.count" .ROOT_DIR .TIMESTAMP)}}' + REPORT_MODE: '{{.REPORT_MODE | default "standalone"}}' + SKIP_RUN: '{{.SKIP_RUN | default "false"}}' + SKIP_INFRA: '{{.SKIP_INFRA | default "false"}}' + WRITE_SUMMARY: '{{.WRITE_SUMMARY | default "true"}}' + FAIL_ON_ERROR: '{{.FAIL_ON_ERROR | default "true"}}' + cmds: + - mkdir -p reports + - touch {{.REPORT_FILE}} + - touch {{.LOG_FILE}} + - touch {{.COUNT_FILE}} + - task: report:section + if: '{{eq .REPORT_MODE "section"}}' + vars: + REPORT_FILE: '{{.REPORT_FILE}}' + RUNTIME: local + PLAN_MODE: '{{.SKIP_RUN}}' + - task: report:header + if: '{{eq .REPORT_MODE "standalone"}}' + vars: + REPORT_FILE: '{{.REPORT_FILE}}' + RUNTIME: local + LOG_REPORT: '{{.LOG_REPORT}}' + PLAN_MODE: '{{.SKIP_RUN}}' + - defer: { task: _infra:deps:down } + if: '{{ne .SKIP_INFRA "true"}}' + - task: _infra:deps:up + if: '{{ne .SKIP_INFRA "true"}}' + - task: exec:examples + vars: + RUNTIME: local + REPORT_FILE: '{{.REPORT_FILE}}' + LOG_FILE: '{{.LOG_FILE}}' + COUNT_FILE: '{{.COUNT_FILE}}' + SKIP_RUN: '{{.SKIP_RUN}}' + WRITE_SUMMARY: '{{.WRITE_SUMMARY}}' + FAIL_ON_ERROR: '{{.FAIL_ON_ERROR}}' + - '{{if eq .REPORT_MODE "standalone"}}echo "📄 Report — {{.REPORT_FILE}}"{{end}}' + - '{{if eq .REPORT_MODE "standalone"}}echo "📄 Logs — {{.LOG_FILE}}"{{end}}' + + examples:local:plan: + desc: Plan local examples — write report only (no go run, no infra). Use instead of task --dry to preview the report file. + cmds: + - task: examples:local + vars: + SKIP_RUN: "true" + SKIP_INFRA: "true" + FAIL_ON_ERROR: "false" + + examples:temporal: + desc: Run all examples with temporal runtime + deps: + - task: require:llm + if: '{{ne .SKIP_RUN "true"}}' + vars: + TIMESTAMP: '{{now | date "2006-01-02_15-04-05"}}' + REPORT_FILE: '{{.REPORT_FILE | default (printf "%s/reports/examples_temporal_%s.txt" .ROOT_DIR .TIMESTAMP)}}' + LOG_FILE: '{{.LOG_FILE | default (printf "%s/reports/examples_temporal_%s.log" .ROOT_DIR .TIMESTAMP)}}' + LOG_REPORT: '{{.LOG_REPORT | default (printf "reports/examples_temporal_%s.log" .TIMESTAMP)}}' + COUNT_FILE: '{{.COUNT_FILE | default (printf "%s/reports/examples_temporal_%s.count" .ROOT_DIR .TIMESTAMP)}}' + REPORT_MODE: '{{.REPORT_MODE | default "standalone"}}' + SKIP_RUN: '{{.SKIP_RUN | default "false"}}' + SKIP_INFRA: '{{.SKIP_INFRA | default "false"}}' + WRITE_SUMMARY: '{{.WRITE_SUMMARY | default "true"}}' + FAIL_ON_ERROR: '{{.FAIL_ON_ERROR | default "true"}}' + cmds: + - mkdir -p reports + - touch {{.REPORT_FILE}} + - touch {{.LOG_FILE}} + - touch {{.COUNT_FILE}} + - task: report:section + if: '{{eq .REPORT_MODE "section"}}' + vars: + REPORT_FILE: '{{.REPORT_FILE}}' + RUNTIME: temporal + - task: report:header + if: '{{eq .REPORT_MODE "standalone"}}' + vars: + REPORT_FILE: '{{.REPORT_FILE}}' + RUNTIME: temporal + LOG_REPORT: '{{.LOG_REPORT}}' + PLAN_MODE: '{{.SKIP_RUN}}' + - defer: { task: _infra:temporal:down } + if: '{{ne .SKIP_INFRA "true"}}' + - defer: { task: _infra:deps:down } + if: '{{ne .SKIP_INFRA "true"}}' + - task: _infra:temporal:up + if: '{{ne .SKIP_INFRA "true"}}' + - task: _infra:temporal:wait + if: '{{ne .SKIP_INFRA "true"}}' + - task: _infra:deps:up + if: '{{ne .SKIP_INFRA "true"}}' + - task: exec:examples + vars: + RUNTIME: temporal + REPORT_FILE: '{{.REPORT_FILE}}' + LOG_FILE: '{{.LOG_FILE}}' + COUNT_FILE: '{{.COUNT_FILE}}' + SKIP_RUN: '{{.SKIP_RUN}}' + WRITE_SUMMARY: '{{.WRITE_SUMMARY}}' + FAIL_ON_ERROR: '{{.FAIL_ON_ERROR}}' + - '{{if eq .REPORT_MODE "standalone"}}echo "📄 Report — {{.REPORT_FILE}}"{{end}}' + - '{{if eq .REPORT_MODE "standalone"}}echo "📄 Logs — {{.LOG_FILE}}"{{end}}' + + examples:temporal:plan: + desc: Plan temporal examples — write report only (no go run, no infra). Use instead of task --dry to preview the report file. + cmds: + - task: examples:temporal + vars: + SKIP_RUN: "true" + SKIP_INFRA: "true" + FAIL_ON_ERROR: "false" + + examples:all: + desc: Run all examples on both runtimes + deps: + - task: require:llm + if: '{{ne .PLAN_MODE "true"}}' + vars: + TIMESTAMP: '{{now | date "2006-01-02_15-04-05"}}' + REPORT_FILE: '{{.REPORT_FILE | default (printf "%s/reports/examples_all_%s.txt" .ROOT_DIR .TIMESTAMP)}}' + LOG_FILE: '{{.LOG_FILE | default (printf "%s/reports/examples_all_%s.log" .ROOT_DIR .TIMESTAMP)}}' + LOG_REPORT: '{{printf "reports/examples_all_%s.log" .TIMESTAMP}}' + COUNT_FILE_LOCAL: '{{printf "%s/reports/examples_all_%s.local.count" .ROOT_DIR .TIMESTAMP}}' + COUNT_FILE_TEMPORAL: '{{printf "%s/reports/examples_all_%s.temporal.count" .ROOT_DIR .TIMESTAMP}}' + PLAN_MODE: '{{.PLAN_MODE | default "false"}}' + cmds: + - mkdir -p reports + - touch {{.REPORT_FILE}} + - touch {{.LOG_FILE}} + - task: report:header + vars: + REPORT_FILE: '{{.REPORT_FILE}}' + RUNTIME: all (local + temporal) + LOG_REPORT: '{{.LOG_REPORT}}' + PLAN_MODE: '{{.PLAN_MODE}}' + - task: examples:local + vars: + REPORT_FILE: '{{.REPORT_FILE}}' + LOG_FILE: '{{.LOG_FILE}}' + COUNT_FILE: '{{.COUNT_FILE_LOCAL}}' + REPORT_MODE: section + SKIP_RUN: '{{.PLAN_MODE}}' + SKIP_INFRA: '{{.PLAN_MODE}}' + WRITE_SUMMARY: "false" + FAIL_ON_ERROR: "false" + - task: examples:temporal + vars: + REPORT_FILE: '{{.REPORT_FILE}}' + LOG_FILE: '{{.LOG_FILE}}' + COUNT_FILE: '{{.COUNT_FILE_TEMPORAL}}' + REPORT_MODE: section + SKIP_RUN: '{{.PLAN_MODE}}' + SKIP_INFRA: '{{.PLAN_MODE}}' + WRITE_SUMMARY: "false" + FAIL_ON_ERROR: "false" + - | + read -r LPASS LFAIL < "{{.COUNT_FILE_LOCAL}}" 2>/dev/null || LPASS=0 LFAIL=0 + read -r TPASS TFAIL < "{{.COUNT_FILE_TEMPORAL}}" 2>/dev/null || TPASS=0 TFAIL=0 + PASS=$((LPASS + TPASS)) + FAIL=$((LFAIL + TFAIL)) + TOTAL=$((PASS + FAIL)) + if [ "{{.PLAN_MODE}}" = "true" ]; then + STATUS_LINE=" Status: ○ PLAN ONLY (not executed)" + elif [ "${FAIL}" -gt 0 ]; then + STATUS_LINE=" Status: ❌ FAILURES (${FAIL})" + else + STATUS_LINE=" Status: ✅ ALL PASSED" + fi + { + echo "" + echo "========================================" + echo "SUMMARY" + echo " Runtimes: local + temporal" + echo " Total: ${TOTAL}" + } >> "{{.REPORT_FILE}}" + if [ "{{.PLAN_MODE}}" = "true" ]; then + { + echo " Planned: ${PASS}" + echo " Pass: 0" + echo " Fail: 0" + echo " local: planned=${LPASS}" + echo " temporal: planned=${TPASS}" + } >> "{{.REPORT_FILE}}" + else + { + echo " Pass: ${PASS}" + echo " Fail: ${FAIL}" + echo " local: pass=${LPASS} fail=${LFAIL}" + echo " temporal: pass=${TPASS} fail=${TFAIL}" + } >> "{{.REPORT_FILE}}" + fi + { + echo "${STATUS_LINE}" + echo "========================================" + } >> "{{.REPORT_FILE}}" + rm -f "{{.COUNT_FILE_LOCAL}}" "{{.COUNT_FILE_TEMPORAL}}" + if [ "{{.PLAN_MODE}}" = "true" ]; then + echo "📊 plan=${PASS} (not executed) (local=${LPASS}, temporal=${TPASS})" + else + echo "📊 total=${TOTAL} pass=${PASS} fail=${FAIL} (local pass=${LPASS} fail=${LFAIL}, temporal pass=${TPASS} fail=${TFAIL})" + fi + if [ "${FAIL}" -gt 0 ]; then + echo "❌ Examples failed: ${FAIL}" + exit 1 + fi + - echo "📄 Report — {{.REPORT_FILE}}" + - echo "📄 Logs — {{.LOG_FILE}}" + + examples:all:plan: + desc: Plan all examples — combined report only (no go run, no infra) + cmds: + - task: examples:all + vars: + PLAN_MODE: "true" diff --git a/taskfiles/reports.yml b/taskfiles/reports.yml new file mode 100644 index 0000000..be2586a --- /dev/null +++ b/taskfiles/reports.yml @@ -0,0 +1,8 @@ +version: '3' + +tasks: + reports:clean: + desc: Clean all report files + cmds: + - rm -rf "{{.ROOT_DIR}}/reports/" + - echo "🗑️ Reports cleaned" \ No newline at end of file