From c6653919dbed4bf81efe50ccc5ed23c1ab8846b0 Mon Sep 17 00:00:00 2001 From: vinodvx <8429554+vinodvx@users.noreply.github.com> Date: Sat, 20 Jun 2026 16:20:31 -0700 Subject: [PATCH] feat: Add agent long-term memory with ondemand and always store modes --- CONTRIBUTING.md | 19 +- README.md | 468 ++++++++----- benchmarks/README.md | 13 + benchmarks/agent_builder.go | 5 +- benchmarks/config.yaml | 5 + benchmarks/main.go | 6 + benchmarks/metrics.go | 29 +- benchmarks/report.go | 4 + benchmarks/runner.go | 24 +- benchmarks/setup/config.go | 66 +- benchmarks/setup/mock_llm.go | 28 + benchmarks/setup/opts.go | 12 + benchmarks/worker/main.go | 5 +- cmd/main.go | 4 +- eval-harness/README.md | 46 +- eval-harness/deepeval/harness.py | 7 + eval-harness/deepeval/test_agent.py | 26 +- eval-harness/promptfoo/config.yaml | 34 + eval-harness/run_agent_memory.sh | 12 + eval-harness/runner/config.yaml | 10 +- eval-harness/runner/main.go | 17 +- eval-harness/runner/output.go | 33 +- eval-harness/runner/runner.go | 49 +- eval-harness/runner/runner_test.go | 65 +- eval-harness/runner/setup/agent.go | 8 + eval-harness/runner/setup/config.go | 107 ++- eval-harness/runner/setup/config_test.go | 25 + eval-harness/runner/setup/load.go | 45 +- eval-harness/runner/setup/mock_llm.go | 40 +- examples/.env.defaults | 11 + examples/README.md | 20 +- examples/agent_with_conversation/main.go | 9 +- examples/agent_with_memory/README.md | 163 +++++ examples/agent_with_memory/common/config.go | 126 ++++ .../agent_with_memory/common/config_test.go | 35 + .../agent_with_memory/common/embed_openai.go | 77 ++ .../agent_with_memory/common/embedding.go | 14 + examples/agent_with_memory/common/opts.go | 57 ++ examples/agent_with_memory/common/run.go | 76 ++ examples/agent_with_memory/pgvector/main.go | 70 ++ examples/agent_with_memory/weaviate/main.go | 61 ++ .../agent_with_stream_conversation/main.go | 4 +- examples/docker/pgvector/setup.sql | 19 + examples/docker/weaviate/seed.sh | 21 + examples/shared/utils.go | 10 +- internal/runtime/base/memory.go | 354 ++++++++++ internal/runtime/base/runtime.go | 194 ++++- internal/runtime/base/runtime_test.go | 660 +++++++++++++++++- internal/runtime/base/types.go | 9 + internal/runtime/base/utils.go | 21 + internal/runtime/local/agent_loop.go | 39 +- internal/runtime/local/agent_loop_test.go | 113 +++ internal/runtime/local/runtime.go | 16 + internal/runtime/runtime.go | 7 + internal/runtime/temporal/agent_workflow.go | 123 +++- .../runtime/temporal/agent_workflow_test.go | 101 +++ internal/runtime/temporal/runtime.go | 18 + internal/testing/inmem_memory.go | 206 ++++++ internal/types/memory.go | 14 + internal/types/metrics.go | 39 +- internal/types/telemetry.go | 6 + internal/types/tool.go | 1 + internal/types/tool_test.go | 6 + pkg/agent/agent.go | 14 +- pkg/agent/agent_test.go | 7 +- pkg/agent/config.go | 119 +++- pkg/agent/config_test.go | 141 +++- pkg/agent/memory.go | 91 +++ pkg/agent/memory_test.go | 126 ++++ pkg/conversation/config.go | 51 ++ pkg/conversation/config_test.go | 71 ++ pkg/conversation/defaults.go | 16 + pkg/interfaces/memory.go | 161 +++++ pkg/interfaces/mocks/mock_memory.go | 90 +++ pkg/memory/config.go | 323 +++++++++ pkg/memory/config_test.go | 249 +++++++ pkg/memory/defaults.go | 97 +++ pkg/memory/pgvector/memory.go | 580 +++++++++++++++ pkg/memory/pgvector/memory_test.go | 462 ++++++++++++ pkg/memory/pgvector/schema.go | 27 + pkg/memory/weaviate/memory.go | 607 ++++++++++++++++ pkg/memory/weaviate/memory_test.go | 439 ++++++++++++ pkg/memory/weaviate/schema.go | 25 + taskfiles/examples.yml | 38 +- 84 files changed, 7322 insertions(+), 324 deletions(-) create mode 100755 eval-harness/run_agent_memory.sh create mode 100644 eval-harness/runner/setup/config_test.go create mode 100644 examples/agent_with_memory/README.md create mode 100644 examples/agent_with_memory/common/config.go create mode 100644 examples/agent_with_memory/common/config_test.go create mode 100644 examples/agent_with_memory/common/embed_openai.go create mode 100644 examples/agent_with_memory/common/embedding.go create mode 100644 examples/agent_with_memory/common/opts.go create mode 100644 examples/agent_with_memory/common/run.go create mode 100644 examples/agent_with_memory/pgvector/main.go create mode 100644 examples/agent_with_memory/weaviate/main.go create mode 100644 internal/runtime/base/memory.go create mode 100644 internal/testing/inmem_memory.go create mode 100644 internal/types/memory.go create mode 100644 pkg/agent/memory.go create mode 100644 pkg/agent/memory_test.go create mode 100644 pkg/conversation/config.go create mode 100644 pkg/conversation/config_test.go create mode 100644 pkg/conversation/defaults.go create mode 100644 pkg/interfaces/memory.go create mode 100644 pkg/interfaces/mocks/mock_memory.go create mode 100644 pkg/memory/config.go create mode 100644 pkg/memory/config_test.go create mode 100644 pkg/memory/defaults.go create mode 100644 pkg/memory/pgvector/memory.go create mode 100644 pkg/memory/pgvector/memory_test.go create mode 100644 pkg/memory/pgvector/schema.go create mode 100644 pkg/memory/weaviate/memory.go create mode 100644 pkg/memory/weaviate/memory_test.go create mode 100644 pkg/memory/weaviate/schema.go diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 32fdeb4..cc8f497 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -72,7 +72,7 @@ task examples:all Requires Task, Docker, and LLM credentials — see [examples/README.md](examples/README.md). -If you change **agent behavior** (e.g. `pkg/agent`, telemetry, tools, runtime) or **`eval-harness/`**, run: +If you change **agent behavior** (e.g. `pkg/agent`, `pkg/memory`, telemetry, tools, runtime) or **`eval-harness/`**, run: ```bash make eval-harness @@ -136,7 +136,7 @@ Or run a single example: go run ./examples/simple_agent "Hello" ``` -See [examples/README.md](examples/README.md) for all examples, env vars, Task install, and infra commands (`task infra:*`, `task examples:local`). +See [examples/README.md](examples/README.md) for all examples, env vars, Task install, and infra commands (`task infra:*`, `task examples:local`). Memory examples (`examples/agent_with_memory/`) need Weaviate or pgvector — see [examples/agent_with_memory/README.md](examples/agent_with_memory/README.md). ## Ways to Contribute @@ -180,7 +180,7 @@ Using the SDK and ran into issues, unclear docs, or confusing behavior? **Raise 2. **Tests** - Add tests for new features and bug fixes. - Unit tests go in `*_test.go` files alongside the code. - - Agent behavior changes (`pkg/agent`, telemetry, tools, runtime) or **`eval-harness/`** edits — run `make eval-harness` before submitting a PR. + - Agent behavior changes (`pkg/agent`, `pkg/memory`, telemetry, tools, runtime) or **`eval-harness/`** edits — run `make eval-harness` before submitting a PR. 3. **Commits** - Use [conventional commits](https://www.conventionalcommits.org) — these drive the release changelog: @@ -201,18 +201,7 @@ Using the SDK and ran into issues, unclear docs, or confusing behavior? **Raise - Keep changes focused. For larger work, consider splitting into multiple PRs. - For new LLM providers: implement `interfaces.LLMClient` (see `pkg/interfaces/llm.go` and existing providers in `pkg/llm/`). - For new tools: implement `interfaces.Tool` (see `pkg/interfaces/tools.go` and `pkg/tools/`). - -## Project Layout - -| Path | Purpose | -|------|---------| -| `pkg/agent/` | Agent core, workflow, config | -| `pkg/llm/` | LLM providers (OpenAI, Anthropic, Gemini) | -| `pkg/interfaces/` | Interfaces for LLM clients, tools, messages | -| `pkg/tools/` | Built-in and custom tools | -| `pkg/conversation/` | Message history (in-memory, Redis) | -| `cmd/` | agentctl CLI | -| `examples/` | Example programs | + - For new memory backends: implement `interfaces.Memory` (see `pkg/interfaces/memory.go` and `pkg/memory/weaviate` or `pkg/memory/pgvector`). ## Releasing (maintainers only) diff --git a/README.md b/README.md index 4b5e219..202dd03 100644 --- a/README.md +++ b/README.md @@ -29,9 +29,11 @@ - [Stream events](#stream-events-stream) - [Token usage](#token-usage-llmusage) - [Tools](#tools) + - [Conversation](#conversation) + - [Memory](#memory) + - [Retrieval (RAG)](#retrieval-rag) - [MCP](#mcp-model-context-protocol) - [A2A](#a2a-agent-to-agent) - - [Retrieval (RAG)](#retrieval-rag) - [Sub-agents](#sub-agents) - [Managing Capabilities at Runtime](#managing-capabilities-at-runtime) - [Approvals](#approvals) @@ -41,7 +43,6 @@ - [Reasoning / extended thinking](#reasoning--extended-thinking) - [Multiple agents](#multiple-agents) - [Agent and worker in separate processes](#agent-and-worker-in-separate-processes) - - [Conversation](#conversation-message-history) - [AG-UI Protocol](#ag-ui-protocol) - [Telemetry](#telemetry) - [Observability](#observability) @@ -74,10 +75,11 @@ The **in-process runtime** runs the agent loop directly in your process with no - **LLM providers** — OpenAI, Anthropic, and Gemini out of the box; bring your own via `interfaces.LLMClient`. - **Tools** — Register built-in or custom tools via `interfaces.Tool`; optional **parallel vs sequential** execution for multiple tool calls in one LLM round (`WithAgentToolExecutionMode`). - **Human-in-the-loop** — Approval gates on tool calls and delegation across `Run`, `RunAsync`, and `Stream`. -- **Conversation** — Persist multi-turn message history via `WithConversation`; in-memory store for in-process; in-memory and Redis for Temporal. Bring your own via `interfaces.Conversation`. +- **Conversation** — Persist multi-turn message history via `WithConversation` and `conversation.Config`; in-memory store for in-process; Redis for remote workers. Bring your own backend via `interfaces.Conversation`. - **Sub-agents** — Delegate to specialist agents via `WithSubAgents`; recursive delegation with depth limiting; all sub-agent events fan in to the parent stream on both runtimes. - **MCP** — Extend agent capabilities by connecting any MCP server as a tool source via `WithMCPConfig` or `WithMCPClients`. - **A2A** — Connect remote [Agent-to-Agent](https://github.com/a2aproject/A2A) agents as tool providers via `WithA2AConfig` or `WithA2AClients`; or expose the agent itself as an A2A server via `WithA2ADefaultServer` / `WithA2AServer` and `RunA2A`. +- **Memory** — Store and recall scoped facts and preferences across runs via `WithMemory` and `memory.Config`; built-in Weaviate and pgvector backends; same config on agent and worker for remote deployments. - **Retrieval (RAG)** — Ground agent responses in external knowledge bases via a pluggable `Retriever` interface with built-in Weaviate and pgvector support; extend with your own implementation. - **Streaming** — Partial tokens and events via `Stream` and `WithStream`. - **AG-UI** — Stream events conform to the [AG-UI protocol](https://docs.ag-ui.com); agents work out of the box with any AG-UI compatible frontend such as [CopilotKit](https://copilotkit.ai). @@ -372,6 +374,282 @@ result, _ := a.Run(ctx, "What's the weather in Tokyo?", nil) [examples/agent_with_tools](examples/agent_with_tools) +### Conversation + +Conversation lets agents **carry message history across turns in the same session** — user messages, assistant replies, and tool results — so multi-step chats stay coherent without resending the full thread yourself. Use it for chat UIs, support bots, or any agent where the user sends several prompts in one session. Pass `agent.WithConversation` on `NewAgent` and `NewAgentWorker` with a `conversation.Config`. + +**Backends** — in-memory (`pkg/conversation/inmem`) for single-process runs, Redis (`pkg/conversation/redis`) when the agent and worker run in separate processes, or your own implementation of `interfaces.Conversation` (see below). Example: [examples/agent_with_conversation](examples/agent_with_conversation). + +**Basic setup** — create a store, pass `conversation.Config`, and use the same conversation ID on every call: + +```go +import ( + "github.com/agenticenv/agent-sdk-go/pkg/agent" + "github.com/agenticenv/agent-sdk-go/pkg/conversation" + "github.com/agenticenv/agent-sdk-go/pkg/conversation/inmem" +) + +conv := inmem.NewInMemoryConversation(inmem.WithMaxSize(100)) +a, _ := agent.NewAgent( + agent.WithTemporalConfig(...), + agent.WithLLMClient(...), + agent.WithConversation(conversation.Config{ + Conversation: conv, + Size: 20, + }), +) +defer a.Close() + +opts := &agent.AgentRunOptions{ConversationOptions: &agent.ConversationOptions{ID: "session-1"}} +a.Run(ctx, "I'm Alice. Remember that.", opts) +a.Run(ctx, "What's my name?", opts) +``` + +**Conversation ID** — when conversation is enabled, pass `AgentRunOptions.ConversationOptions.ID` on every `Run`, `RunAsync`, and `Stream` call with the same session ID so history is shared across turns. + +**Config** — `Size` limits how many past messages are loaded for the LLM (zero defaults to 20). `SaveOnIteration` persists messages after each tool round when external consumers need live updates during a multi-step run. Use `conversation.DefaultConfig(conv)` for SDK defaults with no overrides. + +**Deployment** — single process (agent and worker together): use in-memory. Remote workers (`DisableLocalWorker` or `EnableRemoteWorkers()`): use Redis or another distributed store; in-memory fails at build time with remote workers. Agent and worker must share the same store configuration. + +**Remote workers** — configure Redis on both processes; set `SaveOnIteration: true` on the worker when live updates are needed: + +```go +convW, _ := redis.NewRedisConversation(redis.WithAddr("localhost:6379")) +w, _ := agent.NewAgentWorker( + agent.WithTemporalConfig(...), + agent.WithConversation(conversation.Config{Conversation: convW, Size: 20, SaveOnIteration: true}), +) + +convA, _ := redis.NewRedisConversation(redis.WithAddr("localhost:6379")) +a, _ := agent.NewAgent( + agent.WithTemporalConfig(...), + agent.DisableLocalWorker(), + agent.WithConversation(conversation.Config{Conversation: convA, Size: 20}), +) +opts := &agent.AgentRunOptions{ConversationOptions: &agent.ConversationOptions{ID: "session-1"}} +a.Run(ctx, "Hello", opts) +``` + +**Lifecycle** — you own the conversation store. Call `Clear` when ending a session; the agent never clears history for you. + +**Implementing your own backend** — implement `interfaces.Conversation`: + +```go +type Conversation interface { + AddMessage(ctx context.Context, id string, msg Message) error + ListMessages(ctx context.Context, id string, opts ...ListMessagesOption) ([]Message, error) + Clear(ctx context.Context, id string) error + IsDistributed() bool +} +``` + +Set `IsDistributed() bool` to match your deployment. Pass the implementation in `conversation.Config.Conversation` and wire with `agent.WithConversation`. + +### Memory + +Memory lets agents **remember facts and preferences across separate runs** — scoped per user, tenant, or custom tags — without relying on conversation history alone. Use it when users return later, when preferences should persist beyond a single session, or when you need tenant- or project-scoped memory separate from chat history. Pass `agent.WithMemory` on `NewAgent` and `NewAgentWorker` with a `memory.Config`. + +**Backends** — Weaviate (`pkg/memory/weaviate`, vector search, Docker-friendly), pgvector (`pkg/memory/pgvector`, Postgres + pgvector, requires an embed function), or your own implementation of `interfaces.Memory` (see below). Examples: [examples/agent_with_memory/weaviate](examples/agent_with_memory/weaviate) · [examples/agent_with_memory/pgvector](examples/agent_with_memory/pgvector) — setup and env vars in [examples/agent_with_memory/README.md](examples/agent_with_memory/README.md). + +**Basic setup** — create a backend, wrap it in `memory.DefaultConfig`, and pass to the agent. Attach scope on the context for each run: + +```go +import ( + "github.com/agenticenv/agent-sdk-go/pkg/agent" + "github.com/agenticenv/agent-sdk-go/pkg/memory" + wmem "github.com/agenticenv/agent-sdk-go/pkg/memory/weaviate" +) + +store, _ := wmem.NewMemory( + wmem.WithHost("localhost:8080"), + wmem.WithClassName("AgentMemory"), +) + +a, _ := agent.NewAgent( + agent.WithMemory(memory.DefaultConfig(store)), + agent.WithLLMClient(llmClient), + // ... +) +defer a.Close() + +ctx := memory.WithContextUserID(context.Background(), "user-123") +result, _ := a.Run(ctx, "Remember I prefer concise answers.", nil) +``` + +For **pgvector**, use `pgmem.NewMemory(embedFn, pgmem.WithDSN(...), pgmem.WithTable("agent_memories"))` with the same `WithMemory` pattern. + +**Deployment** — Weaviate and pgvector are external shared stores, so the same backend works for single-process runs and for remote workers. Use a store both processes can reach (same host, DSN, or cluster). + +**Remote workers** — pass the same `memory.Config` to `NewAgent` and `NewAgentWorker` (store, scope, store mode, recall, and TTL). Attach scope on the context in the agent process when calling `Run`, `RunAsync`, or `Stream`. With `StoreModeAlways` and no custom `Extract`, the worker also needs `WithLLMClient`: + +```go +memCfg := memory.DefaultConfig(store) + +w, _ := agent.NewAgentWorker( + agent.WithTemporalConfig(...), + agent.WithLLMClient(llmClient), + agent.WithMemory(memCfg), +) +go w.Start() + +a, _ := agent.NewAgent( + agent.WithTemporalConfig(...), + agent.WithLLMClient(llmClient), + agent.DisableLocalWorker(), + agent.WithMemory(memCfg), +) +ctx := memory.WithContextUserID(ctx, userID) +a.Run(ctx, prompt, nil) +``` + +**Store modes** — set `memCfg.Store.Mode`. **On-demand** (default): the LLM saves via the `save_memory` tool during the run: + +```go +memCfg := memory.DefaultConfig(store) +memCfg.Store.Mode = memory.StoreModeOnDemand +a, _ := agent.NewAgent(agent.WithMemory(memCfg), ...) +``` + +**Always**: memories are saved automatically when the run finishes: + +```go +memCfg := memory.DefaultConfig(store) +memCfg.Store.Mode = memory.StoreModeAlways +a, _ := agent.NewAgent(agent.WithMemory(memCfg), ...) +``` + +Use **on-demand** when the model should decide what to persist. Use **always** when every run should be saved without relying on tool calls. + +**Recall config** — tune `memCfg.Recall`: `Enabled` (default `true`; set `false` for store-only), `Limit` (default `10`), `MinScore` (default `0.35`), `Kinds` (empty = all kinds): + +```go +memCfg.Recall = memory.RecallConfig{ + Enabled: true, + Limit: 20, + MinScore: 0.4, + Kinds: []interfaces.MemoryKind{memory.KindPreference, memory.KindFact}, +} +``` + +**Scope config** — memories are isolated by scope. Defaults read tenant, user, and agent from request context: + +```go +ctx := memory.WithContextUserID(ctx, userID) +ctx = memory.WithContextTenantID(ctx, tenantID) // optional +a.Run(ctx, prompt, nil) +``` + +Use the **same scope values** on every run that should share memories. By default, memories are also isolated per agent name. Override resolvers or add custom tags on `memCfg.ScopeConfig`: + +```go +memCfg.ScopeConfig = memory.ScopeConfig{ + UserIDResolver: func(ctx context.Context) string { return myApp.UserID(ctx) }, + ExtraKeys: []string{"project_id"}, + TagResolvers: map[string]memory.ScopeResolver{ + "project_id": func(ctx context.Context) string { return myApp.ProjectID(ctx) }, + }, +} +``` + +**TTL policy** — `memCfg.TTLPolicy` sets expiry per kind at store time. Defaults: `decision` 7 days, `note` 48 hours, `fact` / `preference` / `instruction` no expiry. Override: + +```go +memCfg.TTLPolicy = memory.TTLPolicy{ + memory.KindPreference: 0, + memory.KindNote: 7 * 24 * time.Hour, +} +``` + +**Dedup** — `memCfg.Store.DedupMinScore` (default **0.85**) controls whether a similar memory is updated or appended. Lower to dedup more; raise to keep distinct entries. + +**Custom extraction** — with **always** mode, set `memCfg.Store.Extract` to control what is saved at run end (`nil` for on-demand): + +```go +memCfg.Store.Mode = memory.StoreModeAlways +memCfg.Store.Extract = func(ctx context.Context, messages []interfaces.Message) ([]interfaces.MemoryRecord, error) { + return []interfaces.MemoryRecord{{Text: "User prefers concise answers.", Kind: memory.KindPreference}}, nil +} +``` + +Omit `Extract` to use the SDK default (requires an LLM client). Also on `Store`: `DefaultKind` (default `note`) and `AllowedKinds` (optional allowlist). + +**Implementing your own backend** — implement `interfaces.Memory`: + +```go +type Memory interface { + Store(ctx context.Context, scope MemoryScope, record MemoryRecord, opts ...StoreMemoryOption) (string, error) + Load(ctx context.Context, scope MemoryScope, query string, opts ...LoadMemoryOption) ([]MemoryEntry, error) + Clear(ctx context.Context, scope MemoryScope) error +} +``` + +`Store` persists in scope and returns an ID (support `WithMemoryID` for upserts). `Load` returns scoped results; honor load options. `Clear` deletes by scope (application use). Pass to `memory.DefaultConfig(yourStore)` and `agent.WithMemory`. + +### Retrieval (RAG) + +Retrieval-Augmented Generation (RAG) lets agents query external knowledge bases and ground responses in up-to-date or domain-specific content — without hardcoding it into the prompt. + +Built-in retriever implementations are in `pkg/retriever/weaviate` and `pkg/retriever/pgvector`. Bring your own by implementing `interfaces.Retriever` (`Name`, `Search`). + +**Retriever modes** + +- **Agentic** (default) — LLM decides when to call the retriever as a tool, the same way it calls any other tool. Best for multi-step agents where retrieval is not always needed. +- **Prefetch** — Retrieval fires before every LLM call. Retrieved context is injected automatically. Best for always-grounded Q&A or enterprise knowledge-base scenarios. +- **Hybrid** — Both: retriever context is pre-fetched and injected (prefetch), and the LLM can also call the retriever as a tool (agentic). + +Set mode with `agent.WithRetrieverMode`: + +```go +agent.WithRetrieverMode(agent.RetrieverModeAgentic) // default +agent.WithRetrieverMode(agent.RetrieverModePrefetch) +agent.WithRetrieverMode(agent.RetrieverModeHybrid) // prefetch + agentic +``` + +**Weaviate** (local Docker, zero auth for dev): + +```go +import "github.com/agenticenv/agent-sdk-go/pkg/retriever/weaviate" + +r, err := weaviate.NewRetriever("product_knowledge", + weaviate.WithHost("localhost:8080"), + weaviate.WithClassName("ProductDocs"), +) + +a, _ := agent.NewAgent( + agent.WithRetrievers(r), + agent.WithRetrieverMode(agent.RetrieverModeAgentic), + ... +) +``` + +**pgvector** (Postgres with pgvector extension; requires an embed function): + +```go +import "github.com/agenticenv/agent-sdk-go/pkg/retriever/pgvector" + +r, err := pgvector.NewRetriever("support_knowledge", embedFn, + pgvector.WithDSN("postgres://user:pass@localhost:5432/mydb"), + pgvector.WithTable("documents"), +) +``` + +**Custom retriever** — implement `interfaces.Retriever`: + +```go +type Retriever interface { + Name() string + Search(ctx context.Context, query string) ([]interfaces.Document, error) +} +``` + +**Multiple retrievers** — pass as many as needed; each must have a unique name: + +```go +agent.WithRetrievers(productRetriever, supportRetriever) +``` + +[examples/agent_with_retriever/weaviate](examples/agent_with_retriever/weaviate) · [examples/agent_with_retriever/pgvector](examples/agent_with_retriever/pgvector) + ### MCP (Model Context Protocol) MCP servers extend your agent with external tools that work identically to built-in tools across `Run`, `Stream`, `RunAsync`, and approval gates. Each server needs a **unique** name in config (the `WithMCPConfig` map key or the first argument to `mcpclient.NewClient`); tools are registered under stable names so they do not collide when several servers expose the same logical tool id. @@ -597,71 +875,6 @@ You may use **Option 1** for some remote agents and **Option 2** for others on t [examples/agent_with_a2a_config](examples/agent_with_a2a_config) and [examples/agent_with_a2a_client](examples/agent_with_a2a_client) show A2A from env (`A2A_URL`, optional bearer/headers/filter). Variables: [examples/.env.defaults](examples/.env.defaults). Running examples from `examples/`: [examples/README.md](examples/README.md). **Remote agent setup (e.g. `a2a-samples` helloworld), curl checks:** [examples/agent_with_a2a_config/README.md](examples/agent_with_a2a_config/README.md). -### Retrieval (RAG) - -Retrieval-Augmented Generation (RAG) lets agents query external knowledge bases and ground responses in up-to-date or domain-specific content — without hardcoding it into the prompt. - -Built-in retriever implementations are in `pkg/retriever/weaviate` and `pkg/retriever/pgvector`. Bring your own by implementing `interfaces.Retriever` (`Name`, `Search`). - -**Retriever modes** - -- **Agentic** (default) — LLM decides when to call the retriever as a tool, the same way it calls any other tool. Best for multi-step agents where retrieval is not always needed. -- **Prefetch** — Retrieval fires before every LLM call. Retrieved context is injected automatically. Best for always-grounded Q&A or enterprise knowledge-base scenarios. -- **Hybrid** — Both: retriever context is pre-fetched and injected (prefetch), and the LLM can also call the retriever as a tool (agentic). - -Set mode with `agent.WithRetrieverMode`: - -```go -agent.WithRetrieverMode(agent.RetrieverModeAgentic) // default -agent.WithRetrieverMode(agent.RetrieverModePrefetch) -agent.WithRetrieverMode(agent.RetrieverModeHybrid) // prefetch + agentic -``` - -**Weaviate** (local Docker, zero auth for dev): - -```go -import "github.com/agenticenv/agent-sdk-go/pkg/retriever/weaviate" - -r, err := weaviate.NewRetriever("product_knowledge", - weaviate.WithHost("localhost:8080"), - weaviate.WithClassName("ProductDocs"), -) - -a, _ := agent.NewAgent( - agent.WithRetrievers(r), - agent.WithRetrieverMode(agent.RetrieverModeAgentic), - ... -) -``` - -**pgvector** (Postgres with pgvector extension; requires an embed function): - -```go -import "github.com/agenticenv/agent-sdk-go/pkg/retriever/pgvector" - -r, err := pgvector.NewRetriever("support_knowledge", embedFn, - pgvector.WithDSN("postgres://user:pass@localhost:5432/mydb"), - pgvector.WithTable("documents"), -) -``` - -**Custom retriever** — implement `interfaces.Retriever`: - -```go -type Retriever interface { - Name() string - Search(ctx context.Context, query string) ([]interfaces.Document, error) -} -``` - -**Multiple retrievers** — pass as many as needed; each must have a unique name: - -```go -agent.WithRetrievers(productRetriever, supportRetriever) -``` - -[examples/agent_with_retriever/weaviate](examples/agent_with_retriever/weaviate) · [examples/agent_with_retriever/pgvector](examples/agent_with_retriever/pgvector) - ### Sub-agents Build each specialist with `NewAgent` (its own `TaskQueue`, LLM, tools, and prompts). Register specialists on the main agent with `WithSubAgents`. Use `WithName` and `WithDescription` when you want clearer labels for routing. Use `WithMaxSubAgentDepth` only if the default nesting limit is not enough. Run `Run`, `Stream`, or `RunAsync` on the main agent. Sub-agents always run without a conversation ID—they do not inherit the main agent session history. If you use `DisableLocalWorker`, pair each `NewAgentWorker` with the same options as the `NewAgent` that runs that agent. @@ -1006,90 +1219,6 @@ result, _ := a.Run(ctx, "Hello", nil) > `agent.WithAgentMode(agent.AgentModeAutonomous)` to skip the worker check and use a > 60-minute default timeout. See `[WithAgentMode](#configuration)` for full detail. -### Conversation (message history) - -Pass `agent.WithConversation(conv)` to persist message history for multi-turn context. Use `agent.WithConversationSize(n)` to limit how many messages are fetched for LLM context (default 20). - -By default, messages from the current run are saved once when the run finishes. Use `agent.EnableConversationSaveOnIteration()` when external consumers (e.g. a UI polling Redis) need live updates **during** a multi-step run—after each tool round, not only at the end. This adds extra store writes. For Temporal remote workers, set it on **`AgentWorker`** (where `WithConversation` and persistence run); the agent caller process does not need it. - -**Conversation ID:** When the agent is configured with a conversation, pass an `*agent.AgentRunOptions` with `ConversationOptions.ID` set to the same session ID on every call to `Run`, `RunAsync`, and `Stream`—so history is shared across turns. - -Choose implementation by deployment: - - -| Deployment | Use | -| -------------------------------------------------------------------- | --------------------------------------------------------- | -| **Single process** (agent and worker in same process) | `inmem.NewInMemoryConversation` | -| **Remote workers** (`DisableLocalWorker` or `EnableRemoteWorkers()`) | `redis.NewRedisConversation` or another distributed store | - - -To add a new conversation store (e.g., Postgres, MongoDB), implement the `interfaces.Conversation` interface in `[pkg/interfaces/conversation.go](pkg/interfaces/conversation.go)`. The interface requires `AddMessage`, `ListMessages`, `Clear`, and `IsDistributed`. See `pkg/conversation/inmem` and `pkg/conversation/redis` for reference. - -In-memory cannot be used with remote workers—the agent will return an error at build time. - -**Remote workers:** Agent and worker must use the same conversation store (same Redis config) so both processes access the same data. Only the process that calls `Run` or `Stream` passes the conversation ID; the worker does not. - -```go -// Single process (default) -conv := inmem.NewInMemoryConversation(inmem.WithMaxSize(100)) -a, _ := agent.NewAgent( - agent.WithTemporalConfig(...), - agent.WithLLMClient(...), - agent.WithConversation(conv), - agent.WithConversationSize(20), // optional; default 20 -) -opts := &agent.AgentRunOptions{ConversationOptions: &agent.ConversationOptions{ID: "session-1"}} -result, _ := a.Run(ctx, "Hello", opts) - -// Worker process -convW, _ := redis.NewRedisConversation(redis.WithAddr("localhost:6379")) -defer convW.Close() -w, _ := agent.NewAgentWorker( - agent.WithTemporalConfig(...), - agent.WithLLMClient(...), - agent.WithConversation(convW), -) -go w.Start() - -// Agent process -convA, _ := redis.NewRedisConversation(redis.WithAddr("localhost:6379")) -defer convA.Close() -a, _ := agent.NewAgent( - agent.WithTemporalConfig(...), - agent.WithLLMClient(...), - agent.DisableLocalWorker(), - agent.WithConversation(convA), -) -opts := &agent.AgentRunOptions{ConversationOptions: &agent.ConversationOptions{ID: "session-1"}} -result, _ := a.Run(ctx, "Hello", opts) -``` - -**Lifecycle:** You own the conversation. Call `Clear` when ending a session or when you no longer need the history. The agent never calls `Clear`. - -**Example (in-memory, single process):** - -```go -import ( - "github.com/agenticenv/agent-sdk-go/pkg/agent" - "github.com/agenticenv/agent-sdk-go/pkg/conversation/inmem" -) - -conv := inmem.NewInMemoryConversation(inmem.WithMaxSize(100)) -a, _ := agent.NewAgent( - agent.WithTemporalConfig(...), - agent.WithLLMClient(...), - agent.WithConversation(conv), - agent.WithConversationSize(20), -) -defer a.Close() - -opts := &agent.AgentRunOptions{ConversationOptions: &agent.ConversationOptions{ID: "session-1"}} -a.Run(ctx, "I'm Alice. Remember that.", opts) -a.Run(ctx, "What's my name?", opts) // agent uses history: "Alice" -``` - -[examples/agent_with_conversation](examples/agent_with_conversation) - ### AG-UI Protocol Agent stream events follow the [AG-UI open protocol](https://docs.ag-ui.com), making your agents natively compatible with any AG-UI frontend without extra integration work. @@ -1125,7 +1254,7 @@ Every run populates `AgentTelemetry` inside `AgentRunResult` with behavioral met - **Run** — start/end time, total LLM calls, and finish reason (`complete` or `max_iterations`) - **Tools** — total calls, failed calls, and per-tool breakdown for registered tools and MCP tools -- **Storage** — RAG retriever search counts split by mode (`prefetch_searches`, `agentic_searches`) and failure count; all fields are zero when no retriever is configured +- **Storage** — RAG retriever search counts split by mode (`prefetch_searches`, `agentic_searches`) and failure count; memory recall/store counts when `WithMemory` is configured; all fields are zero when the corresponding feature is not configured ```go result, _ := ag.Run(ctx, "prompt") @@ -1136,6 +1265,9 @@ fmt.Printf("retriever_searches=%d prefetch=%d agentic=%d\n", t.Storage.TotalRetrieverSearches, t.Storage.PrefetchSearches, t.Storage.AgenticSearches) +fmt.Printf("memory_recalls=%d memory_stores=%d\n", + t.Storage.TotalMemoryRecalls, + t.Storage.TotalMemoryStores) ``` **Stream** — telemetry is on `Result.Telemetry` inside the `RUN_FINISHED` event: @@ -1221,8 +1353,13 @@ a, _ := agent.NewAgent( | `tool.authorize` | `AgentToolAuthorizeActivity` | | `conversation.get_messages` | Fetch conversation history activity | | `conversation.add_messages` | Persist conversation activity | +| `memory.recall` | Load scoped memories before a run (`WithMemory`) | +| `memory.store` | Persist one memory record | +| `memory.store.batch` | Persist multiple records in one store call | +| `memory.dedup` | Semantic dedup lookup before store | +| `memory.extract` | Run-end memory extraction (`StoreModeAlways`) | -Common attributes: `agent.name`, `conversation.id`, `input.length`, `model`, `provider`, `tool`. +Common attributes: `agent.name`, `conversation.id`, `input.length`, `model`, `provider`, `tool`, `memory.kind`, `memory.dedup`, `query` (recall). ### Metrics @@ -1241,7 +1378,7 @@ All metric names are defined in `internal/types/metrics.go`. | `agent.stream.failed` | counter | Dispatch failed | | `agent.stream.duration_ms` | histogram | Dispatch wall-clock time in ms | -**Runtime** (emitted by Temporal activities; attributes: `model`, `provider`, `tool`): +**Runtime** (emitted per run on local and Temporal runtimes; attributes: `model`, `provider`, `tool`, `retriever`, `memory.kind`, `memory.dedup`): | Metric | Kind | Description | |---|---|---| @@ -1255,6 +1392,22 @@ All metric names are defined in `internal/types/metrics.go`. | `agent.tool.call.completed` | counter | Tool execute succeeded | | `agent.tool.call.failed` | counter | Tool execute failed | | `agent.tool.latency_ms` | histogram | Tool wall-clock time in ms | +| `agent.memory.recall.started` | counter | Memory recall started | +| `agent.memory.recall.completed` | counter | Memory recall succeeded | +| `agent.memory.recall.failed` | counter | Memory recall failed | +| `agent.memory.recall.latency_ms` | histogram | Memory recall wall-clock time in ms | +| `agent.memory.store.started` | counter | Memory store started | +| `agent.memory.store.completed` | counter | Memory store succeeded | +| `agent.memory.store.failed` | counter | Memory store failed | +| `agent.memory.store.latency_ms` | histogram | Memory store wall-clock time in ms | +| `agent.memory.dedup.started` | counter | Semantic dedup lookup started | +| `agent.memory.dedup.completed` | counter | Semantic dedup lookup succeeded | +| `agent.memory.dedup.failed` | counter | Semantic dedup lookup failed | +| `agent.memory.dedup.latency_ms` | histogram | Semantic dedup wall-clock time in ms | +| `agent.memory.extract.started` | counter | Run-end memory extract started | +| `agent.memory.extract.completed` | counter | Run-end memory extract succeeded | +| `agent.memory.extract.failed` | counter | Run-end memory extract failed | +| `agent.memory.extract.latency_ms` | histogram | Run-end memory extract wall-clock time in ms | ### Logs @@ -1276,9 +1429,8 @@ A Temporal connection (`WithTemporalConfig` or `WithTemporalClient`) is **option - **WithTemporalClient**: Pre-configured Temporal client. Use for TLS, API key auth, Temporal Cloud. Requires `WithTaskQueue`. Agent does not close the client. - **WithTaskQueue**: Task queue name. Required when using `WithTemporalClient`. Ignored when using `WithTemporalConfig`. - **WithResponseFormat**: LLM response format. Omit for text-only. Use `&interfaces.ResponseFormat{Type, Name, Schema}` for JSON with schema. See [Response format](#response-format). -- **WithConversation**: Message history store. Use `inmem` for single process; `redis` for remote workers. Pass the conversation ID via `AgentRunOptions` to `Run`, `RunAsync`, and `Stream` to share history across turns. See [Conversation](#conversation-message-history). -- **WithConversationSize**: Max messages to fetch for LLM context (default 20). Only applies when `WithConversation` is set. -- **EnableConversationSaveOnIteration**: Persist conversation messages after each tool round instead of batching at run end. For live visibility (e.g. Redis UI) during long runs. Set on `AgentWorker` for Temporal remote workers. +- **WithConversation**: Message history via `conversation.Config` — backend, `Size`, and `SaveOnIteration`. Pass the conversation ID via `AgentRunOptions` on each run. See [Conversation](#conversation). +- **WithMemory**: Memory via `memory.Config` — backend, scope, store mode, recall, and TTL. See [Memory](#memory). - **EnableRemoteWorkers**: Pass `EnableRemoteWorkers()` when using `DisableLocalWorker` with approval or streaming (starts the event worker/workflow path). - **WithSubAgents**: Attach specialist agents the main agent can delegate to. Each needs its own task queue and worker. See [Sub-agents](#sub-agents). - **WithMaxSubAgentDepth**: Maximum delegation hops from this agent (default 2). See [Sub-agents](#sub-agents). diff --git a/benchmarks/README.md b/benchmarks/README.md index 5504697..0808ab8 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -45,6 +45,7 @@ Mock components apply configurable latency and jitter so results reflect realist - Process CPU time - Total input/output tokens (from mock LLM stats; includes sub-agent LLM calls) - Success rate (`Run()` completed without error) +- Long-term memory recalls/stores (when `memory.enabled: true`; from run telemetry) - `est_cost_usd` — placeholder `0` until pricing is configured Reports are written to `benchmarks/reports/` (JSON or text). SDK logs (optional) go to `benchmarks/logs/`. @@ -205,6 +206,18 @@ All paths in config (`dir` fields) are relative to the **repository root** unles | `subagents.count` | Sub-agents per level (0 to disable). | | `subagents.levels` | Max sub-agent nesting depth (1–5). | +### `memory` + +Long-term memory (`agent.WithMemory`) using an in-process inmem backend (no Docker). Disabled by default. + +| Field | Description | +| :--- | :--- | +| `enabled` | `true` wires recall before each run and store after (mode-dependent). | +| `store_mode` | `ondemand` (LLM `save_memory` tool) or `always` (extract at run end). | +| `user_id` | Scope user ID passed via `memory.WithContextUserID` (default `benchmark-user`). | + +When `memory.enabled: true`, `agent.tools.count` may be `0` (memory-only runs). The mock LLM handles `save_memory` tool args and memory-extract JSON like the eval harness. + ### `logger` | Field | Description | diff --git a/benchmarks/agent_builder.go b/benchmarks/agent_builder.go index 1103e6e..73b5306 100644 --- a/benchmarks/agent_builder.go +++ b/benchmarks/agent_builder.go @@ -16,7 +16,10 @@ type AgentBundle struct { func buildAgentBundle(cfg *setup.Config, llm *setup.MockLLMClient, lgr logger.Logger, tree *setup.AgentTree) (*AgentBundle, error) { enableRemote := cfg.ExternalWorkersEnabled() - opts := setup.RootOptions(cfg, llm, lgr, setup.RootAgentName, tree.RootPrompt, tree.SubAgents, cfg.Temporal.TaskQueue, enableRemote) + opts, err := setup.AppendMemoryOptions(cfg, setup.RootOptions(cfg, llm, lgr, setup.RootAgentName, tree.RootPrompt, tree.SubAgents, cfg.Temporal.TaskQueue, enableRemote)) + if err != nil { + return nil, err + } root, err := agent.NewAgent(opts...) if err != nil { diff --git a/benchmarks/config.yaml b/benchmarks/config.yaml index ad5781c..1970952 100644 --- a/benchmarks/config.yaml +++ b/benchmarks/config.yaml @@ -29,6 +29,11 @@ agent: count: 2 levels: 1 # 1 or 2 +memory: + enabled: false + store_mode: ondemand # ondemand or always + user_id: benchmark-user + logger: enabled: false # true writes SDK logs under benchmarks/logs dir: benchmarks/logs diff --git a/benchmarks/main.go b/benchmarks/main.go index 24917a8..69a5403 100644 --- a/benchmarks/main.go +++ b/benchmarks/main.go @@ -28,6 +28,9 @@ type BenchmarkMetrics struct { TotalRuns int `json:"total_runs"` SuccessRate float64 `json:"success_rate"` + + TotalMemoryRecalls int64 `json:"total_memory_recalls"` + TotalMemoryStores int64 `json:"total_memory_stores"` } func main() { @@ -67,6 +70,9 @@ func main() { fmt.Printf("Starting agent-sdk-go benchmark (%s runtime)\n", cfg.Runtime) fmt.Printf("Runs: %d Concurrent: %t Tools: %d Sub-agents: %d (levels %d)\n", cfg.Agent.Runs, cfg.Agent.Concurrent, cfg.Agent.Tools.Count, cfg.Agent.Subagents.Count, cfg.Agent.Subagents.Levels) + if cfg.MemoryEnabled() { + fmt.Printf("Memory : enabled (store_mode=%s, user_id=%s)\n", cfg.Memory.StoreMode, cfg.Memory.UserID) + } if cfg.UseTemporal() { fmt.Printf("External workers : %d\n", cfg.Temporal.WorkersCount) } diff --git a/benchmarks/metrics.go b/benchmarks/metrics.go index 4ccf3f6..c1035e1 100644 --- a/benchmarks/metrics.go +++ b/benchmarks/metrics.go @@ -9,11 +9,14 @@ import ( func aggregateMetrics(outcomes []runOutcome, memBefore, memAfter runtime.MemStats, cpuMs float64, inputTokens, outputTokens int) *BenchmarkMetrics { latencies := make([]float64, 0, len(outcomes)) successes := 0 + var totalRecalls, totalStores int64 for _, o := range outcomes { latencies = append(latencies, o.latencyMs) if o.success { successes++ } + totalRecalls += o.memoryRecalls + totalStores += o.memoryStores } sort.Float64s(latencies) @@ -24,18 +27,20 @@ func aggregateMetrics(outcomes []runOutcome, memBefore, memAfter runtime.MemStat } return &BenchmarkMetrics{ - P50Ms: percentile(latencies, 50), - P95Ms: percentile(latencies, 95), - P99Ms: percentile(latencies, 99), - AvgMs: average(latencies), - HeapAllocBytes: deltaUint64(memAfter.Alloc, memBefore.Alloc), - TotalAllocBytes: deltaUint64(memAfter.TotalAlloc, memBefore.TotalAlloc), - CPUTimeMs: cpuMs, - TotalInputTokens: inputTokens, - TotalOutputTokens: outputTokens, - EstCostUSD: 0, // pricing to be defined later - TotalRuns: totalRuns, - SuccessRate: successRate, + P50Ms: percentile(latencies, 50), + P95Ms: percentile(latencies, 95), + P99Ms: percentile(latencies, 99), + AvgMs: average(latencies), + HeapAllocBytes: deltaUint64(memAfter.Alloc, memBefore.Alloc), + TotalAllocBytes: deltaUint64(memAfter.TotalAlloc, memBefore.TotalAlloc), + CPUTimeMs: cpuMs, + TotalInputTokens: inputTokens, + TotalOutputTokens: outputTokens, + EstCostUSD: 0, // pricing to be defined later + TotalRuns: totalRuns, + SuccessRate: successRate, + TotalMemoryRecalls: totalRecalls, + TotalMemoryStores: totalStores, } } diff --git a/benchmarks/report.go b/benchmarks/report.go index 027ac37..ed10ab3 100644 --- a/benchmarks/report.go +++ b/benchmarks/report.go @@ -89,5 +89,9 @@ func formatTextReport(cfg *setup.Config, metrics *BenchmarkMetrics) string { fmt.Fprintf(&b, "Output tokens : %d\n", metrics.TotalOutputTokens) fmt.Fprintf(&b, "Est. cost (USD) : %.4f # pricing placeholder\n", metrics.EstCostUSD) fmt.Fprintf(&b, "Success rate (%%) : %.2f\n", metrics.SuccessRate) + if cfg.MemoryEnabled() { + fmt.Fprintf(&b, "Memory recalls : %d\n", metrics.TotalMemoryRecalls) + fmt.Fprintf(&b, "Memory stores : %d\n", metrics.TotalMemoryStores) + } return b.String() } diff --git a/benchmarks/runner.go b/benchmarks/runner.go index 2cdbe76..59edc57 100644 --- a/benchmarks/runner.go +++ b/benchmarks/runner.go @@ -12,11 +12,14 @@ import ( "github.com/agenticenv/agent-sdk-go/benchmarks/setup" "github.com/agenticenv/agent-sdk-go/pkg/agent" "github.com/agenticenv/agent-sdk-go/pkg/logger" + "github.com/agenticenv/agent-sdk-go/pkg/memory" ) type runOutcome struct { - latencyMs float64 - success bool + latencyMs float64 + success bool + memoryRecalls int64 + memoryStores int64 } func runBenchmark(ctx context.Context, cfg *setup.Config, llm *setup.MockLLMClient, lgr logger.Logger, runRng *rand.Rand) (*BenchmarkMetrics, error) { @@ -64,7 +67,7 @@ func runBenchmark(ctx context.Context, cfg *setup.Config, llm *setup.MockLLMClie agentIdx := i % len(bundles) go func(bundle *AgentBundle) { defer wg.Done() - outcome := executeRun(ctx, bundle.Root, runRng) + outcome := executeRun(ctx, cfg, bundle.Root, runRng) outcomesMu.Lock() outcomes = append(outcomes, outcome) outcomesMu.Unlock() @@ -89,13 +92,22 @@ func runBenchmark(ctx context.Context, cfg *setup.Config, llm *setup.MockLLMClie return aggregateMetrics(outcomes, memBefore, memAfter, cpuAfter-cpuBefore, inputTokens, outputTokens), nil } -func executeRun(ctx context.Context, a *agent.Agent, rng *rand.Rand) runOutcome { +func executeRun(ctx context.Context, cfg *setup.Config, a *agent.Agent, rng *rand.Rand) runOutcome { + runCtx := ctx + if cfg.MemoryEnabled() { + runCtx = memory.WithContextUserID(ctx, cfg.Memory.UserID) + } start := time.Now() - _, err := a.Run(ctx, setup.RandomUserPrompt(rng), nil) - return runOutcome{ + result, err := a.Run(runCtx, setup.RandomUserPrompt(rng), nil) + outcome := runOutcome{ latencyMs: float64(time.Since(start).Milliseconds()), success: err == nil, } + if result != nil && result.Telemetry != nil { + outcome.memoryRecalls = result.Telemetry.Storage.TotalMemoryRecalls + outcome.memoryStores = result.Telemetry.Storage.TotalMemoryStores + } + return outcome } func processCPUTimeMs() (float64, error) { diff --git a/benchmarks/setup/config.go b/benchmarks/setup/config.go index e2a30fc..1044294 100644 --- a/benchmarks/setup/config.go +++ b/benchmarks/setup/config.go @@ -6,10 +6,14 @@ import ( "path/filepath" "strings" + testutil "github.com/agenticenv/agent-sdk-go/internal/testing" + "github.com/agenticenv/agent-sdk-go/pkg/agent" + "github.com/agenticenv/agent-sdk-go/pkg/memory" "github.com/spf13/viper" ) const BenchmarkTreeSeed int64 = 42 +const defaultMemoryUserID = "benchmark-user" type Config struct { Runtime string `mapstructure:"runtime"` @@ -17,6 +21,7 @@ type Config struct { LLM LLMConfig `mapstructure:"llm"` Tool ToolConfig `mapstructure:"tool"` Agent AgentConfig `mapstructure:"agent"` + Memory MemoryConfig `mapstructure:"memory"` Logger LoggerConfig `mapstructure:"logger"` Output OutputConfig `mapstructure:"output"` } @@ -58,6 +63,13 @@ type SubagentsConfig struct { Levels int `mapstructure:"levels"` } +// MemoryConfig configures long-term memory for benchmark runs. +type MemoryConfig struct { + Enabled bool `mapstructure:"enabled"` + StoreMode string `mapstructure:"store_mode"` + UserID string `mapstructure:"user_id"` +} + type LoggerConfig struct { Enabled bool `mapstructure:"enabled"` Dir string `mapstructure:"dir"` @@ -79,6 +91,50 @@ func (c *Config) ExternalWorkersEnabled() bool { return c.UseTemporal() && c.Temporal.WorkersCount > 0 } +// MemoryEnabled reports whether long-term memory is wired for benchmark runs. +func (c *Config) MemoryEnabled() bool { + return c != nil && c.Memory.Enabled +} + +func (m *MemoryConfig) applyDefaults() { + if m == nil { + return + } + if strings.TrimSpace(m.UserID) == "" { + m.UserID = defaultMemoryUserID + } + if strings.TrimSpace(m.StoreMode) == "" { + m.StoreMode = string(memory.StoreModeOnDemand) + } +} + +func parseMemoryStoreMode(raw string) (memory.StoreMode, error) { + switch strings.ToLower(strings.TrimSpace(raw)) { + case "", string(memory.StoreModeOnDemand), "on-demand", "on_demand": + return memory.StoreModeOnDemand, nil + case string(memory.StoreModeAlways): + return memory.StoreModeAlways, nil + default: + return "", fmt.Errorf("memory.store_mode must be %q or %q", memory.StoreModeOnDemand, memory.StoreModeAlways) + } +} + +// MemoryAgentOption returns WithMemory when memory is enabled. +func MemoryAgentOption(cfg *Config) (agent.Option, error) { + if cfg == nil || !cfg.MemoryEnabled() { + return nil, nil + } + cfg.Memory.applyDefaults() + mode, err := parseMemoryStoreMode(cfg.Memory.StoreMode) + if err != nil { + return nil, err + } + memCfg := memory.DefaultConfig(testutil.NewInmemMemory()) + memCfg.Store.Mode = mode + memCfg.Recall.Enabled = true + return agent.WithMemory(memCfg), nil +} + func LoadConfig(path string) (*Config, error) { if path == "" { path = defaultConfigPath() @@ -106,8 +162,8 @@ func (c *Config) validate() error { if c.Agent.Concurrent && c.Agent.ConcurrentCount <= 0 { return fmt.Errorf("agent.concurrent_count must be > 0 when concurrent is true") } - if c.Agent.Tools.Count <= 0 { - return fmt.Errorf("agent.tools.count must be > 0") + if c.Agent.Tools.Count <= 0 && !c.Memory.Enabled { + return fmt.Errorf("agent.tools.count must be > 0 when memory is disabled") } if c.Agent.Subagents.Levels < 0 { return fmt.Errorf("agent.subagents.levels must be >= 0") @@ -148,6 +204,12 @@ func (c *Config) validate() error { if c.Temporal.Namespace == "" { c.Temporal.Namespace = "default" } + c.Memory.applyDefaults() + if c.Memory.Enabled { + if _, err := parseMemoryStoreMode(c.Memory.StoreMode); err != nil { + return err + } + } return nil } diff --git a/benchmarks/setup/mock_llm.go b/benchmarks/setup/mock_llm.go index da6390d..dc6bb07 100644 --- a/benchmarks/setup/mock_llm.go +++ b/benchmarks/setup/mock_llm.go @@ -9,11 +9,14 @@ import ( "time" "github.com/agenticenv/agent-sdk-go/internal/runtime" + "github.com/agenticenv/agent-sdk-go/internal/types" "github.com/agenticenv/agent-sdk-go/pkg/interfaces" ) const MockLLMModel = "benchmark-mock" +const mockMemoryExtractText = "User prefers concise answers" + type LLMStats struct { mu sync.Mutex TotalInputTokens int @@ -64,6 +67,17 @@ func (m *MockLLMClient) Generate(ctx context.Context, request *interfaces.LLMReq promptTokens, completionTokens := splitMockTokens(m.cfg.MockTokens) m.stats.add(promptTokens, completionTokens) + if isMemoryExtractRequest(request) { + return &interfaces.LLMResponse{ + Content: fmt.Sprintf(`{"memories":[{"text":%q,"kind":"preference"}]}`, mockMemoryExtractText), + Usage: &interfaces.LLMUsage{ + PromptTokens: int64(promptTokens), + CompletionTokens: int64(completionTokens), + TotalTokens: int64(promptTokens + completionTokens), + }, + }, nil + } + if hasToolResultMessages(request) { return &interfaces.LLMResponse{ Content: "benchmark complete", @@ -149,12 +163,26 @@ func hasToolResultMessages(request *interfaces.LLMRequest) bool { } func mockToolArgs(toolName string) map[string]any { + if toolName == types.SaveMemoryToolName { + return map[string]any{ + types.MemoryToolParamText: mockMemoryExtractText, + types.MemoryToolParamKind: "preference", + } + } if strings.HasPrefix(toolName, "subagent_") { return map[string]any{runtime.SubAgentToolParamQuery: "benchmark subtask"} } return map[string]any{"input": "benchmark"} } +func isMemoryExtractRequest(request *interfaces.LLMRequest) bool { + if request == nil || request.ResponseFormat == nil { + return false + } + return request.ResponseFormat.Type == interfaces.ResponseFormatJSON && + request.ResponseFormat.Name == "MemoryExtraction" +} + func splitMockTokens(total int) (prompt, completion int) { if total <= 0 { return 0, 0 diff --git a/benchmarks/setup/opts.go b/benchmarks/setup/opts.go index 98fe940..0b97288 100644 --- a/benchmarks/setup/opts.go +++ b/benchmarks/setup/opts.go @@ -53,6 +53,18 @@ func RootOptions( return opts } +// AppendMemoryOptions adds WithMemory when memory is enabled in cfg. +func AppendMemoryOptions(cfg *Config, opts []agent.Option) ([]agent.Option, error) { + memOpt, err := MemoryAgentOption(cfg) + if err != nil { + return nil, err + } + if memOpt != nil { + opts = append(opts, memOpt) + } + return opts, nil +} + func mapToolExecutionMode(raw string) agent.AgentToolExecutionMode { switch strings.ToLower(strings.TrimSpace(raw)) { case "sequential": diff --git a/benchmarks/worker/main.go b/benchmarks/worker/main.go index 1e2f0b7..057971a 100644 --- a/benchmarks/worker/main.go +++ b/benchmarks/worker/main.go @@ -44,7 +44,10 @@ func main() { } defer setup.CloseAgents(tree.Created) - opts := setup.RootOptions(cfg, llm, lgr, setup.RootAgentName, tree.RootPrompt, tree.SubAgents, cfg.Temporal.TaskQueue, false) + opts, err := setup.AppendMemoryOptions(cfg, setup.RootOptions(cfg, llm, lgr, setup.RootAgentName, tree.RootPrompt, tree.SubAgents, cfg.Temporal.TaskQueue, false)) + if err != nil { + log.Fatalf("memory options: %v", err) + } w, err := agent.NewAgentWorker(opts...) if err != nil { log.Fatalf("create agent worker: %v", err) diff --git a/cmd/main.go b/cmd/main.go index 507456b..a274803 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -13,6 +13,7 @@ import ( "github.com/agenticenv/agent-sdk-go/internal/types" "github.com/agenticenv/agent-sdk-go/pkg/agent" + "github.com/agenticenv/agent-sdk-go/pkg/conversation" "github.com/agenticenv/agent-sdk-go/pkg/conversation/inmem" "github.com/agenticenv/agent-sdk-go/pkg/tools/calculator" "github.com/agenticenv/agent-sdk-go/pkg/tools/currenttime" @@ -92,8 +93,7 @@ func main() { agent.WithLLMClient(llmClient), agent.WithStream(true), agent.WithToolRegistry(reg), - agent.WithConversation(conv), - agent.WithConversationSize(20), + agent.WithConversation(conversation.DefaultConfig(conv)), agent.WithLogger(lgr), } opts = append(opts, RuntimeOption(cfg)...) diff --git a/eval-harness/README.md b/eval-harness/README.md index 0cc9450..cb11fc3 100644 --- a/eval-harness/README.md +++ b/eval-harness/README.md @@ -39,7 +39,21 @@ Default path: `eval-harness/runner/config.yaml` | `temporal.namespace` | `default` | Temporal namespace | | `temporal.task_queue` | `eval-harness` | Task queue | -Temporal mode uses an embedded local worker. Start Temporal before running (e.g. `task infra:temporal:up` from `examples/`). +**Memory** (same file; `enabled: false` for default tool tests): + +| Field | Default | Description | +|-------|---------|-------------| +| `memory.enabled` | `false` | Set true in YAML, or pass `-memory` on the runner | +| `memory.store_mode` | `ondemand` | `ondemand` or `always` | +| `memory.scenario` | `store_recall` | Two-run store then recall when enabled | +| `memory.store_prompt` / `memory.recall_prompt` | — | Used when scenario is active | + +```bash +./eval-harness/run_agent_memory.sh ondemand # same config.yaml, -memory flag +./eval-harness/run_agent_memory.sh always +``` + +Temporal mode uses an embedded local worker. ### Output @@ -49,10 +63,16 @@ Stdout is always JSON: { "content": "eval complete", "llm_usage": { "prompt_tokens": 600, "completion_tokens": 400, "total_tokens": 1000 }, - "telemetry": { "run": { ... }, "tools": { ... }, "storage": { ... } } + "telemetry": { "run": { ... }, "tools": { ... }, "storage": { ... } }, + "memory_scenario": { + "store": { "content": "...", "telemetry": { ... } }, + "recall": { "content": "...", "telemetry": { ... } } + } } ``` +`memory_scenario` is present only for `memory.scenario: store_recall`. + ## PromptFoo Config: `eval-harness/promptfoo/config.yaml` @@ -87,21 +107,23 @@ The runner accepts PromptFoo’s prompt as a positional argument when `-prompt` ### Tests -Four test cases in `config.yaml`, each with a JavaScript assertion on runner JSON: +Six test cases in `config.yaml` (each scoped to one provider — Promptfoo reports **6 runs**, not test×provider duplicates): -| Test | Checks | -|------|--------| -| all mock tools were called | `telemetry.tools.breakdown` — `eval_tool_1`, `eval_tool_2`, `eval_tool_3`, each called once | -| agent completed successfully | `telemetry.run.finish_reason === "complete"` and `content === "eval complete"` | -| no failed tool calls | `telemetry.tools.failed_calls === 0` | -| llm usage reported | `llm_usage.total_tokens > 0` | +| Test | Provider | Checks | +|------|----------|--------| +| all mock tools were called | `eval-agent` | `telemetry.tools.breakdown` — three eval tools, each once | +| agent completed successfully | `eval-agent` | `finish_reason === "complete"` and `content === "eval complete"` | +| no failed tool calls | `eval-agent` | `telemetry.tools.failed_calls === 0` | +| llm usage reported | `eval-agent` | `llm_usage.total_tokens > 0` | +| memory ondemand stores then recalls | `eval-agent-memory-ondemand` | `memory_scenario` store/recall telemetry | +| memory always stores then recalls | `eval-agent-memory-always` | same for always store mode | ### Customizing - **Change the prompt** — edit `prompts` in `promptfoo/config.yaml`, or add `vars` and use `{{var}}` in the prompt string. - **Change agent behavior** — edit `eval-harness/runner/config.yaml` (tool count, runtime, system prompt), or adjust `eval-harness/run_agent.sh`. - **Add tests** — append cases under `tests:` with `type: javascript` and `value:` returning a boolean. -- **Filter providers** — use `label: eval-agent` in test `options.providers` if you add more providers later. +- **Filter providers** — set `providers: [label]` on each test so agent checks do not run on memory providers (and vice versa). ## DeepEval @@ -144,12 +166,14 @@ finish_reason = agent_res["telemetry"]["run"]["finish_reason"] ### Tests -Two pytest tests in `test_agent.py`: +Four pytest tests in `test_agent.py`: | Test | Checks | |------|--------| | `test_agent_completes_with_telemetry` | `content`, `llm_usage`, `finish_reason`, `failed_calls`, `total_calls`, `breakdown` keys | | `test_agent_tool_correctness` | `ToolCorrectnessMetric` — `tools_called` from telemetry vs expected tools | +| `test_memory_store_recall_ondemand` | `memory_scenario` store/recall telemetry (ondemand) | +| `test_memory_store_recall_always` | `memory_scenario` store/recall telemetry (always) | ### Customizing diff --git a/eval-harness/deepeval/harness.py b/eval-harness/deepeval/harness.py index a5a16b2..761a8ab 100644 --- a/eval-harness/deepeval/harness.py +++ b/eval-harness/deepeval/harness.py @@ -35,6 +35,13 @@ def run_agent(prompt: str = DEFAULT_PROMPT) -> dict: return json.loads(raw) +def run_agent_memory(store_mode: str = "ondemand") -> dict: + """Execute the memory store_recall eval harness scenario.""" + script = REPO_ROOT / "eval-harness" / "run_agent_memory.sh" + raw = subprocess.check_output([str(script), store_mode], cwd=REPO_ROOT, text=True) + return json.loads(raw) + + def tools_called(agent_res: dict) -> list[str]: """Return tool names from telemetry breakdown.""" breakdown = agent_res["telemetry"]["tools"]["breakdown"] diff --git a/eval-harness/deepeval/test_agent.py b/eval-harness/deepeval/test_agent.py index a89ee35..53a5576 100644 --- a/eval-harness/deepeval/test_agent.py +++ b/eval-harness/deepeval/test_agent.py @@ -4,7 +4,7 @@ from deepeval.metrics import ToolCorrectnessMetric from deepeval.test_case import LLMTestCase, ToolCall -from harness import DEFAULT_PROMPT, StubJudge, run_agent, tools_called +from harness import DEFAULT_PROMPT, StubJudge, run_agent, run_agent_memory, tools_called EXPECTED_TOOLS = [ ToolCall(name="eval_tool_1"), @@ -54,3 +54,27 @@ def test_agent_tool_correctness(): async_mode=False, ) assert_test(test_case, [metric]) + + +def test_memory_store_recall_ondemand(): + """Memory ondemand: store run persists, recall run loads scoped memories.""" + agent_res = run_agent_memory("ondemand") + store = agent_res["memory_scenario"]["store"]["telemetry"]["storage"] + recall = agent_res["memory_scenario"]["recall"]["telemetry"]["storage"] + + assert store["total_memory_stores"] >= 1 + assert store.get("failed_memory_stores", 0) == 0 + assert recall["total_memory_recalls"] >= 1 + assert recall.get("failed_memory_recalls", 0) == 0 + + +def test_memory_store_recall_always(): + """Memory always: run-end extract stores, recall run loads scoped memories.""" + agent_res = run_agent_memory("always") + store = agent_res["memory_scenario"]["store"]["telemetry"]["storage"] + recall = agent_res["memory_scenario"]["recall"]["telemetry"]["storage"] + + assert store["total_memory_stores"] >= 1 + assert store.get("failed_memory_stores", 0) == 0 + assert recall["total_memory_recalls"] >= 1 + assert recall.get("failed_memory_recalls", 0) == 0 diff --git a/eval-harness/promptfoo/config.yaml b/eval-harness/promptfoo/config.yaml index 106c260..7700c89 100644 --- a/eval-harness/promptfoo/config.yaml +++ b/eval-harness/promptfoo/config.yaml @@ -9,9 +9,14 @@ prompts: providers: - id: exec:../run_agent.sh label: eval-agent + - id: exec:../run_agent_memory.sh ondemand + label: eval-agent-memory-ondemand + - id: exec:../run_agent_memory.sh always + label: eval-agent-memory-always tests: - description: all mock tools were called + providers: [eval-agent] assert: - type: javascript value: | @@ -24,6 +29,7 @@ tests: && breakdown.eval_tool_3 === 1; - description: agent completed successfully + providers: [eval-agent] assert: - type: javascript value: | @@ -32,6 +38,7 @@ tests: && res.content === "eval complete"; - description: no failed tool calls + providers: [eval-agent] assert: - type: javascript value: | @@ -39,8 +46,35 @@ tests: return res.telemetry.tools.failed_calls === 0; - description: llm usage reported + providers: [eval-agent] assert: - type: javascript value: | const res = JSON.parse(output); return res.llm_usage.total_tokens > 0; + + - description: memory ondemand stores then recalls + providers: [eval-agent-memory-ondemand] + assert: + - type: javascript + value: | + const res = JSON.parse(output); + const store = res.memory_scenario.store.telemetry.storage; + const recall = res.memory_scenario.recall.telemetry.storage; + return store.total_memory_stores >= 1 + && (store.failed_memory_stores || 0) === 0 + && recall.total_memory_recalls >= 1 + && (recall.failed_memory_recalls || 0) === 0; + + - description: memory always stores then recalls + providers: [eval-agent-memory-always] + assert: + - type: javascript + value: | + const res = JSON.parse(output); + const store = res.memory_scenario.store.telemetry.storage; + const recall = res.memory_scenario.recall.telemetry.storage; + return store.total_memory_stores >= 1 + && (store.failed_memory_stores || 0) === 0 + && recall.total_memory_recalls >= 1 + && (recall.failed_memory_recalls || 0) === 0; diff --git a/eval-harness/run_agent_memory.sh b/eval-harness/run_agent_memory.sh new file mode 100755 index 0000000..519255e --- /dev/null +++ b/eval-harness/run_agent_memory.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash +# Memory eval harness: enables memory store_recall on the shared config.yaml. +# Usage: run_agent_memory.sh [ondemand|always] + +set -euo pipefail + +ROOT="$(cd "$(dirname "$0")/.." && pwd)" +CONFIG="${ROOT}/eval-harness/runner/config.yaml" +MODE="${1:-ondemand}" + +cd "$ROOT" +exec go run ./eval-harness/runner -config "$CONFIG" -memory -memory-store-mode "$MODE" diff --git a/eval-harness/runner/config.yaml b/eval-harness/runner/config.yaml index 0c8b260..dbba841 100644 --- a/eval-harness/runner/config.yaml +++ b/eval-harness/runner/config.yaml @@ -1,4 +1,4 @@ -runtime: local # local or temporal +runtime: local user_prompt: "run eval check" @@ -7,6 +7,14 @@ agent: system_prompt: "You are an evaluation agent. Use available tools when helpful, then answer concisely." tool_count: 3 +memory: + enabled: false + store_mode: ondemand + user_id: eval-user + scenario: store_recall + store_prompt: "Remember for future runs: I prefer concise answers." + recall_prompt: "What answer style do I prefer?" + temporal: host: localhost port: 7233 diff --git a/eval-harness/runner/main.go b/eval-harness/runner/main.go index 1b8ad64..db233a9 100644 --- a/eval-harness/runner/main.go +++ b/eval-harness/runner/main.go @@ -15,6 +15,8 @@ func main() { prompt := flag.String("prompt", "", "override user_prompt from config") runtimeFlag := flag.String("runtime", "", "override runtime: local or temporal") toolCount := flag.Int("tools", 0, "override agent.tool_count (0 = use config)") + memoryStoreMode := flag.String("memory-store-mode", "", "override memory.store_mode: ondemand or always") + memoryScenario := flag.Bool("memory", false, "enable memory store_recall scenario from config") flag.Parse() fileCfg, err := setup.LoadConfig(*configPath) @@ -35,15 +37,26 @@ func main() { if *toolCount > 0 { runCfg.ToolCount = *toolCount } + if *memoryStoreMode != "" { + mode, err := setup.ParseMemoryStoreMode(*memoryStoreMode) + if err != nil { + log.Fatalf("memory store mode: %v", err) + } + runCfg.Memory.StoreMode = mode + } + if *memoryScenario { + runCfg.Memory.Enabled = true + runCfg.ToolCount = 0 + } - result, err := Run(context.Background(), runCfg) + outcome, err := Run(context.Background(), runCfg) if err != nil { log.Fatalf("eval run failed: %v", err) } enc := json.NewEncoder(os.Stdout) enc.SetIndent("", " ") - if err := enc.Encode(OutputFromResult(result)); err != nil { + if err := enc.Encode(OutputFromResult(outcome)); err != nil { log.Fatalf("encode result: %v", err) } } diff --git a/eval-harness/runner/output.go b/eval-harness/runner/output.go index fa94cf8..558edc7 100644 --- a/eval-harness/runner/output.go +++ b/eval-harness/runner/output.go @@ -4,13 +4,36 @@ import "github.com/agenticenv/agent-sdk-go/pkg/agent" // Output is a JSON-friendly view of an agent run for eval harness tools. type Output struct { - Content string `json:"content"` - LLMUsage *agent.LLMUsage `json:"llm_usage,omitempty"` - Telemetry *agent.AgentTelemetry `json:"telemetry,omitempty"` + Content string `json:"content"` + LLMUsage *agent.LLMUsage `json:"llm_usage,omitempty"` + Telemetry *agent.AgentTelemetry `json:"telemetry,omitempty"` + MemoryScenario *MemoryScenarioOutput `json:"memory_scenario,omitempty"` } -// OutputFromResult maps an AgentRunResult into Output for assertions or CLI JSON output. -func OutputFromResult(result *agent.AgentRunResult) *Output { +// MemoryScenarioOutput is JSON for two-run memory regression scenarios. +type MemoryScenarioOutput struct { + Store *Output `json:"store"` + Recall *Output `json:"recall"` +} + +// OutputFromResult maps a RunOutcome into Output for assertions or CLI JSON output. +func OutputFromResult(outcome *RunOutcome) *Output { + if outcome == nil || outcome.Result == nil { + return nil + } + output := OutputFromRunResult(outcome.Result) + if outcome.MemoryScenario == nil { + return output + } + output.MemoryScenario = &MemoryScenarioOutput{ + Store: OutputFromRunResult(outcome.MemoryScenario.Store), + Recall: OutputFromRunResult(outcome.MemoryScenario.Recall), + } + return output +} + +// OutputFromRunResult maps an AgentRunResult into Output. +func OutputFromRunResult(result *agent.AgentRunResult) *Output { if result == nil { return nil } diff --git a/eval-harness/runner/runner.go b/eval-harness/runner/runner.go index 40f0d74..b32e840 100644 --- a/eval-harness/runner/runner.go +++ b/eval-harness/runner/runner.go @@ -6,15 +6,32 @@ import ( "github.com/agenticenv/agent-sdk-go/eval-harness/runner/setup" "github.com/agenticenv/agent-sdk-go/pkg/agent" + "github.com/agenticenv/agent-sdk-go/pkg/memory" ) +// RunOutcome is the result of an eval harness execution. +type RunOutcome struct { + Result *agent.AgentRunResult + MemoryScenario *MemoryScenarioOutcome +} + +// MemoryScenarioOutcome holds store/recall runs for memory regression checks. +type MemoryScenarioOutcome struct { + Store *agent.AgentRunResult + Recall *agent.AgentRunResult +} + // Run executes one agent run with mock LLM and mock tools, then closes the agent. -func Run(ctx context.Context, cfg setup.Config) (*agent.AgentRunResult, error) { +func Run(ctx context.Context, cfg setup.Config) (*RunOutcome, error) { cfg.ApplyDefaults() if err := cfg.Validate(); err != nil { return nil, err } + if cfg.UsesMemoryScenario() { + return runMemoryStoreRecall(ctx, cfg) + } + a, err := setup.BuildAgent(cfg) if err != nil { return nil, err @@ -25,5 +42,33 @@ func Run(ctx context.Context, cfg setup.Config) (*agent.AgentRunResult, error) { if err != nil { return nil, fmt.Errorf("agent run: %w", err) } - return result, nil + return &RunOutcome{Result: result}, nil +} + +func runMemoryStoreRecall(ctx context.Context, cfg setup.Config) (*RunOutcome, error) { + a, err := setup.BuildAgent(cfg) + if err != nil { + return nil, err + } + defer a.Close() + + scoped := memory.WithContextUserID(ctx, cfg.Memory.UserID) + + storeResult, err := a.Run(scoped, cfg.Memory.StorePrompt, nil) + if err != nil { + return nil, fmt.Errorf("memory store run: %w", err) + } + + recallResult, err := a.Run(scoped, cfg.Memory.RecallPrompt, nil) + if err != nil { + return nil, fmt.Errorf("memory recall run: %w", err) + } + + return &RunOutcome{ + Result: recallResult, + MemoryScenario: &MemoryScenarioOutcome{ + Store: storeResult, + Recall: recallResult, + }, + }, nil } diff --git a/eval-harness/runner/runner_test.go b/eval-harness/runner/runner_test.go index 19fb5f5..02f0b3b 100644 --- a/eval-harness/runner/runner_test.go +++ b/eval-harness/runner/runner_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/agenticenv/agent-sdk-go/eval-harness/runner/setup" + "github.com/agenticenv/agent-sdk-go/pkg/memory" "github.com/stretchr/testify/require" ) @@ -15,29 +16,71 @@ func TestLoadConfig_Defaults(t *testing.T) { require.Equal(t, "run eval check", cfg.UserPrompt) require.Equal(t, "local", cfg.Runtime) require.Equal(t, 3, cfg.Agent.ToolCount) + require.False(t, cfg.Memory.Enabled) +} + +func TestLoadConfig_MemoryFields(t *testing.T) { + cfg, err := setup.LoadConfig("config.yaml") + require.NoError(t, err) + require.False(t, cfg.Memory.Enabled) + require.Equal(t, setup.MemoryScenarioStoreRecall, cfg.Memory.Scenario) + require.NotEmpty(t, cfg.Memory.StorePrompt) + require.NotEmpty(t, cfg.Memory.RecallPrompt) } func TestRun_FromFileConfig(t *testing.T) { fileCfg, err := setup.LoadConfig("config.yaml") require.NoError(t, err) - result, err := Run(context.Background(), fileCfg.Config()) + outcome, err := Run(context.Background(), fileCfg.Config()) require.NoError(t, err) - require.NotEmpty(t, result.Content) - require.Equal(t, int64(3), result.Telemetry.Tools.TotalCalls) + require.NotNil(t, outcome.Result) + require.NotEmpty(t, outcome.Result.Content) + require.Equal(t, int64(3), outcome.Result.Telemetry.Tools.TotalCalls) } func TestRun_LocalRuntime(t *testing.T) { - result, err := Run(context.Background(), setup.Config{ + outcome, err := Run(context.Background(), setup.Config{ UserPrompt: "run eval check", Runtime: setup.RuntimeLocal, ToolCount: 2, }) require.NoError(t, err) - require.NotNil(t, result) - require.NotEmpty(t, result.Content) - require.NotNil(t, result.Telemetry) - require.Equal(t, int64(2), result.Telemetry.Tools.TotalCalls) + require.NotNil(t, outcome.Result) + require.NotEmpty(t, outcome.Result.Content) + require.NotNil(t, outcome.Result.Telemetry) + require.Equal(t, int64(2), outcome.Result.Telemetry.Tools.TotalCalls) +} + +func TestRun_MemoryStoreRecall_OnDemand(t *testing.T) { + fileCfg, err := setup.LoadConfig("config.yaml") + require.NoError(t, err) + + runCfg := fileCfg.Config() + runCfg.Memory.Enabled = true + runCfg.ToolCount = 0 + + outcome, err := Run(context.Background(), runCfg) + require.NoError(t, err) + require.NotNil(t, outcome.MemoryScenario) + require.GreaterOrEqual(t, outcome.MemoryScenario.Store.Telemetry.Storage.TotalMemoryStores, int64(1)) + require.GreaterOrEqual(t, outcome.MemoryScenario.Recall.Telemetry.Storage.TotalMemoryRecalls, int64(1)) +} + +func TestRun_MemoryStoreRecall_Always(t *testing.T) { + fileCfg, err := setup.LoadConfig("config.yaml") + require.NoError(t, err) + + runCfg := fileCfg.Config() + runCfg.Memory.Enabled = true + runCfg.Memory.StoreMode = memory.StoreModeAlways + runCfg.ToolCount = 0 + + outcome, err := Run(context.Background(), runCfg) + require.NoError(t, err) + require.NotNil(t, outcome.MemoryScenario) + require.GreaterOrEqual(t, outcome.MemoryScenario.Store.Telemetry.Storage.TotalMemoryStores, int64(1)) + require.GreaterOrEqual(t, outcome.MemoryScenario.Recall.Telemetry.Storage.TotalMemoryRecalls, int64(1)) } func TestRun_TemporalRuntime(t *testing.T) { @@ -45,14 +88,14 @@ func TestRun_TemporalRuntime(t *testing.T) { t.Skip("set EVAL_HARNESS_TEMPORAL=true with Temporal running on localhost:7233") } - result, err := Run(context.Background(), setup.Config{ + outcome, err := Run(context.Background(), setup.Config{ UserPrompt: "run eval check", Runtime: setup.RuntimeTemporal, ToolCount: 2, }) require.NoError(t, err) - require.NotEmpty(t, result.Content) - require.Equal(t, int64(2), result.Telemetry.Tools.TotalCalls) + require.NotEmpty(t, outcome.Result.Content) + require.Equal(t, int64(2), outcome.Result.Telemetry.Tools.TotalCalls) } func TestRun_RequiresUserPrompt(t *testing.T) { diff --git a/eval-harness/runner/setup/agent.go b/eval-harness/runner/setup/agent.go index 863fb64..754d23c 100644 --- a/eval-harness/runner/setup/agent.go +++ b/eval-harness/runner/setup/agent.go @@ -42,6 +42,14 @@ func BuildAgent(cfg Config) (*agent.Agent, error) { })) } + memOpt, err := MemoryAgentOption(cfg) + if err != nil { + return nil, fmt.Errorf("memory option: %w", err) + } + if memOpt != nil { + opts = append(opts, memOpt) + } + a, err := agent.NewAgent(opts...) if err != nil { return nil, fmt.Errorf("new agent: %w", err) diff --git a/eval-harness/runner/setup/config.go b/eval-harness/runner/setup/config.go index de4f53f..965a75d 100644 --- a/eval-harness/runner/setup/config.go +++ b/eval-harness/runner/setup/config.go @@ -4,17 +4,22 @@ import ( "fmt" "strings" + testutil "github.com/agenticenv/agent-sdk-go/internal/testing" "github.com/agenticenv/agent-sdk-go/pkg/agent" "github.com/agenticenv/agent-sdk-go/pkg/interfaces" "github.com/agenticenv/agent-sdk-go/pkg/logger" + "github.com/agenticenv/agent-sdk-go/pkg/memory" ) const ( - DefaultAgentName = "eval-agent" - DefaultToolCount = 3 - DefaultMockTokens = 500 - DefaultSystemPrompt = "You are an evaluation agent. Use available tools when helpful, then answer concisely." - DefaultRuntime = RuntimeLocal + DefaultAgentName = "eval-agent" + DefaultToolCount = 3 + DefaultMockTokens = 500 + DefaultSystemPrompt = "You are an evaluation agent. Use available tools when helpful, then answer concisely." + DefaultRuntime = RuntimeLocal + DefaultMemoryUserID = "eval-user" + DefaultMemoryStoreMode = memory.StoreModeOnDemand + MemoryScenarioStoreRecall = "store_recall" ) // Runtime selects the agent execution backend. @@ -41,6 +46,16 @@ type TemporalConfig struct { TaskQueue string `mapstructure:"task_queue"` } +// MemoryConfig configures long-term memory for eval harness runs. +type MemoryConfig struct { + Enabled bool + StoreMode memory.StoreMode + UserID string + Scenario string + StorePrompt string + RecallPrompt string +} + // Config holds settings for a single eval agent run. type Config struct { UserPrompt string @@ -51,6 +66,7 @@ type Config struct { LLM LLMConfig Tool ToolConfig ToolCount int + Memory MemoryConfig LLMClient interfaces.LLMClient ToolRegistry agent.ToolRegistry Logger logger.Logger @@ -61,6 +77,29 @@ func (c *Config) UseTemporal() bool { return c != nil && strings.EqualFold(strings.TrimSpace(string(c.Runtime)), string(RuntimeTemporal)) } +// MemoryEnabled reports whether memory is wired for this run. +func (c *Config) MemoryEnabled() bool { + return c != nil && c.Memory.Enabled +} + +// UsesMemoryScenario reports whether the runner executes a multi-step memory scenario. +func (c *Config) UsesMemoryScenario() bool { + return c.MemoryEnabled() && strings.EqualFold(strings.TrimSpace(c.Memory.Scenario), MemoryScenarioStoreRecall) +} + +// ApplyMemoryDefaults fills unset memory config fields. +func (m *MemoryConfig) ApplyMemoryDefaults() { + if m == nil { + return + } + if strings.TrimSpace(m.UserID) == "" { + m.UserID = DefaultMemoryUserID + } + if strings.TrimSpace(string(m.StoreMode)) == "" { + m.StoreMode = DefaultMemoryStoreMode + } +} + // ApplyDefaults fills unset config fields. func (c *Config) ApplyDefaults() { if c == nil { @@ -75,7 +114,7 @@ func (c *Config) ApplyDefaults() { if c.SystemPrompt == "" { c.SystemPrompt = DefaultSystemPrompt } - if c.ToolCount <= 0 { + if c.ToolCount <= 0 && !c.MemoryEnabled() { c.ToolCount = DefaultToolCount } if c.LLM.MockTokens <= 0 { @@ -84,6 +123,7 @@ func (c *Config) ApplyDefaults() { if c.Logger == nil { c.Logger = logger.NoopLogger() } + c.Memory.ApplyMemoryDefaults() if c.Temporal.TaskQueue == "" { c.Temporal.TaskQueue = "eval-harness" } @@ -98,18 +138,65 @@ func (c *Config) ApplyDefaults() { } } +// ValidateMemory checks memory-related config when enabled. +func (c *Config) ValidateMemory() error { + if c == nil || !c.Memory.Enabled { + return nil + } + c.Memory.ApplyMemoryDefaults() + switch c.Memory.StoreMode { + case memory.StoreModeOnDemand, memory.StoreModeAlways: + default: + return fmt.Errorf("memory.store_mode must be %q or %q", memory.StoreModeOnDemand, memory.StoreModeAlways) + } + if !c.UsesMemoryScenario() { + return nil + } + if strings.TrimSpace(c.Memory.StorePrompt) == "" { + return fmt.Errorf("memory.store_prompt is required when memory.scenario is %q", MemoryScenarioStoreRecall) + } + if strings.TrimSpace(c.Memory.RecallPrompt) == "" { + return fmt.Errorf("memory.recall_prompt is required when memory.scenario is %q", MemoryScenarioStoreRecall) + } + return nil +} + // Validate checks required config fields. func (c *Config) Validate() error { if c == nil { return fmt.Errorf("config is required") } - if c.UserPrompt == "" { - return fmt.Errorf("user prompt is required") - } switch strings.ToLower(strings.TrimSpace(string(c.Runtime))) { case string(RuntimeLocal), string(RuntimeTemporal): default: return fmt.Errorf("runtime must be %q or %q", RuntimeLocal, RuntimeTemporal) } - return nil + if !c.UsesMemoryScenario() && c.UserPrompt == "" { + return fmt.Errorf("user prompt is required") + } + return c.ValidateMemory() +} + +// ParseMemoryStoreMode parses eval harness store mode strings. +func ParseMemoryStoreMode(raw string) (memory.StoreMode, error) { + switch strings.ToLower(strings.TrimSpace(raw)) { + case "", string(memory.StoreModeOnDemand), "on-demand", "on_demand": + return memory.StoreModeOnDemand, nil + case string(memory.StoreModeAlways): + return memory.StoreModeAlways, nil + default: + return "", fmt.Errorf("memory store mode must be %q or %q", memory.StoreModeOnDemand, memory.StoreModeAlways) + } +} + +// MemoryAgentOption returns WithMemory when memory is enabled. +func MemoryAgentOption(cfg Config) (agent.Option, error) { + if !cfg.MemoryEnabled() { + return nil, nil + } + cfg.Memory.ApplyMemoryDefaults() + memCfg := memory.DefaultConfig(testutil.NewInmemMemory()) + memCfg.Store.Mode = cfg.Memory.StoreMode + memCfg.Recall.Enabled = true + return agent.WithMemory(memCfg), nil } diff --git a/eval-harness/runner/setup/config_test.go b/eval-harness/runner/setup/config_test.go new file mode 100644 index 0000000..d35a3a6 --- /dev/null +++ b/eval-harness/runner/setup/config_test.go @@ -0,0 +1,25 @@ +package setup + +import ( + "testing" + + "github.com/agenticenv/agent-sdk-go/pkg/memory" +) + +func TestParseMemoryStoreMode(t *testing.T) { + t.Parallel() + + mode, err := ParseMemoryStoreMode("always") + if err != nil || mode != memory.StoreModeAlways { + t.Fatalf("ParseMemoryStoreMode(always) = %q, %v", mode, err) + } + + mode, err = ParseMemoryStoreMode("") + if err != nil || mode != memory.StoreModeOnDemand { + t.Fatalf("ParseMemoryStoreMode(empty) = %q, %v", mode, err) + } + + if _, err := ParseMemoryStoreMode("invalid"); err == nil { + t.Fatal("expected error for invalid mode") + } +} diff --git a/eval-harness/runner/setup/load.go b/eval-harness/runner/setup/load.go index 4ad0f76..afec8af 100644 --- a/eval-harness/runner/setup/load.go +++ b/eval-harness/runner/setup/load.go @@ -10,10 +10,21 @@ import ( // FileConfig is the YAML configuration for eval-harness runs. type FileConfig struct { - Runtime string `mapstructure:"runtime"` - UserPrompt string `mapstructure:"user_prompt"` - Agent FileAgentConfig `mapstructure:"agent"` - Temporal TemporalConfig `mapstructure:"temporal"` + Runtime string `mapstructure:"runtime"` + UserPrompt string `mapstructure:"user_prompt"` + Agent FileAgentConfig `mapstructure:"agent"` + Memory FileMemoryConfig `mapstructure:"memory"` + Temporal TemporalConfig `mapstructure:"temporal"` +} + +// FileMemoryConfig holds memory fields from YAML. +type FileMemoryConfig struct { + Enabled bool `mapstructure:"enabled"` + StoreMode string `mapstructure:"store_mode"` + UserID string `mapstructure:"user_id"` + Scenario string `mapstructure:"scenario"` + StorePrompt string `mapstructure:"store_prompt"` + RecallPrompt string `mapstructure:"recall_prompt"` } // FileAgentConfig holds agent fields from YAML. @@ -28,6 +39,7 @@ func (f *FileConfig) Config() Config { if f == nil { return Config{} } + storeMode, _ := ParseMemoryStoreMode(f.Memory.StoreMode) return Config{ UserPrompt: f.UserPrompt, Runtime: Runtime(f.Runtime), @@ -35,6 +47,14 @@ func (f *FileConfig) Config() Config { AgentName: f.Agent.Name, SystemPrompt: f.Agent.SystemPrompt, ToolCount: f.Agent.ToolCount, + Memory: MemoryConfig{ + Enabled: f.Memory.Enabled, + StoreMode: storeMode, + UserID: f.Memory.UserID, + Scenario: f.Memory.Scenario, + StorePrompt: f.Memory.StorePrompt, + RecallPrompt: f.Memory.RecallPrompt, + }, } } @@ -79,7 +99,7 @@ func (f *FileConfig) validate() error { if f == nil { return fmt.Errorf("config is required") } - if strings.TrimSpace(f.UserPrompt) == "" { + if strings.TrimSpace(f.UserPrompt) == "" && !strings.EqualFold(strings.TrimSpace(f.Memory.Scenario), MemoryScenarioStoreRecall) { return fmt.Errorf("user_prompt is required") } switch strings.ToLower(strings.TrimSpace(f.Runtime)) { @@ -91,7 +111,7 @@ func (f *FileConfig) validate() error { default: return fmt.Errorf("runtime must be %q or %q", RuntimeLocal, RuntimeTemporal) } - if f.Agent.ToolCount <= 0 { + if f.Agent.ToolCount <= 0 && !f.Memory.Enabled { f.Agent.ToolCount = DefaultToolCount } if f.Agent.Name == "" { @@ -112,5 +132,18 @@ func (f *FileConfig) validate() error { if f.Temporal.Namespace == "" { f.Temporal.Namespace = "default" } + if f.Memory.Enabled { + if _, err := ParseMemoryStoreMode(f.Memory.StoreMode); err != nil { + return err + } + if strings.EqualFold(strings.TrimSpace(f.Memory.Scenario), MemoryScenarioStoreRecall) { + if strings.TrimSpace(f.Memory.StorePrompt) == "" { + return fmt.Errorf("memory.store_prompt is required when memory.scenario is %q", MemoryScenarioStoreRecall) + } + if strings.TrimSpace(f.Memory.RecallPrompt) == "" { + return fmt.Errorf("memory.recall_prompt is required when memory.scenario is %q", MemoryScenarioStoreRecall) + } + } + } return nil } diff --git a/eval-harness/runner/setup/mock_llm.go b/eval-harness/runner/setup/mock_llm.go index 8ef9064..126a22e 100644 --- a/eval-harness/runner/setup/mock_llm.go +++ b/eval-harness/runner/setup/mock_llm.go @@ -6,11 +6,17 @@ import ( "math/rand" "time" + "github.com/agenticenv/agent-sdk-go/internal/types" "github.com/agenticenv/agent-sdk-go/pkg/interfaces" ) const mockLLMModel = "eval-mock" +const ( + mockMemoryExtractText = "User prefers concise answers" + mockRecallContent = "You prefer concise answers." +) + // MockLLMClient is a deterministic mock LLM for eval harness runs. type MockLLMClient struct { cfg LLMConfig @@ -33,6 +39,13 @@ func (m *MockLLMClient) Generate(ctx context.Context, request *interfaces.LLMReq TotalTokens: int64(promptTokens + completionTokens), } + if isMemoryExtractRequest(request) { + return &interfaces.LLMResponse{ + Content: fmt.Sprintf(`{"memories":[{"text":%q,"kind":"preference"}]}`, mockMemoryExtractText), + Usage: usage, + }, nil + } + if hasToolResultMessages(request) { return &interfaces.LLMResponse{ Content: "eval complete", @@ -40,12 +53,19 @@ func (m *MockLLMClient) Generate(ctx context.Context, request *interfaces.LLMReq }, nil } + if len(request.Tools) == 0 { + return &interfaces.LLMResponse{ + Content: mockRecallContent, + Usage: usage, + }, nil + } + toolCalls := make([]*interfaces.ToolCall, 0, len(request.Tools)) for i, spec := range request.Tools { toolCalls = append(toolCalls, &interfaces.ToolCall{ ToolCallID: fmt.Sprintf("tc-%d", i+1), ToolName: spec.Name, - Args: map[string]any{"input": "eval"}, + Args: mockToolArgs(spec.Name), }) } @@ -56,6 +76,24 @@ func (m *MockLLMClient) Generate(ctx context.Context, request *interfaces.LLMReq }, nil } +func mockToolArgs(toolName string) map[string]any { + if toolName == types.SaveMemoryToolName { + return map[string]any{ + types.MemoryToolParamText: mockMemoryExtractText, + types.MemoryToolParamKind: "preference", + } + } + return map[string]any{"input": "eval"} +} + +func isMemoryExtractRequest(request *interfaces.LLMRequest) bool { + if request == nil || request.ResponseFormat == nil { + return false + } + return request.ResponseFormat.Type == interfaces.ResponseFormatJSON && + request.ResponseFormat.Name == "MemoryExtraction" +} + func (m *MockLLMClient) GenerateStream(ctx context.Context, request *interfaces.LLMRequest) (interfaces.LLMStream, error) { resp, err := m.Generate(ctx, request) if err != nil { diff --git a/examples/.env.defaults b/examples/.env.defaults index a722111..a3aaede 100644 --- a/examples/.env.defaults +++ b/examples/.env.defaults @@ -73,6 +73,17 @@ PGVECTOR_SOURCE_COL=source PGVECTOR_EMBEDDING_COL=embedding PGVECTOR_MIN_SCORE=0.35 +# --- Long-term memory (agent_with_memory/weaviate, agent_with_memory/pgvector) --- +# MEMORY_STORE_MODE: always (run-end extract) or ondemand (save_memory tool). Default ondemand. +# Task examples:local runs each backend in both modes (4 runs). Same infra as retriever examples. +MEMORY_USER_ID=demo-user +MEMORY_STORE_MODE=ondemand +MEMORY_RECALL_ENABLED=true +MEMORY_RECALL_LIMIT=10 +MEMORY_RECALL_MIN_SCORE=0.35 +WEAVIATE_MEMORY_CLASS=AgentMemory +PGVECTOR_MEMORY_TABLE=agent_memories + # --- OpenAI-compatible embeddings (pgvector client-side; Weaviate text2vec-openai in Docker) --- # Set EMBEDDING_OPENAI_APIKEY in examples/.env (separate from chat LLM_APIKEY for Anthropic/Gemini). EMBEDDING_OPENAI_APIKEY= diff --git a/examples/README.md b/examples/README.md index 3476cb2..3f567cf 100644 --- a/examples/README.md +++ b/examples/README.md @@ -24,7 +24,7 @@ These examples run with `AGENT_RUNTIME=local` (default) or `AGENT_RUNTIME=tempor | Example | What it demonstrates | Infra (Task, from `examples/`) | |---------|---------------------|--------------------------------| | `simple_agent` | Minimal agent, no tools — system prompt, LLM client, single `Run()`; prints `AgentResponse.Usage` (token counts) when the provider reports them | — | -| `agent_with_conversation` | Redis conversation with `WithConversation` — multi-turn context, same `conversationID` for `Run` | `infra:redis:up` (or `infra:deps:up`) | +| `agent_with_conversation` | Redis conversation with `WithConversation(conversation.Config{...})` — multi-turn context, same `conversationID` for `Run` | `infra:redis:up` (or `infra:deps:up`) | | `agent_with_tools/basic` | Built-in tools (echo, calculator, weather, wikipedia, search) with auto-approval | — | | `agent_with_tools/approval` | Tools + `WithApprovalHandler` — user approves or rejects each tool run (`Run` only) | — | | `agent_with_tools/authorizer` | Custom tool authorization via `interfaces.ToolAuthorizer` — denied calls surface as `tool_result` with `denied` status | — | @@ -45,6 +45,7 @@ These examples run with `AGENT_RUNTIME=local` (default) or `AGENT_RUNTIME=tempor | `agent_with_a2a_server` | **Inbound** A2A server — **`A2A_SERVER_*`**; **[README](agent_with_a2a_server/README.md)** | `go run` or `infra:a2a:up` | | `agent_with_observability` | OTLP — **`config/`** vs **`objects/`**; **[README](agent_with_observability/README.md)** | `infra:lgtm:up` (or manual collector) | | `agent_with_retriever` | **`weaviate/`** or **`pgvector/`**; **`RETRIEVER_MODE`** — **[README](agent_with_retriever/README.md)** | `infra:weaviate:up` or `infra:pgvector:up` | +| `agent_with_memory` | **`weaviate/`** or **`pgvector/`** — **[README](agent_with_memory/README.md)**; `MEMORY_STORE_MODE=always\|ondemand` | `infra:weaviate:up` or `infra:pgvector:up` | ### Temporal only @@ -221,6 +222,17 @@ RETRIEVER_MODE=prefetch go run ./agent_with_retriever/weaviate "What are the ret Setup guides: **[agent_with_retriever/README.md](agent_with_retriever/README.md)**. +### Long-term memory (`agent_with_memory`) + +Same vector infra as retriever (`task infra:weaviate:up` or `task infra:pgvector:up`). **Weaviate** uses always store (run-end extract); **pgvector** uses on-demand store (`save_memory`). No CLI args runs a two-turn demo (store, then recall). + +```bash +go run ./agent_with_memory/weaviate +go run ./agent_with_memory/pgvector +``` + +Setup guide: **[agent_with_memory/README.md](agent_with_memory/README.md)**. + --- ### Temporal-only examples @@ -277,7 +289,7 @@ All examples call [`shared.PrintRunFooters`](shared/utils.go) after each run. Se | Env var | Default | When `true` | |---------|---------|-------------| | `SHOW_LLM_USAGE` | `false` | Prints token usage (`prompt_tokens`, `completion_tokens`, etc.) | -| `SHOW_TELEMETRY` | `false` | Prints run telemetry (`total_llm_calls`, tool counts, retriever searches, etc.) | +| `SHOW_TELEMETRY` | `false` | Prints run telemetry (`total_llm_calls`, tool counts, retriever searches, memory recalls/stores, etc.) | ```bash SHOW_LLM_USAGE=true go run ./simple_agent "Hello, what can you do?" @@ -287,6 +299,8 @@ SHOW_LLM_USAGE=true SHOW_TELEMETRY=true go run ./agent_with_stream "What's 17 * For retriever examples, `SHOW_TELEMETRY=true` also prints prefetch/agentic search breakdowns — see [agent_with_retriever/README.md](agent_with_retriever/README.md). +For memory examples, `SHOW_TELEMETRY=true` also prints `total_memory_recalls` and `total_memory_stores` — see [agent_with_memory/README.md](agent_with_memory/README.md). + ## Env vars | Env var | Description | @@ -331,3 +345,5 @@ For retriever examples, `SHOW_TELEMETRY=true` also prints prefetch/agentic searc | `RETRIEVER_MODE` | For **`agent_with_retriever`**: **`agentic`** (default), **`prefetch`**, or **`hybrid`** | | `WEAVIATE_HOST`, `WEAVIATE_SCHEME`, `WEAVIATE_CLASS`, … | Weaviate backend — **`.env.defaults`** and **[agent_with_retriever/README.md#weaviate](agent_with_retriever/README.md#weaviate)** | | `PGVECTOR_DSN`, `PGVECTOR_TABLE`, `EMBEDDING_OPENAI_MODEL`, … | pgvector backend — **`PGVECTOR_DSN` required**; **[agent_with_retriever/README.md#pgvector](agent_with_retriever/README.md#pgvector)** | +| `MEMORY_USER_ID`, `MEMORY_STORE_MODE`, `MEMORY_RECALL_ENABLED`, `MEMORY_RECALL_LIMIT`, `MEMORY_RECALL_MIN_SCORE` | For **`agent_with_memory`**: scope user, store mode (`always` / `ondemand`), recall settings — **[agent_with_memory/README.md](agent_with_memory/README.md)** | +| `WEAVIATE_MEMORY_CLASS`, `PGVECTOR_MEMORY_TABLE` | Memory backend class/table names (defaults: `AgentMemory`, `agent_memories`) | diff --git a/examples/agent_with_conversation/main.go b/examples/agent_with_conversation/main.go index 8623c09..6899f50 100644 --- a/examples/agent_with_conversation/main.go +++ b/examples/agent_with_conversation/main.go @@ -11,6 +11,7 @@ import ( config "github.com/agenticenv/agent-sdk-go/examples" "github.com/agenticenv/agent-sdk-go/examples/shared" "github.com/agenticenv/agent-sdk-go/pkg/agent" + "github.com/agenticenv/agent-sdk-go/pkg/conversation" "github.com/agenticenv/agent-sdk-go/pkg/conversation/redis" "github.com/agenticenv/agent-sdk-go/pkg/tools/calculator" "github.com/agenticenv/agent-sdk-go/pkg/tools/echo" @@ -48,9 +49,11 @@ func main() { agent.WithLLMClient(llmClient), agent.WithToolRegistry(reg), agent.WithToolApprovalPolicy(agent.AutoToolApprovalPolicy()), - agent.WithConversation(conv), - agent.WithConversationSize(20), - agent.EnableConversationSaveOnIteration(), + agent.WithConversation(conversation.Config{ + Conversation: conv, + Size: 20, + SaveOnIteration: true, + }), agent.WithLogger(config.NewLoggerFromLogConfig(cfg)), } opts = append(opts, config.RuntimeOption(cfg)...) diff --git a/examples/agent_with_memory/README.md b/examples/agent_with_memory/README.md new file mode 100644 index 0000000..d4125fb --- /dev/null +++ b/examples/agent_with_memory/README.md @@ -0,0 +1,163 @@ +# Agent with memory (`agent_with_memory`) + +Examples that wire **long-term memory** into **agent-sdk-go**. Pick **one backend** per run. + +| Backend | Package | Example entrypoint | +|---------|---------|-------------------| +| Weaviate | [`pkg/memory/weaviate`](../../pkg/memory/weaviate) | `go run ./agent_with_memory/weaviate` | +| PostgreSQL + pgvector | [`pkg/memory/pgvector`](../../pkg/memory/pgvector) | `go run ./agent_with_memory/pgvector` | + +Store mode is selected with **`MEMORY_STORE_MODE`** (`always` or `ondemand`; default `ondemand`). Both backends support both modes. `task examples:local` runs four combinations (each backend × each mode). + +Uses the same Docker stack as retriever examples ([`../docker/`](../docker/)). `task infra:weaviate:up` / `task infra:pgvector:up` creates the **memory** class/table (`AgentMemory` / `agent_memories`) in addition to retriever schema. No seed rows for memory — rows are written by agent runs. + +## Prerequisites + +- **Runtime** — **`AGENT_RUNTIME=local`** (default): in-process, no Temporal. Optional **`AGENT_RUNTIME=temporal`**: from `examples/`, run `task infra:temporal:up` (and `task infra:temporal:wait` if the example fails to connect). See [`temporal-setup.md`](../../temporal-setup.md). +- **`examples/.env`** — `LLM_APIKEY`, `LLM_MODEL`, and **`EMBEDDING_OPENAI_APIKEY`** (see **`.env.defaults`**) +- **Task** (`go-task`) and **Docker** (`task infra:weaviate:up` or `task infra:pgvector:up`) + +From `examples/`: + +```bash +task infra:status # see what is up +``` + +## Example behavior + +- **Store mode** — `MEMORY_STORE_MODE=always` extracts and stores at run end; `ondemand` registers `save_memory` for the LLM during the run (default). +- **No CLI args** — two runs in one process: run 1 stores a preference, run 2 recalls it. +- **With args** — single custom prompt. +- **Scope** — `MEMORY_USER_ID` in `.env` (default `demo-user`); must be the same across runs you want to share memories. + +Set `MEMORY_RECALL_ENABLED=false` in `.env` for store-only (skip load before LLM). + +Use `SHOW_TELEMETRY=true` to see `total_memory_stores` on run 1 when store succeeds. + +--- + +## Weaviate + +Weaviate embeds memory text via **nearText** (`text2vec-openai` in Docker). + +### Setup + +```bash +cd examples +task infra:weaviate:up +task infra:weaviate:down # when finished +``` + +Compose: [`docker/docker-compose.yml`](../docker/docker-compose.yml). Seed: [`docker/weaviate/seed.sh`](../docker/weaviate/seed.sh) (creates class **`AgentMemory`**). + +`EMBEDDING_OPENAI_APIKEY` must be set in `examples/.env` **before** `up`. After a key change: `task infra:weaviate:down && task infra:weaviate:up`. + +Verify the memory class exists: + +```bash +curl -s http://localhost:8080/v1/schema | jq '.classes[].class' +# expect Document and AgentMemory +``` + +### Environment + +```bash +WEAVIATE_HOST=localhost:8080 +WEAVIATE_SCHEME=http +WEAVIATE_MEMORY_CLASS=AgentMemory +MEMORY_USER_ID=demo-user +MEMORY_RECALL_ENABLED=true +MEMORY_RECALL_LIMIT=10 +MEMORY_RECALL_MIN_SCORE=0.35 +``` + +### Run + +```bash +go run ./agent_with_memory/weaviate +go run ./agent_with_memory/weaviate "Remember my favorite color is blue" +``` + +```bash +SHOW_TELEMETRY=true go run ./agent_with_memory/weaviate +``` + +### Weaviate troubleshooting + +| Symptom | What to do | +|---------|------------| +| `missing class data` / memory recall error on first run | Usually empty class — Weaviate returns `null` not `[]` (fixed in SDK). Update and re-run; or `MEMORY_RECALL_ENABLED=false` for store-only | +| Class **`AgentMemory`** missing from schema | `task infra:weaviate:down && task infra:weaviate:up`; verify: `curl -s http://localhost:8080/v1/schema \| jq '.classes[].class'` | +| Compose / API key errors | Set `EMBEDDING_OPENAI_APIKEY`, then `task infra:weaviate:down && task infra:weaviate:up` | +| Connection refused `:8080` | `task infra:status`, `curl -s http://localhost:8080/v1/.well-known/ready`, `docker logs weaviate` | +| Run 2 does not recall run 1 | Same `MEMORY_USER_ID`; ensure run 1 completed and run-end store succeeded (check `LOG_LEVEL=debug` or `SHOW_TELEMETRY=true`) | +| Port 8080 / 50051 in use | `task infra:weaviate:down`; set `WEAVIATE_HTTP_PORT` / `WEAVIATE_GRPC_PORT` before `up` | + +```bash +LOG_LEVEL=debug go run ./agent_with_memory/weaviate +``` + +--- + +## pgvector + +Client-side **OpenAI-compatible** embeddings, then cosine search in Postgres ([pgvector](https://github.com/pgvector/pgvector)). + +### Setup + +```bash +cd examples +task infra:pgvector:up +task infra:pgvector:down # when finished +``` + +Schema: [`docker/pgvector/setup.sql`](../docker/pgvector/setup.sql) (table **`agent_memories`**). Seed: [`docker/pgvector/seed.sh`](../docker/pgvector/seed.sh). + +Default DSN (in **`.env.defaults`**): `postgres://postgres:secret@localhost:5432/vectordb?sslmode=disable` + +Verify the memory table exists: + +```bash +docker exec pgvector psql -U postgres -d vectordb -c "\d agent_memories" +``` + +### Environment + +```bash +PGVECTOR_DSN=postgres://postgres:secret@localhost:5432/vectordb?sslmode=disable +PGVECTOR_MEMORY_TABLE=agent_memories +EMBEDDING_OPENAI_MODEL=text-embedding-3-small +EMBEDDING_OPENAI_APIKEY=sk-... +MEMORY_USER_ID=demo-user +MEMORY_RECALL_ENABLED=true +MEMORY_RECALL_LIMIT=10 +MEMORY_RECALL_MIN_SCORE=0.35 +``` + +With **Anthropic/Gemini** chat, `EMBEDDING_OPENAI_APIKEY` is still required (not `LLM_APIKEY`). + +### Run + +```bash +go run ./agent_with_memory/pgvector +go run ./agent_with_memory/pgvector "What answer style do I prefer?" +``` + +```bash +SHOW_TELEMETRY=true go run ./agent_with_memory/pgvector +``` + +### pgvector troubleshooting + +| Symptom | What to do | +|---------|------------| +| `relation "agent_memories" does not exist` | `task infra:pgvector:down && task infra:pgvector:up`; verify with `\d agent_memories` above | +| `embedding config` / Anthropic chat | Set `EMBEDDING_OPENAI_APIKEY`; re-run `task infra:pgvector:up` | +| `PGVECTOR_DSN is required` | Use default DSN or match compose `PGVECTOR_*` vars | +| Run 2 does not recall run 1 | Same `MEMORY_USER_ID`; ensure run 1 finished without error and the LLM called `save_memory` (check `LOG_LEVEL=debug` or `SHOW_TELEMETRY=true`) | +| Dimension / SQL errors | Model must match `vector(1536)` in `setup.sql` | +| Port 5432 in use | `task infra:pgvector:down`; set `PGVECTOR_PORT` and update `PGVECTOR_DSN` | + +```bash +LOG_LEVEL=debug go run ./agent_with_memory/pgvector +``` diff --git a/examples/agent_with_memory/common/config.go b/examples/agent_with_memory/common/config.go new file mode 100644 index 0000000..6ad23f3 --- /dev/null +++ b/examples/agent_with_memory/common/config.go @@ -0,0 +1,126 @@ +// Package common holds shared configuration and agent options for the agent_with_memory examples. +package common + +import ( + "fmt" + "os" + "strconv" + "strings" + + "github.com/agenticenv/agent-sdk-go/pkg/memory" +) + +// Settings holds env-driven values shared by the weaviate and pgvector example entry points. +type Settings struct { + UserID string + StoreMode memory.StoreMode + RecallEnabled bool + RecallLimit int + RecallMinScore float32 + + // Weaviate + WeaviateHost string + WeaviateScheme string + WeaviateClass string + WeaviateMemoryClass string + + // PostgreSQL / pgvector + PGDSN string + PGMemoryTable string + PGEmbeddingCol string + EmbeddingModel string + EmbeddingBaseURL string + EmbeddingAPIKey string +} + +func getEnv(key, def string) string { + if v := os.Getenv(key); v != "" { + return v + } + return def +} + +func getEnvInt(key string, def int) int { + if v := os.Getenv(key); v != "" { + if i, err := strconv.Atoi(v); err == nil { + return i + } + } + return def +} + +func getEnvBool(key string, def bool) bool { + v := strings.TrimSpace(os.Getenv(key)) + if v == "" { + return def + } + switch strings.ToLower(v) { + case "1", "true", "yes", "on": + return true + case "0", "false", "no", "off": + return false + default: + return def + } +} + +func getEnvFloat(key string, def float64) float64 { + if v := os.Getenv(key); v != "" { + if f, err := strconv.ParseFloat(v, 64); err == nil { + return f + } + } + return def +} + +// ParseStoreMode reads MEMORY_STORE_MODE (always | ondemand). Empty defaults to ondemand. +func ParseStoreMode(raw string) (memory.StoreMode, error) { + switch strings.ToLower(strings.TrimSpace(raw)) { + case "", "ondemand", "on-demand", "on_demand": + return memory.StoreModeOnDemand, nil + case "always": + return memory.StoreModeAlways, nil + default: + return "", fmt.Errorf("MEMORY_STORE_MODE must be always or ondemand, got %q", raw) + } +} + +// StoreModeHint returns a one-line demo hint for the configured store mode. +func StoreModeHint(mode memory.StoreMode) string { + if mode == memory.StoreModeAlways { + return "no args runs run-end store (run 1) then recall (run 2); pass one prompt for a single run" + } + return "no args runs save_memory (run 1) then recall (run 2); pass one prompt for a single run" +} + +// LoadSettings reads memory example env vars. LLM and Temporal vars come from examples/config.LoadFromEnv. +func LoadSettings() (*Settings, error) { + storeMode, err := ParseStoreMode(getEnv("MEMORY_STORE_MODE", "ondemand")) + if err != nil { + return nil, err + } + s := &Settings{ + UserID: getEnv("MEMORY_USER_ID", "demo-user"), + StoreMode: storeMode, + RecallEnabled: getEnvBool("MEMORY_RECALL_ENABLED", true), + RecallLimit: getEnvInt("MEMORY_RECALL_LIMIT", 10), + RecallMinScore: float32(getEnvFloat("MEMORY_RECALL_MIN_SCORE", 0.35)), + WeaviateHost: getEnv("WEAVIATE_HOST", "localhost:8080"), + WeaviateScheme: getEnv("WEAVIATE_SCHEME", "http"), + WeaviateClass: getEnv("WEAVIATE_CLASS", "Document"), + WeaviateMemoryClass: getEnv("WEAVIATE_MEMORY_CLASS", "AgentMemory"), + PGDSN: strings.TrimSpace(getEnv("PGVECTOR_DSN", "")), + PGMemoryTable: getEnv("PGVECTOR_MEMORY_TABLE", "agent_memories"), + PGEmbeddingCol: getEnv("PGVECTOR_EMBEDDING_COL", "embedding"), + EmbeddingModel: getEnv("EMBEDDING_OPENAI_MODEL", "text-embedding-3-small"), + EmbeddingBaseURL: strings.TrimSpace(getEnv("EMBEDDING_OPENAI_BASEURL", "")), + EmbeddingAPIKey: strings.TrimSpace(getEnv("EMBEDDING_OPENAI_APIKEY", "")), + } + if s.EmbeddingBaseURL == "" { + s.EmbeddingBaseURL = strings.TrimSpace(getEnv("LLM_BASEURL", "https://api.openai.com/v1")) + } + if s.RecallLimit <= 0 { + return nil, fmt.Errorf("MEMORY_RECALL_LIMIT must be positive, got %d", s.RecallLimit) + } + return s, nil +} diff --git a/examples/agent_with_memory/common/config_test.go b/examples/agent_with_memory/common/config_test.go new file mode 100644 index 0000000..1ce0f66 --- /dev/null +++ b/examples/agent_with_memory/common/config_test.go @@ -0,0 +1,35 @@ +package common + +import ( + "testing" + + "github.com/agenticenv/agent-sdk-go/pkg/memory" +) + +func TestParseStoreMode(t *testing.T) { + t.Parallel() + + tests := []struct { + raw string + want memory.StoreMode + }{ + {"", memory.StoreModeOnDemand}, + {"ondemand", memory.StoreModeOnDemand}, + {"on-demand", memory.StoreModeOnDemand}, + {"always", memory.StoreModeAlways}, + } + + for _, tt := range tests { + got, err := ParseStoreMode(tt.raw) + if err != nil { + t.Fatalf("ParseStoreMode(%q): %v", tt.raw, err) + } + if got != tt.want { + t.Fatalf("ParseStoreMode(%q) = %q, want %q", tt.raw, got, tt.want) + } + } + + if _, err := ParseStoreMode("invalid"); err == nil { + t.Fatal("expected error for invalid mode") + } +} diff --git a/examples/agent_with_memory/common/embed_openai.go b/examples/agent_with_memory/common/embed_openai.go new file mode 100644 index 0000000..ce7010f --- /dev/null +++ b/examples/agent_with_memory/common/embed_openai.go @@ -0,0 +1,77 @@ +package common + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + pgmem "github.com/agenticenv/agent-sdk-go/pkg/memory/pgvector" +) + +// OpenAIEmbedFunc returns a [pgmem.EmbedFunc] that calls an OpenAI-compatible embeddings API. +func OpenAIEmbedFunc(settings *Settings) (pgmem.EmbedFunc, error) { + if settings == nil { + return nil, fmt.Errorf("embed: settings is nil") + } + if settings.EmbeddingAPIKey == "" { + return nil, fmt.Errorf("embed: EMBEDDING_OPENAI_APIKEY is required for pgvector") + } + model := strings.TrimSpace(settings.EmbeddingModel) + if model == "" { + return nil, fmt.Errorf("embed: EMBEDDING_OPENAI_MODEL is required") + } + base := strings.TrimRight(strings.TrimSpace(settings.EmbeddingBaseURL), "/") + client := &http.Client{Timeout: 60 * time.Second} + + return func(ctx context.Context, text string) ([]float32, error) { + body, err := json.Marshal(map[string]any{ + "input": text, + "model": model, + }) + if err != nil { + return nil, err + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, base+"/embeddings", bytes.NewReader(body)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+settings.EmbeddingAPIKey) + + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("embeddings API %s: %s", resp.Status, strings.TrimSpace(string(raw))) + } + + var parsed struct { + Data []struct { + Embedding []float64 `json:"embedding"` + } `json:"data"` + } + if err := json.Unmarshal(raw, &parsed); err != nil { + return nil, err + } + if len(parsed.Data) == 0 || len(parsed.Data[0].Embedding) == 0 { + return nil, fmt.Errorf("embeddings API returned no vectors") + } + out := make([]float32, len(parsed.Data[0].Embedding)) + for i, v := range parsed.Data[0].Embedding { + out[i] = float32(v) + } + return out, nil + }, nil +} diff --git a/examples/agent_with_memory/common/embedding.go b/examples/agent_with_memory/common/embedding.go new file mode 100644 index 0000000..487b9e3 --- /dev/null +++ b/examples/agent_with_memory/common/embedding.go @@ -0,0 +1,14 @@ +package common + +import "fmt" + +// ValidateEmbeddingConfig ensures pgvector can call an OpenAI-compatible embeddings API. +func ValidateEmbeddingConfig(settings *Settings) error { + if settings == nil { + return fmt.Errorf("settings is nil") + } + if settings.EmbeddingAPIKey == "" { + return fmt.Errorf("EMBEDDING_OPENAI_APIKEY is required for pgvector memory embeddings (OpenAI-compatible; separate from LLM_APIKEY)") + } + return nil +} diff --git a/examples/agent_with_memory/common/opts.go b/examples/agent_with_memory/common/opts.go new file mode 100644 index 0000000..bfe916e --- /dev/null +++ b/examples/agent_with_memory/common/opts.go @@ -0,0 +1,57 @@ +package common + +import ( + "fmt" + + excfg "github.com/agenticenv/agent-sdk-go/examples" + "github.com/agenticenv/agent-sdk-go/pkg/agent" + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" + "github.com/agenticenv/agent-sdk-go/pkg/logger" + "github.com/agenticenv/agent-sdk-go/pkg/memory" +) + +// MemoryConfig builds [memory.Config] from settings, backend, and store mode. +func MemoryConfig(store interfaces.Memory, settings *Settings, mode memory.StoreMode) memory.Config { + cfg := memory.DefaultConfig(store) + cfg.Store.Mode = mode + cfg.Recall = memory.RecallConfig{ + Enabled: settings.RecallEnabled, + Limit: settings.RecallLimit, + MinScore: settings.RecallMinScore, + } + return cfg +} + +// AgentOptions builds shared agent options: runtime, LLM, memory, and system prompt. +func AgentOptions( + cfg *excfg.Config, + llmClient interfaces.LLMClient, + log logger.Logger, + settings *Settings, + memCfg memory.Config, + backendLabel string, +) []agent.Option { + recallNote := "recall enabled before each run" + if !settings.RecallEnabled { + recallNote = "store-only (recall disabled)" + } + prompt := fmt.Sprintf( + "You are a helpful assistant with long-term memory backed by %s (%s). "+ + "When the system prompt includes a Relevant Memories section, treat those as facts from prior runs and answer from them.", + backendLabel, + recallNote, + ) + if memCfg.Store.Mode == memory.StoreModeOnDemand { + prompt += " When the user asks you to remember something for future runs, persist it with your tools before acknowledging." + } + opts := []agent.Option{ + agent.WithName(fmt.Sprintf("agent-with-memory-%s", backendLabel)), + agent.WithDescription(fmt.Sprintf("Agent with %s long-term memory", backendLabel)), + agent.WithSystemPrompt(prompt), + agent.WithLLMClient(llmClient), + agent.WithLogger(log), + agent.WithMemory(memCfg), + agent.WithToolApprovalPolicy(agent.AutoToolApprovalPolicy()), + } + return append(opts, excfg.RuntimeOption(cfg)...) +} diff --git a/examples/agent_with_memory/common/run.go b/examples/agent_with_memory/common/run.go new file mode 100644 index 0000000..1a27a5d --- /dev/null +++ b/examples/agent_with_memory/common/run.go @@ -0,0 +1,76 @@ +package common + +import ( + "context" + "fmt" + "log" + "os" + "strings" + + "github.com/agenticenv/agent-sdk-go/examples/shared" + "github.com/agenticenv/agent-sdk-go/pkg/agent" + "github.com/agenticenv/agent-sdk-go/pkg/memory" +) + +const ( + // DefaultStorePrompt is run 1 in the two-turn demo. + DefaultStorePrompt = "Remember for all future runs: I prefer concise answers. Persist this preference to long-term memory before you reply." + // defaultStoreRetryPrompt is used when the first on-demand store run did not persist anything. + defaultStoreRetryPrompt = "Persist this to long-term memory before replying: I prefer concise answers in all future conversations." + // DefaultRecallPrompt is run 2 in the two-turn demo. + DefaultRecallPrompt = "What answer style do I prefer?" +) + +// ScopedContext attaches the demo user id for memory scope resolution. +func ScopedContext(ctx context.Context, userID string) context.Context { + return memory.WithContextUserID(ctx, userID) +} + +// RunAgent executes one prompt and prints the assistant reply plus optional footers. +func RunAgent(ctx context.Context, a *agent.Agent, userID, label, prompt string) *agent.AgentRunResult { + fmt.Printf("\n--- %s ---\n", label) + fmt.Println("user:", prompt) + result, err := a.Run(ScopedContext(ctx, userID), prompt, nil) + if err != nil { + log.Printf("%s failed: %v", label, err) + return nil + } + fmt.Println("assistant:", result.Content) + shared.PrintRunFooters(result) + return result +} + +func memoryStores(result *agent.AgentRunResult) int64 { + if result == nil || result.Telemetry == nil { + return 0 + } + return result.Telemetry.Storage.TotalMemoryStores +} + +func runOnDemandStoreDemo(ctx context.Context, a *agent.Agent, userID string) { + result := RunAgent(ctx, a, userID, "run 1 (save_memory)", DefaultStorePrompt) + if memoryStores(result) > 0 { + return + } + fmt.Println("warning: no memory was stored on run 1 (the model may have skipped the memory tool); retrying") + RunAgent(ctx, a, userID, "run 1 retry (save_memory)", defaultStoreRetryPrompt) +} + +// RunFromArgs runs the two-turn store/recall demo when no CLI args are given; otherwise a single custom prompt. +func RunFromArgs(ctx context.Context, a *agent.Agent, userID string, storeMode memory.StoreMode) { + args := os.Args[1:] + if len(args) == 0 { + run1Label := "run 1 (save_memory)" + if storeMode == memory.StoreModeAlways { + run1Label = "run 1 (run-end store)" + } + if storeMode == memory.StoreModeOnDemand { + runOnDemandStoreDemo(ctx, a, userID) + } else { + RunAgent(ctx, a, userID, run1Label, DefaultStorePrompt) + } + RunAgent(ctx, a, userID, "run 2 (recall)", DefaultRecallPrompt) + return + } + RunAgent(ctx, a, userID, "run", strings.Join(args, " ")) +} diff --git a/examples/agent_with_memory/pgvector/main.go b/examples/agent_with_memory/pgvector/main.go new file mode 100644 index 0000000..877d156 --- /dev/null +++ b/examples/agent_with_memory/pgvector/main.go @@ -0,0 +1,70 @@ +// Example agent using PostgreSQL pgvector for long-term memory. +// +// Run from examples/ (no args = two-turn store then recall demo): +// +// go run ./agent_with_memory/pgvector +// MEMORY_STORE_MODE=always go run ./agent_with_memory/pgvector +package main + +import ( + "context" + "fmt" + "log" + + examplecfg "github.com/agenticenv/agent-sdk-go/examples" + "github.com/agenticenv/agent-sdk-go/examples/agent_with_memory/common" + "github.com/agenticenv/agent-sdk-go/pkg/agent" + pgmem "github.com/agenticenv/agent-sdk-go/pkg/memory/pgvector" +) + +func main() { + cfg := examplecfg.LoadFromEnv() + memCfg, err := common.LoadSettings() + if err != nil { + log.Fatalf("memory config: %v", err) + } + if memCfg.PGDSN == "" { + log.Fatal("PGVECTOR_DSN is required for the pgvector memory example; see ../README.md") + } + if err := common.ValidateEmbeddingConfig(memCfg); err != nil { + log.Fatalf("embedding config: %v", err) + } + + llmClient, err := examplecfg.NewLLMClientFromConfig(cfg) + if err != nil { + log.Fatalf("failed to create LLM client: %v", err) + } + logr := examplecfg.NewLoggerFromLogConfig(cfg) + + embed, err := common.OpenAIEmbedFunc(memCfg) + if err != nil { + log.Fatalf("embed func: %v", err) + } + + store, err := pgmem.NewMemory(embed, + pgmem.WithDSN(memCfg.PGDSN), + pgmem.WithTable(memCfg.PGMemoryTable), + pgmem.WithEmbeddingCol(memCfg.PGEmbeddingCol), + pgmem.WithDefaultLimit(memCfg.RecallLimit), + pgmem.WithDefaultMinScore(memCfg.RecallMinScore), + pgmem.WithLogger(logr), + ) + if err != nil { + log.Fatalf("pgvector memory: %v", err) + } + + memoryConfig := common.MemoryConfig(store, memCfg, memCfg.StoreMode) + opts := common.AgentOptions(cfg, llmClient, logr, memCfg, memoryConfig, "pgvector") + + a, err := agent.NewAgent(opts...) + if err != nil { + log.Fatal(examplecfg.FormatNewAgentError("failed to create agent", err)) + } + defer a.Close() + + fmt.Printf("backend: pgvector table: %s user: %s store: %s recall: %v limit: %d\n", + memCfg.PGMemoryTable, memCfg.UserID, memCfg.StoreMode, memCfg.RecallEnabled, memCfg.RecallLimit) + fmt.Println("hint:", common.StoreModeHint(memCfg.StoreMode)) + + common.RunFromArgs(context.Background(), a, memCfg.UserID, memCfg.StoreMode) +} diff --git a/examples/agent_with_memory/weaviate/main.go b/examples/agent_with_memory/weaviate/main.go new file mode 100644 index 0000000..12ad845 --- /dev/null +++ b/examples/agent_with_memory/weaviate/main.go @@ -0,0 +1,61 @@ +// Example agent using Weaviate for long-term memory. +// +// Run from examples/ (no args = two-turn store then recall demo): +// +// go run ./agent_with_memory/weaviate +// MEMORY_STORE_MODE=always go run ./agent_with_memory/weaviate +package main + +import ( + "context" + "fmt" + "log" + + examplecfg "github.com/agenticenv/agent-sdk-go/examples" + "github.com/agenticenv/agent-sdk-go/examples/agent_with_memory/common" + "github.com/agenticenv/agent-sdk-go/pkg/agent" + wmem "github.com/agenticenv/agent-sdk-go/pkg/memory/weaviate" +) + +func main() { + cfg := examplecfg.LoadFromEnv() + memCfg, err := common.LoadSettings() + if err != nil { + log.Fatalf("memory config: %v", err) + } + + llmClient, err := examplecfg.NewLLMClientFromConfig(cfg) + if err != nil { + log.Fatalf("failed to create LLM client: %v", err) + } + logr := examplecfg.NewLoggerFromLogConfig(cfg) + + wOpts := []wmem.Option{ + wmem.WithHost(memCfg.WeaviateHost), + wmem.WithScheme(memCfg.WeaviateScheme), + wmem.WithClassName(memCfg.WeaviateMemoryClass), + wmem.WithDefaultLimit(memCfg.RecallLimit), + wmem.WithDefaultMinScore(memCfg.RecallMinScore), + wmem.WithLogger(logr), + } + + store, err := wmem.NewMemory(wOpts...) + if err != nil { + log.Fatalf("weaviate memory: %v", err) + } + + memoryConfig := common.MemoryConfig(store, memCfg, memCfg.StoreMode) + opts := common.AgentOptions(cfg, llmClient, logr, memCfg, memoryConfig, "weaviate") + + a, err := agent.NewAgent(opts...) + if err != nil { + log.Fatal(examplecfg.FormatNewAgentError("failed to create agent", err)) + } + defer a.Close() + + fmt.Printf("backend: weaviate class: %s user: %s store: %s recall: %v limit: %d\n", + memCfg.WeaviateMemoryClass, memCfg.UserID, memCfg.StoreMode, memCfg.RecallEnabled, memCfg.RecallLimit) + fmt.Println("hint:", common.StoreModeHint(memCfg.StoreMode)) + + common.RunFromArgs(context.Background(), a, memCfg.UserID, memCfg.StoreMode) +} diff --git a/examples/agent_with_stream_conversation/main.go b/examples/agent_with_stream_conversation/main.go index 4a60fe9..d122eb0 100644 --- a/examples/agent_with_stream_conversation/main.go +++ b/examples/agent_with_stream_conversation/main.go @@ -11,6 +11,7 @@ import ( config "github.com/agenticenv/agent-sdk-go/examples" "github.com/agenticenv/agent-sdk-go/examples/shared" "github.com/agenticenv/agent-sdk-go/pkg/agent" + "github.com/agenticenv/agent-sdk-go/pkg/conversation" "github.com/agenticenv/agent-sdk-go/pkg/conversation/inmem" "github.com/agenticenv/agent-sdk-go/pkg/tools/calculator" "github.com/agenticenv/agent-sdk-go/pkg/tools/echo" @@ -43,8 +44,7 @@ func main() { agent.WithStream(true), agent.WithToolRegistry(reg), agent.WithToolApprovalPolicy(agent.AutoToolApprovalPolicy()), - agent.WithConversation(conv), - agent.WithConversationSize(20), + agent.WithConversation(conversation.DefaultConfig(conv)), agent.WithLogger(config.NewLoggerFromLogConfig(cfg)), } opts = append(opts, config.RuntimeOption(cfg)...) diff --git a/examples/docker/pgvector/setup.sql b/examples/docker/pgvector/setup.sql index 4dac62e..45a21df 100644 --- a/examples/docker/pgvector/setup.sql +++ b/examples/docker/pgvector/setup.sql @@ -12,3 +12,22 @@ CREATE TABLE IF NOT EXISTS documents ( CREATE INDEX IF NOT EXISTS documents_embedding_idx ON documents USING hnsw (embedding vector_cosine_ops); + +-- Long-term memory table for agent_with_memory/pgvector (no seed rows — filled by agent runs). +CREATE TABLE IF NOT EXISTS agent_memories ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + text TEXT NOT NULL, + kind TEXT NOT NULL DEFAULT '', + user_id TEXT, + tenant_id TEXT, + agent_id TEXT, + scope_tags TEXT[] NOT NULL DEFAULT '{}', + metadata JSONB NOT NULL DEFAULT '{}', + expires_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + embedding vector(1536) +); + +CREATE INDEX IF NOT EXISTS agent_memories_embedding_idx + ON agent_memories USING hnsw (embedding vector_cosine_ops); diff --git a/examples/docker/weaviate/seed.sh b/examples/docker/weaviate/seed.sh index c6db35b..58664b1 100755 --- a/examples/docker/weaviate/seed.sh +++ b/examples/docker/weaviate/seed.sh @@ -92,3 +92,24 @@ while IFS= read -r row; do done < <(jq -c '.[]' "$DOCS_FILE") echo "Inserted ${count} documents from sample-documents.json" + +MEMORY_CLASS="${WEAVIATE_MEMORY_CLASS:-AgentMemory}" +echo "Creating class ${MEMORY_CLASS} (long-term memory)..." +curl -sf -X POST "${WEAVIATE_URL}/v1/schema" \ + -H 'Content-Type: application/json' \ + -d "{ + \"class\": \"${MEMORY_CLASS}\", + \"vectorizer\": \"text2vec-openai\", + \"properties\": [ + {\"name\": \"text\", \"dataType\": [\"text\"]}, + {\"name\": \"kind\", \"dataType\": [\"text\"]}, + {\"name\": \"user_id\", \"dataType\": [\"text\"]}, + {\"name\": \"tenant_id\", \"dataType\": [\"text\"]}, + {\"name\": \"agent_id\", \"dataType\": [\"text\"]}, + {\"name\": \"scope_tags\", \"dataType\": [\"text[]\"]}, + {\"name\": \"metadata\", \"dataType\": [\"text\"]}, + {\"name\": \"expires_at\", \"dataType\": [\"date\"]}, + {\"name\": \"created_at\", \"dataType\": [\"date\"]}, + {\"name\": \"updated_at\", \"dataType\": [\"date\"]} + ] + }" >/dev/null || true diff --git a/examples/shared/utils.go b/examples/shared/utils.go index 0085db2..15a76f1 100644 --- a/examples/shared/utils.go +++ b/examples/shared/utils.go @@ -146,10 +146,14 @@ func TelemetryFooter(telemetry *agent.AgentTelemetry) string { } lines = append(lines, "[TELEMETRY STORAGE]", - fmt.Sprintf(" total_retriever_searches: %d", telemetry.Storage.TotalRetrieverSearches), + fmt.Sprintf(" total_retriever_searches: %d", telemetry.Storage.TotalRetrieverSearches), fmt.Sprintf(" failed_retriever_searches: %d", telemetry.Storage.FailedRetrieverSearches), - fmt.Sprintf(" prefetch_searches: %d", telemetry.Storage.PrefetchSearches), - fmt.Sprintf(" agentic_searches: %d", telemetry.Storage.AgenticSearches), + fmt.Sprintf(" prefetch_searches: %d", telemetry.Storage.PrefetchSearches), + fmt.Sprintf(" agentic_searches: %d", telemetry.Storage.AgenticSearches), + fmt.Sprintf(" total_memory_recalls: %d", telemetry.Storage.TotalMemoryRecalls), + fmt.Sprintf(" failed_memory_recalls: %d", telemetry.Storage.FailedMemoryRecalls), + fmt.Sprintf(" total_memory_stores: %d", telemetry.Storage.TotalMemoryStores), + fmt.Sprintf(" failed_memory_stores: %d", telemetry.Storage.FailedMemoryStores), ) return strings.Join(lines, "\n") } diff --git a/internal/runtime/base/memory.go b/internal/runtime/base/memory.go new file mode 100644 index 0000000..8c8547a --- /dev/null +++ b/internal/runtime/base/memory.go @@ -0,0 +1,354 @@ +package base + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "strings" + "time" + + "github.com/agenticenv/agent-sdk-go/internal/types" + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" + "github.com/agenticenv/agent-sdk-go/pkg/logger" + "github.com/agenticenv/agent-sdk-go/pkg/memory" +) + +const memoryExtractSystemPrompt = "Extract durable long-term memories from the conversation. " + + "Return only facts, preferences, decisions, or instructions worth recalling in future runs. " + + "Skip greetings, transient context, and tool noise. Return an empty memories array when nothing should persist." + +var errMemoryExtractUnavailable = errors.New("memory extract unavailable: StoreMode always requires custom Extract or LLM client") + +// StoreMemoryRecords persists records through kind policy, dedup, TTL, and the memory backend. +func (rt *Runtime) StoreMemoryRecords(ctx context.Context, log logger.Logger, scope interfaces.MemoryScope, records []interfaces.MemoryRecord) error { + if !rt.MemoryConfigured() { + return nil + } + + ctx, batchSp := rt.Tracer.StartSpan(ctx, "memory.store.batch", + interfaces.Attribute{Key: "record.count", Value: len(records)}, + ) + defer batchSp.End() + + for _, rec := range records { + if err := rt.storeRecord(ctx, log, scope, rec); err != nil { + batchSp.RecordError(err) + return err + } + } + return nil +} + +func (rt *Runtime) storeRecord(ctx context.Context, log logger.Logger, scope interfaces.MemoryScope, rec interfaces.MemoryRecord) error { + cfg := rt.AgentConfig.Memory.Config + text := strings.TrimSpace(rec.Text) + if text == "" { + return nil + } + + ctx, sp := rt.Tracer.StartSpan(ctx, "memory.store") + defer sp.End() + + kind, err := cfg.Store.ResolveKind(rec.Kind) + if err != nil { + sp.RecordError(err) + rt.Metrics.IncrementCounter(ctx, types.MetricMemoryStoreFailed) + log.Error(ctx, "runtime: memory store kind rejected", slog.String("scope", "runtime"), slog.Any("error", err)) + return fmt.Errorf("memory store: %w", err) + } + + kindAttr := interfaces.Attribute{Key: types.MetricAttrMemoryKind, Value: string(kind)} + sp.SetAttribute(string(types.MetricAttrMemoryKind), string(kind)) + log.Debug(ctx, "runtime: memory store started", slog.String("scope", "runtime"), slog.String("kind", string(kind))) + + rt.Metrics.IncrementCounter(ctx, types.MetricMemoryStoreStarted, kindAttr) + start := time.Now() + + storeOpts, dedupAction, dedupErr := rt.dedupStoreOptions(ctx, scope, text) + if dedupErr != nil { + latency := float64(time.Since(start).Milliseconds()) + sp.RecordError(dedupErr) + sp.SetAttribute("latency_ms", latency) + rt.Metrics.IncrementCounter(ctx, types.MetricMemoryStoreFailed, kindAttr) + rt.Metrics.RecordHistogram(ctx, types.MetricMemoryStoreLatencyMs, latency, kindAttr) + log.Error(ctx, "runtime: memory dedup lookup failed", slog.String("scope", "runtime"), slog.Any("error", dedupErr)) + return fmt.Errorf("memory store: dedup: %w", dedupErr) + } + + dedupAttr := interfaces.Attribute{Key: types.MetricAttrMemoryDedup, Value: dedupAction} + sp.SetAttribute(string(types.MetricAttrMemoryDedup), dedupAction) + + now := time.Now().UTC() + record := interfaces.MemoryRecord{ + Text: text, + Kind: kind, + Metadata: rec.Metadata, + ExpiresAt: cfg.ExpiresAtForKind(kind, now), + } + + if _, err := cfg.Memory.Store(ctx, scope, record, storeOpts...); err != nil { + latency := float64(time.Since(start).Milliseconds()) + sp.RecordError(err) + sp.SetAttribute("latency_ms", latency) + rt.Metrics.IncrementCounter(ctx, types.MetricMemoryStoreFailed, kindAttr) + rt.Metrics.RecordHistogram(ctx, types.MetricMemoryStoreLatencyMs, latency, kindAttr) + log.Error(ctx, "runtime: memory store failed", slog.String("scope", "runtime"), slog.Any("error", err)) + return fmt.Errorf("memory store: %w", err) + } + + latency := float64(time.Since(start).Milliseconds()) + sp.SetAttribute("latency_ms", latency) + sp.SetAttribute("dedup.upsert", dedupAction == "upsert") + rt.Metrics.IncrementCounter(ctx, types.MetricMemoryStoreCompleted, kindAttr, dedupAttr) + rt.Metrics.RecordHistogram(ctx, types.MetricMemoryStoreLatencyMs, latency, kindAttr) + log.Debug(ctx, "runtime: memory store completed", + slog.String("scope", "runtime"), + slog.String("dedup", dedupAction)) + return nil +} + +func (rt *Runtime) dedupStoreOptions(ctx context.Context, scope interfaces.MemoryScope, text string) ([]interfaces.StoreMemoryOption, string, error) { + cfg := rt.AgentConfig.Memory.Config + minScore := cfg.Store.DedupMinScore + if minScore <= 0 { + return nil, "append", nil + } + + rt.Metrics.IncrementCounter(ctx, types.MetricMemoryDedupStarted) + dedupStart := time.Now() + + ctx, sp := rt.Tracer.StartSpan(ctx, "memory.dedup", + interfaces.Attribute{Key: "min_score", Value: minScore}, + ) + defer sp.End() + + matches, err := cfg.Memory.Load(ctx, scope, text, + interfaces.WithLoadLimit(1), + interfaces.WithMinScore(minScore), + ) + latency := float64(time.Since(dedupStart).Milliseconds()) + sp.SetAttribute("latency_ms", latency) + + if err != nil { + sp.RecordError(err) + rt.Metrics.IncrementCounter(ctx, types.MetricMemoryDedupFailed) + rt.Metrics.RecordHistogram(ctx, types.MetricMemoryDedupLatencyMs, latency) + return nil, "", err + } + + rt.Metrics.IncrementCounter(ctx, types.MetricMemoryDedupCompleted) + rt.Metrics.RecordHistogram(ctx, types.MetricMemoryDedupLatencyMs, latency) + + if len(matches) == 0 { + sp.SetAttribute("matched", false) + return nil, "append", nil + } + + sp.SetAttribute("matched", true) + sp.SetAttribute("match.id", matches[0].ID) + return []interfaces.StoreMemoryOption{interfaces.WithMemoryID(matches[0].ID)}, "upsert", nil +} + +func parseSaveMemoryToolArgs(args map[string]any) (interfaces.MemoryRecord, error) { + rawText, ok := args[types.MemoryToolParamText].(string) + if !ok { + return interfaces.MemoryRecord{}, fmt.Errorf("save_memory: %q parameter required", types.MemoryToolParamText) + } + text := strings.TrimSpace(rawText) + if text == "" { + return interfaces.MemoryRecord{}, fmt.Errorf("save_memory: %q must be non-empty", types.MemoryToolParamText) + } + record := interfaces.MemoryRecord{ + Text: text, + Metadata: map[string]string{ + "source": types.SaveMemoryToolName, + }, + } + if rawKind, ok := args[types.MemoryToolParamKind].(string); ok { + record.Kind = interfaces.MemoryKind(strings.TrimSpace(rawKind)) + } + return record, nil +} + +func (rt *Runtime) extractMemoryRecords( + ctx context.Context, + log logger.Logger, + messages []interfaces.Message, + extract memory.ExtractFunc, +) ([]interfaces.MemoryRecord, error) { + rt.Metrics.IncrementCounter(ctx, types.MetricMemoryExtractStarted) + start := time.Now() + + ctx, sp := rt.Tracer.StartSpan(ctx, "memory.extract", + interfaces.Attribute{Key: "message.count", Value: len(messages)}, + ) + defer sp.End() + + log.Debug(ctx, "runtime: memory extract started", slog.String("scope", "runtime")) + + records, err := extract(ctx, messages) + latency := float64(time.Since(start).Milliseconds()) + sp.SetAttribute("latency_ms", latency) + + if err != nil { + sp.RecordError(err) + rt.Metrics.IncrementCounter(ctx, types.MetricMemoryExtractFailed) + rt.Metrics.RecordHistogram(ctx, types.MetricMemoryExtractLatencyMs, latency) + log.Error(ctx, "runtime: memory extract failed", slog.String("scope", "runtime"), slog.Any("error", err)) + return nil, fmt.Errorf("memory store: extract: %w", err) + } + + sp.SetAttribute("record.count", len(records)) + rt.Metrics.IncrementCounter(ctx, types.MetricMemoryExtractCompleted) + rt.Metrics.RecordHistogram(ctx, types.MetricMemoryExtractLatencyMs, latency) + log.Debug(ctx, "runtime: memory extract completed", + slog.String("scope", "runtime"), + slog.Int("records", len(records))) + + return records, nil +} + +func (rt *Runtime) recordMemoryExtractUnavailable(ctx context.Context, log logger.Logger) { + ctx, sp := rt.Tracer.StartSpan(ctx, "memory.extract") + defer sp.End() + sp.RecordError(errMemoryExtractUnavailable) + sp.SetAttribute("reason", "no_extractor") + + rt.Metrics.IncrementCounter(ctx, types.MetricMemoryExtractFailed) + log.Warn(ctx, "runtime: memory extract unavailable", + slog.String("scope", "runtime"), + slog.Any("error", errMemoryExtractUnavailable)) +} + +// resolveMemoryExtractFunc returns the user Extract hook or the SDK default when Always store is enabled. +func (rt *Runtime) resolveMemoryExtractFunc() memory.ExtractFunc { + if !rt.RunEndMemoryStoreEnabled() { + return nil + } + if extract := rt.AgentConfig.Memory.Config.Store.Extract; extract != nil { + return extract + } + if client := rt.AgentConfig.LLM.Client; client != nil { + return defaultMemoryExtractFunc(client) + } + return nil +} + +func defaultMemoryExtractFunc(client interfaces.LLMClient) memory.ExtractFunc { + return func(ctx context.Context, messages []interfaces.Message) ([]interfaces.MemoryRecord, error) { + return extractMemoriesWithLLM(ctx, client, messages) + } +} + +func extractMemoriesWithLLM(ctx context.Context, client interfaces.LLMClient, messages []interfaces.Message) ([]interfaces.MemoryRecord, error) { + msgs := messagesForMemoryExtraction(messages) + if len(msgs) == 0 { + return nil, nil + } + + resp, err := client.Generate(ctx, &interfaces.LLMRequest{ + SystemMessage: memoryExtractSystemPrompt, + Messages: msgs, + ResponseFormat: memoryExtractResponseFormat(), + }) + if err != nil { + return nil, fmt.Errorf("memory extract: llm: %w", err) + } + + return parseMemoryExtractResponse(resp.Content) +} + +const memoryExtractTurnPrompt = "Extract durable memories from the conversation above." + +func messagesForMemoryExtraction(messages []interfaces.Message) []interfaces.Message { + out := make([]interfaces.Message, 0, len(messages)+1) + for _, m := range messages { + switch m.Role { + case interfaces.MessageRoleUser, interfaces.MessageRoleAssistant: + if strings.TrimSpace(m.Content) != "" { + out = append(out, m) + } + } + } + if len(out) == 0 { + return nil + } + // Structured output providers (e.g. Anthropic) reject assistant as the final message. + if out[len(out)-1].Role == interfaces.MessageRoleAssistant { + out = append(out, interfaces.Message{ + Role: interfaces.MessageRoleUser, + Content: memoryExtractTurnPrompt, + }) + } + return out +} + +func memoryExtractResponseFormat() *interfaces.ResponseFormat { + return &interfaces.ResponseFormat{ + Type: interfaces.ResponseFormatJSON, + Name: "MemoryExtraction", + Schema: interfaces.JSONSchema{ + "type": "object", + "properties": interfaces.JSONSchema{ + "memories": interfaces.JSONSchema{ + "type": "array", + "items": interfaces.JSONSchema{ + "type": "object", + "properties": interfaces.JSONSchema{ + "text": interfaces.JSONSchema{ + "type": "string", + "description": "Distilled memory text", + }, + "kind": interfaces.JSONSchema{ + "type": "string", + "description": "Optional kind: preference, fact, decision, instruction, note", + }, + }, + "required": []any{"text"}, + "additionalProperties": false, + }, + }, + }, + "required": []any{"memories"}, + "additionalProperties": false, + }, + } +} + +func parseMemoryExtractResponse(content string) ([]interfaces.MemoryRecord, error) { + content = strings.TrimSpace(content) + if content == "" { + return nil, nil + } + + var parsed struct { + Memories []struct { + Text string `json:"text"` + Kind string `json:"kind"` + } `json:"memories"` + } + if err := json.Unmarshal([]byte(content), &parsed); err != nil { + return nil, fmt.Errorf("memory extract: parse response: %w", err) + } + + records := make([]interfaces.MemoryRecord, 0, len(parsed.Memories)) + for _, m := range parsed.Memories { + text := strings.TrimSpace(m.Text) + if text == "" { + continue + } + rec := interfaces.MemoryRecord{ + Text: text, + Metadata: map[string]string{ + "source": "extract", + }, + } + if kind := strings.TrimSpace(m.Kind); kind != "" { + rec.Kind = interfaces.MemoryKind(kind) + } + records = append(records, rec) + } + return records, nil +} diff --git a/internal/runtime/base/runtime.go b/internal/runtime/base/runtime.go index 7a9cfa1..838842f 100644 --- a/internal/runtime/base/runtime.go +++ b/internal/runtime/base/runtime.go @@ -16,6 +16,7 @@ import ( "github.com/agenticenv/agent-sdk-go/internal/types" "github.com/agenticenv/agent-sdk-go/pkg/interfaces" "github.com/agenticenv/agent-sdk-go/pkg/logger" + "github.com/agenticenv/agent-sdk-go/pkg/memory" "github.com/google/uuid" ) @@ -37,18 +38,22 @@ type ExecuteLLMInput struct { MessageID string Messages []interfaces.Message SkipTools bool + MemoryContext string RetrieverContext string Tools []interfaces.Tool Emit func(events.AgentEvent) } // BuildLLMRequest constructs an LLMRequest from the given messages and options. -// When retrieverContext is non-empty it is appended to the system prompt (prefetch/hybrid mode). +// When memoryContext or retrieverContext is non-empty each is appended to the system prompt. // tools is the per-run resolved tool list from [runtime.ExecuteRequest] or activity resolve. -func (rt *Runtime) BuildLLMRequest(messages []interfaces.Message, skipTools bool, retrieverContext string, tools []interfaces.Tool) *interfaces.LLMRequest { +func (rt *Runtime) BuildLLMRequest(messages []interfaces.Message, skipTools bool, memoryContext, retrieverContext string, tools []interfaces.Tool) *interfaces.LLMRequest { systemMessage := rt.AgentSpec.SystemPrompt + if memoryContext != "" { + systemMessage = fmt.Sprintf("%s\n\nRelevant Memories:\n%s", systemMessage, memoryContext) + } if retrieverContext != "" { - systemMessage = fmt.Sprintf("%s\n\nRelevant Context:\n%s", rt.AgentSpec.SystemPrompt, retrieverContext) + systemMessage = fmt.Sprintf("%s\n\nRelevant Context:\n%s", systemMessage, retrieverContext) } req := &interfaces.LLMRequest{ SystemMessage: systemMessage, @@ -146,7 +151,7 @@ func emitEvent(fn func(events.AgentEvent), ev events.AgentEvent) { // TEXT_MESSAGE_START / TEXT_MESSAGE_CONTENT / TEXT_MESSAGE_END events, and returns LLMResult. // messageID and agentName are used only for event construction; emit may be nil. func (rt *Runtime) ExecuteLLM(ctx context.Context, input ExecuteLLMInput) (*LLMResult, error) { - req := rt.BuildLLMRequest(input.Messages, input.SkipTools, input.RetrieverContext, input.Tools) + req := rt.BuildLLMRequest(input.Messages, input.SkipTools, input.MemoryContext, input.RetrieverContext, input.Tools) llmClient := rt.AgentConfig.LLM.Client model := llmClient.GetModel() @@ -201,7 +206,7 @@ func (rt *Runtime) ExecuteLLM(ctx context.Context, input ExecuteLLMInput) (*LLMR // emit as chunks arrive; a final TEXT_MESSAGE_START/CONTENT/END triple is emitted for non-streaming // fallback. emit may be nil. func (rt *Runtime) ExecuteLLMStream(ctx context.Context, input ExecuteLLMInput) (*LLMResult, error) { - req := rt.BuildLLMRequest(input.Messages, input.SkipTools, input.RetrieverContext, input.Tools) + req := rt.BuildLLMRequest(input.Messages, input.SkipTools, input.MemoryContext, input.RetrieverContext, input.Tools) llmClient := rt.AgentConfig.LLM.Client model := llmClient.GetModel() @@ -401,6 +406,48 @@ func (rt *Runtime) ExecuteTool(ctx context.Context, log logger.Logger, tools []i return fmt.Sprintf("%v", result), nil } +// ExecuteToolWithMemoryScope runs a tool; save_memory on on-demand store routes to [StoreMemoryRecords]. +func (rt *Runtime) ExecuteToolWithMemoryScope(ctx context.Context, log logger.Logger, tools []interfaces.Tool, toolName string, args map[string]any, memScope interfaces.MemoryScope) (string, error) { + if toolName == types.SaveMemoryToolName && rt.MemoryStoreOnDemand() { + return rt.executeSaveMemoryTool(ctx, log, memScope, args) + } + return rt.ExecuteTool(ctx, log, tools, toolName, args) +} + +func (rt *Runtime) executeSaveMemoryTool(ctx context.Context, log logger.Logger, scope interfaces.MemoryScope, args map[string]any) (string, error) { + toolAttr := interfaces.Attribute{Key: types.MetricAttrTool, Value: types.SaveMemoryToolName} + rt.Metrics.IncrementCounter(ctx, types.MetricToolCallStarted, toolAttr) + toolStart := time.Now() + + ctx, sp := rt.Tracer.StartSpan(ctx, "tool.execute", + interfaces.Attribute{Key: "tool.name", Value: types.SaveMemoryToolName}, + interfaces.Attribute{Key: "arg.count", Value: len(args)}, + ) + defer sp.End() + + record, err := parseSaveMemoryToolArgs(args) + if err != nil { + toolLatency := float64(time.Since(toolStart).Milliseconds()) + sp.RecordError(err) + rt.Metrics.IncrementCounter(ctx, types.MetricToolCallFailed, toolAttr) + rt.Metrics.RecordHistogram(ctx, types.MetricToolLatencyMs, toolLatency, toolAttr) + return "", err + } + + if err := rt.StoreMemoryRecords(ctx, log, scope, []interfaces.MemoryRecord{record}); err != nil { + toolLatency := float64(time.Since(toolStart).Milliseconds()) + sp.RecordError(err) + rt.Metrics.IncrementCounter(ctx, types.MetricToolCallFailed, toolAttr) + rt.Metrics.RecordHistogram(ctx, types.MetricToolLatencyMs, toolLatency, toolAttr) + return "", err + } + + toolLatency := float64(time.Since(toolStart).Milliseconds()) + rt.Metrics.RecordHistogram(ctx, types.MetricToolLatencyMs, toolLatency, toolAttr) + rt.Metrics.IncrementCounter(ctx, types.MetricToolCallCompleted, toolAttr) + return "memory saved", nil +} + // AuthorizeTool checks programmatic authorization for a tool before approval/execution. // Tools that do not implement interfaces.ToolAuthorizer are allowed by default. func (rt *Runtime) AuthorizeTool(ctx context.Context, log logger.Logger, tools []interfaces.Tool, toolName string, args map[string]any) (AuthorizeResult, error) { @@ -523,3 +570,140 @@ func (rt *Runtime) ExecuteRetrievers(ctx context.Context, log logger.Logger, que FailedSearches: int64(failedCount), }, nil } + +// MemoryConfigured reports whether long-term memory is wired on the runtime. +func (rt *Runtime) MemoryConfigured() bool { + return rt.AgentConfig.Memory.Config != nil && rt.AgentConfig.Memory.Config.Memory != nil +} + +// RecallEnabled reports whether the SDK should load memories before each run. +func (rt *Runtime) RecallEnabled() bool { + if !rt.MemoryConfigured() { + return false + } + return rt.AgentConfig.Memory.Config.Recall.Enabled +} + +// RunEndMemoryStoreEnabled reports whether run-end memory store runs ([memory.StoreModeAlways]). +func (rt *Runtime) RunEndMemoryStoreEnabled() bool { + if !rt.MemoryConfigured() { + return false + } + return rt.AgentConfig.Memory.Config.Store.Mode == memory.StoreModeAlways +} + +// MemoryStoreOnDemand reports whether save_memory tool store is active. +func (rt *Runtime) MemoryStoreOnDemand() bool { + if !rt.MemoryConfigured() { + return false + } + return rt.AgentConfig.Memory.Config.Store.Mode == memory.StoreModeOnDemand +} + +// ResolveMemoryScope builds scope from the request context using configured resolvers. +func (rt *Runtime) ResolveMemoryScope(ctx context.Context) (interfaces.MemoryScope, error) { + if !rt.MemoryConfigured() { + return interfaces.MemoryScope{}, nil + } + return rt.AgentConfig.Memory.Config.ScopeConfig.Resolve(ctx) +} + +// FormatMemoryEntries formats memories for injection into the LLM system prompt. +func FormatMemoryEntries(entries []interfaces.MemoryEntry) string { + if len(entries) == 0 { + return "" + } + var sb strings.Builder + for i, entry := range entries { + fmt.Fprintf(&sb, types.MemoryEntryFormat, i+1, entry.Text, entry.Kind, entry.Score) + } + return sb.String() +} + +// ExecuteMemoryRecall loads scoped memories for query and returns formatted prompt context. +func (rt *Runtime) ExecuteMemoryRecall(ctx context.Context, log logger.Logger, scope interfaces.MemoryScope, query string) (*MemoryResult, error) { + cfg := rt.AgentConfig.Memory.Config + if cfg == nil || cfg.Memory == nil { + return &MemoryResult{}, nil + } + + log.Debug(ctx, "runtime: memory recall started", + slog.String("scope", "runtime"), + slog.String("query", query)) + + rt.Metrics.IncrementCounter(ctx, types.MetricMemoryRecallStarted) + start := time.Now() + + ctx, sp := rt.Tracer.StartSpan(ctx, "memory.recall", + interfaces.Attribute{Key: "query", Value: query}, + ) + defer sp.End() + + entries, err := cfg.Memory.Load(ctx, scope, query, cfg.Recall.LoadOptions()...) + if err != nil { + latency := float64(time.Since(start).Milliseconds()) + sp.RecordError(err) + sp.SetAttribute("latency_ms", latency) + rt.Metrics.IncrementCounter(ctx, types.MetricMemoryRecallFailed) + rt.Metrics.RecordHistogram(ctx, types.MetricMemoryRecallLatencyMs, latency) + log.Error(ctx, "runtime: memory recall failed", slog.String("scope", "runtime"), slog.Any("error", err)) + return nil, fmt.Errorf("memory recall: %w", err) + } + + // Semantic recall often misses distilled memories; fall back to scoped recency list. + if len(entries) == 0 && strings.TrimSpace(query) != "" { + log.Debug(ctx, "runtime: memory recall semantic empty, trying recency fallback", + slog.String("scope", "runtime")) + fallback, fbErr := cfg.Memory.Load(ctx, scope, "", cfg.Recall.RecencyLoadOptions()...) + if fbErr != nil { + latency := float64(time.Since(start).Milliseconds()) + sp.RecordError(fbErr) + sp.SetAttribute("latency_ms", latency) + rt.Metrics.IncrementCounter(ctx, types.MetricMemoryRecallFailed) + rt.Metrics.RecordHistogram(ctx, types.MetricMemoryRecallLatencyMs, latency) + log.Error(ctx, "runtime: memory recall fallback failed", slog.String("scope", "runtime"), slog.Any("error", fbErr)) + return nil, fmt.Errorf("memory recall: %w", fbErr) + } + entries = fallback + } + + latency := float64(time.Since(start).Milliseconds()) + memoryContext := strings.TrimSpace(FormatMemoryEntries(entries)) + sp.SetAttribute("entry.count", len(entries)) + sp.SetAttribute("has_context", memoryContext != "") + sp.SetAttribute("latency_ms", latency) + rt.Metrics.IncrementCounter(ctx, types.MetricMemoryRecallCompleted) + rt.Metrics.RecordHistogram(ctx, types.MetricMemoryRecallLatencyMs, latency) + log.Debug(ctx, "runtime: memory recall completed", + slog.String("scope", "runtime"), + slog.Int("entries", len(entries)), + slog.Bool("hasContext", memoryContext != "")) + + return &MemoryResult{ + Context: memoryContext, + TotalRecalls: 1, + FailedRecalls: 0, + }, nil +} + +// ExecuteMemoryStore extracts long-term memories from the run and persists them in scope. +func (rt *Runtime) ExecuteMemoryStore(ctx context.Context, log logger.Logger, scope interfaces.MemoryScope, messages []interfaces.Message) error { + if !rt.RunEndMemoryStoreEnabled() { + return nil + } + + extract := rt.resolveMemoryExtractFunc() + if extract == nil { + rt.recordMemoryExtractUnavailable(ctx, log) + return nil + } + + records, err := rt.extractMemoryRecords(ctx, log, messages, extract) + if err != nil { + return err + } + if len(records) == 0 { + return nil + } + return rt.StoreMemoryRecords(ctx, log, scope, records) +} diff --git a/internal/runtime/base/runtime_test.go b/internal/runtime/base/runtime_test.go index d01ed4b..3f58991 100644 --- a/internal/runtime/base/runtime_test.go +++ b/internal/runtime/base/runtime_test.go @@ -3,14 +3,18 @@ package base import ( "context" "errors" + "strings" "testing" + "time" "github.com/agenticenv/agent-sdk-go/internal/events" sdkruntime "github.com/agenticenv/agent-sdk-go/internal/runtime" + testutil "github.com/agenticenv/agent-sdk-go/internal/testing" "github.com/agenticenv/agent-sdk-go/internal/types" "github.com/agenticenv/agent-sdk-go/pkg/interfaces" ifmocks "github.com/agenticenv/agent-sdk-go/pkg/interfaces/mocks" "github.com/agenticenv/agent-sdk-go/pkg/logger" + "github.com/agenticenv/agent-sdk-go/pkg/memory" "github.com/agenticenv/agent-sdk-go/pkg/observability" "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" @@ -54,7 +58,7 @@ func TestBuildLLMRequest_Basic(t *testing.T) { LLM: sdkruntime.AgentLLM{Client: stubLLMClient{}}, }) msgs := []interfaces.Message{{Role: interfaces.MessageRoleUser, Content: "hello"}} - req := rt.BuildLLMRequest(msgs, false, "", nil) + req := rt.BuildLLMRequest(msgs, false, "", "", nil) require.Equal(t, "you are helpful", req.SystemMessage) require.Equal(t, msgs, req.Messages) @@ -65,11 +69,22 @@ func TestBuildLLMRequest_WithRetrieverContext(t *testing.T) { rt := newTestRuntime(sdkruntime.AgentConfig{ LLM: sdkruntime.AgentLLM{Client: stubLLMClient{}}, }) - req := rt.BuildLLMRequest(nil, false, "extra context", nil) + req := rt.BuildLLMRequest(nil, false, "", "extra context", nil) require.Contains(t, req.SystemMessage, "you are helpful") require.Contains(t, req.SystemMessage, "extra context") } +func TestBuildLLMRequest_WithMemoryContext(t *testing.T) { + rt := newTestRuntime(sdkruntime.AgentConfig{ + LLM: sdkruntime.AgentLLM{Client: stubLLMClient{}}, + }) + req := rt.BuildLLMRequest(nil, false, "memory fact", "retriever doc", nil) + require.Contains(t, req.SystemMessage, "Relevant Memories") + require.Contains(t, req.SystemMessage, "memory fact") + require.Contains(t, req.SystemMessage, "Relevant Context") + require.Contains(t, req.SystemMessage, "retriever doc") +} + func TestBuildLLMRequest_SkipTools(t *testing.T) { ctrl := gomock.NewController(t) tool := ifmocks.NewMockTool(ctrl) @@ -80,7 +95,7 @@ func TestBuildLLMRequest_SkipTools(t *testing.T) { rt := newTestRuntime(sdkruntime.AgentConfig{ LLM: sdkruntime.AgentLLM{Client: stubLLMClient{}}, }) - req := rt.BuildLLMRequest(nil, true, "", []interfaces.Tool{tool}) + req := rt.BuildLLMRequest(nil, true, "", "", []interfaces.Tool{tool}) require.Empty(t, req.Tools) } @@ -533,6 +548,645 @@ func TestExecuteRetrievers_PartialFailure(t *testing.T) { require.Equal(t, int64(1), got.FailedSearches) } +// --- StoreMemoryRecords --- + +func TestStoreMemoryRecords_appliesTTL(t *testing.T) { + store := testutil.NewInmemMemory() + cfg := memory.DefaultConfig(store) + rt := newTestRuntime(sdkruntime.AgentConfig{ + Memory: sdkruntime.AgentMemory{Config: &cfg}, + }) + + scope := interfaces.MemoryScope{UserID: "u1"} + ctx := context.Background() + before := time.Now().UTC() + + require.NoError(t, rt.StoreMemoryRecords(ctx, noopLog(), scope, []interfaces.MemoryRecord{ + {Text: "User prefers concise answers", Kind: memory.KindNote}, + })) + + entries, err := store.Load(ctx, scope, "", cfg.Recall.RecencyLoadOptions()...) + require.NoError(t, err) + require.Len(t, entries, 1) + require.False(t, entries[0].ExpiresAt.IsZero()) + want := before.Add(memory.TTLNote) + require.WithinDuration(t, want, entries[0].ExpiresAt, 2*time.Second) +} + +func TestStoreMemoryRecords_allowlistRejectsKind(t *testing.T) { + store := testutil.NewInmemMemory() + cfg := memory.DefaultConfig(store) + cfg.Store.AllowedKinds = []interfaces.MemoryKind{memory.KindFact} + rt := newTestRuntime(sdkruntime.AgentConfig{ + Memory: sdkruntime.AgentMemory{Config: &cfg}, + }) + + err := rt.StoreMemoryRecords(context.Background(), noopLog(), interfaces.MemoryScope{UserID: "u1"}, + []interfaces.MemoryRecord{{Text: "note text", Kind: memory.KindNote}}) + require.Error(t, err) +} + +func TestStoreMemoryRecords_dedupUpserts(t *testing.T) { + store := testutil.NewInmemMemory() + cfg := memory.DefaultConfig(store) + rt := newTestRuntime(sdkruntime.AgentConfig{ + Memory: sdkruntime.AgentMemory{Config: &cfg}, + }) + + scope := interfaces.MemoryScope{UserID: "u1"} + ctx := context.Background() + text := "favorite color is blue" + + require.NoError(t, rt.StoreMemoryRecords(ctx, noopLog(), scope, []interfaces.MemoryRecord{ + {Text: text, Kind: memory.KindPreference}, + })) + require.NoError(t, rt.StoreMemoryRecords(ctx, noopLog(), scope, []interfaces.MemoryRecord{ + {Text: text, Kind: memory.KindFact}, + })) + + entries, err := store.Load(ctx, scope, "", cfg.Recall.RecencyLoadOptions()...) + require.NoError(t, err) + require.Len(t, entries, 1) + require.Equal(t, memory.KindFact, entries[0].Kind) +} + +func TestStoreMemoryRecords_dedupAppendsDistinctText(t *testing.T) { + store := testutil.NewInmemMemory() + cfg := memory.DefaultConfig(store) + rt := newTestRuntime(sdkruntime.AgentConfig{ + Memory: sdkruntime.AgentMemory{Config: &cfg}, + }) + + scope := interfaces.MemoryScope{UserID: "u1"} + ctx := context.Background() + + require.NoError(t, rt.StoreMemoryRecords(ctx, noopLog(), scope, []interfaces.MemoryRecord{ + {Text: "favorite color is blue", Kind: memory.KindPreference}, + {Text: "prefers concise answers", Kind: memory.KindPreference}, + })) + + entries, err := store.Load(ctx, scope, "", cfg.Recall.RecencyLoadOptions()...) + require.NoError(t, err) + require.Len(t, entries, 2) +} + +func TestStoreMemoryRecords_appliesDefaultKind(t *testing.T) { + store := testutil.NewInmemMemory() + cfg := memory.DefaultConfig(store) + rt := newTestRuntime(sdkruntime.AgentConfig{ + Memory: sdkruntime.AgentMemory{Config: &cfg}, + }) + + scope := interfaces.MemoryScope{UserID: "u1"} + ctx := context.Background() + + require.NoError(t, rt.StoreMemoryRecords(ctx, noopLog(), scope, []interfaces.MemoryRecord{ + {Text: "remember this"}, + })) + + entries, err := store.Load(ctx, scope, "", cfg.Recall.RecencyLoadOptions()...) + require.NoError(t, err) + require.Len(t, entries, 1) + require.Equal(t, memory.KindNote, entries[0].Kind) +} + +func TestStoreMemoryRecords_notConfigured(t *testing.T) { + rt := newTestRuntime(sdkruntime.AgentConfig{}) + require.NoError(t, rt.StoreMemoryRecords(context.Background(), noopLog(), interfaces.MemoryScope{UserID: "u1"}, + []interfaces.MemoryRecord{{Text: "x"}})) +} + +func TestStoreMemoryRecords_skipsEmptyText(t *testing.T) { + store := testutil.NewInmemMemory() + cfg := memory.DefaultConfig(store) + rt := newTestRuntime(sdkruntime.AgentConfig{ + Memory: sdkruntime.AgentMemory{Config: &cfg}, + }) + + scope := interfaces.MemoryScope{UserID: "u1"} + ctx := context.Background() + + require.NoError(t, rt.StoreMemoryRecords(ctx, noopLog(), scope, []interfaces.MemoryRecord{ + {Text: " "}, + })) + + entries, err := store.Load(ctx, scope, "", cfg.Recall.RecencyLoadOptions()...) + require.NoError(t, err) + require.Empty(t, entries) +} + +func TestStoreMemoryRecords_emitsMetrics(t *testing.T) { + ctrl := gomock.NewController(t) + metrics := ifmocks.NewMockMetrics(ctrl) + metrics.EXPECT().IncrementCounter(gomock.Any(), types.MetricMemoryStoreStarted, gomock.Any()).Times(1) + metrics.EXPECT().IncrementCounter(gomock.Any(), types.MetricMemoryDedupStarted).Times(1) + metrics.EXPECT().IncrementCounter(gomock.Any(), types.MetricMemoryDedupCompleted).Times(1) + metrics.EXPECT().RecordHistogram(gomock.Any(), types.MetricMemoryDedupLatencyMs, gomock.Any()).Times(1) + metrics.EXPECT().IncrementCounter(gomock.Any(), types.MetricMemoryStoreCompleted, gomock.Any(), gomock.Any()).Times(1) + metrics.EXPECT().RecordHistogram(gomock.Any(), types.MetricMemoryStoreLatencyMs, gomock.Any(), gomock.Any()).Times(1) + + mem := ifmocks.NewMockMemory(ctrl) + scope := interfaces.MemoryScope{UserID: "u1"} + mem.EXPECT().Load(gomock.Any(), scope, "hello world", gomock.Any()).Return(nil, nil).Times(1) + mem.EXPECT().Store(gomock.Any(), scope, gomock.Any()).Return("id-1", nil).Times(1) + + cfg := memory.DefaultConfig(mem) + rt := newTestRuntime(sdkruntime.AgentConfig{ + Memory: sdkruntime.AgentMemory{Config: &cfg}, + }) + rt.Metrics = metrics + + require.NoError(t, rt.StoreMemoryRecords(context.Background(), noopLog(), scope, + []interfaces.MemoryRecord{{Text: "hello world", Kind: memory.KindNote}})) +} + +func TestStoreMemoryRecords_kindRejectedEmitsFailedMetric(t *testing.T) { + ctrl := gomock.NewController(t) + metrics := ifmocks.NewMockMetrics(ctrl) + metrics.EXPECT().IncrementCounter(gomock.Any(), types.MetricMemoryStoreFailed).Times(1) + + mem := ifmocks.NewMockMemory(ctrl) + cfg := memory.DefaultConfig(mem) + cfg.Store.AllowedKinds = []interfaces.MemoryKind{memory.KindFact} + rt := newTestRuntime(sdkruntime.AgentConfig{ + Memory: sdkruntime.AgentMemory{Config: &cfg}, + }) + rt.Metrics = metrics + + err := rt.StoreMemoryRecords(context.Background(), noopLog(), interfaces.MemoryScope{UserID: "u1"}, + []interfaces.MemoryRecord{{Text: "x", Kind: memory.KindNote}}) + require.Error(t, err) +} + +// --- default memory extract (Always mode) --- + +type stubMemoryExtractLLM struct { + resp *interfaces.LLMResponse + err error + req *interfaces.LLMRequest +} + +func (s *stubMemoryExtractLLM) Generate(_ context.Context, req *interfaces.LLMRequest) (*interfaces.LLMResponse, error) { + s.req = req + if s.err != nil { + return nil, s.err + } + return s.resp, nil +} + +func (stubMemoryExtractLLM) GenerateStream(context.Context, *interfaces.LLMRequest) (interfaces.LLMStream, error) { + return nil, errors.New("not supported") +} +func (stubMemoryExtractLLM) GetModel() string { return "stub" } +func (stubMemoryExtractLLM) GetProvider() interfaces.LLMProvider { return interfaces.LLMProviderOpenAI } +func (stubMemoryExtractLLM) IsStreamSupported() bool { return false } + +func TestResolveMemoryExtractFunc_defaultLLM(t *testing.T) { + llm := &stubMemoryExtractLLM{resp: &interfaces.LLMResponse{ + Content: `{"memories":[{"text":"prefers concise answers","kind":"preference"}]}`, + }} + cfg := memory.DefaultConfig(testutil.NewInmemMemory()) + cfg.Store.Mode = memory.StoreModeAlways + rt := newTestRuntime(sdkruntime.AgentConfig{ + LLM: sdkruntime.AgentLLM{Client: llm}, + Memory: sdkruntime.AgentMemory{Config: &cfg}, + }) + extract := rt.resolveMemoryExtractFunc() + require.NotNil(t, extract) + require.Nil(t, cfg.Store.Extract) + + records, err := extract(context.Background(), []interfaces.Message{ + {Role: interfaces.MessageRoleUser, Content: "keep it short"}, + {Role: interfaces.MessageRoleAssistant, Content: "will do"}, + }) + require.NoError(t, err) + require.Len(t, records, 1) + require.Equal(t, "prefers concise answers", records[0].Text) + require.Equal(t, memory.KindPreference, records[0].Kind) + require.Equal(t, "extract", records[0].Metadata["source"]) + require.NotNil(t, llm.req.ResponseFormat) +} + +func TestResolveMemoryExtractFunc_skipsOnDemand(t *testing.T) { + llm := &stubMemoryExtractLLM{} + cfg := memory.DefaultConfig(testutil.NewInmemMemory()) + rt := newTestRuntime(sdkruntime.AgentConfig{ + LLM: sdkruntime.AgentLLM{Client: llm}, + Memory: sdkruntime.AgentMemory{Config: &cfg}, + }) + require.Nil(t, rt.resolveMemoryExtractFunc()) +} + +func TestResolveMemoryExtractFunc_preservesCustom(t *testing.T) { + custom := memory.ExtractFunc(func(context.Context, []interfaces.Message) ([]interfaces.MemoryRecord, error) { + return []interfaces.MemoryRecord{{Text: "custom"}}, nil + }) + cfg := memory.DefaultConfig(testutil.NewInmemMemory()) + cfg.Store.Mode = memory.StoreModeAlways + cfg.Store.Extract = custom + rt := newTestRuntime(sdkruntime.AgentConfig{ + LLM: sdkruntime.AgentLLM{Client: &stubMemoryExtractLLM{}}, + Memory: sdkruntime.AgentMemory{Config: &cfg}, + }) + records, err := rt.resolveMemoryExtractFunc()(context.Background(), nil) + require.NoError(t, err) + require.Equal(t, "custom", records[0].Text) +} + +func TestResolveMemoryExtractFunc_skipsToolMessages(t *testing.T) { + llm := &stubMemoryExtractLLM{resp: &interfaces.LLMResponse{Content: `{"memories":[]}`}} + cfg := memory.DefaultConfig(testutil.NewInmemMemory()) + cfg.Store.Mode = memory.StoreModeAlways + rt := newTestRuntime(sdkruntime.AgentConfig{ + LLM: sdkruntime.AgentLLM{Client: llm}, + Memory: sdkruntime.AgentMemory{Config: &cfg}, + }) + records, err := rt.resolveMemoryExtractFunc()(context.Background(), []interfaces.Message{ + {Role: interfaces.MessageRoleTool, Content: "tool output"}, + }) + require.NoError(t, err) + require.Nil(t, records) + require.Nil(t, llm.req) +} + +func TestMessagesForMemoryExtraction_appendsUserTurnAfterAssistant(t *testing.T) { + msgs := messagesForMemoryExtraction([]interfaces.Message{ + {Role: interfaces.MessageRoleUser, Content: "remember this"}, + {Role: interfaces.MessageRoleAssistant, Content: "ok"}, + }) + require.Len(t, msgs, 3) + require.Equal(t, interfaces.MessageRoleUser, msgs[len(msgs)-1].Role) + require.Equal(t, memoryExtractTurnPrompt, msgs[len(msgs)-1].Content) +} + +func TestMessagesForMemoryExtraction_keepsUserLast(t *testing.T) { + msgs := messagesForMemoryExtraction([]interfaces.Message{ + {Role: interfaces.MessageRoleUser, Content: "only user"}, + }) + require.Len(t, msgs, 1) + require.Equal(t, interfaces.MessageRoleUser, msgs[0].Role) +} + +func TestParseMemoryExtractResponse_invalidJSON(t *testing.T) { + _, err := parseMemoryExtractResponse("{") + require.Error(t, err) +} + +// --- ExecuteMemory --- + +func memoryConfigAlways(store interfaces.Memory) memory.Config { + cfg := memory.DefaultConfig(store) + cfg.Store.Mode = memory.StoreModeAlways + return cfg +} + +func testRunMessages(user, assistant string) []interfaces.Message { + var msgs []interfaces.Message + if user != "" { + msgs = append(msgs, interfaces.Message{Role: interfaces.MessageRoleUser, Content: user}) + } + if assistant != "" { + msgs = append(msgs, interfaces.Message{Role: interfaces.MessageRoleAssistant, Content: assistant}) + } + return msgs +} + +func testAlwaysStoreExtract(_ context.Context, messages []interfaces.Message) ([]interfaces.MemoryRecord, error) { + var user, assistant string + for _, m := range messages { + switch m.Role { + case interfaces.MessageRoleUser: + user = strings.TrimSpace(m.Content) + case interfaces.MessageRoleAssistant: + assistant = strings.TrimSpace(m.Content) + } + } + if user == "" && assistant == "" { + return nil, nil + } + var text string + switch { + case user != "" && assistant != "": + text = "User: " + user + "\nAssistant: " + assistant + case assistant != "": + text = "Assistant: " + assistant + default: + text = "User: " + user + } + return []interfaces.MemoryRecord{{ + Text: text, + Metadata: map[string]string{"source": "extract"}, + }}, nil +} + +func TestExecuteMemoryStore_skipsOnDemand(t *testing.T) { + store := testutil.NewInmemMemory() + cfg := memory.DefaultConfig(store) + rt := newTestRuntime(sdkruntime.AgentConfig{ + Memory: sdkruntime.AgentMemory{Config: &cfg}, + }) + scope := interfaces.MemoryScope{UserID: "u1"} + require.NoError(t, rt.ExecuteMemoryStore(context.Background(), noopLog(), scope, testRunMessages("hello", "world"))) + entries, err := store.Load(context.Background(), scope, "", cfg.Recall.RecencyLoadOptions()...) + require.NoError(t, err) + require.Empty(t, entries) +} + +func TestExecuteToolWithMemoryScope_saveMemory(t *testing.T) { + store := testutil.NewInmemMemory() + cfg := memory.DefaultConfig(store) + rt := newTestRuntime(sdkruntime.AgentConfig{ + Memory: sdkruntime.AgentMemory{Config: &cfg}, + }) + scope := interfaces.MemoryScope{UserID: "u1"} + out, err := rt.ExecuteToolWithMemoryScope(context.Background(), noopLog(), nil, types.SaveMemoryToolName, + map[string]any{types.MemoryToolParamText: "favorite color is blue"}, scope) + require.NoError(t, err) + require.Equal(t, "memory saved", out) + entries, err := store.Load(context.Background(), scope, "", cfg.Recall.RecencyLoadOptions()...) + require.NoError(t, err) + require.Len(t, entries, 1) +} + +func TestExecuteMemoryRecallAndStore(t *testing.T) { + store := testutil.NewInmemMemory() + cfg := memoryConfigAlways(store) + cfg.Store.Extract = testAlwaysStoreExtract + rt := newTestRuntime(sdkruntime.AgentConfig{ + Memory: sdkruntime.AgentMemory{Config: &cfg}, + }) + + scope := interfaces.MemoryScope{UserID: "u1"} + ctx := context.Background() + + require.NoError(t, rt.ExecuteMemoryStore(ctx, noopLog(), scope, testRunMessages("hello", "world"))) + + res, err := rt.ExecuteMemoryRecall(ctx, noopLog(), scope, "hello") + require.NoError(t, err) + require.NotEmpty(t, res.Context) +} + +func TestExecuteMemoryStore_AppliesTTLFromPolicy(t *testing.T) { + store := testutil.NewInmemMemory() + cfg := memoryConfigAlways(store) + cfg.Store.Extract = testAlwaysStoreExtract + rt := newTestRuntime(sdkruntime.AgentConfig{ + Memory: sdkruntime.AgentMemory{Config: &cfg}, + }) + + scope := interfaces.MemoryScope{UserID: "u1"} + ctx := context.Background() + before := time.Now().UTC() + + require.NoError(t, rt.ExecuteMemoryStore(ctx, noopLog(), scope, testRunMessages("hello", "world"))) + + entries, err := store.Load(ctx, scope, "", cfg.Recall.RecencyLoadOptions()...) + require.NoError(t, err) + require.Len(t, entries, 1) + require.False(t, entries[0].ExpiresAt.IsZero()) + want := before.Add(memory.TTLNote) + require.WithinDuration(t, want, entries[0].ExpiresAt, 2*time.Second) +} + +func TestExecuteMemoryStore_skipsEmptyMessages(t *testing.T) { + store := testutil.NewInmemMemory() + cfg := memoryConfigAlways(store) + cfg.Store.Extract = testAlwaysStoreExtract + rt := newTestRuntime(sdkruntime.AgentConfig{ + Memory: sdkruntime.AgentMemory{Config: &cfg}, + }) + + scope := interfaces.MemoryScope{UserID: "u1"} + ctx := context.Background() + + require.NoError(t, rt.ExecuteMemoryStore(ctx, noopLog(), scope, nil)) + require.NoError(t, rt.ExecuteMemoryStore(ctx, noopLog(), scope, []interfaces.Message{ + {Role: interfaces.MessageRoleTool, Content: "noise"}, + })) + + entries, err := store.Load(ctx, scope, "", cfg.Recall.RecencyLoadOptions()...) + require.NoError(t, err) + require.Empty(t, entries) +} + +func TestExecuteMemoryStore_noExtractorEmitsFailedMetric(t *testing.T) { + ctrl := gomock.NewController(t) + metrics := ifmocks.NewMockMetrics(ctrl) + metrics.EXPECT().IncrementCounter(gomock.Any(), types.MetricMemoryExtractFailed).Times(1) + + store := testutil.NewInmemMemory() + cfg := memoryConfigAlways(store) + rt := newTestRuntime(sdkruntime.AgentConfig{ + Memory: sdkruntime.AgentMemory{Config: &cfg}, + }) + rt.Metrics = metrics + + require.NoError(t, rt.ExecuteMemoryStore(context.Background(), noopLog(), interfaces.MemoryScope{UserID: "u1"}, + testRunMessages("hi", "there"))) +} + +func TestExecuteMemoryExtract_EmitsMetrics(t *testing.T) { + ctrl := gomock.NewController(t) + metrics := ifmocks.NewMockMetrics(ctrl) + metrics.EXPECT().IncrementCounter(gomock.Any(), types.MetricMemoryExtractStarted).Times(1) + metrics.EXPECT().IncrementCounter(gomock.Any(), types.MetricMemoryExtractCompleted).Times(1) + metrics.EXPECT().RecordHistogram(gomock.Any(), types.MetricMemoryExtractLatencyMs, gomock.Any()).Times(1) + metrics.EXPECT().IncrementCounter(gomock.Any(), types.MetricMemoryDedupStarted).Times(1) + metrics.EXPECT().IncrementCounter(gomock.Any(), types.MetricMemoryDedupCompleted).Times(1) + metrics.EXPECT().RecordHistogram(gomock.Any(), types.MetricMemoryDedupLatencyMs, gomock.Any()).Times(1) + metrics.EXPECT().IncrementCounter(gomock.Any(), types.MetricMemoryStoreStarted, gomock.Any()).Times(1) + metrics.EXPECT().IncrementCounter(gomock.Any(), types.MetricMemoryStoreCompleted, gomock.Any(), gomock.Any()).Times(1) + metrics.EXPECT().RecordHistogram(gomock.Any(), types.MetricMemoryStoreLatencyMs, gomock.Any(), gomock.Any()).Times(1) + + store := testutil.NewInmemMemory() + cfg := memoryConfigAlways(store) + cfg.Store.Extract = testAlwaysStoreExtract + rt := newTestRuntime(sdkruntime.AgentConfig{ + Memory: sdkruntime.AgentMemory{Config: &cfg}, + }) + rt.Metrics = metrics + + scope := interfaces.MemoryScope{UserID: "u1"} + require.NoError(t, rt.ExecuteMemoryStore(context.Background(), noopLog(), scope, testRunMessages("hello", "world"))) +} + +func TestExecuteMemoryExtract_EmitsFailedMetric(t *testing.T) { + ctrl := gomock.NewController(t) + metrics := ifmocks.NewMockMetrics(ctrl) + metrics.EXPECT().IncrementCounter(gomock.Any(), types.MetricMemoryExtractStarted).Times(1) + metrics.EXPECT().IncrementCounter(gomock.Any(), types.MetricMemoryExtractFailed).Times(1) + metrics.EXPECT().RecordHistogram(gomock.Any(), types.MetricMemoryExtractLatencyMs, gomock.Any()).Times(1) + + store := testutil.NewInmemMemory() + cfg := memoryConfigAlways(store) + cfg.Store.Extract = func(context.Context, []interfaces.Message) ([]interfaces.MemoryRecord, error) { + return nil, errors.New("extract failed") + } + rt := newTestRuntime(sdkruntime.AgentConfig{ + Memory: sdkruntime.AgentMemory{Config: &cfg}, + }) + rt.Metrics = metrics + + err := rt.ExecuteMemoryStore(context.Background(), noopLog(), interfaces.MemoryScope{UserID: "u1"}, testRunMessages("hi", "there")) + require.Error(t, err) +} + +func TestExecuteMemoryStore_extractsWithDefaultLLM(t *testing.T) { + llm := &stubMemoryExtractLLM{resp: &interfaces.LLMResponse{ + Content: `{"memories":[{"text":"user likes tea","kind":"preference"}]}`, + }} + store := testutil.NewInmemMemory() + cfg := memoryConfigAlways(store) + rt := newTestRuntime(sdkruntime.AgentConfig{ + LLM: sdkruntime.AgentLLM{Client: llm}, + Memory: sdkruntime.AgentMemory{Config: &cfg}, + }) + + scope := interfaces.MemoryScope{UserID: "u1"} + ctx := context.Background() + require.NoError(t, rt.ExecuteMemoryStore(ctx, noopLog(), scope, testRunMessages("I like tea", "noted"))) + + entries, err := store.Load(ctx, scope, "", cfg.Recall.RecencyLoadOptions()...) + require.NoError(t, err) + require.Len(t, entries, 1) + require.Equal(t, "user likes tea", entries[0].Text) + require.Equal(t, memory.KindPreference, entries[0].Kind) + require.Equal(t, "extract", entries[0].Metadata["source"]) +} + +func TestExecuteMemoryStore_setsExtractMetadata(t *testing.T) { + store := testutil.NewInmemMemory() + cfg := memoryConfigAlways(store) + cfg.Store.Extract = testAlwaysStoreExtract + rt := newTestRuntime(sdkruntime.AgentConfig{ + Memory: sdkruntime.AgentMemory{Config: &cfg}, + }) + + scope := interfaces.MemoryScope{UserID: "u1"} + ctx := context.Background() + + require.NoError(t, rt.ExecuteMemoryStore(ctx, noopLog(), scope, testRunMessages("hi", "there"))) + + entries, err := store.Load(ctx, scope, "", cfg.Recall.RecencyLoadOptions()...) + require.NoError(t, err) + require.Len(t, entries, 1) + require.Equal(t, "extract", entries[0].Metadata["source"]) + require.Contains(t, entries[0].Text, "User: hi") +} + +func TestExecuteMemoryRecall_OmitsExpired(t *testing.T) { + store := testutil.NewInmemMemory() + cfg := memory.DefaultConfig(store) + rt := newTestRuntime(sdkruntime.AgentConfig{ + Memory: sdkruntime.AgentMemory{Config: &cfg}, + }) + + scope := interfaces.MemoryScope{UserID: "u1"} + ctx := context.Background() + + _, err := store.Store(ctx, scope, interfaces.MemoryRecord{ + Text: "User prefers concise answers", + Kind: memory.KindPreference, + ExpiresAt: time.Now().UTC().Add(-time.Hour), + }) + require.NoError(t, err) + + res, err := rt.ExecuteMemoryRecall(ctx, noopLog(), scope, "concise") + require.NoError(t, err) + require.Empty(t, res.Context) +} + +func TestExecuteMemoryRecall_SemanticMissFallsBackToRecency(t *testing.T) { + store := testutil.NewInmemMemory() + cfg := memoryConfigAlways(store) + cfg.Store.Extract = testAlwaysStoreExtract + rt := newTestRuntime(sdkruntime.AgentConfig{ + Memory: sdkruntime.AgentMemory{Config: &cfg}, + }) + + scope := interfaces.MemoryScope{UserID: "u1"} + ctx := context.Background() + + require.NoError(t, rt.ExecuteMemoryStore(ctx, noopLog(), scope, testRunMessages( + "Remember that I prefer concise answers.", "Got it."))) + + res, err := rt.ExecuteMemoryRecall(ctx, noopLog(), scope, "What answer style do I prefer?") + require.NoError(t, err) + require.Contains(t, res.Context, "concise answers") +} + +func TestSubAgentScope(t *testing.T) { + parent := interfaces.MemoryScope{ + TenantID: "t1", + UserID: "u1", + AgentID: "main", + } + got := SubAgentScope(parent, "sub-researcher") + if got.TenantID != "t1" || got.UserID != "u1" { + t.Fatalf("tenant/user = %+v", got) + } + if got.AgentID != "sub-researcher" { + t.Fatalf("agentID = %q", got.AgentID) + } + if got.Tags[scopeKeyParentAgentID] != "main" { + t.Fatalf("tags = %+v", got.Tags) + } +} + +func TestSubAgentScope_nestedDelegation(t *testing.T) { + parent := SubAgentScope(interfaces.MemoryScope{ + UserID: "u1", + AgentID: "main", + }, "sub-a") + got := SubAgentScope(parent, "sub-b") + if got.AgentID != "sub-b" { + t.Fatalf("agentID = %q", got.AgentID) + } + if got.Tags[scopeKeyParentAgentID] != "sub-a" { + t.Fatalf("tags = %+v", got.Tags) + } +} + +func TestExecuteMemoryRecall_EmitsMetrics(t *testing.T) { + ctrl := gomock.NewController(t) + metrics := ifmocks.NewMockMetrics(ctrl) + metrics.EXPECT().IncrementCounter(gomock.Any(), types.MetricMemoryRecallStarted).Times(1) + metrics.EXPECT().IncrementCounter(gomock.Any(), types.MetricMemoryRecallCompleted).Times(1) + metrics.EXPECT().RecordHistogram(gomock.Any(), types.MetricMemoryRecallLatencyMs, gomock.Any()).Times(1) + metrics.EXPECT().IncrementCounter(gomock.Any(), types.MetricMemoryExtractStarted).Times(1) + metrics.EXPECT().IncrementCounter(gomock.Any(), types.MetricMemoryExtractCompleted).Times(1) + metrics.EXPECT().RecordHistogram(gomock.Any(), types.MetricMemoryExtractLatencyMs, gomock.Any()).Times(1) + metrics.EXPECT().IncrementCounter(gomock.Any(), types.MetricMemoryDedupStarted).Times(1) + metrics.EXPECT().IncrementCounter(gomock.Any(), types.MetricMemoryDedupCompleted).Times(1) + metrics.EXPECT().RecordHistogram(gomock.Any(), types.MetricMemoryDedupLatencyMs, gomock.Any()).Times(1) + metrics.EXPECT().IncrementCounter(gomock.Any(), types.MetricMemoryStoreStarted, gomock.Any()).Times(1) + metrics.EXPECT().IncrementCounter(gomock.Any(), types.MetricMemoryStoreCompleted, gomock.Any(), gomock.Any()).Times(1) + metrics.EXPECT().RecordHistogram(gomock.Any(), types.MetricMemoryStoreLatencyMs, gomock.Any(), gomock.Any()).Times(1) + + mem := ifmocks.NewMockMemory(ctrl) + scope := interfaces.MemoryScope{UserID: "u1"} + mem.EXPECT().Store(gomock.Any(), scope, gomock.Any()).Return("id-1", nil).Times(1) + mem.EXPECT().Load(gomock.Any(), scope, gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, _ interfaces.MemoryScope, query string, _ ...interfaces.LoadMemoryOption) ([]interfaces.MemoryEntry, error) { + if query == "hello" { + return []interfaces.MemoryEntry{{Text: "User: hello\nAssistant: world"}}, nil + } + return nil, nil + }).AnyTimes() + + cfg := memoryConfigAlways(mem) + cfg.Store.Extract = testAlwaysStoreExtract + rt := newTestRuntime(sdkruntime.AgentConfig{ + Memory: sdkruntime.AgentMemory{Config: &cfg}, + }) + rt.Metrics = metrics + + ctx := context.Background() + require.NoError(t, rt.ExecuteMemoryStore(ctx, noopLog(), scope, testRunMessages("hello", "world"))) + _, err := rt.ExecuteMemoryRecall(ctx, noopLog(), scope, "hello") + require.NoError(t, err) +} + // --- ExecuteLLMStream --- // streamCapableLLMClient wraps a stubLLMClient and sets IsStreamSupported=true. diff --git a/internal/runtime/base/types.go b/internal/runtime/base/types.go index eb34534..fa21493 100644 --- a/internal/runtime/base/types.go +++ b/internal/runtime/base/types.go @@ -5,6 +5,8 @@ import ( "github.com/agenticenv/agent-sdk-go/pkg/interfaces" ) +const scopeKeyParentAgentID = "parent_agent_id" + // LLMResult is the result of a successful LLM call. // Content holds the assistant text; ToolCalls holds any tool invocations resolved against // the registered tools list (NeedsApproval pre-computed from the approval policy). @@ -39,3 +41,10 @@ type RetrieverResult struct { TotalSearches int64 FailedSearches int64 } + +// MemoryResult is the outcome of ExecuteMemoryRecall. +type MemoryResult struct { + Context string + TotalRecalls int64 + FailedRecalls int64 +} diff --git a/internal/runtime/base/utils.go b/internal/runtime/base/utils.go index dc0ab36..8e6dc32 100644 --- a/internal/runtime/base/utils.go +++ b/internal/runtime/base/utils.go @@ -19,6 +19,27 @@ func SubAgentQuery(args map[string]any) string { return q } +// SubAgentScope derives memory scope for a delegated sub-agent from the parent run scope. +func SubAgentScope(parent interfaces.MemoryScope, subAgentID string) interfaces.MemoryScope { + subAgentID = strings.TrimSpace(subAgentID) + scope := interfaces.MemoryScope{ + TenantID: parent.TenantID, + UserID: parent.UserID, + AgentID: subAgentID, + } + if parent.AgentID != "" || len(parent.Tags) > 0 { + tags := make(map[string]string, len(parent.Tags)+1) + for key, value := range parent.Tags { + tags[key] = value + } + if parent.AgentID != "" { + tags[scopeKeyParentAgentID] = parent.AgentID + } + scope.Tags = tags + } + return scope +} + // FindToolByName returns the first tool whose Name() matches toolName. func FindToolByName(tools []interfaces.Tool, toolName string) (interfaces.Tool, bool) { for _, t := range tools { diff --git a/internal/runtime/local/agent_loop.go b/internal/runtime/local/agent_loop.go index 888ab9f..5d405e7 100644 --- a/internal/runtime/local/agent_loop.go +++ b/internal/runtime/local/agent_loop.go @@ -44,6 +44,8 @@ type AgentLoopInput struct { SubAgentDepth int // MaxSubAgentDepth caps recursive delegation. Mirrors AgentWorkflowInput.MaxSubAgentDepth. MaxSubAgentDepth int + // MemoryScope is resolved before the run and used for recall/store. + MemoryScope interfaces.MemoryScope } // AgentLoopResult is the outcome of a completed local agent run. @@ -132,6 +134,22 @@ func (rt *LocalRuntime) RunAgentLoop(ctx context.Context, input AgentLoopInput) } } + // Pre-fetch long-term memory context when recall is enabled. + memoryContext := "" + if rt.MemoryConfigured() && rt.RecallEnabled() { + log.Debug(ctx, "local: memory recall started", slog.String("scope", "loop")) + res, err := rt.ExecuteMemoryRecall(ctx, log, input.MemoryScope, input.UserPrompt) + if err != nil { + return nil, fmt.Errorf("memory recall: %w", err) + } + memoryContext = res.Context + telemetry.Storage.TotalMemoryRecalls += res.TotalRecalls + telemetry.Storage.FailedMemoryRecalls += res.FailedRecalls + log.Debug(ctx, "local: memory recall done", + slog.String("scope", "loop"), + slog.Bool("hasContext", memoryContext != "")) + } + // Pre-fetch retriever context for prefetch/hybrid modes. retrieverContext := "" retrieverMode := rt.AgentConfig.Retrievers.Mode @@ -172,6 +190,7 @@ func (rt *LocalRuntime) RunAgentLoop(ctx context.Context, input AgentLoopInput) MessageID: messageID, Messages: messages, SkipTools: false, + MemoryContext: memoryContext, RetrieverContext: retrieverContext, Tools: tools, Emit: emit, @@ -210,6 +229,7 @@ func (rt *LocalRuntime) RunAgentLoop(ctx context.Context, input AgentLoopInput) MessageID: finalMessageID, Messages: messages, SkipTools: true, + MemoryContext: memoryContext, RetrieverContext: retrieverContext, Tools: tools, Emit: emit, @@ -278,6 +298,13 @@ func (rt *LocalRuntime) RunAgentLoop(ctx context.Context, input AgentLoopInput) telemetry.Storage.FailedRetrieverSearches++ } } + if tc.ToolName == types.SaveMemoryToolName { + if result.failed { + telemetry.Storage.FailedMemoryStores++ + } else { + telemetry.Storage.TotalMemoryStores++ + } + } } if rt.conversationMemoryEnabled(input) && rt.AgentConfig.Session.ConversationSaveOnIteration && len(messages) > persistedMessageCount { @@ -291,6 +318,15 @@ func (rt *LocalRuntime) RunAgentLoop(ctx context.Context, input AgentLoopInput) rt.persistConversationMessages(ctx, input.ConversationID, messages[persistedMessageCount:]) } + if rt.RunEndMemoryStoreEnabled() { + if err := rt.ExecuteMemoryStore(ctx, log, input.MemoryScope, messages); err != nil { + log.Warn(ctx, "local: memory store failed", slog.String("scope", "loop"), slog.Any("error", err)) + telemetry.Storage.FailedMemoryStores++ + } else { + telemetry.Storage.TotalMemoryStores++ + } + } + log.Info(ctx, "local: agent run completed", slog.String("scope", "loop"), slog.String("agentName", agentName), @@ -555,6 +591,7 @@ func (rt *LocalRuntime) executeSingleTool( StreamingEnabled: input.StreamingEnabled, ChannelName: input.ChannelName, ApprovalHandler: input.ApprovalHandler, + MemoryScope: base.SubAgentScope(input.MemoryScope, stepName), SubAgentRoutes: subAgentRoute.children, SubAgentDepth: input.SubAgentDepth + 1, MaxSubAgentDepth: input.MaxSubAgentDepth, @@ -574,7 +611,7 @@ func (rt *LocalRuntime) executeSingleTool( slog.String("scope", "loop"), slog.String("tool", tc.ToolName), slog.String("toolCallID", tc.ToolCallID)) - result, execErr := rt.ExecuteTool(ctx, log, tools, tc.ToolName, tc.Args) + result, execErr := rt.ExecuteToolWithMemoryScope(ctx, log, tools, tc.ToolName, tc.Args, input.MemoryScope) if execErr != nil { content = "Tool execution failed: " + execErr.Error() failed = true diff --git a/internal/runtime/local/agent_loop_test.go b/internal/runtime/local/agent_loop_test.go index 98bb350..cd00207 100644 --- a/internal/runtime/local/agent_loop_test.go +++ b/internal/runtime/local/agent_loop_test.go @@ -10,10 +10,12 @@ import ( "github.com/agenticenv/agent-sdk-go/internal/events" sdkruntime "github.com/agenticenv/agent-sdk-go/internal/runtime" "github.com/agenticenv/agent-sdk-go/internal/runtime/base" + testutil "github.com/agenticenv/agent-sdk-go/internal/testing" "github.com/agenticenv/agent-sdk-go/internal/types" "github.com/agenticenv/agent-sdk-go/pkg/interfaces" ifmocks "github.com/agenticenv/agent-sdk-go/pkg/interfaces/mocks" "github.com/agenticenv/agent-sdk-go/pkg/logger" + "github.com/agenticenv/agent-sdk-go/pkg/memory" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -92,6 +94,117 @@ func TestRunAgentLoop_SimpleTextResponse(t *testing.T) { require.Equal(t, "hello world", result.Content) } +func TestRunAgentLoop_MemoryRecallAndStore(t *testing.T) { + store := testutil.NewInmemMemory() + memCfg := memory.DefaultConfig(store) + client := &seqLLMClient{ + responses: []*interfaces.LLMResponse{{Content: "I will be concise."}}, + } + + rt, err := NewLocalRuntime( + WithLogger(logger.NoopLogger()), + WithAgentSpec(sdkruntime.AgentSpec{Name: "mem-agent", SystemPrompt: "sys"}), + WithAgentConfig(sdkruntime.AgentConfig{ + LLM: sdkruntime.AgentLLM{Client: client}, + Memory: sdkruntime.AgentMemory{Config: &memCfg}, + Limits: sdkruntime.AgentLimits{MaxIterations: 3, Timeout: 10 * time.Second}, + }), + ) + require.NoError(t, err) + + scope := interfaces.MemoryScope{UserID: "u1"} + _, err = store.Store(context.Background(), scope, interfaces.MemoryRecord{ + Text: "User prefers concise answers", + Kind: memory.KindPreference, + }) + require.NoError(t, err) + + result, err := rt.RunAgentLoop(context.Background(), AgentLoopInput{ + UserPrompt: "What style do I prefer?", + MemoryScope: scope, + }) + require.NoError(t, err) + require.Equal(t, "I will be concise.", result.Content) + require.Equal(t, int64(1), result.Telemetry.Storage.TotalMemoryRecalls) + require.Equal(t, int64(0), result.Telemetry.Storage.TotalMemoryStores) +} + +func TestRunAgentLoop_MemoryAlwaysRunEndStore(t *testing.T) { + store := testutil.NewInmemMemory() + memCfg := memory.DefaultConfig(store) + memCfg.Store.Mode = memory.StoreModeAlways + client := &seqLLMClient{ + responses: []*interfaces.LLMResponse{ + {Content: "done"}, + {Content: `{"memories":[{"text":"greeting","kind":"note"}]}`}, + }, + } + rt, err := NewLocalRuntime( + WithLogger(logger.NoopLogger()), + WithAgentSpec(sdkruntime.AgentSpec{Name: "mem-agent", SystemPrompt: "sys"}), + WithAgentConfig(sdkruntime.AgentConfig{ + LLM: sdkruntime.AgentLLM{Client: client}, + Memory: sdkruntime.AgentMemory{Config: &memCfg}, + Limits: sdkruntime.AgentLimits{MaxIterations: 3, Timeout: 10 * time.Second}, + }), + ) + require.NoError(t, err) + scope := interfaces.MemoryScope{UserID: "u1"} + result, err := rt.RunAgentLoop(context.Background(), AgentLoopInput{ + UserPrompt: "hello", + MemoryScope: scope, + }) + require.NoError(t, err) + require.Equal(t, int64(1), result.Telemetry.Storage.TotalMemoryStores) + + entries, err := store.Load(context.Background(), scope, "", memCfg.Recall.RecencyLoadOptions()...) + require.NoError(t, err) + require.Len(t, entries, 1) + require.Equal(t, "greeting", entries[0].Text) +} + +func TestRunAgentLoop_OnDemandSaveMemoryTool(t *testing.T) { + store := testutil.NewInmemMemory() + memCfg := memory.DefaultConfig(store) + client := &seqLLMClient{ + responses: []*interfaces.LLMResponse{ + {ToolCalls: []*interfaces.ToolCall{{ + ToolCallID: "c1", + ToolName: types.SaveMemoryToolName, + Args: map[string]any{types.MemoryToolParamText: "favorite color is blue"}, + }}}, + {Content: "saved"}, + }, + } + rt, err := NewLocalRuntime( + WithLogger(logger.NoopLogger()), + WithAgentSpec(sdkruntime.AgentSpec{Name: "mem-agent", SystemPrompt: "sys"}), + WithAgentConfig(sdkruntime.AgentConfig{ + LLM: sdkruntime.AgentLLM{Client: client}, + Memory: sdkruntime.AgentMemory{Config: &memCfg}, + Limits: sdkruntime.AgentLimits{MaxIterations: 3, Timeout: 10 * time.Second}, + }), + ) + require.NoError(t, err) + scope := interfaces.MemoryScope{UserID: "u1"} + tool := stubKindTool{ + stubTool: stubTool{name: types.SaveMemoryToolName}, + kind: types.ToolKindMemory, + } + result, err := rt.RunAgentLoop(context.Background(), AgentLoopInput{ + UserPrompt: "remember my color", + MemoryScope: scope, + Tools: []interfaces.Tool{tool}, + }) + require.NoError(t, err) + require.Equal(t, "saved", result.Content) + require.Equal(t, int64(1), result.Telemetry.Storage.TotalMemoryStores) + entries, err := store.Load(context.Background(), scope, "", memCfg.Recall.RecencyLoadOptions()...) + require.NoError(t, err) + require.Len(t, entries, 1) + require.Equal(t, "favorite color is blue", entries[0].Text) +} + func TestRunAgentLoop_LLMError(t *testing.T) { client := &seqLLMClient{errs: []error{errors.New("llm fail")}} rt, _ := newLoopRT(t, 5, client) diff --git a/internal/runtime/local/runtime.go b/internal/runtime/local/runtime.go index 96208c5..a4931ed 100644 --- a/internal/runtime/local/runtime.go +++ b/internal/runtime/local/runtime.go @@ -118,6 +118,13 @@ func (rt *LocalRuntime) Execute(ctx context.Context, req *sdkruntime.ExecuteRequ } conversationID := base.GetConversationID(req) + memoryScope, memErr := rt.ResolveMemoryScope(runCtx) + if memErr != nil { + rt.logger.Warn(runCtx, "runtime memory scope resolve failed, continuing with empty scope", + slog.String("scope", "runtime"), + slog.Any("error", memErr)) + memoryScope = interfaces.MemoryScope{} + } runID := uuid.New().String() tools := req.Tools @@ -125,6 +132,7 @@ func (rt *LocalRuntime) Execute(ctx context.Context, req *sdkruntime.ExecuteRequ loopResult, err := rt.RunAgentLoop(runCtx, AgentLoopInput{ UserPrompt: req.UserPrompt, ConversationID: conversationID, + MemoryScope: memoryScope, StreamingEnabled: false, ChannelName: "", ApprovalHandler: req.ApprovalHandler, @@ -158,6 +166,13 @@ func (rt *LocalRuntime) ExecuteStream(ctx context.Context, req *sdkruntime.Execu slog.Int("inputLen", len(req.UserPrompt))) conversationID := base.GetConversationID(req) + memoryScope, memErr := rt.ResolveMemoryScope(ctx) + if memErr != nil { + rt.logger.Warn(ctx, "runtime memory scope resolve failed, continuing with empty scope", + slog.String("scope", "runtime"), + slog.Any("error", memErr)) + memoryScope = interfaces.MemoryScope{} + } runID := uuid.New().String() threadID := conversationID @@ -214,6 +229,7 @@ func (rt *LocalRuntime) ExecuteStream(ctx context.Context, req *sdkruntime.Execu result, loopErr := rt.RunAgentLoop(runCtx, AgentLoopInput{ UserPrompt: req.UserPrompt, ConversationID: conversationID, + MemoryScope: memoryScope, StreamingEnabled: req.StreamingEnabled, ChannelName: channel, ApprovalHandler: req.ApprovalHandler, diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go index 920a9f6..be593e3 100644 --- a/internal/runtime/runtime.go +++ b/internal/runtime/runtime.go @@ -12,6 +12,7 @@ import ( "github.com/agenticenv/agent-sdk-go/internal/events" "github.com/agenticenv/agent-sdk-go/internal/types" "github.com/agenticenv/agent-sdk-go/pkg/interfaces" + "github.com/agenticenv/agent-sdk-go/pkg/memory" ) //go:generate mockgen -destination=./mocks/mock_runtime.go -package=mocks github.com/agenticenv/agent-sdk-go/internal/runtime Runtime @@ -97,9 +98,15 @@ type AgentConfig struct { ToolApprovalPolicy interfaces.AgentToolApprovalPolicy Retrievers AgentRetrievers Session AgentSession + Memory AgentMemory Limits AgentLimits } +// AgentMemory holds long-term memory configuration for recall and store. +type AgentMemory struct { + Config *memory.Config +} + // AgentRetrievers holds the retriever instances and mode for prefetch and hybrid RAG. type AgentRetrievers struct { // Retrievers is the list of retriever instances registered with the agent. diff --git a/internal/runtime/temporal/agent_workflow.go b/internal/runtime/temporal/agent_workflow.go index 0077e9f..63e6b58 100644 --- a/internal/runtime/temporal/agent_workflow.go +++ b/internal/runtime/temporal/agent_workflow.go @@ -109,6 +109,7 @@ func (rt *TemporalRuntime) sendAgentEventWorkflowUpdate(ctx context.Context, eve // ConversationID is set when conversation is used; workflow fetches messages and writes assistant/tool via activities. // SubAgentDepth is 0 for a top-level user run; each child workflow increments it (runtime cap vs maxSubAgentDepth). // SubAgentRoutes maps sub-agent tool name -> route; built from WithSubAgents when the run starts. +// MemoryScope is resolved before the workflow starts and passed through for recall/store activities. // LocalChannelName is the in-process pub/sub channel name used for in-memory event fan-in across the // delegation tree. Set once at the top level (agent_event_) and propagated unchanged // to all sub-agents. Contrast with EventWorkflowID which is used for out-of-process (remote) routing. @@ -124,6 +125,7 @@ type AgentWorkflowInput struct { StreamingEnabled bool `json:"streaming_enabled,omitempty"` ConversationID string `json:"conversation_id,omitempty"` AgentFingerprint string `json:"agent_fingerprint,omitempty"` + MemoryScope interfaces.MemoryScope `json:"memory_scope,omitempty"` EventTypes []events.AgentEventType `json:"event_types,omitempty"` SubAgentDepth int `json:"sub_agent_depth,omitempty"` SubAgentRoutes map[string]SubAgentRoute `json:"sub_agent_routes,omitempty"` @@ -155,10 +157,32 @@ type AgentRetrieverResult struct { FailedSearches int64 `json:"failed_searches,omitempty"` } +// AgentMemoryRecallInput is the input to AgentMemoryRecallActivity. +type AgentMemoryRecallInput struct { + AgentFingerprint string `json:"agent_fingerprint,omitempty"` + UserPrompt string `json:"user_prompt"` + MemoryScope interfaces.MemoryScope `json:"memory_scope,omitempty"` +} + +// AgentMemoryRecallResult is the return value of AgentMemoryRecallActivity. +type AgentMemoryRecallResult struct { + MemoryContext string `json:"memory_context,omitempty"` + TotalRecalls int64 `json:"total_recalls,omitempty"` + FailedRecalls int64 `json:"failed_recalls,omitempty"` +} + +// AgentMemoryStoreInput is the input to AgentMemoryStoreActivity. +type AgentMemoryStoreInput struct { + AgentFingerprint string `json:"agent_fingerprint,omitempty"` + MemoryScope interfaces.MemoryScope `json:"memory_scope,omitempty"` + Messages []interfaces.Message `json:"messages,omitempty"` +} + // AgentLLMInput is the input to AgentLLMActivity and AgentLLMStreamActivity. // When ConversationID is set, the activity loads history from the store. MessageID is the assistant text id // for TEXT_MESSAGE_* (and stream ordering with REASONING_*); the workflow sets it each turn. // RetrieverContext is the pre-fetched RAG context from AgentRetrieverActivity (prefetch / hybrid modes). +// MemoryContext is the pre-fetched long-term memory context from AgentMemoryRecallActivity. type AgentLLMInput struct { AgentName string `json:"agent_name,omitempty"` ConversationID string `json:"conversation_id,omitempty"` @@ -169,6 +193,7 @@ type AgentLLMInput struct { EventWorkflowID string `json:"event_workflow_id,omitempty"` EventTaskQueue string `json:"event_task_queue,omitempty"` LocalChannelName string `json:"local_channel_name,omitempty"` + MemoryContext string `json:"memory_context,omitempty"` RetrieverContext string `json:"retriever_context,omitempty"` } @@ -237,12 +262,13 @@ type agentToolResult struct { // AgentToolExecuteInput is the input to AgentToolExecuteActivity. type AgentToolExecuteInput struct { - ToolName string `json:"tool_name"` - Args map[string]any `json:"args"` - ConversationID string `json:"conversation_id,omitempty"` - Messages []interfaces.Message `json:"messages,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` - AgentFingerprint string `json:"agent_fingerprint,omitempty"` + ToolName string `json:"tool_name"` + Args map[string]any `json:"args"` + ConversationID string `json:"conversation_id,omitempty"` + Messages []interfaces.Message `json:"messages,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + AgentFingerprint string `json:"agent_fingerprint,omitempty"` + MemoryScope interfaces.MemoryScope `json:"memory_scope,omitempty"` } type AgentToolApprovalInput struct { @@ -341,6 +367,11 @@ func (rt *TemporalRuntime) AgentWorkflow(ctx workflow.Context, input AgentWorkfl StartToCloseTimeout: agentRetrieverActivityTaskTimeout, RetryPolicy: retryPolicy(agentRetrieverActivityMaxAttempts), }) + memoryActCtx := workflow.WithActivityOptions(ctx, workflow.ActivityOptions{ + ActivityID: "AgentMemoryActivity_" + activityIDSuffix, + StartToCloseTimeout: agentRetrieverActivityTaskTimeout, + RetryPolicy: retryPolicy(agentRetrieverActivityMaxAttempts), + }) var streamingUnavailable bool // emitAgentEvent must use wfCtx (the coroutine that calls Get) for ExecuteActivity().Get — not the root @@ -410,6 +441,26 @@ func (rt *TemporalRuntime) AgentWorkflow(ctx workflow.Context, input AgentWorkfl messages := input.State.Messages + memoryContext := "" + if rt.MemoryConfigured() && rt.RecallEnabled() { + logger.Debug("workflow: memory recall started", "scope", "workflow") + var memoryResult AgentMemoryRecallResult + if err := workflow.ExecuteActivity(memoryActCtx, rt.AgentMemoryRecallActivity, AgentMemoryRecallInput{ + AgentFingerprint: input.AgentFingerprint, + UserPrompt: input.UserPrompt, + MemoryScope: input.MemoryScope, + }).Get(memoryActCtx, &memoryResult); err != nil { + if temporal.IsCanceledError(err) { + return nil, err + } + return nil, err + } + memoryContext = memoryResult.MemoryContext + telemetry.Storage.TotalMemoryRecalls += memoryResult.TotalRecalls + telemetry.Storage.FailedMemoryRecalls += memoryResult.FailedRecalls + logger.Debug("workflow: memory recall done", "scope", "workflow", "hasContext", memoryContext != "") + } + // Pre-fetch retrieval context once before the first LLM call (prefetch and hybrid modes). // The resulting retrieverContext is forwarded to every AgentLLMInput in the run so the LLM always // sees the retrieved documents in its system prompt, regardless of the number of iterations. @@ -451,6 +502,7 @@ func (rt *TemporalRuntime) AgentWorkflow(ctx workflow.Context, input AgentWorkfl EventWorkflowID: eventWorkflowID, EventTaskQueue: eventTaskQueue, LocalChannelName: input.LocalChannelName, + MemoryContext: memoryContext, RetrieverContext: retrieverContext, } @@ -674,6 +726,13 @@ func (rt *TemporalRuntime) AgentWorkflow(ctx workflow.Context, input AgentWorkfl telemetry.Storage.FailedRetrieverSearches++ } } + if tc.ToolName == types.SaveMemoryToolName { + if result.failed { + telemetry.Storage.FailedMemoryStores++ + } else { + telemetry.Storage.TotalMemoryStores++ + } + } messages = append(messages, result.message) } @@ -726,6 +785,22 @@ func (rt *TemporalRuntime) AgentWorkflow(ctx workflow.Context, input AgentWorkfl } } + if rt.RunEndMemoryStoreEnabled() { + if err := workflow.ExecuteActivity(memoryActCtx, rt.AgentMemoryStoreActivity, AgentMemoryStoreInput{ + AgentFingerprint: input.AgentFingerprint, + MemoryScope: input.MemoryScope, + Messages: messages, + }).Get(memoryActCtx, nil); err != nil { + if temporal.IsCanceledError(err) { + return nil, err + } + logger.Warn("workflow: memory store failed", "scope", "workflow", "error", err) + telemetry.Storage.FailedMemoryStores++ + } else { + telemetry.Storage.TotalMemoryStores++ + } + } + // Log summary only; avoid full content to prevent leaking sensitive data logger.Info("workflow: agent run completed", "scope", "workflow", "contentLen", len(lastContent)) @@ -912,6 +987,7 @@ func (rt *TemporalRuntime) executeAgentToolCall(input agentToolCallInput, tc Too ConversationID: input.input.ConversationID, ToolCallID: tc.ToolCallID, AgentFingerprint: input.input.AgentFingerprint, + MemoryScope: input.input.MemoryScope, } errExec := workflow.ExecuteActivity(input.execCtx, rt.AgentToolExecuteActivity, execInput).Get(input.execCtx, &result) if errExec != nil { @@ -1024,6 +1100,7 @@ func (rt *TemporalRuntime) AgentLLMStreamActivity(ctx context.Context, input Age Messages: messages, SkipTools: input.SkipTools, RetrieverContext: input.RetrieverContext, + MemoryContext: input.MemoryContext, Tools: tools, Emit: emit, } @@ -1056,6 +1133,32 @@ func (rt *TemporalRuntime) AgentRetrieverActivity(ctx context.Context, input Age }, nil } +// AgentMemoryRecallActivity loads scoped long-term memories and returns formatted prompt context. +func (rt *TemporalRuntime) AgentMemoryRecallActivity(ctx context.Context, input AgentMemoryRecallInput) (*AgentMemoryRecallResult, error) { + if err := rt.verifyAgentFingerprint(ctx, input.AgentFingerprint, nil); err != nil { + return nil, err + } + actLog := newActivityLogger(activity.GetLogger(ctx)) + res, err := rt.ExecuteMemoryRecall(ctx, actLog, input.MemoryScope, input.UserPrompt) + if err != nil { + return nil, err + } + return &AgentMemoryRecallResult{ + MemoryContext: res.Context, + TotalRecalls: res.TotalRecalls, + FailedRecalls: res.FailedRecalls, + }, nil +} + +// AgentMemoryStoreActivity extracts and persists long-term memories from the run. +func (rt *TemporalRuntime) AgentMemoryStoreActivity(ctx context.Context, input AgentMemoryStoreInput) error { + if err := rt.verifyAgentFingerprint(ctx, input.AgentFingerprint, nil); err != nil { + return err + } + actLog := newActivityLogger(activity.GetLogger(ctx)) + return rt.ExecuteMemoryStore(ctx, actLog, input.MemoryScope, input.Messages) +} + // AgentLLMActivity calls the LLM and returns content plus any tool calls. // When input.ConversationID is set, fetches from store and adds assistant message on completion. func (rt *TemporalRuntime) AgentLLMActivity(ctx context.Context, input AgentLLMInput) (*AgentLLMResult, error) { @@ -1089,6 +1192,7 @@ func (rt *TemporalRuntime) AgentLLMActivity(ctx context.Context, input AgentLLMI Messages: messages, SkipTools: input.SkipTools, RetrieverContext: input.RetrieverContext, + MemoryContext: input.MemoryContext, Tools: tools, Emit: emit, } @@ -1263,7 +1367,7 @@ func (rt *TemporalRuntime) AgentToolExecuteActivity(ctx context.Context, input A stopHB := startLongActivityHeartbeats(ctx) defer stopHB() actLog := newActivityLogger(activity.GetLogger(ctx)) - return rt.ExecuteTool(ctx, actLog, tools, input.ToolName, input.Args) + return rt.ExecuteToolWithMemoryScope(ctx, actLog, tools, input.ToolName, input.Args, input.MemoryScope) } // AgentToolAuthorizeActivity checks optional programmatic authorization before approval/execute. @@ -1304,6 +1408,10 @@ func (rt *TemporalRuntime) delegateToSubAgent(ctx workflow.Context, input AgentW } query := base.SubAgentQuery(tc.Args) + subAgentID := strings.TrimSpace(route.Name) + if subAgentID == "" { + subAgentID = tc.ToolName + } childInput := AgentWorkflowInput{ UserPrompt: query, EventWorkflowID: input.EventWorkflowID, @@ -1313,6 +1421,7 @@ func (rt *TemporalRuntime) delegateToSubAgent(ctx workflow.Context, input AgentW ConversationID: "", AgentFingerprint: route.AgentFingerprint, EventTypes: input.EventTypes, + MemoryScope: base.SubAgentScope(input.MemoryScope, subAgentID), SubAgentDepth: input.SubAgentDepth + 1, SubAgentRoutes: route.ChildRoutes, } diff --git a/internal/runtime/temporal/agent_workflow_test.go b/internal/runtime/temporal/agent_workflow_test.go index 4ec5611..a37661d 100644 --- a/internal/runtime/temporal/agent_workflow_test.go +++ b/internal/runtime/temporal/agent_workflow_test.go @@ -14,10 +14,12 @@ import ( sdkruntime "github.com/agenticenv/agent-sdk-go/internal/runtime" "github.com/agenticenv/agent-sdk-go/internal/runtime/base" + testutil "github.com/agenticenv/agent-sdk-go/internal/testing" "github.com/agenticenv/agent-sdk-go/internal/types" "github.com/agenticenv/agent-sdk-go/pkg/interfaces" "github.com/agenticenv/agent-sdk-go/pkg/interfaces/mocks" "github.com/agenticenv/agent-sdk-go/pkg/logger" + "github.com/agenticenv/agent-sdk-go/pkg/memory" "github.com/agenticenv/agent-sdk-go/pkg/observability" ) @@ -912,3 +914,102 @@ func TestAgentWorkflow_AgenticMode_SkipsRetrieverActivity(t *testing.T) { require.True(t, env.IsWorkflowCompleted()) require.NoError(t, env.GetWorkflowError()) } + +func TestAgentToolExecuteActivity_saveMemory(t *testing.T) { + store := testutil.NewInmemMemory() + cfg := memory.DefaultConfig(store) + rt := testRuntimeForWorkflow(t) + rt.AgentConfig.Memory = sdkruntime.AgentMemory{Config: &cfg} + wireTestToolsResolver(rt, nil) + + env := newActivityTestEnv(t) + env.RegisterActivity(rt.AgentToolExecuteActivity) + + scope := interfaces.MemoryScope{UserID: "u1"} + val, err := env.ExecuteActivity(rt.AgentToolExecuteActivity, AgentToolExecuteInput{ + ToolName: types.SaveMemoryToolName, + Args: map[string]any{types.MemoryToolParamText: "remember this"}, + MemoryScope: scope, + }) + require.NoError(t, err) + var got string + require.NoError(t, val.Get(&got)) + require.Equal(t, "memory saved", got) + + entries, err := store.Load(context.Background(), scope, "", cfg.Recall.RecencyLoadOptions()...) + require.NoError(t, err) + require.Len(t, entries, 1) + require.Equal(t, "remember this", entries[0].Text) +} + +func TestAgentMemoryStoreActivity_skipsOnDemand(t *testing.T) { + store := testutil.NewInmemMemory() + cfg := memory.DefaultConfig(store) + rt := testRuntimeForWorkflow(t) + rt.AgentConfig.Memory = sdkruntime.AgentMemory{Config: &cfg} + + env := newActivityTestEnv(t) + env.RegisterActivity(rt.AgentMemoryStoreActivity) + + scope := interfaces.MemoryScope{UserID: "u1"} + _, err := env.ExecuteActivity(rt.AgentMemoryStoreActivity, AgentMemoryStoreInput{ + MemoryScope: scope, + Messages: []interfaces.Message{ + {Role: interfaces.MessageRoleUser, Content: "hi"}, + {Role: interfaces.MessageRoleAssistant, Content: "there"}, + }, + }) + require.NoError(t, err) + + entries, err := store.Load(context.Background(), scope, "", cfg.Recall.RecencyLoadOptions()...) + require.NoError(t, err) + require.Empty(t, entries) +} + +func TestAgentMemoryStoreActivity_alwaysExtracts(t *testing.T) { + store := testutil.NewInmemMemory() + cfg := memory.DefaultConfig(store) + cfg.Store.Mode = memory.StoreModeAlways + rt := testRuntimeForWorkflow(t) + llm := &stubMemoryExtractLLM{resp: &interfaces.LLMResponse{ + Content: `{"memories":[{"text":"stored fact","kind":"fact"}]}`, + }} + rt.AgentConfig.LLM = sdkruntime.AgentLLM{Client: llm} + rt.AgentConfig.Memory = sdkruntime.AgentMemory{Config: &cfg} + + env := newActivityTestEnv(t) + env.RegisterActivity(rt.AgentMemoryStoreActivity) + + scope := interfaces.MemoryScope{UserID: "u1"} + _, err := env.ExecuteActivity(rt.AgentMemoryStoreActivity, AgentMemoryStoreInput{ + MemoryScope: scope, + Messages: []interfaces.Message{ + {Role: interfaces.MessageRoleUser, Content: "remember this"}, + {Role: interfaces.MessageRoleAssistant, Content: "ok"}, + }, + }) + require.NoError(t, err) + + entries, err := store.Load(context.Background(), scope, "", cfg.Recall.RecencyLoadOptions()...) + require.NoError(t, err) + require.Len(t, entries, 1) + require.Equal(t, "stored fact", entries[0].Text) +} + +type stubMemoryExtractLLM struct { + resp *interfaces.LLMResponse + err error +} + +func (s *stubMemoryExtractLLM) Generate(_ context.Context, _ *interfaces.LLMRequest) (*interfaces.LLMResponse, error) { + if s.err != nil { + return nil, s.err + } + return s.resp, nil +} +func (stubMemoryExtractLLM) GenerateStream(context.Context, *interfaces.LLMRequest) (interfaces.LLMStream, error) { + return nil, fmt.Errorf("not supported") +} +func (stubMemoryExtractLLM) GetModel() string { return "stub" } +func (stubMemoryExtractLLM) GetProvider() interfaces.LLMProvider { return interfaces.LLMProviderOpenAI } +func (stubMemoryExtractLLM) IsStreamSupported() bool { return false } diff --git a/internal/runtime/temporal/runtime.go b/internal/runtime/temporal/runtime.go index 743a282..e4264a1 100644 --- a/internal/runtime/temporal/runtime.go +++ b/internal/runtime/temporal/runtime.go @@ -193,6 +193,8 @@ func (rt *TemporalRuntime) Start(ctx context.Context) error { w.RegisterActivityWithOptions(rt.AgentLLMActivity, activity.RegisterOptions{Name: "AgentLLMActivity"}) w.RegisterActivityWithOptions(rt.AgentLLMStreamActivity, activity.RegisterOptions{Name: "AgentLLMStreamActivity"}) w.RegisterActivityWithOptions(rt.AgentRetrieverActivity, activity.RegisterOptions{Name: "AgentRetrieverActivity"}) + w.RegisterActivityWithOptions(rt.AgentMemoryRecallActivity, activity.RegisterOptions{Name: "AgentMemoryRecallActivity"}) + w.RegisterActivityWithOptions(rt.AgentMemoryStoreActivity, activity.RegisterOptions{Name: "AgentMemoryStoreActivity"}) w.RegisterActivityWithOptions(rt.AgentToolAuthorizeActivity, activity.RegisterOptions{Name: "AgentToolAuthorizeActivity"}) w.RegisterActivityWithOptions(rt.AgentToolApprovalActivity, activity.RegisterOptions{Name: "AgentToolApprovalActivity"}) w.RegisterActivityWithOptions(rt.AgentToolExecuteActivity, activity.RegisterOptions{Name: "AgentToolExecuteActivity"}) @@ -299,6 +301,13 @@ func (rt *TemporalRuntime) Execute(ctx context.Context, req *runtime.ExecuteRequ } conversationID := base.GetConversationID(req) + memoryScope, memErr := rt.ResolveMemoryScope(runCtx) + if memErr != nil { + rt.logger.Warn(runCtx, "runtime memory scope resolve failed, continuing with empty scope", + slog.String("scope", "runtime"), + slog.Any("error", memErr)) + memoryScope = interfaces.MemoryScope{} + } runID := uuid.New().String() threadID := conversationID @@ -324,6 +333,7 @@ func (rt *TemporalRuntime) Execute(ctx context.Context, req *runtime.ExecuteRequ EventWorkflowID: "", LocalChannelName: eventChannelName(workflowID), ConversationID: conversationID, + MemoryScope: memoryScope, AgentFingerprint: computeAgentFingerprintFromRuntime(rt, req.Tools), EventTypes: []events.AgentEventType{}, SubAgentDepth: 0, @@ -460,6 +470,13 @@ func (rt *TemporalRuntime) ExecuteStream(ctx context.Context, req *runtime.Execu rt.logger.Debug(ctx, "runtime stream run dispatch", slog.String("scope", "runtime"), slog.String("agent", agentNameFromRuntime(rt)), slog.Int("inputLen", len(req.UserPrompt))) conversationID := base.GetConversationID(req) + memoryScope, memErr := rt.ResolveMemoryScope(ctx) + if memErr != nil { + rt.logger.Warn(ctx, "runtime memory scope resolve failed, continuing with empty scope", + slog.String("scope", "runtime"), + slog.Any("error", memErr)) + memoryScope = interfaces.MemoryScope{} + } runID := uuid.New().String() threadID := conversationID @@ -508,6 +525,7 @@ func (rt *TemporalRuntime) ExecuteStream(ctx context.Context, req *runtime.Execu LocalChannelName: eventChannelName(workflowID), StreamingEnabled: req.StreamingEnabled, ConversationID: conversationID, + MemoryScope: memoryScope, AgentFingerprint: computeAgentFingerprintFromRuntime(rt, req.Tools), EventTypes: streamEventTypes, SubAgentDepth: 0, diff --git a/internal/testing/inmem_memory.go b/internal/testing/inmem_memory.go new file mode 100644 index 0000000..cacb09c --- /dev/null +++ b/internal/testing/inmem_memory.go @@ -0,0 +1,206 @@ +// Package testutil provides test helpers for the agent SDK. +package testutil + +import ( + "context" + "encoding/json" + "sort" + "strings" + "sync" + "time" + + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" + "github.com/agenticenv/agent-sdk-go/pkg/memory" + "github.com/google/uuid" +) + +var _ interfaces.Memory = (*InmemMemory)(nil) + +type record struct { + entry interfaces.MemoryEntry +} + +// InmemMemory is a mutex-protected in-memory [interfaces.Memory] for tests. +type InmemMemory struct { + mu sync.RWMutex + records map[string]record +} + +// NewInmemMemory returns an empty in-memory store. +func NewInmemMemory() *InmemMemory { + return &InmemMemory{records: make(map[string]record)} +} + +// Store persists a memory in scope and returns its ID. +func (m *InmemMemory) Store(ctx context.Context, scope interfaces.MemoryScope, rec interfaces.MemoryRecord, opts ...interfaces.StoreMemoryOption) (string, error) { + _ = ctx + storeOpts := interfaces.StoreMemoryOptions{} + for _, opt := range opts { + opt(&storeOpts) + } + + id := strings.TrimSpace(storeOpts.ID) + if id == "" { + id = uuid.NewString() + } + + now := time.Now().UTC() + m.mu.Lock() + defer m.mu.Unlock() + + existing, ok := m.records[id] + createdAt := now + if ok { + createdAt = existing.entry.CreatedAt + } + + entry := interfaces.MemoryEntry{ + ID: id, + Text: rec.Text, + Kind: rec.Kind, + Scope: cloneScope(scope), + Metadata: cloneMetadata(rec.Metadata), + ExpiresAt: rec.ExpiresAt.UTC(), + CreatedAt: createdAt, + UpdatedAt: now, + } + m.records[id] = record{entry: entry} + return id, nil +} + +// Load retrieves memories within scope. +func (m *InmemMemory) Load(ctx context.Context, scope interfaces.MemoryScope, query string, opts ...interfaces.LoadMemoryOption) ([]interfaces.MemoryEntry, error) { + _ = ctx + loadOpts := interfaces.LoadMemoryOptions{} + for _, opt := range opts { + opt(&loadOpts) + } + limit := loadOpts.Limit + if limit <= 0 { + limit = 10 + } + + m.mu.RLock() + defer m.mu.RUnlock() + + var matches []interfaces.MemoryEntry + query = strings.ToLower(strings.TrimSpace(query)) + recencyOnly := query == "" + for _, rec := range m.records { + entry := rec.entry + if entry.Expired() { + continue + } + if !scopeMatches(entry.Scope, scope) { + continue + } + if !kindMatches(entry.Kind, loadOpts.Kinds) { + continue + } + matched := entry + if !recencyOnly { + if !strings.Contains(strings.ToLower(entry.Text), query) { + continue + } + matched.Score = 1.0 + } + if !recencyOnly && loadOpts.MinScore > 0 && matched.Score < loadOpts.MinScore { + continue + } + matches = append(matches, matched) + } + + sort.Slice(matches, func(i, j int) bool { + return matches[i].UpdatedAt.After(matches[j].UpdatedAt) + }) + if len(matches) > limit { + matches = matches[:limit] + } + return cloneEntries(matches), nil +} + +// Clear removes all memories matching scope. +func (m *InmemMemory) Clear(ctx context.Context, scope interfaces.MemoryScope) error { + _ = ctx + if scopeIsEmpty(scope) { + return nil + } + m.mu.Lock() + defer m.mu.Unlock() + for id, rec := range m.records { + if scopeMatches(rec.entry.Scope, scope) { + delete(m.records, id) + } + } + return nil +} + +func scopeMatches(stored, filter interfaces.MemoryScope) bool { + metaStored := memory.ScopeMetadata(stored) + metaFilter := memory.ScopeMetadata(filter) + for key, want := range metaFilter { + if got, ok := metaStored[key]; !ok || got != want { + return false + } + } + return true +} + +func kindMatches(kind interfaces.MemoryKind, kinds []interfaces.MemoryKind) bool { + if len(kinds) == 0 { + return true + } + for _, allowed := range kinds { + if kind == allowed { + return true + } + } + return false +} + +func scopeIsEmpty(scope interfaces.MemoryScope) bool { + return scope.UserID == "" && scope.TenantID == "" && scope.AgentID == "" && len(scope.Tags) == 0 +} + +func cloneScope(scope interfaces.MemoryScope) interfaces.MemoryScope { + out := interfaces.MemoryScope{ + UserID: scope.UserID, + TenantID: scope.TenantID, + AgentID: scope.AgentID, + } + if len(scope.Tags) > 0 { + out.Tags = make(map[string]string, len(scope.Tags)) + for k, v := range scope.Tags { + out.Tags[k] = v + } + } + return out +} + +func cloneMetadata(metadata map[string]string) map[string]string { + if len(metadata) == 0 { + return nil + } + raw, _ := json.Marshal(metadata) + var out map[string]string + _ = json.Unmarshal(raw, &out) + return out +} + +func cloneEntries(entries []interfaces.MemoryEntry) []interfaces.MemoryEntry { + out := make([]interfaces.MemoryEntry, len(entries)) + for i, entry := range entries { + out[i] = interfaces.MemoryEntry{ + ID: entry.ID, + Text: entry.Text, + Kind: entry.Kind, + Scope: cloneScope(entry.Scope), + Metadata: cloneMetadata(entry.Metadata), + ExpiresAt: entry.ExpiresAt, + Score: entry.Score, + CreatedAt: entry.CreatedAt, + UpdatedAt: entry.UpdatedAt, + } + } + return out +} diff --git a/internal/types/memory.go b/internal/types/memory.go new file mode 100644 index 0000000..68c874d --- /dev/null +++ b/internal/types/memory.go @@ -0,0 +1,14 @@ +package types + +// MemoryEntryFormat is the printf format used to render a single [interfaces.MemoryEntry] for LLM context. +// Arguments: 1-based index (int), text (string), kind (string), score (float32). +const MemoryEntryFormat = "[%d] %s\n(kind: %s, score: %.2f)\n\n" + +// SaveMemoryToolName is the LLM-facing tool name for on-demand long-term memory store. +const SaveMemoryToolName = "save_memory" + +// Memory tool JSON parameter names. +const ( + MemoryToolParamText = "text" + MemoryToolParamKind = "kind" +) diff --git a/internal/types/metrics.go b/internal/types/metrics.go index 7724d7d..7d0cb7b 100644 --- a/internal/types/metrics.go +++ b/internal/types/metrics.go @@ -46,9 +46,40 @@ const ( // Runtime — retriever search wall-clock latency. MetricRetrieverLatencyMs = "agent.retriever.latency_ms" + // Runtime — emitted per memory.Load (recall) call. + MetricMemoryRecallStarted = "agent.memory.recall.started" + MetricMemoryRecallCompleted = "agent.memory.recall.completed" + MetricMemoryRecallFailed = "agent.memory.recall.failed" + + // Runtime — memory recall wall-clock latency. + MetricMemoryRecallLatencyMs = "agent.memory.recall.latency_ms" + + // Runtime — emitted per memory.Store call. + MetricMemoryStoreStarted = "agent.memory.store.started" + MetricMemoryStoreCompleted = "agent.memory.store.completed" + MetricMemoryStoreFailed = "agent.memory.store.failed" + + // Runtime — memory store wall-clock latency. + MetricMemoryStoreLatencyMs = "agent.memory.store.latency_ms" + + // Runtime — semantic dedup lookup before memory.Store (Load for upsert decision). + MetricMemoryDedupStarted = "agent.memory.dedup.started" + MetricMemoryDedupCompleted = "agent.memory.dedup.completed" + MetricMemoryDedupFailed = "agent.memory.dedup.failed" + MetricMemoryDedupLatencyMs = "agent.memory.dedup.latency_ms" + + // Runtime — run-end memory extraction (StoreMode always). + MetricMemoryExtractStarted = "agent.memory.extract.started" + MetricMemoryExtractCompleted = "agent.memory.extract.completed" + MetricMemoryExtractFailed = "agent.memory.extract.failed" + MetricMemoryExtractLatencyMs = "agent.memory.extract.latency_ms" + // Attribute keys used on both metrics and spans. - MetricAttrModel = "model" - MetricAttrProvider = "provider" - MetricAttrTool = "tool" - MetricAttrRetriever = "retriever" + MetricAttrModel = "model" + MetricAttrProvider = "provider" + MetricAttrTool = "tool" + MetricAttrRetriever = "retriever" + MetricAttrMemoryKind = "memory.kind" + // MetricAttrMemoryDedup is "upsert" when an existing record is updated, else "append". + MetricAttrMemoryDedup = "memory.dedup" ) diff --git a/internal/types/telemetry.go b/internal/types/telemetry.go index 69c4cc1..a28c6d1 100644 --- a/internal/types/telemetry.go +++ b/internal/types/telemetry.go @@ -93,4 +93,10 @@ type StorageTelemetry struct { // Breakdown by mode — zero if mode not used. PrefetchSearches int64 `json:"prefetch_searches,omitempty"` AgenticSearches int64 `json:"agentic_searches,omitempty"` + + // Memory — long-term recall/store operations, zero if not configured. + TotalMemoryRecalls int64 `json:"total_memory_recalls,omitempty"` + FailedMemoryRecalls int64 `json:"failed_memory_recalls,omitempty"` + TotalMemoryStores int64 `json:"total_memory_stores,omitempty"` + FailedMemoryStores int64 `json:"failed_memory_stores,omitempty"` } diff --git a/internal/types/tool.go b/internal/types/tool.go index 42d5ee8..36bb291 100644 --- a/internal/types/tool.go +++ b/internal/types/tool.go @@ -24,6 +24,7 @@ const ( ToolKindA2A ToolKind = "a2a" ToolKindSubAgent ToolKind = "sub_agent" ToolKindRetriever ToolKind = "retriever" + ToolKindMemory ToolKind = "memory" ) // ToolKindProvider is implemented by SDK tool wrappers (MCP, A2A, sub-agent, retriever). diff --git a/internal/types/tool_test.go b/internal/types/tool_test.go index 43f3f9e..9056fbd 100644 --- a/internal/types/tool_test.go +++ b/internal/types/tool_test.go @@ -18,6 +18,9 @@ func TestKindOf(t *testing.T) { if KindOf(stubKindTool{kind: ToolKindMCP}) != ToolKindMCP { t.Fatal("mcp kind") } + if KindOf(stubKindTool{kind: ToolKindMemory}) != ToolKindMemory { + t.Fatal("memory kind") + } if KindOf(stubKindTool{kind: ""}) != ToolKindNative { t.Fatal("empty kind falls back to native") } @@ -32,4 +35,7 @@ func TestToolKind_CountsTowardToolTelemetry(t *testing.T) { t.Fatalf("%q should not count toward tool telemetry", k) } } + if !ToolKindMemory.CountsTowardToolTelemetry() { + t.Fatal("memory tool should count toward tool telemetry") + } } diff --git a/pkg/agent/agent.go b/pkg/agent/agent.go index a71aaa6..450c518 100644 --- a/pkg/agent/agent.go +++ b/pkg/agent/agent.go @@ -14,6 +14,7 @@ import ( "github.com/agenticenv/agent-sdk-go/internal/runtime" "github.com/agenticenv/agent-sdk-go/internal/types" "github.com/agenticenv/agent-sdk-go/pkg/interfaces" + "github.com/agenticenv/agent-sdk-go/pkg/memory" ) // Agent runs LLM-backed agent execution through the configured execution runtime. @@ -142,6 +143,7 @@ func (a *Agent) Run(ctx context.Context, input string, opts *AgentRunOptions) (* } func (a *Agent) runInternal(ctx context.Context, input string, opts *AgentRunOptions, runAsync bool) (*AgentRunResult, error) { + ctx = a.attachMemoryScopeContext(ctx) conversationID := conversationIDFromOpts(opts) spanName := "agent.run" @@ -244,6 +246,7 @@ func copyApprovalArgs(src map[string]any) map[string]any { func (a *Agent) Stream(ctx context.Context, input string, opts *AgentRunOptions) (<-chan events.AgentEvent, error) { a.logger.Debug(ctx, "agent run stream started", slog.String("scope", "agent"), slog.String("name", a.Name), slog.Int("inputLen", len(input))) + ctx = a.attachMemoryScopeContext(ctx) conversationID := conversationIDFromOpts(opts) start := time.Now() @@ -292,6 +295,13 @@ func (a *Agent) Stream(ctx context.Context, input string, opts *AgentRunOptions) return streamCh, nil } +func (a *Agent) attachMemoryScopeContext(ctx context.Context) context.Context { + if a.Name != "" { + ctx = memory.WithContextAgentID(ctx, a.Name) + } + return ctx +} + func conversationIDFromOpts(opts *AgentRunOptions) string { if opts != nil && opts.ConversationOptions != nil { return opts.ConversationOptions.ID @@ -300,10 +310,10 @@ func conversationIDFromOpts(opts *AgentRunOptions) string { } func (a *Agent) validateConversationID(conversationID string) error { - if conversationID != "" && a.conversation == nil { + if conversationID != "" && a.conversationConfig == nil { return fmt.Errorf("conversationID %s requires conversation configuration", conversationID) } - if conversationID == "" && a.conversation != nil { + if conversationID == "" && a.conversationConfig != nil { return fmt.Errorf("conversationID is required when using conversation") } return nil diff --git a/pkg/agent/agent_test.go b/pkg/agent/agent_test.go index 69bd15c..7296349 100644 --- a/pkg/agent/agent_test.go +++ b/pkg/agent/agent_test.go @@ -14,6 +14,7 @@ import ( "github.com/agenticenv/agent-sdk-go/internal/runtime" rtmocks "github.com/agenticenv/agent-sdk-go/internal/runtime/mocks" "github.com/agenticenv/agent-sdk-go/internal/types" + "github.com/agenticenv/agent-sdk-go/pkg/conversation" "github.com/agenticenv/agent-sdk-go/pkg/interfaces" "github.com/agenticenv/agent-sdk-go/pkg/logger" "github.com/agenticenv/agent-sdk-go/pkg/observability" @@ -235,7 +236,7 @@ func TestAgent_Run_ForwardsRunOptions(t *testing.T) { }) a := testAgentWithRuntime(mockRT) - a.conversation = &mockConversation{} + a.conversationConfig = &conversation.Config{Conversation: &mockConversation{}} _, err := a.Run(context.Background(), "hello", opts) if err != nil { t.Fatal(err) @@ -244,7 +245,7 @@ func TestAgent_Run_ForwardsRunOptions(t *testing.T) { func TestAgent_Stream_RejectsMissingConversationID(t *testing.T) { a := testAgentWithRuntime(&stubRuntime{}) - a.conversation = &mockConversation{} + a.conversationConfig = &conversation.Config{Conversation: &mockConversation{}} _, err := a.Stream(context.Background(), "prompt", nil) if err == nil { t.Fatal("expected error when conversation configured but opts nil") @@ -262,7 +263,7 @@ func TestAgent_ValidateConversationID(t *testing.T) { t.Error("non-empty conversationID with no conversation should error") } - a.conversation = &mockConversation{} + a.conversationConfig = &conversation.Config{Conversation: &mockConversation{}} if err := a.validateConversationID(""); err == nil { t.Error("empty conversationID with conversation should error") } diff --git a/pkg/agent/config.go b/pkg/agent/config.go index 3df6850..6654694 100644 --- a/pkg/agent/config.go +++ b/pkg/agent/config.go @@ -16,8 +16,10 @@ import ( "github.com/agenticenv/agent-sdk-go/internal/runtime" "github.com/agenticenv/agent-sdk-go/internal/runtime/temporal" "github.com/agenticenv/agent-sdk-go/internal/types" + "github.com/agenticenv/agent-sdk-go/pkg/conversation" "github.com/agenticenv/agent-sdk-go/pkg/interfaces" "github.com/agenticenv/agent-sdk-go/pkg/logger" + "github.com/agenticenv/agent-sdk-go/pkg/memory" "github.com/agenticenv/agent-sdk-go/pkg/observability" "github.com/google/uuid" "go.temporal.io/sdk/client" @@ -184,7 +186,7 @@ type ObservabilityConfig struct { // - Both: WithName, WithDescription, WithSystemPrompt, WithTemporalConfig, WithTemporalClient, // WithInstanceId, WithLLMClient, WithToolApprovalPolicy, WithTools, WithToolRegistry, // WithMCPRegistry, WithA2ARegistry, WithSubAgentRegistry, -// WithMaxIterations, WithStream, WithLogger, WithLogLevel, WithConversation, WithConversationSize, EnableConversationSaveOnIteration, +// WithMaxIterations, WithStream, WithLogger, WithLogLevel, WithConversation, WithMemory, // WithResponseFormat, WithLLMSampling, WithSubAgents, WithMaxSubAgentDepth, // WithMCPConfig, WithMCPClients, WithA2AConfig, WithA2AClients, WithRetrievers, WithRetrieverMode, WithAgentMode, WithDisableFingerprintCheck, WithAgentToolExecutionMode, // WithObservabilityConfig, WithTracer, WithMetrics, WithLogs @@ -215,9 +217,9 @@ type agentConfig struct { timeout time.Duration approvalTimeout time.Duration // max wait per tool approval; must be < timeout when tools require approval - conversation interfaces.Conversation - conversationSize int // max messages to fetch for LLM context (default 20) - conversationSaveOnIteration bool // save the conversation on each iteration, defaults to false + conversationConfig *conversation.Config + + memoryConfig *memory.Config // responseFormat: when set, LLM requests use it; otherwise use text-only (no JSON schema). responseFormat *interfaces.ResponseFormat @@ -442,34 +444,29 @@ func WithDisableFingerprintCheck(disable bool) Option { return func(c *agentConfig) { c.disableFingerprintCheck = disable } } -// WithConversation sets the conversation for message history. Applies to Agent and AgentWorker. -// The user creates the conversation (inmem or redis) and passes it to the agent. -// System messages are not stored; agent SystemPrompt is used for LLM calls. +// WithConversation sets conversation history for the agent. Applies to Agent and AgentWorker. +// Pass [conversation.DefaultConfig] for SDK defaults or a custom [conversation.Config]. // // Choose implementation based on deployment: // - Single process: use inmem.NewInMemoryConversation // - Remote workers: use redis.NewRedisConversation (in-memory cannot be used across processes) // -// The user owns the conversation lifecycle. Call Clear on the conversation when appropriate -// (e.g., when ending a session). The agent never calls Clear. -// Note: Agent and worker must use the same conversation and ID when using remote workers. -func WithConversation(conv interfaces.Conversation) Option { - return func(c *agentConfig) { c.conversation = conv } -} - -// WithConversationSize sets the max messages to fetch for LLM context (default 20). -func WithConversationSize(size int) Option { - return func(c *agentConfig) { c.conversationSize = size } +// The application owns Clear on the conversation when required; the agent runtime does not call Clear. +// Agent and worker must use the same conversation config and ID when using remote workers. +func WithConversation(cfg conversation.Config) Option { + return func(c *agentConfig) { + c.conversationConfig = &cfg + } } -// EnableConversationSaveOnIteration persists conversation messages after each tool round instead of -// batching the full run at the end. Use when external consumers need live updates from conversation -// storage (e.g. Redis) between iterations. This degrades performance. -// -// For Temporal, set this on [AgentWorker] (worker process) where [WithConversation] is configured; -// the agent caller process does not need it. -func EnableConversationSaveOnIteration() Option { - return func(c *agentConfig) { c.conversationSaveOnIteration = true } +// WithMemory enables long-term memory for the agent. Applies to Agent and AgentWorker. +// Pass [memory.DefaultConfig] for SDK defaults or a custom [memory.Config]. +// Use a shipped backend (pkg/memory/weaviate or pkg/memory/pgvector) or your own [interfaces.Memory]. +// The application owns Clear on the store when required; the agent runtime does not call Clear. +func WithMemory(cfg memory.Config) Option { + return func(c *agentConfig) { + c.memoryConfig = &cfg + } } // WithResponseFormat sets the LLM response format (e.g. JSON with schema). Applies to Agent and AgentWorker. @@ -703,11 +700,27 @@ func buildAgentConfig(opts []Option) (*agentConfig, error) { if c.LLMClient == nil { return nil, errors.New("LLM client is required") } - if c.conversation != nil && (c.enableRemoteWorkers || c.disableLocalWorker) && !c.conversation.IsDistributed() { - return nil, errors.New("in-memory conversation cannot be used with remote workers (DisableLocalWorker or EnableRemoteWorkers()): use distributed storage such as redis.NewRedisConversation") + if c.conversationConfig != nil { + cfg := c.conversationConfig.WithDefaults() + if err := cfg.Validate(); err != nil { + return nil, err + } + remoteWorkers := c.enableRemoteWorkers || c.disableLocalWorker + if err := conversation.ValidateDistributed(cfg.Conversation, remoteWorkers); err != nil { + return nil, err + } + c.conversationConfig = &cfg } - if c.conversationSize <= 0 { - c.conversationSize = 20 + if c.memoryConfig != nil { + cfg := c.memoryConfig.WithDefaults() + if err := cfg.Validate(); err != nil { + return nil, err + } + if cfg.Store.Mode == memory.StoreModeAlways && cfg.Store.Extract == nil && c.LLMClient == nil { + c.logger.Warn(context.Background(), "memory StoreMode always requires custom Extract or LLM client — run-end store will be skipped", + slog.String("scope", "agent")) + } + c.memoryConfig = &cfg } if c.agentMode == "" { c.agentMode = AgentModeInteractive @@ -856,7 +869,7 @@ func buildAgentConfig(opts []Option) (*agentConfig, error) { slog.Int("subAgentRegistryCount", len(c.subAgentRegistry.List())), slog.Int("retrieverCount", len(c.retrievers)), slog.String("retrieverMode", string(c.retrieverMode)), - slog.Bool("hasConversation", c.conversation != nil), + slog.Bool("hasConversation", c.conversationConfig != nil), slog.Bool("hasObservability", c.observabilityConfig != nil), slog.Bool("enabledTracer", c.tracer != nil), slog.Bool("enabledMetrics", c.metrics != nil), @@ -1080,6 +1093,12 @@ func (c *agentConfig) resolveTools(ctx context.Context) ([]interfaces.Tool, erro } tools = append(tools, retrieverTools...) + memoryTools, err := c.resolveMemoryTools() + if err != nil { + return nil, err + } + tools = append(tools, memoryTools...) + if err := validateToolNames(tools); err != nil { return nil, err } @@ -1091,6 +1110,18 @@ func (c *agentConfig) resolveTools(ctx context.Context) ([]interfaces.Tool, erro return tools, nil } +// resolveMemoryTools registers save_memory when long-term memory uses on-demand store. +func (c *agentConfig) resolveMemoryTools() ([]interfaces.Tool, error) { + if c.memoryConfig == nil { + return nil, nil + } + cfg := c.memoryConfig.WithDefaults() + if cfg.Memory == nil || cfg.Store.Mode != memory.StoreModeOnDemand { + return nil, nil + } + return []interfaces.Tool{NewRegisteredMemoryTool()}, nil +} + // resolveSubAgentTools returns sub-agent delegation tools from [subAgentRegistry]. func (c *agentConfig) resolveSubAgentTools() ([]interfaces.Tool, error) { if c.subAgentRegistry == nil { @@ -1169,6 +1200,27 @@ func (c *agentConfig) runtimeAgentSpec() runtime.AgentSpec { } } +// runtimeAgentMemory maps memory config onto the runtime memory view. +func (c *agentConfig) runtimeAgentMemory() runtime.AgentMemory { + if c.memoryConfig == nil { + return runtime.AgentMemory{} + } + cfg := c.memoryConfig.WithDefaults() + return runtime.AgentMemory{Config: &cfg} +} + +// runtimeAgentSession maps conversation config onto the runtime session view. +func (c *agentConfig) runtimeAgentSession() runtime.AgentSession { + if c.conversationConfig == nil { + return runtime.AgentSession{} + } + return runtime.AgentSession{ + Conversation: c.conversationConfig.Conversation, + ConversationSize: c.conversationConfig.Size, + ConversationSaveOnIteration: c.conversationConfig.SaveOnIteration, + } +} + // runtimeAgentConfig is static wiring copied onto the runtime at construction. func (c *agentConfig) runtimeAgentConfig() runtime.AgentConfig { d := runtime.AgentConfig{ @@ -1180,11 +1232,8 @@ func (c *agentConfig) runtimeAgentConfig() runtime.AgentConfig { Retrievers: c.retrievers, Mode: c.retrieverMode, }, - Session: runtime.AgentSession{ - Conversation: c.conversation, - ConversationSize: c.conversationSize, - ConversationSaveOnIteration: c.conversationSaveOnIteration, - }, + Session: c.runtimeAgentSession(), + Memory: c.runtimeAgentMemory(), Limits: runtime.AgentLimits{ MaxIterations: c.maxIterations, Timeout: c.timeout, diff --git a/pkg/agent/config_test.go b/pkg/agent/config_test.go index 1fdb58d..6848292 100644 --- a/pkg/agent/config_test.go +++ b/pkg/agent/config_test.go @@ -17,6 +17,7 @@ import ( "github.com/agenticenv/agent-sdk-go/pkg/interfaces" "github.com/agenticenv/agent-sdk-go/pkg/logger" mcpclient "github.com/agenticenv/agent-sdk-go/pkg/mcp/client" + "github.com/agenticenv/agent-sdk-go/pkg/memory" "github.com/agenticenv/agent-sdk-go/pkg/observability" "github.com/modelcontextprotocol/go-sdk/mcp" ) @@ -31,12 +32,16 @@ func agentConfigFingerprint(c *agentConfig) string { } func agentConfigFingerprintTools(c *agentConfig, tools []interfaces.Tool) string { + convSize := 0 + if c.conversationConfig != nil { + convSize = c.conversationConfig.Size + } return temporal.ComputeAgentFingerprint(temporal.BuildAgentFingerprintPayload( c.runtimeAgentSpec(), temporal.ToolNamesFromTools(tools), toolPolicyFingerprint(c.toolApprovalPolicy), llmSamplingRuntimeView(c.llmSampling), - c.conversationSize, + convSize, runtime.AgentLimits{ MaxIterations: c.maxIterations, Timeout: c.timeout, @@ -683,6 +688,140 @@ func TestBuildRetrieverTools(t *testing.T) { }) } +func TestResolveMemoryTools(t *testing.T) { + stub := stubMemoryBackend{} + t.Run("ondemand", func(t *testing.T) { + cfg := memory.DefaultConfig(stub) + c := &agentConfig{memoryConfig: &cfg} + tools, err := c.resolveMemoryTools() + if err != nil { + t.Fatal(err) + } + if len(tools) != 1 || tools[0].Name() != types.SaveMemoryToolName { + t.Fatalf("tools = %v", tools) + } + }) + t.Run("always", func(t *testing.T) { + cfg := memory.DefaultConfig(stub) + cfg.Store.Mode = memory.StoreModeAlways + c := &agentConfig{memoryConfig: &cfg} + tools, err := c.resolveMemoryTools() + if err != nil { + t.Fatal(err) + } + if len(tools) != 0 { + t.Fatalf("tools = %v", tools) + } + }) + t.Run("no_memory", func(t *testing.T) { + c := &agentConfig{} + tools, err := c.resolveMemoryTools() + if err != nil { + t.Fatal(err) + } + if len(tools) != 0 { + t.Fatalf("tools = %v", tools) + } + }) +} + +func TestBuildAgentConfig_WithMemory_registersSaveMemory(t *testing.T) { + cfg, err := buildAgentConfig([]Option{ + WithName("test"), + WithTemporalConfig(&TemporalConfig{TaskQueue: "q"}), + WithLLMClient(stubLLM{}), + WithMemory(memory.DefaultConfig(stubMemoryBackend{})), + }) + if err != nil { + t.Fatal(err) + } + tools, err := cfg.resolveTools(context.Background()) + if err != nil { + t.Fatal(err) + } + found := false + for _, tool := range tools { + if tool.Name() == types.SaveMemoryToolName { + found = true + break + } + } + if !found { + t.Fatal("save_memory not in resolved tools") + } +} + +func TestBuildAgentConfig_WithMemoryAlways_leavesExtractNil(t *testing.T) { + cfg := memory.DefaultConfig(stubMemoryBackend{}) + cfg.Store.Mode = memory.StoreModeAlways + got, err := buildAgentConfig([]Option{ + WithName("test"), + WithTemporalConfig(&TemporalConfig{TaskQueue: "q"}), + WithLLMClient(stubLLM{}), + WithMemory(cfg), + }) + if err != nil { + t.Fatal(err) + } + mem := got.runtimeAgentMemory() + if mem.Config == nil || mem.Config.Store.Extract != nil { + t.Fatal("expected nil Extract on config; default resolves lazily at run-end") + } +} + +func TestBuildAgentConfig_WithMemoryOnDemand_noExtract(t *testing.T) { + cfg := memory.DefaultConfig(stubMemoryBackend{}) + got, err := buildAgentConfig([]Option{ + WithName("test"), + WithTemporalConfig(&TemporalConfig{TaskQueue: "q"}), + WithLLMClient(stubLLM{}), + WithMemory(cfg), + }) + if err != nil { + t.Fatal(err) + } + mem := got.runtimeAgentMemory() + if mem.Config == nil || mem.Config.Store.Extract != nil { + t.Fatal("expected nil Extract for ondemand") + } +} + +func TestBuildAgentConfig_WithMemoryAlways_preservesCustomExtract(t *testing.T) { + custom := memory.ExtractFunc(func(context.Context, []interfaces.Message) ([]interfaces.MemoryRecord, error) { + return nil, nil + }) + cfg := memory.DefaultConfig(stubMemoryBackend{}) + cfg.Store.Mode = memory.StoreModeAlways + cfg.Store.Extract = custom + got, err := buildAgentConfig([]Option{ + WithName("test"), + WithTemporalConfig(&TemporalConfig{TaskQueue: "q"}), + WithLLMClient(stubLLM{}), + WithMemory(cfg), + }) + if err != nil { + t.Fatal(err) + } + mem := got.runtimeAgentMemory() + if mem.Config == nil || mem.Config.Store.Extract == nil { + t.Fatal("expected custom extract") + } + records, err := mem.Config.Store.Extract(context.Background(), nil) + if err != nil || records != nil { + t.Fatalf("custom extract: records=%v err=%v", records, err) + } +} + +type stubMemoryBackend struct{} + +func (stubMemoryBackend) Store(context.Context, interfaces.MemoryScope, interfaces.MemoryRecord, ...interfaces.StoreMemoryOption) (string, error) { + return "", nil +} +func (stubMemoryBackend) Load(context.Context, interfaces.MemoryScope, string, ...interfaces.LoadMemoryOption) ([]interfaces.MemoryEntry, error) { + return nil, nil +} +func (stubMemoryBackend) Clear(context.Context, interfaces.MemoryScope) error { return nil } + func TestBuildAgentConfig_WithRetrievers(t *testing.T) { r1, r2 := namedStubRetriever("kb-a"), namedStubRetriever("kb-b") cfg, err := buildAgentConfig([]Option{ diff --git a/pkg/agent/memory.go b/pkg/agent/memory.go new file mode 100644 index 0000000..0dfc143 --- /dev/null +++ b/pkg/agent/memory.go @@ -0,0 +1,91 @@ +package agent + +import ( + "context" + "fmt" + "strings" + + "github.com/agenticenv/agent-sdk-go/internal/types" + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" + "github.com/agenticenv/agent-sdk-go/pkg/tools" +) + +var _ interfaces.Tool = (*MemoryTool)(nil) +var _ types.ToolKindProvider = (*MemoryTool)(nil) + +// ErrMemoryToolNotExecutable is returned when save_memory runs outside a managed agent runtime. +var ErrMemoryToolNotExecutable = fmt.Errorf("save_memory must be executed via runtime") + +// MemoryStoreFunc persists extracted memory records for the current run scope. +type MemoryStoreFunc func(ctx context.Context, records []interfaces.MemoryRecord) error + +// MemoryTool implements [interfaces.Tool] for on-demand long-term memory store ([memory.StoreModeOnDemand]). +type MemoryTool struct { + Store MemoryStoreFunc +} + +// NewMemoryTool returns a save_memory tool. Returns nil when store is nil. +func NewMemoryTool(store MemoryStoreFunc) interfaces.Tool { + if store == nil { + return nil + } + return &MemoryTool{Store: store} +} + +// NewRegisteredMemoryTool returns save_memory for agent tool registration ([memory.StoreModeOnDemand]). +func NewRegisteredMemoryTool() interfaces.Tool { + return &MemoryTool{} +} + +// ToolKind implements [types.ToolKindProvider]. +func (t *MemoryTool) ToolKind() types.ToolKind { return types.ToolKindMemory } + +// Name implements [interfaces.Tool]. +func (t *MemoryTool) Name() string { return types.SaveMemoryToolName } + +// DisplayName implements [interfaces.Tool]. +func (t *MemoryTool) DisplayName() string { return "Save Memory" } + +// Description implements [interfaces.Tool]. +func (t *MemoryTool) Description() string { + return "Save a fact, preference, or decision to long-term memory for future runs. " + + "Required when the user asks to remember, save, or persist something for later — " + + "call this tool before acknowledging; a text reply alone does not store memory." +} + +// Parameters implements [interfaces.Tool]. +func (t *MemoryTool) Parameters() interfaces.JSONSchema { + return tools.Params(map[string]interfaces.JSONSchema{ + types.MemoryToolParamText: tools.ParamString( + "The memory text to store (fact, preference, or decision distilled from the conversation)", + ), + types.MemoryToolParamKind: tools.ParamString( + "Optional memory kind (e.g. preference, fact, decision, instruction, note)", + ), + }, types.MemoryToolParamText) +} + +// Execute implements [interfaces.Tool]. +func (t *MemoryTool) Execute(ctx context.Context, args map[string]any) (any, error) { + if t.Store == nil { + return nil, ErrMemoryToolNotExecutable + } + rawText, ok := args[types.MemoryToolParamText].(string) + if !ok { + return nil, fmt.Errorf("save_memory: %q parameter required", types.MemoryToolParamText) + } + text := strings.TrimSpace(rawText) + if text == "" { + return nil, fmt.Errorf("save_memory: %q must be non-empty", types.MemoryToolParamText) + } + + record := interfaces.MemoryRecord{Text: text} + if rawKind, ok := args[types.MemoryToolParamKind].(string); ok { + record.Kind = interfaces.MemoryKind(strings.TrimSpace(rawKind)) + } + + if err := t.Store(ctx, []interfaces.MemoryRecord{record}); err != nil { + return nil, err + } + return "memory saved", nil +} diff --git a/pkg/agent/memory_test.go b/pkg/agent/memory_test.go new file mode 100644 index 0000000..062d71f --- /dev/null +++ b/pkg/agent/memory_test.go @@ -0,0 +1,126 @@ +package agent + +import ( + "context" + "errors" + "testing" + + "github.com/agenticenv/agent-sdk-go/internal/types" + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" +) + +func TestNewMemoryTool_nil(t *testing.T) { + if NewMemoryTool(nil) != nil { + t.Fatal("expected nil") + } +} + +func TestMemoryTool_metadata(t *testing.T) { + tool := NewMemoryTool(func(context.Context, []interfaces.MemoryRecord) error { return nil }) + if tool.Name() != types.SaveMemoryToolName { + t.Fatalf("Name = %q", tool.Name()) + } + if tool.(*MemoryTool).ToolKind() != types.ToolKindMemory { + t.Fatalf("ToolKind = %q", tool.(*MemoryTool).ToolKind()) + } +} + +func TestMemoryTool_Execute_storesRecord(t *testing.T) { + var stored []interfaces.MemoryRecord + tool := NewMemoryTool(func(_ context.Context, records []interfaces.MemoryRecord) error { + stored = records + return nil + }) + + out, err := tool.Execute(context.Background(), map[string]any{ + types.MemoryToolParamText: "favorite color is blue", + types.MemoryToolParamKind: "preference", + }) + if err != nil { + t.Fatal(err) + } + if out != "memory saved" { + t.Fatalf("out = %v", out) + } + if len(stored) != 1 || stored[0].Text != "favorite color is blue" || stored[0].Kind != "preference" { + t.Fatalf("stored = %+v", stored) + } +} + +func TestMemoryTool_Execute_emptyText(t *testing.T) { + tool := NewMemoryTool(func(context.Context, []interfaces.MemoryRecord) error { return nil }) + _, err := tool.Execute(context.Background(), map[string]any{ + types.MemoryToolParamText: " ", + }) + if err == nil { + t.Fatal("expected error for empty text") + } +} + +func TestMemoryTool_Execute_missingText(t *testing.T) { + tool := NewMemoryTool(func(context.Context, []interfaces.MemoryRecord) error { return nil }) + _, err := tool.Execute(context.Background(), map[string]any{}) + if err == nil { + t.Fatal("expected error for missing text") + } +} + +func TestMemoryTool_Execute_storeError(t *testing.T) { + tool := NewMemoryTool(func(context.Context, []interfaces.MemoryRecord) error { + return errors.New("backend down") + }) + _, err := tool.Execute(context.Background(), map[string]any{ + types.MemoryToolParamText: "remember this", + }) + if err == nil || err.Error() != "backend down" { + t.Fatalf("err = %v", err) + } +} + +func TestMemoryTool_Execute_withoutKind(t *testing.T) { + var stored []interfaces.MemoryRecord + tool := NewMemoryTool(func(_ context.Context, records []interfaces.MemoryRecord) error { + stored = records + return nil + }) + + _, err := tool.Execute(context.Background(), map[string]any{ + types.MemoryToolParamText: "remember this", + }) + if err != nil { + t.Fatal(err) + } + if len(stored) != 1 || stored[0].Kind != "" { + t.Fatalf("stored = %+v", stored) + } +} + +func TestMemoryTool_Parameters(t *testing.T) { + tool := NewMemoryTool(func(context.Context, []interfaces.MemoryRecord) error { return nil }) + schema := tool.(*MemoryTool).Parameters() + required, ok := schema["required"].([]string) + if !ok || len(required) != 1 || required[0] != types.MemoryToolParamText { + t.Fatalf("required = %v", schema["required"]) + } + props, ok := schema["properties"].(map[string]interfaces.JSONSchema) + if !ok || props[types.MemoryToolParamText] == nil || props[types.MemoryToolParamKind] == nil { + t.Fatalf("properties = %v", schema["properties"]) + } +} + +func TestMemoryTool_Execute_nilStore(t *testing.T) { + tool := &MemoryTool{} + _, err := tool.Execute(context.Background(), map[string]any{ + types.MemoryToolParamText: "x", + }) + if !errors.Is(err, ErrMemoryToolNotExecutable) { + t.Fatalf("err = %v, want ErrMemoryToolNotExecutable", err) + } +} + +func TestNewRegisteredMemoryTool(t *testing.T) { + tool := NewRegisteredMemoryTool() + if tool == nil || tool.Name() != types.SaveMemoryToolName { + t.Fatalf("tool = %+v", tool) + } +} diff --git a/pkg/conversation/config.go b/pkg/conversation/config.go new file mode 100644 index 0000000..2e2b4ac --- /dev/null +++ b/pkg/conversation/config.go @@ -0,0 +1,51 @@ +package conversation + +import ( + "errors" + + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" +) + +// Config wires conversation history for the agent SDK. +type Config struct { + // Conversation is the conversation backend implementation. Required. + Conversation interfaces.Conversation + + // Size is the maximum number of messages to fetch for LLM context. + // Zero or negative defaults to [DefaultSize]. + Size int + + // SaveOnIteration persists messages after each tool round instead of batching at run end. + SaveOnIteration bool +} + +// WithDefaults fills zero fields with SDK defaults. Conversation must be set separately. +func (c Config) WithDefaults() Config { + if c.Size <= 0 { + c.Size = DefaultSize + } + return c +} + +// Validate checks the config. Call [WithDefaults] first. +func (c Config) Validate() error { + if c.Conversation == nil { + return errors.New("conversation config: Conversation is required") + } + return nil +} + +// ValidateDistributed returns an error when conv is not distributed but remote workers require it. +func ValidateDistributed(conv interfaces.Conversation, remoteWorkers bool) error { + if remoteWorkers && conv != nil && !conv.IsDistributed() { + return errors.New("in-memory conversation cannot be used with remote workers: use distributed storage such as redis.NewRedisConversation") + } + return nil +} + +// ListOptions builds [interfaces.ListMessagesOption] values from this config. +func (c Config) ListOptions() []interfaces.ListMessagesOption { + return []interfaces.ListMessagesOption{ + interfaces.WithLimit(c.Size), + } +} diff --git a/pkg/conversation/config_test.go b/pkg/conversation/config_test.go new file mode 100644 index 0000000..6440a20 --- /dev/null +++ b/pkg/conversation/config_test.go @@ -0,0 +1,71 @@ +package conversation_test + +import ( + "context" + "testing" + + "github.com/agenticenv/agent-sdk-go/pkg/conversation" + "github.com/agenticenv/agent-sdk-go/pkg/conversation/inmem" + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" +) + +func TestDefaultConfig(t *testing.T) { + conv := inmem.NewInMemoryConversation() + cfg := conversation.DefaultConfig(conv) + if err := cfg.Validate(); err != nil { + t.Fatal(err) + } + if cfg.Size != conversation.DefaultSize { + t.Fatalf("size = %d", cfg.Size) + } +} + +func TestConfig_WithDefaults(t *testing.T) { + cfg := (conversation.Config{Conversation: inmem.NewInMemoryConversation()}).WithDefaults() + if cfg.Size != conversation.DefaultSize { + t.Fatalf("size = %d", cfg.Size) + } +} + +func TestConfig_Validate_missingConversation(t *testing.T) { + if err := (conversation.Config{}).Validate(); err == nil { + t.Fatal("expected error") + } +} + +func TestConfig_ListOptions(t *testing.T) { + cfg := conversation.DefaultConfig(inmem.NewInMemoryConversation()) + opts := cfg.ListOptions() + if len(opts) != 1 { + t.Fatalf("opts len = %d", len(opts)) + } +} + +func TestValidateDistributed_inmemRemoteWorkers(t *testing.T) { + conv := inmem.NewInMemoryConversation() + if err := conversation.ValidateDistributed(conv, true); err == nil { + t.Fatal("expected distributed error") + } + if err := conversation.ValidateDistributed(conv, false); err != nil { + t.Fatal(err) + } +} + +type distributedConv struct { + interfaces.Conversation +} + +func (distributedConv) IsDistributed() bool { return true } +func (distributedConv) AddMessage(context.Context, string, interfaces.Message) error { + return nil +} +func (distributedConv) ListMessages(context.Context, string, ...interfaces.ListMessagesOption) ([]interfaces.Message, error) { + return nil, nil +} +func (distributedConv) Clear(context.Context, string) error { return nil } + +func TestValidateDistributed_distributedOK(t *testing.T) { + if err := conversation.ValidateDistributed(distributedConv{}, true); err != nil { + t.Fatal(err) + } +} diff --git a/pkg/conversation/defaults.go b/pkg/conversation/defaults.go new file mode 100644 index 0000000..54b04bb --- /dev/null +++ b/pkg/conversation/defaults.go @@ -0,0 +1,16 @@ +package conversation + +import ( + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" +) + +// DefaultSize is the default max messages fetched for LLM context. +const DefaultSize = 20 + +// DefaultConfig returns a [Config] with SDK defaults for size and save behavior. +func DefaultConfig(conv interfaces.Conversation) Config { + return Config{ + Conversation: conv, + Size: DefaultSize, + } +} diff --git a/pkg/interfaces/memory.go b/pkg/interfaces/memory.go new file mode 100644 index 0000000..901a211 --- /dev/null +++ b/pkg/interfaces/memory.go @@ -0,0 +1,161 @@ +package interfaces + +import ( + "context" + "time" +) + +//go:generate mockgen -destination=./mocks/mock_memory.go -package=mocks github.com/agenticenv/agent-sdk-go/pkg/interfaces Memory + +// MemoryKind is an arbitrary label for categorizing memories (e.g. "preference", "fact", "bug_report"). +// Different agents define their own taxonomies; the interface does not prescribe fixed kinds. +type MemoryKind string + +// MemoryScope identifies the namespace a memory belongs to. +// Non-empty fields are AND-ed when storing, loading, or clearing. +// +// UserID, TenantID, and AgentID are common isolation dimensions; Tags hold custom keys +// (e.g. project_id, env). The agent SDK may populate scope via pkg/memory ScopeConfig; +// implementations only see the resolved [MemoryScope] values. +type MemoryScope struct { + UserID string + TenantID string + AgentID string + + // Tags holds additional scope dimensions (e.g. project, team, environment). + // Do not duplicate UserID, TenantID, or AgentID as tag keys. + Tags map[string]string +} + +// MemoryRecord is the content written to long-term storage. +type MemoryRecord struct { + // Text is the distilled fact, preference, or instruction to remember. + Text string + + // Kind is an optional category label for filtered recall. Empty is allowed. + Kind MemoryKind + + // Metadata holds optional attributes (source run, confidence, custom tags). + // Scope fields from [MemoryScope] are stored separately and should not be duplicated here. + Metadata map[string]string + + // ExpiresAt is when the entry should no longer be recalled. Zero means no expiry. + // The agent SDK sets this from [memory.Config] TTL policy on Store; direct callers may leave it zero. + ExpiresAt time.Time +} + +// MemoryEntry is a memory returned from [Memory.Load]. +type MemoryEntry struct { + // ID is the stable record identifier assigned by the backend. + // Always non-empty; required for upserts via [WithMemoryID]. + ID string + + Text string + Kind MemoryKind + Scope MemoryScope + Metadata map[string]string + ExpiresAt time.Time + + // Score is query relevance when the backend supports ranked retrieval (e.g. vector search). + // Zero means not applicable (e.g. key-value or recency-only backends). + Score float32 + + CreatedAt time.Time + UpdatedAt time.Time +} + +// Expired reports whether the entry has passed its expiry time. +// Entries with a zero ExpiresAt never expire. +func (e MemoryEntry) Expired() bool { + return !e.ExpiresAt.IsZero() && time.Now().After(e.ExpiresAt) +} + +// Memory stores and retrieves long-term agent context across runs. +// +// Store/Load/Clear with scope and query work with vector, relational, or key-value backends. +// Scope drives isolation; query semantics are backend-specific (semantic search, filter, etc.). +// +// Agent SDK usage (runtime): +// 1. Store — persist extracted context after a run. +// 2. Load — recall memories for the current query before or during a run. +// +// Application usage (optional, not invoked by agent runtime): +// 3. Clear — remove all memories in a scope (e.g. tenant offboarding, forget-me). +type Memory interface { + // Store persists a new memory in the given scope and returns its assigned ID. + // The returned ID is always non-empty; implementations must assign one if the + // backend does not (e.g. a UUID). + // + // Expiry (ExpiresAt on [MemoryRecord]) is set by the agent SDK from TTL policy. + // Use [WithMemoryID] to upsert an existing record. + Store(ctx context.Context, scope MemoryScope, record MemoryRecord, opts ...StoreMemoryOption) (id string, err error) + + // Load retrieves memories within scope. + // Query drives ranking or filtering; pass empty query to list by recency when supported. + // Implementations must omit expired entries (ExpiresAt non-zero and in the past). + // + // Non-empty scope fields filter results (AND). Use [WithLoadLimit], [WithMinScore], + // and [WithLoadKinds] to narrow recall. + Load(ctx context.Context, scope MemoryScope, query string, opts ...LoadMemoryOption) ([]MemoryEntry, error) + + // Clear removes all memories matching the scope. Called by the application when + // required (e.g. user offboarding). Not invoked by the agent runtime. + // Warning: a TenantID-only scope deletes all memories for that tenant across every user and agent. + Clear(ctx context.Context, scope MemoryScope) error +} + +// --- Store options --- + +// StoreMemoryOptions configures a [Memory.Store] call. +type StoreMemoryOptions struct { + // ID upserts the record when non-empty. Use the ID returned by a prior Store call. + ID string +} + +type StoreMemoryOption func(*StoreMemoryOptions) + +// WithMemoryID upserts the record with the given ID when the backend supports +// stable identifiers. Use the ID returned by a prior [Memory.Store] call. +func WithMemoryID(id string) StoreMemoryOption { + return func(o *StoreMemoryOptions) { + o.ID = id + } +} + +// --- Load options --- + +// LoadMemoryOptions configures a [Memory.Load] call. +type LoadMemoryOptions struct { + // Limit is the maximum number of memories to return. Zero or negative means backend default. + Limit int + + // MinScore filters out entries below the given relevance score when Score is applicable. + MinScore float32 + + // Kinds restricts recall to the given memory kinds. Empty means all kinds. + Kinds []MemoryKind +} + +type LoadMemoryOption func(*LoadMemoryOptions) + +// WithLoadLimit sets the maximum number of memories to return. +// Zero or negative means backend default. +func WithLoadLimit(limit int) LoadMemoryOption { + return func(o *LoadMemoryOptions) { + o.Limit = limit + } +} + +// WithMinScore filters out entries below the given relevance score. +func WithMinScore(minScore float32) LoadMemoryOption { + return func(o *LoadMemoryOptions) { + o.MinScore = minScore + } +} + +// WithLoadKinds restricts recall to the given memory kinds. +func WithLoadKinds(kinds ...MemoryKind) LoadMemoryOption { + return func(o *LoadMemoryOptions) { + o.Kinds = kinds + } +} diff --git a/pkg/interfaces/mocks/mock_memory.go b/pkg/interfaces/mocks/mock_memory.go new file mode 100644 index 0000000..5c6f8ab --- /dev/null +++ b/pkg/interfaces/mocks/mock_memory.go @@ -0,0 +1,90 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/agenticenv/agent-sdk-go/pkg/interfaces (interfaces: Memory) + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + interfaces "github.com/agenticenv/agent-sdk-go/pkg/interfaces" + gomock "github.com/golang/mock/gomock" +) + +// MockMemory is a mock of Memory interface. +type MockMemory struct { + ctrl *gomock.Controller + recorder *MockMemoryMockRecorder +} + +// MockMemoryMockRecorder is the mock recorder for MockMemory. +type MockMemoryMockRecorder struct { + mock *MockMemory +} + +// NewMockMemory creates a new mock instance. +func NewMockMemory(ctrl *gomock.Controller) *MockMemory { + mock := &MockMemory{ctrl: ctrl} + mock.recorder = &MockMemoryMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockMemory) EXPECT() *MockMemoryMockRecorder { + return m.recorder +} + +// Clear mocks base method. +func (m *MockMemory) Clear(arg0 context.Context, arg1 interfaces.MemoryScope) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Clear", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// Clear indicates an expected call of Clear. +func (mr *MockMemoryMockRecorder) Clear(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Clear", reflect.TypeOf((*MockMemory)(nil).Clear), arg0, arg1) +} + +// Load mocks base method. +func (m *MockMemory) Load(arg0 context.Context, arg1 interfaces.MemoryScope, arg2 string, arg3 ...interfaces.LoadMemoryOption) ([]interfaces.MemoryEntry, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1, arg2} + for _, a := range arg3 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Load", varargs...) + ret0, _ := ret[0].([]interfaces.MemoryEntry) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Load indicates an expected call of Load. +func (mr *MockMemoryMockRecorder) Load(arg0, arg1, arg2 interface{}, arg3 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1, arg2}, arg3...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Load", reflect.TypeOf((*MockMemory)(nil).Load), varargs...) +} + +// Store mocks base method. +func (m *MockMemory) Store(arg0 context.Context, arg1 interfaces.MemoryScope, arg2 interfaces.MemoryRecord, arg3 ...interfaces.StoreMemoryOption) (string, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1, arg2} + for _, a := range arg3 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Store", varargs...) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Store indicates an expected call of Store. +func (mr *MockMemoryMockRecorder) Store(arg0, arg1, arg2 interface{}, arg3 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1, arg2}, arg3...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Store", reflect.TypeOf((*MockMemory)(nil).Store), varargs...) +} diff --git a/pkg/memory/config.go b/pkg/memory/config.go new file mode 100644 index 0000000..c1b9578 --- /dev/null +++ b/pkg/memory/config.go @@ -0,0 +1,323 @@ +package memory + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" +) + +// StoreMode selects when the SDK persists long-term memories. +type StoreMode string + +const ( + // StoreModeOnDemand registers the save_memory tool; the LLM stores via tool calls (default). + StoreModeOnDemand StoreMode = "ondemand" + // StoreModeAlways extracts memories at run end and stores them automatically. + StoreModeAlways StoreMode = "always" +) + +// ExtractFunc extracts long-term memories from a completed run. +// Used only when [StoreMode] is [StoreModeAlways]. Nil uses the SDK default LLM extractor. +type ExtractFunc func(ctx context.Context, messages []interfaces.Message) ([]interfaces.MemoryRecord, error) + +// Config wires long-term memory for the agent SDK. +type Config struct { + // Memory is the memory backend implementation. Required. + // Shipped backends: pkg/memory/weaviate and pkg/memory/pgvector. + Memory interfaces.Memory + + // ScopeConfig resolves [interfaces.MemoryScope] from request context. + ScopeConfig ScopeConfig + + // TTLPolicy sets ExpiresAt at Store time from record kind. + TTLPolicy TTLPolicy + + // Store controls when and how memories are written. + Store StoreConfig + + // Recall controls Load behavior when the SDK recalls memories for a run. + Recall RecallConfig +} + +// StoreConfig controls store timing, extraction, deduplication, and kind policy. +type StoreConfig struct { + // Mode controls when memories are written. Default [StoreModeOnDemand]. + Mode StoreMode + + // DedupMinScore is the semantic similarity floor for upserting an existing memory instead of appending. + // Zero defaults to [defaultDedupMinScore]. + DedupMinScore float32 + + // Extract overrides run-end memory extraction when [Mode] is [StoreModeAlways]. + // Must be nil when [Mode] is [StoreModeOnDemand]. + Extract ExtractFunc + + // DefaultKind is used when the record kind is empty. Zero defaults to [KindNote]. + DefaultKind interfaces.MemoryKind + + // AllowedKinds restricts stored kinds when non-empty. + AllowedKinds []interfaces.MemoryKind +} + +// ScopeResolver extracts one scope field value from the request context. +type ScopeResolver func(ctx context.Context) string + +// ScopeConfig controls how the agent SDK builds [interfaces.MemoryScope] from context. +// Configured via [Config.ScopeConfig] and agent [WithMemory]. +// +// Defaults resolve tenant_id, user_id, and agent_id from context values set by +// [WithContextTenantID], [WithContextUserID], and [WithContextAgentID]. +// Override any field with a custom resolver when your auth or tenancy model differs. +type ScopeConfig struct { + TenantIDResolver ScopeResolver + UserIDResolver ScopeResolver + AgentIDResolver ScopeResolver + + // ExtraKeys lists tag keys always resolved on Store/Load/Clear (e.g. "project_id", "env"). + // Each key must have a matching entry in TagResolvers. + ExtraKeys []string + + // TagResolvers supplies values for [ExtraKeys]. + TagResolvers map[string]ScopeResolver +} + +// TTLPolicy maps memory kind strings to time-to-live. +// Zero duration means no expiry for that kind. Opt-in via agent config; not enforced by [interfaces.Memory]. +type TTLPolicy map[interfaces.MemoryKind]time.Duration + +// RecallConfig controls SDK-initiated memory Load before or during a run. +type RecallConfig struct { + // Enabled recalls memories automatically for each run when true (default). + // When false the SDK still stores memories after each run but skips Load. + Enabled bool + + // Limit is the maximum number of memories to load. Zero or negative defaults to [defaultRecallLimit]. + Limit int + + // MinScore filters out entries below this relevance score when Score is applicable. + MinScore float32 + + // Kinds restricts recall to these kinds. Empty means all kinds. + Kinds []interfaces.MemoryKind +} + +type ctxKey struct{ name string } + +// WithContextTenantID attaches tenant ID for default scope resolution. +func WithContextTenantID(ctx context.Context, tenantID string) context.Context { + return context.WithValue(ctx, ctxKeyTenantID, tenantID) +} + +// WithContextUserID attaches user ID for default scope resolution. +func WithContextUserID(ctx context.Context, userID string) context.Context { + return context.WithValue(ctx, ctxKeyUserID, userID) +} + +// WithContextAgentID attaches agent ID for default scope resolution. +func WithContextAgentID(ctx context.Context, agentID string) context.Context { + return context.WithValue(ctx, ctxKeyAgentID, agentID) +} + +// Validate checks that every [ScopeConfig.ExtraKeys] entry has a TagResolver. +func (c ScopeConfig) Validate() error { + for _, key := range c.ExtraKeys { + if c.TagResolvers == nil || c.TagResolvers[key] == nil { + return fmt.Errorf("memory scope: ExtraKeys %q requires a TagResolver", key) + } + } + return nil +} + +// Resolve builds a [interfaces.MemoryScope] from context using configured resolvers. +func (c ScopeConfig) Resolve(ctx context.Context) (interfaces.MemoryScope, error) { + if err := c.Validate(); err != nil { + return interfaces.MemoryScope{}, err + } + + scope := interfaces.MemoryScope{} + if c.TenantIDResolver != nil { + scope.TenantID = c.TenantIDResolver(ctx) + } + if c.UserIDResolver != nil { + scope.UserID = c.UserIDResolver(ctx) + } + if c.AgentIDResolver != nil { + scope.AgentID = c.AgentIDResolver(ctx) + } + + if len(c.ExtraKeys) > 0 { + tags := make(map[string]string, len(c.ExtraKeys)) + for _, key := range c.ExtraKeys { + if v := c.TagResolvers[key](ctx); v != "" { + tags[key] = v + } + } + if len(tags) > 0 { + scope.Tags = tags + } + } + + return scope, nil +} + +// ScopeMetadata returns a flat map of scope fields for vector-store or metadata filters. +// SDK default keys are defined in [ScopeKeyTenantID], [ScopeKeyUserID], and [ScopeKeyAgentID]. +// Struct fields take precedence over conflicting tag keys. +func ScopeMetadata(s interfaces.MemoryScope) map[string]string { + out := make(map[string]string, len(s.Tags)+3) + for k, v := range s.Tags { + if k == ScopeKeyUserID || k == ScopeKeyTenantID || k == ScopeKeyAgentID { + continue + } + out[k] = v + } + if s.UserID != "" { + out[ScopeKeyUserID] = s.UserID + } + if s.TenantID != "" { + out[ScopeKeyTenantID] = s.TenantID + } + if s.AgentID != "" { + out[ScopeKeyAgentID] = s.AgentID + } + return out +} + +// ExpiresAt returns the expiry timestamp for kind at now. +// Empty or unknown kinds return zero time (no expiry) unless present in the policy map. +func (p TTLPolicy) ExpiresAt(kind interfaces.MemoryKind, now time.Time) time.Time { + if len(p) == 0 { + return time.Time{} + } + ttl, ok := p[kind] + if !ok { + return time.Time{} + } + if ttl <= 0 { + return time.Time{} + } + return now.Add(ttl) +} + +// ResolveKind returns the kind to store, applying default and allowlist policy. +func (s StoreConfig) ResolveKind(kind interfaces.MemoryKind) (interfaces.MemoryKind, error) { + if kind == "" { + kind = s.DefaultKind + } + if kind == "" { + kind = KindNote + } + if len(s.AllowedKinds) == 0 { + return kind, nil + } + for _, allowed := range s.AllowedKinds { + if kind == allowed { + return kind, nil + } + } + return "", fmt.Errorf("memory store: kind %q is not allowed", kind) +} + +// WithDefaults fills zero store fields with SDK defaults. +func (s StoreConfig) WithDefaults() StoreConfig { + if s.Mode == "" { + s.Mode = DefaultStoreMode() + } + if s.DedupMinScore <= 0 { + s.DedupMinScore = defaultDedupMinScore + } + if s.DefaultKind == "" && len(s.AllowedKinds) == 0 { + s.DefaultKind = KindNote + } + return s +} + +// Validate checks store configuration. +func (s StoreConfig) Validate() error { + if _, err := s.ResolveKind(""); err != nil { + return err + } + switch s.Mode { + case StoreModeOnDemand, StoreModeAlways: + default: + return fmt.Errorf("memory config: invalid StoreMode %q", s.Mode) + } + if s.Mode == StoreModeOnDemand && s.Extract != nil { + return errors.New("memory config: Extract is only valid with StoreMode always") + } + if s.DedupMinScore < 0 || s.DedupMinScore > 1 { + return fmt.Errorf("memory config: DedupMinScore must be between 0 and 1, got %v", s.DedupMinScore) + } + return nil +} + +// WithDefaults fills zero policy fields with SDK defaults. Memory must be set separately. +func (c Config) WithDefaults() Config { + if c.ScopeConfig.TenantIDResolver == nil && + c.ScopeConfig.UserIDResolver == nil && + c.ScopeConfig.AgentIDResolver == nil && + len(c.ScopeConfig.ExtraKeys) == 0 { + c.ScopeConfig = DefaultScopeConfig() + } + if len(c.TTLPolicy) == 0 { + c.TTLPolicy = DefaultTTLPolicy() + } + c.Store = c.Store.WithDefaults() + if c.Recall.Limit <= 0 { + c.Recall.Limit = defaultRecallLimit + } + return c +} + +// Validate checks the config. Call [WithDefaults] first. +func (c Config) Validate() error { + if c.Memory == nil { + return errors.New("memory config: Memory is required") + } + if err := c.ScopeConfig.Validate(); err != nil { + return err + } + return c.Store.Validate() +} + +// ExpiresAtForKind returns expiry for kind at now using the config TTL policy. +func (c Config) ExpiresAtForKind(kind interfaces.MemoryKind, now time.Time) time.Time { + resolved, err := c.Store.ResolveKind(kind) + if err != nil { + return time.Time{} + } + return c.TTLPolicy.ExpiresAt(resolved, now) +} + +// LoadOptions builds [interfaces.LoadMemoryOption] values from recall settings. +func (r RecallConfig) LoadOptions() []interfaces.LoadMemoryOption { + return r.loadOptions(true) +} + +// RecencyLoadOptions builds load options for scoped recency listing (no semantic min score). +func (r RecallConfig) RecencyLoadOptions() []interfaces.LoadMemoryOption { + return r.loadOptions(false) +} + +func (r RecallConfig) loadOptions(withMinScore bool) []interfaces.LoadMemoryOption { + opts := []interfaces.LoadMemoryOption{ + interfaces.WithLoadLimit(r.Limit), + } + if withMinScore && r.MinScore > 0 { + opts = append(opts, interfaces.WithMinScore(r.MinScore)) + } + if len(r.Kinds) > 0 { + opts = append(opts, interfaces.WithLoadKinds(r.Kinds...)) + } + return opts +} + +func contextStringResolver(key ctxKey) ScopeResolver { + return func(ctx context.Context) string { + v, _ := ctx.Value(key).(string) + return v + } +} diff --git a/pkg/memory/config_test.go b/pkg/memory/config_test.go new file mode 100644 index 0000000..176181c --- /dev/null +++ b/pkg/memory/config_test.go @@ -0,0 +1,249 @@ +package memory_test + +import ( + "context" + "testing" + "time" + + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" + "github.com/agenticenv/agent-sdk-go/pkg/memory" +) + +type stubMemory struct{} + +func (stubMemory) Store(context.Context, interfaces.MemoryScope, interfaces.MemoryRecord, ...interfaces.StoreMemoryOption) (string, error) { + return "id-1", nil +} +func (stubMemory) Load(context.Context, interfaces.MemoryScope, string, ...interfaces.LoadMemoryOption) ([]interfaces.MemoryEntry, error) { + return nil, nil +} +func (stubMemory) Clear(context.Context, interfaces.MemoryScope) error { return nil } + +func TestDefaultConfig(t *testing.T) { + cfg := memory.DefaultConfig(stubMemory{}) + if err := cfg.Validate(); err != nil { + t.Fatal(err) + } + if cfg.Recall.Limit != 10 { + t.Fatalf("recall limit = %d", cfg.Recall.Limit) + } + if !cfg.Recall.Enabled { + t.Fatal("expected recall enabled by default") + } + if cfg.Store.Mode != memory.StoreModeOnDemand { + t.Fatalf("store mode = %q", cfg.Store.Mode) + } + if cfg.Store.DedupMinScore != 0.85 { + t.Fatalf("dedup min score = %v", cfg.Store.DedupMinScore) + } +} + +func TestConfig_Validate_extractWithOnDemand(t *testing.T) { + cfg := memory.DefaultConfig(stubMemory{}) + cfg.Store.Extract = func(context.Context, []interfaces.Message) ([]interfaces.MemoryRecord, error) { + return nil, nil + } + if err := cfg.Validate(); err == nil { + t.Fatal("expected error when Extract is set with OnDemand") + } +} + +func TestConfig_Validate_invalidStoreMode(t *testing.T) { + cfg := memory.DefaultConfig(stubMemory{}) + cfg.Store.Mode = "invalid" + if err := cfg.Validate(); err == nil { + t.Fatal("expected error for invalid StoreMode") + } +} + +func TestConfig_Validate_invalidDedupMinScore(t *testing.T) { + cfg := memory.DefaultConfig(stubMemory{}) + cfg.Store.DedupMinScore = 1.5 + if err := cfg.Validate(); err == nil { + t.Fatal("expected error for DedupMinScore > 1") + } +} + +func TestConfig_Validate_alwaysWithExtract(t *testing.T) { + cfg := memory.DefaultConfig(stubMemory{}) + cfg.Store.Mode = memory.StoreModeAlways + cfg.Store.Extract = func(context.Context, []interfaces.Message) ([]interfaces.MemoryRecord, error) { + return []interfaces.MemoryRecord{{Text: "fact", Kind: memory.KindFact}}, nil + } + if err := cfg.Validate(); err != nil { + t.Fatal(err) + } +} + +func TestStoreConfig_WithDefaults(t *testing.T) { + got := (memory.StoreConfig{}).WithDefaults() + if got.Mode != memory.StoreModeOnDemand { + t.Fatalf("mode = %q", got.Mode) + } + if got.DedupMinScore != 0.85 { + t.Fatalf("dedup = %v", got.DedupMinScore) + } + if got.DefaultKind != memory.KindNote { + t.Fatalf("default kind = %q", got.DefaultKind) + } +} + +func TestStoreConfig_ResolveKind(t *testing.T) { + s := memory.StoreConfig{DefaultKind: memory.KindFact} + got, err := s.ResolveKind("") + if err != nil || got != memory.KindFact { + t.Fatalf("got %q err=%v", got, err) + } + + s = memory.StoreConfig{AllowedKinds: []interfaces.MemoryKind{memory.KindFact}} + if _, err := s.ResolveKind(memory.KindNote); err == nil { + t.Fatal("expected allowlist error") + } +} + +func TestDefaultStoreConfig(t *testing.T) { + got := memory.DefaultStoreConfig() + if got.Mode != memory.StoreModeOnDemand || got.DedupMinScore != 0.85 || got.DefaultKind != memory.KindNote { + t.Fatalf("got = %+v", got) + } +} + +func TestConfig_WithDefaults_appliesStore(t *testing.T) { + cfg := (memory.Config{Memory: stubMemory{}}).WithDefaults() + if cfg.Store.Mode != memory.StoreModeOnDemand { + t.Fatalf("store mode = %q", cfg.Store.Mode) + } + if cfg.Store.DefaultKind != memory.KindNote { + t.Fatalf("default kind = %q", cfg.Store.DefaultKind) + } +} + +func TestConfig_WithDefaults(t *testing.T) { + cfg := (memory.Config{Memory: stubMemory{}}).WithDefaults() + if err := cfg.Validate(); err != nil { + t.Fatal(err) + } +} + +func TestConfig_Validate_missingMemory(t *testing.T) { + if err := (memory.Config{}).Validate(); err == nil { + t.Fatal("expected error for missing store") + } +} + +func TestConfig_ExpiresAtForKind(t *testing.T) { + cfg := memory.DefaultConfig(stubMemory{}) + now := time.Date(2026, 6, 18, 12, 0, 0, 0, time.UTC) + exp := cfg.ExpiresAtForKind(memory.KindDecision, now) + if !exp.Equal(now.Add(memory.TTLDecision)) { + t.Fatalf("expires = %v", exp) + } +} + +func TestRecallConfig_LoadOptions(t *testing.T) { + opts := (memory.RecallConfig{Limit: 5, MinScore: 0.5, Kinds: []interfaces.MemoryKind{memory.KindFact}}).LoadOptions() + if len(opts) != 3 { + t.Fatalf("opts len = %d", len(opts)) + } +} + +func TestDefaultScopeConfig_Resolve(t *testing.T) { + ctx := context.Background() + ctx = memory.WithContextTenantID(ctx, "tenant-1") + ctx = memory.WithContextUserID(ctx, "user-1") + ctx = memory.WithContextAgentID(ctx, "agent-1") + + scope, err := memory.DefaultScopeConfig().Resolve(ctx) + if err != nil { + t.Fatal(err) + } + if scope.TenantID != "tenant-1" || scope.UserID != "user-1" || scope.AgentID != "agent-1" { + t.Fatalf("scope = %+v", scope) + } +} + +func TestScopeConfig_Resolve_extraKeys(t *testing.T) { + cfg := memory.ScopeConfig{ + TenantIDResolver: func(ctx context.Context) string { return "t1" }, + ExtraKeys: []string{"project_id", "env"}, + TagResolvers: map[string]memory.ScopeResolver{ + "project_id": func(ctx context.Context) string { return "proj-a" }, + "env": func(ctx context.Context) string { return "prod" }, + }, + } + scope, err := cfg.Resolve(context.Background()) + if err != nil { + t.Fatal(err) + } + if scope.Tags["project_id"] != "proj-a" || scope.Tags["env"] != "prod" { + t.Fatalf("tags = %+v", scope.Tags) + } +} + +func TestScopeConfig_Validate_missingTagResolver(t *testing.T) { + cfg := memory.ScopeConfig{ + ExtraKeys: []string{"project_id"}, + } + if err := cfg.Validate(); err == nil { + t.Fatal("expected validation error") + } +} + +func TestScopeConfig_Resolve_customResolvers(t *testing.T) { + cfg := memory.ScopeConfig{ + TenantIDResolver: func(ctx context.Context) string { return "custom-tenant" }, + UserIDResolver: func(ctx context.Context) string { return "custom-user" }, + } + scope, err := cfg.Resolve(context.Background()) + if err != nil { + t.Fatal(err) + } + if scope.TenantID != "custom-tenant" || scope.UserID != "custom-user" { + t.Fatalf("scope = %+v", scope) + } +} + +func TestScopeMetadata(t *testing.T) { + meta := memory.ScopeMetadata(interfaces.MemoryScope{ + TenantID: "t1", + UserID: "u1", + Tags: map[string]string{ + "project_id": "p1", + memory.ScopeKeyUserID: "ignored", + }, + }) + if meta[memory.ScopeKeyTenantID] != "t1" || meta[memory.ScopeKeyUserID] != "u1" { + t.Fatalf("meta = %+v", meta) + } + if meta["project_id"] != "p1" { + t.Fatalf("tags = %+v", meta) + } + if _, ok := meta["ignored"]; ok { + t.Fatal("conflicting tag key should be skipped") + } +} + +func TestTTLPolicy_ExpiresAt(t *testing.T) { + now := time.Date(2026, 6, 18, 12, 0, 0, 0, time.UTC) + p := memory.DefaultTTLPolicy() + + if got := p.ExpiresAt(memory.KindDecision, now); !got.Equal(now.Add(memory.TTLDecision)) { + t.Fatalf("decision expiry = %v, want %v", got, now.Add(memory.TTLDecision)) + } + if !p.ExpiresAt(memory.KindFact, now).IsZero() { + t.Fatal("fact should not expire") + } + if !p.ExpiresAt(interfaces.MemoryKind("custom"), now).IsZero() { + t.Fatal("unknown kind should not expire") + } + if !p.ExpiresAt("", now).IsZero() { + t.Fatal("empty kind should not expire") + } +} + +func TestTTLPolicy_ExpiresAt_nilPolicy(t *testing.T) { + var p memory.TTLPolicy + if !p.ExpiresAt(memory.KindNote, time.Now()).IsZero() { + t.Fatal("nil policy should not expire") + } +} diff --git a/pkg/memory/defaults.go b/pkg/memory/defaults.go new file mode 100644 index 0000000..55fb010 --- /dev/null +++ b/pkg/memory/defaults.go @@ -0,0 +1,97 @@ +package memory + +import ( + "time" + + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" +) + +const defaultRecallLimit = 10 + +// defaultRecallMinScore is the default semantic similarity floor for recall. +// Run-summary memories often score below 0.75 against follow-up questions; 0.35 matches retriever example tuning. +const defaultRecallMinScore float32 = 0.35 + +// defaultDedupMinScore is the default semantic similarity floor for upserting an existing memory on store. +const defaultDedupMinScore float32 = 0.85 + +// Optional default kinds for general-purpose agents. Custom agents may define and use their own kind strings. +const ( + KindPreference interfaces.MemoryKind = "preference" + KindFact interfaces.MemoryKind = "fact" + KindDecision interfaces.MemoryKind = "decision" + KindInstruction interfaces.MemoryKind = "instruction" + KindNote interfaces.MemoryKind = "note" +) + +// Default scope metadata keys for [DefaultScopeConfig] and [ScopeMetadata]. +const ( + ScopeKeyTenantID = "tenant_id" + ScopeKeyUserID = "user_id" + ScopeKeyAgentID = "agent_id" +) + +// Default TTL values for [DefaultTTLPolicy]. Zero duration means no expiry. +const ( + TTLDecision = 7 * 24 * time.Hour + TTLFact = 0 + TTLPreference = 0 + TTLInstruction = 0 + TTLNote = 48 * time.Hour +) + +// Context keys used by [DefaultScopeConfig] resolvers. +var ( + ctxKeyTenantID = ctxKey{"memory:tenant_id"} + ctxKeyUserID = ctxKey{"memory:user_id"} + ctxKeyAgentID = ctxKey{"memory:agent_id"} +) + +// DefaultConfig returns a [Config] with SDK defaults for scope, TTL, store, and recall policies. +func DefaultConfig(store interfaces.Memory) Config { + return Config{ + Memory: store, + ScopeConfig: DefaultScopeConfig(), + TTLPolicy: DefaultTTLPolicy(), + Store: DefaultStoreConfig(), + Recall: DefaultRecallConfig(), + } +} + +// DefaultStoreConfig returns SDK defaults for store behavior. +func DefaultStoreConfig() StoreConfig { + return StoreConfig{ + Mode: DefaultStoreMode(), + DedupMinScore: defaultDedupMinScore, + DefaultKind: KindNote, + } +} + +// DefaultStoreMode returns the default store mode ([StoreModeOnDemand]). +func DefaultStoreMode() StoreMode { + return StoreModeOnDemand +} + +// DefaultScopeConfig returns SDK defaults: tenant_id, user_id, and agent_id from context. +func DefaultScopeConfig() ScopeConfig { + return ScopeConfig{ + TenantIDResolver: contextStringResolver(ctxKeyTenantID), + UserIDResolver: contextStringResolver(ctxKeyUserID), + AgentIDResolver: contextStringResolver(ctxKeyAgentID), + } +} + +// DefaultTTLPolicy returns a convenience TTL map for general-purpose agents. +func DefaultTTLPolicy() TTLPolicy { + return TTLPolicy{ + KindDecision: TTLDecision, + KindFact: TTLFact, + KindPreference: TTLPreference, + KindInstruction: TTLInstruction, + KindNote: TTLNote, + } +} + +func DefaultRecallConfig() RecallConfig { + return RecallConfig{Enabled: true, Limit: defaultRecallLimit, MinScore: defaultRecallMinScore} +} diff --git a/pkg/memory/pgvector/memory.go b/pkg/memory/pgvector/memory.go new file mode 100644 index 0000000..ac49ea8 --- /dev/null +++ b/pkg/memory/pgvector/memory.go @@ -0,0 +1,580 @@ +// Package pgvector provides a [interfaces.Memory] implementation backed by PostgreSQL with pgvector. +// +// Expected table schema (embedding dimension must match [EmbedFunc] output): +// +// CREATE TABLE agent_memories ( +// id UUID PRIMARY KEY DEFAULT gen_random_uuid(), +// text TEXT NOT NULL, +// kind TEXT NOT NULL DEFAULT '', +// user_id TEXT, +// tenant_id TEXT, +// agent_id TEXT, +// scope_tags TEXT[] NOT NULL DEFAULT '{}', +// metadata JSONB NOT NULL DEFAULT '{}', +// expires_at TIMESTAMPTZ, +// created_at TIMESTAMPTZ NOT NULL DEFAULT now(), +// updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), +// embedding vector(1536) +// ); +// +// Callers provide [EmbedFunc] to vectorize memory text on store and recall queries on load. +package pgvector + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "strings" + "time" + + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" + "github.com/agenticenv/agent-sdk-go/pkg/logger" + "github.com/agenticenv/agent-sdk-go/pkg/memory" + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgxpool" + pgvec "github.com/pgvector/pgvector-go" + pgxvec "github.com/pgvector/pgvector-go/pgx" +) + +var _ interfaces.Memory = (*Memory)(nil) + +// EmbedFunc converts plain text into a vector embedding. +type EmbedFunc func(ctx context.Context, text string) ([]float32, error) + +// pgRows is the subset of [pgx.Rows] used by Load, allowing injection in tests. +type pgRows interface { + Close() + Next() bool + Scan(dest ...any) error + Err() error +} + +// pgDB abstracts database calls; satisfied by [pgxPoolDB] and test stubs. +type pgDB interface { + Query(ctx context.Context, sql string, args ...any) (pgRows, error) + Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error) +} + +type pgxPoolDB struct{ pool *pgxpool.Pool } + +func (d *pgxPoolDB) Query(ctx context.Context, sql string, args ...any) (pgRows, error) { + return d.pool.Query(ctx, sql, args...) +} + +func (d *pgxPoolDB) Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error) { + return d.pool.Exec(ctx, sql, args...) +} + +// Memory stores and recalls agent memories in PostgreSQL with pgvector. +type Memory struct { + db pgDB + table string + textCol string + embeddingCol string + embed EmbedFunc + + defaultLimit int + defaultMinScore float32 + + logger logger.Logger + dsn string + logLevel string +} + +// Option configures [Memory]. +type Option func(*Memory) + +// WithPool sets an existing [pgxpool.Pool]. When provided, [WithDSN] is ignored. +func WithPool(pool *pgxpool.Pool) Option { + return func(m *Memory) { m.db = &pgxPoolDB{pool: pool} } +} + +// WithDSN sets the PostgreSQL connection string used to create a pool when [WithPool] is omitted. +func WithDSN(dsn string) Option { + return func(m *Memory) { m.dsn = dsn } +} + +// WithTable sets the PostgreSQL table. Defaults to [DefaultTable]. +func WithTable(table string) Option { + return func(m *Memory) { m.table = table } +} + +// WithTextCol sets the column that holds memory text. Defaults to [DefaultTextCol]. +func WithTextCol(col string) Option { + return func(m *Memory) { m.textCol = col } +} + +// WithEmbeddingCol sets the column that holds the pgvector embedding. Defaults to [DefaultEmbeddingCol]. +func WithEmbeddingCol(col string) Option { + return func(m *Memory) { m.embeddingCol = col } +} + +// WithDefaultLimit sets the load limit when callers omit [interfaces.WithLoadLimit]. +func WithDefaultLimit(limit int) Option { + return func(m *Memory) { m.defaultLimit = limit } +} + +// WithDefaultMinScore sets the cosine similarity floor when callers omit [interfaces.WithMinScore]. +func WithDefaultMinScore(minScore float32) Option { + return func(m *Memory) { m.defaultMinScore = minScore } +} + +// WithLogger sets the logger. +func WithLogger(l logger.Logger) Option { + return func(m *Memory) { m.logger = l } +} + +// WithLogLevel sets the log level when no logger is provided. +func WithLogLevel(level string) Option { + return func(m *Memory) { m.logLevel = level } +} + +// NewMemory builds a pgvector-backed [interfaces.Memory]. embed is required. +// When [WithPool] is omitted, [WithDSN] must be provided. +func NewMemory(embed EmbedFunc, opts ...Option) (*Memory, error) { + if embed == nil { + return nil, errors.New("embed func is required") + } + m := &Memory{embed: embed} + for _, opt := range opts { + opt(m) + } + if m.table == "" { + m.table = DefaultTable + } + if m.textCol == "" { + m.textCol = DefaultTextCol + } + if m.embeddingCol == "" { + m.embeddingCol = DefaultEmbeddingCol + } + if m.defaultLimit <= 0 { + m.defaultLimit = DefaultLoadLimit + } + if m.defaultMinScore == 0 { + m.defaultMinScore = DefaultMinScore + } + if m.logLevel == "" { + m.logLevel = "error" + } + if m.logger == nil { + m.logger = logger.DefaultLogger(m.logLevel) + } + if m.db == nil { + if m.dsn == "" { + return nil, errors.New("DSN is required when not using WithPool; use WithDSN or WithPool") + } + cfg, err := pgxpool.ParseConfig(m.dsn) + if err != nil { + return nil, fmt.Errorf("parse DSN: %w", err) + } + cfg.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error { + return pgxvec.RegisterTypes(ctx, conn) + } + pool, err := pgxpool.NewWithConfig(context.Background(), cfg) + if err != nil { + return nil, fmt.Errorf("create pgx pool: %w", err) + } + m.db = &pgxPoolDB{pool: pool} + } + m.logger.Info(context.Background(), "pgvector memory built", + slog.String("scope", "pgvector-memory"), + slog.String("table", m.table), + slog.String("textCol", m.textCol), + slog.String("embeddingCol", m.embeddingCol), + slog.Int("defaultLimit", m.defaultLimit), + slog.Float64("defaultMinScore", float64(m.defaultMinScore)), + ) + return m, nil +} + +// Store persists a memory in scope and returns its ID. +func (m *Memory) Store(ctx context.Context, scope interfaces.MemoryScope, record interfaces.MemoryRecord, opts ...interfaces.StoreMemoryOption) (string, error) { + if m.db == nil { + return "", errors.New("database is not set") + } + + storeOpts := interfaces.StoreMemoryOptions{} + for _, opt := range opts { + opt(&storeOpts) + } + + vec, err := m.embed(ctx, record.Text) + if err != nil { + return "", fmt.Errorf("embed memory text: %w", err) + } + + now := time.Now().UTC() + meta := memory.ScopeMetadata(scope) + scopeTags := encodeScopeTags(meta) + + metadataJSON, err := marshalMetadata(record.Metadata) + if err != nil { + return "", err + } + + id := strings.TrimSpace(storeOpts.ID) + if id == "" { + id = uuid.NewString() + } + + userID := meta[memory.ScopeKeyUserID] + tenantID := meta[memory.ScopeKeyTenantID] + agentID := meta[memory.ScopeKeyAgentID] + + expiresArg := any(nil) + if !record.ExpiresAt.IsZero() { + expiresArg = record.ExpiresAt.UTC() + } + + //nolint:gosec // table/column identifiers are build-time developer config. + sql := fmt.Sprintf(` + INSERT INTO %s ( + %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s + ) VALUES ( + $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12 + ) + ON CONFLICT (%s) DO UPDATE SET + %s = EXCLUDED.%s, + %s = EXCLUDED.%s, + %s = EXCLUDED.%s, + %s = EXCLUDED.%s, + %s = EXCLUDED.%s, + %s = EXCLUDED.%s, + %s = EXCLUDED.%s, + %s = EXCLUDED.%s, + %s = EXCLUDED.%s, + %s = %s.%s, + %s = EXCLUDED.%s + RETURNING %s`, + m.table, + ColID, m.textCol, ColKind, ColUserID, ColTenantID, ColAgentID, ColScopeTags, ColMetadata, + ColExpiresAt, ColCreatedAt, ColUpdatedAt, m.embeddingCol, + ColID, + m.textCol, m.textCol, + ColKind, ColKind, + ColUserID, ColUserID, + ColTenantID, ColTenantID, + ColAgentID, ColAgentID, + ColScopeTags, ColScopeTags, + ColMetadata, ColMetadata, + ColExpiresAt, ColExpiresAt, + ColUpdatedAt, ColUpdatedAt, + ColCreatedAt, m.table, ColCreatedAt, + m.embeddingCol, m.embeddingCol, + ColID, + ) + + args := []any{ + id, + record.Text, + string(record.Kind), + nullIfEmpty(userID), + nullIfEmpty(tenantID), + nullIfEmpty(agentID), + scopeTags, + metadataJSON, + expiresArg, + now, + now, + pgvec.NewVector(vec), + } + + rows, err := m.db.Query(ctx, sql, args...) + if err != nil { + return "", fmt.Errorf("pgvector store memory: %w", err) + } + defer rows.Close() + + if !rows.Next() { + if err := rows.Err(); err != nil { + return "", fmt.Errorf("pgvector store memory: %w", err) + } + return "", errors.New("pgvector store memory: no id returned") + } + var returnedID string + if err := rows.Scan(&returnedID); err != nil { + return "", fmt.Errorf("pgvector store memory scan id: %w", err) + } + if err := rows.Err(); err != nil { + return "", fmt.Errorf("pgvector store memory: %w", err) + } + return returnedID, nil +} + +// Load recalls memories within scope. Non-empty query uses vector similarity; empty query lists by updated_at. +func (m *Memory) Load(ctx context.Context, scope interfaces.MemoryScope, query string, opts ...interfaces.LoadMemoryOption) ([]interfaces.MemoryEntry, error) { + if m.db == nil { + return nil, errors.New("database is not set") + } + + loadOpts := interfaces.LoadMemoryOptions{} + for _, opt := range opts { + opt(&loadOpts) + } + limit := loadOpts.Limit + if limit <= 0 { + limit = m.defaultLimit + } + minScore := loadOpts.MinScore + if minScore == 0 { + minScore = m.defaultMinScore + } + + query = strings.TrimSpace(query) + if query != "" { + vec, err := m.embed(ctx, query) + if err != nil { + return nil, fmt.Errorf("embed recall query: %w", err) + } + + whereSQL, scopeArgs := buildScopeArgs(scope, loadOpts.Kinds, 4) + whereSQL = appendNotExpired(whereSQL) + args := append([]any{pgvec.NewVector(vec), float64(minScore), limit}, scopeArgs...) + scoreExpr := fmt.Sprintf("1 - (%s <=> $1)", m.embeddingCol) + //nolint:gosec + sql := fmt.Sprintf(` + SELECT %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s AS score + FROM %s + WHERE %s AND %s >= $2 + ORDER BY %s <=> $1 + LIMIT $3`, + ColID, m.textCol, ColKind, ColUserID, ColTenantID, ColAgentID, ColScopeTags, ColMetadata, + ColExpiresAt, ColCreatedAt, ColUpdatedAt, scoreExpr, + m.table, + whereSQL, scoreExpr, + m.embeddingCol, + ) + rows, err := m.db.Query(ctx, sql, args...) + if err != nil { + return nil, fmt.Errorf("pgvector load memories: %w", err) + } + return scanMemoryRows(rows) + } + + whereSQL, args := buildScopeArgs(scope, loadOpts.Kinds, 1) + whereSQL = appendNotExpired(whereSQL) + args = append(args, limit) + limitArg := len(args) + //nolint:gosec + sql := fmt.Sprintf(` + SELECT %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, 0::float8 AS score + FROM %s + WHERE %s + ORDER BY %s DESC + LIMIT $%d`, + ColID, m.textCol, ColKind, ColUserID, ColTenantID, ColAgentID, ColScopeTags, ColMetadata, + ColExpiresAt, ColCreatedAt, ColUpdatedAt, + m.table, + whereSQL, + ColUpdatedAt, + limitArg, + ) + rows, err := m.db.Query(ctx, sql, args...) + if err != nil { + return nil, fmt.Errorf("pgvector load memories: %w", err) + } + return scanMemoryRows(rows) +} + +// Clear removes all memories matching scope. +func (m *Memory) Clear(ctx context.Context, scope interfaces.MemoryScope) error { + if m.db == nil { + return errors.New("database is not set") + } + + meta := memory.ScopeMetadata(scope) + if len(meta) == 0 { + return errors.New("scope must include at least one non-empty field") + } + + whereSQL, args := buildScopeArgs(scope, nil, 1) + + //nolint:gosec + sql := fmt.Sprintf("DELETE FROM %s WHERE %s", m.table, whereSQL) + if _, err := m.db.Exec(ctx, sql, args...); err != nil { + return fmt.Errorf("pgvector clear memories: %w", err) + } + return nil +} + +func buildScopeArgs(scope interfaces.MemoryScope, kinds []interfaces.MemoryKind, startArg int) (string, []any) { + meta := memory.ScopeMetadata(scope) + if len(meta) == 0 && len(kinds) == 0 { + return "TRUE", nil + } + + var parts []string + var args []any + nextArg := startArg + + for key, value := range meta { + switch key { + case memory.ScopeKeyUserID, memory.ScopeKeyTenantID, memory.ScopeKeyAgentID: + args = append(args, value) + parts = append(parts, fmt.Sprintf("%s = $%d", key, nextArg)) + nextArg++ + default: + args = append(args, []string{scopeTagToken(key, value)}) + parts = append(parts, fmt.Sprintf("%s @> $%d::text[]", ColScopeTags, nextArg)) + nextArg++ + } + } + + if len(kinds) == 1 { + args = append(args, string(kinds[0])) + parts = append(parts, fmt.Sprintf("%s = $%d", ColKind, nextArg)) + } else if len(kinds) > 1 { + kindValues := make([]string, len(kinds)) + for i, kind := range kinds { + kindValues[i] = string(kind) + } + args = append(args, kindValues) + parts = append(parts, fmt.Sprintf("%s = ANY($%d::text[])", ColKind, nextArg)) + } + + if len(parts) == 0 { + return "TRUE", args + } + return strings.Join(parts, " AND "), args +} + +func appendNotExpired(whereSQL string) string { + clause := fmt.Sprintf("(%s IS NULL OR %s > now())", ColExpiresAt, ColExpiresAt) + if whereSQL == "" || whereSQL == "TRUE" { + return clause + } + return whereSQL + " AND " + clause +} + +func scanMemoryRows(rows pgRows) ([]interfaces.MemoryEntry, error) { + defer rows.Close() + + var entries []interfaces.MemoryEntry + for rows.Next() { + var ( + entry interfaces.MemoryEntry + kind string + userID *string + tenantID *string + agentID *string + scopeTags []string + metadata []byte + expiresAt *time.Time + createdAt time.Time + updatedAt time.Time + score float64 + ) + if err := rows.Scan( + &entry.ID, + &entry.Text, + &kind, + &userID, + &tenantID, + &agentID, + &scopeTags, + &metadata, + &expiresAt, + &createdAt, + &updatedAt, + &score, + ); err != nil { + return nil, fmt.Errorf("scan memory row: %w", err) + } + + entry.Kind = interfaces.MemoryKind(kind) + entry.Scope = interfaces.MemoryScope{ + UserID: derefString(userID), + TenantID: derefString(tenantID), + AgentID: derefString(agentID), + Tags: decodeScopeTags(scopeTags), + } + if len(metadata) > 0 { + if err := json.Unmarshal(metadata, &entry.Metadata); err != nil { + return nil, fmt.Errorf("unmarshal metadata: %w", err) + } + } + if expiresAt != nil { + entry.ExpiresAt = expiresAt.UTC() + } + entry.CreatedAt = createdAt.UTC() + entry.UpdatedAt = updatedAt.UTC() + entry.Score = float32(score) + + if entry.Expired() { + continue + } + entries = append(entries, entry) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate memory rows: %w", err) + } + return entries, nil +} + +func marshalMetadata(metadata map[string]string) ([]byte, error) { + if len(metadata) == 0 { + return []byte("{}"), nil + } + raw, err := json.Marshal(metadata) + if err != nil { + return nil, fmt.Errorf("marshal metadata: %w", err) + } + return raw, nil +} + +func encodeScopeTags(meta map[string]string) []string { + if len(meta) == 0 { + return []string{} + } + tags := make([]string, 0, len(meta)) + for key, value := range meta { + switch key { + case memory.ScopeKeyUserID, memory.ScopeKeyTenantID, memory.ScopeKeyAgentID: + continue + default: + tags = append(tags, scopeTagToken(key, value)) + } + } + return tags +} + +func decodeScopeTags(tags []string) map[string]string { + if len(tags) == 0 { + return nil + } + out := make(map[string]string, len(tags)) + for _, token := range tags { + key, value, ok := strings.Cut(token, "=") + if !ok || key == "" { + continue + } + out[key] = value + } + if len(out) == 0 { + return nil + } + return out +} + +func scopeTagToken(key, value string) string { + return key + "=" + value +} + +func nullIfEmpty(s string) any { + if s == "" { + return nil + } + return s +} + +func derefString(s *string) string { + if s == nil { + return "" + } + return *s +} diff --git a/pkg/memory/pgvector/memory_test.go b/pkg/memory/pgvector/memory_test.go new file mode 100644 index 0000000..f23a072 --- /dev/null +++ b/pkg/memory/pgvector/memory_test.go @@ -0,0 +1,462 @@ +package pgvector + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" + "github.com/agenticenv/agent-sdk-go/pkg/logger" + "github.com/jackc/pgx/v5/pgconn" +) + +type noopLogger struct{} + +func (noopLogger) Debug(_ context.Context, _ string, _ ...any) {} +func (noopLogger) Info(_ context.Context, _ string, _ ...any) {} +func (noopLogger) Warn(_ context.Context, _ string, _ ...any) {} +func (noopLogger) Error(_ context.Context, _ string, _ ...any) {} + +var _ logger.Logger = noopLogger{} + +type stubIDRows struct { + id string + scanned bool + err error +} + +func (r *stubIDRows) Close() {} +func (r *stubIDRows) Next() bool { + if r.scanned { + return false + } + r.scanned = true + return true +} +func (r *stubIDRows) Scan(dest ...any) error { + if r.err != nil { + return r.err + } + *dest[0].(*string) = r.id + return nil +} +func (r *stubIDRows) Err() error { return nil } + +type memoryRowData struct { + id string + text string + kind string + userID *string + tenantID *string + agentID *string + scopeTags []string + metadata []byte + expiresAt *time.Time + createdAt time.Time + updatedAt time.Time + score float64 +} + +type stubMemoryRows struct { + data []memoryRowData + pos int + scanErr error + iterErr error +} + +func (r *stubMemoryRows) Close() {} +func (r *stubMemoryRows) Next() bool { + r.pos++ + return r.pos <= len(r.data) +} +func (r *stubMemoryRows) Scan(dest ...any) error { + if r.scanErr != nil { + return r.scanErr + } + row := r.data[r.pos-1] + *dest[0].(*string) = row.id + *dest[1].(*string) = row.text + *dest[2].(*string) = row.kind + *dest[3].(**string) = row.userID + *dest[4].(**string) = row.tenantID + *dest[5].(**string) = row.agentID + *dest[6].(*[]string) = row.scopeTags + *dest[7].(*[]byte) = row.metadata + *dest[8].(**time.Time) = row.expiresAt + *dest[9].(*time.Time) = row.createdAt + *dest[10].(*time.Time) = row.updatedAt + *dest[11].(*float64) = row.score + return nil +} +func (r *stubMemoryRows) Err() error { return r.iterErr } + +type stubDB struct { + queryRows pgRows + queryErr error + execTag pgconn.CommandTag + execErr error + + lastQuerySQL string + lastQueryArgs []any + lastExecSQL string + lastExecArgs []any +} + +func (s *stubDB) Query(_ context.Context, sql string, args ...any) (pgRows, error) { + s.lastQuerySQL = sql + s.lastQueryArgs = args + return s.queryRows, s.queryErr +} + +func (s *stubDB) Exec(_ context.Context, sql string, args ...any) (pgconn.CommandTag, error) { + s.lastExecSQL = sql + s.lastExecArgs = args + return s.execTag, s.execErr +} + +func stubEmbed(_ context.Context, _ string) ([]float32, error) { + return []float32{0.1, 0.2, 0.3}, nil +} + +func errEmbed(_ context.Context, _ string) ([]float32, error) { + return nil, errors.New("embed error") +} + +func newTestMemory(t *testing.T, db pgDB, opts ...Option) *Memory { + t.Helper() + m := &Memory{ + embed: stubEmbed, + db: db, + table: DefaultTable, + textCol: DefaultTextCol, + embeddingCol: DefaultEmbeddingCol, + defaultLimit: DefaultLoadLimit, + defaultMinScore: DefaultMinScore, + logger: noopLogger{}, + } + for _, opt := range opts { + opt(m) + } + return m +} + +func TestNewMemory_MissingEmbed(t *testing.T) { + _, err := NewMemory(nil, WithTable(DefaultTable)) + if err == nil || !contains(err.Error(), "embed") { + t.Fatalf("err = %v", err) + } +} + +func TestNewMemory_MissingDSNAndPool(t *testing.T) { + _, err := NewMemory(stubEmbed, WithLogger(noopLogger{})) + if err == nil || !contains(err.Error(), "DSN") { + t.Fatalf("err = %v", err) + } +} + +func TestNewMemory_InvalidDSN(t *testing.T) { + _, err := NewMemory(stubEmbed, WithDSN("not-a-valid-dsn"), WithLogger(noopLogger{})) + if err == nil { + t.Fatal("expected parse DSN error") + } +} + +func TestNewMemory_Defaults(t *testing.T) { + m := newTestMemory(t, &stubDB{}) + if m.table != DefaultTable { + t.Fatalf("table = %q", m.table) + } + if m.textCol != DefaultTextCol || m.embeddingCol != DefaultEmbeddingCol { + t.Fatalf("cols = %q %q", m.textCol, m.embeddingCol) + } + if m.defaultLimit != DefaultLoadLimit || m.defaultMinScore != DefaultMinScore { + t.Fatalf("defaults = %d %v", m.defaultLimit, m.defaultMinScore) + } +} + +func TestStore_NoDB(t *testing.T) { + m := &Memory{embed: stubEmbed, db: nil} + _, err := m.Store(context.Background(), interfaces.MemoryScope{}, interfaces.MemoryRecord{Text: "x"}) + if err == nil || !contains(err.Error(), "database is not set") { + t.Fatalf("err = %v", err) + } +} + +func TestStore_EmbedError(t *testing.T) { + m := newTestMemory(t, &stubDB{}) + m.embed = errEmbed + _, err := m.Store(context.Background(), interfaces.MemoryScope{UserID: "u1"}, interfaces.MemoryRecord{Text: "x"}) + if err == nil || !contains(err.Error(), "embed") { + t.Fatalf("err = %v", err) + } +} + +func TestStore_Create(t *testing.T) { + db := &stubDB{ + queryRows: &stubIDRows{id: "mem-1"}, + } + m := newTestMemory(t, db) + + id, err := m.Store(context.Background(), + interfaces.MemoryScope{UserID: "u1", TenantID: "t1", Tags: map[string]string{"team": "a"}}, + interfaces.MemoryRecord{ + Text: "likes dark mode", + Kind: interfaces.MemoryKind("preference"), + Metadata: map[string]string{"source": "run-1"}, + }, + ) + if err != nil { + t.Fatal(err) + } + if id != "mem-1" { + t.Fatalf("id = %q", id) + } + if !contains(db.lastQuerySQL, "INSERT INTO "+DefaultTable) { + t.Fatalf("sql = %s", db.lastQuerySQL) + } + if !contains(db.lastQuerySQL, "ON CONFLICT") { + t.Fatalf("sql missing upsert: %s", db.lastQuerySQL) + } + if contains(db.lastQuerySQL, "updated_at = updated_at.") { + t.Fatalf("sql must not reference updated_at as table: %s", db.lastQuerySQL) + } + if !contains(db.lastQuerySQL, "updated_at = EXCLUDED.updated_at") { + t.Fatalf("sql missing updated_at upsert: %s", db.lastQuerySQL) + } + if !contains(db.lastQuerySQL, "created_at = "+DefaultTable+".created_at") { + t.Fatalf("sql must preserve created_at on conflict: %s", db.lastQuerySQL) + } + if len(db.lastQueryArgs) != 12 { + t.Fatalf("args len = %d", len(db.lastQueryArgs)) + } +} + +func TestStore_UpsertWithID(t *testing.T) { + db := &stubDB{ + queryRows: &stubIDRows{id: "fixed-id"}, + } + m := newTestMemory(t, db) + + id, err := m.Store(context.Background(), + interfaces.MemoryScope{UserID: "u1"}, + interfaces.MemoryRecord{Text: "updated"}, + interfaces.WithMemoryID("fixed-id"), + ) + if err != nil { + t.Fatal(err) + } + if id != "fixed-id" { + t.Fatalf("id = %q", id) + } + if db.lastQueryArgs[0] != "fixed-id" { + t.Fatalf("first arg = %v", db.lastQueryArgs[0]) + } +} + +func TestLoad_Semantic(t *testing.T) { + now := time.Now().UTC() + user := "u1" + db := &stubDB{ + queryRows: &stubMemoryRows{data: []memoryRowData{{ + id: "abc", + text: "likes dark mode", + kind: "preference", + userID: &user, + metadata: []byte(`{"source":"run-1"}`), + createdAt: now, + updatedAt: now, + score: 0.91, + }}}, + } + m := newTestMemory(t, db) + + entries, err := m.Load(context.Background(), + interfaces.MemoryScope{UserID: "u1"}, + "theme preference", + interfaces.WithLoadLimit(5), + interfaces.WithMinScore(0.8), + ) + if err != nil { + t.Fatal(err) + } + if len(entries) != 1 { + t.Fatalf("len = %d", len(entries)) + } + if entries[0].ID != "abc" || entries[0].Text != "likes dark mode" || entries[0].Score != 0.91 { + t.Fatalf("entry = %#v", entries[0]) + } + if entries[0].Metadata["source"] != "run-1" { + t.Fatalf("metadata = %#v", entries[0].Metadata) + } + if !contains(db.lastQuerySQL, "<=>") { + t.Fatalf("sql missing vector search: %s", db.lastQuerySQL) + } + if !contains(db.lastQuerySQL, "user_id = $") { + t.Fatalf("sql missing scope filter: %s", db.lastQuerySQL) + } +} + +func TestLoad_Recency(t *testing.T) { + now := time.Now().UTC() + db := &stubDB{ + queryRows: &stubMemoryRows{data: []memoryRowData{{ + id: "note-1", + text: "recent note", + kind: "note", + createdAt: now, + updatedAt: now, + }}}, + } + m := newTestMemory(t, db) + + entries, err := m.Load(context.Background(), interfaces.MemoryScope{UserID: "u1"}, "") + if err != nil { + t.Fatal(err) + } + if len(entries) != 1 || entries[0].Text != "recent note" { + t.Fatalf("entries = %#v", entries) + } + if contains(db.lastQuerySQL, "<=>") { + t.Fatalf("unexpected vector search for empty query: %s", db.lastQuerySQL) + } + if !contains(db.lastQuerySQL, "ORDER BY updated_at DESC") { + t.Fatalf("sql missing recency sort: %s", db.lastQuerySQL) + } +} + +func TestLoad_EmbedError(t *testing.T) { + m := newTestMemory(t, &stubDB{}) + m.embed = errEmbed + _, err := m.Load(context.Background(), interfaces.MemoryScope{UserID: "u1"}, "query") + if err == nil || !contains(err.Error(), "embed") { + t.Fatalf("err = %v", err) + } +} + +func TestLoad_SkipsExpired(t *testing.T) { + past := time.Now().UTC().Add(-time.Hour) + db := &stubDB{ + queryRows: &stubMemoryRows{data: []memoryRowData{{ + id: "expired", + text: "old", + expiresAt: &past, + createdAt: past, + updatedAt: past, + }}}, + } + m := newTestMemory(t, db) + + entries, err := m.Load(context.Background(), interfaces.MemoryScope{}, "") + if err != nil { + t.Fatal(err) + } + if len(entries) != 0 { + t.Fatalf("expected expired entry to be omitted, got %#v", entries) + } +} + +func TestLoad_KindsFilter(t *testing.T) { + db := &stubDB{queryRows: &stubMemoryRows{}} + m := newTestMemory(t, db) + + _, err := m.Load(context.Background(), interfaces.MemoryScope{UserID: "u1"}, "", + interfaces.WithLoadKinds("fact", "note"), + ) + if err != nil { + t.Fatal(err) + } + if !contains(db.lastQuerySQL, "kind = ANY") { + t.Fatalf("sql = %s", db.lastQuerySQL) + } +} + +func TestClear(t *testing.T) { + db := &stubDB{} + m := newTestMemory(t, db) + + if err := m.Clear(context.Background(), interfaces.MemoryScope{TenantID: "t1"}); err != nil { + t.Fatal(err) + } + if !contains(db.lastExecSQL, "DELETE FROM "+DefaultTable) { + t.Fatalf("sql = %s", db.lastExecSQL) + } + if !contains(db.lastExecSQL, "tenant_id = $1") { + t.Fatalf("sql = %s", db.lastExecSQL) + } +} + +func TestClear_EmptyScope(t *testing.T) { + m := newTestMemory(t, &stubDB{}) + err := m.Clear(context.Background(), interfaces.MemoryScope{}) + if err == nil || !contains(err.Error(), "scope must include") { + t.Fatalf("err = %v", err) + } +} + +func TestEncodeDecodeScopeTags(t *testing.T) { + meta := map[string]string{ + "user_id": "u1", + "project_id": "p1", + "env": "prod", + } + encoded := encodeScopeTags(meta) + if len(encoded) != 2 { + t.Fatalf("encoded = %#v", encoded) + } + decoded := decodeScopeTags([]string{"project_id=p1", "env=prod"}) + if decoded["project_id"] != "p1" || decoded["env"] != "prod" { + t.Fatalf("decoded = %#v", decoded) + } +} + +func TestBuildScopeArgs_CustomTag(t *testing.T) { + sql, args := buildScopeArgs(interfaces.MemoryScope{ + Tags: map[string]string{"team": "alpha"}, + }, nil, 1) + if !contains(sql, "scope_tags @> $1::text[]") { + t.Fatalf("sql = %q", sql) + } + tags, ok := args[0].([]string) + if !ok || len(tags) != 1 || tags[0] != "team=alpha" { + t.Fatalf("args = %#v", args) + } +} + +func TestScanMemoryRows_MetadataError(t *testing.T) { + now := time.Now().UTC() + rows := &stubMemoryRows{data: []memoryRowData{{ + id: "x", + text: "t", + metadata: []byte("not-json"), + createdAt: now, + updatedAt: now, + }}} + _, err := scanMemoryRows(rows) + if err == nil || !contains(err.Error(), "unmarshal metadata") { + t.Fatalf("err = %v", err) + } +} + +func TestMarshalMetadata(t *testing.T) { + raw, err := marshalMetadata(nil) + if err != nil { + t.Fatal(err) + } + if string(raw) != "{}" { + t.Fatalf("raw = %s", raw) + } +} + +func contains(s, sub string) bool { + return len(sub) > 0 && len(s) >= len(sub) && (s == sub || containsStr(s, sub)) +} + +func containsStr(s, sub string) bool { + for i := 0; i <= len(s)-len(sub); i++ { + if s[i:i+len(sub)] == sub { + return true + } + } + return false +} diff --git a/pkg/memory/pgvector/schema.go b/pkg/memory/pgvector/schema.go new file mode 100644 index 0000000..1d656f5 --- /dev/null +++ b/pkg/memory/pgvector/schema.go @@ -0,0 +1,27 @@ +package pgvector + +import "github.com/agenticenv/agent-sdk-go/pkg/memory" + +// Default table and column names for [Memory]. +const ( + DefaultTable = "agent_memories" + DefaultTextCol = "text" + DefaultEmbeddingCol = "embedding" + + ColID = "id" + ColKind = "kind" + ColMetadata = "metadata" + ColScopeTags = "scope_tags" + ColExpiresAt = "expires_at" + ColCreatedAt = "created_at" + ColUpdatedAt = "updated_at" + ColUserID = memory.ScopeKeyUserID + ColTenantID = memory.ScopeKeyTenantID + ColAgentID = memory.ScopeKeyAgentID +) + +// DefaultLoadLimit is the maximum memories returned when [interfaces.WithLoadLimit] is zero or negative. +const DefaultLoadLimit = 10 + +// DefaultMinScore is the default cosine similarity when [interfaces.WithMinScore] is zero. +const DefaultMinScore float32 = 0.35 diff --git a/pkg/memory/weaviate/memory.go b/pkg/memory/weaviate/memory.go new file mode 100644 index 0000000..b8e028c --- /dev/null +++ b/pkg/memory/weaviate/memory.go @@ -0,0 +1,607 @@ +// Package weaviate provides a [interfaces.Memory] implementation backed by Weaviate. +// +// Expected class schema (vectorizer should target the text field): +// +// { +// "class": "AgentMemory", +// "vectorizer": "text2vec-...", +// "properties": [ +// {"name": "text", "dataType": ["text"]}, +// {"name": "kind", "dataType": ["text"]}, +// {"name": "user_id", "dataType": ["text"]}, +// {"name": "tenant_id", "dataType": ["text"]}, +// {"name": "agent_id", "dataType": ["text"]}, +// {"name": "scope_tags", "dataType": ["text[]"]}, +// {"name": "metadata", "dataType": ["text"]}, +// {"name": "expires_at", "dataType": ["date"]}, +// {"name": "created_at", "dataType": ["date"]}, +// {"name": "updated_at", "dataType": ["date"]} +// ] +// } +package weaviate + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "strings" + "time" + + "github.com/agenticenv/agent-sdk-go/internal/types" + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" + "github.com/agenticenv/agent-sdk-go/pkg/logger" + "github.com/agenticenv/agent-sdk-go/pkg/memory" + "github.com/google/uuid" + client "github.com/weaviate/weaviate-go-client/v5/weaviate" + "github.com/weaviate/weaviate-go-client/v5/weaviate/data" + "github.com/weaviate/weaviate-go-client/v5/weaviate/fault" + "github.com/weaviate/weaviate-go-client/v5/weaviate/filters" + "github.com/weaviate/weaviate-go-client/v5/weaviate/graphql" + "github.com/weaviate/weaviate/entities/models" +) + +var _ interfaces.Memory = (*Memory)(nil) + +// Memory stores and recalls agent memories in a Weaviate class. +type Memory struct { + className string + textField string + tenant string + + defaultLimit int + defaultMinScore float32 + + logger logger.Logger + client *client.Client + + host string + scheme string + logLevel string +} + +// Option configures [Memory]. +type Option func(*Memory) + +// WithClient sets the Weaviate client. +func WithClient(c *client.Client) Option { + return func(m *Memory) { m.client = c } +} + +// WithHost sets the Weaviate host (required when [WithClient] is omitted). +func WithHost(host string) Option { + return func(m *Memory) { m.host = host } +} + +// WithScheme sets the Weaviate scheme. Defaults to [types.DefaultScheme]. +func WithScheme(scheme string) Option { + return func(m *Memory) { m.scheme = scheme } +} + +// WithClassName sets the Weaviate class name. Defaults to [DefaultClassName]. +func WithClassName(className string) Option { + return func(m *Memory) { m.className = className } +} + +// WithTextField sets the property used for memory text and vectorization. Defaults to [DefaultTextField]. +func WithTextField(textField string) Option { + return func(m *Memory) { m.textField = textField } +} + +// WithTenant sets the Weaviate multi-tenancy tenant for all operations. +func WithTenant(tenant string) Option { + return func(m *Memory) { m.tenant = tenant } +} + +// WithDefaultLimit sets the load limit when callers omit [interfaces.WithLoadLimit]. +func WithDefaultLimit(limit int) Option { + return func(m *Memory) { m.defaultLimit = limit } +} + +// WithDefaultMinScore sets the nearText certainty when callers omit [interfaces.WithMinScore]. +func WithDefaultMinScore(minScore float32) Option { + return func(m *Memory) { m.defaultMinScore = minScore } +} + +// WithLogger sets the logger. +func WithLogger(l logger.Logger) Option { + return func(m *Memory) { m.logger = l } +} + +// WithLogLevel sets the log level when no logger is provided. +func WithLogLevel(logLevel string) Option { + return func(m *Memory) { m.logLevel = logLevel } +} + +// NewMemory builds a Weaviate-backed [interfaces.Memory]. When [WithClient] is omitted, host is required. +func NewMemory(opts ...Option) (*Memory, error) { + m := &Memory{} + for _, opt := range opts { + opt(m) + } + if m.className == "" { + m.className = DefaultClassName + } + if m.textField == "" { + m.textField = DefaultTextField + } + if m.defaultLimit <= 0 { + m.defaultLimit = DefaultLoadLimit + } + if m.defaultMinScore == 0 { + m.defaultMinScore = DefaultMinScore + } + if m.logLevel == "" { + m.logLevel = "error" + } + if m.logger == nil { + m.logger = logger.DefaultLogger(m.logLevel) + } + if m.client == nil { + if m.host == "" { + return nil, errors.New("host is required when not using WithClient") + } + if m.scheme == "" { + m.scheme = types.DefaultScheme + } + wc, err := client.NewClient(client.Config{ + Scheme: m.scheme, + Host: m.host, + }) + if err != nil { + return nil, fmt.Errorf("create weaviate client: %w", err) + } + m.client = wc + } + m.logger.Info(context.Background(), "weaviate memory built", + slog.String("scope", "weaviate-memory"), + slog.String("class", m.className), + slog.String("textField", m.textField), + slog.Int("defaultLimit", m.defaultLimit), + slog.Float64("defaultMinScore", float64(m.defaultMinScore)), + ) + return m, nil +} + +// Store persists a memory in scope and returns its ID. +func (m *Memory) Store(ctx context.Context, scope interfaces.MemoryScope, record interfaces.MemoryRecord, opts ...interfaces.StoreMemoryOption) (string, error) { + if m.client == nil { + return "", errors.New("client is not set") + } + + storeOpts := interfaces.StoreMemoryOptions{} + for _, opt := range opts { + opt(&storeOpts) + } + + now := time.Now().UTC() + props, err := buildProperties(m.textField, scope, record, now, storeOpts.ID == "") + if err != nil { + return "", err + } + + id := strings.TrimSpace(storeOpts.ID) + if id != "" { + updater := m.client.Data().Updater(). + WithClassName(m.className). + WithID(id). + WithProperties(props). + WithMerge() + if m.tenant != "" { + updater = updater.WithTenant(m.tenant) + } + if err := updater.Do(ctx); err != nil { + if !isNotFound(err) { + return "", fmt.Errorf("weaviate update memory: %w", err) + } + creator := m.client.Data().Creator(). + WithClassName(m.className). + WithID(id). + WithProperties(props) + if m.tenant != "" { + creator = creator.WithTenant(m.tenant) + } + wrapper, createErr := creator.Do(ctx) + if createErr != nil { + return "", fmt.Errorf("weaviate create memory: %w", createErr) + } + return objectID(wrapper), nil + } + return id, nil + } + + creator := m.client.Data().Creator(). + WithClassName(m.className). + WithProperties(props) + if m.tenant != "" { + creator = creator.WithTenant(m.tenant) + } + wrapper, err := creator.Do(ctx) + if err != nil { + return "", fmt.Errorf("weaviate create memory: %w", err) + } + if got := objectID(wrapper); got != "" { + return got, nil + } + return uuid.NewString(), nil +} + +// Load recalls memories within scope. Non-empty query uses nearText; empty query lists by updated_at. +func (m *Memory) Load(ctx context.Context, scope interfaces.MemoryScope, query string, opts ...interfaces.LoadMemoryOption) ([]interfaces.MemoryEntry, error) { + if m.client == nil { + return nil, errors.New("client is not set") + } + + loadOpts := interfaces.LoadMemoryOptions{} + for _, opt := range opts { + opt(&loadOpts) + } + limit := loadOpts.Limit + if limit <= 0 { + limit = m.defaultLimit + } + minScore := loadOpts.MinScore + if minScore == 0 { + minScore = m.defaultMinScore + } + + where := combineWhere( + scopeWhere(scope), + kindsWhere(loadOpts.Kinds), + ) + + fields := []graphql.Field{ + {Name: m.textField}, + {Name: PropKind}, + {Name: PropUserID}, + {Name: PropTenantID}, + {Name: PropAgentID}, + {Name: PropScopeTags}, + {Name: PropMetadata}, + {Name: PropExpiresAt}, + {Name: PropCreatedAt}, + {Name: PropUpdatedAt}, + {Name: "_additional { id certainty }"}, + } + + builder := m.client.GraphQL().Get(). + WithClassName(m.className). + WithLimit(limit). + WithFields(fields...). + WithSort(graphql.Sort{Path: []string{PropUpdatedAt}, Order: graphql.Desc}) + + if where != nil { + builder = builder.WithWhere(where) + } + + query = strings.TrimSpace(query) + if query != "" { + nearText := m.client.GraphQL(). + NearTextArgBuilder(). + WithConcepts([]string{query}). + WithCertainty(minScore) + builder = builder.WithNearText(nearText) + } + + result, err := builder.Do(ctx) + if err != nil { + return nil, fmt.Errorf("weaviate load memories: %w", err) + } + if err := graphqlErrors(result); err != nil { + return nil, fmt.Errorf("weaviate load memories: %w", err) + } + + entries, err := m.parseEntries(result) + if err != nil { + return nil, err + } + + if minScore > 0 && query != "" { + filtered := entries[:0] + for _, e := range entries { + if e.Score >= minScore { + filtered = append(filtered, e) + } + } + entries = filtered + } + + return entries, nil +} + +// Clear removes all memories matching scope. +func (m *Memory) Clear(ctx context.Context, scope interfaces.MemoryScope) error { + if m.client == nil { + return errors.New("client is not set") + } + + where := scopeWhere(scope) + if where == nil { + return errors.New("scope must include at least one non-empty field") + } + + deleter := m.client.Batch().ObjectsBatchDeleter(). + WithClassName(m.className). + WithWhere(where) + if m.tenant != "" { + deleter = deleter.WithTenant(m.tenant) + } + if _, err := deleter.Do(ctx); err != nil { + return fmt.Errorf("weaviate clear memories: %w", err) + } + return nil +} + +func (m *Memory) parseEntries(result *models.GraphQLResponse) ([]interfaces.MemoryEntry, error) { + if result == nil || result.Data == nil { + return nil, nil + } + + get, ok := result.Data["Get"].(map[string]interface{}) + if !ok { + return nil, errors.New("invalid response: missing Get") + } + + itemsRaw, ok := get[m.className] + if !ok || itemsRaw == nil { + return nil, nil + } + items, ok := itemsRaw.([]interface{}) + if !ok { + return nil, errors.New("invalid response: missing class data") + } + + entries := make([]interfaces.MemoryEntry, 0, len(items)) + for _, item := range items { + obj, ok := item.(map[string]interface{}) + if !ok { + m.logger.Warn(context.Background(), "weaviate memory: skipping non-object item", + slog.String("scope", "weaviate-memory"), + slog.String("class", m.className), + ) + continue + } + entry, err := parseEntry(obj, m.textField) + if err != nil { + return nil, err + } + if entry.Expired() { + continue + } + entries = append(entries, entry) + } + return entries, nil +} + +func buildProperties(textField string, scope interfaces.MemoryScope, record interfaces.MemoryRecord, now time.Time, setCreated bool) (map[string]interface{}, error) { + meta := memory.ScopeMetadata(scope) + props := map[string]interface{}{ + textField: record.Text, + PropKind: string(record.Kind), + PropUpdatedAt: now, + } + if setCreated { + props[PropCreatedAt] = now + } + + if v := meta[memory.ScopeKeyUserID]; v != "" { + props[PropUserID] = v + } + if v := meta[memory.ScopeKeyTenantID]; v != "" { + props[PropTenantID] = v + } + if v := meta[memory.ScopeKeyAgentID]; v != "" { + props[PropAgentID] = v + } + if tags := encodeScopeTags(meta); len(tags) > 0 { + props[PropScopeTags] = tags + } + if len(record.Metadata) > 0 { + raw, err := json.Marshal(record.Metadata) + if err != nil { + return nil, fmt.Errorf("marshal metadata: %w", err) + } + props[PropMetadata] = string(raw) + } + if !record.ExpiresAt.IsZero() { + props[PropExpiresAt] = record.ExpiresAt.UTC() + } + return props, nil +} + +func parseEntry(obj map[string]interface{}, textField string) (interfaces.MemoryEntry, error) { + entry := interfaces.MemoryEntry{ + ID: additionalID(obj), + Text: getString(obj, textField), + Kind: interfaces.MemoryKind(getString(obj, PropKind)), + Scope: interfaces.MemoryScope{ + UserID: getString(obj, PropUserID), + TenantID: getString(obj, PropTenantID), + AgentID: getString(obj, PropAgentID), + Tags: decodeScopeTags(obj[PropScopeTags]), + }, + ExpiresAt: getTime(obj, PropExpiresAt), + CreatedAt: getTime(obj, PropCreatedAt), + UpdatedAt: getTime(obj, PropUpdatedAt), + } + if raw := getString(obj, PropMetadata); raw != "" { + if err := json.Unmarshal([]byte(raw), &entry.Metadata); err != nil { + return interfaces.MemoryEntry{}, fmt.Errorf("unmarshal metadata: %w", err) + } + } + if additional, ok := obj["_additional"].(map[string]interface{}); ok { + if certainty, ok := additional["certainty"].(float64); ok { + entry.Score = float32(certainty) + } + } + return entry, nil +} + +func scopeWhere(scope interfaces.MemoryScope) *filters.WhereBuilder { + meta := memory.ScopeMetadata(scope) + if len(meta) == 0 { + return nil + } + + operands := make([]*filters.WhereBuilder, 0, len(meta)) + for key, value := range meta { + switch key { + case memory.ScopeKeyUserID, memory.ScopeKeyTenantID, memory.ScopeKeyAgentID: + operands = append(operands, filters.Where(). + WithPath([]string{key}). + WithOperator(filters.Equal). + WithValueString(value)) + default: + operands = append(operands, filters.Where(). + WithPath([]string{PropScopeTags}). + WithOperator(filters.ContainsAll). + WithValueText(scopeTagToken(key, value))) + } + } + return combineWhereOperands(operands) +} + +func kindsWhere(kinds []interfaces.MemoryKind) *filters.WhereBuilder { + if len(kinds) == 0 { + return nil + } + if len(kinds) == 1 { + return filters.Where(). + WithPath([]string{PropKind}). + WithOperator(filters.Equal). + WithValueString(string(kinds[0])) + } + values := make([]string, len(kinds)) + for i, kind := range kinds { + values[i] = string(kind) + } + return filters.Where(). + WithPath([]string{PropKind}). + WithOperator(filters.ContainsAny). + WithValueText(values...) +} + +func combineWhere(parts ...*filters.WhereBuilder) *filters.WhereBuilder { + operands := make([]*filters.WhereBuilder, 0, len(parts)) + for _, part := range parts { + if part != nil { + operands = append(operands, part) + } + } + return combineWhereOperands(operands) +} + +func combineWhereOperands(operands []*filters.WhereBuilder) *filters.WhereBuilder { + if len(operands) == 0 { + return nil + } + if len(operands) == 1 { + return operands[0] + } + return filters.Where().WithOperator(filters.And).WithOperands(operands) +} + +func encodeScopeTags(meta map[string]string) []string { + if len(meta) == 0 { + return nil + } + tags := make([]string, 0, len(meta)) + for key, value := range meta { + switch key { + case memory.ScopeKeyUserID, memory.ScopeKeyTenantID, memory.ScopeKeyAgentID: + continue + default: + tags = append(tags, scopeTagToken(key, value)) + } + } + return tags +} + +func graphqlErrors(result *models.GraphQLResponse) error { + if result == nil || len(result.Errors) == 0 { + return nil + } + msgs := make([]string, len(result.Errors)) + for i, e := range result.Errors { + if e != nil && e.Message != "" { + msgs[i] = e.Message + } + } + return fmt.Errorf("graphql: %s", strings.Join(msgs, "; ")) +} + +func decodeScopeTags(raw any) map[string]string { + values, ok := raw.([]interface{}) + if !ok || len(values) == 0 { + return nil + } + out := make(map[string]string, len(values)) + for _, item := range values { + token, ok := item.(string) + if !ok { + continue + } + key, value, ok := strings.Cut(token, "=") + if !ok || key == "" { + continue + } + out[key] = value + } + if len(out) == 0 { + return nil + } + return out +} + +func scopeTagToken(key, value string) string { + return key + "=" + value +} + +func objectID(wrapper *data.ObjectWrapper) string { + if wrapper == nil || wrapper.Object == nil || wrapper.Object.ID == "" { + return "" + } + return wrapper.Object.ID.String() +} + +func isNotFound(err error) bool { + var wcErr *fault.WeaviateClientError + if errors.As(err, &wcErr) { + return wcErr.StatusCode == 404 + } + return false +} + +func getString(obj map[string]interface{}, key string) string { + if v, ok := obj[key].(string); ok { + return v + } + return "" +} + +func getTime(obj map[string]interface{}, key string) time.Time { + raw := getString(obj, key) + if raw == "" { + return time.Time{} + } + t, err := time.Parse(time.RFC3339Nano, raw) + if err != nil { + t, err = time.Parse(time.RFC3339, raw) + if err != nil { + return time.Time{} + } + } + return t.UTC() +} + +func additionalID(obj map[string]interface{}) string { + additional, ok := obj["_additional"].(map[string]interface{}) + if !ok { + return "" + } + if id, ok := additional["id"].(string); ok { + return id + } + return "" +} diff --git a/pkg/memory/weaviate/memory_test.go b/pkg/memory/weaviate/memory_test.go new file mode 100644 index 0000000..e8c811f --- /dev/null +++ b/pkg/memory/weaviate/memory_test.go @@ -0,0 +1,439 @@ +package weaviate + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "github.com/agenticenv/agent-sdk-go/internal/types" + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" + weaviateclient "github.com/weaviate/weaviate-go-client/v5/weaviate" + "github.com/weaviate/weaviate/entities/models" +) + +type noopLogger struct{} + +func (noopLogger) Debug(ctx context.Context, msg string, _ ...any) {} +func (noopLogger) Info(ctx context.Context, msg string, _ ...any) {} +func (noopLogger) Warn(ctx context.Context, msg string, _ ...any) {} +func (noopLogger) Error(ctx context.Context, msg string, _ ...any) {} + +func testWeaviateHost(t *testing.T, srv *httptest.Server) string { + t.Helper() + u, err := url.Parse(srv.URL) + if err != nil { + t.Fatal(err) + } + return u.Host +} + +func graphQLData(entries map[string]interface{}) map[string]models.JSONObject { + out := make(map[string]models.JSONObject, len(entries)) + for k, v := range entries { + out[k] = v + } + return out +} + +func TestNew_MissingHost(t *testing.T) { + _, err := NewMemory() + if err == nil || !strings.Contains(err.Error(), "host") { + t.Fatalf("err = %v", err) + } +} + +func TestNew_Defaults(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + mem, err := NewMemory(WithHost(testWeaviateHost(t, srv))) + if err != nil { + t.Fatal(err) + } + if mem.className != DefaultClassName { + t.Fatalf("className = %q", mem.className) + } + if mem.textField != DefaultTextField { + t.Fatalf("textField = %q", mem.textField) + } + if mem.defaultLimit != DefaultLoadLimit { + t.Fatalf("defaultLimit = %d", mem.defaultLimit) + } + if mem.scheme != types.DefaultScheme { + t.Fatalf("scheme = %q", mem.scheme) + } +} + +func TestNew_WithClient(t *testing.T) { + wc, err := weaviateclient.NewClient(weaviateclient.Config{Scheme: "http", Host: "unused:0"}) + if err != nil { + t.Fatal(err) + } + mem, err := NewMemory( + WithClient(wc), + WithClassName("CustomMemory"), + WithTextField("body"), + WithDefaultLimit(3), + WithDefaultMinScore(0.5), + ) + if err != nil { + t.Fatal(err) + } + if mem.className != "CustomMemory" || mem.textField != "body" { + t.Fatalf("class/text = %q %q", mem.className, mem.textField) + } + if mem.defaultLimit != 3 || mem.defaultMinScore != 0.5 { + t.Fatalf("defaults = %d %v", mem.defaultLimit, mem.defaultMinScore) + } +} + +func TestStore_NoClient(t *testing.T) { + mem := &Memory{className: DefaultClassName, client: nil} + _, err := mem.Store(context.Background(), interfaces.MemoryScope{}, interfaces.MemoryRecord{Text: "x"}) + if err == nil || !strings.Contains(err.Error(), "client is not set") { + t.Fatalf("err = %v", err) + } +} + +func TestStore_Create(t *testing.T) { + const className = DefaultClassName + var gotBody string + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/v1/meta", "/v1/.well-known/ready": + w.WriteHeader(http.StatusOK) + return + case "/v1/objects": + if r.Method != http.MethodPost { + t.Errorf("method = %s", r.Method) + } + body, _ := io.ReadAll(r.Body) + gotBody = string(body) + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{"class":"AgentMemory","id":"11111111-1111-1111-1111-111111111111"}`) + return + default: + t.Errorf("unexpected path %s", r.URL.Path) + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + + mem, err := NewMemory(WithHost(testWeaviateHost(t, srv)), WithLogger(noopLogger{})) + if err != nil { + t.Fatal(err) + } + + id, err := mem.Store(context.Background(), + interfaces.MemoryScope{UserID: "u1", TenantID: "t1"}, + interfaces.MemoryRecord{ + Text: "likes dark mode", + Kind: interfaces.MemoryKind("preference"), + Metadata: map[string]string{"source": "run-1"}, + }, + ) + if err != nil { + t.Fatal(err) + } + if id != "11111111-1111-1111-1111-111111111111" { + t.Fatalf("id = %q", id) + } + if !strings.Contains(gotBody, className) { + t.Fatalf("body missing class: %s", gotBody) + } + if !strings.Contains(gotBody, "likes dark mode") || !strings.Contains(gotBody, "u1") { + t.Fatalf("body = %s", gotBody) + } +} + +func TestStore_UpsertUpdate(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/v1/meta", "/v1/.well-known/ready": + w.WriteHeader(http.StatusOK) + return + case "/v1/objects/22222222-2222-2222-2222-222222222222": + if r.Method != http.MethodPatch { + t.Errorf("method = %s", r.Method) + } + w.WriteHeader(http.StatusNoContent) + return + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + + mem, err := NewMemory(WithHost(testWeaviateHost(t, srv)), WithLogger(noopLogger{})) + if err != nil { + t.Fatal(err) + } + + id, err := mem.Store(context.Background(), + interfaces.MemoryScope{UserID: "u1"}, + interfaces.MemoryRecord{Text: "updated"}, + interfaces.WithMemoryID("22222222-2222-2222-2222-222222222222"), + ) + if err != nil { + t.Fatal(err) + } + if id != "22222222-2222-2222-2222-222222222222" { + t.Fatalf("id = %q", id) + } +} + +func TestLoad_Semantic(t *testing.T) { + mockBody := `{ + "data": { + "Get": { + "AgentMemory": [ + { + "text": "likes dark mode", + "kind": "preference", + "user_id": "u1", + "metadata": "{\"source\":\"run-1\"}", + "_additional": { "id": "abc", "certainty": 0.91 } + } + ] + } + } + }` + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/v1/meta", "/v1/.well-known/ready": + w.WriteHeader(http.StatusOK) + return + case "/v1/graphql": + body, _ := io.ReadAll(r.Body) + if !strings.Contains(string(body), "nearText") { + t.Errorf("body missing nearText: %s", body) + } + _, _ = io.WriteString(w, mockBody) + return + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + + mem, err := NewMemory(WithHost(testWeaviateHost(t, srv)), WithLogger(noopLogger{})) + if err != nil { + t.Fatal(err) + } + + entries, err := mem.Load(context.Background(), + interfaces.MemoryScope{UserID: "u1"}, + "theme preference", + interfaces.WithLoadLimit(5), + interfaces.WithMinScore(0.8), + ) + if err != nil { + t.Fatal(err) + } + if len(entries) != 1 { + t.Fatalf("len = %d", len(entries)) + } + if entries[0].ID != "abc" || entries[0].Text != "likes dark mode" || entries[0].Score != 0.91 { + t.Fatalf("entry = %#v", entries[0]) + } + if entries[0].Metadata["source"] != "run-1" { + t.Fatalf("metadata = %#v", entries[0].Metadata) + } +} + +func TestLoad_Recency(t *testing.T) { + mockBody := `{ + "data": { + "Get": { + "AgentMemory": [ + { + "text": "recent note", + "kind": "note", + "user_id": "u1", + "_additional": { "id": "note-1" } + } + ] + } + } + }` + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/v1/meta", "/v1/.well-known/ready": + w.WriteHeader(http.StatusOK) + return + case "/v1/graphql": + body, _ := io.ReadAll(r.Body) + bodyStr := string(body) + if strings.Contains(bodyStr, "nearText") { + t.Errorf("unexpected nearText for empty query: %s", body) + } + if !strings.Contains(bodyStr, "sort") { + t.Errorf("body missing sort: %s", body) + } + _, _ = io.WriteString(w, mockBody) + return + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + + mem, err := NewMemory(WithHost(testWeaviateHost(t, srv)), WithLogger(noopLogger{})) + if err != nil { + t.Fatal(err) + } + + entries, err := mem.Load(context.Background(), interfaces.MemoryScope{UserID: "u1"}, "") + if err != nil { + t.Fatal(err) + } + if len(entries) != 1 || entries[0].Text != "recent note" { + t.Fatalf("entries = %#v", entries) + } +} + +func TestClear(t *testing.T) { + var gotDelete bool + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/v1/meta", "/v1/.well-known/ready": + w.WriteHeader(http.StatusOK) + return + case "/v1/batch/objects": + if r.Method != http.MethodDelete { + t.Errorf("method = %s", r.Method) + } + gotDelete = true + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, `{"results":{"matches":1,"successful":1}}`) + return + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + + mem, err := NewMemory(WithHost(testWeaviateHost(t, srv)), WithLogger(noopLogger{})) + if err != nil { + t.Fatal(err) + } + + if err := mem.Clear(context.Background(), interfaces.MemoryScope{TenantID: "t1"}); err != nil { + t.Fatal(err) + } + if !gotDelete { + t.Fatal("expected batch delete") + } +} + +func TestClear_EmptyScope(t *testing.T) { + wc, _ := weaviateclient.NewClient(weaviateclient.Config{Scheme: "http", Host: "unused:0"}) + mem, err := NewMemory(WithClient(wc)) + if err != nil { + t.Fatal(err) + } + err = mem.Clear(context.Background(), interfaces.MemoryScope{}) + if err == nil || !strings.Contains(err.Error(), "scope must include") { + t.Fatalf("err = %v", err) + } +} + +func TestEncodeDecodeScopeTags(t *testing.T) { + meta := map[string]string{ + "user_id": "u1", + "project_id": "p1", + "env": "prod", + } + encoded := encodeScopeTags(meta) + if len(encoded) != 2 { + t.Fatalf("encoded = %#v", encoded) + } + + decoded := decodeScopeTags([]interface{}{"project_id=p1", "env=prod"}) + if decoded["project_id"] != "p1" || decoded["env"] != "prod" { + t.Fatalf("decoded = %#v", decoded) + } +} + +func TestParseEntry_MetadataError(t *testing.T) { + _, err := parseEntry(map[string]interface{}{ + "text": "x", + "metadata": "not-json", + }, DefaultTextField) + if err == nil { + t.Fatal("expected error") + } +} + +func TestBuildProperties_Metadata(t *testing.T) { + props, err := buildProperties(DefaultTextField, + interfaces.MemoryScope{Tags: map[string]string{"team": "a"}}, + interfaces.MemoryRecord{ + Text: "hello", + Metadata: map[string]string{"k": "v"}, + }, + mustParseTime(t, "2026-01-01T00:00:00Z"), + true, + ) + if err != nil { + t.Fatal(err) + } + raw, ok := props[PropMetadata].(string) + if !ok { + t.Fatalf("metadata type = %T", props[PropMetadata]) + } + var decoded map[string]string + if err := json.Unmarshal([]byte(raw), &decoded); err != nil { + t.Fatal(err) + } + if decoded["k"] != "v" { + t.Fatalf("decoded = %#v", decoded) + } + tags, ok := props[PropScopeTags].([]string) + if !ok || len(tags) != 1 || tags[0] != "team=a" { + t.Fatalf("tags = %#v", props[PropScopeTags]) + } +} + +func TestParseEntries_InvalidResponse(t *testing.T) { + mem := &Memory{className: DefaultClassName, textField: DefaultTextField, logger: noopLogger{}} + _, err := mem.parseEntries(&models.GraphQLResponse{ + Data: graphQLData(map[string]interface{}{"Get": "bad"}), + }) + if err == nil || !strings.Contains(err.Error(), "missing Get") { + t.Fatalf("err = %v", err) + } + + entries, err := mem.parseEntries(&models.GraphQLResponse{ + Data: graphQLData(map[string]interface{}{ + "Get": map[string]interface{}{DefaultClassName: nil}, + }), + }) + if err != nil { + t.Fatalf("null class should be empty result: %v", err) + } + if len(entries) != 0 { + t.Fatalf("entries = %#v", entries) + } +} + +func mustParseTime(t *testing.T, raw string) time.Time { + t.Helper() + ts, err := time.Parse(time.RFC3339, raw) + if err != nil { + t.Fatal(err) + } + return ts.UTC() +} diff --git a/pkg/memory/weaviate/schema.go b/pkg/memory/weaviate/schema.go new file mode 100644 index 0000000..92b5aad --- /dev/null +++ b/pkg/memory/weaviate/schema.go @@ -0,0 +1,25 @@ +package weaviate + +import "github.com/agenticenv/agent-sdk-go/pkg/memory" + +// Default Weaviate class and property names for [Memory]. +const ( + DefaultClassName = "AgentMemory" + DefaultTextField = "text" + + PropKind = "kind" + PropMetadata = "metadata" + PropScopeTags = "scope_tags" + PropExpiresAt = "expires_at" + PropCreatedAt = "created_at" + PropUpdatedAt = "updated_at" + PropUserID = memory.ScopeKeyUserID + PropTenantID = memory.ScopeKeyTenantID + PropAgentID = memory.ScopeKeyAgentID +) + +// DefaultLoadLimit is the maximum memories returned when [interfaces.WithLoadLimit] is zero or negative. +const DefaultLoadLimit = 10 + +// DefaultMinScore is the default nearText certainty when [interfaces.WithMinScore] is zero. +const DefaultMinScore float32 = 0.35 diff --git a/taskfiles/examples.yml b/taskfiles/examples.yml index d8839d5..0f10983 100644 --- a/taskfiles/examples.yml +++ b/taskfiles/examples.yml @@ -122,17 +122,21 @@ tasks: NAME: '{{.NAME}}' PROMPT: '{{.PROMPT | default ""}}' SKIP_RUN: '{{.SKIP_RUN | default "false"}}' + MEMORY_STORE_MODE: '{{.MEMORY_STORE_MODE | default ""}}' + RUN_LABEL: '{{.NAME}}{{if .MEMORY_STORE_MODE}} store={{.MEMORY_STORE_MODE}}{{end}}' cmds: - | if [ "{{.SKIP_RUN}}" = "true" ]; then RESULT="○ PLAN" if [ -n "{{.PROMPT}}" ]; then - echo "○ Plan {{.NAME}} ({{.RUNTIME}}) — go run ./{{.NAME}} \"{{.PROMPT}}\"" + echo "○ Plan {{.RUN_LABEL}} ({{.RUNTIME}}) — go run ./{{.NAME}} \"{{.PROMPT}}\"" + elif [ -n "{{.MEMORY_STORE_MODE}}" ]; then + echo "○ Plan {{.RUN_LABEL}} ({{.RUNTIME}}) — MEMORY_STORE_MODE={{.MEMORY_STORE_MODE}} go run ./{{.NAME}}" else - echo "○ Plan {{.NAME}} ({{.RUNTIME}}) — go run ./{{.NAME}}" + echo "○ Plan {{.RUN_LABEL}} ({{.RUNTIME}}) — go run ./{{.NAME}}" fi if [ -n "{{.REPORT_FILE}}" ]; then - echo "$RESULT {{.NAME}}" >> {{.REPORT_FILE}} + echo "$RESULT {{.RUN_LABEL}}" >> {{.REPORT_FILE}} fi if [ -n "{{.COUNT_FILE}}" ]; then read -r PASS FAIL < "{{.COUNT_FILE}}" 2>/dev/null || PASS=0 FAIL=0 @@ -141,7 +145,7 @@ tasks: fi exit 0 fi - echo "🚀 Running {{.NAME}} with {{.RUNTIME}} runtime..." + echo "🚀 Running {{.RUN_LABEL}} with {{.RUNTIME}} runtime..." set +e if [ -n "{{.LOG_FILE}}" ]; then if [ -n "{{.PROMPT}}" ]; then @@ -163,9 +167,9 @@ tasks: else RESULT="❌ FAIL" fi - echo "$RESULT {{.NAME}} runtime={{.RUNTIME}}" + echo "$RESULT {{.RUN_LABEL}} runtime={{.RUNTIME}}" if [ -n "{{.REPORT_FILE}}" ]; then - echo "$RESULT {{.NAME}}" >> {{.REPORT_FILE}} + echo "$RESULT {{.RUN_LABEL}}" >> {{.REPORT_FILE}} fi if [ -n "{{.COUNT_FILE}}" ]; then read -r PASS FAIL < "{{.COUNT_FILE}}" 2>/dev/null || PASS=0 FAIL=0 @@ -180,6 +184,7 @@ tasks: AGENT_RUNTIME: '{{.RUNTIME}}' # Batch only — not in .env.defaults; manual go run leaves unset (interactive y/n). EXAMPLES_AUTO_APPROVE: '{{.EXAMPLES_AUTO_APPROVE | default "true"}}' + MEMORY_STORE_MODE: '{{.MEMORY_STORE_MODE}}' exec:examples: internal: true @@ -210,6 +215,15 @@ tasks: - agent_with_subagents - agent_with_tools/approval - agent_with_run_async + EXAMPLES_MEMORY: + - NAME: agent_with_memory/weaviate + MEMORY_STORE_MODE: always + - NAME: agent_with_memory/weaviate + MEMORY_STORE_MODE: ondemand + - NAME: agent_with_memory/pgvector + MEMORY_STORE_MODE: always + - NAME: agent_with_memory/pgvector + MEMORY_STORE_MODE: ondemand EXAMPLES_WITH_PROMPTS: - agent_with_conversation - agent_with_stream_conversation @@ -233,6 +247,18 @@ tasks: LOG_FILE: '{{.LOG_FILE}}' COUNT_FILE: '{{.COUNT_FILE}}' SKIP_RUN: '{{.SKIP_RUN}}' + - for: + var: EXAMPLES_MEMORY + ignore_error: true + task: exec:example + vars: + NAME: '{{.ITEM.NAME}}' + MEMORY_STORE_MODE: '{{.ITEM.MEMORY_STORE_MODE}}' + RUNTIME: '{{.RUNTIME}}' + REPORT_FILE: '{{.REPORT_FILE}}' + LOG_FILE: '{{.LOG_FILE}}' + COUNT_FILE: '{{.COUNT_FILE}}' + SKIP_RUN: '{{.SKIP_RUN}}' - for: var: EXAMPLES_WITH_PROMPTS ignore_error: true