diff --git a/README.md b/README.md index 202dd03..315308f 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,7 @@ - [Sub-agents](#sub-agents) - [Managing Capabilities at Runtime](#managing-capabilities-at-runtime) - [Approvals](#approvals) + - [Hooks](#hooks) - [Timeouts and deadlines](#timeouts-and-deadlines) - [Custom tools](#custom-tools) - [Response format](#response-format) @@ -1082,6 +1083,243 @@ For **Run** / **RunAsync**, use `req.Respond` only. For **Stream**, use `**OnApp - **Stream:** An `AgentEventError` is emitted on the event channel with the error message. - **RunAsync:** `resultCh` receives `AgentRunAsyncResult` with `Error` set. +### Hooks + +Hooks let you register middleware callbacks that fire at key points in the agent execution lifecycle — before and after LLM calls, tool executions, retrievals, and memory operations. Register once via `WithHooks` and the SDK calls your functions automatically across both `Run` and `Stream`. + +#### Hook points + +| Hook | Fires | Mutable outputs | +|------|-------|----------------| +| `BeforeLLMHook` | Before each LLM request | Request (messages, system prompt, tools, params) | +| `AfterLLMHook` | After each LLM response | Response (content, tool calls) | +| `BeforeToolHook` | Before native/MCP tool executes | Args | +| `AfterToolHook` | After native/MCP tool executes | Content, Err | +| `BeforeRetrieveHook` | Before each retriever search | Query | +| `AfterRetrieveHook` | After each retriever search | Documents | +| `BeforeMemoryLoadHook` | Before memory recall | Query, Limit, MinScore, Kinds | +| `AfterMemoryLoadHook` | After memory recall, before injection into prompt | PromptContext | +| `BeforeMemoryStoreHook` | Before each memory record is persisted | Record, ID | +| `AfterMemoryStoreHook` | After memory is persisted | — (error/abort only) | + +Every hook receives `RunMeta` (`RunID`, `Iteration`, `HooksGroup`) for correlation and auditing. Return an error from any hook to abort the run immediately. + +#### Register hooks + +```go +a, err := agent.NewAgent( + agent.WithTemporalConfig(...), + agent.WithLLMClient(llmClient), + agent.WithHooks("guardrails", agent.AgentHooks{ + BeforeLLM: []agent.BeforeLLMHook{ + func(ctx context.Context, in agent.BeforeLLMHookInput) (agent.BeforeLLMHookOutput, error) { + // inspect or modify in.Request before the LLM call + return agent.BeforeLLMHookOutput{Request: in.Request}, nil + }, + }, + AfterLLM: []agent.AfterLLMHook{ + func(ctx context.Context, in agent.AfterLLMHookInput) (agent.AfterLLMHookOutput, error) { + // inspect or modify in.Response after the LLM call + return agent.AfterLLMHookOutput{Response: in.Response}, nil + }, + }, + }), +) +``` + +Call `WithHooks` multiple times to register independent groups. Each group runs with its own name set in `RunMeta.HooksGroup`, so hooks know which group they belong to: + +```go +a, err := agent.NewAgent( + agent.WithHooks("pii-scrubber", piiHooks), + agent.WithHooks("cost-tracker", costHooks), + agent.WithHooks("audit-logger", auditHooks), + // ... +) +``` + +#### Hook use cases + +##### Guardrails and safety + +Block unsafe inputs or outputs before they reach the LLM or the user: + +```go +agent.WithHooks("safety", agent.AgentHooks{ + BeforeLLM: []agent.BeforeLLMHook{ + func(ctx context.Context, in agent.BeforeLLMHookInput) (agent.BeforeLLMHookOutput, error) { + if containsInjection(in.Request.Messages) { + return agent.BeforeLLMHookOutput{}, errors.New("prompt injection detected") + } + return agent.BeforeLLMHookOutput{Request: in.Request}, nil + }, + }, + AfterLLM: []agent.AfterLLMHook{ + func(ctx context.Context, in agent.AfterLLMHookInput) (agent.AfterLLMHookOutput, error) { + if isToxic(in.Response.Content) { + return agent.AfterLLMHookOutput{}, errors.New("unsafe output blocked") + } + return agent.AfterLLMHookOutput{Response: in.Response}, nil + }, + }, +}) +``` + +**Hooks for this use case:** `BeforeLLM`, `AfterLLM` + +##### PII and data privacy + +Scrub sensitive data before it leaves your system or gets persisted: + +```go +agent.WithHooks("pii", agent.AgentHooks{ + BeforeLLM: []agent.BeforeLLMHook{ + func(ctx context.Context, in agent.BeforeLLMHookInput) (agent.BeforeLLMHookOutput, error) { + req := in.Request + req.SystemMessage = redactPII(req.SystemMessage) + return agent.BeforeLLMHookOutput{Request: req}, nil + }, + }, + BeforeMemoryStore: []agent.BeforeMemoryStoreHook{ + func(ctx context.Context, in agent.BeforeMemoryStoreHookInput) (agent.BeforeMemoryStoreHookOutput, error) { + rec := in.Record + rec.Text = redactPII(rec.Text) + return agent.BeforeMemoryStoreHookOutput{Record: rec}, nil + }, + }, +}) +``` + +**Hooks for this use case:** `BeforeLLM`, `AfterLLM`, `BeforeTool`, `BeforeRetrieve`, `AfterRetrieve`, `BeforeMemoryStore`, `AfterMemoryLoad` + +##### Token budget and cost tracking + +Track and enforce token budgets across LLM calls: + +```go +agent.WithHooks("cost", agent.AgentHooks{ + AfterLLM: []agent.AfterLLMHook{ + func(ctx context.Context, in agent.AfterLLMHookInput) (agent.AfterLLMHookOutput, error) { + if in.Response.Usage != nil { + trackTokens(in.RunMeta.RunID, in.Response.Usage.TotalTokens) + if budgetExceeded(in.RunMeta.RunID) { + return agent.AfterLLMHookOutput{}, errors.New("token budget exceeded") + } + } + return agent.AfterLLMHookOutput{Response: in.Response}, nil + }, + }, +}) +``` + +**Hooks for this use case:** `BeforeLLM` (cache read / short-circuit), `AfterLLM` (token tracking, cache write, model fallback) + +##### Memory tenant isolation + +Enforce that each tenant only reads and writes their own memories: + +```go +agent.WithHooks("tenant-isolation", agent.AgentHooks{ + BeforeMemoryLoad: []agent.BeforeMemoryLoadHook{ + func(ctx context.Context, in agent.BeforeMemoryLoadHookInput) (agent.BeforeMemoryLoadHookOutput, error) { + if in.Scope.TenantID == "" { + return agent.BeforeMemoryLoadHookOutput{}, errors.New("tenant ID required for memory access") + } + return agent.BeforeMemoryLoadHookOutput{ + Query: in.Query, Limit: in.Limit, + MinScore: in.MinScore, Kinds: in.Kinds, + }, nil + }, + }, + BeforeMemoryStore: []agent.BeforeMemoryStoreHook{ + func(ctx context.Context, in agent.BeforeMemoryStoreHookInput) (agent.BeforeMemoryStoreHookOutput, error) { + if in.Scope.TenantID == "" { + return agent.BeforeMemoryStoreHookOutput{}, errors.New("tenant ID required for memory store") + } + return agent.BeforeMemoryStoreHookOutput{Record: in.Record, ID: in.ID}, nil + }, + }, +}) +``` + +**Hooks for this use case:** `BeforeMemoryLoad`, `BeforeMemoryStore` + +##### Logging and auditing + +Log every operation for observability without touching business logic: + +```go +agent.WithHooks("audit", agent.AgentHooks{ + AfterLLM: []agent.AfterLLMHook{ + func(ctx context.Context, in agent.AfterLLMHookInput) (agent.AfterLLMHookOutput, error) { + log.Printf("run=%s iter=%d tokens=%d", in.RunMeta.RunID, in.RunMeta.Iteration, in.Response.Usage.TotalTokens) + return agent.AfterLLMHookOutput{Response: in.Response}, nil + }, + }, + AfterTool: []agent.AfterToolHook{ + func(ctx context.Context, in agent.AfterToolHookInput) (agent.AfterToolHookOutput, error) { + log.Printf("run=%s tool=%s err=%v", in.RunMeta.RunID, in.Call.Name, in.Err) + return agent.AfterToolHookOutput{Content: in.Content, Err: in.Err}, nil + }, + }, + AfterMemoryStore: []agent.AfterMemoryStoreHook{ + func(ctx context.Context, in agent.AfterMemoryStoreHookInput) (agent.AfterMemoryStoreHookOutput, error) { + log.Printf("run=%s memory stored id=%s", in.RunMeta.RunID, in.ID) + return agent.AfterMemoryStoreHookOutput{}, nil + }, + }, +}) +``` + +**Hooks for this use case:** `BeforeLLM`, `AfterLLM`, `BeforeTool`, `AfterTool`, `BeforeRetrieve`, `AfterRetrieve`, `AfterMemoryLoad`, `AfterMemoryStore` + +##### Retrieval query rewriting and filtering + +Modify the query before search and filter or re-rank results after: + +```go +agent.WithHooks("retrieval", agent.AgentHooks{ + BeforeRetrieve: []agent.BeforeRetrieveHook{ + func(ctx context.Context, in agent.BeforeRetrieveHookInput) (agent.BeforeRetrieveHookOutput, error) { + return agent.BeforeRetrieveHookOutput{Query: expandQuery(in.Query)}, nil + }, + }, + AfterRetrieve: []agent.AfterRetrieveHook{ + func(ctx context.Context, in agent.AfterRetrieveHookInput) (agent.AfterRetrieveHookOutput, error) { + return agent.AfterRetrieveHookOutput{Documents: rerank(in.Documents)}, nil + }, + }, +}) +``` + +**Hooks for this use case:** `BeforeRetrieve` (query rewriting), `AfterRetrieve` (filtering, re-ranking) + +#### Use case reference + +| Goal | Hooks to use | +|------|-------------| +| Prompt injection / input guardrails | `BeforeLLM` | +| Output content filtering | `AfterLLM` | +| PII scrubbing from prompts / responses | `BeforeLLM`, `AfterLLM` | +| PII scrubbing from tool args / results | `BeforeTool`, `AfterTool` | +| PII scrubbing from retrieval | `BeforeRetrieve`, `AfterRetrieve` | +| PII scrubbing from memory | `BeforeMemoryStore`, `AfterMemoryLoad` | +| Token tracking and budget enforcement | `AfterLLM` | +| LLM response caching | `BeforeLLM` (read), `AfterLLM` (write) | +| Model fallback on error | `AfterLLM` | +| Tool rate limiting | `BeforeTool` | +| Tool input validation / authorization | `BeforeTool` | +| Log all operations | `BeforeLLM`, `AfterLLM`, `BeforeTool`, `AfterTool`, `BeforeRetrieve`, `AfterRetrieve`, `AfterMemoryLoad`, `AfterMemoryStore` | +| Retrieval query rewriting | `BeforeRetrieve` | +| Document re-ranking / filtering | `AfterRetrieve` | +| Memory query rewriting | `BeforeMemoryLoad` | +| Filter injected memory context | `AfterMemoryLoad` | +| Memory tenant isolation | `BeforeMemoryLoad`, `BeforeMemoryStore` | +| Control what gets persisted | `BeforeMemoryStore` | +| Audit memory reads / writes | `AfterMemoryLoad`, `AfterMemoryStore` | + +See [examples/agent_with_hooks](examples/agent_with_hooks) for a runnable demo that registers every hook point (PII scrubbing, retrieval filtering, memory tenant checks). Hook activity is logged to stderr with a `[hooks]` prefix. + ### Timeouts and deadlines You can limit run duration in two ways: diff --git a/examples/.env.defaults b/examples/.env.defaults index a3aaede..4f57f72 100644 --- a/examples/.env.defaults +++ b/examples/.env.defaults @@ -21,7 +21,7 @@ REDIS_ADDR=localhost:6379 # --- Temporal (when AGENT_RUNTIME=temporal) --- # TEMPORAL_TASKQUEUE is a base prefix; each example appends its suffix (e.g. agent-sdk-go-simple_agent). -TEMPORAL_HOST=localhost +TEMPORAL_HOST=127.0.0.1 TEMPORAL_PORT=7233 TEMPORAL_NAMESPACE=default TEMPORAL_TASKQUEUE=agent-sdk-go diff --git a/examples/README.md b/examples/README.md index 3f567cf..950d931 100644 --- a/examples/README.md +++ b/examples/README.md @@ -46,6 +46,7 @@ These examples run with `AGENT_RUNTIME=local` (default) or `AGENT_RUNTIME=tempor | `agent_with_observability` | OTLP — **`config/`** vs **`objects/`**; **[README](agent_with_observability/README.md)** | `infra:lgtm:up` (or manual collector) | | `agent_with_retriever` | **`weaviate/`** or **`pgvector/`**; **`RETRIEVER_MODE`** — **[README](agent_with_retriever/README.md)** | `infra:weaviate:up` or `infra:pgvector:up` | | `agent_with_memory` | **`weaviate/`** or **`pgvector/`** — **[README](agent_with_memory/README.md)**; `MEMORY_STORE_MODE=always\|ondemand` | `infra:weaviate:up` or `infra:pgvector:up` | +| `agent_with_hooks` | All middleware hooks — PII scrubbing, retrieval filtering, memory tenant checks; **[README](agent_with_hooks/README.md)** | — | ### Temporal only @@ -104,6 +105,17 @@ go run ./agent_with_tools/custom "Reverse 'hello world'" go run ./agent_with_tools/dynamic_registry ``` +### Agent with hooks + +Middleware hooks across LLM, tools, retrieval, and memory. Hook activity is logged to stderr (`[hooks]` prefix). + +```bash +go run ./agent_with_hooks +go run ./agent_with_hooks "My email is alice@example.com. What is the return policy?" +``` + +See **[agent_with_hooks/README.md](agent_with_hooks/README.md)**. When using **`AGENT_RUNTIME=temporal`**, register the same hook groups on both the agent starter and the worker. + ### Streaming (partial content as tokens arrive) ```bash diff --git a/examples/agent_with_hooks/README.md b/examples/agent_with_hooks/README.md new file mode 100644 index 0000000..864d059 --- /dev/null +++ b/examples/agent_with_hooks/README.md @@ -0,0 +1,31 @@ +# agent_with_hooks + +Demonstrates **every** middleware hook point in the SDK with realistic transformations: + +| Hook | What this example does | +|------|------------------------| +| `BeforeLLM` / `AfterLLM` | Redact email addresses and SSNs from prompts and responses | +| `BeforeTool` / `AfterTool` | Scrub PII from tool args and results | +| `BeforeRetrieve` / `AfterRetrieve` | Prefix queries with `kb:`; drop documents containing SSNs | +| `BeforeMemoryLoad` / `AfterMemoryLoad` | Require `tenant_id` in scope; wrap recalled context with a scrubbed header | +| `BeforeMemoryStore` / `AfterMemoryStore` | Scrub PII before persist; audit log after store | + +Hook activity is printed to **stderr** with a `[hooks]` prefix so you can see when each hook fires without mixing into the assistant reply. + +## Run + +From `examples/`: + +```bash +go run ./agent_with_hooks +``` + +Default: two-run demo (store memories + prefetch retrieval + tools, then recall). + +```bash +go run ./agent_with_hooks "My email is alice@example.com. What is the return policy?" +``` + +## Temporal + +Hooks are Go functions — they run in the **process that executes activities** (the worker). Register the same hook groups on both the agent starter and the worker via `HookOptions()` (or equivalent `WithHooks` calls). Group **names** are fingerprinted for drift detection; hook **logic** consistency is your responsibility. diff --git a/examples/agent_with_hooks/hooks.go b/examples/agent_with_hooks/hooks.go new file mode 100644 index 0000000..b1848e9 --- /dev/null +++ b/examples/agent_with_hooks/hooks.go @@ -0,0 +1,159 @@ +package main + +import ( + "context" + "fmt" + "os" + "regexp" + "strings" + + "github.com/agenticenv/agent-sdk-go/pkg/agent" + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" +) + +var ( + emailPattern = regexp.MustCompile(`[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}`) + ssnPattern = regexp.MustCompile(`\b\d{3}-\d{2}-\d{4}\b`) +) + +func redactPII(s string) string { + s = emailPattern.ReplaceAllString(s, "[REDACTED_EMAIL]") + s = ssnPattern.ReplaceAllString(s, "[REDACTED_SSN]") + return s +} + +func hookLog(hook, detail string) { + fmt.Fprintf(os.Stderr, "[hooks] %s: %s\n", hook, detail) +} + +// HookOptions returns [agent.WithHooks] groups that demonstrate every hook point. +// Register the same groups on both the agent starter and Temporal worker when using AGENT_RUNTIME=temporal. +func HookOptions() []agent.Option { + return []agent.Option{ + agent.WithHooks("pii-scrubber", agent.AgentHooks{ + BeforeLLM: []agent.BeforeLLMHook{beforeLLMRedact}, + AfterLLM: []agent.AfterLLMHook{afterLLMRedact}, + BeforeTool: []agent.BeforeToolHook{ + func(ctx context.Context, in agent.BeforeToolHookInput) (agent.BeforeToolHookOutput, error) { + args := cloneArgs(in.Call.Args) + for k, v := range args { + if s, ok := v.(string); ok { + args[k] = redactPII(s) + } + } + hookLog("BeforeTool", fmt.Sprintf("group=%s tool=%s args scrubbed", in.RunMeta.HooksGroup, in.Call.Name)) + return agent.BeforeToolHookOutput{Args: args}, nil + }, + }, + AfterTool: []agent.AfterToolHook{ + func(ctx context.Context, in agent.AfterToolHookInput) (agent.AfterToolHookOutput, error) { + content := redactPII(in.Content) + hookLog("AfterTool", fmt.Sprintf("group=%s tool=%s result scrubbed", in.RunMeta.HooksGroup, in.Call.Name)) + return agent.AfterToolHookOutput{Content: content, Err: in.Err}, nil + }, + }, + BeforeRetrieve: []agent.BeforeRetrieveHook{ + func(ctx context.Context, in agent.BeforeRetrieveHookInput) (agent.BeforeRetrieveHookOutput, error) { + q := strings.TrimSpace(in.Query) + if q != "" && !strings.HasPrefix(q, "kb:") { + q = "kb: " + q + } + hookLog("BeforeRetrieve", fmt.Sprintf("group=%s query=%q", in.RunMeta.HooksGroup, q)) + return agent.BeforeRetrieveHookOutput{Query: q}, nil + }, + }, + AfterRetrieve: []agent.AfterRetrieveHook{ + func(ctx context.Context, in agent.AfterRetrieveHookInput) (agent.AfterRetrieveHookOutput, error) { + filtered := make([]interfaces.Document, 0, len(in.Documents)) + for _, doc := range in.Documents { + if ssnPattern.MatchString(doc.Content) { + hookLog("AfterRetrieve", fmt.Sprintf("group=%s dropped doc from %s (SSN)", in.RunMeta.HooksGroup, doc.Source)) + continue + } + doc.Content = redactPII(doc.Content) + filtered = append(filtered, doc) + } + return agent.AfterRetrieveHookOutput{Documents: filtered}, nil + }, + }, + BeforeMemoryLoad: []agent.BeforeMemoryLoadHook{ + func(ctx context.Context, in agent.BeforeMemoryLoadHookInput) (agent.BeforeMemoryLoadHookOutput, error) { + if in.Scope.TenantID == "" { + return agent.BeforeMemoryLoadHookOutput{}, fmt.Errorf("tenant ID required for memory recall") + } + hookLog("BeforeMemoryLoad", fmt.Sprintf("group=%s tenant=%s query=%q", in.RunMeta.HooksGroup, in.Scope.TenantID, in.Query)) + return agent.BeforeMemoryLoadHookOutput{ + Query: in.Query, Limit: in.Limit, MinScore: in.MinScore, Kinds: in.Kinds, + }, nil + }, + }, + AfterMemoryLoad: []agent.AfterMemoryLoadHook{ + func(ctx context.Context, in agent.AfterMemoryLoadHookInput) (agent.AfterMemoryLoadHookOutput, error) { + ctxBlock := redactPII(in.PromptContext) + if strings.TrimSpace(ctxBlock) != "" { + ctxBlock = "## Memories (scrubbed)\n" + ctxBlock + } + hookLog("AfterMemoryLoad", fmt.Sprintf("group=%s context len=%d", in.RunMeta.HooksGroup, len(ctxBlock))) + return agent.AfterMemoryLoadHookOutput{PromptContext: ctxBlock}, nil + }, + }, + BeforeMemoryStore: []agent.BeforeMemoryStoreHook{ + func(ctx context.Context, in agent.BeforeMemoryStoreHookInput) (agent.BeforeMemoryStoreHookOutput, error) { + if in.Scope.TenantID == "" { + return agent.BeforeMemoryStoreHookOutput{}, fmt.Errorf("tenant ID required for memory store") + } + rec := in.Record + rec.Text = redactPII(rec.Text) + hookLog("BeforeMemoryStore", fmt.Sprintf("group=%s tenant=%s text scrubbed", in.RunMeta.HooksGroup, in.Scope.TenantID)) + return agent.BeforeMemoryStoreHookOutput{Record: rec, ID: in.ID}, nil + }, + }, + AfterMemoryStore: []agent.AfterMemoryStoreHook{ + func(ctx context.Context, in agent.AfterMemoryStoreHookInput) (agent.AfterMemoryStoreHookOutput, error) { + hookLog("AfterMemoryStore", fmt.Sprintf("group=%s id=%s", in.RunMeta.HooksGroup, in.ID)) + return agent.AfterMemoryStoreHookOutput{}, nil + }, + }, + }), + agent.WithHooks("audit", agent.AgentHooks{ + AfterLLM: []agent.AfterLLMHook{ + func(ctx context.Context, in agent.AfterLLMHookInput) (agent.AfterLLMHookOutput, error) { + tokens := int64(0) + if in.Response.Usage != nil { + tokens = in.Response.Usage.TotalTokens + } + hookLog("AfterLLM", fmt.Sprintf("group=%s run=%s iter=%d tokens=%d", in.RunMeta.HooksGroup, in.RunMeta.RunID, in.RunMeta.Iteration, tokens)) + return agent.AfterLLMHookOutput{Response: in.Response}, nil + }, + }, + }), + } +} + +func beforeLLMRedact(_ context.Context, in agent.BeforeLLMHookInput) (agent.BeforeLLMHookOutput, error) { + req := in.Request + req.SystemMessage = redactPII(req.SystemMessage) + for i := range req.Messages { + req.Messages[i].Content = redactPII(req.Messages[i].Content) + } + hookLog("BeforeLLM", fmt.Sprintf("group=%s run=%s iter=%d messages scrubbed", in.RunMeta.HooksGroup, in.RunMeta.RunID, in.RunMeta.Iteration)) + return agent.BeforeLLMHookOutput{Request: req}, nil +} + +func afterLLMRedact(_ context.Context, in agent.AfterLLMHookInput) (agent.AfterLLMHookOutput, error) { + resp := in.Response + resp.Content = redactPII(resp.Content) + hookLog("AfterLLM", fmt.Sprintf("group=%s response scrubbed", in.RunMeta.HooksGroup)) + return agent.AfterLLMHookOutput{Response: resp}, nil +} + +func cloneArgs(args map[string]any) map[string]any { + if len(args) == 0 { + return nil + } + out := make(map[string]any, len(args)) + for k, v := range args { + out[k] = v + } + return out +} diff --git a/examples/agent_with_hooks/main.go b/examples/agent_with_hooks/main.go new file mode 100644 index 0000000..201457c --- /dev/null +++ b/examples/agent_with_hooks/main.go @@ -0,0 +1,112 @@ +// Example agent demonstrating all middleware hook points. +// +// Run from examples/: +// +// go run ./agent_with_hooks +// go run ./agent_with_hooks "My email is alice@example.com. What is the return policy?" +// +// Hook activity is printed to stderr with a [hooks] prefix. When AGENT_RUNTIME=temporal, +// register the same [HookOptions] on both the agent starter and the worker process. +package main + +import ( + "context" + "fmt" + "log" + "os" + "strings" + + config "github.com/agenticenv/agent-sdk-go/examples" + "github.com/agenticenv/agent-sdk-go/examples/shared" + "github.com/agenticenv/agent-sdk-go/pkg/agent" + "github.com/agenticenv/agent-sdk-go/pkg/memory" + "github.com/agenticenv/agent-sdk-go/pkg/tools/calculator" + "github.com/agenticenv/agent-sdk-go/pkg/tools/echo" +) + +const demoTenantID = "tenant-demo" +const demoUserID = "user-demo" + +func main() { + cfg := config.LoadFromEnv() + + llmClient, err := config.NewLLMClientFromConfig(cfg) + if err != nil { + log.Fatalf("failed to create LLM client: %v", err) + } + + memStore := newDemoMemory() + memCfg := memory.DefaultConfig(memStore) + memCfg.Store.Mode = memory.StoreModeAlways + memCfg.Store.Extract = demoMemoryExtract + memCfg.Recall.Enabled = true + memCfg.Recall.Limit = 5 + + opts := []agent.Option{ + agent.WithName("agent-with-hooks"), + agent.WithDescription("Demonstrates all agent middleware hooks"), + agent.WithSystemPrompt( + "You are a helpful assistant. Use retrieved knowledge when present. " + + "Use tools when asked for calculations or echo. " + + "Answer concisely in bullet points when the user asks about preferences.", + ), + agent.WithLLMClient(llmClient), + agent.WithLogger(config.NewLoggerFromLogConfig(cfg)), + agent.WithTools(echo.New(), calculator.New()), + agent.WithToolApprovalPolicy(agent.AutoToolApprovalPolicy()), + agent.WithRetrievers(demoRetriever{}), + agent.WithRetrieverMode(agent.RetrieverModePrefetch), + agent.WithMemory(memCfg), + } + opts = append(opts, config.RuntimeOption(cfg)...) + opts = append(opts, config.ToolApprovalOptions()...) + opts = append(opts, HookOptions()...) + + a, err := agent.NewAgent(opts...) + if err != nil { + log.Fatal(config.FormatNewAgentError("failed to create agent", err)) + } + defer a.Close() + + ctx := memory.WithContextUserID( + memory.WithContextTenantID(context.Background(), demoTenantID), + demoUserID, + ) + + args := os.Args[1:] + if len(args) == 0 { + runDemo(ctx, a) + return + } + runOnce(ctx, a, "custom", strings.Join(args, " ")) +} + +func runDemo(ctx context.Context, a *agent.Agent) { + fmt.Println("=== agent_with_hooks demo (two runs) ===") + fmt.Println("Hook log lines go to stderr. Look for [hooks] BeforeLLM, AfterRetrieve, BeforeMemoryStore, etc.") + fmt.Println() + + runOnce(ctx, a, "run 1 (store + prefetch + tools)", + "My email is alice@example.com. "+ + "What is the return policy according to the knowledge base? "+ + "Echo the phrase hooks-demo-ok and compute 12 * 8.", + ) + + runOnce(ctx, a, "run 2 (memory recall)", + "What answer style do I prefer?", + ) +} + +func runOnce(ctx context.Context, a *agent.Agent, label, prompt string) { + fmt.Printf("\n--- %s ---\n", label) + fmt.Println("user:", prompt) + fmt.Fprintln(os.Stderr, "--- hook activity ---") + + result, err := a.Run(ctx, prompt, nil) + if err != nil { + log.Printf("%s failed: %v", label, err) + return + } + fmt.Println("assistant:", result.Content) + shared.PrintRunFooters(result) +} diff --git a/examples/agent_with_hooks/stubs.go b/examples/agent_with_hooks/stubs.go new file mode 100644 index 0000000..32b8ca7 --- /dev/null +++ b/examples/agent_with_hooks/stubs.go @@ -0,0 +1,150 @@ +package main + +import ( + "context" + "strings" + "sync" + "time" + + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" + "github.com/google/uuid" +) + +// demoRetriever returns fixed knowledge-base documents for prefetch / agentic retrieval demos. +type demoRetriever struct{} + +func (demoRetriever) Name() string { return "demo-kb" } + +func (demoRetriever) Search(_ context.Context, query string) ([]interfaces.Document, error) { + _ = query + return []interfaces.Document{ + { + Content: "Return policy: full refund within 30 days. Customer record includes SSN 123-45-6789.", + Source: "kb/returns", + Score: 0.92, + Metadata: map[string]any{"section": "returns"}, + }, + { + Content: "Shipping is free on orders over $50.", + Source: "kb/shipping", + Score: 0.81, + Metadata: map[string]any{"section": "shipping"}, + }, + }, nil +} + +// demoMemory is a minimal in-process [interfaces.Memory] for examples (no external DB). +type demoMemory struct { + mu sync.RWMutex + records map[string]demoMemRecord +} + +type demoMemRecord struct { + entry interfaces.MemoryEntry +} + +func newDemoMemory() *demoMemory { + return &demoMemory{records: make(map[string]demoMemRecord)} +} + +func (m *demoMemory) Store(_ context.Context, scope interfaces.MemoryScope, rec interfaces.MemoryRecord, opts ...interfaces.StoreMemoryOption) (string, error) { + storeOpts := interfaces.StoreMemoryOptions{} + for _, opt := range opts { + opt(&storeOpts) + } + id := strings.TrimSpace(storeOpts.ID) + if id == "" { + id = uuid.NewString() + } + + now := time.Now().UTC() + m.mu.Lock() + defer m.mu.Unlock() + + createdAt := now + if existing, ok := m.records[id]; ok { + createdAt = existing.entry.CreatedAt + } + m.records[id] = demoMemRecord{entry: interfaces.MemoryEntry{ + ID: id, + Text: rec.Text, + Kind: rec.Kind, + Scope: scope, + Metadata: rec.Metadata, + ExpiresAt: rec.ExpiresAt, + CreatedAt: createdAt, + UpdatedAt: now, + }} + return id, nil +} + +func (m *demoMemory) Load(_ context.Context, scope interfaces.MemoryScope, query string, opts ...interfaces.LoadMemoryOption) ([]interfaces.MemoryEntry, error) { + loadOpts := interfaces.LoadMemoryOptions{} + for _, opt := range opts { + opt(&loadOpts) + } + limit := loadOpts.Limit + if limit <= 0 { + limit = 10 + } + + m.mu.RLock() + defer m.mu.RUnlock() + + query = strings.ToLower(strings.TrimSpace(query)) + recencyOnly := query == "" + var out []interfaces.MemoryEntry + for _, rec := range m.records { + entry := rec.entry + if entry.Expired() { + continue + } + if !demoScopeMatches(entry.Scope, scope) { + continue + } + if query != "" && !strings.Contains(strings.ToLower(entry.Text), query) { + continue + } + if !recencyOnly { + entry.Score = 1.0 + } + out = append(out, entry) + } + if len(out) > limit { + out = out[:limit] + } + return out, nil +} + +func (m *demoMemory) Clear(_ context.Context, scope interfaces.MemoryScope) error { + m.mu.Lock() + defer m.mu.Unlock() + for id, rec := range m.records { + if demoScopeMatches(rec.entry.Scope, scope) { + delete(m.records, id) + } + } + return nil +} + +func demoScopeMatches(stored, filter interfaces.MemoryScope) bool { + if filter.UserID != "" && stored.UserID != filter.UserID { + return false + } + if filter.TenantID != "" && stored.TenantID != filter.TenantID { + return false + } + if filter.AgentID != "" && stored.AgentID != filter.AgentID { + return false + } + return true +} + +// demoMemoryExtract returns a deterministic memory for run-end store (no extra LLM call). +func demoMemoryExtract(_ context.Context, _ []interfaces.Message) ([]interfaces.MemoryRecord, error) { + return []interfaces.MemoryRecord{{ + Text: "User prefers bullet-point answers. Contact email on file: alice@example.com", + Kind: "preference", + Metadata: map[string]string{"source": "extract"}, + }}, nil +} diff --git a/internal/hooks/hooks.go b/internal/hooks/hooks.go new file mode 100644 index 0000000..1357182 --- /dev/null +++ b/internal/hooks/hooks.go @@ -0,0 +1,101 @@ +// Package hooks defines agent middleware hook types used by the SDK runtime. +// Register hooks via [github.com/agenticenv/agent-sdk-go/pkg/agent.WithHooks]; types are re-exported +// from the agent package for application use. +package hooks + +// AgentHooks defines middleware hooks that fire at key points in the agent execution lifecycle. +// Multiple hooks can be registered per execution point and are chained in declaration order. +// Each hook receives the (possibly modified) output of the previous hook in the chain. +// Any hook returning an error aborts the remaining chain and halts execution. +// +// Common use cases by category: +// +// - Guardrails and safety: block bad inputs, prompt-injection checks ([BeforeLLMHook]); +// output filtering ([AfterLLMHook]). +// - PII and privacy: scrub prompts and responses ([BeforeLLMHook], [AfterLLMHook]); tool args +// ([BeforeToolHook], [AfterToolHook]); retrieval query and documents ([BeforeRetrieveHook], +// [AfterRetrieveHook]); memory load ([BeforeMemoryLoadHook], [AfterMemoryLoadHook]) and +// store ([BeforeMemoryStoreHook], [AfterMemoryStoreHook]). +// - Cost and caching: token tracking and budgets ([AfterLLMHook], [BeforeLLMHook]); LLM response +// cache read/write ([BeforeLLMHook], [AfterLLMHook]). +// - Rate limiting and validation: per-tool input scrubbing and rate limits ([BeforeToolHook]); +// retrieval query validation ([BeforeRetrieveHook]). +// - Logging and audit: LLM, tool, retrieval, and memory operations (before/after hooks on +// each area). +// - Resilience: model fallback ([AfterLLMHook]); retrieval retry or re-rank ([AfterRetrieveHook]). +// - Memory control: scope filtering and tenant isolation ([BeforeMemoryLoadHook], +// [BeforeMemoryStoreHook]); inspect injected context ([AfterMemoryLoadHook]). +type AgentHooks struct { + // LLM hooks fire on every model call. + // BeforeLLM — guardrails, PII redaction, prompt injection detection, caching, input validation + // AfterLLM — cost tracking, PII scrubbing, fallback model swap, cache store, token budget enforcement + BeforeLLM []BeforeLLMHook + AfterLLM []AfterLLMHook + + // Tool hooks fire for native and MCP tools only, after authorization and approval. + // BeforeTool — input scrubbing, rate limiting, arg mutation + // AfterTool — result scrubbing, logging, result transformation + BeforeTool []BeforeToolHook + AfterTool []AfterToolHook + + // Retrieve hooks fire for both prefetch and agentic RAG paths. + // BeforeRetrieve — query rewriting, PII scrubbing, query validation + // AfterRetrieve — result filtering, re-ranking, result logging + BeforeRetrieve []BeforeRetrieveHook + AfterRetrieve []AfterRetrieveHook + + // Memory hooks fire on memory read and write operations. + // BeforeMemoryLoad — query/load-option mutation; scope is read-only context on input + // AfterMemoryLoad — filter or rewrite prompt context injected into the LLM + // BeforeMemoryStore — scrub PII before persisting, control what gets stored + // AfterMemoryStore — audit persisted memories, logging + BeforeMemoryLoad []BeforeMemoryLoadHook + AfterMemoryLoad []AfterMemoryLoadHook + BeforeMemoryStore []BeforeMemoryStoreHook + AfterMemoryStore []AfterMemoryStoreHook +} + +// HookGroup is a named set of middleware hooks registered via +// [github.com/agenticenv/agent-sdk-go/pkg/agent.WithHooks]. +type HookGroup struct { + // Name is the unique hook group identifier used for Temporal fingerprinting and [RunMeta]. + Name string + + // Hooks are the middleware functions in this group. + Hooks AgentHooks +} + +// RunMeta carries read-only execution context shared across hooks in a run. +// Hooks must not modify RunMeta; the runtime populates it when firing hooks. +type RunMeta struct { + // RunID is the stable identifier for the current agent run. + RunID string + + // Iteration is the zero-based LLM loop round (0 for the first model call). + Iteration int + + // HooksGroup is the [github.com/agenticenv/agent-sdk-go/pkg/agent.WithHooks] group name for the + // hook currently executing. The runtime sets this from the validated group name when firing hooks. + HooksGroup string +} + +// Merge combines two AgentHooks by appending each hook slice in order. +// Hooks from other are appended after hooks already present in h. +// Nil or empty slices in either value are ignored by append. +func (h AgentHooks) Merge(other AgentHooks) AgentHooks { + return AgentHooks{ + BeforeLLM: append(h.BeforeLLM, other.BeforeLLM...), + AfterLLM: append(h.AfterLLM, other.AfterLLM...), + + BeforeTool: append(h.BeforeTool, other.BeforeTool...), + AfterTool: append(h.AfterTool, other.AfterTool...), + + BeforeRetrieve: append(h.BeforeRetrieve, other.BeforeRetrieve...), + AfterRetrieve: append(h.AfterRetrieve, other.AfterRetrieve...), + + BeforeMemoryLoad: append(h.BeforeMemoryLoad, other.BeforeMemoryLoad...), + AfterMemoryLoad: append(h.AfterMemoryLoad, other.AfterMemoryLoad...), + BeforeMemoryStore: append(h.BeforeMemoryStore, other.BeforeMemoryStore...), + AfterMemoryStore: append(h.AfterMemoryStore, other.AfterMemoryStore...), + } +} diff --git a/internal/hooks/hooks_test.go b/internal/hooks/hooks_test.go new file mode 100644 index 0000000..91bbb8f --- /dev/null +++ b/internal/hooks/hooks_test.go @@ -0,0 +1,473 @@ +package hooks + +import ( + "context" + "errors" + "testing" + + "github.com/agenticenv/agent-sdk-go/internal/types" + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" +) + +func TestAgentHooks_Merge_Empty(t *testing.T) { + got := (AgentHooks{}).Merge(AgentHooks{}) + if !hooksEmpty(got) { + t.Fatal("expected empty merged hooks") + } +} + +func TestAgentHooks_Merge_AppendsInOrder(t *testing.T) { + h1 := func(context.Context, BeforeLLMHookInput) (BeforeLLMHookOutput, error) { + return BeforeLLMHookOutput{}, nil + } + h2 := func(context.Context, BeforeLLMHookInput) (BeforeLLMHookOutput, error) { + return BeforeLLMHookOutput{}, nil + } + h3 := func(context.Context, AfterToolHookInput) (AfterToolHookOutput, error) { + return AfterToolHookOutput{}, nil + } + + base := AgentHooks{BeforeLLM: []BeforeLLMHook{h1}} + other := AgentHooks{ + BeforeLLM: []BeforeLLMHook{h2}, + AfterTool: []AfterToolHook{h3}, + } + got := base.Merge(other) + + if len(got.BeforeLLM) != 2 { + t.Fatalf("BeforeLLM len = %d, want 2", len(got.BeforeLLM)) + } + if got.BeforeLLM[0] == nil || got.BeforeLLM[1] == nil { + t.Fatal("expected non-nil hook funcs") + } + if len(got.AfterTool) != 1 { + t.Fatalf("AfterTool len = %d, want 1", len(got.AfterTool)) + } + if len(got.BeforeTool) != 0 { + t.Fatal("expected unrelated hook slices to stay empty") + } +} + +func TestAgentHooks_Merge_NilSlices(t *testing.T) { + h := func(context.Context, BeforeMemoryLoadHookInput) (BeforeMemoryLoadHookOutput, error) { + return BeforeMemoryLoadHookOutput{}, nil + } + got := AgentHooks{BeforeMemoryLoad: []BeforeMemoryLoadHook{h}}.Merge(AgentHooks{}) + if len(got.BeforeMemoryLoad) != 1 { + t.Fatalf("BeforeMemoryLoad len = %d, want 1", len(got.BeforeMemoryLoad)) + } +} + +func hooksEmpty(h AgentHooks) bool { + return len(h.BeforeLLM) == 0 && + len(h.AfterLLM) == 0 && + len(h.BeforeTool) == 0 && + len(h.AfterTool) == 0 && + len(h.BeforeRetrieve) == 0 && + len(h.AfterRetrieve) == 0 && + len(h.BeforeMemoryLoad) == 0 && + len(h.AfterMemoryLoad) == 0 && + len(h.BeforeMemoryStore) == 0 && + len(h.AfterMemoryStore) == 0 +} + +func TestRunBeforeLLM_chainAndGroupOrder(t *testing.T) { + var order []string + groups := []HookGroup{ + { + Name: "guardrails", + Hooks: AgentHooks{BeforeLLM: []BeforeLLMHook{ + func(_ context.Context, in BeforeLLMHookInput) (BeforeLLMHookOutput, error) { + order = append(order, "g1:"+in.RunMeta.HooksGroup) + out := in.Request + out.SystemMessage = "step1" + return BeforeLLMHookOutput{Request: out}, nil + }, + }}, + }, + { + Name: "audit", + Hooks: AgentHooks{BeforeLLM: []BeforeLLMHook{ + func(_ context.Context, in BeforeLLMHookInput) (BeforeLLMHookOutput, error) { + order = append(order, "g2:"+in.RunMeta.HooksGroup+":"+in.Request.SystemMessage) + out := in.Request + out.SystemMessage = "step2" + return BeforeLLMHookOutput{Request: out}, nil + }, + }}, + }, + } + meta := RunMeta{RunID: "run-1", Iteration: 3} + got, err := RunBeforeLLM(context.Background(), groups, meta, interfaces.LLMRequest{SystemMessage: "orig"}) + if err != nil { + t.Fatal(err) + } + if got.SystemMessage != "step2" { + t.Fatalf("SystemMessage = %q, want step2", got.SystemMessage) + } + if len(order) != 2 || order[0] != "g1:guardrails" || order[1] != "g2:audit:step1" { + t.Fatalf("order = %v", order) + } +} + +func TestRunBeforeLLM_errorAborts(t *testing.T) { + boom := errors.New("blocked") + groups := []HookGroup{{ + Name: "guardrails", + Hooks: AgentHooks{BeforeLLM: []BeforeLLMHook{ + func(context.Context, BeforeLLMHookInput) (BeforeLLMHookOutput, error) { + return BeforeLLMHookOutput{}, boom + }, + func(context.Context, BeforeLLMHookInput) (BeforeLLMHookOutput, error) { + t.Fatal("second hook should not run after error") + return BeforeLLMHookOutput{}, nil + }, + }}, + }} + _, err := RunBeforeLLM(context.Background(), groups, RunMeta{}, interfaces.LLMRequest{}) + if !errors.Is(err, boom) { + t.Fatalf("err = %v, want boom", err) + } +} + +func TestRunAfterLLM_modifiesResponse(t *testing.T) { + groups := []HookGroup{{ + Name: "scrub", + Hooks: AgentHooks{AfterLLM: []AfterLLMHook{ + func(_ context.Context, in AfterLLMHookInput) (AfterLLMHookOutput, error) { + out := in.Response + out.Content = "scrubbed" + return AfterLLMHookOutput{Response: out}, nil + }, + }}, + }} + got, err := RunAfterLLM(context.Background(), groups, RunMeta{HooksGroup: "unused"}, interfaces.LLMResponse{Content: "raw"}) + if err != nil { + t.Fatal(err) + } + if got.Content != "scrubbed" { + t.Fatalf("Content = %q", got.Content) + } +} + +func TestRunBeforeTool_skipsNonEligibleKind(t *testing.T) { + var called bool + groups := []HookGroup{{ + Name: "guard", + Hooks: AgentHooks{BeforeTool: []BeforeToolHook{ + func(context.Context, BeforeToolHookInput) (BeforeToolHookOutput, error) { + called = true + return BeforeToolHookOutput{}, nil + }, + }}, + }} + call := ToolCall{ID: "tc", Name: "retriever", DisplayName: "R", Kind: types.ToolKindRetriever} + got, err := RunBeforeTool(context.Background(), groups, RunMeta{}, call) + if err != nil { + t.Fatal(err) + } + if called { + t.Fatal("hooks should not run for retriever kind") + } + if got.Name != "retriever" { + t.Fatalf("call = %#v", got) + } +} + +func TestRunBeforeTool_chainOrder(t *testing.T) { + var order []string + groups := []HookGroup{ + {Name: "a", Hooks: AgentHooks{BeforeTool: []BeforeToolHook{ + func(_ context.Context, in BeforeToolHookInput) (BeforeToolHookOutput, error) { + order = append(order, "a:"+in.RunMeta.HooksGroup) + return BeforeToolHookOutput{Args: map[string]any{"step": "1"}}, nil + }, + }}}, + {Name: "b", Hooks: AgentHooks{BeforeTool: []BeforeToolHook{ + func(_ context.Context, in BeforeToolHookInput) (BeforeToolHookOutput, error) { + order = append(order, "b:"+in.RunMeta.HooksGroup) + return BeforeToolHookOutput{Args: in.Call.Args}, nil + }, + }}}, + } + call := ToolCall{ID: "tc", Name: "tool", DisplayName: "Tool", Kind: types.ToolKindNative, Args: map[string]any{"step": "0"}} + got, err := RunBeforeTool(context.Background(), groups, RunMeta{}, call) + if err != nil { + t.Fatal(err) + } + if got.Args["step"] != "1" { + t.Fatalf("args = %#v", got.Args) + } + if got.Name != "tool" || got.Kind != types.ToolKindNative { + t.Fatalf("read-only fields changed: %#v", got) + } + if len(order) != 2 || order[0] != "a:a" || order[1] != "b:b" { + t.Fatalf("order = %v", order) + } +} + +func TestRunBeforeRetrieve_chainOrder(t *testing.T) { + var order []string + groups := []HookGroup{ + {Name: "a", Hooks: AgentHooks{BeforeRetrieve: []BeforeRetrieveHook{ + func(_ context.Context, in BeforeRetrieveHookInput) (BeforeRetrieveHookOutput, error) { + order = append(order, "a:"+in.RunMeta.HooksGroup) + return BeforeRetrieveHookOutput{Query: "step1"}, nil + }, + }}}, + {Name: "b", Hooks: AgentHooks{BeforeRetrieve: []BeforeRetrieveHook{ + func(_ context.Context, in BeforeRetrieveHookInput) (BeforeRetrieveHookOutput, error) { + order = append(order, "b:"+in.RunMeta.HooksGroup) + return BeforeRetrieveHookOutput{Query: "step2"}, nil + }, + }}}, + } + q, err := RunBeforeRetrieve(context.Background(), groups, RunMeta{}, RetrieveCall{Query: "start", Mode: types.RetrieverModePrefetch, RetrieverName: "kb"}) + if err != nil { + t.Fatal(err) + } + if q.Query != "step2" { + t.Fatalf("query = %q", q.Query) + } + if q.Mode != types.RetrieverModePrefetch || q.RetrieverName != "kb" { + t.Fatalf("read-only fields changed: %#v", q) + } + if len(order) != 2 || order[0] != "a:a" || order[1] != "b:b" { + t.Fatalf("order = %v", order) + } +} + +func TestRunBeforeRetrieve_errorAborts(t *testing.T) { + groups := []HookGroup{{ + Name: "block", + Hooks: AgentHooks{BeforeRetrieve: []BeforeRetrieveHook{ + func(context.Context, BeforeRetrieveHookInput) (BeforeRetrieveHookOutput, error) { + return BeforeRetrieveHookOutput{}, errors.New("blocked") + }, + }}, + }} + _, err := RunBeforeRetrieve(context.Background(), groups, RunMeta{}, RetrieveCall{Query: "q", Mode: types.RetrieverModePrefetch}) + if err == nil || err.Error() != "blocked" { + t.Fatalf("err = %v", err) + } +} + +func TestRunAfterRetrieve_modifiesDocuments(t *testing.T) { + groups := []HookGroup{{ + Name: "filter", + Hooks: AgentHooks{AfterRetrieve: []AfterRetrieveHook{ + func(context.Context, AfterRetrieveHookInput) (AfterRetrieveHookOutput, error) { + return AfterRetrieveHookOutput{Documents: []interfaces.Document{{Content: "kept"}}}, nil + }, + }}, + }} + docs, err := RunAfterRetrieve(context.Background(), groups, RunMeta{}, RetrieveCall{Query: "q", Mode: types.RetrieverModeAgentic, RetrieverName: "kb"}, + []interfaces.Document{{Content: "drop"}}) + if err != nil { + t.Fatal(err) + } + if len(docs) != 1 || docs[0].Content != "kept" { + t.Fatalf("docs = %#v", docs) + } +} + +func TestRunAfterRetrieve_errorAborts(t *testing.T) { + groups := []HookGroup{{ + Name: "block", + Hooks: AgentHooks{AfterRetrieve: []AfterRetrieveHook{ + func(context.Context, AfterRetrieveHookInput) (AfterRetrieveHookOutput, error) { + return AfterRetrieveHookOutput{}, errors.New("blocked") + }, + }}, + }} + _, err := RunAfterRetrieve(context.Background(), groups, RunMeta{}, RetrieveCall{Query: "q", Mode: types.RetrieverModePrefetch, RetrieverName: "kb"}, + []interfaces.Document{{Content: "x"}}) + if err == nil || err.Error() != "blocked" { + t.Fatalf("err = %v", err) + } +} + +func TestRunBeforeMemoryLoad_chainOrder(t *testing.T) { + var order []string + scope := interfaces.MemoryScope{UserID: "u1"} + groups := []HookGroup{ + {Name: "a", Hooks: AgentHooks{BeforeMemoryLoad: []BeforeMemoryLoadHook{ + func(_ context.Context, in BeforeMemoryLoadHookInput) (BeforeMemoryLoadHookOutput, error) { + order = append(order, "a:"+in.RunMeta.HooksGroup) + return BeforeMemoryLoadHookOutput{Query: "step1", Limit: 3}, nil + }, + }}}, + {Name: "b", Hooks: AgentHooks{BeforeMemoryLoad: []BeforeMemoryLoadHook{ + func(_ context.Context, in BeforeMemoryLoadHookInput) (BeforeMemoryLoadHookOutput, error) { + order = append(order, "b:"+in.RunMeta.HooksGroup+":"+in.Query) + return BeforeMemoryLoadHookOutput{Query: "step2", Limit: in.Limit, MinScore: 0.9}, nil + }, + }}}, + } + got, err := RunBeforeMemoryLoad(context.Background(), groups, RunMeta{}, MemoryLoadCall{ + Scope: scope, Query: "start", Limit: 5, + }) + if err != nil { + t.Fatal(err) + } + if got.Query != "step2" || got.Limit != 3 || got.MinScore != 0.9 { + t.Fatalf("call = %#v", got) + } + if got.Scope.UserID != "u1" { + t.Fatal("scope should be read-only") + } + if len(order) != 2 || order[0] != "a:a" || order[1] != "b:b:step1" { + t.Fatalf("order = %v", order) + } +} + +func TestRunBeforeMemoryLoad_errorAborts(t *testing.T) { + groups := []HookGroup{{ + Name: "block", + Hooks: AgentHooks{BeforeMemoryLoad: []BeforeMemoryLoadHook{ + func(context.Context, BeforeMemoryLoadHookInput) (BeforeMemoryLoadHookOutput, error) { + return BeforeMemoryLoadHookOutput{}, errors.New("blocked") + }, + }}, + }} + _, err := RunBeforeMemoryLoad(context.Background(), groups, RunMeta{}, MemoryLoadCall{Query: "q"}) + if err == nil || err.Error() != "blocked" { + t.Fatalf("err = %v", err) + } +} + +func TestRunAfterMemoryLoad_modifiesPromptContext(t *testing.T) { + groups := []HookGroup{{ + Name: "filter", + Hooks: AgentHooks{AfterMemoryLoad: []AfterMemoryLoadHook{ + func(context.Context, AfterMemoryLoadHookInput) (AfterMemoryLoadHookOutput, error) { + return AfterMemoryLoadHookOutput{PromptContext: "filtered"}, nil + }, + }}, + }} + ctx, err := RunAfterMemoryLoad(context.Background(), groups, RunMeta{}, MemoryLoadCall{Query: "q"}, "raw") + if err != nil { + t.Fatal(err) + } + if ctx != "filtered" { + t.Fatalf("PromptContext = %q", ctx) + } +} + +func TestRunAfterMemoryLoad_errorAborts(t *testing.T) { + groups := []HookGroup{{ + Name: "block", + Hooks: AgentHooks{AfterMemoryLoad: []AfterMemoryLoadHook{ + func(context.Context, AfterMemoryLoadHookInput) (AfterMemoryLoadHookOutput, error) { + return AfterMemoryLoadHookOutput{}, errors.New("blocked") + }, + }}, + }} + _, err := RunAfterMemoryLoad(context.Background(), groups, RunMeta{}, MemoryLoadCall{Query: "q"}, "ctx") + if err == nil || err.Error() != "blocked" { + t.Fatalf("err = %v", err) + } +} + +func TestRunBeforeMemoryStore_chainOrder(t *testing.T) { + var order []string + scope := interfaces.MemoryScope{AgentID: "a1"} + groups := []HookGroup{ + {Name: "scrub", Hooks: AgentHooks{BeforeMemoryStore: []BeforeMemoryStoreHook{ + func(_ context.Context, in BeforeMemoryStoreHookInput) (BeforeMemoryStoreHookOutput, error) { + order = append(order, in.RunMeta.HooksGroup) + return BeforeMemoryStoreHookOutput{ + Record: interfaces.MemoryRecord{Text: "scrubbed"}, + ID: "id-1", + }, nil + }, + }}}, + {Name: "audit", Hooks: AgentHooks{BeforeMemoryStore: []BeforeMemoryStoreHook{ + func(_ context.Context, in BeforeMemoryStoreHookInput) (BeforeMemoryStoreHookOutput, error) { + order = append(order, in.RunMeta.HooksGroup+":"+in.Record.Text) + return BeforeMemoryStoreHookOutput{ + Record: interfaces.MemoryRecord{Text: in.Record.Text + "-final"}, + ID: "id-2", + }, nil + }, + }}}, + } + got, err := RunBeforeMemoryStore(context.Background(), groups, RunMeta{}, MemoryStoreCall{ + Scope: scope, Record: interfaces.MemoryRecord{Text: "raw"}, ID: "orig", + }) + if err != nil { + t.Fatal(err) + } + if got.Record.Text != "scrubbed-final" || got.ID != "id-2" { + t.Fatalf("call = %#v", got) + } + if got.Scope.AgentID != "a1" { + t.Fatal("scope should be read-only") + } + if len(order) != 2 || order[0] != "scrub" || order[1] != "audit:scrubbed" { + t.Fatalf("order = %v", order) + } +} + +func TestRunBeforeMemoryStore_errorAborts(t *testing.T) { + groups := []HookGroup{{ + Name: "block", + Hooks: AgentHooks{BeforeMemoryStore: []BeforeMemoryStoreHook{ + func(context.Context, BeforeMemoryStoreHookInput) (BeforeMemoryStoreHookOutput, error) { + return BeforeMemoryStoreHookOutput{}, errors.New("blocked") + }, + }}, + }} + _, err := RunBeforeMemoryStore(context.Background(), groups, RunMeta{}, MemoryStoreCall{ + Record: interfaces.MemoryRecord{Text: "x"}, + }) + if err == nil || err.Error() != "blocked" { + t.Fatalf("err = %v", err) + } +} + +func TestRunAfterMemoryStore_runsInOrder(t *testing.T) { + var order []string + groups := []HookGroup{ + {Name: "a", Hooks: AgentHooks{AfterMemoryStore: []AfterMemoryStoreHook{ + func(_ context.Context, in AfterMemoryStoreHookInput) (AfterMemoryStoreHookOutput, error) { + order = append(order, "a:"+in.RunMeta.HooksGroup) + return AfterMemoryStoreHookOutput{}, nil + }, + }}}, + {Name: "b", Hooks: AgentHooks{AfterMemoryStore: []AfterMemoryStoreHook{ + func(_ context.Context, in AfterMemoryStoreHookInput) (AfterMemoryStoreHookOutput, error) { + order = append(order, "b:"+in.ID) + return AfterMemoryStoreHookOutput{}, nil + }, + }}}, + } + call := MemoryStoreCall{ + Scope: interfaces.MemoryScope{UserID: "u1"}, + Record: interfaces.MemoryRecord{Text: "stored"}, + ID: "mem-1", + } + if err := RunAfterMemoryStore(context.Background(), groups, RunMeta{}, call); err != nil { + t.Fatal(err) + } + if len(order) != 2 || order[0] != "a:a" || order[1] != "b:mem-1" { + t.Fatalf("order = %v", order) + } +} + +func TestRunAfterMemoryStore_errorAborts(t *testing.T) { + groups := []HookGroup{{ + Name: "block", + Hooks: AgentHooks{AfterMemoryStore: []AfterMemoryStoreHook{ + func(context.Context, AfterMemoryStoreHookInput) (AfterMemoryStoreHookOutput, error) { + return AfterMemoryStoreHookOutput{}, errors.New("blocked") + }, + }}, + }} + err := RunAfterMemoryStore(context.Background(), groups, RunMeta{}, MemoryStoreCall{ + Record: interfaces.MemoryRecord{Text: "x"}, ID: "id", + }) + if err == nil || err.Error() != "blocked" { + t.Fatalf("err = %v", err) + } +} diff --git a/internal/hooks/llm.go b/internal/hooks/llm.go new file mode 100644 index 0000000..90c879b --- /dev/null +++ b/internal/hooks/llm.go @@ -0,0 +1,87 @@ +package hooks + +import ( + "context" + + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" +) + +// BeforeLLMHookInput is the payload passed to [BeforeLLMHook] before an LLM call. +type BeforeLLMHookInput struct { + RunMeta RunMeta + Request interfaces.LLMRequest +} + +// BeforeLLMHookOutput is the mutable result returned from [BeforeLLMHook]. +type BeforeLLMHookOutput struct { + Request interfaces.LLMRequest +} + +// AfterLLMHookInput is the payload passed to [AfterLLMHook] after an LLM call completes. +type AfterLLMHookInput struct { + RunMeta RunMeta + Response interfaces.LLMResponse +} + +// AfterLLMHookOutput is the mutable result returned from [AfterLLMHook]. +type AfterLLMHookOutput struct { + Response interfaces.LLMResponse +} + +// BeforeLLMHook runs before each LLM request is sent. Return a modified request or an error to abort the run. +type BeforeLLMHook func(ctx context.Context, input BeforeLLMHookInput) (BeforeLLMHookOutput, error) + +// AfterLLMHook runs after each LLM response is received. Return a modified response or an error to abort the run. +type AfterLLMHook func(ctx context.Context, input AfterLLMHookInput) (AfterLLMHookOutput, error) + +// RunBeforeLLM runs all BeforeLLM hooks in hook group registration order. Hooks within a group +// run in declaration order; each hook receives the output of the previous hook. The first error +// aborts the remaining chain. Returns req unchanged when groups is empty or no BeforeLLM hooks +// are registered. +func RunBeforeLLM(ctx context.Context, groups []HookGroup, meta RunMeta, req interfaces.LLMRequest) (interfaces.LLMRequest, error) { + current := req + for _, g := range groups { + if len(g.Hooks.BeforeLLM) == 0 { + continue + } + groupMeta := meta + groupMeta.HooksGroup = g.Name + for _, hook := range g.Hooks.BeforeLLM { + if hook == nil { + continue + } + out, err := hook(ctx, BeforeLLMHookInput{RunMeta: groupMeta, Request: current}) + if err != nil { + return interfaces.LLMRequest{}, err + } + current = out.Request + } + } + return current, nil +} + +// RunAfterLLM runs all AfterLLM hooks in hook group registration order. Hooks within a group run +// in declaration order; each hook receives the output of the previous hook. The first error +// aborts the remaining chain. Returns resp unchanged when groups is empty or no AfterLLM hooks +// are registered. +func RunAfterLLM(ctx context.Context, groups []HookGroup, meta RunMeta, resp interfaces.LLMResponse) (interfaces.LLMResponse, error) { + current := resp + for _, g := range groups { + if len(g.Hooks.AfterLLM) == 0 { + continue + } + groupMeta := meta + groupMeta.HooksGroup = g.Name + for _, hook := range g.Hooks.AfterLLM { + if hook == nil { + continue + } + out, err := hook(ctx, AfterLLMHookInput{RunMeta: groupMeta, Response: current}) + if err != nil { + return interfaces.LLMResponse{}, err + } + current = out.Response + } + } + return current, nil +} diff --git a/internal/hooks/memory.go b/internal/hooks/memory.go new file mode 100644 index 0000000..aa63814 --- /dev/null +++ b/internal/hooks/memory.go @@ -0,0 +1,236 @@ +package hooks + +import ( + "context" + + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" +) + +// BeforeMemoryLoadHookInput is the payload passed to [BeforeMemoryLoadHook] before memories are loaded. +type BeforeMemoryLoadHookInput struct { + RunMeta RunMeta + Scope interfaces.MemoryScope + Query string + + // Limit is the maximum number of memories to return. Zero means backend default. + Limit int + + // MinScore filters out entries below the given relevance score when Score is applicable. + MinScore float32 + + // Kinds restricts recall to the given memory kinds. Empty means all kinds. + Kinds []interfaces.MemoryKind +} + +// BeforeMemoryLoadHookOutput is the mutable result returned from [BeforeMemoryLoadHook]. +// Only Query and load options may be changed; Scope on input is read-only context. +type BeforeMemoryLoadHookOutput struct { + Query string + Limit int + MinScore float32 + Kinds []interfaces.MemoryKind +} + +// AfterMemoryLoadHookInput is the payload passed to [AfterMemoryLoadHook] after memories are loaded. +type AfterMemoryLoadHookInput struct { + RunMeta RunMeta + Scope interfaces.MemoryScope + Query string + + // PromptContext is the formatted memory block injected into the LLM system prompt. + PromptContext string +} + +// AfterMemoryLoadHookOutput is the mutable result returned from [AfterMemoryLoadHook]. +type AfterMemoryLoadHookOutput struct { + PromptContext string +} + +// BeforeMemoryStoreHookInput is the payload passed to [BeforeMemoryStoreHook] before a memory is stored. +type BeforeMemoryStoreHookInput struct { + RunMeta RunMeta + Scope interfaces.MemoryScope + Record interfaces.MemoryRecord + + // ID upserts the record when non-empty. + ID string +} + +// BeforeMemoryStoreHookOutput is the mutable result returned from [BeforeMemoryStoreHook]. +// Only Record and ID may be changed; Scope on input is read-only context. +type BeforeMemoryStoreHookOutput struct { + Record interfaces.MemoryRecord + ID string +} + +// AfterMemoryStoreHookInput is the payload passed to [AfterMemoryStoreHook] after a memory is stored. +type AfterMemoryStoreHookInput struct { + RunMeta RunMeta + Scope interfaces.MemoryScope + Record interfaces.MemoryRecord + + // ID is the record identifier assigned by the backend after a successful store. + ID string +} + +// AfterMemoryStoreHookOutput is the mutable result returned from [AfterMemoryStoreHook]. +// Store has already completed; hooks use input for audit and may abort via error only. +type AfterMemoryStoreHookOutput struct{} + +// BeforeMemoryLoadHook runs before memory load. Return modified query or load options, or an error to abort the run. +type BeforeMemoryLoadHook func(ctx context.Context, input BeforeMemoryLoadHookInput) (BeforeMemoryLoadHookOutput, error) + +// AfterMemoryLoadHook runs after memory load. Return a modified prompt context or an error to abort the run. +type AfterMemoryLoadHook func(ctx context.Context, input AfterMemoryLoadHookInput) (AfterMemoryLoadHookOutput, error) + +// BeforeMemoryStoreHook runs before memory store. Return modified record or upsert ID, or an error to abort the run. +type BeforeMemoryStoreHook func(ctx context.Context, input BeforeMemoryStoreHookInput) (BeforeMemoryStoreHookOutput, error) + +// AfterMemoryStoreHook runs after memory store. Return an error to abort the run. +type AfterMemoryStoreHook func(ctx context.Context, input AfterMemoryStoreHookInput) (AfterMemoryStoreHookOutput, error) + +// MemoryLoadCall is the resolved memory load invocation used by the runtime hook runner. +type MemoryLoadCall struct { + Scope interfaces.MemoryScope + Query string + Limit int + MinScore float32 + Kinds []interfaces.MemoryKind +} + +// MemoryStoreCall is the resolved memory store invocation used by the runtime hook runner. +type MemoryStoreCall struct { + Scope interfaces.MemoryScope + Record interfaces.MemoryRecord + ID string +} + +// RunBeforeMemoryLoad runs all BeforeMemoryLoad hooks in hook group registration order. +// Only Query and load options may be changed; Scope on call is read-only context. +// Returns call unchanged when groups is empty or no BeforeMemoryLoad hooks are registered. +func RunBeforeMemoryLoad(ctx context.Context, groups []HookGroup, meta RunMeta, call MemoryLoadCall) (MemoryLoadCall, error) { + current := call + for _, g := range groups { + if len(g.Hooks.BeforeMemoryLoad) == 0 { + continue + } + groupMeta := meta + groupMeta.HooksGroup = g.Name + for _, hook := range g.Hooks.BeforeMemoryLoad { + if hook == nil { + continue + } + out, err := hook(ctx, BeforeMemoryLoadHookInput{ + RunMeta: groupMeta, + Scope: current.Scope, + Query: current.Query, + Limit: current.Limit, + MinScore: current.MinScore, + Kinds: cloneKinds(current.Kinds), + }) + if err != nil { + return MemoryLoadCall{}, err + } + current.Query = out.Query + current.Limit = out.Limit + current.MinScore = out.MinScore + current.Kinds = cloneKinds(out.Kinds) + } + } + return current, nil +} + +// RunAfterMemoryLoad runs all AfterMemoryLoad hooks in hook group registration order. +// Scope and Query on input are read-only context; only PromptContext may change. +// Returns promptContext unchanged when groups is empty or no AfterMemoryLoad hooks are registered. +func RunAfterMemoryLoad(ctx context.Context, groups []HookGroup, meta RunMeta, call MemoryLoadCall, promptContext string) (string, error) { + currentContext := promptContext + for _, g := range groups { + if len(g.Hooks.AfterMemoryLoad) == 0 { + continue + } + groupMeta := meta + groupMeta.HooksGroup = g.Name + for _, hook := range g.Hooks.AfterMemoryLoad { + if hook == nil { + continue + } + out, err := hook(ctx, AfterMemoryLoadHookInput{ + RunMeta: groupMeta, + Scope: call.Scope, + Query: call.Query, + PromptContext: currentContext, + }) + if err != nil { + return "", err + } + currentContext = out.PromptContext + } + } + return currentContext, nil +} + +// RunBeforeMemoryStore runs all BeforeMemoryStore hooks in hook group registration order. +// Only Record and ID may be changed; Scope on call is read-only context. +// Returns call unchanged when groups is empty or no BeforeMemoryStore hooks are registered. +func RunBeforeMemoryStore(ctx context.Context, groups []HookGroup, meta RunMeta, call MemoryStoreCall) (MemoryStoreCall, error) { + current := call + for _, g := range groups { + if len(g.Hooks.BeforeMemoryStore) == 0 { + continue + } + groupMeta := meta + groupMeta.HooksGroup = g.Name + for _, hook := range g.Hooks.BeforeMemoryStore { + if hook == nil { + continue + } + out, err := hook(ctx, BeforeMemoryStoreHookInput{ + RunMeta: groupMeta, + Scope: current.Scope, + Record: current.Record, + ID: current.ID, + }) + if err != nil { + return MemoryStoreCall{}, err + } + current.Record, current.ID = out.Record, out.ID + } + } + return current, nil +} + +// RunAfterMemoryStore runs all AfterMemoryStore hooks in hook group registration order. +// Returns nil when groups is empty or no AfterMemoryStore hooks are registered. +func RunAfterMemoryStore(ctx context.Context, groups []HookGroup, meta RunMeta, call MemoryStoreCall) error { + for _, g := range groups { + if len(g.Hooks.AfterMemoryStore) == 0 { + continue + } + groupMeta := meta + groupMeta.HooksGroup = g.Name + for _, hook := range g.Hooks.AfterMemoryStore { + if hook == nil { + continue + } + if _, err := hook(ctx, AfterMemoryStoreHookInput{ + RunMeta: groupMeta, + Scope: call.Scope, + Record: call.Record, + ID: call.ID, + }); err != nil { + return err + } + } + } + return nil +} + +func cloneKinds(kinds []interfaces.MemoryKind) []interfaces.MemoryKind { + if len(kinds) == 0 { + return nil + } + out := make([]interfaces.MemoryKind, len(kinds)) + copy(out, kinds) + return out +} diff --git a/internal/hooks/retriever.go b/internal/hooks/retriever.go new file mode 100644 index 0000000..7cd7b21 --- /dev/null +++ b/internal/hooks/retriever.go @@ -0,0 +1,123 @@ +package hooks + +import ( + "context" + + "github.com/agenticenv/agent-sdk-go/internal/types" + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" +) + +// RetrieveCall is the resolved retrieval invocation used by the runtime hook runner. +type RetrieveCall struct { + Query string + Mode types.RetrieverMode + RetrieverName string +} + +// BeforeRetrieveHookInput is the payload passed to [BeforeRetrieveHook] before a retrieval runs. +type BeforeRetrieveHookInput struct { + RunMeta RunMeta + Query string + Mode types.RetrieverMode + + // RetrieverName is the target retriever when agentic; empty when prefetch runs all configured retrievers. + RetrieverName string +} + +// BeforeRetrieveHookOutput is the mutable result returned from [BeforeRetrieveHook]. +// Only Query may be changed; Mode and RetrieverName on input are read-only context. +type BeforeRetrieveHookOutput struct { + Query string +} + +// AfterRetrieveHookInput is the payload passed to [AfterRetrieveHook] after documents are retrieved. +type AfterRetrieveHookInput struct { + RunMeta RunMeta + Query string + Mode types.RetrieverMode + RetrieverName string + Documents []interfaces.Document +} + +// AfterRetrieveHookOutput is the mutable result returned from [AfterRetrieveHook]. +type AfterRetrieveHookOutput struct { + Documents []interfaces.Document +} + +// BeforeRetrieveHook runs before retrieval. Return a modified query or an error to abort the run. +type BeforeRetrieveHook func(ctx context.Context, input BeforeRetrieveHookInput) (BeforeRetrieveHookOutput, error) + +// AfterRetrieveHook runs after retrieval. Return filtered or re-ranked documents or an error to abort the run. +type AfterRetrieveHook func(ctx context.Context, input AfterRetrieveHookInput) (AfterRetrieveHookOutput, error) + +// RunBeforeRetrieve runs all BeforeRetrieve hooks in hook group registration order. Hooks within a +// group run in declaration order; each hook receives the output of the previous hook. The first +// error aborts the remaining chain. Returns call unchanged when groups is empty or no +// BeforeRetrieve hooks are registered. +func RunBeforeRetrieve(ctx context.Context, groups []HookGroup, meta RunMeta, call RetrieveCall) (RetrieveCall, error) { + current := call + for _, g := range groups { + if len(g.Hooks.BeforeRetrieve) == 0 { + continue + } + groupMeta := meta + groupMeta.HooksGroup = g.Name + for _, hook := range g.Hooks.BeforeRetrieve { + if hook == nil { + continue + } + out, err := hook(ctx, BeforeRetrieveHookInput{ + RunMeta: groupMeta, + Query: current.Query, + Mode: current.Mode, + RetrieverName: current.RetrieverName, + }) + if err != nil { + return RetrieveCall{}, err + } + current.Query = out.Query + } + } + return current, nil +} + +// RunAfterRetrieve runs all AfterRetrieve hooks in hook group registration order. Hooks within a +// group run in declaration order; each hook receives the output of the previous hook. The first +// error aborts the remaining chain. Returns documents unchanged when groups is empty or no +// AfterRetrieve hooks are registered. +func RunAfterRetrieve(ctx context.Context, groups []HookGroup, meta RunMeta, call RetrieveCall, documents []interfaces.Document) ([]interfaces.Document, error) { + current := documents + for _, g := range groups { + if len(g.Hooks.AfterRetrieve) == 0 { + continue + } + groupMeta := meta + groupMeta.HooksGroup = g.Name + for _, hook := range g.Hooks.AfterRetrieve { + if hook == nil { + continue + } + out, err := hook(ctx, AfterRetrieveHookInput{ + RunMeta: groupMeta, + Query: call.Query, + Mode: call.Mode, + RetrieverName: call.RetrieverName, + Documents: cloneDocuments(current), + }) + if err != nil { + return nil, err + } + current = cloneDocuments(out.Documents) + } + } + return current, nil +} + +func cloneDocuments(docs []interfaces.Document) []interfaces.Document { + if len(docs) == 0 { + return nil + } + out := make([]interfaces.Document, len(docs)) + copy(out, docs) + return out +} diff --git a/internal/hooks/tools.go b/internal/hooks/tools.go new file mode 100644 index 0000000..7988457 --- /dev/null +++ b/internal/hooks/tools.go @@ -0,0 +1,135 @@ +package hooks + +import ( + "context" + + "github.com/agenticenv/agent-sdk-go/internal/types" +) + +// ToolCall is the resolved tool invocation passed to tool hooks. +type ToolCall struct { + // ID is the tool call identifier from the LLM; used to match tool results in the conversation. + ID string + + // Name is the tool identifier the LLM selected. + Name string + + // DisplayName is the human-readable tool name when available. + DisplayName string + + // Kind classifies the tool implementation (native, MCP, sub-agent, etc.). + Kind types.ToolKind + + // Args are the arguments the LLM produced for this invocation. + Args map[string]any +} + +// BeforeToolHookInput is the payload passed to [BeforeToolHook] before a tool executes. +type BeforeToolHookInput struct { + RunMeta RunMeta + Call ToolCall +} + +// BeforeToolHookOutput is the mutable result returned from [BeforeToolHook]. +// Only Args may be changed; identity fields on [BeforeToolHookInput].Call are read-only context. +type BeforeToolHookOutput struct { + Args map[string]any +} + +// AfterToolHookInput is the payload passed to [AfterToolHook] after a tool executes. +type AfterToolHookInput struct { + RunMeta RunMeta + Call ToolCall + + // Content is the serialized tool result passed back to the LLM. + Content string + + // Err is set when tool execution failed. + Err error +} + +// AfterToolHookOutput is the mutable result returned from [AfterToolHook]. +type AfterToolHookOutput struct { + Content string + Err error +} + +// BeforeToolHook runs immediately before tool.Execute for native and MCP tools, after +// programmatic authorization and interactive approval have succeeded. Return modified args +// or an error to abort the run. Not a substitute for the SDK's tool authorization or approval mechanisms. +type BeforeToolHook func(ctx context.Context, input BeforeToolHookInput) (BeforeToolHookOutput, error) + +// AfterToolHook runs after a tool executes. Return a modified result or an error to abort the run. +type AfterToolHook func(ctx context.Context, input AfterToolHookInput) (AfterToolHookOutput, error) + +// RunBeforeTool runs all BeforeTool hooks in hook group registration order for hook-eligible +// tool kinds ([types.ToolKind.HooksEligible]). Hooks within a group run in declaration order. +// The first error aborts the remaining chain. +func RunBeforeTool(ctx context.Context, groups []HookGroup, meta RunMeta, call ToolCall) (ToolCall, error) { + if !call.Kind.HooksEligible() { + return call, nil + } + current := call + for _, g := range groups { + if len(g.Hooks.BeforeTool) == 0 { + continue + } + groupMeta := meta + groupMeta.HooksGroup = g.Name + for _, hook := range g.Hooks.BeforeTool { + if hook == nil { + continue + } + out, err := hook(ctx, BeforeToolHookInput{RunMeta: groupMeta, Call: current}) + if err != nil { + return ToolCall{}, err + } + current.Args = cloneToolArgs(out.Args) + } + } + return current, nil +} + +// RunAfterTool runs all AfterTool hooks in hook group registration order for hook-eligible tool +// kinds. Hooks within a group run in declaration order. The first error aborts the chain. +func RunAfterTool(ctx context.Context, groups []HookGroup, meta RunMeta, call ToolCall, content string, execErr error) (string, error, error) { + if !call.Kind.HooksEligible() { + return content, execErr, nil + } + currentContent := content + currentErr := execErr + for _, g := range groups { + if len(g.Hooks.AfterTool) == 0 { + continue + } + groupMeta := meta + groupMeta.HooksGroup = g.Name + for _, hook := range g.Hooks.AfterTool { + if hook == nil { + continue + } + out, err := hook(ctx, AfterToolHookInput{ + RunMeta: groupMeta, + Call: call, + Content: currentContent, + Err: currentErr, + }) + if err != nil { + return currentContent, currentErr, err + } + currentContent, currentErr = out.Content, out.Err + } + } + return currentContent, currentErr, nil +} + +func cloneToolArgs(args map[string]any) map[string]any { + if len(args) == 0 { + return nil + } + out := make(map[string]any, len(args)) + for k, v := range args { + out[k] = v + } + return out +} diff --git a/internal/runtime/base/hooks.go b/internal/runtime/base/hooks.go new file mode 100644 index 0000000..48f6205 --- /dev/null +++ b/internal/runtime/base/hooks.go @@ -0,0 +1,116 @@ +package base + +import ( + "context" + + "github.com/agenticenv/agent-sdk-go/internal/hooks" + "github.com/agenticenv/agent-sdk-go/internal/types" + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" +) + +func hookRunMeta(runID string, iteration int) hooks.RunMeta { + return hooks.RunMeta{ + RunID: runID, + Iteration: iteration, + } +} + +func (rt *Runtime) runBeforeLLMRequest(ctx context.Context, input ExecuteLLMInput, req *interfaces.LLMRequest) error { + if req == nil || len(rt.AgentConfig.Hooks) == 0 { + return nil + } + hooked, err := hooks.RunBeforeLLM(ctx, rt.AgentConfig.Hooks, hookRunMeta(input.RunID, input.Iteration), *req) + if err != nil { + return err + } + *req = hooked + return nil +} + +func (rt *Runtime) runAfterLLMResponse(ctx context.Context, input ExecuteLLMInput, resp *interfaces.LLMResponse) error { + if resp == nil || len(rt.AgentConfig.Hooks) == 0 { + return nil + } + hooked, err := hooks.RunAfterLLM(ctx, rt.AgentConfig.Hooks, hookRunMeta(input.RunID, input.Iteration), *resp) + if err != nil { + return err + } + *resp = hooked + return nil +} + +func (rt *Runtime) toolCallForHooks(input ExecuteToolInput, tool interfaces.Tool) hooks.ToolCall { + displayName := tool.DisplayName() + if displayName == "" { + displayName = input.ToolName + } + return hooks.ToolCall{ + ID: input.ToolCallID, + Name: input.ToolName, + DisplayName: displayName, + Kind: types.KindOf(tool), + Args: input.Args, + } +} + +func (rt *Runtime) runBeforeToolHooks(ctx context.Context, input ExecuteToolInput, tool interfaces.Tool) (hooks.ToolCall, error) { + call := rt.toolCallForHooks(input, tool) + if len(rt.AgentConfig.Hooks) == 0 { + return call, nil + } + return hooks.RunBeforeTool(ctx, rt.AgentConfig.Hooks, hookRunMeta(input.RunID, input.Iteration), call) +} + +func (rt *Runtime) runAfterToolHooks(ctx context.Context, input ExecuteToolInput, call hooks.ToolCall, content string, execErr error) (string, error, error) { + if len(rt.AgentConfig.Hooks) == 0 { + return content, execErr, nil + } + return hooks.RunAfterTool(ctx, rt.AgentConfig.Hooks, hookRunMeta(input.RunID, input.Iteration), call, content, execErr) +} + +func (rt *Runtime) runBeforeRetrieveHooks(ctx context.Context, input ExecuteRetrieversInput, mode types.RetrieverMode, retrieverName string) (hooks.RetrieveCall, error) { + call := hooks.RetrieveCall{ + Query: input.Query, + Mode: mode, + RetrieverName: retrieverName, + } + if len(rt.AgentConfig.Hooks) == 0 { + return call, nil + } + return hooks.RunBeforeRetrieve(ctx, rt.AgentConfig.Hooks, hookRunMeta(input.RunID, input.Iteration), call) +} + +func (rt *Runtime) runAfterRetrieveHooks(ctx context.Context, input ExecuteRetrieversInput, call hooks.RetrieveCall, docs []interfaces.Document) ([]interfaces.Document, error) { + if len(rt.AgentConfig.Hooks) == 0 { + return docs, nil + } + return hooks.RunAfterRetrieve(ctx, rt.AgentConfig.Hooks, hookRunMeta(input.RunID, input.Iteration), call, docs) +} + +func (rt *Runtime) runBeforeMemoryLoadHooks(ctx context.Context, input ExecuteMemoryRecallInput, call hooks.MemoryLoadCall) (hooks.MemoryLoadCall, error) { + if len(rt.AgentConfig.Hooks) == 0 { + return call, nil + } + return hooks.RunBeforeMemoryLoad(ctx, rt.AgentConfig.Hooks, hookRunMeta(input.RunID, input.Iteration), call) +} + +func (rt *Runtime) runAfterMemoryLoadHooks(ctx context.Context, input ExecuteMemoryRecallInput, call hooks.MemoryLoadCall, promptContext string) (string, error) { + if len(rt.AgentConfig.Hooks) == 0 { + return promptContext, nil + } + return hooks.RunAfterMemoryLoad(ctx, rt.AgentConfig.Hooks, hookRunMeta(input.RunID, input.Iteration), call, promptContext) +} + +func (rt *Runtime) runBeforeMemoryStoreHooks(ctx context.Context, input StoreMemoryRecordsInput, call hooks.MemoryStoreCall) (hooks.MemoryStoreCall, error) { + if len(rt.AgentConfig.Hooks) == 0 { + return call, nil + } + return hooks.RunBeforeMemoryStore(ctx, rt.AgentConfig.Hooks, hookRunMeta(input.RunID, input.Iteration), call) +} + +func (rt *Runtime) runAfterMemoryStoreHooks(ctx context.Context, input StoreMemoryRecordsInput, call hooks.MemoryStoreCall) error { + if len(rt.AgentConfig.Hooks) == 0 { + return nil + } + return hooks.RunAfterMemoryStore(ctx, rt.AgentConfig.Hooks, hookRunMeta(input.RunID, input.Iteration), call) +} diff --git a/internal/runtime/base/hooks_test.go b/internal/runtime/base/hooks_test.go new file mode 100644 index 0000000..502967f --- /dev/null +++ b/internal/runtime/base/hooks_test.go @@ -0,0 +1,125 @@ +package base + +import ( + "context" + "testing" + + "github.com/agenticenv/agent-sdk-go/internal/hooks" + sdkruntime "github.com/agenticenv/agent-sdk-go/internal/runtime" + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" +) + +func TestRunBeforeMemoryLoadHooks_noHooksPassthrough(t *testing.T) { + rt := newTestRuntime(sdkruntime.AgentConfig{}) + call := hooks.MemoryLoadCall{Query: "q", Limit: 5} + got, err := rt.runBeforeMemoryLoadHooks(context.Background(), ExecuteMemoryRecallInput{ + RunID: "run-1", Iteration: 0, + }, call) + if err != nil { + t.Fatal(err) + } + if got.Query != call.Query || got.Limit != call.Limit { + t.Fatalf("got %#v", got) + } +} + +func TestRunBeforeMemoryLoadHooks_delegates(t *testing.T) { + rt := newTestRuntime(sdkruntime.AgentConfig{ + Hooks: []hooks.HookGroup{{ + Name: "rewrite", + Hooks: hooks.AgentHooks{BeforeMemoryLoad: []hooks.BeforeMemoryLoadHook{ + func(_ context.Context, in hooks.BeforeMemoryLoadHookInput) (hooks.BeforeMemoryLoadHookOutput, error) { + if in.RunMeta.RunID != "run-1" || in.RunMeta.Iteration != 2 || in.RunMeta.HooksGroup != "rewrite" { + t.Fatalf("RunMeta = %#v", in.RunMeta) + } + return hooks.BeforeMemoryLoadHookOutput{Query: "hooked"}, nil + }, + }}, + }}, + }) + got, err := rt.runBeforeMemoryLoadHooks(context.Background(), ExecuteMemoryRecallInput{ + RunID: "run-1", Iteration: 2, + }, hooks.MemoryLoadCall{Query: "orig"}) + if err != nil { + t.Fatal(err) + } + if got.Query != "hooked" { + t.Fatalf("query = %q", got.Query) + } +} + +func TestRunAfterMemoryLoadHooks_delegates(t *testing.T) { + rt := newTestRuntime(sdkruntime.AgentConfig{ + Hooks: []hooks.HookGroup{{ + Name: "filter", + Hooks: hooks.AgentHooks{AfterMemoryLoad: []hooks.AfterMemoryLoadHook{ + func(_ context.Context, in hooks.AfterMemoryLoadHookInput) (hooks.AfterMemoryLoadHookOutput, error) { + return hooks.AfterMemoryLoadHookOutput{PromptContext: "filtered"}, nil + }, + }}, + }}, + }) + ctx, err := rt.runAfterMemoryLoadHooks(context.Background(), ExecuteMemoryRecallInput{ + RunID: "run-1", Iteration: 0, + }, hooks.MemoryLoadCall{Query: "q"}, "raw") + if err != nil { + t.Fatal(err) + } + if ctx != "filtered" { + t.Fatalf("ctx = %q", ctx) + } +} + +func TestRunBeforeMemoryStoreHooks_delegates(t *testing.T) { + rt := newTestRuntime(sdkruntime.AgentConfig{ + Hooks: []hooks.HookGroup{{ + Name: "scrub", + Hooks: hooks.AgentHooks{BeforeMemoryStore: []hooks.BeforeMemoryStoreHook{ + func(_ context.Context, in hooks.BeforeMemoryStoreHookInput) (hooks.BeforeMemoryStoreHookOutput, error) { + return hooks.BeforeMemoryStoreHookOutput{ + Record: interfaces.MemoryRecord{Text: "scrubbed"}, + ID: "id-1", + }, nil + }, + }}, + }}, + }) + got, err := rt.runBeforeMemoryStoreHooks(context.Background(), StoreMemoryRecordsInput{ + RunID: "run-1", Iteration: 1, + }, hooks.MemoryStoreCall{ + Record: interfaces.MemoryRecord{Text: "raw"}, + }) + if err != nil { + t.Fatal(err) + } + if got.Record.Text != "scrubbed" || got.ID != "id-1" { + t.Fatalf("call = %#v", got) + } +} + +func TestRunAfterMemoryStoreHooks_delegates(t *testing.T) { + var seen bool + rt := newTestRuntime(sdkruntime.AgentConfig{ + Hooks: []hooks.HookGroup{{ + Name: "audit", + Hooks: hooks.AgentHooks{AfterMemoryStore: []hooks.AfterMemoryStoreHook{ + func(_ context.Context, in hooks.AfterMemoryStoreHookInput) (hooks.AfterMemoryStoreHookOutput, error) { + seen = in.ID == "mem-id" + return hooks.AfterMemoryStoreHookOutput{}, nil + }, + }}, + }}, + }) + call := hooks.MemoryStoreCall{ + Record: interfaces.MemoryRecord{Text: "stored"}, + ID: "mem-id", + } + if err := rt.runAfterMemoryStoreHooks(context.Background(), StoreMemoryRecordsInput{ + RunID: "run-1", Iteration: 0, + }, call); err != nil { + t.Fatal(err) + } + if !seen { + t.Fatal("after store hook was not called") + } +} diff --git a/internal/runtime/base/memory.go b/internal/runtime/base/memory.go index 8c8547a..427ce0e 100644 --- a/internal/runtime/base/memory.go +++ b/internal/runtime/base/memory.go @@ -9,6 +9,7 @@ import ( "strings" "time" + "github.com/agenticenv/agent-sdk-go/internal/hooks" "github.com/agenticenv/agent-sdk-go/internal/types" "github.com/agenticenv/agent-sdk-go/pkg/interfaces" "github.com/agenticenv/agent-sdk-go/pkg/logger" @@ -22,18 +23,19 @@ const memoryExtractSystemPrompt = "Extract durable long-term memories from the c var errMemoryExtractUnavailable = errors.New("memory extract unavailable: StoreMode always requires custom Extract or LLM client") // StoreMemoryRecords persists records through kind policy, dedup, TTL, and the memory backend. -func (rt *Runtime) StoreMemoryRecords(ctx context.Context, log logger.Logger, scope interfaces.MemoryScope, records []interfaces.MemoryRecord) error { +func (rt *Runtime) StoreMemoryRecords(ctx context.Context, input StoreMemoryRecordsInput) error { if !rt.MemoryConfigured() { return nil } + log := input.Logger ctx, batchSp := rt.Tracer.StartSpan(ctx, "memory.store.batch", - interfaces.Attribute{Key: "record.count", Value: len(records)}, + interfaces.Attribute{Key: "record.count", Value: len(input.Records)}, ) defer batchSp.End() - for _, rec := range records { - if err := rt.storeRecord(ctx, log, scope, rec); err != nil { + for _, rec := range input.Records { + if err := rt.storeRecord(ctx, log, input, rec); err != nil { batchSp.RecordError(err) return err } @@ -41,9 +43,18 @@ func (rt *Runtime) StoreMemoryRecords(ctx context.Context, log logger.Logger, sc return nil } -func (rt *Runtime) storeRecord(ctx context.Context, log logger.Logger, scope interfaces.MemoryScope, rec interfaces.MemoryRecord) error { +func (rt *Runtime) storeRecord(ctx context.Context, log logger.Logger, input StoreMemoryRecordsInput, rec interfaces.MemoryRecord) error { cfg := rt.AgentConfig.Memory.Config - text := strings.TrimSpace(rec.Text) + scope := input.Scope + + storeCall := hooks.MemoryStoreCall{Scope: scope, Record: rec} + var err error + storeCall, err = rt.runBeforeMemoryStoreHooks(ctx, input, storeCall) + if err != nil { + return err + } + + text := strings.TrimSpace(storeCall.Record.Text) if text == "" { return nil } @@ -51,7 +62,7 @@ func (rt *Runtime) storeRecord(ctx context.Context, log logger.Logger, scope int ctx, sp := rt.Tracer.StartSpan(ctx, "memory.store") defer sp.End() - kind, err := cfg.Store.ResolveKind(rec.Kind) + kind, err := cfg.Store.ResolveKind(storeCall.Record.Kind) if err != nil { sp.RecordError(err) rt.Metrics.IncrementCounter(ctx, types.MetricMemoryStoreFailed) @@ -66,15 +77,22 @@ func (rt *Runtime) storeRecord(ctx context.Context, log logger.Logger, scope int rt.Metrics.IncrementCounter(ctx, types.MetricMemoryStoreStarted, kindAttr) start := time.Now() - storeOpts, dedupAction, dedupErr := rt.dedupStoreOptions(ctx, scope, text) - if dedupErr != nil { - latency := float64(time.Since(start).Milliseconds()) - sp.RecordError(dedupErr) - sp.SetAttribute("latency_ms", latency) - rt.Metrics.IncrementCounter(ctx, types.MetricMemoryStoreFailed, kindAttr) - rt.Metrics.RecordHistogram(ctx, types.MetricMemoryStoreLatencyMs, latency, kindAttr) - log.Error(ctx, "runtime: memory dedup lookup failed", slog.String("scope", "runtime"), slog.Any("error", dedupErr)) - return fmt.Errorf("memory store: dedup: %w", dedupErr) + var storeOpts []interfaces.StoreMemoryOption + var dedupAction string + if hookID := strings.TrimSpace(storeCall.ID); hookID != "" { + storeOpts = []interfaces.StoreMemoryOption{interfaces.WithMemoryID(hookID)} + dedupAction = "upsert" + } else { + storeOpts, dedupAction, err = rt.dedupStoreOptions(ctx, scope, text) + if err != nil { + latency := float64(time.Since(start).Milliseconds()) + sp.RecordError(err) + sp.SetAttribute("latency_ms", latency) + rt.Metrics.IncrementCounter(ctx, types.MetricMemoryStoreFailed, kindAttr) + rt.Metrics.RecordHistogram(ctx, types.MetricMemoryStoreLatencyMs, latency, kindAttr) + log.Error(ctx, "runtime: memory dedup lookup failed", slog.String("scope", "runtime"), slog.Any("error", err)) + return fmt.Errorf("memory store: dedup: %w", err) + } } dedupAttr := interfaces.Attribute{Key: types.MetricAttrMemoryDedup, Value: dedupAction} @@ -84,11 +102,12 @@ func (rt *Runtime) storeRecord(ctx context.Context, log logger.Logger, scope int record := interfaces.MemoryRecord{ Text: text, Kind: kind, - Metadata: rec.Metadata, + Metadata: storeCall.Record.Metadata, ExpiresAt: cfg.ExpiresAtForKind(kind, now), } - if _, err := cfg.Memory.Store(ctx, scope, record, storeOpts...); err != nil { + storedID, err := cfg.Memory.Store(ctx, scope, record, storeOpts...) + if err != nil { latency := float64(time.Since(start).Milliseconds()) sp.RecordError(err) sp.SetAttribute("latency_ms", latency) @@ -98,14 +117,29 @@ func (rt *Runtime) storeRecord(ctx context.Context, log logger.Logger, scope int return fmt.Errorf("memory store: %w", err) } + if err := rt.runAfterMemoryStoreHooks(ctx, input, hooks.MemoryStoreCall{ + Scope: scope, + Record: record, + ID: storedID, + }); err != nil { + latency := float64(time.Since(start).Milliseconds()) + sp.RecordError(err) + sp.SetAttribute("latency_ms", latency) + rt.Metrics.IncrementCounter(ctx, types.MetricMemoryStoreFailed, kindAttr) + rt.Metrics.RecordHistogram(ctx, types.MetricMemoryStoreLatencyMs, latency, kindAttr) + return err + } + latency := float64(time.Since(start).Milliseconds()) sp.SetAttribute("latency_ms", latency) sp.SetAttribute("dedup.upsert", dedupAction == "upsert") + sp.SetAttribute("memory.id", storedID) rt.Metrics.IncrementCounter(ctx, types.MetricMemoryStoreCompleted, kindAttr, dedupAttr) rt.Metrics.RecordHistogram(ctx, types.MetricMemoryStoreLatencyMs, latency, kindAttr) log.Debug(ctx, "runtime: memory store completed", slog.String("scope", "runtime"), - slog.String("dedup", dedupAction)) + slog.String("dedup", dedupAction), + slog.String("id", storedID)) return nil } @@ -352,3 +386,16 @@ func parseMemoryExtractResponse(content string) ([]interfaces.MemoryRecord, erro } return records, nil } + +func loadOptionsFromCall(call hooks.MemoryLoadCall, withMinScore bool) []interfaces.LoadMemoryOption { + opts := []interfaces.LoadMemoryOption{ + interfaces.WithLoadLimit(call.Limit), + } + if withMinScore && call.MinScore > 0 { + opts = append(opts, interfaces.WithMinScore(call.MinScore)) + } + if len(call.Kinds) > 0 { + opts = append(opts, interfaces.WithLoadKinds(call.Kinds...)) + } + return opts +} diff --git a/internal/runtime/base/runtime.go b/internal/runtime/base/runtime.go index 838842f..6953028 100644 --- a/internal/runtime/base/runtime.go +++ b/internal/runtime/base/runtime.go @@ -12,6 +12,7 @@ import ( "time" "github.com/agenticenv/agent-sdk-go/internal/events" + "github.com/agenticenv/agent-sdk-go/internal/hooks" "github.com/agenticenv/agent-sdk-go/internal/runtime" "github.com/agenticenv/agent-sdk-go/internal/types" "github.com/agenticenv/agent-sdk-go/pkg/interfaces" @@ -36,6 +37,8 @@ type ExecuteLLMInput struct { Logger logger.Logger AgentName string MessageID string + RunID string + Iteration int Messages []interfaces.Message SkipTools bool MemoryContext string @@ -152,6 +155,9 @@ func emitEvent(fn func(events.AgentEvent), ev events.AgentEvent) { // messageID and agentName are used only for event construction; emit may be nil. func (rt *Runtime) ExecuteLLM(ctx context.Context, input ExecuteLLMInput) (*LLMResult, error) { req := rt.BuildLLMRequest(input.Messages, input.SkipTools, input.MemoryContext, input.RetrieverContext, input.Tools) + if err := rt.runBeforeLLMRequest(ctx, input, req); err != nil { + return nil, err + } llmClient := rt.AgentConfig.LLM.Client model := llmClient.GetModel() @@ -181,6 +187,10 @@ func (rt *Runtime) ExecuteLLM(ctx context.Context, input ExecuteLLMInput) (*LLMR } sp.End() + if err := rt.runAfterLLMResponse(ctx, input, resp); err != nil { + return nil, err + } + rt.Metrics.RecordHistogram(ctx, types.MetricLLMLatencyMs, llmLatency, modelAttr, providerAttr) rt.Metrics.IncrementCounter(ctx, types.MetricLLMCallCompleted, modelAttr, providerAttr) if resp.Usage != nil { @@ -207,6 +217,9 @@ func (rt *Runtime) ExecuteLLM(ctx context.Context, input ExecuteLLMInput) (*LLMR // fallback. emit may be nil. func (rt *Runtime) ExecuteLLMStream(ctx context.Context, input ExecuteLLMInput) (*LLMResult, error) { req := rt.BuildLLMRequest(input.Messages, input.SkipTools, input.MemoryContext, input.RetrieverContext, input.Tools) + if err := rt.runBeforeLLMRequest(ctx, input, req); err != nil { + return nil, err + } llmClient := rt.AgentConfig.LLM.Client model := llmClient.GetModel() @@ -265,6 +278,12 @@ func (rt *Runtime) ExecuteLLMStream(ctx context.Context, input ExecuteLLMInput) rt.Metrics.RecordHistogram(ctx, types.MetricLLMLatencyMs, llmLatency, modelAttr, providerAttr) return nil, err } + if err := rt.runAfterLLMResponse(ctx, input, resp); err != nil { + sp.RecordError(err) + rt.Metrics.IncrementCounter(ctx, types.MetricLLMCallFailed, modelAttr, providerAttr) + rt.Metrics.RecordHistogram(ctx, types.MetricLLMLatencyMs, llmLatency, modelAttr, providerAttr) + return nil, err + } result, err := rt.llmResponseToResult(resp, input.Tools) if err != nil { sp.RecordError(err) @@ -350,6 +369,13 @@ func (rt *Runtime) ExecuteLLMStream(ctx context.Context, input ExecuteLLMInput) return nil, err } + if err := rt.runAfterLLMResponse(ctx, input, resp); err != nil { + sp.RecordError(err) + rt.Metrics.IncrementCounter(ctx, types.MetricLLMCallFailed, modelAttr, providerAttr) + rt.Metrics.RecordHistogram(ctx, types.MetricLLMLatencyMs, llmLatency, modelAttr, providerAttr) + return nil, err + } + result, err := rt.llmResponseToResult(resp, input.Tools) if err != nil { sp.RecordError(err) @@ -370,17 +396,44 @@ func (rt *Runtime) ExecuteLLMStream(ctx context.Context, input ExecuteLLMInput) return result, nil } -// ExecuteTool finds the named tool and executes it, recording tracing and metrics. +// ExecuteTool runs a tool with optional memory scope; save_memory on on-demand store routes to [StoreMemoryRecords]. +// Retriever tools use [Runtime.executeRetrieverTool] inside [Runtime.executeTool]. +// For [types.ToolKindNative] and [types.ToolKindMCP] tools, [hooks.BeforeToolHook] and +func (rt *Runtime) ExecuteTool(ctx context.Context, input ExecuteToolInput, memScope interfaces.MemoryScope) (string, error) { + if input.ToolName == types.SaveMemoryToolName && rt.MemoryStoreOnDemand() { + return rt.executeSaveMemoryTool(ctx, input, memScope) + } + return rt.executeTool(ctx, input) +} + +// executeTool finds the named tool and executes it, recording tracing and metrics. // Returns the string representation of the tool result. -func (rt *Runtime) ExecuteTool(ctx context.Context, log logger.Logger, tools []interfaces.Tool, toolName string, args map[string]any) (string, error) { +func (rt *Runtime) executeTool(ctx context.Context, input ExecuteToolInput) (string, error) { + log := input.Logger + toolName := input.ToolName + args := input.Args + log.Debug(ctx, "runtime: tool execute started", slog.String("scope", "runtime"), slog.String("tool", toolName), slog.Int("argCount", len(args))) - tool, ok := FindToolByName(tools, toolName) + tool, ok := FindToolByName(input.Tools, toolName) if !ok { log.Warn(ctx, "runtime: unknown tool", slog.String("scope", "runtime"), slog.String("tool", toolName)) return "", fmt.Errorf("unknown tool: %s", toolName) } + kind := types.KindOf(tool) + runHooks := len(rt.AgentConfig.Hooks) > 0 && kind.HooksEligible() + + var hookedCall hooks.ToolCall + if runHooks { + var err error + hookedCall, err = rt.runBeforeToolHooks(ctx, input, tool) + if err != nil { + return "", err + } + args = hookedCall.Args + } + toolAttr := interfaces.Attribute{Key: types.MetricAttrTool, Value: toolName} rt.Metrics.IncrementCounter(ctx, types.MetricToolCallStarted, toolAttr) toolStart := time.Now() @@ -391,30 +444,99 @@ func (rt *Runtime) ExecuteTool(ctx context.Context, log logger.Logger, tools []i ) defer sp.End() - result, err := tool.Execute(ctx, args) + var content string + var execErr error + switch kind { + case types.ToolKindRetriever: + content, execErr = rt.executeRetrieverTool(ctx, input, tool, args) + default: + var result any + result, execErr = tool.Execute(ctx, args) + if execErr == nil { + content = fmt.Sprintf("%v", result) + } + if runHooks { + var err error + content, execErr, err = rt.runAfterToolHooks(ctx, input, hookedCall, content, execErr) + if err != nil { + sp.RecordError(err) + rt.Metrics.IncrementCounter(ctx, types.MetricToolCallFailed, toolAttr) + rt.Metrics.RecordHistogram(ctx, types.MetricToolLatencyMs, float64(time.Since(toolStart).Milliseconds()), toolAttr) + return "", err + } + } + } + toolLatency := float64(time.Since(toolStart).Milliseconds()) - if err != nil { - sp.RecordError(err) + + if execErr != nil { + sp.RecordError(execErr) rt.Metrics.IncrementCounter(ctx, types.MetricToolCallFailed, toolAttr) rt.Metrics.RecordHistogram(ctx, types.MetricToolLatencyMs, toolLatency, toolAttr) - return "", err + return "", execErr } rt.Metrics.RecordHistogram(ctx, types.MetricToolLatencyMs, toolLatency, toolAttr) rt.Metrics.IncrementCounter(ctx, types.MetricToolCallCompleted, toolAttr) log.Debug(ctx, "runtime: tool execute completed", slog.String("scope", "runtime"), slog.String("tool", toolName)) - return fmt.Sprintf("%v", result), nil + return content, nil } -// ExecuteToolWithMemoryScope runs a tool; save_memory on on-demand store routes to [StoreMemoryRecords]. -func (rt *Runtime) ExecuteToolWithMemoryScope(ctx context.Context, log logger.Logger, tools []interfaces.Tool, toolName string, args map[string]any, memScope interfaces.MemoryScope) (string, error) { - if toolName == types.SaveMemoryToolName && rt.MemoryStoreOnDemand() { - return rt.executeSaveMemoryTool(ctx, log, memScope, args) +// executeRetrieverTool runs retrieve hooks, tool.Execute, and formats documents for the LLM. +// Metrics and tracing are handled by [Runtime.executeTool]. +func (rt *Runtime) executeRetrieverTool(ctx context.Context, input ExecuteToolInput, tool interfaces.Tool, args map[string]any) (string, error) { + retrieverName, ok := types.RetrieverNameFromToolName(tool.Name()) + if !ok { + return "", fmt.Errorf("retriever tool: invalid tool name %q", tool.Name()) + } + + query, err := types.RetrieverToolParamQueryValue(args) + if err != nil { + return "", err + } + + retrieveInput := ExecuteRetrieversInput{ + RunID: input.RunID, + Iteration: input.Iteration, + Query: query, } - return rt.ExecuteTool(ctx, log, tools, toolName, args) + + call, err := rt.runBeforeRetrieveHooks(ctx, retrieveInput, types.RetrieverModeAgentic, retrieverName) + if err != nil { + return "", err + } + + execArgs := map[string]any{types.RetrieverToolParamQuery: call.Query} + result, execErr := tool.Execute(ctx, execArgs) + if execErr != nil { + return "", execErr + } + + docs, ok := result.([]interfaces.Document) + if !ok { + return "", fmt.Errorf("retriever tool: unexpected result type %T", result) + } + + docs, err = rt.runAfterRetrieveHooks(ctx, retrieveInput, call, docs) + if err != nil { + return "", err + } + + content := FormatRetrieverDocs(docs) + if content == "" { + input.Logger.Warn(ctx, "runtime: retriever returned no documents", + slog.String("scope", "runtime"), + slog.String("tool", tool.Name()), + slog.String("retriever", retrieverName), + slog.String("query", call.Query), + ) + } + return content, nil } -func (rt *Runtime) executeSaveMemoryTool(ctx context.Context, log logger.Logger, scope interfaces.MemoryScope, args map[string]any) (string, error) { +func (rt *Runtime) executeSaveMemoryTool(ctx context.Context, input ExecuteToolInput, scope interfaces.MemoryScope) (string, error) { + log := input.Logger + args := input.Args toolAttr := interfaces.Attribute{Key: types.MetricAttrTool, Value: types.SaveMemoryToolName} rt.Metrics.IncrementCounter(ctx, types.MetricToolCallStarted, toolAttr) toolStart := time.Now() @@ -434,7 +556,13 @@ func (rt *Runtime) executeSaveMemoryTool(ctx context.Context, log logger.Logger, return "", err } - if err := rt.StoreMemoryRecords(ctx, log, scope, []interfaces.MemoryRecord{record}); err != nil { + if err := rt.StoreMemoryRecords(ctx, StoreMemoryRecordsInput{ + Logger: log, + RunID: input.RunID, + Iteration: input.Iteration, + Scope: scope, + Records: []interfaces.MemoryRecord{record}, + }); err != nil { toolLatency := float64(time.Since(toolStart).Milliseconds()) sp.RecordError(err) rt.Metrics.IncrementCounter(ctx, types.MetricToolCallFailed, toolAttr) @@ -494,12 +622,17 @@ func (rt *Runtime) AuthorizeTool(ctx context.Context, log logger.Logger, tools [ // ExecuteRetrievers runs all configured retrievers in parallel for the given query and // returns a combined document context string for injection into the LLM system prompt. // Partial failures are logged and skipped; all retrievers failing returns an error. -func (rt *Runtime) ExecuteRetrievers(ctx context.Context, log logger.Logger, query string) (*RetrieverResult, error) { +func (rt *Runtime) ExecuteRetrievers(ctx context.Context, input ExecuteRetrieversInput) (*RetrieverResult, error) { + log := input.Logger + query := input.Query + retrievers := rt.AgentConfig.Retrievers.Retrievers if len(retrievers) == 0 { return &RetrieverResult{}, nil } + mode := rt.AgentConfig.Retrievers.Mode + log.Debug(ctx, "runtime: retriever prefetch started", slog.String("scope", "runtime"), slog.Int("retrieverCount", len(retrievers)), slog.String("query", query)) type retrieverResult struct { @@ -517,25 +650,45 @@ func (rt *Runtime) ExecuteRetrievers(ctx context.Context, log logger.Logger, que name := ret.Name() retrieverAttr := interfaces.Attribute{Key: types.MetricAttrRetriever, Value: name} rt.Metrics.IncrementCounter(ctx, types.MetricRetrieverCallStarted, retrieverAttr) - start := time.Now() + + call, hookErr := rt.runBeforeRetrieveHooks(ctx, input, mode, name) + if hookErr != nil { + rt.Metrics.IncrementCounter(ctx, types.MetricRetrieverCallFailed, retrieverAttr) + results[idx] = retrieverResult{name: name, err: hookErr} + return + } searchCtx, sp := rt.Tracer.StartSpan(ctx, "retriever.search", interfaces.Attribute{Key: "retriever.name", Value: name}, - interfaces.Attribute{Key: "query", Value: query}, + interfaces.Attribute{Key: "query", Value: call.Query}, ) - docs, err := ret.Search(searchCtx, query) + + start := time.Now() + docs, err := ret.Search(searchCtx, call.Query) latency := float64(time.Since(start).Milliseconds()) if err != nil { sp.RecordError(err) sp.End() rt.Metrics.IncrementCounter(ctx, types.MetricRetrieverCallFailed, retrieverAttr) rt.Metrics.RecordHistogram(ctx, types.MetricRetrieverLatencyMs, latency, retrieverAttr) - } else { + results[idx] = retrieverResult{name: name, docs: docs, err: err} + return + } + + docs, hookErr = rt.runAfterRetrieveHooks(searchCtx, input, call, docs) + if hookErr != nil { + sp.RecordError(hookErr) sp.End() - rt.Metrics.IncrementCounter(ctx, types.MetricRetrieverCallCompleted, retrieverAttr) + rt.Metrics.IncrementCounter(ctx, types.MetricRetrieverCallFailed, retrieverAttr) rt.Metrics.RecordHistogram(ctx, types.MetricRetrieverLatencyMs, latency, retrieverAttr) + results[idx] = retrieverResult{name: name, err: hookErr} + return } - results[idx] = retrieverResult{name: name, docs: docs, err: err} + + sp.End() + rt.Metrics.IncrementCounter(ctx, types.MetricRetrieverCallCompleted, retrieverAttr) + rt.Metrics.RecordHistogram(ctx, types.MetricRetrieverLatencyMs, latency, retrieverAttr) + results[idx] = retrieverResult{name: name, docs: docs, err: nil} }(i, r) } wg.Wait() @@ -621,12 +774,16 @@ func FormatMemoryEntries(entries []interfaces.MemoryEntry) string { } // ExecuteMemoryRecall loads scoped memories for query and returns formatted prompt context. -func (rt *Runtime) ExecuteMemoryRecall(ctx context.Context, log logger.Logger, scope interfaces.MemoryScope, query string) (*MemoryResult, error) { +func (rt *Runtime) ExecuteMemoryRecall(ctx context.Context, input ExecuteMemoryRecallInput) (*MemoryResult, error) { cfg := rt.AgentConfig.Memory.Config if cfg == nil || cfg.Memory == nil { return &MemoryResult{}, nil } + log := input.Logger + query := input.Query + scope := input.Scope + log.Debug(ctx, "runtime: memory recall started", slog.String("scope", "runtime"), slog.String("query", query)) @@ -639,7 +796,25 @@ func (rt *Runtime) ExecuteMemoryRecall(ctx context.Context, log logger.Logger, s ) defer sp.End() - entries, err := cfg.Memory.Load(ctx, scope, query, cfg.Recall.LoadOptions()...) + recall := cfg.Recall + loadCall := hooks.MemoryLoadCall{ + Scope: scope, + Query: query, + Limit: recall.Limit, + MinScore: recall.MinScore, + Kinds: recall.Kinds, + } + loadCall, err := rt.runBeforeMemoryLoadHooks(ctx, input, loadCall) + if err != nil { + latency := float64(time.Since(start).Milliseconds()) + sp.RecordError(err) + sp.SetAttribute("latency_ms", latency) + rt.Metrics.IncrementCounter(ctx, types.MetricMemoryRecallFailed) + rt.Metrics.RecordHistogram(ctx, types.MetricMemoryRecallLatencyMs, latency) + return nil, err + } + + entries, err := cfg.Memory.Load(ctx, loadCall.Scope, loadCall.Query, loadOptionsFromCall(loadCall, true)...) if err != nil { latency := float64(time.Since(start).Milliseconds()) sp.RecordError(err) @@ -651,10 +826,10 @@ func (rt *Runtime) ExecuteMemoryRecall(ctx context.Context, log logger.Logger, s } // Semantic recall often misses distilled memories; fall back to scoped recency list. - if len(entries) == 0 && strings.TrimSpace(query) != "" { + if len(entries) == 0 && strings.TrimSpace(loadCall.Query) != "" { log.Debug(ctx, "runtime: memory recall semantic empty, trying recency fallback", slog.String("scope", "runtime")) - fallback, fbErr := cfg.Memory.Load(ctx, scope, "", cfg.Recall.RecencyLoadOptions()...) + fallback, fbErr := cfg.Memory.Load(ctx, loadCall.Scope, "", cfg.Recall.RecencyLoadOptions()...) if fbErr != nil { latency := float64(time.Since(start).Milliseconds()) sp.RecordError(fbErr) @@ -667,8 +842,18 @@ func (rt *Runtime) ExecuteMemoryRecall(ctx context.Context, log logger.Logger, s entries = fallback } - latency := float64(time.Since(start).Milliseconds()) memoryContext := strings.TrimSpace(FormatMemoryEntries(entries)) + memoryContext, err = rt.runAfterMemoryLoadHooks(ctx, input, loadCall, memoryContext) + if err != nil { + latency := float64(time.Since(start).Milliseconds()) + sp.RecordError(err) + sp.SetAttribute("latency_ms", latency) + rt.Metrics.IncrementCounter(ctx, types.MetricMemoryRecallFailed) + rt.Metrics.RecordHistogram(ctx, types.MetricMemoryRecallLatencyMs, latency) + return nil, err + } + + latency := float64(time.Since(start).Milliseconds()) sp.SetAttribute("entry.count", len(entries)) sp.SetAttribute("has_context", memoryContext != "") sp.SetAttribute("latency_ms", latency) @@ -687,23 +872,30 @@ func (rt *Runtime) ExecuteMemoryRecall(ctx context.Context, log logger.Logger, s } // ExecuteMemoryStore extracts long-term memories from the run and persists them in scope. -func (rt *Runtime) ExecuteMemoryStore(ctx context.Context, log logger.Logger, scope interfaces.MemoryScope, messages []interfaces.Message) error { +func (rt *Runtime) ExecuteMemoryStore(ctx context.Context, input ExecuteMemoryStoreInput) error { if !rt.RunEndMemoryStoreEnabled() { return nil } + log := input.Logger extract := rt.resolveMemoryExtractFunc() if extract == nil { rt.recordMemoryExtractUnavailable(ctx, log) return nil } - records, err := rt.extractMemoryRecords(ctx, log, messages, extract) + records, err := rt.extractMemoryRecords(ctx, log, input.Messages, extract) if err != nil { return err } if len(records) == 0 { return nil } - return rt.StoreMemoryRecords(ctx, log, scope, records) + return rt.StoreMemoryRecords(ctx, StoreMemoryRecordsInput{ + Logger: log, + RunID: input.RunID, + Iteration: input.Iteration, + Scope: input.Scope, + Records: records, + }) } diff --git a/internal/runtime/base/runtime_test.go b/internal/runtime/base/runtime_test.go index 3f58991..6447c45 100644 --- a/internal/runtime/base/runtime_test.go +++ b/internal/runtime/base/runtime_test.go @@ -8,6 +8,7 @@ import ( "time" "github.com/agenticenv/agent-sdk-go/internal/events" + "github.com/agenticenv/agent-sdk-go/internal/hooks" sdkruntime "github.com/agenticenv/agent-sdk-go/internal/runtime" testutil "github.com/agenticenv/agent-sdk-go/internal/testing" "github.com/agenticenv/agent-sdk-go/internal/types" @@ -35,6 +36,18 @@ func newTestRuntime(exec sdkruntime.AgentConfig) *Runtime { func noopLog() logger.Logger { return logger.NoopLogger() } +func storeMemoryRecords(rt *Runtime, ctx context.Context, scope interfaces.MemoryScope, records []interfaces.MemoryRecord) error { + return rt.StoreMemoryRecords(ctx, StoreMemoryRecordsInput{ + Logger: noopLog(), Scope: scope, Records: records, + }) +} + +func executeMemoryStore(rt *Runtime, ctx context.Context, scope interfaces.MemoryScope, messages []interfaces.Message) error { + return rt.ExecuteMemoryStore(ctx, ExecuteMemoryStoreInput{ + Logger: noopLog(), Scope: scope, Messages: messages, + }) +} + // stubLLMClient is a minimal LLMClient that returns a fixed response. type stubLLMClient struct { resp *interfaces.LLMResponse @@ -51,6 +64,26 @@ func (stubLLMClient) GetModel() string { return "stub" } func (stubLLMClient) GetProvider() interfaces.LLMProvider { return interfaces.LLMProviderOpenAI } func (stubLLMClient) IsStreamSupported() bool { return false } +type captureLLMClient struct { + lastReq *interfaces.LLMRequest + resp *interfaces.LLMResponse + err error +} + +func (c *captureLLMClient) Generate(_ context.Context, req *interfaces.LLMRequest) (*interfaces.LLMResponse, error) { + if req != nil { + copy := *req + c.lastReq = © + } + return c.resp, c.err +} +func (captureLLMClient) GenerateStream(context.Context, *interfaces.LLMRequest) (interfaces.LLMStream, error) { + return nil, errors.New("stream not implemented") +} +func (captureLLMClient) GetModel() string { return "stub" } +func (captureLLMClient) GetProvider() interfaces.LLMProvider { return interfaces.LLMProviderOpenAI } +func (captureLLMClient) IsStreamSupported() bool { return false } + // --- BuildLLMRequest --- func TestBuildLLMRequest_Basic(t *testing.T) { @@ -188,7 +221,7 @@ func TestFetchConversationMessages_Error(t *testing.T) { func TestExecuteTool_UnknownTool(t *testing.T) { rt := newTestRuntime(sdkruntime.AgentConfig{}) - _, err := rt.ExecuteTool(context.Background(), noopLog(), nil, "missing", nil) + _, err := rt.ExecuteTool(context.Background(), ExecuteToolInput{Logger: noopLog(), ToolName: "missing"}, interfaces.MemoryScope{}) require.Error(t, err) require.Contains(t, err.Error(), "unknown tool") } @@ -200,7 +233,9 @@ func TestExecuteTool_Success(t *testing.T) { tool.EXPECT().Execute(gomock.Any(), gomock.Any()).Return("42", nil) rt := newTestRuntime(sdkruntime.AgentConfig{}) - result, err := rt.ExecuteTool(context.Background(), noopLog(), []interfaces.Tool{tool}, "calc", map[string]any{"x": 1}) + result, err := rt.ExecuteTool(context.Background(), ExecuteToolInput{ + Logger: noopLog(), Tools: []interfaces.Tool{tool}, ToolName: "calc", Args: map[string]any{"x": 1}, + }, interfaces.MemoryScope{}) require.NoError(t, err) require.Equal(t, "42", result) } @@ -212,11 +247,244 @@ func TestExecuteTool_ToolError(t *testing.T) { tool.EXPECT().Execute(gomock.Any(), gomock.Any()).Return(nil, errors.New("tool failed")) rt := newTestRuntime(sdkruntime.AgentConfig{}) - _, err := rt.ExecuteTool(context.Background(), noopLog(), []interfaces.Tool{tool}, "fail-tool", nil) + _, err := rt.ExecuteTool(context.Background(), ExecuteToolInput{ + Logger: noopLog(), Tools: []interfaces.Tool{tool}, ToolName: "fail-tool", + }, interfaces.MemoryScope{}) require.Error(t, err) require.Contains(t, err.Error(), "tool failed") } +func TestExecuteTool_BeforeToolModifiesArgs(t *testing.T) { + ctrl := gomock.NewController(t) + tool := ifmocks.NewMockTool(ctrl) + tool.EXPECT().Name().Return("calc").AnyTimes() + tool.EXPECT().DisplayName().Return("calc").AnyTimes() + tool.EXPECT().Execute(gomock.Any(), map[string]any{"x": 99}).Return("42", nil) + + rt := newTestRuntime(sdkruntime.AgentConfig{ + Hooks: []hooks.HookGroup{{ + Name: "guard", + Hooks: hooks.AgentHooks{BeforeTool: []hooks.BeforeToolHook{ + func(_ context.Context, in hooks.BeforeToolHookInput) (hooks.BeforeToolHookOutput, error) { + return hooks.BeforeToolHookOutput{Args: map[string]any{"x": 99}}, nil + }, + }}, + }}, + }) + _, err := rt.ExecuteTool(context.Background(), ExecuteToolInput{ + Logger: noopLog(), Tools: []interfaces.Tool{tool}, ToolName: "calc", + Args: map[string]any{"x": 1}, ToolCallID: "tc-1", RunID: "run-1", + }, interfaces.MemoryScope{}) + require.NoError(t, err) +} + +func TestExecuteTool_AfterToolModifiesResult(t *testing.T) { + ctrl := gomock.NewController(t) + tool := ifmocks.NewMockTool(ctrl) + tool.EXPECT().Name().Return("calc").AnyTimes() + tool.EXPECT().DisplayName().Return("calc").AnyTimes() + tool.EXPECT().Execute(gomock.Any(), gomock.Any()).Return("raw", nil) + + rt := newTestRuntime(sdkruntime.AgentConfig{ + Hooks: []hooks.HookGroup{{ + Name: "scrub", + Hooks: hooks.AgentHooks{AfterTool: []hooks.AfterToolHook{ + func(context.Context, hooks.AfterToolHookInput) (hooks.AfterToolHookOutput, error) { + return hooks.AfterToolHookOutput{Content: "scrubbed"}, nil + }, + }}, + }}, + }) + result, err := rt.ExecuteTool(context.Background(), ExecuteToolInput{ + Logger: noopLog(), Tools: []interfaces.Tool{tool}, ToolName: "calc", + }, interfaces.MemoryScope{}) + require.NoError(t, err) + require.Equal(t, "scrubbed", result) +} + +type memoryKindTool struct{} + +func (memoryKindTool) Name() string { return "mem_tool" } +func (memoryKindTool) DisplayName() string { return "" } +func (memoryKindTool) Description() string { return "" } +func (memoryKindTool) Parameters() interfaces.JSONSchema { return nil } +func (memoryKindTool) Execute(context.Context, map[string]any) (any, error) { + return "ok", nil +} +func (memoryKindTool) ToolKind() types.ToolKind { return types.ToolKindMemory } + +func TestExecuteTool_NonHookEligibleKindSkipsHooks(t *testing.T) { + var beforeCalled bool + rt := newTestRuntime(sdkruntime.AgentConfig{ + Hooks: []hooks.HookGroup{{ + Name: "guard", + Hooks: hooks.AgentHooks{BeforeTool: []hooks.BeforeToolHook{ + func(context.Context, hooks.BeforeToolHookInput) (hooks.BeforeToolHookOutput, error) { + beforeCalled = true + return hooks.BeforeToolHookOutput{}, nil + }, + }}, + }}, + }) + _, err := rt.ExecuteTool(context.Background(), ExecuteToolInput{ + Logger: noopLog(), Tools: []interfaces.Tool{memoryKindTool{}}, ToolName: "mem_tool", + }, interfaces.MemoryScope{}) + require.NoError(t, err) + require.False(t, beforeCalled) +} + +type testRetriever struct { + name string + lastQuery string + docs []interfaces.Document + err error +} + +func (r *testRetriever) Name() string { return r.name } + +func (r *testRetriever) Search(_ context.Context, query string) ([]interfaces.Document, error) { + r.lastQuery = query + if r.err != nil { + return nil, r.err + } + return r.docs, nil +} + +type testRetrieverTool struct { + r *testRetriever +} + +func (t *testRetrieverTool) Name() string { return types.RetrieverToolName(t.r.name) } +func (t *testRetrieverTool) DisplayName() string { return types.RetrieverToolDisplayName(t.r.name) } +func (t *testRetrieverTool) Description() string { return "test retriever tool" } +func (t *testRetrieverTool) Parameters() interfaces.JSONSchema { + return interfaces.JSONSchema{"type": "object"} +} +func (t *testRetrieverTool) ToolKind() types.ToolKind { return types.ToolKindRetriever } +func (t *testRetrieverTool) Execute(ctx context.Context, args map[string]any) (any, error) { + raw, ok := args[types.RetrieverToolParamQuery].(string) + if !ok { + return nil, errors.New("query required") + } + query := strings.TrimSpace(raw) + if query == "" { + return nil, errors.New("query empty") + } + return t.r.Search(ctx, query) +} + +func newTestRetrieverTool(r *testRetriever) interfaces.Tool { + return &testRetrieverTool{r: r} +} + +func TestExecuteTool_RetrieverTool_FormatsDocs(t *testing.T) { + stub := &testRetriever{ + name: "kb", + docs: []interfaces.Document{{Content: "doc", Source: "s", Score: 0.9}}, + } + tool := newTestRetrieverTool(stub) + rt := newTestRuntime(sdkruntime.AgentConfig{}) + got, err := rt.ExecuteTool(context.Background(), ExecuteToolInput{ + Logger: noopLog(), Tools: []interfaces.Tool{tool}, ToolName: "retriever_kb", + Args: map[string]any{types.RetrieverToolParamQuery: "q"}, + }, interfaces.MemoryScope{}) + require.NoError(t, err) + require.Contains(t, got, "doc") + require.Equal(t, "q", stub.lastQuery) +} + +func TestExecuteTool_RetrieverTool_BeforeRetrieveRewritesQuery(t *testing.T) { + stub := &testRetriever{ + name: "kb", + docs: []interfaces.Document{{Content: "doc", Source: "s", Score: 0.9}}, + } + tool := newTestRetrieverTool(stub) + rt := newTestRuntime(sdkruntime.AgentConfig{ + Hooks: []hooks.HookGroup{{ + Name: "rewrite", + Hooks: hooks.AgentHooks{BeforeRetrieve: []hooks.BeforeRetrieveHook{ + func(_ context.Context, in hooks.BeforeRetrieveHookInput) (hooks.BeforeRetrieveHookOutput, error) { + if in.RunMeta.Iteration != 2 || in.RetrieverName != "kb" || in.Mode != types.RetrieverModeAgentic { + t.Fatalf("input = %#v", in) + } + return hooks.BeforeRetrieveHookOutput{Query: "hooked"}, nil + }, + }}, + }}, + }) + _, err := rt.ExecuteTool(context.Background(), ExecuteToolInput{ + Logger: noopLog(), Tools: []interfaces.Tool{tool}, ToolName: "retriever_kb", + Args: map[string]any{types.RetrieverToolParamQuery: "raw"}, + RunID: "run-1", Iteration: 2, + }, interfaces.MemoryScope{}) + require.NoError(t, err) + require.Equal(t, "hooked", stub.lastQuery) +} + +func TestExecuteTool_RetrieverTool_AfterRetrieveFiltersDocs(t *testing.T) { + stub := &testRetriever{ + name: "kb", + docs: []interfaces.Document{ + {Content: "drop", Source: "s", Score: 0.5}, + {Content: "keep", Source: "s", Score: 0.9}, + }, + } + tool := newTestRetrieverTool(stub) + rt := newTestRuntime(sdkruntime.AgentConfig{ + Hooks: []hooks.HookGroup{{ + Name: "filter", + Hooks: hooks.AgentHooks{AfterRetrieve: []hooks.AfterRetrieveHook{ + func(_ context.Context, in hooks.AfterRetrieveHookInput) (hooks.AfterRetrieveHookOutput, error) { + return hooks.AfterRetrieveHookOutput{Documents: []interfaces.Document{in.Documents[1]}}, nil + }, + }}, + }}, + }) + got, err := rt.ExecuteTool(context.Background(), ExecuteToolInput{ + Logger: noopLog(), Tools: []interfaces.Tool{tool}, ToolName: "retriever_kb", + Args: map[string]any{types.RetrieverToolParamQuery: "q"}, + }, interfaces.MemoryScope{}) + require.NoError(t, err) + require.Contains(t, got, "keep") + require.NotContains(t, got, "drop") +} + +func TestExecuteTool_RetrieverTool_EmptyDocs(t *testing.T) { + stub := &testRetriever{name: "kb"} + tool := newTestRetrieverTool(stub) + rt := newTestRuntime(sdkruntime.AgentConfig{}) + got, err := rt.ExecuteTool(context.Background(), ExecuteToolInput{ + Logger: noopLog(), Tools: []interfaces.Tool{tool}, ToolName: "retriever_kb", + Args: map[string]any{types.RetrieverToolParamQuery: "q"}, + }, interfaces.MemoryScope{}) + require.NoError(t, err) + require.Equal(t, "", got) +} + +func TestExecuteTool_RetrieverTool_BeforeRetrieveAbort(t *testing.T) { + stub := &testRetriever{ + name: "kb", + docs: []interfaces.Document{{Content: "doc", Source: "s", Score: 0.9}}, + } + tool := newTestRetrieverTool(stub) + rt := newTestRuntime(sdkruntime.AgentConfig{ + Hooks: []hooks.HookGroup{{ + Name: "block", + Hooks: hooks.AgentHooks{BeforeRetrieve: []hooks.BeforeRetrieveHook{ + func(context.Context, hooks.BeforeRetrieveHookInput) (hooks.BeforeRetrieveHookOutput, error) { + return hooks.BeforeRetrieveHookOutput{}, errors.New("blocked") + }, + }}, + }}, + }) + _, err := rt.ExecuteTool(context.Background(), ExecuteToolInput{ + Logger: noopLog(), Tools: []interfaces.Tool{tool}, ToolName: "retriever_kb", + Args: map[string]any{types.RetrieverToolParamQuery: "q"}, + }, interfaces.MemoryScope{}) + require.Error(t, err) + require.Contains(t, err.Error(), "blocked") +} + // --- AuthorizeTool --- func TestAuthorizeTool_UnknownTool(t *testing.T) { @@ -257,7 +525,7 @@ func TestAuthorizeTool_Denied(t *testing.T) { func TestExecuteRetrievers_NoRetrievers(t *testing.T) { rt := newTestRuntime(sdkruntime.AgentConfig{}) - got, err := rt.ExecuteRetrievers(context.Background(), noopLog(), "query") + got, err := rt.ExecuteRetrievers(context.Background(), ExecuteRetrieversInput{Logger: noopLog(), Query: "query"}) require.NoError(t, err) require.Equal(t, "", got.Context) require.Equal(t, int64(0), got.TotalSearches) @@ -272,7 +540,7 @@ func TestExecuteRetrievers_AllFail(t *testing.T) { rt := newTestRuntime(sdkruntime.AgentConfig{ Retrievers: sdkruntime.AgentRetrievers{Retrievers: []interfaces.Retriever{r}}, }) - got, err := rt.ExecuteRetrievers(context.Background(), noopLog(), "q") + got, err := rt.ExecuteRetrievers(context.Background(), ExecuteRetrieversInput{Logger: noopLog(), Query: "q"}) require.NoError(t, err) require.Equal(t, "", got.Context) require.Equal(t, int64(1), got.TotalSearches) @@ -290,15 +558,203 @@ func TestExecuteRetrievers_Success(t *testing.T) { rt := newTestRuntime(sdkruntime.AgentConfig{ Retrievers: sdkruntime.AgentRetrievers{Retrievers: []interfaces.Retriever{r}}, }) - got, err := rt.ExecuteRetrievers(context.Background(), noopLog(), "my query") + got, err := rt.ExecuteRetrievers(context.Background(), ExecuteRetrieversInput{Logger: noopLog(), Query: "my query"}) require.NoError(t, err) require.Contains(t, got.Context, "doc content") require.Equal(t, int64(1), got.TotalSearches) require.Equal(t, int64(0), got.FailedSearches) } +func TestExecuteRetrievers_BeforeRetrieveRewritesQuery(t *testing.T) { + ctrl := gomock.NewController(t) + r := ifmocks.NewMockRetriever(ctrl) + r.EXPECT().Name().Return("kb").AnyTimes() + r.EXPECT().Search(gomock.Any(), "hooked").Return([]interfaces.Document{ + {Content: "doc", Source: "s", Score: 0.9}, + }, nil) + + rt := newTestRuntime(sdkruntime.AgentConfig{ + Retrievers: sdkruntime.AgentRetrievers{ + Mode: types.RetrieverModePrefetch, + Retrievers: []interfaces.Retriever{r}, + }, + Hooks: []hooks.HookGroup{{ + Name: "rewrite", + Hooks: hooks.AgentHooks{BeforeRetrieve: []hooks.BeforeRetrieveHook{ + func(_ context.Context, in hooks.BeforeRetrieveHookInput) (hooks.BeforeRetrieveHookOutput, error) { + if in.RunMeta.RunID != "run-r" || in.RunMeta.Iteration != 0 || in.RunMeta.HooksGroup != "rewrite" { + t.Fatalf("RunMeta = %#v", in.RunMeta) + } + if in.RetrieverName != "kb" || in.Mode != types.RetrieverModePrefetch { + t.Fatalf("input = %#v", in) + } + return hooks.BeforeRetrieveHookOutput{Query: "hooked"}, nil + }, + }}, + }}, + }) + got, err := rt.ExecuteRetrievers(context.Background(), ExecuteRetrieversInput{ + Logger: noopLog(), RunID: "run-r", Iteration: 0, Query: "raw", + }) + require.NoError(t, err) + require.Contains(t, got.Context, "doc") +} + +func TestExecuteRetrievers_AfterRetrieveFiltersDocs(t *testing.T) { + ctrl := gomock.NewController(t) + r := ifmocks.NewMockRetriever(ctrl) + r.EXPECT().Name().Return("kb").AnyTimes() + r.EXPECT().Search(gomock.Any(), "q").Return([]interfaces.Document{ + {Content: "drop", Source: "s", Score: 0.5}, + {Content: "keep", Source: "s", Score: 0.9}, + }, nil) + + rt := newTestRuntime(sdkruntime.AgentConfig{ + Retrievers: sdkruntime.AgentRetrievers{Retrievers: []interfaces.Retriever{r}}, + Hooks: []hooks.HookGroup{{ + Name: "filter", + Hooks: hooks.AgentHooks{AfterRetrieve: []hooks.AfterRetrieveHook{ + func(_ context.Context, in hooks.AfterRetrieveHookInput) (hooks.AfterRetrieveHookOutput, error) { + return hooks.AfterRetrieveHookOutput{ + Documents: []interfaces.Document{in.Documents[1]}, + }, nil + }, + }}, + }}, + }) + got, err := rt.ExecuteRetrievers(context.Background(), ExecuteRetrieversInput{Logger: noopLog(), Query: "q"}) + require.NoError(t, err) + require.Contains(t, got.Context, "keep") + require.NotContains(t, got.Context, "drop") +} + +func TestExecuteRetrievers_BeforeRetrieveAbort(t *testing.T) { + ctrl := gomock.NewController(t) + r := ifmocks.NewMockRetriever(ctrl) + r.EXPECT().Name().Return("kb").AnyTimes() + // Search must not be called when BeforeRetrieve aborts. + + rt := newTestRuntime(sdkruntime.AgentConfig{ + Retrievers: sdkruntime.AgentRetrievers{Retrievers: []interfaces.Retriever{r}}, + Hooks: []hooks.HookGroup{{ + Name: "block", + Hooks: hooks.AgentHooks{BeforeRetrieve: []hooks.BeforeRetrieveHook{ + func(context.Context, hooks.BeforeRetrieveHookInput) (hooks.BeforeRetrieveHookOutput, error) { + return hooks.BeforeRetrieveHookOutput{}, errors.New("blocked") + }, + }}, + }}, + }) + got, err := rt.ExecuteRetrievers(context.Background(), ExecuteRetrieversInput{Logger: noopLog(), Query: "q"}) + require.NoError(t, err) + require.Equal(t, "", got.Context) + require.Equal(t, int64(1), got.TotalSearches) + require.Equal(t, int64(1), got.FailedSearches) +} + +func TestExecuteRetrievers_AfterRetrieveAbort(t *testing.T) { + ctrl := gomock.NewController(t) + r := ifmocks.NewMockRetriever(ctrl) + r.EXPECT().Name().Return("kb").AnyTimes() + r.EXPECT().Search(gomock.Any(), "q").Return([]interfaces.Document{ + {Content: "doc", Source: "s", Score: 0.9}, + }, nil) + + rt := newTestRuntime(sdkruntime.AgentConfig{ + Retrievers: sdkruntime.AgentRetrievers{Retrievers: []interfaces.Retriever{r}}, + Hooks: []hooks.HookGroup{{ + Name: "block", + Hooks: hooks.AgentHooks{AfterRetrieve: []hooks.AfterRetrieveHook{ + func(context.Context, hooks.AfterRetrieveHookInput) (hooks.AfterRetrieveHookOutput, error) { + return hooks.AfterRetrieveHookOutput{}, errors.New("blocked") + }, + }}, + }}, + }) + got, err := rt.ExecuteRetrievers(context.Background(), ExecuteRetrieversInput{Logger: noopLog(), Query: "q"}) + require.NoError(t, err) + require.Equal(t, "", got.Context) + require.Equal(t, int64(1), got.TotalSearches) + require.Equal(t, int64(1), got.FailedSearches) +} + // --- ExecuteLLM --- +func TestExecuteLLM_BeforeLLMModifiesRequest(t *testing.T) { + llm := &captureLLMClient{resp: &interfaces.LLMResponse{Content: "ok"}} + rt := newTestRuntime(sdkruntime.AgentConfig{ + LLM: sdkruntime.AgentLLM{Client: llm}, + Hooks: []hooks.HookGroup{{ + Name: "guardrails", + Hooks: hooks.AgentHooks{BeforeLLM: []hooks.BeforeLLMHook{ + func(_ context.Context, in hooks.BeforeLLMHookInput) (hooks.BeforeLLMHookOutput, error) { + if in.RunMeta.RunID != "run-42" || in.RunMeta.Iteration != 2 || in.RunMeta.HooksGroup != "guardrails" { + t.Fatalf("RunMeta = %#v", in.RunMeta) + } + out := in.Request + out.SystemMessage = "hooked" + return hooks.BeforeLLMHookOutput{Request: out}, nil + }, + }}, + }}, + }) + input := ExecuteLLMInput{ + Logger: noopLog(), + AgentName: "agent", + MessageID: "msg-1", + RunID: "run-42", + Iteration: 2, + } + _, err := rt.ExecuteLLM(context.Background(), input) + require.NoError(t, err) + require.NotNil(t, llm.lastReq) + require.Equal(t, "hooked", llm.lastReq.SystemMessage) +} + +func TestExecuteLLM_AfterLLMModifiesResponse(t *testing.T) { + rt := newTestRuntime(sdkruntime.AgentConfig{ + LLM: sdkruntime.AgentLLM{Client: stubLLMClient{ + resp: &interfaces.LLMResponse{Content: "raw"}, + }}, + Hooks: []hooks.HookGroup{{ + Name: "scrub", + Hooks: hooks.AgentHooks{AfterLLM: []hooks.AfterLLMHook{ + func(_ context.Context, in hooks.AfterLLMHookInput) (hooks.AfterLLMHookOutput, error) { + out := in.Response + out.Content = "scrubbed" + return hooks.AfterLLMHookOutput{Response: out}, nil + }, + }}, + }}, + }) + result, err := rt.ExecuteLLM(context.Background(), ExecuteLLMInput{ + Logger: noopLog(), AgentName: "agent", MessageID: "msg-1", + }) + require.NoError(t, err) + require.Equal(t, "scrubbed", result.Content) +} + +func TestExecuteLLM_BeforeLLMErrorAborts(t *testing.T) { + rt := newTestRuntime(sdkruntime.AgentConfig{ + LLM: sdkruntime.AgentLLM{Client: stubLLMClient{ + resp: &interfaces.LLMResponse{Content: "should not run"}, + }}, + Hooks: []hooks.HookGroup{{ + Name: "block", + Hooks: hooks.AgentHooks{BeforeLLM: []hooks.BeforeLLMHook{ + func(context.Context, hooks.BeforeLLMHookInput) (hooks.BeforeLLMHookOutput, error) { + return hooks.BeforeLLMHookOutput{}, errors.New("blocked") + }, + }}, + }}, + }) + _, err := rt.ExecuteLLM(context.Background(), ExecuteLLMInput{ + Logger: noopLog(), AgentName: "agent", MessageID: "msg-1", + }) + require.Error(t, err) + require.Contains(t, err.Error(), "blocked") +} + func TestExecuteLLM_LLMError(t *testing.T) { rt := newTestRuntime(sdkruntime.AgentConfig{ LLM: sdkruntime.AgentLLM{Client: stubLLMClient{err: errors.New("llm unavailable")}}, @@ -541,7 +997,7 @@ func TestExecuteRetrievers_PartialFailure(t *testing.T) { rt := newTestRuntime(sdkruntime.AgentConfig{ Retrievers: sdkruntime.AgentRetrievers{Retrievers: []interfaces.Retriever{good, bad}}, }) - got, err := rt.ExecuteRetrievers(context.Background(), noopLog(), "q") + got, err := rt.ExecuteRetrievers(context.Background(), ExecuteRetrieversInput{Logger: noopLog(), Query: "q"}) require.NoError(t, err) // partial is ok require.Contains(t, got.Context, "useful") require.Equal(t, int64(2), got.TotalSearches) @@ -561,7 +1017,7 @@ func TestStoreMemoryRecords_appliesTTL(t *testing.T) { ctx := context.Background() before := time.Now().UTC() - require.NoError(t, rt.StoreMemoryRecords(ctx, noopLog(), scope, []interfaces.MemoryRecord{ + require.NoError(t, storeMemoryRecords(rt, ctx, scope, []interfaces.MemoryRecord{ {Text: "User prefers concise answers", Kind: memory.KindNote}, })) @@ -581,7 +1037,7 @@ func TestStoreMemoryRecords_allowlistRejectsKind(t *testing.T) { Memory: sdkruntime.AgentMemory{Config: &cfg}, }) - err := rt.StoreMemoryRecords(context.Background(), noopLog(), interfaces.MemoryScope{UserID: "u1"}, + err := storeMemoryRecords(rt, context.Background(), interfaces.MemoryScope{UserID: "u1"}, []interfaces.MemoryRecord{{Text: "note text", Kind: memory.KindNote}}) require.Error(t, err) } @@ -597,10 +1053,10 @@ func TestStoreMemoryRecords_dedupUpserts(t *testing.T) { ctx := context.Background() text := "favorite color is blue" - require.NoError(t, rt.StoreMemoryRecords(ctx, noopLog(), scope, []interfaces.MemoryRecord{ + require.NoError(t, storeMemoryRecords(rt, ctx, scope, []interfaces.MemoryRecord{ {Text: text, Kind: memory.KindPreference}, })) - require.NoError(t, rt.StoreMemoryRecords(ctx, noopLog(), scope, []interfaces.MemoryRecord{ + require.NoError(t, storeMemoryRecords(rt, ctx, scope, []interfaces.MemoryRecord{ {Text: text, Kind: memory.KindFact}, })) @@ -620,7 +1076,7 @@ func TestStoreMemoryRecords_dedupAppendsDistinctText(t *testing.T) { scope := interfaces.MemoryScope{UserID: "u1"} ctx := context.Background() - require.NoError(t, rt.StoreMemoryRecords(ctx, noopLog(), scope, []interfaces.MemoryRecord{ + require.NoError(t, storeMemoryRecords(rt, ctx, scope, []interfaces.MemoryRecord{ {Text: "favorite color is blue", Kind: memory.KindPreference}, {Text: "prefers concise answers", Kind: memory.KindPreference}, })) @@ -640,7 +1096,7 @@ func TestStoreMemoryRecords_appliesDefaultKind(t *testing.T) { scope := interfaces.MemoryScope{UserID: "u1"} ctx := context.Background() - require.NoError(t, rt.StoreMemoryRecords(ctx, noopLog(), scope, []interfaces.MemoryRecord{ + require.NoError(t, storeMemoryRecords(rt, ctx, scope, []interfaces.MemoryRecord{ {Text: "remember this"}, })) @@ -652,7 +1108,7 @@ func TestStoreMemoryRecords_appliesDefaultKind(t *testing.T) { func TestStoreMemoryRecords_notConfigured(t *testing.T) { rt := newTestRuntime(sdkruntime.AgentConfig{}) - require.NoError(t, rt.StoreMemoryRecords(context.Background(), noopLog(), interfaces.MemoryScope{UserID: "u1"}, + require.NoError(t, storeMemoryRecords(rt, context.Background(), interfaces.MemoryScope{UserID: "u1"}, []interfaces.MemoryRecord{{Text: "x"}})) } @@ -666,7 +1122,7 @@ func TestStoreMemoryRecords_skipsEmptyText(t *testing.T) { scope := interfaces.MemoryScope{UserID: "u1"} ctx := context.Background() - require.NoError(t, rt.StoreMemoryRecords(ctx, noopLog(), scope, []interfaces.MemoryRecord{ + require.NoError(t, storeMemoryRecords(rt, ctx, scope, []interfaces.MemoryRecord{ {Text: " "}, })) @@ -696,7 +1152,7 @@ func TestStoreMemoryRecords_emitsMetrics(t *testing.T) { }) rt.Metrics = metrics - require.NoError(t, rt.StoreMemoryRecords(context.Background(), noopLog(), scope, + require.NoError(t, storeMemoryRecords(rt, context.Background(), scope, []interfaces.MemoryRecord{{Text: "hello world", Kind: memory.KindNote}})) } @@ -713,11 +1169,139 @@ func TestStoreMemoryRecords_kindRejectedEmitsFailedMetric(t *testing.T) { }) rt.Metrics = metrics - err := rt.StoreMemoryRecords(context.Background(), noopLog(), interfaces.MemoryScope{UserID: "u1"}, + err := storeMemoryRecords(rt, context.Background(), interfaces.MemoryScope{UserID: "u1"}, []interfaces.MemoryRecord{{Text: "x", Kind: memory.KindNote}}) require.Error(t, err) } +func TestStoreMemoryRecords_BeforeMemoryStoreRewritesRecord(t *testing.T) { + store := testutil.NewInmemMemory() + cfg := memory.DefaultConfig(store) + rt := newTestRuntime(sdkruntime.AgentConfig{ + Memory: sdkruntime.AgentMemory{Config: &cfg}, + Hooks: []hooks.HookGroup{{ + Name: "scrub", + Hooks: hooks.AgentHooks{BeforeMemoryStore: []hooks.BeforeMemoryStoreHook{ + func(_ context.Context, in hooks.BeforeMemoryStoreHookInput) (hooks.BeforeMemoryStoreHookOutput, error) { + if in.RunMeta.RunID != "run-s" || in.RunMeta.Iteration != 3 { + t.Fatalf("RunMeta = %#v", in.RunMeta) + } + return hooks.BeforeMemoryStoreHookOutput{ + Record: interfaces.MemoryRecord{Text: "scrubbed text"}, + }, nil + }, + }}, + }}, + }) + + scope := interfaces.MemoryScope{UserID: "u1"} + ctx := context.Background() + require.NoError(t, rt.StoreMemoryRecords(ctx, StoreMemoryRecordsInput{ + Logger: noopLog(), RunID: "run-s", Iteration: 3, Scope: scope, + Records: []interfaces.MemoryRecord{{Text: "raw secret"}}, + })) + + entries, err := store.Load(ctx, scope, "scrubbed", cfg.Recall.RecencyLoadOptions()...) + require.NoError(t, err) + require.Len(t, entries, 1) + require.Equal(t, "scrubbed text", entries[0].Text) +} + +func TestStoreMemoryRecords_BeforeMemoryStoreHookIDSkipsDedup(t *testing.T) { + ctrl := gomock.NewController(t) + mem := ifmocks.NewMockMemory(ctrl) + scope := interfaces.MemoryScope{UserID: "u1"} + + mem.EXPECT().Load(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(0) + mem.EXPECT().Store(gomock.Any(), scope, gomock.Any(), gomock.Any()).DoAndReturn( + func(_ context.Context, _ interfaces.MemoryScope, rec interfaces.MemoryRecord, opts ...interfaces.StoreMemoryOption) (string, error) { + storeOpts := interfaces.StoreMemoryOptions{} + for _, opt := range opts { + opt(&storeOpts) + } + require.Equal(t, "hook-id", storeOpts.ID) + require.Equal(t, "stored text", rec.Text) + return "hook-id", nil + }).Times(1) + + cfg := memory.DefaultConfig(mem) + rt := newTestRuntime(sdkruntime.AgentConfig{ + Memory: sdkruntime.AgentMemory{Config: &cfg}, + Hooks: []hooks.HookGroup{{ + Name: "upsert", + Hooks: hooks.AgentHooks{BeforeMemoryStore: []hooks.BeforeMemoryStoreHook{ + func(context.Context, hooks.BeforeMemoryStoreHookInput) (hooks.BeforeMemoryStoreHookOutput, error) { + return hooks.BeforeMemoryStoreHookOutput{ + Record: interfaces.MemoryRecord{Text: "stored text"}, + ID: "hook-id", + }, nil + }, + }}, + }}, + }) + + require.NoError(t, rt.StoreMemoryRecords(context.Background(), StoreMemoryRecordsInput{ + Logger: noopLog(), Scope: scope, + Records: []interfaces.MemoryRecord{{Text: "original"}}, + })) +} + +func TestStoreMemoryRecords_BeforeMemoryStoreAbort(t *testing.T) { + store := testutil.NewInmemMemory() + cfg := memory.DefaultConfig(store) + rt := newTestRuntime(sdkruntime.AgentConfig{ + Memory: sdkruntime.AgentMemory{Config: &cfg}, + Hooks: []hooks.HookGroup{{ + Name: "block", + Hooks: hooks.AgentHooks{BeforeMemoryStore: []hooks.BeforeMemoryStoreHook{ + func(context.Context, hooks.BeforeMemoryStoreHookInput) (hooks.BeforeMemoryStoreHookOutput, error) { + return hooks.BeforeMemoryStoreHookOutput{}, errors.New("blocked") + }, + }}, + }}, + }) + + err := rt.StoreMemoryRecords(context.Background(), StoreMemoryRecordsInput{ + Logger: noopLog(), Scope: interfaces.MemoryScope{UserID: "u1"}, + Records: []interfaces.MemoryRecord{{Text: "x"}}, + }) + require.Error(t, err) + require.Equal(t, "blocked", err.Error()) + + entries, err := store.Load(context.Background(), interfaces.MemoryScope{UserID: "u1"}, "", cfg.Recall.RecencyLoadOptions()...) + require.NoError(t, err) + require.Empty(t, entries) +} + +func TestStoreMemoryRecords_AfterMemoryStoreAbort(t *testing.T) { + ctrl := gomock.NewController(t) + mem := ifmocks.NewMockMemory(ctrl) + scope := interfaces.MemoryScope{UserID: "u1"} + mem.EXPECT().Load(gomock.Any(), scope, gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + mem.EXPECT().Store(gomock.Any(), scope, gomock.Any()).Return("id-1", nil).Times(1) + + cfg := memory.DefaultConfig(mem) + rt := newTestRuntime(sdkruntime.AgentConfig{ + Memory: sdkruntime.AgentMemory{Config: &cfg}, + Hooks: []hooks.HookGroup{{ + Name: "audit", + Hooks: hooks.AgentHooks{AfterMemoryStore: []hooks.AfterMemoryStoreHook{ + func(_ context.Context, in hooks.AfterMemoryStoreHookInput) (hooks.AfterMemoryStoreHookOutput, error) { + require.Equal(t, "id-1", in.ID) + return hooks.AfterMemoryStoreHookOutput{}, errors.New("blocked") + }, + }}, + }}, + }) + + err := rt.StoreMemoryRecords(context.Background(), StoreMemoryRecordsInput{ + Logger: noopLog(), Scope: scope, + Records: []interfaces.MemoryRecord{{Text: "hello"}}, + }) + require.Error(t, err) + require.Equal(t, "blocked", err.Error()) +} + // --- default memory extract (Always mode) --- type stubMemoryExtractLLM struct { @@ -886,21 +1470,23 @@ func TestExecuteMemoryStore_skipsOnDemand(t *testing.T) { Memory: sdkruntime.AgentMemory{Config: &cfg}, }) scope := interfaces.MemoryScope{UserID: "u1"} - require.NoError(t, rt.ExecuteMemoryStore(context.Background(), noopLog(), scope, testRunMessages("hello", "world"))) + require.NoError(t, executeMemoryStore(rt, context.Background(), scope, testRunMessages("hello", "world"))) entries, err := store.Load(context.Background(), scope, "", cfg.Recall.RecencyLoadOptions()...) require.NoError(t, err) require.Empty(t, entries) } -func TestExecuteToolWithMemoryScope_saveMemory(t *testing.T) { +func TestExecuteTool_saveMemory(t *testing.T) { store := testutil.NewInmemMemory() cfg := memory.DefaultConfig(store) rt := newTestRuntime(sdkruntime.AgentConfig{ Memory: sdkruntime.AgentMemory{Config: &cfg}, }) scope := interfaces.MemoryScope{UserID: "u1"} - out, err := rt.ExecuteToolWithMemoryScope(context.Background(), noopLog(), nil, types.SaveMemoryToolName, - map[string]any{types.MemoryToolParamText: "favorite color is blue"}, scope) + out, err := rt.ExecuteTool(context.Background(), ExecuteToolInput{ + Logger: noopLog(), ToolName: types.SaveMemoryToolName, + Args: map[string]any{types.MemoryToolParamText: "favorite color is blue"}, + }, scope) require.NoError(t, err) require.Equal(t, "memory saved", out) entries, err := store.Load(context.Background(), scope, "", cfg.Recall.RecencyLoadOptions()...) @@ -908,6 +1494,43 @@ func TestExecuteToolWithMemoryScope_saveMemory(t *testing.T) { require.Len(t, entries, 1) } +func TestExecuteTool_saveMemory_runsBeforeStoreHook(t *testing.T) { + store := testutil.NewInmemMemory() + cfg := memory.DefaultConfig(store) + rt := newTestRuntime(sdkruntime.AgentConfig{ + Memory: sdkruntime.AgentMemory{Config: &cfg}, + Hooks: []hooks.HookGroup{{ + Name: "scrub", + Hooks: hooks.AgentHooks{BeforeMemoryStore: []hooks.BeforeMemoryStoreHook{ + func(_ context.Context, in hooks.BeforeMemoryStoreHookInput) (hooks.BeforeMemoryStoreHookOutput, error) { + if in.RunMeta.RunID != "run-t" || in.RunMeta.Iteration != 2 { + t.Fatalf("RunMeta = %#v", in.RunMeta) + } + rec := in.Record + if rec.Metadata == nil { + rec.Metadata = map[string]string{} + } + rec.Metadata["hooked"] = "true" + return hooks.BeforeMemoryStoreHookOutput{Record: rec}, nil + }, + }}, + }}, + }) + + scope := interfaces.MemoryScope{UserID: "u1"} + _, err := rt.ExecuteTool(context.Background(), ExecuteToolInput{ + Logger: noopLog(), RunID: "run-t", Iteration: 2, + ToolName: types.SaveMemoryToolName, + Args: map[string]any{types.MemoryToolParamText: "favorite color is blue"}, + }, scope) + require.NoError(t, err) + + entries, err := store.Load(context.Background(), scope, "", cfg.Recall.RecencyLoadOptions()...) + require.NoError(t, err) + require.Len(t, entries, 1) + require.Equal(t, "true", entries[0].Metadata["hooked"]) +} + func TestExecuteMemoryRecallAndStore(t *testing.T) { store := testutil.NewInmemMemory() cfg := memoryConfigAlways(store) @@ -919,9 +1542,11 @@ func TestExecuteMemoryRecallAndStore(t *testing.T) { scope := interfaces.MemoryScope{UserID: "u1"} ctx := context.Background() - require.NoError(t, rt.ExecuteMemoryStore(ctx, noopLog(), scope, testRunMessages("hello", "world"))) + require.NoError(t, executeMemoryStore(rt, ctx, scope, testRunMessages("hello", "world"))) - res, err := rt.ExecuteMemoryRecall(ctx, noopLog(), scope, "hello") + res, err := rt.ExecuteMemoryRecall(ctx, ExecuteMemoryRecallInput{ + Logger: noopLog(), RunID: "run-1", Iteration: 0, Scope: scope, Query: "hello", + }) require.NoError(t, err) require.NotEmpty(t, res.Context) } @@ -938,7 +1563,7 @@ func TestExecuteMemoryStore_AppliesTTLFromPolicy(t *testing.T) { ctx := context.Background() before := time.Now().UTC() - require.NoError(t, rt.ExecuteMemoryStore(ctx, noopLog(), scope, testRunMessages("hello", "world"))) + require.NoError(t, executeMemoryStore(rt, ctx, scope, testRunMessages("hello", "world"))) entries, err := store.Load(ctx, scope, "", cfg.Recall.RecencyLoadOptions()...) require.NoError(t, err) @@ -959,8 +1584,8 @@ func TestExecuteMemoryStore_skipsEmptyMessages(t *testing.T) { scope := interfaces.MemoryScope{UserID: "u1"} ctx := context.Background() - require.NoError(t, rt.ExecuteMemoryStore(ctx, noopLog(), scope, nil)) - require.NoError(t, rt.ExecuteMemoryStore(ctx, noopLog(), scope, []interfaces.Message{ + require.NoError(t, executeMemoryStore(rt, ctx, scope, nil)) + require.NoError(t, executeMemoryStore(rt, ctx, scope, []interfaces.Message{ {Role: interfaces.MessageRoleTool, Content: "noise"}, })) @@ -981,7 +1606,7 @@ func TestExecuteMemoryStore_noExtractorEmitsFailedMetric(t *testing.T) { }) rt.Metrics = metrics - require.NoError(t, rt.ExecuteMemoryStore(context.Background(), noopLog(), interfaces.MemoryScope{UserID: "u1"}, + require.NoError(t, executeMemoryStore(rt, context.Background(), interfaces.MemoryScope{UserID: "u1"}, testRunMessages("hi", "there"))) } @@ -1007,7 +1632,7 @@ func TestExecuteMemoryExtract_EmitsMetrics(t *testing.T) { rt.Metrics = metrics scope := interfaces.MemoryScope{UserID: "u1"} - require.NoError(t, rt.ExecuteMemoryStore(context.Background(), noopLog(), scope, testRunMessages("hello", "world"))) + require.NoError(t, executeMemoryStore(rt, context.Background(), scope, testRunMessages("hello", "world"))) } func TestExecuteMemoryExtract_EmitsFailedMetric(t *testing.T) { @@ -1027,7 +1652,7 @@ func TestExecuteMemoryExtract_EmitsFailedMetric(t *testing.T) { }) rt.Metrics = metrics - err := rt.ExecuteMemoryStore(context.Background(), noopLog(), interfaces.MemoryScope{UserID: "u1"}, testRunMessages("hi", "there")) + err := executeMemoryStore(rt, context.Background(), interfaces.MemoryScope{UserID: "u1"}, testRunMessages("hi", "there")) require.Error(t, err) } @@ -1044,7 +1669,7 @@ func TestExecuteMemoryStore_extractsWithDefaultLLM(t *testing.T) { scope := interfaces.MemoryScope{UserID: "u1"} ctx := context.Background() - require.NoError(t, rt.ExecuteMemoryStore(ctx, noopLog(), scope, testRunMessages("I like tea", "noted"))) + require.NoError(t, executeMemoryStore(rt, ctx, scope, testRunMessages("I like tea", "noted"))) entries, err := store.Load(ctx, scope, "", cfg.Recall.RecencyLoadOptions()...) require.NoError(t, err) @@ -1065,7 +1690,7 @@ func TestExecuteMemoryStore_setsExtractMetadata(t *testing.T) { scope := interfaces.MemoryScope{UserID: "u1"} ctx := context.Background() - require.NoError(t, rt.ExecuteMemoryStore(ctx, noopLog(), scope, testRunMessages("hi", "there"))) + require.NoError(t, executeMemoryStore(rt, ctx, scope, testRunMessages("hi", "there"))) entries, err := store.Load(ctx, scope, "", cfg.Recall.RecencyLoadOptions()...) require.NoError(t, err) @@ -1091,7 +1716,9 @@ func TestExecuteMemoryRecall_OmitsExpired(t *testing.T) { }) require.NoError(t, err) - res, err := rt.ExecuteMemoryRecall(ctx, noopLog(), scope, "concise") + res, err := rt.ExecuteMemoryRecall(ctx, ExecuteMemoryRecallInput{ + Logger: noopLog(), RunID: "run-1", Iteration: 0, Scope: scope, Query: "concise", + }) require.NoError(t, err) require.Empty(t, res.Context) } @@ -1107,14 +1734,97 @@ func TestExecuteMemoryRecall_SemanticMissFallsBackToRecency(t *testing.T) { scope := interfaces.MemoryScope{UserID: "u1"} ctx := context.Background() - require.NoError(t, rt.ExecuteMemoryStore(ctx, noopLog(), scope, testRunMessages( + require.NoError(t, executeMemoryStore(rt, ctx, scope, testRunMessages( "Remember that I prefer concise answers.", "Got it."))) - res, err := rt.ExecuteMemoryRecall(ctx, noopLog(), scope, "What answer style do I prefer?") + res, err := rt.ExecuteMemoryRecall(ctx, ExecuteMemoryRecallInput{ + Logger: noopLog(), RunID: "run-1", Iteration: 0, Scope: scope, Query: "What answer style do I prefer?", + }) require.NoError(t, err) require.Contains(t, res.Context, "concise answers") } +func TestExecuteMemoryRecall_BeforeMemoryLoadRewritesQuery(t *testing.T) { + store := testutil.NewInmemMemory() + cfg := memoryConfigAlways(store) + cfg.Store.Extract = testAlwaysStoreExtract + rt := newTestRuntime(sdkruntime.AgentConfig{ + Memory: sdkruntime.AgentMemory{Config: &cfg}, + Hooks: []hooks.HookGroup{{ + Name: "rewrite", + Hooks: hooks.AgentHooks{BeforeMemoryLoad: []hooks.BeforeMemoryLoadHook{ + func(_ context.Context, in hooks.BeforeMemoryLoadHookInput) (hooks.BeforeMemoryLoadHookOutput, error) { + if in.RunMeta.RunID != "run-m" || in.RunMeta.Iteration != 0 || in.RunMeta.HooksGroup != "rewrite" { + t.Fatalf("RunMeta = %#v", in.RunMeta) + } + return hooks.BeforeMemoryLoadHookOutput{Query: "hooked"}, nil + }, + }}, + }}, + }) + + scope := interfaces.MemoryScope{UserID: "u1"} + ctx := context.Background() + require.NoError(t, executeMemoryStore(rt, ctx, scope, testRunMessages( + "Remember that I prefer concise answers.", "Got it."))) + + res, err := rt.ExecuteMemoryRecall(ctx, ExecuteMemoryRecallInput{ + Logger: noopLog(), RunID: "run-m", Iteration: 0, Scope: scope, Query: "raw", + }) + require.NoError(t, err) + require.Contains(t, res.Context, "concise answers") +} + +func TestExecuteMemoryRecall_AfterMemoryLoadFiltersContext(t *testing.T) { + store := testutil.NewInmemMemory() + cfg := memoryConfigAlways(store) + cfg.Store.Extract = testAlwaysStoreExtract + rt := newTestRuntime(sdkruntime.AgentConfig{ + Memory: sdkruntime.AgentMemory{Config: &cfg}, + Hooks: []hooks.HookGroup{{ + Name: "filter", + Hooks: hooks.AgentHooks{AfterMemoryLoad: []hooks.AfterMemoryLoadHook{ + func(_ context.Context, in hooks.AfterMemoryLoadHookInput) (hooks.AfterMemoryLoadHookOutput, error) { + return hooks.AfterMemoryLoadHookOutput{PromptContext: "filtered"}, nil + }, + }}, + }}, + }) + + scope := interfaces.MemoryScope{UserID: "u1"} + ctx := context.Background() + require.NoError(t, executeMemoryStore(rt, ctx, scope, testRunMessages("hello", "world"))) + + res, err := rt.ExecuteMemoryRecall(ctx, ExecuteMemoryRecallInput{ + Logger: noopLog(), RunID: "run-1", Iteration: 0, Scope: scope, Query: "hello", + }) + require.NoError(t, err) + require.Equal(t, "filtered", res.Context) +} + +func TestExecuteMemoryRecall_BeforeMemoryLoadAbort(t *testing.T) { + store := testutil.NewInmemMemory() + cfg := memory.DefaultConfig(store) + rt := newTestRuntime(sdkruntime.AgentConfig{ + Memory: sdkruntime.AgentMemory{Config: &cfg}, + Hooks: []hooks.HookGroup{{ + Name: "block", + Hooks: hooks.AgentHooks{BeforeMemoryLoad: []hooks.BeforeMemoryLoadHook{ + func(context.Context, hooks.BeforeMemoryLoadHookInput) (hooks.BeforeMemoryLoadHookOutput, error) { + return hooks.BeforeMemoryLoadHookOutput{}, errors.New("blocked") + }, + }}, + }}, + }) + + _, err := rt.ExecuteMemoryRecall(context.Background(), ExecuteMemoryRecallInput{ + Logger: noopLog(), RunID: "run-1", Iteration: 0, + Scope: interfaces.MemoryScope{UserID: "u1"}, Query: "q", + }) + require.Error(t, err) + require.Equal(t, "blocked", err.Error()) +} + func TestSubAgentScope(t *testing.T) { parent := interfaces.MemoryScope{ TenantID: "t1", @@ -1182,8 +1892,10 @@ func TestExecuteMemoryRecall_EmitsMetrics(t *testing.T) { rt.Metrics = metrics ctx := context.Background() - require.NoError(t, rt.ExecuteMemoryStore(ctx, noopLog(), scope, testRunMessages("hello", "world"))) - _, err := rt.ExecuteMemoryRecall(ctx, noopLog(), scope, "hello") + require.NoError(t, executeMemoryStore(rt, ctx, scope, testRunMessages("hello", "world"))) + _, err := rt.ExecuteMemoryRecall(ctx, ExecuteMemoryRecallInput{ + Logger: noopLog(), RunID: "run-1", Iteration: 0, Scope: scope, Query: "hello", + }) require.NoError(t, err) } @@ -1592,7 +2304,7 @@ func TestExecuteRetrievers_EmptyDocsSkipped(t *testing.T) { rt := newTestRuntime(sdkruntime.AgentConfig{ Retrievers: sdkruntime.AgentRetrievers{Retrievers: []interfaces.Retriever{r}}, }) - got, err := rt.ExecuteRetrievers(context.Background(), noopLog(), "q") + got, err := rt.ExecuteRetrievers(context.Background(), ExecuteRetrieversInput{Logger: noopLog(), Query: "q"}) require.NoError(t, err) require.Equal(t, "", got.Context) require.Equal(t, int64(1), got.TotalSearches) diff --git a/internal/runtime/base/types.go b/internal/runtime/base/types.go index fa21493..c25668a 100644 --- a/internal/runtime/base/types.go +++ b/internal/runtime/base/types.go @@ -3,6 +3,7 @@ package base import ( "github.com/agenticenv/agent-sdk-go/internal/types" "github.com/agenticenv/agent-sdk-go/pkg/interfaces" + "github.com/agenticenv/agent-sdk-go/pkg/logger" ) const scopeKeyParentAgentID = "parent_agent_id" @@ -42,9 +43,60 @@ type RetrieverResult struct { FailedSearches int64 } +// ExecuteRetrieversInput holds per-invocation inputs for [Runtime.ExecuteRetrievers]. +// RunID and Iteration populate [hooks.RunMeta] for retrieve middleware hooks. +type ExecuteRetrieversInput struct { + Logger logger.Logger + RunID string + Iteration int + Query string +} + // MemoryResult is the outcome of ExecuteMemoryRecall. type MemoryResult struct { Context string TotalRecalls int64 FailedRecalls int64 } + +// ExecuteMemoryRecallInput holds per-invocation inputs for [Runtime.ExecuteMemoryRecall]. +// RunID and Iteration populate [hooks.RunMeta] for memory load middleware hooks. +type ExecuteMemoryRecallInput struct { + Logger logger.Logger + RunID string + Iteration int + Scope interfaces.MemoryScope + Query string +} + +// StoreMemoryRecordsInput holds per-invocation inputs for [Runtime.StoreMemoryRecords]. +// RunID and Iteration populate [hooks.RunMeta] for memory store middleware hooks. +type StoreMemoryRecordsInput struct { + Logger logger.Logger + RunID string + Iteration int + Scope interfaces.MemoryScope + Records []interfaces.MemoryRecord +} + +// ExecuteMemoryStoreInput holds per-invocation inputs for [Runtime.ExecuteMemoryStore]. +// RunID and Iteration populate [hooks.RunMeta] for memory store middleware hooks. +type ExecuteMemoryStoreInput struct { + Logger logger.Logger + RunID string + Iteration int + Scope interfaces.MemoryScope + Messages []interfaces.Message +} + +// ExecuteToolInput holds per-invocation inputs for [Runtime.ExecuteTool]. +// RunID and Iteration populate [hooks.RunMeta] for tool middleware hooks. +type ExecuteToolInput struct { + Logger logger.Logger + Tools []interfaces.Tool + ToolName string + Args map[string]any + ToolCallID string + RunID string + Iteration int +} diff --git a/internal/runtime/local/agent_loop.go b/internal/runtime/local/agent_loop.go index 5d405e7..11e3c12 100644 --- a/internal/runtime/local/agent_loop.go +++ b/internal/runtime/local/agent_loop.go @@ -46,6 +46,8 @@ type AgentLoopInput struct { MaxSubAgentDepth int // MemoryScope is resolved before the run and used for recall/store. MemoryScope interfaces.MemoryScope + // RunID is the stable identifier for this agent run; passed to LLM hooks via [RunMeta]. + RunID string } // AgentLoopResult is the outcome of a completed local agent run. @@ -138,7 +140,13 @@ func (rt *LocalRuntime) RunAgentLoop(ctx context.Context, input AgentLoopInput) memoryContext := "" if rt.MemoryConfigured() && rt.RecallEnabled() { log.Debug(ctx, "local: memory recall started", slog.String("scope", "loop")) - res, err := rt.ExecuteMemoryRecall(ctx, log, input.MemoryScope, input.UserPrompt) + res, err := rt.ExecuteMemoryRecall(ctx, base.ExecuteMemoryRecallInput{ + Logger: log, + RunID: input.RunID, + Iteration: 0, + Scope: input.MemoryScope, + Query: input.UserPrompt, + }) if err != nil { return nil, fmt.Errorf("memory recall: %w", err) } @@ -159,7 +167,12 @@ func (rt *LocalRuntime) RunAgentLoop(ctx context.Context, input AgentLoopInput) slog.String("scope", "loop"), slog.String("mode", string(retrieverMode)), slog.Int("retrieverCount", len(rt.AgentConfig.Retrievers.Retrievers))) - res, err := rt.ExecuteRetrievers(ctx, log, input.UserPrompt) + res, err := rt.ExecuteRetrievers(ctx, base.ExecuteRetrieversInput{ + Logger: log, + RunID: input.RunID, + Iteration: 0, + Query: input.UserPrompt, + }) if err != nil { return nil, fmt.Errorf("retriever prefetch: %w", err) } @@ -188,6 +201,8 @@ func (rt *LocalRuntime) RunAgentLoop(ctx context.Context, input AgentLoopInput) Logger: log, AgentName: agentName, MessageID: messageID, + RunID: input.RunID, + Iteration: iter, Messages: messages, SkipTools: false, MemoryContext: memoryContext, @@ -227,6 +242,8 @@ func (rt *LocalRuntime) RunAgentLoop(ctx context.Context, input AgentLoopInput) Logger: log, AgentName: agentName, MessageID: finalMessageID, + RunID: input.RunID, + Iteration: iter, Messages: messages, SkipTools: true, MemoryContext: memoryContext, @@ -272,9 +289,9 @@ func (rt *LocalRuntime) RunAgentLoop(ctx context.Context, input AgentLoopInput) var toolResults []toolResult switch toolExecMode { case types.AgentToolExecutionModeParallel: - toolResults, err = rt.executeToolsParallel(ctx, input, messageID, llmResult.ToolCalls, emit) + toolResults, err = rt.executeToolsParallel(ctx, input, messageID, iter, llmResult.ToolCalls, emit) case types.AgentToolExecutionModeSequential: - toolResults, err = rt.executeToolsSequential(ctx, input, messageID, llmResult.ToolCalls, emit) + toolResults, err = rt.executeToolsSequential(ctx, input, messageID, iter, llmResult.ToolCalls, emit) default: return nil, fmt.Errorf("invalid tool execution mode %q: use %q or %q", toolExecMode, @@ -319,7 +336,13 @@ func (rt *LocalRuntime) RunAgentLoop(ctx context.Context, input AgentLoopInput) } if rt.RunEndMemoryStoreEnabled() { - if err := rt.ExecuteMemoryStore(ctx, log, input.MemoryScope, messages); err != nil { + if err := rt.ExecuteMemoryStore(ctx, base.ExecuteMemoryStoreInput{ + Logger: log, + RunID: input.RunID, + Iteration: 0, + Scope: input.MemoryScope, + Messages: messages, + }); err != nil { log.Warn(ctx, "local: memory store failed", slog.String("scope", "loop"), slog.Any("error", err)) telemetry.Storage.FailedMemoryStores++ } else { @@ -353,6 +376,7 @@ func (rt *LocalRuntime) executeToolsParallel( ctx context.Context, input AgentLoopInput, messageID string, + iteration int, toolCalls []base.ToolCallRequest, emit func(events.AgentEvent), ) ([]toolResult, error) { @@ -367,7 +391,7 @@ func (rt *LocalRuntime) executeToolsParallel( wg.Add(1) go func(idx int, tc base.ToolCallRequest) { defer wg.Done() - result, err := rt.executeSingleTool(ctx, input, messageID, tc, emit) + result, err := rt.executeSingleTool(ctx, input, messageID, iteration, tc, emit) if err != nil { rt.logger.Info(ctx, "local: parallel tool failed", slog.String("scope", "loop"), @@ -396,6 +420,7 @@ func (rt *LocalRuntime) executeToolsSequential( ctx context.Context, input AgentLoopInput, messageID string, + iteration int, toolCalls []base.ToolCallRequest, emit func(events.AgentEvent), ) ([]toolResult, error) { @@ -405,7 +430,7 @@ func (rt *LocalRuntime) executeToolsSequential( results := make([]toolResult, len(toolCalls)) for idx, tc := range toolCalls { - result, err := rt.executeSingleTool(ctx, input, messageID, tc, emit) + result, err := rt.executeSingleTool(ctx, input, messageID, iteration, tc, emit) if err != nil { rt.logger.Info(ctx, "local: sequential tool failed", slog.String("scope", "loop"), @@ -432,6 +457,7 @@ func (rt *LocalRuntime) executeSingleTool( ctx context.Context, input AgentLoopInput, messageID string, + iteration int, tc base.ToolCallRequest, emit func(events.AgentEvent), ) (toolResult, error) { @@ -588,6 +614,7 @@ func (rt *LocalRuntime) executeSingleTool( emit(events.NewAgentStepStartedEvent(stepName)) subResult, execErr := subAgentRoute.runtime.RunAgentLoop(ctx, AgentLoopInput{ UserPrompt: query, + RunID: uuid.New().String(), StreamingEnabled: input.StreamingEnabled, ChannelName: input.ChannelName, ApprovalHandler: input.ApprovalHandler, @@ -611,7 +638,15 @@ func (rt *LocalRuntime) executeSingleTool( slog.String("scope", "loop"), slog.String("tool", tc.ToolName), slog.String("toolCallID", tc.ToolCallID)) - result, execErr := rt.ExecuteToolWithMemoryScope(ctx, log, tools, tc.ToolName, tc.Args, input.MemoryScope) + result, execErr := rt.ExecuteTool(ctx, base.ExecuteToolInput{ + Logger: log, + Tools: tools, + ToolName: tc.ToolName, + Args: tc.Args, + ToolCallID: tc.ToolCallID, + RunID: input.RunID, + Iteration: iteration, + }, input.MemoryScope) if execErr != nil { content = "Tool execution failed: " + execErr.Error() failed = true diff --git a/internal/runtime/local/agent_loop_test.go b/internal/runtime/local/agent_loop_test.go index cd00207..e25555e 100644 --- a/internal/runtime/local/agent_loop_test.go +++ b/internal/runtime/local/agent_loop_test.go @@ -555,7 +555,7 @@ func TestExecuteToolsParallel_AllSucceed(t *testing.T) { testToolCall("c2", "t2"), } - msgs, err := rt.executeToolsParallel(context.Background(), loopToolsInput(tools), "msg-1", calls, noopEmit) + msgs, err := rt.executeToolsParallel(context.Background(), loopToolsInput(tools), "msg-1", 0, calls, noopEmit) require.NoError(t, err) require.Len(t, msgs, 2) // Order must match submission order (parallel but results are indexed). @@ -569,7 +569,7 @@ func TestExecuteToolsParallel_ToolErrorInMessage(t *testing.T) { rt, tools := newLoopRT(t, 5, &seqLLMClient{}, failing) calls := []base.ToolCallRequest{testToolCall("c1", "bad")} - msgs, err := rt.executeToolsParallel(context.Background(), loopToolsInput(tools), "msg", calls, noopEmit) + msgs, err := rt.executeToolsParallel(context.Background(), loopToolsInput(tools), "msg", 0, calls, noopEmit) require.NoError(t, err) // parallel swallows into message require.Len(t, msgs, 1) require.Contains(t, msgs[0].message.Content, "boom") @@ -590,7 +590,7 @@ func TestExecuteToolsParallel_ResultsOrderPreserved(t *testing.T) { testToolCall("2", "b"), testToolCall("3", "c"), } - msgs, err := rt.executeToolsParallel(context.Background(), loopToolsInput(tools), "m", calls, noopEmit) + msgs, err := rt.executeToolsParallel(context.Background(), loopToolsInput(tools), "m", 0, calls, noopEmit) require.NoError(t, err) require.Equal(t, []string{"A", "B", "C"}, []string{msgs[0].message.Content, msgs[1].message.Content, msgs[2].message.Content}) } @@ -608,7 +608,7 @@ func TestExecuteToolsSequential_AllSucceed(t *testing.T) { testToolCall("c1", "s1"), testToolCall("c2", "s2"), } - msgs, err := rt.executeToolsSequential(context.Background(), loopToolsInput(tools), "msg", calls, noopEmit) + msgs, err := rt.executeToolsSequential(context.Background(), loopToolsInput(tools), "msg", 0, calls, noopEmit) require.NoError(t, err) require.Len(t, msgs, 2) require.Equal(t, "v1", msgs[0].message.Content) @@ -621,7 +621,7 @@ func TestExecuteToolsSequential_HardErrorOnContextCancel(t *testing.T) { cancel() // pre-cancelled calls := []base.ToolCallRequest{testToolCall("c1", "missing-tool")} - results, err := rt.executeToolsSequential(ctx, AgentLoopInput{}, "msg", calls, noopEmit) + results, err := rt.executeToolsSequential(ctx, AgentLoopInput{}, "msg", 0, calls, noopEmit) require.NoError(t, err) require.Len(t, results, 1) require.True(t, results[0].failed) @@ -637,7 +637,7 @@ func TestExecuteSingleTool_Approved(t *testing.T) { rt, tools := newLoopRT(t, 5, &seqLLMClient{}, tool) emit, evs := captureEmit() - msg, err := rt.executeSingleTool(context.Background(), loopToolsInput(tools), "msg-1", + msg, err := rt.executeSingleTool(context.Background(), loopToolsInput(tools), "msg-1", 0, testToolCall("c1", "my-tool"), emit) require.NoError(t, err) @@ -655,7 +655,7 @@ func TestExecuteSingleTool_ToolExecError(t *testing.T) { tool := stubTool{name: "boom", execErr: errors.New("exec failed")} rt, tools := newLoopRT(t, 5, &seqLLMClient{}, tool) - msg, err := rt.executeSingleTool(context.Background(), loopToolsInput(tools), "msg", + msg, err := rt.executeSingleTool(context.Background(), loopToolsInput(tools), "msg", 0, testToolCall("c1", "boom"), noopEmit) require.NoError(t, err) // tool errors become a content message, not a hard error require.Contains(t, msg.message.Content, "exec failed") @@ -665,7 +665,7 @@ func TestExecuteSingleTool_ToolExecError(t *testing.T) { func TestExecuteSingleTool_UnknownToolErrors(t *testing.T) { rt, _ := newLoopRT(t, 5, &seqLLMClient{}) // no tools registered - _, err := rt.executeSingleTool(context.Background(), AgentLoopInput{}, "msg", + _, err := rt.executeSingleTool(context.Background(), AgentLoopInput{}, "msg", 0, testToolCall("c1", "ghost"), noopEmit) require.Error(t, err) require.Contains(t, err.Error(), "ghost") @@ -686,7 +686,7 @@ func TestExecuteSingleTool_AuthorizationDenied(t *testing.T) { authTool := authorizerStubLocal{name: "restricted", allow: false, reason: "policy denied"} rt, tools := newLoopRT(t, 5, &seqLLMClient{}, authTool) - msg, err := rt.executeSingleTool(context.Background(), loopToolsInput(tools), "msg", + msg, err := rt.executeSingleTool(context.Background(), loopToolsInput(tools), "msg", 0, testToolCall("c1", "restricted"), noopEmit) require.NoError(t, err) require.Contains(t, msg.message.Content, msgToolUnauthorized) @@ -698,7 +698,7 @@ func TestExecuteSingleTool_AuthorizationError(t *testing.T) { authTool := authorizerStubLocal{name: "err-tool", allow: false, authErr: errors.New("auth backend down")} rt, tools := newLoopRT(t, 5, &seqLLMClient{}, authTool) - _, err := rt.executeSingleTool(context.Background(), loopToolsInput(tools), "msg", + _, err := rt.executeSingleTool(context.Background(), loopToolsInput(tools), "msg", 0, testToolCall("c1", "err-tool"), noopEmit) require.Error(t, err) require.Contains(t, err.Error(), "auth backend down") @@ -710,7 +710,7 @@ func TestExecuteSingleTool_ApprovalUnavailable(t *testing.T) { rt, tools := newLoopRT(t, 5, &seqLLMClient{}, tool) msg, err := rt.executeSingleTool(context.Background(), - AgentLoopInput{ChannelName: "", ApprovalHandler: nil, Tools: tools}, "msg", + AgentLoopInput{ChannelName: "", ApprovalHandler: nil, Tools: tools}, "msg", 0, testToolCallNeedsApproval("c1", "guarded"), noopEmit) require.NoError(t, err) require.Contains(t, msg.message.Content, msgToolApprovalUnavailable) @@ -725,7 +725,7 @@ func TestExecuteSingleTool_ApprovalHandlerApproves(t *testing.T) { } msg, err := rt.executeSingleTool(context.Background(), - AgentLoopInput{ApprovalHandler: handler, Tools: tools}, "msg", + AgentLoopInput{ApprovalHandler: handler, Tools: tools}, "msg", 0, testToolCallNeedsApproval("c1", "guarded"), noopEmit) require.NoError(t, err) require.Equal(t, "ok", msg.message.Content) @@ -740,7 +740,7 @@ func TestExecuteSingleTool_ApprovalHandlerRejects(t *testing.T) { } msg, err := rt.executeSingleTool(context.Background(), - AgentLoopInput{ApprovalHandler: handler, Tools: tools}, "msg", + AgentLoopInput{ApprovalHandler: handler, Tools: tools}, "msg", 0, testToolCallNeedsApproval("c1", "guarded"), noopEmit) require.NoError(t, err) require.Equal(t, msgToolRejected, msg.message.Content) @@ -789,7 +789,7 @@ func TestExecuteSingleTool_StreamingApproveUnblocks(t *testing.T) { result, resultErr = rt.executeSingleTool( context.Background(), AgentLoopInput{ChannelName: "some-channel", Tools: tools}, // streaming path - "msg", + "msg", 0, testToolCallNeedsApproval("c1", "guarded"), emit, ) @@ -826,7 +826,7 @@ func TestExecuteSingleTool_ApprovalContextCancel(t *testing.T) { }() _, err := rt.executeSingleTool(ctx, - AgentLoopInput{ChannelName: "some-channel", Tools: tools}, "msg", + AgentLoopInput{ChannelName: "some-channel", Tools: tools}, "msg", 0, testToolCallNeedsApproval("c1", "guarded"), noopEmit) <-done diff --git a/internal/runtime/local/runtime.go b/internal/runtime/local/runtime.go index a4931ed..1450354 100644 --- a/internal/runtime/local/runtime.go +++ b/internal/runtime/local/runtime.go @@ -131,6 +131,7 @@ func (rt *LocalRuntime) Execute(ctx context.Context, req *sdkruntime.ExecuteRequ loopResult, err := rt.RunAgentLoop(runCtx, AgentLoopInput{ UserPrompt: req.UserPrompt, + RunID: runID, ConversationID: conversationID, MemoryScope: memoryScope, StreamingEnabled: false, @@ -145,7 +146,6 @@ func (rt *LocalRuntime) Execute(ctx context.Context, req *sdkruntime.ExecuteRequ return nil, err } - _ = runID return &types.AgentRunResult{ Content: loopResult.Content, AgentName: strings.TrimSpace(agentName), @@ -228,6 +228,7 @@ func (rt *LocalRuntime) ExecuteStream(ctx context.Context, req *sdkruntime.Execu }() result, loopErr := rt.RunAgentLoop(runCtx, AgentLoopInput{ UserPrompt: req.UserPrompt, + RunID: runID, ConversationID: conversationID, MemoryScope: memoryScope, StreamingEnabled: req.StreamingEnabled, diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go index be593e3..4e1d71a 100644 --- a/internal/runtime/runtime.go +++ b/internal/runtime/runtime.go @@ -10,6 +10,7 @@ import ( "github.com/agenticenv/agent-sdk-go/internal/eventbus" "github.com/agenticenv/agent-sdk-go/internal/events" + "github.com/agenticenv/agent-sdk-go/internal/hooks" "github.com/agenticenv/agent-sdk-go/internal/types" "github.com/agenticenv/agent-sdk-go/pkg/interfaces" "github.com/agenticenv/agent-sdk-go/pkg/memory" @@ -92,7 +93,7 @@ type AgentSpec struct { ResponseFormat *interfaces.ResponseFormat } -// AgentConfig is static agent wiring on the runtime at construction: LLM client, tool approval policy, session, limits, and retriever config. +// AgentConfig is static agent wiring on the runtime at construction: LLM client, tool approval policy, session, limits, retriever config, and hooks. type AgentConfig struct { LLM AgentLLM ToolApprovalPolicy interfaces.AgentToolApprovalPolicy @@ -100,6 +101,7 @@ type AgentConfig struct { Session AgentSession Memory AgentMemory Limits AgentLimits + Hooks []hooks.HookGroup } // AgentMemory holds long-term memory configuration for recall and store. diff --git a/internal/runtime/temporal/agent_workflow.go b/internal/runtime/temporal/agent_workflow.go index 63e6b58..03b10b5 100644 --- a/internal/runtime/temporal/agent_workflow.go +++ b/internal/runtime/temporal/agent_workflow.go @@ -125,6 +125,7 @@ type AgentWorkflowInput struct { StreamingEnabled bool `json:"streaming_enabled,omitempty"` ConversationID string `json:"conversation_id,omitempty"` AgentFingerprint string `json:"agent_fingerprint,omitempty"` + RunID string `json:"run_id,omitempty"` MemoryScope interfaces.MemoryScope `json:"memory_scope,omitempty"` EventTypes []events.AgentEventType `json:"event_types,omitempty"` SubAgentDepth int `json:"sub_agent_depth,omitempty"` @@ -145,6 +146,7 @@ type AgentWorkflowState struct { // AgentRetrieverInput is the input to AgentRetrieverActivity. type AgentRetrieverInput struct { AgentFingerprint string `json:"agent_fingerprint,omitempty"` + RunID string `json:"run_id,omitempty"` UserPrompt string `json:"user_prompt"` } @@ -160,6 +162,7 @@ type AgentRetrieverResult struct { // AgentMemoryRecallInput is the input to AgentMemoryRecallActivity. type AgentMemoryRecallInput struct { AgentFingerprint string `json:"agent_fingerprint,omitempty"` + RunID string `json:"run_id,omitempty"` UserPrompt string `json:"user_prompt"` MemoryScope interfaces.MemoryScope `json:"memory_scope,omitempty"` } @@ -174,6 +177,7 @@ type AgentMemoryRecallResult struct { // AgentMemoryStoreInput is the input to AgentMemoryStoreActivity. type AgentMemoryStoreInput struct { AgentFingerprint string `json:"agent_fingerprint,omitempty"` + RunID string `json:"run_id,omitempty"` MemoryScope interfaces.MemoryScope `json:"memory_scope,omitempty"` Messages []interfaces.Message `json:"messages,omitempty"` } @@ -195,6 +199,8 @@ type AgentLLMInput struct { LocalChannelName string `json:"local_channel_name,omitempty"` MemoryContext string `json:"memory_context,omitempty"` RetrieverContext string `json:"retriever_context,omitempty"` + RunID string `json:"run_id,omitempty"` + Iteration int `json:"iteration,omitempty"` } // AgentLLMResult is the return value of AgentLLMActivity. Workflow uses it to decide: return content or execute tools. @@ -241,6 +247,7 @@ type agentToolCallInput struct { wfCtx workflow.Context input AgentWorkflowInput messageID string + iteration int emitEvent func(events.AgentEvent) error authorizeCtx workflow.Context approvalCtx workflow.Context @@ -267,6 +274,8 @@ type AgentToolExecuteInput struct { ConversationID string `json:"conversation_id,omitempty"` Messages []interfaces.Message `json:"messages,omitempty"` ToolCallID string `json:"tool_call_id,omitempty"` + RunID string `json:"run_id,omitempty"` + Iteration int `json:"iteration,omitempty"` AgentFingerprint string `json:"agent_fingerprint,omitempty"` MemoryScope interfaces.MemoryScope `json:"memory_scope,omitempty"` } @@ -447,6 +456,7 @@ func (rt *TemporalRuntime) AgentWorkflow(ctx workflow.Context, input AgentWorkfl var memoryResult AgentMemoryRecallResult if err := workflow.ExecuteActivity(memoryActCtx, rt.AgentMemoryRecallActivity, AgentMemoryRecallInput{ AgentFingerprint: input.AgentFingerprint, + RunID: input.RunID, UserPrompt: input.UserPrompt, MemoryScope: input.MemoryScope, }).Get(memoryActCtx, &memoryResult); err != nil { @@ -471,6 +481,7 @@ func (rt *TemporalRuntime) AgentWorkflow(ctx workflow.Context, input AgentWorkfl logger.Debug("workflow: retriever prefetch started", "scope", "workflow", "retrieverMode", string(retrieverMode), "retrieverCount", len(rt.AgentConfig.Retrievers.Retrievers)) retrieverInput := AgentRetrieverInput{ AgentFingerprint: input.AgentFingerprint, + RunID: input.RunID, UserPrompt: input.UserPrompt, } var retrieverResult AgentRetrieverResult @@ -499,6 +510,8 @@ func (rt *TemporalRuntime) AgentWorkflow(ctx workflow.Context, input AgentWorkfl Messages: messages, AgentFingerprint: input.AgentFingerprint, MessageID: messageID, + RunID: input.RunID, + Iteration: iter, EventWorkflowID: eventWorkflowID, EventTaskQueue: eventTaskQueue, LocalChannelName: input.LocalChannelName, @@ -600,7 +613,7 @@ func (rt *TemporalRuntime) AgentWorkflow(ctx workflow.Context, input AgentWorkfl "toolName", tc.ToolName, "toolCallID", tc.ToolCallID) slot := strconv.Itoa(i) - parallelInput := rt.newAgentToolCallInput(gCtx, input, activityIDSuffix, messageID, emitAgentEvent, slot) + parallelInput := rt.newAgentToolCallInput(gCtx, input, activityIDSuffix, messageID, iter, emitAgentEvent, slot) toolOutput, runErr := rt.executeAgentToolCall(parallelInput, tc, streamingUnavailable) if runErr != nil { gLog.Debug("workflow: parallel tool branch finished with error", @@ -667,7 +680,7 @@ func (rt *TemporalRuntime) AgentWorkflow(ctx workflow.Context, input AgentWorkfl "scope", "workflow", "executionMode", string(types.AgentToolExecutionModeSequential), "toolCount", len(llmResult.ToolCalls)) - toolInput := rt.newAgentToolCallInput(ctx, input, activityIDSuffix, messageID, emitAgentEvent, "") + toolInput := rt.newAgentToolCallInput(ctx, input, activityIDSuffix, messageID, iter, emitAgentEvent, "") toolResults = make([]agentToolResult, len(llmResult.ToolCalls)) for i, tc := range llmResult.ToolCalls { logger.Debug("workflow: sequential tool executing", @@ -788,6 +801,7 @@ func (rt *TemporalRuntime) AgentWorkflow(ctx workflow.Context, input AgentWorkfl if rt.RunEndMemoryStoreEnabled() { if err := workflow.ExecuteActivity(memoryActCtx, rt.AgentMemoryStoreActivity, AgentMemoryStoreInput{ AgentFingerprint: input.AgentFingerprint, + RunID: input.RunID, MemoryScope: input.MemoryScope, Messages: messages, }).Get(memoryActCtx, nil); err != nil { @@ -828,6 +842,7 @@ func (rt *TemporalRuntime) newAgentToolCallInput( wfCtx workflow.Context, input AgentWorkflowInput, activityIDSuffix, messageID string, + iteration int, emitAgentEvent func(workflow.Context, events.AgentEvent) error, parallelSlot string, ) agentToolCallInput { @@ -846,6 +861,7 @@ func (rt *TemporalRuntime) newAgentToolCallInput( wfCtx: wfCtx, input: input, messageID: messageID, + iteration: iteration, emitEvent: func(ev events.AgentEvent) error { return emitAgentEvent(wfCtx, ev) }, @@ -986,6 +1002,8 @@ func (rt *TemporalRuntime) executeAgentToolCall(input agentToolCallInput, tc Too Args: tc.Args, ConversationID: input.input.ConversationID, ToolCallID: tc.ToolCallID, + RunID: input.input.RunID, + Iteration: input.iteration, AgentFingerprint: input.input.AgentFingerprint, MemoryScope: input.input.MemoryScope, } @@ -1097,6 +1115,8 @@ func (rt *TemporalRuntime) AgentLLMStreamActivity(ctx context.Context, input Age Logger: actLog, AgentName: agentName, MessageID: input.MessageID, + RunID: input.RunID, + Iteration: input.Iteration, Messages: messages, SkipTools: input.SkipTools, RetrieverContext: input.RetrieverContext, @@ -1122,7 +1142,12 @@ func (rt *TemporalRuntime) AgentRetrieverActivity(ctx context.Context, input Age return nil, err } actLog := newActivityLogger(activity.GetLogger(ctx)) - res, err := rt.ExecuteRetrievers(ctx, actLog, input.UserPrompt) + res, err := rt.ExecuteRetrievers(ctx, base.ExecuteRetrieversInput{ + Logger: actLog, + RunID: input.RunID, + Iteration: 0, + Query: input.UserPrompt, + }) if err != nil { return nil, err } @@ -1139,7 +1164,13 @@ func (rt *TemporalRuntime) AgentMemoryRecallActivity(ctx context.Context, input return nil, err } actLog := newActivityLogger(activity.GetLogger(ctx)) - res, err := rt.ExecuteMemoryRecall(ctx, actLog, input.MemoryScope, input.UserPrompt) + res, err := rt.ExecuteMemoryRecall(ctx, base.ExecuteMemoryRecallInput{ + Logger: actLog, + RunID: input.RunID, + Iteration: 0, + Scope: input.MemoryScope, + Query: input.UserPrompt, + }) if err != nil { return nil, err } @@ -1156,7 +1187,13 @@ func (rt *TemporalRuntime) AgentMemoryStoreActivity(ctx context.Context, input A return err } actLog := newActivityLogger(activity.GetLogger(ctx)) - return rt.ExecuteMemoryStore(ctx, actLog, input.MemoryScope, input.Messages) + return rt.ExecuteMemoryStore(ctx, base.ExecuteMemoryStoreInput{ + Logger: actLog, + RunID: input.RunID, + Iteration: 0, + Scope: input.MemoryScope, + Messages: input.Messages, + }) } // AgentLLMActivity calls the LLM and returns content plus any tool calls. @@ -1189,6 +1226,8 @@ func (rt *TemporalRuntime) AgentLLMActivity(ctx context.Context, input AgentLLMI Logger: actLog, AgentName: agentName, MessageID: input.MessageID, + RunID: input.RunID, + Iteration: input.Iteration, Messages: messages, SkipTools: input.SkipTools, RetrieverContext: input.RetrieverContext, @@ -1367,7 +1406,15 @@ func (rt *TemporalRuntime) AgentToolExecuteActivity(ctx context.Context, input A stopHB := startLongActivityHeartbeats(ctx) defer stopHB() actLog := newActivityLogger(activity.GetLogger(ctx)) - return rt.ExecuteToolWithMemoryScope(ctx, actLog, tools, input.ToolName, input.Args, input.MemoryScope) + return rt.ExecuteTool(ctx, base.ExecuteToolInput{ + Logger: actLog, + Tools: tools, + ToolName: input.ToolName, + Args: input.Args, + ToolCallID: input.ToolCallID, + RunID: input.RunID, + Iteration: input.Iteration, + }, input.MemoryScope) } // AgentToolAuthorizeActivity checks optional programmatic authorization before approval/execute. @@ -1412,8 +1459,18 @@ func (rt *TemporalRuntime) delegateToSubAgent(ctx workflow.Context, input AgentW if subAgentID == "" { subAgentID = tc.ToolName } + + var childSuffix string + if err := workflow.SideEffect(ctx, func(workflow.Context) interface{} { + return uuid.New().String() + }).Get(&childSuffix); err != nil { + logger.Warn("workflow: sub-agent child run id failed", "scope", "workflow", "error", err) + return "", err + } + childInput := AgentWorkflowInput{ UserPrompt: query, + RunID: childSuffix, EventWorkflowID: input.EventWorkflowID, EventTaskQueue: input.EventTaskQueue, LocalChannelName: input.LocalChannelName, @@ -1426,14 +1483,6 @@ func (rt *TemporalRuntime) delegateToSubAgent(ctx workflow.Context, input AgentW SubAgentRoutes: route.ChildRoutes, } - var childSuffix string - if err := workflow.SideEffect(ctx, func(workflow.Context) interface{} { - return uuid.New().String() - }).Get(&childSuffix); err != nil { - logger.Warn("workflow: sub-agent child run id failed", "scope", "workflow", "error", err) - return "", err - } - parentID := workflow.GetInfo(ctx).WorkflowExecution.ID childWfID := fmt.Sprintf("%s-sub-%s-%s", parentID, tc.ToolCallID, childSuffix) childTO := rt.subAgentChildWorkflowTimeout() diff --git a/internal/runtime/temporal/fingerprint.go b/internal/runtime/temporal/fingerprint.go index e026dc9..a911cbb 100644 --- a/internal/runtime/temporal/fingerprint.go +++ b/internal/runtime/temporal/fingerprint.go @@ -55,6 +55,10 @@ type AgentFingerprintPayload struct { // Omitted when empty. Must match pkg/agent [retrieverConfigFingerprint] on caller and worker. RetrieverFingerprint string `json:"retriever_fingerprint,omitempty"` + // HooksFingerprint is the pkg/agent digest of registered hook group names (sorted). + // Omitted when empty. Must match pkg/agent [hookGroupsFingerprint] on caller and worker. + HooksFingerprint string `json:"hooks_fingerprint,omitempty"` + Sampling *sdkruntime.LLMSampling `json:"sampling,omitempty"` SessionSize int `json:"session_size"` @@ -94,6 +98,7 @@ func BuildAgentFingerprintPayload( agentMode string, agentToolExecutionMode types.AgentToolExecutionMode, retrieverFingerprint string, + hooksFingerprint string, ) AgentFingerprintPayload { names := append([]string(nil), toolNames...) sort.Strings(names) @@ -117,6 +122,7 @@ func BuildAgentFingerprintPayload( AgentMode: mode, AgentToolExecutionMode: string(toolExecutionMode), RetrieverFingerprint: retrieverFingerprint, + HooksFingerprint: hooksFingerprint, Sampling: cloneLLMSampling(sampling), SessionSize: sessionSize, MaxIterations: limits.MaxIterations, @@ -197,6 +203,7 @@ func computeAgentFingerprintFromRuntime(rt *TemporalRuntime, tools []interfaces. rt.agentMode, rt.ToolExecutionMode, rt.retrieverFingerprint, + rt.hooksFingerprint, ) return ComputeAgentFingerprint(mat) } diff --git a/internal/runtime/temporal/fingerprint_test.go b/internal/runtime/temporal/fingerprint_test.go index a4e4c0d..088f45b 100644 --- a/internal/runtime/temporal/fingerprint_test.go +++ b/internal/runtime/temporal/fingerprint_test.go @@ -23,8 +23,8 @@ func (f fpTool) Execute(ctx context.Context, args map[string]any) (any, error) { func TestComputeAgentFingerprint_toolOrderStable(t *testing.T) { spec := sdkruntime.AgentSpec{Name: "a", SystemPrompt: "p"} lim := sdkruntime.AgentLimits{MaxIterations: 3} - hA := ComputeAgentFingerprint(BuildAgentFingerprintPayload(spec, []string{"a", "b", "c"}, "auto", nil, 0, lim, "", "", "", "", "", "")) - hB := ComputeAgentFingerprint(BuildAgentFingerprintPayload(spec, []string{"c", "a", "b"}, "auto", nil, 0, lim, "", "", "", "", "", "")) + hA := ComputeAgentFingerprint(BuildAgentFingerprintPayload(spec, []string{"a", "b", "c"}, "auto", nil, 0, lim, "", "", "", "", "", "", "")) + hB := ComputeAgentFingerprint(BuildAgentFingerprintPayload(spec, []string{"c", "a", "b"}, "auto", nil, 0, lim, "", "", "", "", "", "", "")) if hA != hB { t.Fatalf("tool order should not matter: %q vs %q", hA, hB) } @@ -33,8 +33,8 @@ func TestComputeAgentFingerprint_toolOrderStable(t *testing.T) { func TestComputeAgentFingerprint_stableWithoutTools(t *testing.T) { spec := sdkruntime.AgentSpec{Name: "a", SystemPrompt: "p"} lim := sdkruntime.AgentLimits{MaxIterations: 3} - interactive := BuildAgentFingerprintPayload(spec, nil, "auto", nil, 0, lim, "", "", "", "", "", "") - autonomous := BuildAgentFingerprintPayload(spec, nil, "auto", nil, 0, lim, "", "", "", "autonomous", "", "") + interactive := BuildAgentFingerprintPayload(spec, nil, "auto", nil, 0, lim, "", "", "", "", "", "", "") + autonomous := BuildAgentFingerprintPayload(spec, nil, "auto", nil, 0, lim, "", "", "", "autonomous", "", "", "") if ComputeAgentFingerprint(interactive) == ComputeAgentFingerprint(autonomous) { t.Fatal("expected different digests for autonomous vs interactive") } @@ -44,8 +44,8 @@ func TestComputeAgentFingerprint_mcpFingerprintChangesDigest(t *testing.T) { spec := sdkruntime.AgentSpec{Name: "a", SystemPrompt: "p"} lim := sdkruntime.AgentLimits{MaxIterations: 3} tools := []string{"mcp_srv_echo"} - base := BuildAgentFingerprintPayload(spec, tools, "auto", nil, 0, lim, "", "", "", "", "", "") - withMCP := BuildAgentFingerprintPayload(spec, tools, "auto", nil, 0, lim, "abc123deadbeef", "", "", "", "", "") + base := BuildAgentFingerprintPayload(spec, tools, "auto", nil, 0, lim, "", "", "", "", "", "", "") + withMCP := BuildAgentFingerprintPayload(spec, tools, "auto", nil, 0, lim, "abc123deadbeef", "", "", "", "", "", "") h0 := ComputeAgentFingerprint(base) h1 := ComputeAgentFingerprint(withMCP) if h0 == h1 { @@ -57,8 +57,8 @@ func TestComputeAgentFingerprint_a2aFingerprintChangesDigest(t *testing.T) { spec := sdkruntime.AgentSpec{Name: "a", SystemPrompt: "p"} lim := sdkruntime.AgentLimits{MaxIterations: 3} tools := []string{"a2a_remote_echo"} - base := BuildAgentFingerprintPayload(spec, tools, "auto", nil, 0, lim, "", "", "", "", "", "") - withA2A := BuildAgentFingerprintPayload(spec, tools, "auto", nil, 0, lim, "", "a2afp_deadbeef", "", "", "", "") + base := BuildAgentFingerprintPayload(spec, tools, "auto", nil, 0, lim, "", "", "", "", "", "", "") + withA2A := BuildAgentFingerprintPayload(spec, tools, "auto", nil, 0, lim, "", "a2afp_deadbeef", "", "", "", "", "") h0 := ComputeAgentFingerprint(base) h1 := ComputeAgentFingerprint(withA2A) if h0 == h1 { @@ -69,8 +69,8 @@ func TestComputeAgentFingerprint_a2aFingerprintChangesDigest(t *testing.T) { func TestComputeAgentFingerprint_retrieverFingerprintChangesDigest(t *testing.T) { spec := sdkruntime.AgentSpec{Name: "a", SystemPrompt: "p"} lim := sdkruntime.AgentLimits{MaxIterations: 3} - empty := BuildAgentFingerprintPayload(spec, nil, "auto", nil, 0, lim, "", "", "", "", "", "") - withFP := BuildAgentFingerprintPayload(spec, nil, "auto", nil, 0, lim, "", "", "", "", "", "retriever_fp_deadbeef") + empty := BuildAgentFingerprintPayload(spec, nil, "auto", nil, 0, lim, "", "", "", "", "", "", "") + withFP := BuildAgentFingerprintPayload(spec, nil, "auto", nil, 0, lim, "", "", "", "", "", "retriever_fp_deadbeef", "") if ComputeAgentFingerprint(empty) == ComputeAgentFingerprint(withFP) { t.Fatal("expected different digests when retriever fingerprint set") } @@ -80,8 +80,8 @@ func TestComputeAgentFingerprint_observabilityFingerprintChangesDigest(t *testin spec := sdkruntime.AgentSpec{Name: "a", SystemPrompt: "p"} lim := sdkruntime.AgentLimits{MaxIterations: 3} tools := []string{"t1"} - base := BuildAgentFingerprintPayload(spec, tools, "auto", nil, 0, lim, "", "", "", "", "", "") - withObs := BuildAgentFingerprintPayload(spec, tools, "auto", nil, 0, lim, "", "", "obs_deadbeef", "", "", "") + base := BuildAgentFingerprintPayload(spec, tools, "auto", nil, 0, lim, "", "", "", "", "", "", "") + withObs := BuildAgentFingerprintPayload(spec, tools, "auto", nil, 0, lim, "", "", "obs_deadbeef", "", "", "", "") h0 := ComputeAgentFingerprint(base) h1 := ComputeAgentFingerprint(withObs) if h0 == h1 { @@ -89,6 +89,16 @@ func TestComputeAgentFingerprint_observabilityFingerprintChangesDigest(t *testin } } +func TestComputeAgentFingerprint_hooksFingerprintChangesDigest(t *testing.T) { + spec := sdkruntime.AgentSpec{Name: "a", SystemPrompt: "p"} + lim := sdkruntime.AgentLimits{MaxIterations: 3} + empty := BuildAgentFingerprintPayload(spec, nil, "auto", nil, 0, lim, "", "", "", "", "", "", "") + withHooks := BuildAgentFingerprintPayload(spec, nil, "auto", nil, 0, lim, "", "", "", "", "", "", "hooks_fp_deadbeef") + if ComputeAgentFingerprint(empty) == ComputeAgentFingerprint(withHooks) { + t.Fatal("expected different digests when hooks fingerprint set") + } +} + func TestVerifyAgentFingerprint_mismatch(t *testing.T) { rt := &TemporalRuntime{ resolveToolsFn: func(context.Context) ([]interfaces.Tool, error) { @@ -149,7 +159,7 @@ func TestBuildAgentFingerprintPayload_responseFormatAndSampling(t *testing.T) { Reasoning: &interfaces.LLMReasoning{Effort: "low"}, } lim := sdkruntime.AgentLimits{MaxIterations: 1, Timeout: 0, ApprovalTimeout: 0} - p := BuildAgentFingerprintPayload(spec, []string{"t1"}, "p", sampling, 5, lim, "mcpfp", "", "", "", "", "") + p := BuildAgentFingerprintPayload(spec, []string{"t1"}, "p", sampling, 5, lim, "mcpfp", "", "", "", "", "", "") if p.ResponseFormat == nil || p.ResponseFormat.Type != string(interfaces.ResponseFormatJSON) { t.Fatalf("response format: %+v", p.ResponseFormat) } diff --git a/internal/runtime/temporal/options.go b/internal/runtime/temporal/options.go index 5494039..1a425bc 100644 --- a/internal/runtime/temporal/options.go +++ b/internal/runtime/temporal/options.go @@ -118,6 +118,12 @@ func WithRetrieverFingerprint(fp string) Option { return func(rt *TemporalRuntime) { rt.retrieverFingerprint = fp } } +// WithHooksFingerprint sets the hook group names digest used with [ComputeAgentFingerprint]. +// Must match pkg/agent [hookGroupsFingerprint] for the same WithHooks wiring. +func WithHooksFingerprint(fp string) Option { + return func(rt *TemporalRuntime) { rt.hooksFingerprint = fp } +} + // WithToolsResolver sets the callback that resolves tools at activity time on the worker runtime. func WithToolsResolver(fn ToolsResolver) Option { return func(rt *TemporalRuntime) { rt.resolveToolsFn = fn } diff --git a/internal/runtime/temporal/runtime.go b/internal/runtime/temporal/runtime.go index e4264a1..7e6feb4 100644 --- a/internal/runtime/temporal/runtime.go +++ b/internal/runtime/temporal/runtime.go @@ -78,6 +78,7 @@ type TemporalRuntime struct { // agentMode is the string form of [types.AgentMode] (e.g. "interactive", "autonomous"). agentMode string retrieverFingerprint string + hooksFingerprint string // Temporal-specific flags // disableLocalWorker mirrors pkg/agent DisableLocalWorker: when false, the client embeds a worker @@ -329,6 +330,7 @@ func (rt *TemporalRuntime) Execute(ctx context.Context, req *runtime.ExecuteRequ wfInput := AgentWorkflowInput{ UserPrompt: req.UserPrompt, + RunID: runID, StreamingEnabled: false, EventWorkflowID: "", LocalChannelName: eventChannelName(workflowID), @@ -520,6 +522,7 @@ func (rt *TemporalRuntime) ExecuteStream(ctx context.Context, req *runtime.Execu } wfInput := AgentWorkflowInput{ UserPrompt: req.UserPrompt, + RunID: runID, EventWorkflowID: eventWorkflowID, EventTaskQueue: eventTaskQueue, LocalChannelName: eventChannelName(workflowID), diff --git a/internal/types/retriever.go b/internal/types/retriever.go index 8fde9a9..b01f6a5 100644 --- a/internal/types/retriever.go +++ b/internal/types/retriever.go @@ -1,5 +1,13 @@ package types +import ( + "fmt" + "strings" +) + +// RetrieverToolNamePrefix is the tool name prefix for agentic retriever tools (see [RetrieverToolName]). +const RetrieverToolNamePrefix = "retriever_" + // RetrieverToolParamQuery is the tool/JSON parameter name for the query sent to a retriever. const RetrieverToolParamQuery = "query" @@ -28,3 +36,48 @@ const ( // RetrieverModeHybrid combines prefetch with agentic retrieval during the run. RetrieverModeHybrid RetrieverMode = "hybrid" ) + +// RetrieverToolName returns the registered tool name for a retriever key (e.g. "kb" → "retriever_kb"). +// Returns "" when retrieverKey is empty after trim. +func RetrieverToolName(retrieverName string) string { + name := strings.TrimSpace(retrieverName) + if name == "" { + return "" + } + return RetrieverToolNamePrefix + name +} + +// RetrieverNameFromToolName extracts the retriever name from a retriever tool name. +// Returns ok false when toolName does not use [RetrieverToolNamePrefix] or the key is empty. +func RetrieverNameFromToolName(toolName string) (name string, ok bool) { + toolName = strings.TrimSpace(toolName) + if toolName == "" || !strings.HasPrefix(toolName, RetrieverToolNamePrefix) { + return "", false + } + name = strings.TrimSpace(toolName[len(RetrieverToolNamePrefix):]) + if name == "" { + return "", false + } + return name, true +} + +// RetrieverToolDisplayName returns the human-readable tool name for a retriever key. +func RetrieverToolDisplayName(retrieverKey string) string { + key := strings.TrimSpace(retrieverKey) + if key == "" { + return "" + } + return fmt.Sprintf("%s Retriever Tool", key) +} + +func RetrieverToolParamQueryValue(args map[string]any) (string, error) { + query, ok := args[RetrieverToolParamQuery].(string) + if !ok { + return "", fmt.Errorf("retriever tool: %q parameter required", RetrieverToolParamQuery) + } + query = strings.TrimSpace(query) + if query == "" { + return "", fmt.Errorf("retriever tool: %q must be non-empty", RetrieverToolParamQuery) + } + return query, nil +} diff --git a/internal/types/retriever_test.go b/internal/types/retriever_test.go new file mode 100644 index 0000000..7188291 --- /dev/null +++ b/internal/types/retriever_test.go @@ -0,0 +1,79 @@ +package types + +import ( + "strings" + "testing" +) + +func TestRetrieverToolName(t *testing.T) { + if got := RetrieverToolName(" wiki "); got != "retriever_wiki" { + t.Fatalf("got %q", got) + } + if RetrieverToolName("") != "" || RetrieverToolName(" ") != "" { + t.Fatal("expected empty for missing key") + } +} + +func TestRetrieverNameFromToolName(t *testing.T) { + name, ok := RetrieverNameFromToolName("retriever_kb") + if !ok || name != "kb" { + t.Fatalf("got (%q, %v)", name, ok) + } + if _, ok := RetrieverNameFromToolName("retriever_"); ok { + t.Fatal("expected false for prefix only") + } + if _, ok := RetrieverNameFromToolName("other_kb"); ok { + t.Fatal("expected false for wrong prefix") + } + if _, ok := RetrieverNameFromToolName(""); ok { + t.Fatal("expected false for empty") + } +} + +func TestRetrieverToolName_roundTrip(t *testing.T) { + key := "my_kb" + toolName := RetrieverToolName(key) + got, ok := RetrieverNameFromToolName(toolName) + if !ok || got != key { + t.Fatalf("round trip = (%q, %v)", got, ok) + } +} + +func TestRetrieverToolDisplayName(t *testing.T) { + if got := RetrieverToolDisplayName("wiki"); got != "wiki Retriever Tool" { + t.Fatalf("got %q", got) + } + if RetrieverToolDisplayName(" ") != "" { + t.Fatal("expected empty") + } +} + +func TestRetrieverToolParamQueryValue(t *testing.T) { + got, err := RetrieverToolParamQueryValue(map[string]any{RetrieverToolParamQuery: " golang "}) + if err != nil { + t.Fatal(err) + } + if got != "golang" { + t.Fatalf("got %q", got) + } + + _, err = RetrieverToolParamQueryValue(nil) + if err == nil { + t.Fatal("expected error for nil args") + } + + _, err = RetrieverToolParamQueryValue(map[string]any{}) + if err == nil { + t.Fatal("expected error for missing query") + } + + _, err = RetrieverToolParamQueryValue(map[string]any{RetrieverToolParamQuery: 42}) + if err == nil { + t.Fatal("expected error for non-string query") + } + + _, err = RetrieverToolParamQueryValue(map[string]any{RetrieverToolParamQuery: " "}) + if err == nil || !strings.Contains(err.Error(), "non-empty") { + t.Fatalf("got %v", err) + } +} diff --git a/internal/types/tool.go b/internal/types/tool.go index 36bb291..d9e409f 100644 --- a/internal/types/tool.go +++ b/internal/types/tool.go @@ -54,3 +54,8 @@ func (k ToolKind) CountsTowardToolTelemetry() bool { return true } } + +// HooksEligible reports whether [BeforeToolHook] and [AfterToolHook] run for this tool kind. +func (k ToolKind) HooksEligible() bool { + return k == ToolKindNative || k == ToolKindMCP +} diff --git a/internal/types/tool_test.go b/internal/types/tool_test.go index 9056fbd..6284871 100644 --- a/internal/types/tool_test.go +++ b/internal/types/tool_test.go @@ -39,3 +39,14 @@ func TestToolKind_CountsTowardToolTelemetry(t *testing.T) { t.Fatal("memory tool should count toward tool telemetry") } } + +func TestToolKind_HooksEligible(t *testing.T) { + if !ToolKindNative.HooksEligible() || !ToolKindMCP.HooksEligible() { + t.Fatal("native and mcp should be hook eligible") + } + for _, k := range []ToolKind{ToolKindA2A, ToolKindSubAgent, ToolKindRetriever, ToolKindMemory} { + if k.HooksEligible() { + t.Fatalf("%q should not be hook eligible", k) + } + } +} diff --git a/pkg/agent/config.go b/pkg/agent/config.go index 6654694..d65177f 100644 --- a/pkg/agent/config.go +++ b/pkg/agent/config.go @@ -13,6 +13,7 @@ import ( "log/slog" + "github.com/agenticenv/agent-sdk-go/internal/hooks" "github.com/agenticenv/agent-sdk-go/internal/runtime" "github.com/agenticenv/agent-sdk-go/internal/runtime/temporal" "github.com/agenticenv/agent-sdk-go/internal/types" @@ -189,7 +190,7 @@ type ObservabilityConfig struct { // WithMaxIterations, WithStream, WithLogger, WithLogLevel, WithConversation, WithMemory, // WithResponseFormat, WithLLMSampling, WithSubAgents, WithMaxSubAgentDepth, // WithMCPConfig, WithMCPClients, WithA2AConfig, WithA2AClients, WithRetrievers, WithRetrieverMode, WithAgentMode, WithDisableFingerprintCheck, WithAgentToolExecutionMode, -// WithObservabilityConfig, WithTracer, WithMetrics, WithLogs +// WithObservabilityConfig, WithTracer, WithMetrics, WithLogs, WithHooks // // When [WithObservabilityConfig] is set and a signal is not disabled, [buildAgentConfig] replaces // [WithTracer], [WithMetrics], and [WithLogs] for that signal with OTLP clients built from the config. @@ -270,6 +271,9 @@ type agentConfig struct { tracer interfaces.Tracer metrics interfaces.Metrics logs interfaces.Logs + + // Hooks: named middleware hook groups for the agent execution lifecycle. + hooks []hooks.HookGroup } // Default Run/Stream deadlines when [WithTimeout] is unset: shorter for interactive sessions, @@ -640,6 +644,34 @@ func WithLogs(l interfaces.Logs) Option { return func(c *agentConfig) { c.logs = l } } +// WithHooks registers a named group of middleware hooks on the agent. +// name should be unique across all WithHooks calls; it is used for Temporal fingerprinting and +// run correlation (see [RunMeta]). Can be called multiple times; each call appends a new group +// that runs independently with its own [RunMeta].HooksGroup value, preserving declaration order. +// +// Use hooks to intercept and modify agent behavior at runtime — for example guardrails and PII +// scrubbing on [BeforeLLMHook]/[AfterLLMHook], tool input scrubbing on [BeforeToolHook]/[AfterToolHook], +// retrieval filtering on [AfterRetrieveHook], and memory tenant isolation on +// [BeforeMemoryLoadHook]/[BeforeMemoryStoreHook]. See [AgentHooks] for the full list of hook +// points and common use cases. +func WithHooks(name string, agentHooks AgentHooks) Option { + return func(c *agentConfig) { + c.hooks = append(c.hooks, hooks.HookGroup{ + Name: strings.TrimSpace(name), + Hooks: agentHooks, + }) + } +} + +// mergedHooks returns all hook groups combined in registration order. +func (c *agentConfig) mergedHooks() AgentHooks { + var out AgentHooks + for _, g := range c.hooks { + out = out.Merge(g.Hooks) + } + return out +} + // otlpLogsClientConfigured reports whether logs holds a concrete OTLP [*observability.Logs] client // (built by the SDK or injected after [observability.NewLogs]), before [DefaultNoopLogs] fallback. func otlpLogsClientConfigured(logs interfaces.Logs) bool { @@ -686,6 +718,9 @@ func buildAgentConfig(opts []Option) (*agentConfig, error) { if c.toolApprovalPolicy == nil { c.toolApprovalPolicy = RequireAllToolApprovalPolicy{} } + if err := c.validateHookGroups(); err != nil { + return nil, err + } // Temporal-specific validation: only enforced when the caller explicitly opts in to the // Temporal backend. When neither is set the local runtime is used as the default backend. if c.temporalConfig != nil && c.temporalClient != nil { @@ -1180,6 +1215,20 @@ func validateToolNames(tools []interfaces.Tool) error { return nil } +func (c *agentConfig) validateHookGroups() error { + seen := make(map[string]struct{}, len(c.hooks)) + for _, g := range c.hooks { + if g.Name == "" { + return errors.New("WithHooks: hook group name is required") + } + if _, dup := seen[g.Name]; dup { + return fmt.Errorf("WithHooks: duplicate hook group name %q", g.Name) + } + seen[g.Name] = struct{}{} + } + return nil +} + // responseFormatForLLM returns the response format for LLM requests. // When user sets WithResponseFormat, that is used; otherwise text-only. func (c *agentConfig) responseFormatForLLM() *interfaces.ResponseFormat { @@ -1239,6 +1288,7 @@ func (c *agentConfig) runtimeAgentConfig() runtime.AgentConfig { Timeout: c.timeout, ApprovalTimeout: c.approvalTimeout, }, + Hooks: c.runtimeHookGroups(), } if c.llmSampling != nil { d.LLM.Sampling = &runtime.LLMSampling{ @@ -1252,6 +1302,37 @@ func (c *agentConfig) runtimeAgentConfig() runtime.AgentConfig { return d } +// runtimeHookGroups copies configured hook groups onto the runtime view. +func (c *agentConfig) runtimeHookGroups() []hooks.HookGroup { + if len(c.hooks) == 0 { + return nil + } + out := make([]hooks.HookGroup, len(c.hooks)) + copy(out, c.hooks) + return out +} + +// hookGroupsFingerprint returns a stable SHA-256 digest of configured hook group names for +// [temporal.ComputeAgentFingerprint]. Names are sorted for stability. Returns "" when no hook +// groups are configured. Hook implementations are not hashed — caller and worker must register +// matching behavior under each name. +func hookGroupsFingerprint(groups []hooks.HookGroup) string { + if len(groups) == 0 { + return "" + } + names := make([]string, len(groups)) + for i, g := range groups { + names[i] = g.Name + } + sort.Strings(names) + b, err := json.Marshal(names) + if err != nil { + return "" + } + sum := sha256.Sum256(b) + return hex.EncodeToString(sum[:]) +} + type observabilityFpShot struct { Endpoint string `json:"endpoint"` Protocol string `json:"protocol"` diff --git a/pkg/agent/config_test.go b/pkg/agent/config_test.go index 6848292..38cecb9 100644 --- a/pkg/agent/config_test.go +++ b/pkg/agent/config_test.go @@ -53,6 +53,7 @@ func agentConfigFingerprintTools(c *agentConfig, tools []interfaces.Tool) string string(c.agentMode), c.agentToolExecutionMode, retrieverConfigFingerprint(c.retrieverMode, c.retrievers), + hookGroupsFingerprint(c.hooks), )) } diff --git a/pkg/agent/hooks.go b/pkg/agent/hooks.go new file mode 100644 index 0000000..b19f33f --- /dev/null +++ b/pkg/agent/hooks.go @@ -0,0 +1,48 @@ +package agent + +import ( + "github.com/agenticenv/agent-sdk-go/internal/hooks" +) + +// Core +type AgentHooks = hooks.AgentHooks +type HookGroup = hooks.HookGroup +type RunMeta = hooks.RunMeta + +// LLM +type BeforeLLMHookInput = hooks.BeforeLLMHookInput +type BeforeLLMHookOutput = hooks.BeforeLLMHookOutput +type AfterLLMHookInput = hooks.AfterLLMHookInput +type AfterLLMHookOutput = hooks.AfterLLMHookOutput +type BeforeLLMHook = hooks.BeforeLLMHook +type AfterLLMHook = hooks.AfterLLMHook + +// Tools +type BeforeToolHookInput = hooks.BeforeToolHookInput +type BeforeToolHookOutput = hooks.BeforeToolHookOutput +type AfterToolHookInput = hooks.AfterToolHookInput +type AfterToolHookOutput = hooks.AfterToolHookOutput +type BeforeToolHook = hooks.BeforeToolHook +type AfterToolHook = hooks.AfterToolHook + +// Retriever +type BeforeRetrieveHookInput = hooks.BeforeRetrieveHookInput +type BeforeRetrieveHookOutput = hooks.BeforeRetrieveHookOutput +type AfterRetrieveHookInput = hooks.AfterRetrieveHookInput +type AfterRetrieveHookOutput = hooks.AfterRetrieveHookOutput +type BeforeRetrieveHook = hooks.BeforeRetrieveHook +type AfterRetrieveHook = hooks.AfterRetrieveHook + +// Memory +type BeforeMemoryLoadHookInput = hooks.BeforeMemoryLoadHookInput +type BeforeMemoryLoadHookOutput = hooks.BeforeMemoryLoadHookOutput +type AfterMemoryLoadHookInput = hooks.AfterMemoryLoadHookInput +type AfterMemoryLoadHookOutput = hooks.AfterMemoryLoadHookOutput +type BeforeMemoryStoreHookInput = hooks.BeforeMemoryStoreHookInput +type BeforeMemoryStoreHookOutput = hooks.BeforeMemoryStoreHookOutput +type AfterMemoryStoreHookInput = hooks.AfterMemoryStoreHookInput +type AfterMemoryStoreHookOutput = hooks.AfterMemoryStoreHookOutput +type BeforeMemoryLoadHook = hooks.BeforeMemoryLoadHook +type AfterMemoryLoadHook = hooks.AfterMemoryLoadHook +type BeforeMemoryStoreHook = hooks.BeforeMemoryStoreHook +type AfterMemoryStoreHook = hooks.AfterMemoryStoreHook diff --git a/pkg/agent/hooks_test.go b/pkg/agent/hooks_test.go new file mode 100644 index 0000000..44aed7f --- /dev/null +++ b/pkg/agent/hooks_test.go @@ -0,0 +1,147 @@ +package agent + +import ( + "context" + "strings" + "testing" +) + +func TestWithHooks_MergesInDeclarationOrder(t *testing.T) { + h1 := func(context.Context, BeforeLLMHookInput) (BeforeLLMHookOutput, error) { + return BeforeLLMHookOutput{}, nil + } + h2 := func(context.Context, BeforeLLMHookInput) (BeforeLLMHookOutput, error) { + return BeforeLLMHookOutput{}, nil + } + h3 := func(context.Context, AfterToolHookInput) (AfterToolHookOutput, error) { + return AfterToolHookOutput{}, nil + } + + cfg, err := buildAgentConfig([]Option{ + WithName("hooks"), + WithLLMClient(stubLLM{}), + WithHooks("guardrails", AgentHooks{BeforeLLM: []BeforeLLMHook{h1}}), + WithHooks("audit", AgentHooks{ + BeforeLLM: []BeforeLLMHook{h2}, + AfterTool: []AfterToolHook{h3}, + }), + }) + if err != nil { + t.Fatal(err) + } + merged := cfg.mergedHooks() + if len(merged.BeforeLLM) != 2 { + t.Fatalf("BeforeLLM len = %d, want 2", len(merged.BeforeLLM)) + } + if merged.BeforeLLM[0] == nil || merged.BeforeLLM[1] == nil { + t.Fatal("expected non-nil BeforeLLM hooks") + } + if len(merged.AfterTool) != 1 { + t.Fatalf("AfterTool len = %d, want 1", len(merged.AfterTool)) + } + if len(cfg.hooks) != 2 || cfg.hooks[0].Name != "guardrails" || cfg.hooks[1].Name != "audit" { + t.Fatalf("hooks = %#v", cfg.hooks) + } +} + +func TestWithHooks_RequiresName(t *testing.T) { + h := func(context.Context, BeforeRetrieveHookInput) (BeforeRetrieveHookOutput, error) { + return BeforeRetrieveHookOutput{}, nil + } + _, err := buildAgentConfig([]Option{ + WithName("hooks-empty"), + WithLLMClient(stubLLM{}), + WithHooks("", AgentHooks{BeforeRetrieve: []BeforeRetrieveHook{h}}), + }) + if err == nil || !strings.Contains(err.Error(), "hook group name is required") { + t.Fatalf("expected name required error, got %v", err) + } +} + +func TestWithHooks_RejectsDuplicateName(t *testing.T) { + h := func(context.Context, BeforeRetrieveHookInput) (BeforeRetrieveHookOutput, error) { + return BeforeRetrieveHookOutput{}, nil + } + _, err := buildAgentConfig([]Option{ + WithName("hooks-dup"), + WithLLMClient(stubLLM{}), + WithHooks("audit", AgentHooks{BeforeRetrieve: []BeforeRetrieveHook{h}}), + WithHooks("audit", AgentHooks{}), + }) + if err == nil || !strings.Contains(err.Error(), `duplicate hook group name "audit"`) { + t.Fatalf("expected duplicate name error, got %v", err) + } +} + +func TestRuntimeAgentConfig_PassesHookGroups(t *testing.T) { + h := func(context.Context, BeforeLLMHookInput) (BeforeLLMHookOutput, error) { + return BeforeLLMHookOutput{}, nil + } + cfg, err := buildAgentConfig([]Option{ + WithName("hooks-runtime"), + WithLLMClient(stubLLM{}), + WithHooks("guardrails", AgentHooks{BeforeLLM: []BeforeLLMHook{h}}), + WithHooks("audit", AgentHooks{}), + }) + if err != nil { + t.Fatal(err) + } + rt := cfg.runtimeAgentConfig() + if len(rt.Hooks) != 2 { + t.Fatalf("Hooks len = %d, want 2", len(rt.Hooks)) + } + if rt.Hooks[0].Name != "guardrails" || len(rt.Hooks[0].Hooks.BeforeLLM) != 1 { + t.Fatalf("Hooks[0] = %#v", rt.Hooks[0]) + } + if rt.Hooks[1].Name != "audit" { + t.Fatalf("Hooks[1].Name = %q, want audit", rt.Hooks[1].Name) + } +} + +func TestHookGroupsFingerprint_emptyWhenNoGroups(t *testing.T) { + if got := hookGroupsFingerprint(nil); got != "" { + t.Fatalf("got %q, want empty", got) + } +} + +func TestHookGroupsFingerprint_sortedNamesStable(t *testing.T) { + fpAB := hookGroupsFingerprint([]HookGroup{{Name: "audit"}, {Name: "guardrails"}}) + fpBA := hookGroupsFingerprint([]HookGroup{{Name: "guardrails"}, {Name: "audit"}}) + if fpAB == "" { + t.Fatal("expected non-empty fingerprint") + } + if fpAB != fpBA { + t.Fatalf("registration order should not matter: %q vs %q", fpAB, fpBA) + } +} + +func TestHookGroupsFingerprint_differentNamesDifferentDigest(t *testing.T) { + fpOne := hookGroupsFingerprint([]HookGroup{{Name: "guardrails"}}) + fpTwo := hookGroupsFingerprint([]HookGroup{{Name: "guardrails"}, {Name: "audit"}}) + if fpOne == fpTwo { + t.Fatal("expected different fingerprints for different hook group sets") + } +} + +func TestAgentConfigFingerprint_HookGroupsChangesDigest(t *testing.T) { + h := func(context.Context, BeforeLLMHookInput) (BeforeLLMHookOutput, error) { + return BeforeLLMHookOutput{}, nil + } + baseOpts := []Option{ + WithName("hooks-fp"), + WithLLMClient(stubLLM{}), + } + cfgNoHooks, err := buildAgentConfig(baseOpts) + if err != nil { + t.Fatal(err) + } + cfgWithHooks, err := buildAgentConfig(append(baseOpts, + WithHooks("guardrails", AgentHooks{BeforeLLM: []BeforeLLMHook{h}}), + )) + if err != nil { + t.Fatal(err) + } + if agentConfigFingerprint(cfgNoHooks) == agentConfigFingerprint(cfgWithHooks) { + t.Fatal("expected different fingerprints when hook groups are configured") + } +} diff --git a/pkg/agent/retriever.go b/pkg/agent/retriever.go index cbaf906..4dbb4dc 100644 --- a/pkg/agent/retriever.go +++ b/pkg/agent/retriever.go @@ -14,11 +14,6 @@ import ( "github.com/agenticenv/agent-sdk-go/pkg/tools" ) -var ( - retrieverToolNameTemplate = "retriever_%s" - retrieverToolDisplayNameTemplate = "%s Retriever Tool" -) - var _ interfaces.Tool = (*RetrieverTool)(nil) var _ types.ToolKindProvider = (*RetrieverTool)(nil) @@ -29,26 +24,6 @@ type RetrieverTool struct { Retriever interfaces.Retriever } -// retrieverToolName returns the registered tool name for a retriever name -// (same format as [RetrieverTool.Name]). Trims whitespace; returns "" if empty after trim. -func retrieverToolName(retrieverName string) string { - n := strings.TrimSpace(retrieverName) - if n == "" { - return "" - } - return fmt.Sprintf(retrieverToolNameTemplate, n) -} - -// retrieverToolDisplayName returns the display name for a retriever name -// (same format as [RetrieverTool.DisplayName]). Trims whitespace; returns "" if empty after trim. -func retrieverToolDisplayName(retrieverName string) string { - n := strings.TrimSpace(retrieverName) - if n == "" { - return "" - } - return fmt.Sprintf(retrieverToolDisplayNameTemplate, n) -} - // NewRetrieverTool builds a RetrieverTool. Returns nil when retriever is nil or [interfaces.Retriever.Name] is empty. func NewRetrieverTool(retriever interfaces.Retriever) interfaces.Tool { if retriever == nil { @@ -69,7 +44,7 @@ func (t *RetrieverTool) Name() string { if t == nil { return "" } - return retrieverToolName(t.RetrieverName) + return types.RetrieverToolName(t.RetrieverName) } // DisplayName implements [interfaces.Tool]. @@ -77,7 +52,7 @@ func (t *RetrieverTool) DisplayName() string { if t == nil { return "" } - return fmt.Sprintf(retrieverToolDisplayNameTemplate, t.RetrieverName) + return types.RetrieverToolDisplayName(t.RetrieverName) } // Description implements [interfaces.Tool]. @@ -105,35 +80,16 @@ func (t *RetrieverTool) Parameters() interfaces.JSONSchema { } // Execute implements [interfaces.Tool]: reads the query argument, calls [interfaces.Retriever.Search], -// and returns a numbered plain-text summary of matching documents. +// and returns matching documents. Formatting for the LLM is done by the runtime. func (t *RetrieverTool) Execute(ctx context.Context, args map[string]any) (any, error) { if t.Retriever == nil { return nil, fmt.Errorf("retriever tool: nil retriever") } - raw, ok := args[types.RetrieverToolParamQuery].(string) - if !ok { - return nil, fmt.Errorf("retriever tool: %q parameter required", types.RetrieverToolParamQuery) - } - query := strings.TrimSpace(raw) - if query == "" { - return nil, fmt.Errorf("retriever tool: %q must be non-empty", types.RetrieverToolParamQuery) - } - docs, err := t.Retriever.Search(ctx, query) + query, err := types.RetrieverToolParamQueryValue(args) if err != nil { return nil, err } - return formatRetrieverDocs(docs), nil -} - -func formatRetrieverDocs(docs []interfaces.Document) string { - if len(docs) == 0 { - return "no relevant documents found" - } - var sb strings.Builder - for i, doc := range docs { - fmt.Fprintf(&sb, types.RetrieverDocFormat, i+1, doc.Content, doc.Source, doc.Score) - } - return sb.String() + return t.Retriever.Search(ctx, query) } // --------------------------------------------------------------------------- diff --git a/pkg/agent/retriever_test.go b/pkg/agent/retriever_test.go index c89b265..3717d4c 100644 --- a/pkg/agent/retriever_test.go +++ b/pkg/agent/retriever_test.go @@ -33,19 +33,19 @@ func (r *retrieverExecuteStub) Search(ctx context.Context, query string) ([]inte } func TestRetrieverToolName(t *testing.T) { - if got := retrieverToolName(" wiki "); got != "retriever_wiki" { + if got := types.RetrieverToolName(" wiki "); got != "retriever_wiki" { t.Fatalf("got %q", got) } - if retrieverToolName("") != "" || retrieverToolName(" ") != "" { + if types.RetrieverToolName("") != "" || types.RetrieverToolName(" ") != "" { t.Fatal("expected empty for missing name") } } func TestRetrieverToolDisplayName(t *testing.T) { - if got := retrieverToolDisplayName("wiki"); got != "wiki Retriever Tool" { + if got := types.RetrieverToolDisplayName("wiki"); got != "wiki Retriever Tool" { t.Fatalf("got %q", got) } - if retrieverToolDisplayName(" ") != "" { + if types.RetrieverToolDisplayName(" ") != "" { t.Fatal("expected empty") } } @@ -120,12 +120,12 @@ func TestRetrieverTool_Execute_success(t *testing.T) { if err != nil { t.Fatal(err) } - s, ok := out.(string) + docs, ok := out.([]interfaces.Document) if !ok { t.Fatalf("got %T", out) } - if !strings.Contains(s, "[1] Go is great") || !strings.Contains(s, "[2] Rust is fast") { - t.Fatalf("output = %q", s) + if len(docs) != 2 || docs[0].Content != "Go is great" || docs[1].Content != "Rust is fast" { + t.Fatalf("docs = %#v", docs) } if stub.lastQuery != "golang" { t.Fatalf("query = %q", stub.lastQuery) @@ -141,8 +141,12 @@ func TestRetrieverTool_Execute_noDocs(t *testing.T) { if err != nil { t.Fatal(err) } - if out != "no relevant documents found" { - t.Fatalf("got %q", out) + docs, ok := out.([]interfaces.Document) + if !ok { + t.Fatalf("got %T", out) + } + if len(docs) != 0 { + t.Fatalf("docs = %#v", docs) } } @@ -179,18 +183,6 @@ func TestRetrieverTool_Execute_nilRetriever(t *testing.T) { } } -func TestFormatRetrieverDocs(t *testing.T) { - if formatRetrieverDocs(nil) != "no relevant documents found" { - t.Fatal("nil docs") - } - got := formatRetrieverDocs([]interfaces.Document{ - {Content: "alpha", Source: "a", Score: 0.5}, - }) - if !strings.Contains(got, "[1] alpha") || !strings.Contains(got, "score: 0.50") { - t.Fatalf("got %q", got) - } -} - // --------------------------------------------------------------------------- // retrieverConfigFingerprint tests diff --git a/pkg/agent/runtime_factory.go b/pkg/agent/runtime_factory.go index 6be46d0..2e04fd9 100644 --- a/pkg/agent/runtime_factory.go +++ b/pkg/agent/runtime_factory.go @@ -24,6 +24,7 @@ func (cfg *agentConfig) buildTemporalRuntime(remoteWorker bool) (*temporal.Tempo temporal.WithAgentMode(string(cfg.agentMode)), temporal.WithAgentToolExecutionMode(cfg.agentToolExecutionMode), temporal.WithRetrieverFingerprint(retrieverConfigFingerprint(cfg.retrieverMode, cfg.retrievers)), + temporal.WithHooksFingerprint(hookGroupsFingerprint(cfg.hooks)), temporal.WithDisableLocalWorker(cfg.disableLocalWorker), // Never allow fingerprint bypass on remote worker runtime. temporal.WithDisableFingerprintCheck(cfg.disableFingerprintCheck && !remoteWorker), diff --git a/taskfiles/examples.yml b/taskfiles/examples.yml index 0f10983..7097cc2 100644 --- a/taskfiles/examples.yml +++ b/taskfiles/examples.yml @@ -199,6 +199,7 @@ tasks: - agent_with_tools/basic - agent_with_tools/custom - agent_with_tools/dynamic_registry + - agent_with_hooks - agent_with_tools/authorizer - agent_with_json_response - agent_with_stream