From 55a360f5074c82f5ccbfbca5e4589cb9a0822742 Mon Sep 17 00:00:00 2001 From: Vinod Vanjarapu Date: Sun, 24 May 2026 11:39:23 -0700 Subject: [PATCH 1/2] feat: add RAG retriever support with Weaviate/pgvector --- .github/ISSUE_TEMPLATE/bug_report.md | 2 +- .github/workflows/ci.yml | 2 +- CONTRIBUTING.md | 4 +- Makefile | 11 +- README.md | 193 ++++++-- examples/README.md | 20 + examples/agent_with_retriever/README.md | 72 +++ .../agent_with_retriever/common/config.go | 131 ++++++ .../common/embed_openai.go | 77 ++++ .../agent_with_retriever/common/embedding.go | 35 ++ examples/agent_with_retriever/common/opts.go | 46 ++ .../common/sample-documents.json | 26 ++ .../agent_with_retriever/pgvector/README.md | 136 ++++++ .../agent_with_retriever/pgvector/cleanup.sh | 24 + .../agent_with_retriever/pgvector/main.go | 93 ++++ .../agent_with_retriever/pgvector/setup.sh | 223 ++++++++++ .../agent_with_retriever/pgvector/setup.sql | 14 + .../agent_with_retriever/pgvector/verify.sh | 92 ++++ .../agent_with_retriever/weaviate/README.md | 136 ++++++ .../agent_with_retriever/weaviate/cleanup.sh | 24 + .../agent_with_retriever/weaviate/main.go | 83 ++++ .../agent_with_retriever/weaviate/setup.sh | 164 +++++++ examples/env.sample | 31 ++ go.mod | 70 ++- go.sum | 193 ++++---- internal/runtime/runtime.go | 17 +- internal/runtime/temporal/agent_workflow.go | 162 ++++++- .../runtime/temporal/agent_workflow_test.go | 267 +++++++++++ internal/runtime/temporal/config.go | 10 + internal/runtime/temporal/fingerprint.go | 7 + internal/runtime/temporal/fingerprint_test.go | 33 +- internal/runtime/temporal/runtime.go | 1 + internal/types/metrics.go | 15 +- internal/types/retriever.go | 30 ++ pkg/agent/a2a.go | 30 +- pkg/agent/config.go | 139 +++++- pkg/agent/config_test.go | 418 ++++++++++++++++++ pkg/agent/mcp.go | 30 +- pkg/agent/retriever.go | 176 ++++++++ pkg/agent/retriever_test.go | 323 ++++++++++++++ pkg/agent/runtime_factory.go | 1 + pkg/interfaces/mocks/mock_retriever.go | 65 +++ pkg/interfaces/retriever.go | 19 + pkg/retriever/pgvector/retriever.go | 295 ++++++++++++ pkg/retriever/pgvector/retriever_test.go | 402 +++++++++++++++++ pkg/retriever/weaviate/retriever.go | 280 ++++++++++++ pkg/retriever/weaviate/retriever_test.go | 364 +++++++++++++++ 47 files changed, 4767 insertions(+), 219 deletions(-) create mode 100644 examples/agent_with_retriever/README.md create mode 100644 examples/agent_with_retriever/common/config.go create mode 100644 examples/agent_with_retriever/common/embed_openai.go create mode 100644 examples/agent_with_retriever/common/embedding.go create mode 100644 examples/agent_with_retriever/common/opts.go create mode 100644 examples/agent_with_retriever/common/sample-documents.json create mode 100644 examples/agent_with_retriever/pgvector/README.md create mode 100755 examples/agent_with_retriever/pgvector/cleanup.sh create mode 100644 examples/agent_with_retriever/pgvector/main.go create mode 100755 examples/agent_with_retriever/pgvector/setup.sh create mode 100644 examples/agent_with_retriever/pgvector/setup.sql create mode 100755 examples/agent_with_retriever/pgvector/verify.sh create mode 100644 examples/agent_with_retriever/weaviate/README.md create mode 100755 examples/agent_with_retriever/weaviate/cleanup.sh create mode 100644 examples/agent_with_retriever/weaviate/main.go create mode 100755 examples/agent_with_retriever/weaviate/setup.sh create mode 100644 internal/types/retriever.go create mode 100644 pkg/agent/retriever.go create mode 100644 pkg/agent/retriever_test.go create mode 100644 pkg/interfaces/mocks/mock_retriever.go create mode 100644 pkg/interfaces/retriever.go create mode 100644 pkg/retriever/pgvector/retriever.go create mode 100644 pkg/retriever/pgvector/retriever_test.go create mode 100644 pkg/retriever/weaviate/retriever.go create mode 100644 pkg/retriever/weaviate/retriever_test.go diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index a954f04..9e60025 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -28,7 +28,7 @@ What actually happened. ## Environment -- Go version: (e.g. 1.25.x; must be ≥ the `go` line in `go.mod`) +- Go version: (e.g. 1.26.x; must be ≥ the `go` line in `go.mod`) - OS: (e.g. macOS, Linux) - Temporal: (local dev server, Temporal Cloud, version) - LLM provider: (OpenAI, Anthropic, Gemini) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a897707..8f0972b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -26,7 +26,7 @@ jobs: with: go-version-file: go.mod - # v2 is built with Go >= 1.25; v1.x binaries refuse modules targeting go 1.25+ + # v2 is built with Go >= 1.26; v1.x binaries refuse modules targeting go 1.26+ - name: Install golangci-lint run: go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@latest diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index a4ef7b4..5bb9741 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -8,7 +8,7 @@ Before contributing, ensure you have: | Requirement | Version / Notes | |-------------|-----------------| -| **Go** | **Minimum `go 1.25.0`** (see the `go` line in `go.mod`; use that version or newer). | +| **Go** | **Minimum `go 1.26.0`** (see the `go` line in `go.mod`; use that version or newer). | | **Temporal server** | Required for examples, CLI, and tests — see [Temporal setup](temporal-setup.md) | | **golangci-lint** | Required for `make lint` — install **v2** with Go **≥** the `go` line in `go.mod`: `go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@latest` | | **gofmt** | `make lint` runs `gofmt -s` check first; run `make fmt` to apply `gofmt -s -w` project-wide | @@ -78,7 +78,7 @@ make lint This runs `go vet` and `golangci-lint`. All contributions must pass lint with zero errors. -**golangci-lint vs Go version:** If you see `the Go language version used to build golangci-lint is lower than the targeted Go version`, your `golangci-lint` binary is too old for this module (Go 1.25+ requires **golangci-lint v2**). Reinstall: `go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@latest`, ensure `$(go env GOPATH)/bin` is on `PATH` ahead of any older install, then run `golangci-lint version` — it should report **v2.x** and a Go build **≥ 1.25**. +**golangci-lint vs Go version:** If you see `the Go language version used to build golangci-lint is lower than the targeted Go version`, your `golangci-lint` binary is too old for this module (Go 1.26+ requires **golangci-lint v2**). Reinstall: `go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@latest`, ensure `$(go env GOPATH)/bin` is on `PATH` ahead of any older install, then run `golangci-lint version` — it should report **v2.x** and a Go build **≥ 1.26**. ### 5. Generate coverage diff --git a/Makefile b/Makefile index 2a8876c..128bf74 100644 --- a/Makefile +++ b/Makefile @@ -3,11 +3,12 @@ BIN_DIR := cmd/bin BINARY := $(BIN_DIR)/agentctl GOPATH_BIN := $(shell go env GOPATH)/bin -# Go 1.25+: coverage merges via covdata; with GOTOOLCHAIN=auto the fetched toolchain can fail on -# packages with no tests ("go: no such tool covdata"). Pin minimum toolchain to module go line. +# Coverage merges via covdata; with GOTOOLCHAIN=auto the fetched toolchain can fail on +# packages with no tests ("go: no such tool covdata"). Pin to the exact toolchain line in go.mod +# (e.g. go1.26.0) — the bare language version (go1.26) is not a valid toolchain name. # https://github.com/golang/go/issues/75031 -GO_MOD_VERSION := $(shell awk '/^go / { print $$2; exit }' go.mod) -GOTOOLCHAIN_COVERAGE := go$(GO_MOD_VERSION)+auto +GO_TOOLCHAIN := $(shell awk '/^toolchain / { print $$2; exit }' go.mod) +GOTOOLCHAIN_COVERAGE := $(GO_TOOLCHAIN)+auto # Embedded in agentctl -version (git describe, or "dev" outside a repo) VERSION := $(shell git describe --tags --always --dirty 2>/dev/null || echo dev) LDFLAGS := -ldflags "-X main.version=$(VERSION)" @@ -68,7 +69,7 @@ spell: go run github.com/client9/misspell/cmd/misspell@latest -error . # Run linters (gofmt -s, misspell, go vet + golangci-lint). -# Use golangci-lint v2 when go.mod is 1.25+ — v1.x was built with Go 1.24 and errors on newer language targets. +# Use golangci-lint v2 when go.mod is 1.26+ — v1.x binaries error on newer language targets. lint: fmt-check spell @echo "==> Checking lints (go vet + golangci-lint)..." go vet ./... diff --git a/README.md b/README.md index b4b9a9a..fec4d63 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Agent SDK for Go - Temporal-first -**Build durable, production-grade AI agents in Go** — Temporal-backed workflows that survive crashes and deploys. OpenAI, Anthropic, Gemini, MCP, A2A, AG-UI, observability, streaming, sub-agents, and human-in-the-loop approvals. +**Build durable, production-grade AI agents in Go** — Temporal-backed workflows that survive crashes and deploys. See [Capabilities](#capabilities) for the full feature set. [![CI](https://github.com/agenticenv/agent-sdk-go/actions/workflows/ci.yml/badge.svg?branch=main)](https://github.com/agenticenv/agent-sdk-go/actions) [![Release](https://img.shields.io/github/v/release/agenticenv/agent-sdk-go?label=Release)](https://github.com/agenticenv/agent-sdk-go/releases) @@ -10,6 +10,50 @@ [![Mentioned in Awesome Go](https://awesome.re/mentioned-badge.svg)](https://github.com/avelino/awesome-go) > **Versioning:** [Semantic versioning](https://semver.org/); published lines are **git tags** (e.g. `v0.1.2`). See the **[latest release](https://github.com/agenticenv/agent-sdk-go/releases/latest)** — the README does not pin a patch number so it stays accurate after each tag. +> +> **Note:** Independent community library — **not** affiliated with Temporal Technologies. + +## Table of Contents + +- [Overview](#overview) +- [Capabilities](#capabilities) +- [Reference apps](#reference-apps) +- [Temporal Runtime](#temporal-runtime) + - [Durable agents: survive crashes, restarts, and deploys](#durable-agents-survive-crashes-restarts-and-deploys) + - [Streaming and approvals](#streaming-and-approvals) +- [Getting Started](#getting-started) + - [Prerequisites](#prerequisites) + - [Create an agent and run](#create-an-agent-and-run) + - [Temporal connection](#temporal-connection) + - [LLM providers](#create-an-llm-client-openai-anthropic-or-gemini) + - [Stream events](#stream-events-stream) + - [Token usage](#token-usage-llmusage) + - [Tools](#tools) + - [MCP](#mcp-model-context-protocol) + - [A2A](#a2a-agent-to-agent) + - [Retrieval (RAG)](#retrieval-rag) + - [Sub-agents](#sub-agents) + - [Approvals](#approvals) + - [Timeouts and deadlines](#timeouts-and-deadlines) + - [Custom tools](#custom-tools) + - [Response format](#response-format) + - [Reasoning / extended thinking](#reasoning--extended-thinking) + - [Multiple agents](#multiple-agents) + - [Agent and worker in separate processes](#agent-and-worker-in-separate-processes) + - [Conversation](#conversation-message-history) + - [AG-UI Protocol](#ag-ui-protocol) +- [Observability](#observability) + - [Wire OTLP](#wire-otlp-traces--metrics--logs-in-one-block) + - [Bring your own tracer / metrics](#bring-your-own-tracer--metrics) + - [Traces](#traces-spans) + - [Metrics](#metrics) + - [Logs](#logs) +- [Configuration](#configuration) +- [Development](#development) + - [Code Coverage](#code-coverage) +- [Setup and run examples](#setup-and-run-examples) +- [Production Readiness Checklist](#production-readiness-checklist) +- [Disclaimer](#disclaimer) ## Overview @@ -19,25 +63,31 @@ `pkg/agent` exposes three entry points — `Run`, `Stream`, and `RunAsync` — each mapped directly to a Temporal workflow. Connect via `WithTemporalConfig` or `WithTemporalClient` to your cluster. See [Getting Started](#getting-started) to set up, or [Temporal Runtime](#temporal-runtime) for deeper detail on workers, queues, and streaming. -> **Note:** Independent community library — **not** affiliated with Temporal Technologies. - ## Capabilities -> Every run is a Temporal workflow — with full replay, retry, A2A protocol support, AG-UI event streaming, and built-in observability. No in-memory execution path. +> Every agent run is a Temporal workflow: durable, replay-safe, and observable. No in-memory execution path. - **LLM providers** — OpenAI, Anthropic, and Gemini out of the box; bring your own via `interfaces.LLMClient`. +- **Tools** — Register built-in or custom tools via `interfaces.Tool`; optional **parallel vs sequential** execution for multiple tool calls in one LLM round (`WithAgentToolExecutionMode`). +- **Human-in-the-loop** — Approval gates on tool calls and delegation across `Run`, `RunAsync`, and `Stream`. +- **Conversation** — Persist multi-turn message history across runs via `WithConversation`; built-in in-memory and Redis stores, or bring your own. +- **Sub-agents** — Delegate to specialist agents via `WithSubAgents`. +- **MCP** — Extend agent capabilities by connecting any MCP server as a tool source via `WithMCPConfig` or `WithMCPClients`. +- **A2A** — Connect remote [Agent-to-Agent](https://github.com/a2aproject/A2A) agents as tool providers via `WithA2AConfig` or `WithA2AClients`; or expose the agent itself as an A2A server via `WithA2ADefaultServer` / `WithA2AServer` and `RunA2A`. +- **Retrieval (RAG)** — Ground agent responses in external knowledge bases via a pluggable `Retriever` interface with built-in Weaviate and pgvector support; extend with your own implementation. - **Streaming** — Partial tokens and events via `Stream` and `WithStream`. - **AG-UI** — Stream events conform to the [AG-UI protocol](https://docs.ag-ui.com); agents work out of the box with any AG-UI compatible frontend such as [CopilotKit](https://copilotkit.ai). - **Reasoning** — Extended thinking / chain-of-thought where supported (Anthropic, Gemini). - **Token usage** — Track input, output, and reasoning token counts per run. -- **Tools** — Register built-in or custom tools via `interfaces.Tool`; optional **parallel vs sequential** execution for multiple tool calls in one LLM round (`WithAgentToolExecutionMode`). -- **MCP** — Extend agent capabilities by connecting any MCP server as a tool source via `WithMCPConfig` or `WithMCPClients`. -- **A2A** — Connect remote [Agent-to-Agent](https://github.com/a2aproject/A2A) agents as tool providers via `WithA2AConfig` or `WithA2AClients`; or expose the agent itself as an A2A server via `WithA2ADefaultServer` / `WithA2AServer` and `RunA2A`. -- **Human-in-the-loop** — Approval gates on tool calls and delegation across `Run`, `RunAsync`, and `Stream`. -- **Sub-agents** — Delegate to specialist agents via `WithSubAgents`. - **Scale** — Add Temporal workers to scale agent execution horizontally. - **Observability** — OpenTelemetry traces, metrics, and structured logs across all agent execution paths; export to any OTLP-compatible backend. +## Reference apps + +Demo applications that use **agent-sdk-go** end-to-end: + +- **[Agent Chat](https://github.com/agenticenv/agent-chat)** — Web chat demo with durable conversations; a good reference for wiring the SDK into an HTTP-backed app. + ## Temporal Runtime **Temporal** powers agents through three moving parts: a **Temporal client** that launches agent workflows, **workers** (typically `NewAgentWorker`) that poll task queues and execute workflow and activity code, and **workflow history** that makes each run durable. Workers are stateless — they replay and advance history, not hold state themselves. @@ -62,8 +112,6 @@ graph TD Child --> Mem2[Activity: save memory] ``` - - Details: [Temporal connection](#temporal-connection), [Sub-agents](#sub-agents), [Agent and worker in separate processes](#agent-and-worker-in-separate-processes). ### Durable agents: survive crashes, restarts, and deploys @@ -82,44 +130,13 @@ Stream events and approval events cross two boundaries: **Temporal** (durable wo - **Your responsibility.** Keep worker processes supervised and restarting on crash, maintain a stable connection to your Temporal cluster, and ensure stream subscribers can reconnect. - **Client reconnection and UX.** For interactive apps, if the process serving `Stream` crashes, the workflow continues in Temporal but your client loses the connection. Once a stream is lost, reconnecting to that specific run is not supported — the recommended approach is to block the user from sending a new prompt until the current one completes, then fetch the final response and display it. This keeps conversation turns sequential and avoids out-of-order state. For autonomous agents, this is a non-issue since the caller waits for completion and the workflow finishes regardless. -## AG-UI Protocol - -Agent stream events follow the [AG-UI open protocol](https://docs.ag-ui.com), making your agents natively compatible with any AG-UI frontend without extra integration work. - -Events like `RUN_STARTED`, `TEXT_MESSAGE_CONTENT`, `TOOL_CALL_START`, and `REASONING_MESSAGE_CONTENT` are emitted in the correct AG-UI sequence during every `Stream()` call. Serialize any event with `event.ToJSON()` and forward it over SSE, WebSocket, or Redis to a TypeScript/React frontend using the AG-UI client SDK. - -For a complete server + UI reference, see `[examples/agent_copilotkit](examples/agent_copilotkit)` (Go SSE server in `server/main.go`, Next.js + CopilotKit bridge in `ui/app/api/copilotkit/route.ts`). - -```go -ch, err := a.Stream(ctx, prompt, conversationID) -if err != nil { - return err -} -for ev := range ch { - if ev == nil { - continue - } - data, err := ev.ToJSON() - if err != nil { - continue - } - _ = data // e.g. SSE or WebSocket -} -``` - -## Reference apps - -Demo applications that use **agent-sdk-go** end-to-end. More may be added over time (e.g. web apps, autonomous agents, other integration patterns). - -- **[Agent Chat](https://github.com/agenticenv/agent-chat)** — Web chat demo with durable conversations; a good reference for wiring the SDK into an HTTP-backed app. - ## Getting Started How to **use** the SDK—agents, LLMs, Temporal connection, examples. ### Prerequisites -**agent-sdk-go** runs agents on the **[Temporal](https://temporal.io)** runtime (durable workflows and activities), so a **running Temporal server** is required. See **[Temporal setup](temporal-setup.md)**. Also **Go 1.25+** (see `go.mod`) and credentials for your LLM provider. +**agent-sdk-go** runs agents on the **[Temporal](https://temporal.io)** runtime (durable workflows and activities), so a **running Temporal server** is required. See **[Temporal setup](temporal-setup.md)**. Also **Go 1.26+** (see `go.mod`) and credentials for your LLM provider. **Module:** `github.com/agenticenv/agent-sdk-go` @@ -540,6 +557,71 @@ You may use **Option 1** for some remote agents and **Option 2** for others on t [examples/agent_with_a2a_config](examples/agent_with_a2a_config) and [examples/agent_with_a2a_client](examples/agent_with_a2a_client) show A2A from env (`A2A_URL`, optional bearer/headers/filter). Variables: [examples/env.sample](examples/env.sample). Running examples from `examples/`: [examples/README.md](examples/README.md). **Remote agent setup (e.g. `a2a-samples` helloworld), curl checks:** [examples/agent_with_a2a_config/README.md](examples/agent_with_a2a_config/README.md). +### Retrieval (RAG) + +Retrieval-Augmented Generation (RAG) lets agents query external knowledge bases and ground responses in up-to-date or domain-specific content — without hardcoding it into the prompt. + +Built-in retriever implementations are in `pkg/retriever/weaviate` and `pkg/retriever/pgvector`. Bring your own by implementing `interfaces.Retriever` (`Name`, `Search`). + +**Retriever modes** + +- **Agentic** (default) — LLM decides when to call the retriever as a tool, the same way it calls any other tool. Best for multi-step agents where retrieval is not always needed. +- **Prefetch** — Retrieval fires before every LLM call. Retrieved context is injected automatically. Best for always-grounded Q&A or enterprise knowledge-base scenarios. +- **Hybrid** — Both: retriever context is pre-fetched and injected (prefetch), and the LLM can also call the retriever as a tool (agentic). + +Set mode with `agent.WithRetrieverMode`: + +```go +agent.WithRetrieverMode(agent.RetrieverModeAgentic) // default +agent.WithRetrieverMode(agent.RetrieverModePrefetch) +agent.WithRetrieverMode(agent.RetrieverModeHybrid) // prefetch + agentic +``` + +**Weaviate** (local Docker, zero auth for dev): + +```go +import "github.com/agenticenv/agent-sdk-go/pkg/retriever/weaviate" + +r, err := weaviate.NewRetriever("product_knowledge", + weaviate.WithHost("localhost:8080"), + weaviate.WithClassName("ProductDocs"), +) + +a, _ := agent.NewAgent( + agent.WithRetrievers(r), + agent.WithRetrieverMode(agent.RetrieverModeAgentic), + ... +) +``` + +**pgvector** (Postgres with pgvector extension; requires an embed function): + +```go +import "github.com/agenticenv/agent-sdk-go/pkg/retriever/pgvector" + +r, err := pgvector.NewRetriever("support_knowledge", embedFn, + pgvector.WithDSN("postgres://user:pass@localhost:5432/mydb"), + pgvector.WithTable("documents"), +) +``` + +**Custom retriever** — implement `interfaces.Retriever`: + +```go +type Retriever interface { + Name() string + Search(ctx context.Context, query string) ([]interfaces.Document, error) +} +``` + +**Multiple retrievers** — pass as many as needed; each must have a unique name: + +```go +agent.WithRetrievers(productRetriever, supportRetriever) +``` + +[examples/agent_with_retriever/weaviate](examples/agent_with_retriever/weaviate) · [examples/agent_with_retriever/pgvector](examples/agent_with_retriever/pgvector) + ### Sub-agents Build each specialist with `NewAgent` (its own `TaskQueue`, LLM, tools, and prompts). Register specialists on the main agent with `WithSubAgents`. Use `WithName` and `WithDescription` when you want clearer labels for routing. Use `WithMaxSubAgentDepth` only if the default nesting limit is not enough. Run `Run`, `Stream`, or `RunAsync` on the main agent. Sub-agents always run without a conversation ID—they do not inherit the main agent session history. If you use `DisableLocalWorker`, pair each `NewAgentWorker` with the same options as the `NewAgent` that runs that agent. @@ -905,6 +987,31 @@ a.Run(ctx, "What's my name?", convID) // agent uses history: "Alice" [examples/agent_with_conversation](examples/agent_with_conversation) +### AG-UI Protocol + +Agent stream events follow the [AG-UI open protocol](https://docs.ag-ui.com), making your agents natively compatible with any AG-UI frontend without extra integration work. + +Events like `RUN_STARTED`, `TEXT_MESSAGE_CONTENT`, `TOOL_CALL_START`, and `REASONING_MESSAGE_CONTENT` are emitted in the correct AG-UI sequence during every `Stream()` call. Serialize any event with `event.ToJSON()` and forward it over SSE, WebSocket, or Redis to a TypeScript/React frontend using the AG-UI client SDK. + +For a complete server + UI reference, see `[examples/agent_copilotkit](examples/agent_copilotkit)` (Go SSE server in `server/main.go`, Next.js + CopilotKit bridge in `ui/app/api/copilotkit/route.ts`). + +```go +ch, err := a.Stream(ctx, prompt, conversationID) +if err != nil { + return err +} +for ev := range ch { + if ev == nil { + continue + } + data, err := ev.ToJSON() + if err != nil { + continue + } + _ = data // e.g. SSE or WebSocket +} +``` + --- ## Observability diff --git a/examples/README.md b/examples/README.md index 14031a9..68e16a0 100644 --- a/examples/README.md +++ b/examples/README.md @@ -35,6 +35,7 @@ The examples use `TEMPORAL_HOST`, `TEMPORAL_PORT`, and `TEMPORAL_NAMESPACE` from | `agent_with_a2a_client` | Same env, explicit **`pkg/a2a/client`** — **[README](agent_with_a2a_client/README.md)** | | `agent_with_a2a_server` | **Inbound** A2A server — **`A2A_SERVER_*`**; **[README](agent_with_a2a_server/README.md)** (curl, **`a2a` CLI**, client example) | | `agent_with_observability` | OpenTelemetry OTLP exports — two runnable programs: **`config/`** ([`WithObservabilityConfig`](../pkg/agent/config.go)) vs **`objects/`** (pre-built [`pkg/observability`](../pkg/observability/) tracer/metrics + [`WithTracer`](../pkg/agent/config.go) / [`WithMetrics`](../pkg/agent/config.go)); shared **`setup/`** helper package — **[README](agent_with_observability/README.md)** (collector endpoint, ports **`4317`**/**`4318`**) | +| `agent_with_retriever` | Vector retrievers — **`weaviate/`** or **`pgvector/`** backends; shared **`common/`**; modes **`agentic`**, **`prefetch`**, **`hybrid`** via **`RETRIEVER_MODE`** — **[README](agent_with_retriever/README.md)** (Weaviate / Postgres setup in subfolder READMEs) | ## Setup @@ -189,6 +190,22 @@ go run ./agent_with_observability/objects/ "Say hello in one sentence" Details, env semantics, and collector notes: **[agent_with_observability/README.md](agent_with_observability/README.md)**. +### Vector retriever (`agent_with_retriever`) + +Requires a running vector store (Weaviate **or** Postgres with pgvector) plus Temporal and LLM env. Set backend-specific vars in **`env.sample`** (`WEAVIATE_*` or **`PGVECTOR_DSN`**). + +```bash +# Weaviate (run ./agent_with_retriever/weaviate/setup.sh; ./cleanup.sh when done) +go run ./agent_with_retriever/weaviate "What is the return policy?" + +# pgvector (run ./agent_with_retriever/pgvector/setup.sh; ./cleanup.sh when done) +go run ./agent_with_retriever/pgvector "What is the return policy?" + +RETRIEVER_MODE=prefetch go run ./agent_with_retriever/weaviate "What are the return and shipping rules?" +``` + +Setup guides: **[agent_with_retriever/README.md](agent_with_retriever/README.md)**, **[weaviate/README.md](agent_with_retriever/weaviate/README.md)**, **[pgvector/README.md](agent_with_retriever/pgvector/README.md)**. + ## Logging Examples send conversation (user prompt, assistant response) to **stdout** and internal logs to **stderr**. By default only errors are logged. @@ -242,3 +259,6 @@ Examples send conversation (user prompt, assistant response) to **stdout** and i | `OTEL_EXPORTER_OTLP_ENDPOINT` | **Required** for **`agent_with_observability`** examples: OTLP collector **`host:port`** only (no `http://` scheme), e.g. **`localhost:4317`** (gRPC) or **`localhost:4318`** (HTTP) | | `OTLP_PROTOCOL` | Optional for **`agent_with_observability`**: **`grpc`** (default) or **`http`** — must match how the collector listens | | `OTLP_INSECURE` | Optional: set to **`true`** for plaintext export (typical for local collectors without TLS) | +| `RETRIEVER_MODE` | For **`agent_with_retriever`**: **`agentic`** (default), **`prefetch`**, or **`hybrid`** | +| `WEAVIATE_HOST`, `WEAVIATE_SCHEME`, `WEAVIATE_CLASS`, … | Weaviate backend — see **`env.sample`** and **[agent_with_retriever/weaviate/README.md](agent_with_retriever/weaviate/README.md)** | +| `PGVECTOR_DSN`, `PGVECTOR_TABLE`, `EMBEDDING_MODEL`, … | pgvector backend — **`PGVECTOR_DSN` required**; see **[agent_with_retriever/pgvector/README.md](agent_with_retriever/pgvector/README.md)** | diff --git a/examples/agent_with_retriever/README.md b/examples/agent_with_retriever/README.md new file mode 100644 index 0000000..6d4e70b --- /dev/null +++ b/examples/agent_with_retriever/README.md @@ -0,0 +1,72 @@ +# Agent with retriever (`agent_with_retriever`) + +Examples that wire a **vector retriever** into **agent-sdk-go**. Pick **one backend** per run. + +| Backend | Directory | Guide | +|---------|-----------|--------| +| Weaviate | [`weaviate/`](weaviate/) | [`weaviate/README.md`](weaviate/README.md) | +| PostgreSQL + pgvector | [`pgvector/`](pgvector/) | [`pgvector/README.md`](pgvector/README.md) | + +Shared sample data: [`common/sample-documents.json`](common/sample-documents.json). + +## Prerequisites + +- **Temporal** — [`temporal-setup.md`](../../temporal-setup.md) +- **LLM** — `LLM_APIKEY`, `LLM_MODEL` in `examples/.env` ([`env.sample`](../env.sample)) +- **Vector store** — set up via `./setup.sh` in the backend folder you choose + +## Quick start + +```bash +cd examples +cp env.sample .env +# Edit .env: LLM keys and backend vars (see env.sample) + +# Weaviate +cd agent_with_retriever/weaviate && ./setup.sh && cd ../.. +go run ./agent_with_retriever/weaviate "What is the return policy?" + +# pgvector +cd agent_with_retriever/pgvector && ./setup.sh && cd ../.. +go run ./agent_with_retriever/pgvector "What is the return policy?" +``` + +Cleanup: `./cleanup.sh` in the backend folder when done. + +## Retriever modes + +Set `RETRIEVER_MODE` in `.env` (default `agentic`): + +| Mode | Behavior | +|------|----------| +| `agentic` | Retriever exposed as a tool; LLM decides when to search | +| `prefetch` | Search runs once before the first LLM call; context injected into system prompt | +| `hybrid` | Prefetch and retriever tools | + +```bash +RETRIEVER_MODE=prefetch go run ./agent_with_retriever/weaviate "What is the return policy?" +``` + +## Troubleshooting + +| Issue | Where to look | +|-------|----------------| +| Weaviate setup, search, vectorizer | [`weaviate/README.md`](weaviate/README.md#troubleshooting) | +| pgvector setup, embeddings, `minScore` | [`pgvector/README.md`](pgvector/README.md#troubleshooting) | + +**Common checks (all examples):** + +- **Temporal** running — see [`temporal-setup.md`](../../temporal-setup.md) +- **`examples/.env`** — `LLM_APIKEY`, `LLM_MODEL`, and backend vars from [`env.sample`](../env.sample) +- **Vector store up** — `./setup.sh` in `weaviate/` or `pgvector/` before `go run` +- **Retriever mode** — `RETRIEVER_MODE=agentic|prefetch|hybrid` in `.env` +- **Debug** — `LOG_LEVEL=debug go run ./agent_with_retriever/ "..."` + +**pgvector + Anthropic/Gemini chat:** set `EMBEDDING_APIKEY` (OpenAI) in `.env`; chat `LLM_APIKEY` is not used for embeddings. + +**Clean restart a backend:** + +```bash +cd agent_with_retriever/weaviate # or pgvector +./cleanup.sh && ./setup.sh +``` diff --git a/examples/agent_with_retriever/common/config.go b/examples/agent_with_retriever/common/config.go new file mode 100644 index 0000000..4e9cee2 --- /dev/null +++ b/examples/agent_with_retriever/common/config.go @@ -0,0 +1,131 @@ +// Package common holds shared configuration and agent options for the agent_with_retriever examples. +package common + +import ( + "fmt" + "os" + "strconv" + "strings" + + "github.com/agenticenv/agent-sdk-go/pkg/agent" +) + +// Settings holds env-driven values shared by the weaviate and pgvector example entry points. +type Settings struct { + // RetrieverMode is agentic, prefetch, or hybrid (see agent.WithRetrieverMode). + RetrieverMode agent.RetrieverMode + + // Weaviate + WeaviateHost string + WeaviateScheme string + WeaviateClass string + WeaviateRetrieverName string + WeaviateContentField string + WeaviateSourceField string + WeaviateTopK int + WeaviateMinScore float64 + + // PostgreSQL / pgvector + PGDSN string + PGTable string + PGContentCol string + PGSourceCol string + PGEmbeddingCol string + PGRetrieverName string + PGTopK int + PGMinScore float64 + EmbeddingModel string + EmbeddingBaseURL string + EmbeddingAPIKey string +} + +func getEnv(key, def string) string { + if v := os.Getenv(key); v != "" { + return v + } + return def +} + +func getEnvInt(key string, def int) int { + if v := os.Getenv(key); v != "" { + if i, err := strconv.Atoi(v); err == nil { + return i + } + } + return def +} + +func getEnvFloat(key string, def float64) float64 { + if v := os.Getenv(key); v != "" { + if f, err := strconv.ParseFloat(v, 64); err == nil { + return f + } + } + return def +} + +// LoadSettings reads retriever example env vars. LLM and Temporal vars come from examples/config.LoadFromEnv. +func LoadSettings() (*Settings, error) { + mode, err := ParseRetrieverMode(strings.TrimSpace(getEnv("RETRIEVER_MODE", "agentic"))) + if err != nil { + return nil, err + } + s := &Settings{ + RetrieverMode: mode, + + WeaviateHost: getEnv("WEAVIATE_HOST", "localhost:8080"), + WeaviateScheme: getEnv("WEAVIATE_SCHEME", "http"), + WeaviateClass: getEnv("WEAVIATE_CLASS", "Document"), + WeaviateRetrieverName: getEnv("WEAVIATE_RETRIEVER_NAME", "weaviate-kb"), + WeaviateContentField: getEnv("WEAVIATE_CONTENT_FIELD", "content"), + WeaviateSourceField: getEnv("WEAVIATE_SOURCE_FIELD", "source"), + WeaviateTopK: getEnvInt("WEAVIATE_TOP_K", 0), + WeaviateMinScore: getEnvFloat("WEAVIATE_MIN_SCORE", 0), + + PGDSN: strings.TrimSpace(getEnv("PGVECTOR_DSN", "")), + PGTable: getEnv("PGVECTOR_TABLE", "documents"), + PGContentCol: getEnv("PGVECTOR_CONTENT_COL", "content"), + PGSourceCol: getEnv("PGVECTOR_SOURCE_COL", "source"), + PGEmbeddingCol: getEnv("PGVECTOR_EMBEDDING_COL", "embedding"), + PGRetrieverName: getEnv("PGVECTOR_RETRIEVER_NAME", "pgvector-kb"), + PGTopK: getEnvInt("PGVECTOR_TOP_K", 0), + // Example default 0.35 — sample KB often scores 0.3–0.6 per topic; 0.5 drops secondary docs on combined queries. + PGMinScore: getEnvFloat("PGVECTOR_MIN_SCORE", 0.35), + EmbeddingModel: getEnv("EMBEDDING_MODEL", "text-embedding-3-small"), + EmbeddingBaseURL: strings.TrimSpace(getEnv("EMBEDDING_BASEURL", "")), + EmbeddingAPIKey: strings.TrimSpace(getEnv("EMBEDDING_APIKEY", "")), + } + if s.EmbeddingBaseURL == "" { + s.EmbeddingBaseURL = strings.TrimSpace(getEnv("LLM_BASEURL", "https://api.openai.com/v1")) + } + if s.EmbeddingAPIKey == "" { + s.EmbeddingAPIKey = strings.TrimSpace(getEnv("LLM_APIKEY", "")) + } + return s, nil +} + +// ParseRetrieverMode maps env text to agent.RetrieverMode. +func ParseRetrieverMode(raw string) (agent.RetrieverMode, error) { + switch strings.ToLower(strings.TrimSpace(raw)) { + case "", "agentic": + return agent.RetrieverModeAgentic, nil + case "prefetch": + return agent.RetrieverModePrefetch, nil + case "hybrid": + return agent.RetrieverModeHybrid, nil + default: + return "", fmt.Errorf("retriever: unknown RETRIEVER_MODE %q (use agentic, prefetch, or hybrid)", raw) + } +} + +// ModeHint returns a short phrase describing how the current mode uses retrievers. +func ModeHint(mode agent.RetrieverMode) string { + switch mode { + case agent.RetrieverModePrefetch: + return "context is prefetched before the first LLM call (no retriever tools)" + case agent.RetrieverModeHybrid: + return "context is prefetched and retriever tools remain available" + default: + return "the LLM may call retriever_* tools when it needs documents" + } +} diff --git a/examples/agent_with_retriever/common/embed_openai.go b/examples/agent_with_retriever/common/embed_openai.go new file mode 100644 index 0000000..90b453f --- /dev/null +++ b/examples/agent_with_retriever/common/embed_openai.go @@ -0,0 +1,77 @@ +package common + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + pgretriever "github.com/agenticenv/agent-sdk-go/pkg/retriever/pgvector" +) + +// OpenAIEmbedFunc returns an [pgretriever.EmbedFunc] that calls an OpenAI-compatible embeddings API. +func OpenAIEmbedFunc(settings *Settings) (pgretriever.EmbedFunc, error) { + if settings == nil { + return nil, fmt.Errorf("embed: settings is nil") + } + if settings.EmbeddingAPIKey == "" { + return nil, fmt.Errorf("embed: EMBEDDING_APIKEY or LLM_APIKEY is required for pgvector") + } + model := strings.TrimSpace(settings.EmbeddingModel) + if model == "" { + return nil, fmt.Errorf("embed: EMBEDDING_MODEL is required") + } + base := strings.TrimRight(strings.TrimSpace(settings.EmbeddingBaseURL), "/") + client := &http.Client{Timeout: 60 * time.Second} + + return func(ctx context.Context, text string) ([]float32, error) { + body, err := json.Marshal(map[string]any{ + "input": text, + "model": model, + }) + if err != nil { + return nil, err + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, base+"/embeddings", bytes.NewReader(body)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+settings.EmbeddingAPIKey) + + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + raw, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("embeddings API %s: %s", resp.Status, strings.TrimSpace(string(raw))) + } + + var parsed struct { + Data []struct { + Embedding []float64 `json:"embedding"` + } `json:"data"` + } + if err := json.Unmarshal(raw, &parsed); err != nil { + return nil, err + } + if len(parsed.Data) == 0 || len(parsed.Data[0].Embedding) == 0 { + return nil, fmt.Errorf("embeddings API returned no vectors") + } + out := make([]float32, len(parsed.Data[0].Embedding)) + for i, v := range parsed.Data[0].Embedding { + out[i] = float32(v) + } + return out, nil + }, nil +} diff --git a/examples/agent_with_retriever/common/embedding.go b/examples/agent_with_retriever/common/embedding.go new file mode 100644 index 0000000..2fd038d --- /dev/null +++ b/examples/agent_with_retriever/common/embedding.go @@ -0,0 +1,35 @@ +package common + +import ( + "fmt" + "os" + "strings" + + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" +) + +// ValidateEmbeddingConfig ensures pgvector can call an OpenAI-compatible embeddings API. +// When LLM_PROVIDER is not openai, EMBEDDING_APIKEY (or OPENAI_APIKEY) must be set explicitly. +func ValidateEmbeddingConfig(provider interfaces.LLMProvider, settings *Settings) error { + if settings == nil { + return fmt.Errorf("settings is nil") + } + if settings.EmbeddingAPIKey == "" { + return fmt.Errorf("EMBEDDING_APIKEY or LLM_APIKEY is required for pgvector embeddings") + } + explicit := strings.TrimSpace(os.Getenv("EMBEDDING_APIKEY")) != "" || + strings.TrimSpace(os.Getenv("OPENAI_APIKEY")) != "" + if explicit { + return nil + } + switch provider { + case interfaces.LLMProviderOpenAI, "": + return nil + default: + return fmt.Errorf( + "pgvector embeddings need an OpenAI-compatible API key in EMBEDDING_APIKEY (or OPENAI_APIKEY); "+ + "LLM_PROVIDER=%s cannot use LLM_APIKEY for /embeddings", + provider, + ) + } +} diff --git a/examples/agent_with_retriever/common/opts.go b/examples/agent_with_retriever/common/opts.go new file mode 100644 index 0000000..5290c6c --- /dev/null +++ b/examples/agent_with_retriever/common/opts.go @@ -0,0 +1,46 @@ +package common + +import ( + "fmt" + + "github.com/agenticenv/agent-sdk-go/pkg/agent" + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" + "github.com/agenticenv/agent-sdk-go/pkg/logger" +) + +// AgentOptions builds shared agent options: Temporal, LLM, retriever mode, and system prompt. +// backendLabel is shown in the agent name/description (e.g. "weaviate" or "pgvector"). +func AgentOptions( + host string, + port int, + namespace, taskQueue string, + llmClient interfaces.LLMClient, + log logger.Logger, + settings *Settings, + backendLabel string, +) []agent.Option { + mode := settings.RetrieverMode + prompt := fmt.Sprintf( + "You are a helpful assistant with access to a %s knowledge base (%s mode). "+ + "Use retrieved documents to answer questions accurately. "+ + "When in agentic or hybrid mode, call the retriever tool when you need facts from the knowledge base. "+ + "Cite sources when possible.", + backendLabel, + mode, + ) + return []agent.Option{ + agent.WithName(fmt.Sprintf("agent-with-retriever-%s", backendLabel)), + agent.WithDescription(fmt.Sprintf("Agent with %s retriever (%s)", backendLabel, mode)), + agent.WithSystemPrompt(prompt), + agent.WithTemporalConfig(&agent.TemporalConfig{ + Host: host, + Port: port, + Namespace: namespace, + TaskQueue: taskQueue, + }), + agent.WithLLMClient(llmClient), + agent.WithLogger(log), + agent.WithRetrieverMode(mode), + agent.WithToolApprovalPolicy(agent.AutoToolApprovalPolicy()), + } +} diff --git a/examples/agent_with_retriever/common/sample-documents.json b/examples/agent_with_retriever/common/sample-documents.json new file mode 100644 index 0000000..6d6cb92 --- /dev/null +++ b/examples/agent_with_retriever/common/sample-documents.json @@ -0,0 +1,26 @@ +[ + { + "content": "Standard shipping within the continental United States takes 3–5 business days after the order ships. Express shipping (1–2 business days) is available at checkout for an additional fee. Orders placed after 2 p.m. local time ship the next business day.", + "source": "kb/shipping-and-delivery" + }, + { + "content": "Most items can be returned within 30 days of delivery if they are unused and in original packaging. Refunds are issued to the original payment method within 5–7 business days after we receive and inspect the return. Clearance items are final sale unless defective.", + "source": "kb/returns-and-refunds" + }, + { + "content": "To reset your account password, open the sign-in page and choose Forgot password. Enter the email on your account; you will receive a link that expires in 24 hours. If you do not see the email, check spam or contact support with your order number.", + "source": "kb/account/password-reset" + }, + { + "content": "Hardware products include a one-year limited warranty covering manufacturing defects. The warranty does not cover accidental damage, water damage, or normal wear. To start a claim, open a support ticket with your serial number and a short description of the issue.", + "source": "kb/warranty/hardware" + }, + { + "content": "Pro and Enterprise plans include priority email support with a target first response within one business day. Business hours are Monday–Friday, 9 a.m.–6 p.m. Eastern Time, excluding U.S. federal holidays. Phone support is available on Enterprise plans only.", + "source": "kb/support/hours-and-sla" + }, + { + "content": "Invoices for subscription plans are emailed on the first of each month. You can download past invoices from Billing → Invoice history in the customer portal. Tax is calculated based on your billing address and local regulations.", + "source": "kb/billing/invoices" + } +] diff --git a/examples/agent_with_retriever/pgvector/README.md b/examples/agent_with_retriever/pgvector/README.md new file mode 100644 index 0000000..c33c84a --- /dev/null +++ b/examples/agent_with_retriever/pgvector/README.md @@ -0,0 +1,136 @@ +# pgvector retriever example + +This program uses [`pkg/retriever/pgvector`](../../../pkg/retriever/pgvector): queries are embedded with an **OpenAI-compatible API**, then searched in PostgreSQL with [**pgvector**](https://github.com/pgvector/pgvector). + +Parent overview: [`../README.md`](../README.md). + +## Quick setup + +```bash +cd examples/agent_with_retriever/pgvector +chmod +x setup.sh cleanup.sh verify.sh +./setup.sh +``` + +Requires **Docker**, **curl**, **jq**, and an OpenAI-compatible key for embeddings (`EMBEDDING_APIKEY`, `OPENAI_APIKEY`, or `LLM_APIKEY` in `examples/.env`). + +**[`setup.sh`](setup.sh)** starts Postgres, applies [`setup.sql`](setup.sql), embeds [`../common/sample-documents.json`](../common/sample-documents.json), and prints `PGVECTOR_DSN` for `.env`. + +```bash +./cleanup.sh # when finished +``` + +## Configure `.env` + +From `examples/` (after `./setup.sh`): + +```bash +# Temporal + LLM (required) +LLM_APIKEY=sk-... +LLM_MODEL=gpt-4o + +# Postgres +PGVECTOR_DSN=postgres://postgres:secret@localhost:5432/vectordb?sslmode=disable +PGVECTOR_TABLE=documents +PGVECTOR_RETRIEVER_NAME=pgvector-kb + +# Embeddings (must match ./setup.sh) +EMBEDDING_MODEL=text-embedding-3-small +EMBEDDING_APIKEY=sk-... # required when LLM_PROVIDER is not openai +# PGVECTOR_MIN_SCORE=0.35 # example default; see env.sample + +# Optional: agentic | prefetch | hybrid +RETRIEVER_MODE=agentic +``` + +Embeddings use **OpenAI** (or `EMBEDDING_*`). Chat can use another provider (e.g. Anthropic). + +## Run the example + +```bash +cd examples +go run ./agent_with_retriever/pgvector "What is the return policy?" +go run ./agent_with_retriever/pgvector "How long does standard shipping take in the US?" + +RETRIEVER_MODE=prefetch go run ./agent_with_retriever/pgvector "What is the return policy?" + +RETRIEVER_MODE=hybrid go run ./agent_with_retriever/pgvector "What are Pro and Enterprise support hours?" +``` + +Sample prompts match the customer-support articles in [`../common/sample-documents.json`](../common/sample-documents.json) (returns, shipping, warranty, support hours, etc.). + +## Verify search (optional) + +```bash +./verify.sh "What is the return policy?" +``` + +Shows row count and similarity scores without running the agent. + +## Troubleshooting + +### `no relevant documents found` + +The retriever ran but no rows passed the similarity filter. + +1. Check data and scores: + ```bash + ./verify.sh "What is the return policy?" + ``` + - **`COUNT` is 0** → run `./setup.sh` again, or fix `PGVECTOR_DSN` in `examples/.env`. + - **Rows exist but low `score`** → lower the threshold in `examples/.env`: + ```bash + PGVECTOR_MIN_SCORE=0.35 + ``` + Re-run the example (startup line shows `minScore: 0.35`). + +2. **Embeddings key** — search uses OpenAI `/embeddings`, not your chat LLM. If `LLM_PROVIDER=anthropic` or `gemini`, set: + ```bash + EMBEDDING_APIKEY=sk-... + EMBEDDING_BASEURL=https://api.openai.com/v1 + ``` + Re-run `./setup.sh` so stored vectors use the same model as queries. + +3. **Prefetch / hybrid** — your full user message is embedded as the search query. Use a concrete question (e.g. *“What is the return policy?”*), not *“Summarize the knowledge base”*. + +### `embedding config: ... LLM_PROVIDER=anthropic` + +Set `EMBEDDING_APIKEY` (OpenAI-compatible) in `examples/.env`. `LLM_APIKEY` alone is not enough when chat uses Anthropic/Gemini. + +### `PGVECTOR_DSN is required` + +Copy `PGVECTOR_DSN` from `./setup.sh` output into `examples/.env`. + +### `dimension mismatch` or SQL errors + +`EMBEDDING_MODEL` must match `vector(1536)` in `setup.sql` (default `text-embedding-3-small`). After changing the model, `./cleanup.sh`, `./setup.sh`, and update `.env`. + +### Connection / port errors + +```bash +./cleanup.sh +./setup.sh +docker logs pgvector +docker ps +``` + +Port **5432** already in use → stop other Postgres or set `PGVECTOR_PORT` and update `PGVECTOR_DSN`. + +### Weak or incomplete answers (prefetch) + +Only documents above `PGVECTOR_MIN_SCORE` are injected. Run `./verify.sh` with your exact prompt; lower `PGVECTOR_MIN_SCORE` if needed docs are below the threshold. + +### Debug logs + +```bash +LOG_LEVEL=debug go run ./agent_with_retriever/pgvector "What is the return policy?" +``` + +Look for `pgvector search done` with `docs=0` vs embed/query errors. + +### Clean reset + +```bash +./cleanup.sh && ./setup.sh +./verify.sh "What is the return policy?" +``` diff --git a/examples/agent_with_retriever/pgvector/cleanup.sh b/examples/agent_with_retriever/pgvector/cleanup.sh new file mode 100755 index 0000000..79f6feb --- /dev/null +++ b/examples/agent_with_retriever/pgvector/cleanup.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash +# Stop and remove the local pgvector Postgres Docker container for this example. +# +# Usage (from this directory): +# ./cleanup.sh +# +# Environment: +# PGVECTOR_CONTAINER_NAME default pgvector +set -euo pipefail + +CONTAINER_NAME="${PGVECTOR_CONTAINER_NAME:-pgvector}" + +if ! command -v docker >/dev/null 2>&1; then + echo "error: docker is required but not installed" >&2 + exit 1 +fi + +if docker ps -a --format '{{.Names}}' | grep -qx "$CONTAINER_NAME"; then + echo "Stopping and removing '${CONTAINER_NAME}'..." + docker rm -f "$CONTAINER_NAME" >/dev/null + echo "Done." +else + echo "No container named '${CONTAINER_NAME}' found." +fi diff --git a/examples/agent_with_retriever/pgvector/main.go b/examples/agent_with_retriever/pgvector/main.go new file mode 100644 index 0000000..cad1b05 --- /dev/null +++ b/examples/agent_with_retriever/pgvector/main.go @@ -0,0 +1,93 @@ +// Example agent using a PostgreSQL pgvector retriever. +// +// Run from the repository root (or examples/): +// +// go run ./examples/agent_with_retriever/pgvector "What do you know about our docs?" +// +// See ../README.md and ./README.md for Postgres/pgvector setup and env vars. +package main + +import ( + "context" + "fmt" + "log" + "os" + "strings" + + examplecfg "github.com/agenticenv/agent-sdk-go/examples" + "github.com/agenticenv/agent-sdk-go/examples/agent_with_retriever/common" + "github.com/agenticenv/agent-sdk-go/pkg/agent" + pgretriever "github.com/agenticenv/agent-sdk-go/pkg/retriever/pgvector" +) + +func main() { + cfg := examplecfg.LoadFromEnv() + retrieverCfg, err := common.LoadSettings() + if err != nil { + log.Fatalf("retriever config: %v", err) + } + if retrieverCfg.PGDSN == "" { + log.Fatal("PGVECTOR_DSN is required for the pgvector example; see ./README.md") + } + if err := common.ValidateEmbeddingConfig(cfg.Provider, retrieverCfg); err != nil { + log.Fatalf("embedding config: %v", err) + } + + llmClient, err := examplecfg.NewLLMClientFromConfig(cfg) + if err != nil { + log.Fatalf("failed to create LLM client: %v", err) + } + logr := examplecfg.NewLoggerFromLogConfig(cfg) + + embed, err := common.OpenAIEmbedFunc(retrieverCfg) + if err != nil { + log.Fatalf("embed func: %v", err) + } + + pOpts := []pgretriever.Option{ + pgretriever.WithDSN(retrieverCfg.PGDSN), + pgretriever.WithTable(retrieverCfg.PGTable), + pgretriever.WithContentCol(retrieverCfg.PGContentCol), + pgretriever.WithSourceCol(retrieverCfg.PGSourceCol), + pgretriever.WithEmbeddingCol(retrieverCfg.PGEmbeddingCol), + pgretriever.WithLogger(logr), + } + if retrieverCfg.PGTopK > 0 { + pOpts = append(pOpts, pgretriever.WithTopK(retrieverCfg.PGTopK)) + } + pOpts = append(pOpts, pgretriever.WithMinScore(retrieverCfg.PGMinScore)) + + retriever, err := pgretriever.NewRetriever(retrieverCfg.PGRetrieverName, embed, pOpts...) + if err != nil { + log.Fatalf("pgvector retriever: %v", err) + } + + opts := common.AgentOptions( + cfg.Host, cfg.Port, cfg.Namespace, cfg.TaskQueue, + llmClient, logr, retrieverCfg, "pgvector", + ) + opts = append(opts, agent.WithRetrievers(retriever)) + + a, err := agent.NewAgent(opts...) + if err != nil { + log.Fatal(examplecfg.FormatNewAgentError("failed to create agent", err)) + } + defer a.Close() + + prompt := strings.Join(os.Args[1:], " ") + if prompt == "" { + prompt = "What is the return policy according to the knowledge base?" + } + + fmt.Printf("backend: pgvector mode: %s retriever: %s table: %s minScore: %.2f\n", + retrieverCfg.RetrieverMode, retriever.Name(), retrieverCfg.PGTable, retrieverCfg.PGMinScore) + fmt.Printf("hint: %s\n", common.ModeHint(retrieverCfg.RetrieverMode)) + fmt.Println("user:", prompt) + + result, err := a.Run(context.Background(), prompt, "") + if err != nil { + log.Printf("run failed: %v", err) + return + } + fmt.Println("assistant:", result.Content) +} diff --git a/examples/agent_with_retriever/pgvector/setup.sh b/examples/agent_with_retriever/pgvector/setup.sh new file mode 100755 index 0000000..68711e9 --- /dev/null +++ b/examples/agent_with_retriever/pgvector/setup.sh @@ -0,0 +1,223 @@ +#!/usr/bin/env bash +# One-shot pgvector setup for the agent_with_retriever/pgvector example: +# - starts PostgreSQL with pgvector in Docker (or reuses existing container) +# - waits until Postgres is ready +# - applies setup.sql (extension, table, index) +# - embeds sample-documents.json via OpenAI-compatible API and inserts rows +# +# Usage (from this directory): +# ./setup.sh +# +# Teardown: ./cleanup.sh +# +# Environment: +# OPENAI_APIKEY / LLM_APIKEY from env or examples/.env (required for embeddings) +# EMBEDDING_MODEL default text-embedding-3-small (1536 dimensions) +# EMBEDDING_BASEURL default LLM_BASEURL from .env or https://api.openai.com/v1 +# PGVECTOR_CONTAINER_NAME default pgvector +# PGVECTOR_PORT default 5432 +set -euo pipefail + +CONTAINER_NAME="${PGVECTOR_CONTAINER_NAME:-pgvector}" +PG_IMAGE="${PGVECTOR_IMAGE:-pgvector/pgvector:pg16}" +PG_PORT="${PGVECTOR_PORT:-5432}" +PG_USER="${PGVECTOR_USER:-postgres}" +PG_PASSWORD="${PGVECTOR_PASSWORD:-secret}" +PG_DB="${PGVECTOR_DB:-vectordb}" +PG_TABLE="${PGVECTOR_TABLE:-documents}" +EMBEDDING_MODEL="${EMBEDDING_MODEL:-text-embedding-3-small}" +READY_TIMEOUT_SEC="${PGVECTOR_READY_TIMEOUT_SEC:-120}" + +PG_DSN="postgres://${PG_USER}:${PG_PASSWORD}@localhost:${PG_PORT}/${PG_DB}?sslmode=disable" + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ENV_FILE="${SCRIPT_DIR}/../../.env" +DOCS_FILE="${SCRIPT_DIR}/../common/sample-documents.json" +SQL_FILE="${SCRIPT_DIR}/setup.sql" + +read_env_value() { + local key="$1" file="$2" + [[ -f "$file" ]] || return 1 + local line + line="$(grep -E "^${key}=" "$file" | tail -1 || true)" + [[ -n "$line" ]] || return 1 + line="${line#${key}=}" + line="${line%$'\r'}" + line="${line#\"}"; line="${line%\"}" + line="${line#\'}"; line="${line%\'}" + printf '%s' "$line" +} + +require_cmd() { + if ! command -v "$1" >/dev/null 2>&1; then + echo "error: '$1' is required but not installed" >&2 + exit 1 + fi +} + +sql_escape() { + printf '%s' "$1" | sed "s/'/''/g" +} + +resolve_openai_api_key() { + if [[ -n "${OPENAI_APIKEY:-}" ]]; then + return 0 + fi + if key="$(read_env_value OPENAI_APIKEY "$ENV_FILE" 2>/dev/null)" && [[ -n "$key" ]]; then + export OPENAI_APIKEY="$key" + echo "Using OPENAI_APIKEY from ${ENV_FILE}" + return 0 + fi + if key="$(read_env_value EMBEDDING_APIKEY "$ENV_FILE" 2>/dev/null)" && [[ -n "$key" ]]; then + export OPENAI_APIKEY="$key" + echo "Using EMBEDDING_APIKEY from ${ENV_FILE}" + return 0 + fi + if key="$(read_env_value LLM_APIKEY "$ENV_FILE" 2>/dev/null)" && [[ -n "$key" ]]; then + export OPENAI_APIKEY="$key" + echo "Using LLM_APIKEY from ${ENV_FILE} for embeddings" + return 0 + fi + echo "error: set OPENAI_APIKEY / EMBEDDING_APIKEY / LLM_APIKEY for embedding seed data" >&2 + exit 1 +} + +resolve_embedding_base_url() { + if [[ -n "${EMBEDDING_BASEURL:-}" ]]; then + return 0 + fi + if url="$(read_env_value EMBEDDING_BASEURL "$ENV_FILE" 2>/dev/null)" && [[ -n "$url" ]]; then + export EMBEDDING_BASEURL="$url" + return 0 + fi + if url="$(read_env_value LLM_BASEURL "$ENV_FILE" 2>/dev/null)" && [[ -n "$url" ]]; then + export EMBEDDING_BASEURL="$url" + return 0 + fi + export EMBEDDING_BASEURL="https://api.openai.com/v1" +} + +container_running() { + docker ps --format '{{.Names}}' | grep -qx "$CONTAINER_NAME" +} + +container_exists() { + docker ps -a --format '{{.Names}}' | grep -qx "$CONTAINER_NAME" +} + +wait_for_postgres() { + local deadline=$((SECONDS + READY_TIMEOUT_SEC)) + echo "Waiting for Postgres in '${CONTAINER_NAME}' (timeout ${READY_TIMEOUT_SEC}s)..." + while (( SECONDS < deadline )); do + if docker exec "$CONTAINER_NAME" pg_isready -U "$PG_USER" -d "$PG_DB" >/dev/null 2>&1; then + echo "Postgres is ready." + return 0 + fi + sleep 2 + done + echo "error: Postgres did not become ready within ${READY_TIMEOUT_SEC}s" >&2 + echo "Check logs: docker logs ${CONTAINER_NAME}" >&2 + exit 1 +} + +start_postgres() { + if container_running; then + echo "Container '${CONTAINER_NAME}' is already running." + return 0 + fi + if container_exists; then + echo "Starting existing container '${CONTAINER_NAME}'..." + docker start "$CONTAINER_NAME" >/dev/null + return 0 + fi + + echo "Creating and starting '${CONTAINER_NAME}' (${PG_IMAGE})..." + docker run -d --name "$CONTAINER_NAME" \ + -e POSTGRES_PASSWORD="$PG_PASSWORD" \ + -e POSTGRES_DB="$PG_DB" \ + -p "${PG_PORT}:5432" \ + "$PG_IMAGE" >/dev/null +} + +apply_schema() { + if [[ ! -f "$SQL_FILE" ]]; then + echo "error: missing ${SQL_FILE}" >&2 + exit 1 + fi + echo "Applying schema from setup.sql..." + docker exec -i "$CONTAINER_NAME" psql -U "$PG_USER" -d "$PG_DB" -v ON_ERROR_STOP=1 < "$SQL_FILE" +} + +embed_text() { + local text="$1" + local body response + body=$(jq -n --arg input "$text" --arg model "$EMBEDDING_MODEL" '{input: $input, model: $model}') + response=$(curl -sf "${EMBEDDING_BASEURL%/}/embeddings" \ + -H "Authorization: Bearer ${OPENAI_APIKEY}" \ + -H "Content-Type: application/json" \ + -d "$body") + echo "$response" | jq -c '.data[0].embedding' +} + +seed_documents() { + if [[ ! -f "$DOCS_FILE" ]]; then + echo "error: missing ${DOCS_FILE}" >&2 + exit 1 + fi + + echo "Clearing existing rows in ${PG_TABLE}..." + docker exec "$CONTAINER_NAME" psql -U "$PG_USER" -d "$PG_DB" -v ON_ERROR_STOP=1 \ + -c "TRUNCATE ${PG_TABLE} RESTART IDENTITY;" >/dev/null + + local count=0 row content source vec content_sql source_sql + while IFS= read -r row; do + content="$(echo "$row" | jq -r '.content')" + source="$(echo "$row" | jq -r '.source')" + echo "Embedding document $((count + 1)): ${source}" + vec="$(embed_text "$content")" + if [[ -z "$vec" || "$vec" == "null" ]]; then + echo "error: empty embedding for ${source}" >&2 + exit 1 + fi + content_sql="$(sql_escape "$content")" + source_sql="$(sql_escape "$source")" + docker exec "$CONTAINER_NAME" psql -U "$PG_USER" -d "$PG_DB" -v ON_ERROR_STOP=1 \ + -c "INSERT INTO ${PG_TABLE} (content, source, embedding) VALUES ('${content_sql}', '${source_sql}', '${vec}'::vector);" >/dev/null + count=$((count + 1)) + done < <(jq -c '.[]' "$DOCS_FILE") + + echo "Inserted ${count} documents from sample-documents.json" +} + +require_cmd docker +require_cmd curl +require_cmd jq +resolve_openai_api_key +resolve_embedding_base_url +start_postgres +wait_for_postgres +apply_schema +seed_documents + +cat </dev/null 2>&1 || { echo "error: need $1" >&2; exit 1; } +} + +resolve_openai_api_key() { + [[ -n "${OPENAI_APIKEY:-}" ]] && return 0 + if key="$(read_env_value OPENAI_APIKEY "$ENV_FILE" 2>/dev/null)" && [[ -n "$key" ]]; then + export OPENAI_APIKEY="$key"; return 0 + fi + if key="$(read_env_value EMBEDDING_APIKEY "$ENV_FILE" 2>/dev/null)" && [[ -n "$key" ]]; then + export OPENAI_APIKEY="$key"; return 0 + fi + if key="$(read_env_value LLM_APIKEY "$ENV_FILE" 2>/dev/null)" && [[ -n "$key" ]]; then + export OPENAI_APIKEY="$key"; return 0 + fi + echo "error: set OPENAI_APIKEY or EMBEDDING_APIKEY" >&2 + exit 1 +} + +resolve_embedding_base_url() { + [[ -n "${EMBEDDING_BASEURL:-}" ]] && return 0 + if url="$(read_env_value EMBEDDING_BASEURL "$ENV_FILE" 2>/dev/null)" && [[ -n "$url" ]]; then + export EMBEDDING_BASEURL="$url"; return 0 + fi + if url="$(read_env_value LLM_BASEURL "$ENV_FILE" 2>/dev/null)" && [[ -n "$url" ]]; then + export EMBEDDING_BASEURL="$url"; return 0 + fi + export EMBEDDING_BASEURL="https://api.openai.com/v1" +} + +embed_text() { + local text="$1" body response + body=$(jq -n --arg input "$text" --arg model "${EMBEDDING_MODEL:-text-embedding-3-small}" \ + '{input: $input, model: $model}') + response=$(curl -sf "${EMBEDDING_BASEURL%/}/embeddings" \ + -H "Authorization: Bearer ${OPENAI_APIKEY}" \ + -H "Content-Type: application/json" \ + -d "$body") + echo "$response" | jq -c '.data[0].embedding' +} + +require_cmd docker +require_cmd curl +require_cmd jq +resolve_openai_api_key +resolve_embedding_base_url + +echo "=== row count ===" +docker exec "$CONTAINER_NAME" psql -U "$PG_USER" -d "$PG_DB" -t -c "SELECT COUNT(*) FROM ${PG_TABLE};" + +echo "=== top matches (no min_score filter) for: ${QUERY} ===" +vec="$(embed_text "$QUERY")" +docker exec "$CONTAINER_NAME" psql -U "$PG_USER" -d "$PG_DB" -c \ + "SELECT source, LEFT(content, 60) AS preview, + ROUND((1 - (embedding <=> '${vec}'::vector))::numeric, 4) AS score + FROM ${PG_TABLE} + ORDER BY embedding <=> '${vec}'::vector + LIMIT 5;" + +echo "" +echo "If expected docs are missing, lower PGVECTOR_MIN_SCORE in examples/.env (example default 0.35)" +echo "If COUNT is 0, re-run ./setup.sh" diff --git a/examples/agent_with_retriever/weaviate/README.md b/examples/agent_with_retriever/weaviate/README.md new file mode 100644 index 0000000..0deb56a --- /dev/null +++ b/examples/agent_with_retriever/weaviate/README.md @@ -0,0 +1,136 @@ +# Weaviate retriever example + +This program uses [`pkg/retriever/weaviate`](../../../pkg/retriever/weaviate): Weaviate embeds queries via **nearText** (no client-side embedding). + +Parent overview: [`../README.md`](../README.md). + +## Quick setup + +```bash +cd examples/agent_with_retriever/weaviate +chmod +x setup.sh cleanup.sh +export OPENAI_APIKEY=sk-your-key # or set in examples/.env +./setup.sh +``` + +Requires **Docker**, **curl**, **jq**, and an OpenAI API key for Weaviate’s `text2vec-openai` module. + +**[`setup.sh`](setup.sh)** starts Weaviate, creates the schema, and loads [`../common/sample-documents.json`](../common/sample-documents.json). + +```bash +./cleanup.sh # when finished +``` + +## Configure `.env` + +From `examples/`: + +```bash +# Temporal + LLM (required) +LLM_APIKEY=sk-... +LLM_MODEL=gpt-4o + +# Weaviate (defaults shown) +WEAVIATE_HOST=localhost:8080 +WEAVIATE_SCHEME=http +WEAVIATE_CLASS=Document +WEAVIATE_RETRIEVER_NAME=weaviate-kb + +# Optional: agentic | prefetch | hybrid +RETRIEVER_MODE=agentic +``` + +Weaviate uses **OpenAI** inside Docker for vectors. Chat can use another provider (e.g. Anthropic). + +## Run the example + +```bash +cd examples +go run ./agent_with_retriever/weaviate "What is the return policy?" +go run ./agent_with_retriever/weaviate "How long does standard shipping take in the US?" + +RETRIEVER_MODE=prefetch go run ./agent_with_retriever/weaviate "What is the return policy?" + +RETRIEVER_MODE=hybrid go run ./agent_with_retriever/weaviate "What are Pro and Enterprise support hours?" +``` + +Prompts match articles in [`../common/sample-documents.json`](../common/sample-documents.json). + +## Troubleshooting + +### `OPENAI_APIKEY` error from setup.sh + +Weaviate’s `text2vec-openai` module needs an OpenAI key in the container: + +```bash +export OPENAI_APIKEY=sk-your-key +./setup.sh +``` + +Or add `OPENAI_APIKEY` / `LLM_APIKEY` to `examples/.env` before running `./setup.sh`. + +### Connection refused on `:8080` + +Weaviate is not running or `WEAVIATE_HOST` is wrong. + +```bash +docker ps +./setup.sh +curl -s http://localhost:8080/v1/.well-known/ready +``` + +### Empty search or `no relevant documents found` + +1. Re-seed the sample KB: `./setup.sh` +2. Confirm class name matches `.env`: `WEAVIATE_CLASS=Document` +3. Optional: lower certainty — `WEAVIATE_MIN_SCORE=0.5` in `.env` (SDK default is **0.75**) +4. List objects: + ```bash + curl -s "http://localhost:8080/v1/objects?class=Document&limit=5" + ``` + +### Vectorizer / OpenAI errors in logs + +`OPENAI_APIKEY` must be set when the container starts. Fix and recreate: + +```bash +./cleanup.sh +export OPENAI_APIKEY=sk-your-key +./setup.sh +docker logs weaviate +``` + +### Port already in use (`8080` or `50051`) + +Another process or old container is using the port: + +```bash +./cleanup.sh +./setup.sh +``` + +### Answers ignore the knowledge base + +- Run `./setup.sh` and confirm objects exist (curl above). +- **Agentic mode** — LLM must call `retriever_weaviate-kb`; try **prefetch** to force retrieval: + ```bash + RETRIEVER_MODE=prefetch go run ./agent_with_retriever/weaviate "What is the return policy?" + ``` +- Check `WEAVIATE_HOST`, `WEAVIATE_CLASS`, and `WEAVIATE_SCHEME` in `.env`. + +### Prefetch / hybrid returns little context + +Prefetch searches with your **exact user message**. Use topic questions aligned with the sample KB (returns, shipping, warranty, etc.). + +### Debug logs + +```bash +LOG_LEVEL=debug go run ./agent_with_retriever/weaviate "What is the return policy?" +docker logs weaviate +``` + +### Clean reset + +```bash +./cleanup.sh && ./setup.sh +``` diff --git a/examples/agent_with_retriever/weaviate/cleanup.sh b/examples/agent_with_retriever/weaviate/cleanup.sh new file mode 100755 index 0000000..53d5e18 --- /dev/null +++ b/examples/agent_with_retriever/weaviate/cleanup.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash +# Stop and remove the local Weaviate Docker container for this example. +# +# Usage (from this directory): +# ./cleanup.sh +# +# Environment: +# WEAVIATE_CONTAINER_NAME default weaviate +set -euo pipefail + +CONTAINER_NAME="${WEAVIATE_CONTAINER_NAME:-weaviate}" + +if ! command -v docker >/dev/null 2>&1; then + echo "error: docker is required but not installed" >&2 + exit 1 +fi + +if docker ps -a --format '{{.Names}}' | grep -qx "$CONTAINER_NAME"; then + echo "Stopping and removing '${CONTAINER_NAME}'..." + docker rm -f "$CONTAINER_NAME" >/dev/null + echo "Done." +else + echo "No container named '${CONTAINER_NAME}' found." +fi diff --git a/examples/agent_with_retriever/weaviate/main.go b/examples/agent_with_retriever/weaviate/main.go new file mode 100644 index 0000000..5c948c2 --- /dev/null +++ b/examples/agent_with_retriever/weaviate/main.go @@ -0,0 +1,83 @@ +// Example agent using a Weaviate vector retriever. +// +// Run from the repository root (or examples/): +// +// go run ./examples/agent_with_retriever/weaviate "What do you know about our docs?" +// +// See ../README.md and ./README.md for Weaviate setup and env vars. +package main + +import ( + "context" + "fmt" + "log" + "os" + "strings" + + examplecfg "github.com/agenticenv/agent-sdk-go/examples" + "github.com/agenticenv/agent-sdk-go/examples/agent_with_retriever/common" + "github.com/agenticenv/agent-sdk-go/pkg/agent" + weaviate "github.com/agenticenv/agent-sdk-go/pkg/retriever/weaviate" +) + +func main() { + cfg := examplecfg.LoadFromEnv() + retrieverCfg, err := common.LoadSettings() + if err != nil { + log.Fatalf("retriever config: %v", err) + } + + llmClient, err := examplecfg.NewLLMClientFromConfig(cfg) + if err != nil { + log.Fatalf("failed to create LLM client: %v", err) + } + logr := examplecfg.NewLoggerFromLogConfig(cfg) + + wOpts := []weaviate.Option{ + weaviate.WithHost(retrieverCfg.WeaviateHost), + weaviate.WithScheme(retrieverCfg.WeaviateScheme), + weaviate.WithClassName(retrieverCfg.WeaviateClass), + weaviate.WithContentField(retrieverCfg.WeaviateContentField), + weaviate.WithSourceField(retrieverCfg.WeaviateSourceField), + weaviate.WithLogger(logr), + } + if retrieverCfg.WeaviateTopK > 0 { + wOpts = append(wOpts, weaviate.WithTopK(retrieverCfg.WeaviateTopK)) + } + if retrieverCfg.WeaviateMinScore > 0 { + wOpts = append(wOpts, weaviate.WithMinScore(retrieverCfg.WeaviateMinScore)) + } + + retriever, err := weaviate.NewRetriever(retrieverCfg.WeaviateRetrieverName, wOpts...) + if err != nil { + log.Fatalf("weaviate retriever: %v", err) + } + + opts := common.AgentOptions( + cfg.Host, cfg.Port, cfg.Namespace, cfg.TaskQueue, + llmClient, logr, retrieverCfg, "weaviate", + ) + opts = append(opts, agent.WithRetrievers(retriever)) + + a, err := agent.NewAgent(opts...) + if err != nil { + log.Fatal(examplecfg.FormatNewAgentError("failed to create agent", err)) + } + defer a.Close() + + prompt := strings.Join(os.Args[1:], " ") + if prompt == "" { + prompt = "What is the return policy according to the knowledge base?" + } + + fmt.Printf("backend: weaviate mode: %s retriever: %s\n", retrieverCfg.RetrieverMode, retriever.Name()) + fmt.Printf("hint: %s\n", common.ModeHint(retrieverCfg.RetrieverMode)) + fmt.Println("user:", prompt) + + result, err := a.Run(context.Background(), prompt, "") + if err != nil { + log.Printf("run failed: %v", err) + return + } + fmt.Println("assistant:", result.Content) +} diff --git a/examples/agent_with_retriever/weaviate/setup.sh b/examples/agent_with_retriever/weaviate/setup.sh new file mode 100755 index 0000000..ff91d84 --- /dev/null +++ b/examples/agent_with_retriever/weaviate/setup.sh @@ -0,0 +1,164 @@ +#!/usr/bin/env bash +# One-shot Weaviate setup for the agent_with_retriever/weaviate example: +# - starts Docker (or reuses an existing weaviate container) +# - waits until the API is ready +# - creates schema and loads sample-documents.json +# +# Usage (from this directory): +# ./setup.sh +# +# Teardown: ./cleanup.sh +# +# Environment: +# OPENAI_APIKEY required for text2vec-openai (falls back to LLM_APIKEY from examples/.env) +# WEAVIATE_URL default http://localhost:8080 +# WEAVIATE_CLASS default Document +set -euo pipefail + +CONTAINER_NAME="${WEAVIATE_CONTAINER_NAME:-weaviate}" +WEAVIATE_IMAGE="${WEAVIATE_IMAGE:-cr.weaviate.io/semitechnologies/weaviate:1.27.0}" +WEAVIATE_HTTP_PORT="${WEAVIATE_HTTP_PORT:-8080}" +WEAVIATE_GRPC_PORT="${WEAVIATE_GRPC_PORT:-50051}" +WEAVIATE_URL="${WEAVIATE_URL:-http://localhost:${WEAVIATE_HTTP_PORT}}" +WEAVIATE_CLASS="${WEAVIATE_CLASS:-Document}" +READY_TIMEOUT_SEC="${WEAVIATE_READY_TIMEOUT_SEC:-120}" + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ENV_FILE="${SCRIPT_DIR}/../../.env" +DOCS_FILE="${SCRIPT_DIR}/../common/sample-documents.json" + +read_env_value() { + local key="$1" file="$2" + [[ -f "$file" ]] || return 1 + local line + line="$(grep -E "^${key}=" "$file" | tail -1 || true)" + [[ -n "$line" ]] || return 1 + line="${line#${key}=}" + line="${line%$'\r'}" + line="${line#\"}"; line="${line%\"}" + line="${line#\'}"; line="${line%\'}" + printf '%s' "$line" +} + +require_cmd() { + if ! command -v "$1" >/dev/null 2>&1; then + echo "error: '$1' is required but not installed" >&2 + exit 1 + fi +} + +resolve_openai_api_key() { + if [[ -n "${OPENAI_APIKEY:-}" ]]; then + return 0 + fi + if key="$(read_env_value OPENAI_APIKEY "$ENV_FILE" 2>/dev/null)" && [[ -n "$key" ]]; then + export OPENAI_APIKEY="$key" + echo "Using OPENAI_APIKEY from ${ENV_FILE}" + return 0 + fi + if key="$(read_env_value LLM_APIKEY "$ENV_FILE" 2>/dev/null)" && [[ -n "$key" ]]; then + export OPENAI_APIKEY="$key" + echo "Using LLM_APIKEY from ${ENV_FILE} for Weaviate text2vec-openai" + return 0 + fi + echo "error: set OPENAI_APIKEY (Weaviate vectorizer) or add OPENAI_APIKEY / LLM_APIKEY to ${ENV_FILE}" >&2 + exit 1 +} + +container_running() { + docker ps --format '{{.Names}}' | grep -qx "$CONTAINER_NAME" +} + +container_exists() { + docker ps -a --format '{{.Names}}' | grep -qx "$CONTAINER_NAME" +} + +wait_for_ready() { + local deadline=$((SECONDS + READY_TIMEOUT_SEC)) + echo "Waiting for Weaviate at ${WEAVIATE_URL} (timeout ${READY_TIMEOUT_SEC}s)..." + while (( SECONDS < deadline )); do + if curl -sf "${WEAVIATE_URL}/v1/.well-known/ready" >/dev/null 2>&1; then + echo "Weaviate is ready." + return 0 + fi + sleep 2 + done + echo "error: Weaviate did not become ready within ${READY_TIMEOUT_SEC}s" >&2 + echo "Check logs: docker logs ${CONTAINER_NAME}" >&2 + exit 1 +} + +start_weaviate() { + if container_running; then + echo "Container '${CONTAINER_NAME}' is already running." + return 0 + fi + if container_exists; then + echo "Starting existing container '${CONTAINER_NAME}'..." + docker start "$CONTAINER_NAME" >/dev/null + return 0 + fi + + echo "Creating and starting '${CONTAINER_NAME}' (${WEAVIATE_IMAGE})..." + docker run -d --name "$CONTAINER_NAME" \ + -p "${WEAVIATE_HTTP_PORT}:8080" \ + -p "${WEAVIATE_GRPC_PORT}:50051" \ + -e AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED=true \ + -e DEFAULT_VECTORIZER_MODULE=text2vec-openai \ + -e ENABLE_MODULES=text2vec-openai \ + -e OPENAI_APIKEY="${OPENAI_APIKEY}" \ + "$WEAVIATE_IMAGE" >/dev/null +} + +seed_documents() { + if [[ ! -f "$DOCS_FILE" ]]; then + echo "error: missing ${DOCS_FILE}" >&2 + exit 1 + fi + + echo "Creating class ${WEAVIATE_CLASS} at ${WEAVIATE_URL} (ignored if it already exists)..." + curl -sf -X POST "${WEAVIATE_URL}/v1/schema" \ + -H 'Content-Type: application/json' \ + -d "{ + \"class\": \"${WEAVIATE_CLASS}\", + \"vectorizer\": \"text2vec-openai\", + \"properties\": [ + {\"name\": \"content\", \"dataType\": [\"text\"]}, + {\"name\": \"source\", \"dataType\": [\"text\"]} + ] + }" >/dev/null || true + + local count=0 row payload + while IFS= read -r row; do + payload=$(jq -n \ + --arg class "$WEAVIATE_CLASS" \ + --arg content "$(echo "$row" | jq -r '.content')" \ + --arg source "$(echo "$row" | jq -r '.source')" \ + '{class: $class, properties: {content: $content, source: $source}}') + curl -sf -X POST "${WEAVIATE_URL}/v1/objects" \ + -H 'Content-Type: application/json' \ + -d "$payload" >/dev/null + count=$((count + 1)) + done < <(jq -c '.[]' "$DOCS_FILE") + + echo "Inserted ${count} documents from sample-documents.json" +} + +require_cmd docker +require_cmd curl +require_cmd jq +resolve_openai_api_key +start_weaviate +wait_for_ready +seed_documents + +cat < ./pkg/tools @@ -19,46 +21,83 @@ require ( github.com/redis/go-redis/v9 v9.18.0 github.com/spf13/viper v1.21.0 github.com/stretchr/testify v1.11.1 + github.com/weaviate/weaviate-go-client/v5 v5.7.3 + go.opentelemetry.io/contrib/bridges/otelslog v0.18.0 go.opentelemetry.io/otel v1.43.0 + go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc v0.19.0 + go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploghttp v0.19.0 go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.43.0 go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp v1.43.0 go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.43.0 go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.43.0 + go.opentelemetry.io/otel/log v0.19.0 go.opentelemetry.io/otel/metric v1.43.0 go.opentelemetry.io/otel/sdk v1.43.0 + go.opentelemetry.io/otel/sdk/log v0.19.0 go.opentelemetry.io/otel/sdk/metric v1.43.0 go.opentelemetry.io/otel/trace v1.43.0 go.temporal.io/api v1.62.2 go.temporal.io/sdk v1.41.0 + go.temporal.io/sdk/contrib/opentelemetry v0.7.0 golang.org/x/oauth2 v0.35.0 google.golang.org/genai v1.51.0 google.golang.org/grpc v1.80.0 ) require ( - cloud.google.com/go v0.116.0 // indirect - cloud.google.com/go/auth v0.9.3 // indirect + cloud.google.com/go v0.123.0 // indirect + cloud.google.com/go/auth v0.18.1 // indirect cloud.google.com/go/compute/metadata v0.9.0 // indirect github.com/cenkalti/backoff/v5 v5.0.3 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect - github.com/davecgh/go-spew v1.1.1 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/facebookgo/clock v0.0.0-20150410010913-600d898af40a // indirect + github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect + github.com/go-openapi/analysis v0.24.1 // indirect + github.com/go-openapi/errors v0.22.4 // indirect + github.com/go-openapi/jsonpointer v0.22.4 // indirect + github.com/go-openapi/jsonreference v0.21.4 // indirect + github.com/go-openapi/loads v0.23.2 // indirect + github.com/go-openapi/runtime v0.29.2 // indirect + github.com/go-openapi/spec v0.22.3 // indirect + github.com/go-openapi/strfmt v0.25.0 // indirect + github.com/go-openapi/swag v0.23.0 // indirect + github.com/go-openapi/swag/conv v0.25.4 // indirect + github.com/go-openapi/swag/fileutils v0.25.1 // indirect + github.com/go-openapi/swag/jsonname v0.25.4 // indirect + github.com/go-openapi/swag/jsonutils v0.25.4 // indirect + github.com/go-openapi/swag/loading v0.25.4 // indirect + github.com/go-openapi/swag/mangling v0.25.1 // indirect + github.com/go-openapi/swag/stringutils v0.25.4 // indirect + github.com/go-openapi/swag/typeutils v0.25.4 // indirect + github.com/go-openapi/swag/yamlutils v0.25.4 // indirect + github.com/go-openapi/validate v0.25.1 // indirect github.com/go-viper/mapstructure/v2 v2.4.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect - github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/google/jsonschema-go v0.4.2 // indirect - github.com/google/s2a-go v0.1.8 // indirect - github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect + github.com/google/s2a-go v0.1.9 // indirect + github.com/googleapis/enterprise-certificate-proxy v0.3.11 // indirect + github.com/googleapis/gax-go/v2 v2.17.0 // indirect github.com/gorilla/websocket v1.5.3 // indirect github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.3.2 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/pgx/v5 v5.9.2 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect + github.com/josharian/intern v1.0.0 // indirect + github.com/mailru/easyjson v0.7.7 // indirect + github.com/oklog/ulid v1.3.1 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/pgvector/pgvector-go v0.4.0 // indirect + github.com/pgvector/pgvector-go/pgx v0.4.0 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/robfig/cron v1.2.0 // indirect github.com/sagikazarmark/locafero v0.11.0 // indirect github.com/segmentio/asm v1.1.3 // indirect @@ -67,24 +106,21 @@ require ( github.com/spf13/afero v1.15.0 // indirect github.com/spf13/cast v1.10.0 // indirect github.com/spf13/pflag v1.0.10 // indirect - github.com/stretchr/objx v0.5.2 // indirect + github.com/stretchr/objx v0.5.3 // indirect github.com/subosito/gotenv v1.6.0 // indirect github.com/tidwall/gjson v1.18.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/sjson v1.2.5 // indirect + github.com/weaviate/weaviate v1.37.2 // indirect + github.com/x448/float16 v0.8.4 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect github.com/yuin/gopher-lua v1.1.1 // indirect - go.opencensus.io v0.24.0 // indirect + go.mongodb.org/mongo-driver v1.17.6 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect - go.opentelemetry.io/contrib/bridges/otelslog v0.18.0 // indirect - go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc v0.19.0 // indirect - go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploghttp v0.19.0 // indirect + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.43.0 // indirect - go.opentelemetry.io/otel/log v0.19.0 // indirect - go.opentelemetry.io/otel/sdk/log v0.19.0 // indirect go.opentelemetry.io/proto/otlp v1.10.0 // indirect - go.temporal.io/sdk/contrib/opentelemetry v0.7.0 // indirect go.uber.org/atomic v1.11.0 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect golang.org/x/crypto v0.49.0 // indirect @@ -93,7 +129,7 @@ require ( golang.org/x/sync v0.20.0 // indirect golang.org/x/sys v0.42.0 // indirect golang.org/x/text v0.35.0 // indirect - golang.org/x/time v0.6.0 // indirect + golang.org/x/time v0.14.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20260427160629-7cedc36a6bc4 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20260427160629-7cedc36a6bc4 // indirect google.golang.org/protobuf v1.36.11 // indirect diff --git a/go.sum b/go.sum index 03197bd..b160c17 100644 --- a/go.sum +++ b/go.sum @@ -1,11 +1,9 @@ -cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -cloud.google.com/go v0.116.0 h1:B3fRrSDkLRt5qSHWe40ERJvhvnQwdZiHu0bJOpldweE= -cloud.google.com/go v0.116.0/go.mod h1:cEPSRWPzZEswwdr9BxE6ChEn01dWlTaF05LiC2Xs70U= -cloud.google.com/go/auth v0.9.3 h1:VOEUIAADkkLtyfr3BLa3R8Ed/j6w1jTBmARx+wb5w5U= -cloud.google.com/go/auth v0.9.3/go.mod h1:7z6VY+7h3KUdRov5F1i8NDP5ZzWKYmEPO842BgCsmTk= +cloud.google.com/go v0.123.0 h1:2NAUJwPR47q+E35uaJeYoNhuNEM9kM8SjgRgdeOJUSE= +cloud.google.com/go v0.123.0/go.mod h1:xBoMV08QcqUGuPW65Qfm1o9Y4zKZBpGS+7bImXLTAZU= +cloud.google.com/go/auth v0.18.1 h1:IwTEx92GFUo2pJ6Qea0EU3zYvKnTAeRCODxfA/G5UWs= +cloud.google.com/go/auth v0.18.1/go.mod h1:GfTYoS9G3CWpRA3Va9doKN9mjPGRS+v41jmZAhBzbrA= cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs= cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10= -github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/a2aproject/a2a-go/v2 v2.2.1 h1:NAfUoceWAStJ7FnF8TTfWuHejf9mzKgc9QmaKV1hyXw= github.com/a2aproject/a2a-go/v2 v2.2.1/go.mod h1:mkZr8y2bUgAVQsjs/5fHK7xrRlAHDybMEyxWh2tKRC8= github.com/alicebob/miniredis/v2 v2.37.0 h1:RheObYW32G1aiJIj81XVt78ZHJpHonHLHW7OLIshq68= @@ -18,24 +16,19 @@ github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM= github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw= -github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= -github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/dnaeon/go-vcr v1.2.0 h1:zHCHvJYTMh1N7xnV7zf1m1GPBF9Ad0Jk/whtQ1663qI= github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ= -github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= -github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= -github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= -github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/facebookgo/clock v0.0.0-20150410010913-600d898af40a h1:yDWHCSQ40h88yih2JAcL6Ls/kVkSE8GFACTGVnMPruw= github.com/facebookgo/clock v0.0.0-20150410010913-600d898af40a/go.mod h1:7Ga40egUymuWXxAe151lTNnCv97MddSOVsjpPPkityA= +github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= +github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= @@ -45,74 +38,119 @@ github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/go-openapi/analysis v0.24.1 h1:Xp+7Yn/KOnVWYG8d+hPksOYnCYImE3TieBa7rBOesYM= +github.com/go-openapi/analysis v0.24.1/go.mod h1:dU+qxX7QGU1rl7IYhBC8bIfmWQdX4Buoea4TGtxXY84= +github.com/go-openapi/errors v0.22.4 h1:oi2K9mHTOb5DPW2Zjdzs/NIvwi2N3fARKaTJLdNabaM= +github.com/go-openapi/errors v0.22.4/go.mod h1:z9S8ASTUqx7+CP1Q8dD8ewGH/1JWFFLX/2PmAYNQLgk= +github.com/go-openapi/jsonpointer v0.22.4 h1:dZtK82WlNpVLDW2jlA1YCiVJFVqkED1MegOUy9kR5T4= +github.com/go-openapi/jsonpointer v0.22.4/go.mod h1:elX9+UgznpFhgBuaMQ7iu4lvvX1nvNsesQ3oxmYTw80= +github.com/go-openapi/jsonreference v0.21.4 h1:24qaE2y9bx/q3uRK/qN+TDwbok1NhbSmGjjySRCHtC8= +github.com/go-openapi/jsonreference v0.21.4/go.mod h1:rIENPTjDbLpzQmQWCj5kKj3ZlmEh+EFVbz3RTUh30/4= +github.com/go-openapi/loads v0.23.2 h1:rJXAcP7g1+lWyBHC7iTY+WAF0rprtM+pm8Jxv1uQJp4= +github.com/go-openapi/loads v0.23.2/go.mod h1:IEVw1GfRt/P2Pplkelxzj9BYFajiWOtY2nHZNj4UnWY= +github.com/go-openapi/runtime v0.29.2 h1:UmwSGWNmWQqKm1c2MGgXVpC2FTGwPDQeUsBMufc5Yj0= +github.com/go-openapi/runtime v0.29.2/go.mod h1:biq5kJXRJKBJxTDJXAa00DOTa/anflQPhT0/wmjuy+0= +github.com/go-openapi/spec v0.22.3 h1:qRSmj6Smz2rEBxMnLRBMeBWxbbOvuOoElvSvObIgwQc= +github.com/go-openapi/spec v0.22.3/go.mod h1:iIImLODL2loCh3Vnox8TY2YWYJZjMAKYyLH2Mu8lOZs= +github.com/go-openapi/strfmt v0.25.0 h1:7R0RX7mbKLa9EYCTHRcCuIPcaqlyQiWNPTXwClK0saQ= +github.com/go-openapi/strfmt v0.25.0/go.mod h1:nNXct7OzbwrMY9+5tLX4I21pzcmE6ccMGXl3jFdPfn8= +github.com/go-openapi/swag v0.23.0 h1:vsEVJDUo2hPJ2tu0/Xc+4noaxyEffXNIs3cOULZ+GrE= +github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ= +github.com/go-openapi/swag/conv v0.25.4 h1:/Dd7p0LZXczgUcC/Ikm1+YqVzkEeCc9LnOWjfkpkfe4= +github.com/go-openapi/swag/conv v0.25.4/go.mod h1:3LXfie/lwoAv0NHoEuY1hjoFAYkvlqI/Bn5EQDD3PPU= +github.com/go-openapi/swag/fileutils v0.25.1 h1:rSRXapjQequt7kqalKXdcpIegIShhTPXx7yw0kek2uU= +github.com/go-openapi/swag/fileutils v0.25.1/go.mod h1:+NXtt5xNZZqmpIpjqcujqojGFek9/w55b3ecmOdtg8M= +github.com/go-openapi/swag/jsonname v0.25.4 h1:bZH0+MsS03MbnwBXYhuTttMOqk+5KcQ9869Vye1bNHI= +github.com/go-openapi/swag/jsonname v0.25.4/go.mod h1:GPVEk9CWVhNvWhZgrnvRA6utbAltopbKwDu8mXNUMag= +github.com/go-openapi/swag/jsonutils v0.25.4 h1:VSchfbGhD4UTf4vCdR2F4TLBdLwHyUDTd1/q4i+jGZA= +github.com/go-openapi/swag/jsonutils v0.25.4/go.mod h1:7OYGXpvVFPn4PpaSdPHJBtF0iGnbEaTk8AvBkoWnaAY= +github.com/go-openapi/swag/jsonutils/fixtures_test v0.25.4 h1:IACsSvBhiNJwlDix7wq39SS2Fh7lUOCJRmx/4SN4sVo= +github.com/go-openapi/swag/jsonutils/fixtures_test v0.25.4/go.mod h1:Mt0Ost9l3cUzVv4OEZG+WSeoHwjWLnarzMePNDAOBiM= +github.com/go-openapi/swag/loading v0.25.4 h1:jN4MvLj0X6yhCDduRsxDDw1aHe+ZWoLjW+9ZQWIKn2s= +github.com/go-openapi/swag/loading v0.25.4/go.mod h1:rpUM1ZiyEP9+mNLIQUdMiD7dCETXvkkC30z53i+ftTE= +github.com/go-openapi/swag/mangling v0.25.1 h1:XzILnLzhZPZNtmxKaz/2xIGPQsBsvmCjrJOWGNz/ync= +github.com/go-openapi/swag/mangling v0.25.1/go.mod h1:CdiMQ6pnfAgyQGSOIYnZkXvqhnnwOn997uXZMAd/7mQ= +github.com/go-openapi/swag/stringutils v0.25.4 h1:O6dU1Rd8bej4HPA3/CLPciNBBDwZj9HiEpdVsb8B5A8= +github.com/go-openapi/swag/stringutils v0.25.4/go.mod h1:GTsRvhJW5xM5gkgiFe0fV3PUlFm0dr8vki6/VSRaZK0= +github.com/go-openapi/swag/typeutils v0.25.4 h1:1/fbZOUN472NTc39zpa+YGHn3jzHWhv42wAJSN91wRw= +github.com/go-openapi/swag/typeutils v0.25.4/go.mod h1:Ou7g//Wx8tTLS9vG0UmzfCsjZjKhpjxayRKTHXf2pTE= +github.com/go-openapi/swag/yamlutils v0.25.4 h1:6jdaeSItEUb7ioS9lFoCZ65Cne1/RZtPBZ9A56h92Sw= +github.com/go-openapi/swag/yamlutils v0.25.4/go.mod h1:MNzq1ulQu+yd8Kl7wPOut/YHAAU/H6hL91fF+E2RFwc= +github.com/go-openapi/testify/enable/yaml/v2 v2.0.2 h1:0+Y41Pz1NkbTHz8NngxTuAXxEodtNSI1WG1c/m5Akw4= +github.com/go-openapi/testify/enable/yaml/v2 v2.0.2/go.mod h1:kme83333GCtJQHXQ8UKX3IBZu6z8T5Dvy5+CW3NLUUg= +github.com/go-openapi/testify/v2 v2.0.2 h1:X999g3jeLcoY8qctY/c/Z8iBHTbwLz7R2WXd6Ub6wls= +github.com/go-openapi/testify/v2 v2.0.2/go.mod h1:HCPmvFFnheKK2BuwSA0TbbdxJ3I16pjwMkYkP4Ywn54= +github.com/go-openapi/validate v0.25.1 h1:sSACUI6Jcnbo5IWqbYHgjibrhhmt3vR6lCzKZnmAgBw= +github.com/go-openapi/validate v0.25.1/go.mod h1:RMVyVFYte0gbSTaZ0N4KmTn6u/kClvAFp+mAVfS/DQc= github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= -github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= -github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= -github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= -github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= -github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= -github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= -github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= -github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= -github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= -github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= -github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= -github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8= github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= -github.com/google/s2a-go v0.1.8 h1:zZDs9gcbt9ZPLV0ndSyQk6Kacx2g/X+SKYovpnz3SMM= -github.com/google/s2a-go v0.1.8/go.mod h1:6iNWHTpQ+nfNRN5E00MSdfDwVesa8hhS32PhPO8deJA= -github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0= +github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/googleapis/enterprise-certificate-proxy v0.3.4 h1:XYIDZApgAnrN1c855gTgghdIA6Stxb52D5RnLI1SLyw= -github.com/googleapis/enterprise-certificate-proxy v0.3.4/go.mod h1:YKe7cfqYXjKGpGvmSg28/fFvhNzinZQm8DGnaburhGA= +github.com/googleapis/enterprise-certificate-proxy v0.3.11 h1:vAe81Msw+8tKUxi2Dqh/NZMz7475yUvmRIkXr4oN2ao= +github.com/googleapis/enterprise-certificate-proxy v0.3.11/go.mod h1:RFV7MUdlb7AgEq2v7FmMCfeSMCllAzWxFgRdusoGks8= +github.com/googleapis/gax-go/v2 v2.17.0 h1:RksgfBpxqff0EZkDWYuz9q/uWsTVz+kf43LsZ1J6SMc= +github.com/googleapis/gax-go/v2 v2.17.0/go.mod h1:mzaqghpQp4JDh3HvADwrat+6M3MOIDp5YKHhb9PAgDY= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.3.2 h1:sGm2vDRFUrQJO/Veii4h4zG2vvqG6uWNkBHSTqXOZk0= github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.3.2/go.mod h1:wd1YpapPLivG6nQgbf7ZkG1hhSOXDhhn4MLTknx2aAc= github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 h1:HWRh5R2+9EifMyIHV7ZV+MIZqgz+PMpZ14Jynv3O2Zs= github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0/go.mod h1:JfhWUomR1baixubs02l85lZYYOm7LV6om4ceouMv45c= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.9.2 h1:3ZhOzMWnR4yJ+RW1XImIPsD1aNSz4T4fyP7zlQb56hw= +github.com/jackc/pgx/v5 v5.9.2/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= -github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= -github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/klauspost/cpuid/v2 v2.2.11 h1:0OwqZRYI2rFrjS4kvkDnqJkKHdHaRnCm68/DY4OxRzU= +github.com/klauspost/cpuid/v2 v2.2.11/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/modelcontextprotocol/go-sdk v1.4.0 h1:u0kr8lbJc1oBcawK7Df+/ajNMpIDFE41OEPxdeTLOn8= github.com/modelcontextprotocol/go-sdk v1.4.0/go.mod h1:Nxc2n+n/GdCebUaqCOhTetptS17SXXNu9IfNTaLDi1E= github.com/nexus-rpc/sdk-go v0.6.0 h1:QRgnP2zTbxEbiyWG/aXH8uSC5LV/Mg1fqb19jb4DBlo= github.com/nexus-rpc/sdk-go v0.6.0/go.mod h1:FHdPfVQwRuJFZFTF0Y2GOAxCrbIBNrcPna9slkGKPYk= +github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= +github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= github.com/openai/openai-go/v3 v3.26.0 h1:bRt6H/ozMNt/dDkN4gobnLqaEGrRGBzmbVs0xxJEnQE= github.com/openai/openai-go/v3 v3.26.0/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo= github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pgvector/pgvector-go v0.4.0 h1:879hQCnuix1bkfa5TQISnnK9ik4Fo+cHj2vuZSgW5v4= +github.com/pgvector/pgvector-go v0.4.0/go.mod h1:4fSXyjl1TYAIdByAql6JazKWRr2s7J0g4hcRY5cBFCk= +github.com/pgvector/pgvector-go/pgx v0.4.0 h1:wHFoQRtCksVfmrBaHoxeT8IkonmnxlvnLzz3T4EW9Y0= +github.com/pgvector/pgvector-go/pgx v0.4.0/go.mod h1:G61nQVFeCjO8sJU9SsihwGf5Ko34IOnaqXfOWe2kBpU= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/redis/go-redis/v9 v9.18.0 h1:pMkxYPkEbMPwRdenAzUNyFNrDgHx9U+DrBabWNfSRQs= github.com/redis/go-redis/v9 v9.18.0/go.mod h1:k3ufPphLU5YXwNTUcCRXGxUoF1fqxnhFQmscfkCoDA0= github.com/robfig/cron v1.2.0 h1:ZjScXvvxeQ63Dbyxy76Fj3AT3Ut0aKsyd2/tl3DTMuQ= @@ -136,13 +174,10 @@ github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3A github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU= github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= -github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= -github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= -github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/objx v0.5.3 h1:jmXUvGomnU1o3W/V5h2VEradbpJDwGrzugQQvL0POH4= +github.com/stretchr/objx v0.5.3/go.mod h1:rDQraq+vQZU7Fde9LOZLr8Tax6zZvy4kuNKF+QYS+U0= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= @@ -157,6 +192,12 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= +github.com/weaviate/weaviate v1.37.2 h1:LgDVfNfYA01KAsHtlznuB2JoRWA5tSuoDttwe9PEg2Q= +github.com/weaviate/weaviate v1.37.2/go.mod h1:dZtzUfJmL9zs/47ADqi+d/pyAQDRwFrun/dCjKj7R5A= +github.com/weaviate/weaviate-go-client/v5 v5.7.3 h1:AB7asp3Nv3QMh34EtJarpezQQ600p6TJVrdRhmisrMI= +github.com/weaviate/weaviate-go-client/v5 v5.7.3/go.mod h1:Q6lG7oDiKZLcE0UPiPkJ2+V7bOejSOpgJV3kZari1MY= +github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= +github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -166,12 +207,14 @@ github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= -go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= -go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= +go.mongodb.org/mongo-driver v1.17.6 h1:87JUG1wZfWsr6rIz3ZmpH90rL5tea7O3IHuSwHUpsss= +go.mongodb.org/mongo-driver v1.17.6/go.mod h1:Hy04i7O2kC4RS06ZrhPRqj/u4DTYkFDAAccj+rVKqgQ= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= go.opentelemetry.io/contrib/bridges/otelslog v0.18.0 h1:hhPGP3zvvy1xWT9RTy970wlniSxFttBIsAK1gvMguJM= go.opentelemetry.io/contrib/bridges/otelslog v0.18.0/go.mod h1:twJF7inoMza6kxMcF8JOdL3mPmtOZu7GEr34CUNE6Dg= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 h1:F7Jx+6hwnZ41NSFTO5q4LYDtJRXBf2PD0rNBkeB/lus= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0/go.mod h1:UHB22Z8QsdRDrnAtX4PntOl36ajSxcdUMt1sF7Y6E7Q= go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I= go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0= go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc v0.19.0 h1:Dn8rkudDzY6KV9dr/D/bTUuWgqDf9xe0rr4G2elrn0Y= @@ -196,6 +239,8 @@ go.opentelemetry.io/otel/sdk v1.43.0 h1:pi5mE86i5rTeLXqoF/hhiBtUNcrAGHLKQdhg4h4V go.opentelemetry.io/otel/sdk v1.43.0/go.mod h1:P+IkVU3iWukmiit/Yf9AWvpyRDlUeBaRg6Y+C58QHzg= go.opentelemetry.io/otel/sdk/log v0.19.0 h1:scYVLqT22D2gqXItnWiocLUKGH9yvkkeql5dBDiXyko= go.opentelemetry.io/otel/sdk/log v0.19.0/go.mod h1:vFBowwXGLlW9AvpuF7bMgnNI95LiW10szrOdvzBHlAg= +go.opentelemetry.io/otel/sdk/log/logtest v0.19.0 h1:BEbF7ZBB6qQloV/Ub1+3NQoOUnVtcGkU3XX4Ws3GQfk= +go.opentelemetry.io/otel/sdk/log/logtest v0.19.0/go.mod h1:Lua81/3yM0wOmoHTokLj9y9ADeA02v1naRrVrkAZuKk= go.opentelemetry.io/otel/sdk/metric v1.43.0 h1:S88dyqXjJkuBNLeMcVPRFXpRw2fuwdvfCGLEo89fDkw= go.opentelemetry.io/otel/sdk/metric v1.43.0/go.mod h1:C/RJtwSEJ5hzTiUz5pXF1kILHStzb9zFlIEe85bhj6A= go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A= @@ -219,39 +264,26 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= -golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= -golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= -golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8= golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w= -golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= -golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.35.0 h1:Mv2mzuHuZuY2+bkyWXIHMfhNdJAdwW3FuWeCPYN5GVQ= golang.org/x/oauth2 v0.35.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= -golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= -golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -265,13 +297,9 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= -golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U= -golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= +golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= -golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= @@ -284,33 +312,14 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4= gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E= -google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= -google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/genai v1.51.0 h1:IZGuUqgfx40INv3hLFGCbOSGp0qFqm7LVmDghzNIYqg= google.golang.org/genai v1.51.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk= -google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= -google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= google.golang.org/genproto/googleapis/api v0.0.0-20260427160629-7cedc36a6bc4 h1:yOzSCGPx+cp5VO7IxvZ9SBFF7j1tZVcNtlHR2iYKtVo= google.golang.org/genproto/googleapis/api v0.0.0-20260427160629-7cedc36a6bc4/go.mod h1:Q9HWtNeE7tM9npdIsEvqXj1QJIvVoeAV3rtXtS715Cw= google.golang.org/genproto/googleapis/rpc v0.0.0-20260427160629-7cedc36a6bc4 h1:tEkOQcXgF6dH1G+MVKZrfpYvozGrzb91k6ha7jireSM= google.golang.org/genproto/googleapis/rpc v0.0.0-20260427160629-7cedc36a6bc4/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= -google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= -google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= -google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= -google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= -google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc= google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM= google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4= -google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= -google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= -google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= -google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= -google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= -google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= @@ -321,5 +330,3 @@ gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go index 3ba51fb..e1280ea 100644 --- a/internal/runtime/runtime.go +++ b/internal/runtime/runtime.go @@ -83,10 +83,19 @@ type AgentSpec struct { // stay stable so callers do not depend on a single flat blob that might be reshaped later. // Temporal-backed runtimes typically use worker-local configuration for activities; this is a snapshot. type AgentExecution struct { - LLM AgentLLM - Tools AgentTools - Session AgentSession - Limits AgentLimits + LLM AgentLLM + Tools AgentTools + Retrievers AgentRetrievers + Session AgentSession + Limits AgentLimits +} + +// AgentRetrievers holds the retriever instances and mode for prefetch and hybrid RAG. +type AgentRetrievers struct { + // Retrievers is the list of retriever instances registered with the agent. + Retrievers []interfaces.Retriever + // Mode is the retriever mode (agentic, prefetch, hybrid). + Mode types.RetrieverMode } // LLMSampling is the runtime package name for per-run sampling options. diff --git a/internal/runtime/temporal/agent_workflow.go b/internal/runtime/temporal/agent_workflow.go index ea13335..bfa1e34 100644 --- a/internal/runtime/temporal/agent_workflow.go +++ b/internal/runtime/temporal/agent_workflow.go @@ -36,6 +36,9 @@ var ( agentToolExecuteActivityTaskTimeout time.Duration = 30 * time.Minute agentToolExecuteActivityMaxAttempts int32 = 3 + agentRetrieverActivityTaskTimeout time.Duration = 5 * time.Minute + agentRetrieverActivityMaxAttempts int32 = 3 + sendEventActivityTaskTimeout time.Duration = 15 * time.Second sendEventActivityMaxAttempts int32 = 1 @@ -134,9 +137,23 @@ type AgentWorkflowState struct { Messages []interfaces.Message `json:"messages"` } +// AgentRetrieverInput is the input to AgentRetrieverActivity. +type AgentRetrieverInput struct { + AgentFingerprint string `json:"agent_fingerprint,omitempty"` + UserPrompt string `json:"user_prompt"` +} + +// AgentRetrieverResult is the return value of AgentRetrieverActivity. +// RetrieverContext is the combined, formatted document context from all retrievers; empty when no +// documents were found. It is injected into the system prompt by AgentLLMActivity and AgentLLMStreamActivity. +type AgentRetrieverResult struct { + RetrieverContext string `json:"retriever_context,omitempty"` +} + // AgentLLMInput is the input to AgentLLMActivity and AgentLLMStreamActivity. // When ConversationID is set, the activity loads history from the store. MessageID is the assistant text id // for TEXT_MESSAGE_* (and stream ordering with REASONING_*); the workflow sets it each turn. +// RetrieverContext is the pre-fetched RAG context from AgentRetrieverActivity (prefetch / hybrid modes). type AgentLLMInput struct { AgentName string `json:"agent_name,omitempty"` ConversationID string `json:"conversation_id,omitempty"` @@ -147,6 +164,7 @@ type AgentLLMInput struct { EventWorkflowID string `json:"event_workflow_id,omitempty"` EventTaskQueue string `json:"event_task_queue,omitempty"` LocalChannelName string `json:"local_channel_name,omitempty"` + RetrieverContext string `json:"retriever_context,omitempty"` } // AgentLLMResult is the return value of AgentLLMActivity. Workflow uses it to decide: return content or execute tools. @@ -284,6 +302,11 @@ func (rt *TemporalRuntime) AgentWorkflow(ctx workflow.Context, input AgentWorkfl StartToCloseTimeout: conversationActivityTaskTimeout, RetryPolicy: retryPolicy(conversationActivityMaxAttempts), }) + retrieverActCtx := workflow.WithActivityOptions(ctx, workflow.ActivityOptions{ + ActivityID: "AgentRetrieverActivity_" + activityIDSuffix, + StartToCloseTimeout: agentRetrieverActivityTaskTimeout, + RetryPolicy: retryPolicy(agentRetrieverActivityMaxAttempts), + }) var streamingUnavailable bool // emitAgentEvent must use wfCtx (the coroutine that calls Get) for ExecuteActivity().Get — not the root @@ -347,6 +370,29 @@ func (rt *TemporalRuntime) AgentWorkflow(ctx workflow.Context, input AgentWorkfl messages := input.State.Messages + // Pre-fetch retrieval context once before the first LLM call (prefetch and hybrid modes). + // The resulting retrieverContext is forwarded to every AgentLLMInput in the run so the LLM always + // sees the retrieved documents in its system prompt, regardless of the number of iterations. + retrieverContext := "" + retrieverMode := rt.AgentExecution.Retrievers.Mode + if (retrieverMode == types.RetrieverModePrefetch || retrieverMode == types.RetrieverModeHybrid) && + len(rt.AgentExecution.Retrievers.Retrievers) > 0 { + logger.Debug("workflow: retriever prefetch started", "scope", "workflow", "retrieverMode", string(retrieverMode), "retrieverCount", len(rt.AgentExecution.Retrievers.Retrievers)) + retrieverInput := AgentRetrieverInput{ + AgentFingerprint: input.AgentFingerprint, + UserPrompt: input.UserPrompt, + } + var retrieverResult AgentRetrieverResult + if err := workflow.ExecuteActivity(retrieverActCtx, rt.AgentRetrieverActivity, retrieverInput).Get(retrieverActCtx, &retrieverResult); err != nil { + if temporal.IsCanceledError(err) { + return nil, err + } + return nil, err + } + retrieverContext = retrieverResult.RetrieverContext + logger.Debug("workflow: retriever prefetch done", "scope", "workflow", "hasContext", retrieverContext != "") + } + lastContent := "" var runUsage *interfaces.LLMUsage var llmResult AgentLLMResult @@ -363,6 +409,7 @@ func (rt *TemporalRuntime) AgentWorkflow(ctx workflow.Context, input AgentWorkfl EventWorkflowID: eventWorkflowID, EventTaskQueue: eventTaskQueue, LocalChannelName: input.LocalChannelName, + RetrieverContext: retrieverContext, } if useStreaming { @@ -860,7 +907,7 @@ func (rt *TemporalRuntime) AgentLLMStreamActivity(ctx context.Context, input Age logger.Debug("activity: LLM stream started", "scope", "activity", "runID", agentWorkflowID, "messageCount", len(messages)) - req, tools := rt.buildLLMRequest(messages, input.SkipTools) + req, tools := rt.buildLLMRequest(messages, input.SkipTools, input.RetrieverContext) emitDelta := func(ev events.AgentEvent) { rt.publishAgentEventToStream(ctx, agentName, input.LocalChannelName, input.EventWorkflowID, input.EventTaskQueue, ev) @@ -1026,11 +1073,17 @@ func (rt *TemporalRuntime) AgentLLMStreamActivity(ctx context.Context, input Age return result, nil } -// buildLLMRequest builds an LLMRequest from messages and skipTools. Returns the request and tools list. -func (rt *TemporalRuntime) buildLLMRequest(messages []interfaces.Message, skipTools bool) (*interfaces.LLMRequest, []interfaces.Tool) { +// buildLLMRequest builds an LLMRequest from messages, skipTools, and optional retrieverContext. +// When retrieverContext is non-empty (prefetch / hybrid mode) it is appended to the system prompt so the +// LLM sees pre-fetched documents on every call in the run. Returns the request and tools list. +func (rt *TemporalRuntime) buildLLMRequest(messages []interfaces.Message, skipTools bool, retrieverContext string) (*interfaces.LLMRequest, []interfaces.Tool) { tools := rt.AgentExecution.Tools.Tools + systemMessage := rt.AgentSpec.SystemPrompt + if retrieverContext != "" { + systemMessage = fmt.Sprintf("%s\n\nRelevant Context:\n%s", rt.AgentSpec.SystemPrompt, retrieverContext) + } req := &interfaces.LLMRequest{ - SystemMessage: rt.AgentSpec.SystemPrompt, + SystemMessage: systemMessage, ResponseFormat: rt.AgentSpec.ResponseFormat, Messages: messages, } @@ -1100,6 +1153,105 @@ func (rt *TemporalRuntime) llmResponseToResult(resp *interfaces.LLMResponse, too return result, nil } +// AgentRetrieverActivity runs all configured retrievers in parallel using input.UserPrompt as the query, +// then returns a combined document context string for injection into the LLM system prompt. +// Called only for [types.RetrieverModePrefetch] and [types.RetrieverModeHybrid]. +// Partial failures (some retrievers fail) are logged and skipped; if all retrievers fail, the activity +// returns an error so Temporal can retry per the retry policy. +func (rt *TemporalRuntime) AgentRetrieverActivity(ctx context.Context, input AgentRetrieverInput) (*AgentRetrieverResult, error) { + if err := rt.verifyAgentFingerprint(input.AgentFingerprint); err != nil { + return nil, err + } + + retrievers := rt.AgentExecution.Retrievers.Retrievers + if len(retrievers) == 0 { + return &AgentRetrieverResult{}, nil + } + + logger := activity.GetLogger(ctx) + logger.Debug("activity: retriever prefetch started", "scope", "activity", "retrieverCount", len(retrievers), "query", input.UserPrompt) + + type retrieverResult struct { + name string + docs []interfaces.Document + err error + } + + results := make([]retrieverResult, len(retrievers)) + var wg sync.WaitGroup + for i, r := range retrievers { + wg.Add(1) + go func(idx int, ret interfaces.Retriever) { + defer wg.Done() + name := ret.Name() + retrieverAttr := interfaces.Attribute{Key: types.MetricAttrRetriever, Value: name} + rt.Metrics.IncrementCounter(ctx, types.MetricRetrieverCallStarted, retrieverAttr) + start := time.Now() + + searchCtx, sp := rt.Tracer.StartSpan(ctx, "retriever.search", + interfaces.Attribute{Key: "retriever.name", Value: name}, + interfaces.Attribute{Key: "query", Value: input.UserPrompt}, + ) + docs, err := ret.Search(searchCtx, input.UserPrompt) + latency := float64(time.Since(start).Milliseconds()) + if err != nil { + sp.RecordError(err) + sp.End() + rt.Metrics.IncrementCounter(ctx, types.MetricRetrieverCallFailed, retrieverAttr) + rt.Metrics.RecordHistogram(ctx, types.MetricRetrieverLatencyMs, latency, retrieverAttr) + } else { + sp.End() + rt.Metrics.IncrementCounter(ctx, types.MetricRetrieverCallCompleted, retrieverAttr) + rt.Metrics.RecordHistogram(ctx, types.MetricRetrieverLatencyMs, latency, retrieverAttr) + } + results[idx] = retrieverResult{name: name, docs: docs, err: err} + }(i, r) + } + wg.Wait() + + multipleRetrievers := len(retrievers) > 1 + var sb strings.Builder + failedCount := 0 + for _, res := range results { + if res.err != nil { + failedCount++ + logger.Error("activity: retriever search failed, skipping", "scope", "activity", "retriever", res.name, "error", res.err) + continue + } + if len(res.docs) == 0 { + continue + } + if multipleRetrievers { + fmt.Fprintf(&sb, "## %s\n", res.name) + } + sb.WriteString(formatRetrieverDocs(res.docs)) + } + + if failedCount == len(retrievers) { + return nil, fmt.Errorf("retriever prefetch: all %d retriever(s) failed", len(retrievers)) + } + if failedCount > 0 { + logger.Warn("activity: some retrievers failed, continuing with partial context", "scope", "activity", "failed", failedCount, "total", len(retrievers)) + } + + retrieverContext := strings.TrimSpace(sb.String()) + logger.Debug("activity: retriever prefetch completed", "scope", "activity", "retrieverCount", len(retrievers), "hasContext", retrieverContext != "") + return &AgentRetrieverResult{RetrieverContext: retrieverContext}, nil +} + +// formatRetrieverDocs formats a list of documents for injection into the LLM system prompt. +// Format: "[N] content\n(source: s, score: 0.XX)\n\n" for each document. +func formatRetrieverDocs(docs []interfaces.Document) string { + if len(docs) == 0 { + return "" + } + var sb strings.Builder + for i, doc := range docs { + fmt.Fprintf(&sb, types.RetrieverDocFormat, i+1, doc.Content, doc.Source, doc.Score) + } + return sb.String() +} + // AgentLLMActivity calls the LLM and returns content plus any tool calls. // When input.ConversationID is set, fetches from store and adds assistant message on completion. func (rt *TemporalRuntime) AgentLLMActivity(ctx context.Context, input AgentLLMInput) (*AgentLLMResult, error) { @@ -1118,7 +1270,7 @@ func (rt *TemporalRuntime) AgentLLMActivity(ctx context.Context, input AgentLLMI } logger.Debug("activity: LLM generate started", "scope", "activity", "messageCount", len(messages)) - req, tools := rt.buildLLMRequest(messages, input.SkipTools) + req, tools := rt.buildLLMRequest(messages, input.SkipTools, input.RetrieverContext) llmClient := rt.AgentExecution.LLM.Client model := llmClient.GetModel() diff --git a/internal/runtime/temporal/agent_workflow_test.go b/internal/runtime/temporal/agent_workflow_test.go index 46b3575..9f6f302 100644 --- a/internal/runtime/temporal/agent_workflow_test.go +++ b/internal/runtime/temporal/agent_workflow_test.go @@ -2,6 +2,7 @@ package temporal import ( "context" + "fmt" "testing" "github.com/golang/mock/gomock" @@ -479,6 +480,272 @@ func TestAgentWorkflow_ContinueAsNewOnHistorySizeAfterTools(t *testing.T) { require.True(t, workflow.IsContinueAsNewError(wfErr), "expected continue-as-new, got: %v", wfErr) } +// --------------------------------------------------------------------------- +// AgentRetrieverActivity tests +// --------------------------------------------------------------------------- + +func makeRetrieverRuntime(t *testing.T, retrievers []interfaces.Retriever, mode types.RetrieverMode) *TemporalRuntime { + t.Helper() + mockLLM := mocks.NewMockLLMClient(gomock.NewController(t)) + mockLLM.EXPECT().GetModel().Return("test-model").AnyTimes() + mockLLM.EXPECT().GetProvider().Return(interfaces.LLMProviderOpenAI).AnyTimes() + return &TemporalRuntime{ + TemporalRuntimeConfig: TemporalRuntimeConfig{ + AgentSpec: sdkruntime.AgentSpec{Name: "RetrieverTest"}, + AgentExecution: sdkruntime.AgentExecution{ + LLM: sdkruntime.AgentLLM{Client: mockLLM}, + Retrievers: sdkruntime.AgentRetrievers{ + Retrievers: retrievers, + Mode: mode, + }, + }, + logger: logger.NoopLogger(), + Tracer: observability.DefaultNoopTracer, + Metrics: observability.DefaultNoopMetrics, + }, + } +} + +func TestAgentRetrieverActivity_NoRetrievers(t *testing.T) { + rt := makeRetrieverRuntime(t, nil, types.RetrieverModePrefetch) + actEnv := newActivityTestEnv(t) + actEnv.RegisterActivity(rt.AgentRetrieverActivity) + + val, err := actEnv.ExecuteActivity(rt.AgentRetrieverActivity, AgentRetrieverInput{UserPrompt: "test"}) + require.NoError(t, err) + + var got AgentRetrieverResult + require.NoError(t, val.Get(&got)) + require.Empty(t, got.RetrieverContext) +} + +func TestAgentRetrieverActivity_SingleRetriever(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockR := mocks.NewMockRetriever(ctrl) + mockR.EXPECT().Name().Return("kb").AnyTimes() + mockR.EXPECT().Search(gomock.Any(), "what is Go?").Return([]interfaces.Document{ + {Content: "Go is a language", Source: "docs.go.dev", Score: 0.95}, + }, nil) + + rt := makeRetrieverRuntime(t, []interfaces.Retriever{mockR}, types.RetrieverModePrefetch) + actEnv := newActivityTestEnv(t) + actEnv.RegisterActivity(rt.AgentRetrieverActivity) + + val, err := actEnv.ExecuteActivity(rt.AgentRetrieverActivity, AgentRetrieverInput{UserPrompt: "what is Go?"}) + require.NoError(t, err) + + var got AgentRetrieverResult + require.NoError(t, val.Get(&got)) + require.Contains(t, got.RetrieverContext, "Go is a language") + require.Contains(t, got.RetrieverContext, "docs.go.dev") + require.Contains(t, got.RetrieverContext, "0.95") + // Single retriever: no section header + require.NotContains(t, got.RetrieverContext, "## kb") +} + +func TestAgentRetrieverActivity_MultipleRetrievers_SectionHeaders(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockR1 := mocks.NewMockRetriever(ctrl) + mockR1.EXPECT().Name().Return("r1").AnyTimes() + mockR1.EXPECT().Search(gomock.Any(), "q").Return([]interfaces.Document{ + {Content: "doc from r1", Source: "s1", Score: 0.9}, + }, nil) + + mockR2 := mocks.NewMockRetriever(ctrl) + mockR2.EXPECT().Name().Return("r2").AnyTimes() + mockR2.EXPECT().Search(gomock.Any(), "q").Return([]interfaces.Document{ + {Content: "doc from r2", Source: "s2", Score: 0.8}, + }, nil) + + rt := makeRetrieverRuntime(t, []interfaces.Retriever{mockR1, mockR2}, types.RetrieverModeHybrid) + actEnv := newActivityTestEnv(t) + actEnv.RegisterActivity(rt.AgentRetrieverActivity) + + val, err := actEnv.ExecuteActivity(rt.AgentRetrieverActivity, AgentRetrieverInput{UserPrompt: "q"}) + require.NoError(t, err) + + var got AgentRetrieverResult + require.NoError(t, val.Get(&got)) + require.Contains(t, got.RetrieverContext, "## r1") + require.Contains(t, got.RetrieverContext, "doc from r1") + require.Contains(t, got.RetrieverContext, "## r2") + require.Contains(t, got.RetrieverContext, "doc from r2") +} + +func TestAgentRetrieverActivity_PartialFailure_ContinuesWithPartialContext(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockOK := mocks.NewMockRetriever(ctrl) + mockOK.EXPECT().Name().Return("ok").AnyTimes() + mockOK.EXPECT().Search(gomock.Any(), "q").Return([]interfaces.Document{ + {Content: "good doc", Source: "src", Score: 0.88}, + }, nil) + + mockFail := mocks.NewMockRetriever(ctrl) + mockFail.EXPECT().Name().Return("bad").AnyTimes() + mockFail.EXPECT().Search(gomock.Any(), "q").Return(nil, fmt.Errorf("connection refused")) + + rt := makeRetrieverRuntime(t, []interfaces.Retriever{mockOK, mockFail}, types.RetrieverModePrefetch) + actEnv := newActivityTestEnv(t) + actEnv.RegisterActivity(rt.AgentRetrieverActivity) + + val, err := actEnv.ExecuteActivity(rt.AgentRetrieverActivity, AgentRetrieverInput{UserPrompt: "q"}) + require.NoError(t, err) + + var got AgentRetrieverResult + require.NoError(t, val.Get(&got)) + require.Contains(t, got.RetrieverContext, "good doc") + require.NotContains(t, got.RetrieverContext, "bad") +} + +func TestAgentRetrieverActivity_AllFail_ReturnsError(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockFail := mocks.NewMockRetriever(ctrl) + mockFail.EXPECT().Name().Return("bad").AnyTimes() + mockFail.EXPECT().Search(gomock.Any(), "q").Return(nil, fmt.Errorf("timeout")) + + rt := makeRetrieverRuntime(t, []interfaces.Retriever{mockFail}, types.RetrieverModePrefetch) + actEnv := newActivityTestEnv(t) + actEnv.RegisterActivity(rt.AgentRetrieverActivity) + + _, err := actEnv.ExecuteActivity(rt.AgentRetrieverActivity, AgentRetrieverInput{UserPrompt: "q"}) + require.Error(t, err) + require.Contains(t, err.Error(), "all") +} + +func TestAgentRetrieverActivity_EmptyDocs_EmptyContext(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockR := mocks.NewMockRetriever(ctrl) + mockR.EXPECT().Name().Return("kb").AnyTimes() + mockR.EXPECT().Search(gomock.Any(), "q").Return(nil, nil) + + rt := makeRetrieverRuntime(t, []interfaces.Retriever{mockR}, types.RetrieverModePrefetch) + actEnv := newActivityTestEnv(t) + actEnv.RegisterActivity(rt.AgentRetrieverActivity) + + val, err := actEnv.ExecuteActivity(rt.AgentRetrieverActivity, AgentRetrieverInput{UserPrompt: "q"}) + require.NoError(t, err) + + var got AgentRetrieverResult + require.NoError(t, val.Get(&got)) + require.Empty(t, got.RetrieverContext) +} + +// --------------------------------------------------------------------------- +// buildLLMRequest RAG context tests +// --------------------------------------------------------------------------- + +func TestBuildLLMRequest_WithRagContext_AugmentsSystemPrompt(t *testing.T) { + rt := &TemporalRuntime{ + TemporalRuntimeConfig: TemporalRuntimeConfig{ + AgentSpec: sdkruntime.AgentSpec{Name: "Test", SystemPrompt: "You are helpful."}, + AgentExecution: sdkruntime.AgentExecution{LLM: sdkruntime.AgentLLM{Client: stubLLM{}}}, + }, + } + req, _ := rt.buildLLMRequest(nil, false, "doc context") + require.Contains(t, req.SystemMessage, "You are helpful.") + require.Contains(t, req.SystemMessage, "Relevant Context:") + require.Contains(t, req.SystemMessage, "doc context") +} + +func TestBuildLLMRequest_NoRagContext_UnchangedSystemPrompt(t *testing.T) { + rt := &TemporalRuntime{ + TemporalRuntimeConfig: TemporalRuntimeConfig{ + AgentSpec: sdkruntime.AgentSpec{Name: "Test", SystemPrompt: "You are helpful."}, + AgentExecution: sdkruntime.AgentExecution{LLM: sdkruntime.AgentLLM{Client: stubLLM{}}}, + }, + } + req, _ := rt.buildLLMRequest(nil, false, "") + require.Equal(t, "You are helpful.", req.SystemMessage) +} + +// --------------------------------------------------------------------------- +// AgentWorkflow + prefetch mode integration +// --------------------------------------------------------------------------- + +func TestAgentWorkflow_PrefetchMode_CallsRetrieverActivityFirst(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockR := mocks.NewMockRetriever(ctrl) + mockR.EXPECT().Name().Return("kb").AnyTimes() + + var suite testsuite.WorkflowTestSuite + env := suite.NewTestWorkflowEnvironment() + rt := &TemporalRuntime{ + TemporalRuntimeConfig: TemporalRuntimeConfig{ + AgentSpec: sdkruntime.AgentSpec{Name: "PrefetchAgent", SystemPrompt: "base prompt"}, + AgentExecution: sdkruntime.AgentExecution{ + LLM: sdkruntime.AgentLLM{Client: stubLLM{}}, + Limits: sdkruntime.AgentLimits{MaxIterations: 5}, + Retrievers: sdkruntime.AgentRetrievers{ + Retrievers: []interfaces.Retriever{mockR}, + Mode: types.RetrieverModePrefetch, + }, + }, + logger: logger.NoopLogger(), + Tracer: observability.DefaultNoopTracer, + Metrics: observability.DefaultNoopMetrics, + }, + } + + env.RegisterWorkflow(rt.AgentWorkflow) + + retrieverCalled := false + env.OnActivity(rt.AgentRetrieverActivity, mock.Anything, mock.Anything).Return( + func(ctx context.Context, in AgentRetrieverInput) (*AgentRetrieverResult, error) { + retrieverCalled = true + require.Equal(t, "user query", in.UserPrompt) + return &AgentRetrieverResult{RetrieverContext: "[1] prefetched doc"}, nil + }) + + env.OnActivity(rt.AgentLLMActivity, mock.Anything, mock.Anything).Return( + func(ctx context.Context, in AgentLLMInput) (*AgentLLMResult, error) { + require.Contains(t, in.RetrieverContext, "prefetched doc") + return &AgentLLMResult{Content: "answer"}, nil + }) + + env.ExecuteWorkflow(rt.AgentWorkflow, AgentWorkflowInput{UserPrompt: "user query"}) + + require.True(t, env.IsWorkflowCompleted()) + require.NoError(t, env.GetWorkflowError()) + require.True(t, retrieverCalled, "AgentRetrieverActivity must have been called") + + var result types.AgentRunResult + require.NoError(t, env.GetWorkflowResult(&result)) + require.Equal(t, "answer", result.Content) +} + +func TestAgentWorkflow_AgenticMode_SkipsRetrieverActivity(t *testing.T) { + var suite testsuite.WorkflowTestSuite + env := suite.NewTestWorkflowEnvironment() + rt := testRuntimeForWorkflow(t) + + env.RegisterWorkflow(rt.AgentWorkflow) + env.OnActivity(rt.AgentLLMActivity, mock.Anything, mock.Anything).Return(&AgentLLMResult{Content: "done"}, nil) + + // AgentRetrieverActivity must NOT be called when mode is agentic (default / empty) + env.OnActivity(rt.AgentRetrieverActivity, mock.Anything, mock.Anything).Return( + func(ctx context.Context, in AgentRetrieverInput) (*AgentRetrieverResult, error) { + t.Error("AgentRetrieverActivity must not be called in agentic mode") + return &AgentRetrieverResult{}, nil + }) + + env.ExecuteWorkflow(rt.AgentWorkflow, AgentWorkflowInput{UserPrompt: "hi"}) + + require.True(t, env.IsWorkflowCompleted()) + require.NoError(t, env.GetWorkflowError()) +} + func TestMergeLLMUsage(t *testing.T) { a := &interfaces.LLMUsage{PromptTokens: 10, CompletionTokens: 5, TotalTokens: 15} b := &interfaces.LLMUsage{PromptTokens: 3, CompletionTokens: 7, TotalTokens: 10, CachedPromptTokens: 2, ReasoningTokens: 1} diff --git a/internal/runtime/temporal/config.go b/internal/runtime/temporal/config.go index dc56425..34f1198 100644 --- a/internal/runtime/temporal/config.go +++ b/internal/runtime/temporal/config.go @@ -50,6 +50,8 @@ type TemporalRuntimeConfig struct { AgentMode string // AgentToolExecutionMode is the [types.AgentToolExecutionMode] (e.g. "sequential", "parallel"); must match pkg/agent WithAgentToolExecutionMode. AgentToolExecutionMode types.AgentToolExecutionMode + // RetrieverFingerprint is from pkg/agent retrieverConfigFingerprint; must match caller temporal.ComputeAgentFingerprint inputs. + RetrieverFingerprint string // DisableLocalWorker mirrors pkg/agent [DisableLocalWorker]: when false, the client embeds a worker // so Execute/ExecuteStream skip DescribeTaskQueue poller checks. ([NewAgentWorker] never calls those methods.) DisableLocalWorker bool @@ -171,6 +173,14 @@ func WithAgentToolExecutionMode(mode types.AgentToolExecutionMode) Option { } } +// WithRetrieverFingerprint sets the retriever wiring digest (mode + retriever names). +// Must match pkg/agent [retrieverConfigFingerprint] for the same agent. +func WithRetrieverFingerprint(fp string) Option { + return func(c *TemporalRuntimeConfig) { + c.RetrieverFingerprint = fp + } +} + // WithDisableLocalWorker mirrors pkg/agent [DisableLocalWorker]. When false, the client embeds a worker // and the runtime skips DescribeTaskQueue poller checks before starting workflows. func WithDisableLocalWorker(disable bool) Option { diff --git a/internal/runtime/temporal/fingerprint.go b/internal/runtime/temporal/fingerprint.go index 07a0a4b..d929c62 100644 --- a/internal/runtime/temporal/fingerprint.go +++ b/internal/runtime/temporal/fingerprint.go @@ -51,6 +51,10 @@ type AgentFingerprintPayload struct { // AgentToolExecutionMode is the tool execution mode (e.g. sequential vs parallel); must match pkg/agent WithAgentToolExecutionMode on caller and worker. AgentToolExecutionMode string `json:"agent_tool_execution_mode"` + // RetrieverFingerprint is the pkg/agent digest of retriever mode and registered retriever names. + // Omitted when empty. Must match pkg/agent [retrieverConfigFingerprint] on caller and worker. + RetrieverFingerprint string `json:"retriever_fingerprint,omitempty"` + Sampling *sdkruntime.LLMSampling `json:"sampling,omitempty"` SessionSize int `json:"session_size"` @@ -89,6 +93,7 @@ func BuildAgentFingerprintPayload( observabilityFingerprint string, agentMode string, agentToolExecutionMode types.AgentToolExecutionMode, + retrieverFingerprint string, ) AgentFingerprintPayload { names := append([]string(nil), toolNames...) sort.Strings(names) @@ -111,6 +116,7 @@ func BuildAgentFingerprintPayload( ObservabilityFingerprint: observabilityFingerprint, AgentMode: mode, AgentToolExecutionMode: string(toolExecutionMode), + RetrieverFingerprint: retrieverFingerprint, Sampling: cloneLLMSampling(sampling), SessionSize: sessionSize, MaxIterations: limits.MaxIterations, @@ -187,6 +193,7 @@ func computeAgentFingerprintFromRuntimeConfig(c *TemporalRuntimeConfig) string { c.ObservabilityFingerprint, c.AgentMode, c.AgentToolExecutionMode, + c.RetrieverFingerprint, ) return ComputeAgentFingerprint(mat) } diff --git a/internal/runtime/temporal/fingerprint_test.go b/internal/runtime/temporal/fingerprint_test.go index 2ddf005..9ec144a 100644 --- a/internal/runtime/temporal/fingerprint_test.go +++ b/internal/runtime/temporal/fingerprint_test.go @@ -34,6 +34,7 @@ func TestComputeAgentFingerprint_stableAndToolOrder(t *testing.T) { "", "", "", + "", ) h1 := ComputeAgentFingerprint(m) h2 := ComputeAgentFingerprint(m) @@ -41,8 +42,8 @@ func TestComputeAgentFingerprint_stableAndToolOrder(t *testing.T) { t.Fatalf("fingerprint len=%d h1=%q h2=%q", len(h1), h1, h2) } - hA := ComputeAgentFingerprint(BuildAgentFingerprintPayload(spec, []string{"a", "b", "c"}, "auto", nil, 0, lim, "", "", "", "", "")) - hB := ComputeAgentFingerprint(BuildAgentFingerprintPayload(spec, []string{"c", "a", "b"}, "auto", nil, 0, lim, "", "", "", "", "")) + hA := ComputeAgentFingerprint(BuildAgentFingerprintPayload(spec, []string{"a", "b", "c"}, "auto", nil, 0, lim, "", "", "", "", "", "")) + hB := ComputeAgentFingerprint(BuildAgentFingerprintPayload(spec, []string{"c", "a", "b"}, "auto", nil, 0, lim, "", "", "", "", "", "")) if hA != hB { t.Fatalf("tool order should not matter: %q vs %q", hA, hB) } @@ -51,8 +52,8 @@ func TestComputeAgentFingerprint_stableAndToolOrder(t *testing.T) { func TestComputeAgentFingerprint_agentModeChangesDigest(t *testing.T) { spec := sdkruntime.AgentSpec{Name: "a", SystemPrompt: "p"} lim := sdkruntime.AgentLimits{MaxIterations: 3} - interactive := BuildAgentFingerprintPayload(spec, nil, "auto", nil, 0, lim, "", "", "", "", "") - autonomous := BuildAgentFingerprintPayload(spec, nil, "auto", nil, 0, lim, "", "", "", "autonomous", "") + interactive := BuildAgentFingerprintPayload(spec, nil, "auto", nil, 0, lim, "", "", "", "", "", "") + autonomous := BuildAgentFingerprintPayload(spec, nil, "auto", nil, 0, lim, "", "", "", "autonomous", "", "") if ComputeAgentFingerprint(interactive) == ComputeAgentFingerprint(autonomous) { t.Fatal("expected different digests for autonomous vs interactive") } @@ -62,8 +63,8 @@ func TestComputeAgentFingerprint_mcpFingerprintChangesDigest(t *testing.T) { spec := sdkruntime.AgentSpec{Name: "a", SystemPrompt: "p"} lim := sdkruntime.AgentLimits{MaxIterations: 3} tools := []string{"mcp_srv_echo"} - base := BuildAgentFingerprintPayload(spec, tools, "auto", nil, 0, lim, "", "", "", "", "") - withMCP := BuildAgentFingerprintPayload(spec, tools, "auto", nil, 0, lim, "abc123deadbeef", "", "", "", "") + base := BuildAgentFingerprintPayload(spec, tools, "auto", nil, 0, lim, "", "", "", "", "", "") + withMCP := BuildAgentFingerprintPayload(spec, tools, "auto", nil, 0, lim, "abc123deadbeef", "", "", "", "", "") h0 := ComputeAgentFingerprint(base) h1 := ComputeAgentFingerprint(withMCP) if h0 == h1 { @@ -75,8 +76,8 @@ func TestComputeAgentFingerprint_a2aFingerprintChangesDigest(t *testing.T) { spec := sdkruntime.AgentSpec{Name: "a", SystemPrompt: "p"} lim := sdkruntime.AgentLimits{MaxIterations: 3} tools := []string{"a2a_remote_echo"} - base := BuildAgentFingerprintPayload(spec, tools, "auto", nil, 0, lim, "", "", "", "", "") - withA2A := BuildAgentFingerprintPayload(spec, tools, "auto", nil, 0, lim, "", "a2afp_deadbeef", "", "", "") + base := BuildAgentFingerprintPayload(spec, tools, "auto", nil, 0, lim, "", "", "", "", "", "") + withA2A := BuildAgentFingerprintPayload(spec, tools, "auto", nil, 0, lim, "", "a2afp_deadbeef", "", "", "", "") h0 := ComputeAgentFingerprint(base) h1 := ComputeAgentFingerprint(withA2A) if h0 == h1 { @@ -84,12 +85,22 @@ func TestComputeAgentFingerprint_a2aFingerprintChangesDigest(t *testing.T) { } } +func TestComputeAgentFingerprint_retrieverFingerprintChangesDigest(t *testing.T) { + spec := sdkruntime.AgentSpec{Name: "a", SystemPrompt: "p"} + lim := sdkruntime.AgentLimits{MaxIterations: 3} + empty := BuildAgentFingerprintPayload(spec, nil, "auto", nil, 0, lim, "", "", "", "", "", "") + withFP := BuildAgentFingerprintPayload(spec, nil, "auto", nil, 0, lim, "", "", "", "", "", "retriever_fp_deadbeef") + if ComputeAgentFingerprint(empty) == ComputeAgentFingerprint(withFP) { + t.Fatal("expected different digests when retriever fingerprint set") + } +} + func TestComputeAgentFingerprint_observabilityFingerprintChangesDigest(t *testing.T) { spec := sdkruntime.AgentSpec{Name: "a", SystemPrompt: "p"} lim := sdkruntime.AgentLimits{MaxIterations: 3} tools := []string{"t1"} - base := BuildAgentFingerprintPayload(spec, tools, "auto", nil, 0, lim, "", "", "", "", "") - withObs := BuildAgentFingerprintPayload(spec, tools, "auto", nil, 0, lim, "", "", "obs_deadbeef", "", "") + base := BuildAgentFingerprintPayload(spec, tools, "auto", nil, 0, lim, "", "", "", "", "", "") + withObs := BuildAgentFingerprintPayload(spec, tools, "auto", nil, 0, lim, "", "", "obs_deadbeef", "", "", "") h0 := ComputeAgentFingerprint(base) h1 := ComputeAgentFingerprint(withObs) if h0 == h1 { @@ -192,7 +203,7 @@ func TestBuildAgentFingerprintPayload_responseFormatAndSampling(t *testing.T) { Reasoning: &interfaces.LLMReasoning{Effort: "low"}, } lim := sdkruntime.AgentLimits{MaxIterations: 1, Timeout: 0, ApprovalTimeout: 0} - p := BuildAgentFingerprintPayload(spec, []string{"t1"}, "p", sampling, 5, lim, "mcpfp", "", "", "", "") + p := BuildAgentFingerprintPayload(spec, []string{"t1"}, "p", sampling, 5, lim, "mcpfp", "", "", "", "", "") if p.ResponseFormat == nil || p.ResponseFormat.Type != string(interfaces.ResponseFormatJSON) { t.Fatalf("response format: %+v", p.ResponseFormat) } diff --git a/internal/runtime/temporal/runtime.go b/internal/runtime/temporal/runtime.go index 0ffeb02..8196fe2 100644 --- a/internal/runtime/temporal/runtime.go +++ b/internal/runtime/temporal/runtime.go @@ -137,6 +137,7 @@ func (rt *TemporalRuntime) Start(ctx context.Context) error { w.RegisterWorkflowWithOptions(rt.AgentWorkflow, workflow.RegisterOptions{Name: "AgentWorkflow"}) w.RegisterActivityWithOptions(rt.AgentLLMActivity, activity.RegisterOptions{Name: "AgentLLMActivity"}) w.RegisterActivityWithOptions(rt.AgentLLMStreamActivity, activity.RegisterOptions{Name: "AgentLLMStreamActivity"}) + w.RegisterActivityWithOptions(rt.AgentRetrieverActivity, activity.RegisterOptions{Name: "AgentRetrieverActivity"}) w.RegisterActivityWithOptions(rt.AgentToolAuthorizeActivity, activity.RegisterOptions{Name: "AgentToolAuthorizeActivity"}) w.RegisterActivityWithOptions(rt.AgentToolApprovalActivity, activity.RegisterOptions{Name: "AgentToolApprovalActivity"}) w.RegisterActivityWithOptions(rt.AgentToolExecuteActivity, activity.RegisterOptions{Name: "AgentToolExecuteActivity"}) diff --git a/internal/types/metrics.go b/internal/types/metrics.go index 28da9c7..7724d7d 100644 --- a/internal/types/metrics.go +++ b/internal/types/metrics.go @@ -38,8 +38,17 @@ const ( // Runtime — tool wall-clock latency. MetricToolLatencyMs = "agent.tool.latency_ms" + // Runtime — emitted per retriever.Search call (prefetch and hybrid modes). + MetricRetrieverCallStarted = "agent.retriever.call.started" + MetricRetrieverCallCompleted = "agent.retriever.call.completed" + MetricRetrieverCallFailed = "agent.retriever.call.failed" + + // Runtime — retriever search wall-clock latency. + MetricRetrieverLatencyMs = "agent.retriever.latency_ms" + // Attribute keys used on both metrics and spans. - MetricAttrModel = "model" - MetricAttrProvider = "provider" - MetricAttrTool = "tool" + MetricAttrModel = "model" + MetricAttrProvider = "provider" + MetricAttrTool = "tool" + MetricAttrRetriever = "retriever" ) diff --git a/internal/types/retriever.go b/internal/types/retriever.go new file mode 100644 index 0000000..8fde9a9 --- /dev/null +++ b/internal/types/retriever.go @@ -0,0 +1,30 @@ +package types + +// RetrieverToolParamQuery is the tool/JSON parameter name for the query sent to a retriever. +const RetrieverToolParamQuery = "query" + +// RetrieverDocFormat is the printf format used to render a single [interfaces.Document] for LLM context. +// Arguments: 1-based index (int), content (string), source (string), score (float64). +const RetrieverDocFormat = "[%d] %s\n(source: %s, score: %.2f)\n\n" + +// Default retriever settings. +const ( + DefaultTopK = 5 + DefaultMinScore = 0.75 + DefaultScheme = "http" + DefaultContentField = "content" + DefaultSourceField = "source" +) + +// RetrieverMode selects how registered retrievers participate in agent runs. +// String values are stable for configuration (see pkg/agent.WithRetrieverMode). +type RetrieverMode string + +const ( + // RetrieverModeAgentic is the default: the agent decides when to query retrievers (e.g. via tools). + RetrieverModeAgentic RetrieverMode = "agentic" + // RetrieverModePrefetch runs retrievers before the first LLM call and injects context up front. + RetrieverModePrefetch RetrieverMode = "prefetch" + // RetrieverModeHybrid combines prefetch with agentic retrieval during the run. + RetrieverModeHybrid RetrieverMode = "hybrid" +) diff --git a/pkg/agent/a2a.go b/pkg/agent/a2a.go index dd42fdd..03b7dc1 100644 --- a/pkg/agent/a2a.go +++ b/pkg/agent/a2a.go @@ -76,35 +76,35 @@ func NewA2ATool(serverName string, spec interfaces.ToolSpec, skillSpec interface } // Name implements [interfaces.Tool]. -func (m *A2ATool) Name() string { - if m == nil { +func (t *A2ATool) Name() string { + if t == nil { return "" } - return a2aToolName(m.ServerName, m.Spec.Name) + return a2aToolName(t.ServerName, t.Spec.Name) } // DisplayName implements [interfaces.Tool]. -func (m *A2ATool) DisplayName() string { - if m == nil { +func (t *A2ATool) DisplayName() string { + if t == nil { return "" } - return a2aToolDisplayName(m.ServerName, m.Spec.Name) + return a2aToolDisplayName(t.ServerName, t.Spec.Name) } // Description implements [interfaces.Tool]. -func (m *A2ATool) Description() string { - if m == nil { +func (t *A2ATool) Description() string { + if t == nil { return "" } - return m.Spec.Description + return t.Spec.Description } // Parameters implements [interfaces.Tool]. Returns a default object schema when spec parameters are nil. -func (m *A2ATool) Parameters() interfaces.JSONSchema { - if m == nil || m.Spec.Parameters == nil { +func (t *A2ATool) Parameters() interfaces.JSONSchema { + if t == nil || t.Spec.Parameters == nil { return interfaces.JSONSchema{"type": "object"} } - return m.Spec.Parameters + return t.Spec.Parameters } // Execute implements [interfaces.Tool]. @@ -115,15 +115,15 @@ func (m *A2ATool) Parameters() interfaces.JSONSchema { // - Task result (async): the task is JSON-encoded and returned as a string. // Callers that need full task lifecycle management should use [interfaces.A2AClient] directly. // - Empty result (neither message nor task): an empty string is returned without error. -func (m *A2ATool) Execute(ctx context.Context, args map[string]any) (any, error) { - if m == nil || m.Client == nil { +func (t *A2ATool) Execute(ctx context.Context, args map[string]any) (any, error) { + if t == nil || t.Client == nil { return nil, fmt.Errorf("a2a tool: nil client") } raw, err := json.Marshal(args) if err != nil { return nil, fmt.Errorf("a2a tool: marshal args: %w", err) } - result, err := m.Client.SendMessage(ctx, interfaces.A2ASendMessageRequest{ + result, err := t.Client.SendMessage(ctx, interfaces.A2ASendMessageRequest{ Message: interfaces.A2AMessage{ Role: "user", Parts: []interfaces.A2APart{{Kind: "text", Text: string(raw)}}, diff --git a/pkg/agent/config.go b/pkg/agent/config.go index 1e19fdd..551021d 100644 --- a/pkg/agent/config.go +++ b/pkg/agent/config.go @@ -56,6 +56,15 @@ const ( AgentToolExecutionModeSequential = types.AgentToolExecutionModeSequential ) +// RetrieverMode selects how retrievers are used in a run. Aliases [types.RetrieverMode]. +type RetrieverMode = types.RetrieverMode + +const ( + RetrieverModeAgentic = types.RetrieverModeAgentic + RetrieverModePrefetch = types.RetrieverModePrefetch + RetrieverModeHybrid = types.RetrieverModeHybrid +) + // MCPServers maps a stable server key (e.g. "github", "slack") to per-server MCP settings. // Registered tool names use prefix mcp__ (see [MCPTool]). // nil or empty map means no MCP servers from configuration. @@ -178,7 +187,7 @@ type ObservabilityConfig struct { // WithInstanceId, WithLLMClient, WithToolApprovalPolicy, WithTools, WithToolRegistry, // WithMaxIterations, WithStream, WithLogger, WithLogLevel, WithConversation, WithConversationSize, // WithResponseFormat, WithLLMSampling, WithSubAgents, WithMaxSubAgentDepth, -// WithMCPConfig, WithMCPClients, WithA2AConfig, WithA2AClients, WithAgentMode, WithDisableFingerprintCheck, WithAgentToolExecutionMode, +// WithMCPConfig, WithMCPClients, WithA2AConfig, WithA2AClients, WithRetrievers, WithRetrieverMode, WithAgentMode, WithDisableFingerprintCheck, WithAgentToolExecutionMode, // WithObservabilityConfig, WithTracer, WithMetrics, WithLogs // // When [WithObservabilityConfig] is set and a signal is not disabled, [buildAgentConfig] replaces @@ -235,6 +244,12 @@ type agentConfig struct { a2aClients []interfaces.A2AClient a2aTools []interfaces.Tool + // Retrievers: optional vector/document backends (e.g. Weaviate) for RAG; validated at build. + // retrieverTools is filled by buildRetrieverTools for agentic/hybrid modes (see [RetrieverTool]). + retrievers []interfaces.Retriever + retrieverMode RetrieverMode + retrieverTools []interfaces.Tool + //A2A Server: optional server config; merged at build into a2aServer (see RunA2A). a2aServerConfig *A2AServerConfig @@ -499,6 +514,27 @@ func WithA2AClients(clients ...interfaces.A2AClient) Option { } } +// WithRetrievers registers vector/document retrievers (e.g. [pkg/retriever/weaviate]). +// Each entry must be non-nil. Applies to Agent and AgentWorker. +func WithRetrievers(retrievers ...interfaces.Retriever) Option { + return func(c *agentConfig) { + if len(retrievers) == 0 { + c.retrievers = nil + return + } + c.retrievers = append([]interfaces.Retriever(nil), retrievers...) + } +} + +// WithRetrieverMode sets how retrievers participate in runs. Applies to Agent and AgentWorker. +// When omitted, [RetrieverModeAgentic] is used: retrievers are exposed as tools and the LLM +// decides when to call them. [RetrieverModeHybrid] combines pre-fetched context with agentic +// tool access. [RetrieverModePrefetch] injects context before the first LLM call without +// exposing retriever tools. +func WithRetrieverMode(mode RetrieverMode) Option { + return func(c *agentConfig) { c.retrieverMode = mode } +} + // WithA2ADefaultServer enables the built-in A2A HTTP server with default // settings (hostname "localhost", port 9999). Use this when you want to // expose the agent as an A2A server without customising the listen address. @@ -679,6 +715,17 @@ func buildAgentConfig(opts []Option) (*agentConfig, error) { if err := c.buildA2ATools(); err != nil { return nil, err } + if err := validateRetrievers(c.retrievers); err != nil { + return nil, err + } + mode, err := validateRetrieverMode(c.retrieverMode) + if err != nil { + return nil, err + } + c.retrieverMode = mode + if err := c.buildRetrieverTools(); err != nil { + return nil, err + } if err := c.buildSubAgentTools(); err != nil { return nil, err } @@ -793,6 +840,9 @@ func buildAgentConfig(opts []Option) (*agentConfig, error) { slog.Int("subAgentToolCount", len(c.subAgentTools)), slog.Int("mcpToolCount", len(c.mcpTools)), slog.Int("a2aToolCount", len(c.a2aTools)), + slog.Int("retrieverCount", len(c.retrievers)), + slog.Int("retrieverToolCount", len(c.retrieverTools)), + slog.String("retrieverMode", string(c.retrieverMode)), slog.Bool("hasConversation", c.conversation != nil), slog.Bool("hasObservability", c.observabilityConfig != nil), slog.Bool("enabledTracer", c.tracer != nil), @@ -814,7 +864,8 @@ func buildAgentConfig(opts []Option) (*agentConfig, error) { return c, nil } -// toolsList returns WithTools or registry tools, merged MCP tools ([mcpTools]), A2A tools ([a2aTools]), then [subAgentTools] from [buildSubAgentTools]. +// toolsList returns WithTools or registry tools, merged MCP tools ([mcpTools]), A2A tools ([a2aTools]), +// retriever tools ([retrieverTools]), then [subAgentTools] from [buildSubAgentTools]. func (c *agentConfig) toolsList() []interfaces.Tool { var base []interfaces.Tool if c.toolRegistry != nil { @@ -834,6 +885,12 @@ func (c *agentConfig) toolsList() []interfaces.Tool { copy(merged[len(base):], c.a2aTools) base = merged } + if len(c.retrieverTools) > 0 { + merged := make([]interfaces.Tool, len(base)+len(c.retrieverTools)) + copy(merged, base) + copy(merged[len(base):], c.retrieverTools) + base = merged + } if len(c.subAgentTools) > 0 { merged := make([]interfaces.Tool, len(base)+len(c.subAgentTools)) copy(merged, base) @@ -907,7 +964,8 @@ func dfsSubAgentDepth(a *Agent, path map[*Agent]struct{}, depth, maxDepth int) e return nil } -// validateToolNames ensures tool names are unique across WithTools/registry, MCP tools, A2A tools, and [subAgentTools]. +// validateToolNames ensures tool names are unique across WithTools/registry, MCP tools, A2A tools, +// retriever tools, and [subAgentTools]. func (c *agentConfig) validateToolNames() error { var base []interfaces.Tool if c.toolRegistry != nil { @@ -943,6 +1001,16 @@ func (c *agentConfig) validateToolNames() error { } names[n] = struct{}{} } + for _, t := range c.retrieverTools { + if t == nil { + return fmt.Errorf("retriever tool must not be nil") + } + n := t.Name() + if _, ok := names[n]; ok { + return fmt.Errorf("duplicate tool name %q: retriever tool conflicts with an existing tool", n) + } + names[n] = struct{}{} + } for _, t := range c.subAgentTools { if t == nil { return fmt.Errorf("sub-agent tool must not be nil") @@ -987,6 +1055,10 @@ func (c *agentConfig) runtimeAgentExecution() runtime.AgentExecution { Registry: c.toolRegistry, ApprovalPolicy: c.toolApprovalPolicy, }, + Retrievers: runtime.AgentRetrievers{ + Retrievers: c.retrievers, + Mode: c.retrieverMode, + }, Session: runtime.AgentSession{ Conversation: c.conversation, ConversationSize: c.conversationSize, @@ -1078,7 +1150,8 @@ func observabilityOptions(c *agentConfig) []observability.Option { // agentConfigFingerprint hashes identity, prompts, tools, sampling, limits, approval policy, // MCP wiring digest (transports, timeouts, filters, extra MCP client names), A2A wiring digest // for outbound clients only ([WithA2AConfig] / [WithA2AClients] via [a2aConfigFingerprint]), -// and observability OTLP wiring ([observabilityConfigFingerprint] from [WithObservabilityConfig]). +// observability OTLP wiring ([observabilityConfigFingerprint] from [WithObservabilityConfig]), +// [WithRetrieverMode], and retriever names ([retrieverConfigFingerprint]). // Same inputs as temporal.NewTemporalRuntime agent fingerprint. // // Inbound [A2AServerConfig] from [WithA2AServer] / [WithA2ADefaultServer] (listen address, @@ -1102,6 +1175,7 @@ func (c *agentConfig) agentConfigFingerprint() string { observabilityConfigFingerprint(c.observabilityConfig), string(c.agentMode), c.agentToolExecutionMode, + retrieverConfigFingerprint(c.retrieverMode, c.retrievers), ) return temporal.ComputeAgentFingerprint(mat) } @@ -1246,6 +1320,63 @@ func (c *agentConfig) buildMCPTools() error { return nil } +// validateRetrievers checks for nil entries in [WithRetrievers]. +func validateRetrievers(retrievers []interfaces.Retriever) error { + for i, r := range retrievers { + if r == nil { + return fmt.Errorf("retriever at index %d must not be nil", i) + } + } + return nil +} + +// validateRetrieverMode applies the default [RetrieverModeAgentic] when mode is empty and +// ensures mode is one of the supported values. +func validateRetrieverMode(mode RetrieverMode) (RetrieverMode, error) { + if mode == "" { + mode = RetrieverModeAgentic + } + switch mode { + case RetrieverModeAgentic, RetrieverModePrefetch, RetrieverModeHybrid: + return mode, nil + default: + return "", fmt.Errorf("invalid retriever mode %q: use %q, %q, or %q", + mode, RetrieverModeAgentic, RetrieverModePrefetch, RetrieverModeHybrid) + } +} + +// buildRetrieverTools registers a [RetrieverTool] per [WithRetrievers] entry when mode is +// [RetrieverModeAgentic] or [RetrieverModeHybrid], and appends to [agentConfig.retrieverTools]. +// [RetrieverModePrefetch] does not expose tools (context is injected before the first LLM call). +func (c *agentConfig) buildRetrieverTools() error { + c.retrieverTools = nil + if c.retrieverMode != RetrieverModeAgentic && c.retrieverMode != RetrieverModeHybrid { + return nil + } + if len(c.retrievers) == 0 { + return nil + } + seen := make(map[string]struct{}, len(c.retrievers)) + tools := make([]interfaces.Tool, 0, len(c.retrievers)) + for _, r := range c.retrievers { + n := strings.TrimSpace(r.Name()) + if n == "" { + return fmt.Errorf("retriever name must not be empty") + } + if _, dup := seen[n]; dup { + return fmt.Errorf("duplicate retriever name %q", n) + } + seen[n] = struct{}{} + tool := NewRetrieverTool(r) + if tool == nil { + return fmt.Errorf("retriever %q: failed to build tool", n) + } + tools = append(tools, tool) + } + c.retrieverTools = tools + return nil +} + // validateA2AClients checks for nil clients, empty names, and duplicate [interfaces.A2AClient.Name] values. func validateA2AClients(clients []interfaces.A2AClient) error { seen := make(map[string]struct{}, len(clients)) diff --git a/pkg/agent/config_test.go b/pkg/agent/config_test.go index 3c33145..69de6e9 100644 --- a/pkg/agent/config_test.go +++ b/pkg/agent/config_test.go @@ -463,6 +463,424 @@ func (s *stubA2AClient) Close() error { return nil } var _ interfaces.A2AClient = (*stubA2AClient)(nil) +type stubRetriever struct{} + +func (stubRetriever) Name() string { return "stub" } + +func (stubRetriever) Search(context.Context, string) ([]interfaces.Document, error) { + return nil, nil +} + +var _ interfaces.Retriever = stubRetriever{} + +type namedStubRetriever string + +func (n namedStubRetriever) Name() string { return string(n) } + +func (namedStubRetriever) Search(context.Context, string) ([]interfaces.Document, error) { + return nil, nil +} + +var _ interfaces.Retriever = namedStubRetriever("") + +// --------------------------------------------------------------------------- +// Retriever config tests +// --------------------------------------------------------------------------- + +func TestValidateRetrieverMode(t *testing.T) { + t.Run("default", func(t *testing.T) { + mode, err := validateRetrieverMode("") + if err != nil || mode != RetrieverModeAgentic { + t.Fatalf("mode=%q err=%v", mode, err) + } + }) + t.Run("valid", func(t *testing.T) { + for _, want := range []RetrieverMode{ + RetrieverModeAgentic, + RetrieverModePrefetch, + RetrieverModeHybrid, + } { + mode, err := validateRetrieverMode(want) + if err != nil || mode != want { + t.Fatalf("want %q got %q err=%v", want, mode, err) + } + } + }) + t.Run("invalid", func(t *testing.T) { + _, err := validateRetrieverMode(RetrieverMode("bogus")) + if err == nil || !strings.Contains(err.Error(), "invalid retriever mode") { + t.Fatalf("got %v", err) + } + }) +} + +func TestValidateRetrievers(t *testing.T) { + t.Run("nil", func(t *testing.T) { + err := validateRetrievers([]interfaces.Retriever{nil}) + if err == nil || !strings.Contains(err.Error(), "nil") { + t.Fatalf("got %v", err) + } + }) + t.Run("ok", func(t *testing.T) { + if err := validateRetrievers([]interfaces.Retriever{stubRetriever{}, stubRetriever{}}); err != nil { + t.Fatalf("got %v", err) + } + }) +} + +func TestBuildRetrieverTools(t *testing.T) { + t.Run("agentic_builds_tools", func(t *testing.T) { + c := &agentConfig{ + retrieverMode: RetrieverModeAgentic, + retrievers: []interfaces.Retriever{namedStubRetriever("kb")}, + } + if err := c.buildRetrieverTools(); err != nil { + t.Fatal(err) + } + if len(c.retrieverTools) != 1 || c.retrieverTools[0].Name() != "retriever_kb" { + t.Fatalf("retrieverTools = %v", c.retrieverTools) + } + }) + t.Run("hybrid_builds_tools", func(t *testing.T) { + c := &agentConfig{ + retrieverMode: RetrieverModeHybrid, + retrievers: []interfaces.Retriever{stubRetriever{}}, + } + if err := c.buildRetrieverTools(); err != nil { + t.Fatal(err) + } + if len(c.retrieverTools) != 1 { + t.Fatalf("len = %d", len(c.retrieverTools)) + } + }) + t.Run("prefetch_skips_tools", func(t *testing.T) { + c := &agentConfig{ + retrieverMode: RetrieverModePrefetch, + retrievers: []interfaces.Retriever{stubRetriever{}}, + } + if err := c.buildRetrieverTools(); err != nil { + t.Fatal(err) + } + if c.retrieverTools != nil { + t.Fatalf("retrieverTools = %v, want nil", c.retrieverTools) + } + }) + t.Run("no_retrievers", func(t *testing.T) { + c := &agentConfig{retrieverMode: RetrieverModeAgentic} + if err := c.buildRetrieverTools(); err != nil { + t.Fatal(err) + } + if c.retrieverTools != nil { + t.Fatalf("retrieverTools = %v, want nil", c.retrieverTools) + } + }) + t.Run("duplicate_name", func(t *testing.T) { + c := &agentConfig{ + retrieverMode: RetrieverModeAgentic, + retrievers: []interfaces.Retriever{namedStubRetriever("x"), namedStubRetriever("x")}, + } + err := c.buildRetrieverTools() + if err == nil || !strings.Contains(err.Error(), "duplicate retriever name") { + t.Fatalf("got %v", err) + } + }) + t.Run("empty_name", func(t *testing.T) { + c := &agentConfig{ + retrieverMode: RetrieverModeAgentic, + retrievers: []interfaces.Retriever{namedStubRetriever(" ")}, + } + err := c.buildRetrieverTools() + if err == nil || !strings.Contains(err.Error(), "must not be empty") { + t.Fatalf("got %v", err) + } + }) +} + +func TestBuildAgentConfig_WithRetrievers(t *testing.T) { + r1, r2 := namedStubRetriever("kb-a"), namedStubRetriever("kb-b") + cfg, err := buildAgentConfig([]Option{ + WithName("test"), + WithTemporalConfig(&TemporalConfig{TaskQueue: "q"}), + WithLLMClient(stubLLM{}), + WithRetrievers(r1, r2), + }) + if err != nil { + t.Fatal(err) + } + if len(cfg.retrievers) != 2 { + t.Fatalf("retrievers len = %d", len(cfg.retrievers)) + } + if len(cfg.retrieverTools) != 2 { + t.Fatalf("retrieverTools len = %d, want 2 (default agentic mode)", len(cfg.retrieverTools)) + } +} + +func TestBuildAgentConfig_RetrieverMode_prefetchNoTools(t *testing.T) { + cfg, err := buildAgentConfig([]Option{ + WithName("test"), + WithTemporalConfig(&TemporalConfig{TaskQueue: "q"}), + WithLLMClient(stubLLM{}), + WithRetrievers(stubRetriever{}), + WithRetrieverMode(RetrieverModePrefetch), + }) + if err != nil { + t.Fatal(err) + } + if len(cfg.retrieverTools) != 0 { + t.Fatalf("retrieverTools len = %d, want 0 for prefetch", len(cfg.retrieverTools)) + } +} + +func TestBuildAgentConfig_RetrieverMode_agenticBuildsTools(t *testing.T) { + cfg, err := buildAgentConfig([]Option{ + WithName("test"), + WithTemporalConfig(&TemporalConfig{TaskQueue: "q"}), + WithLLMClient(stubLLM{}), + WithRetrievers(stubRetriever{}), + WithRetrieverMode(RetrieverModeAgentic), + }) + if err != nil { + t.Fatal(err) + } + if len(cfg.retrieverTools) != 1 || cfg.retrieverTools[0].Name() != "retriever_stub" { + t.Fatalf("retrieverTools = %v", cfg.retrieverTools) + } +} + +func TestBuildAgentConfig_AgenticNoRetrievers_NoTools(t *testing.T) { + cfg, err := buildAgentConfig([]Option{ + WithName("test"), + WithTemporalConfig(&TemporalConfig{TaskQueue: "q"}), + WithLLMClient(stubLLM{}), + WithRetrieverMode(RetrieverModeAgentic), + }) + if err != nil { + t.Fatal(err) + } + if len(cfg.retrieverTools) != 0 { + t.Fatalf("retrieverTools len = %d, want 0", len(cfg.retrieverTools)) + } +} + +func TestBuildAgentConfig_RetrieverMode_hybridBuildsTools(t *testing.T) { + cfg, err := buildAgentConfig([]Option{ + WithName("test"), + WithTemporalConfig(&TemporalConfig{TaskQueue: "q"}), + WithLLMClient(stubLLM{}), + WithRetrievers(stubRetriever{}), + WithRetrieverMode(RetrieverModeHybrid), + }) + if err != nil { + t.Fatal(err) + } + if len(cfg.retrieverTools) != 1 || cfg.retrieverTools[0].Name() != "retriever_stub" { + t.Fatalf("retrieverTools = %v", cfg.retrieverTools) + } +} + +func TestBuildAgentConfig_RetrieverDuplicateName(t *testing.T) { + _, err := buildAgentConfig([]Option{ + WithName("test"), + WithTemporalConfig(&TemporalConfig{TaskQueue: "q"}), + WithLLMClient(stubLLM{}), + WithRetrievers(namedStubRetriever("dup"), namedStubRetriever("dup")), + }) + if err == nil || !strings.Contains(err.Error(), "duplicate retriever name") { + t.Fatalf("got %v", err) + } +} + +func TestBuildAgentConfig_toolsList_includesRetrieverTools(t *testing.T) { + cfg, err := buildAgentConfig([]Option{ + WithName("test"), + WithTemporalConfig(&TemporalConfig{TaskQueue: "q"}), + WithLLMClient(stubLLM{}), + WithTools(mockTool{name: "echo"}), + WithRetrievers(stubRetriever{}), + }) + if err != nil { + t.Fatal(err) + } + list := cfg.toolsList() + if len(list) != 2 { + t.Fatalf("toolsList len = %d, want 2", len(list)) + } + if list[1].Name() != "retriever_stub" { + t.Fatalf("tool[1].Name = %q", list[1].Name()) + } +} + +func TestBuildAgentConfig_validateToolNames_RetrieverConflict(t *testing.T) { + c := &agentConfig{ + tools: []interfaces.Tool{mockTool{name: "retriever_stub"}}, + retrieverTools: []interfaces.Tool{ + NewRetrieverTool(stubRetriever{}), + }, + } + err := c.validateToolNames() + if err == nil || !strings.Contains(err.Error(), "retriever tool conflicts") { + t.Fatalf("got %v", err) + } +} + +func TestBuildAgentConfig_validateToolNames_nilRetrieverTool(t *testing.T) { + c := &agentConfig{retrieverTools: []interfaces.Tool{nil}} + err := c.validateToolNames() + if err == nil || !strings.Contains(err.Error(), "retriever tool must not be nil") { + t.Fatalf("got %v", err) + } +} + +func TestBuildAgentConfig_WithRetrievers_nilEntry(t *testing.T) { + _, err := buildAgentConfig([]Option{ + WithName("test"), + WithTemporalConfig(&TemporalConfig{TaskQueue: "q"}), + WithLLMClient(stubLLM{}), + WithRetrievers(stubRetriever{}, nil), + }) + if err == nil || !strings.Contains(err.Error(), "nil") { + t.Fatalf("got %v", err) + } +} + +func TestBuildAgentConfig_WithRetrievers_emptyClears(t *testing.T) { + cfg, err := buildAgentConfig([]Option{ + WithName("test"), + WithTemporalConfig(&TemporalConfig{TaskQueue: "q"}), + WithLLMClient(stubLLM{}), + WithRetrievers(stubRetriever{}), + WithRetrievers(), + }) + if err != nil { + t.Fatal(err) + } + if cfg.retrievers != nil { + t.Fatalf("retrievers = %v, want nil", cfg.retrievers) + } + if len(cfg.retrieverTools) != 0 { + t.Fatalf("retrieverTools len = %d, want 0", len(cfg.retrieverTools)) + } +} + +func TestBuildAgentConfig_RetrieverMode_default(t *testing.T) { + cfg, err := buildAgentConfig([]Option{ + WithName("test"), + WithTemporalConfig(&TemporalConfig{TaskQueue: "q"}), + WithLLMClient(stubLLM{}), + }) + if err != nil { + t.Fatal(err) + } + if cfg.retrieverMode != RetrieverModeAgentic { + t.Fatalf("retrieverMode = %q, want %q", cfg.retrieverMode, RetrieverModeAgentic) + } +} + +func TestBuildAgentConfig_RetrieverMode_explicit(t *testing.T) { + for _, mode := range []RetrieverMode{ + RetrieverModeAgentic, + RetrieverModePrefetch, + RetrieverModeHybrid, + } { + t.Run(string(mode), func(t *testing.T) { + cfg, err := buildAgentConfig([]Option{ + WithName("test"), + WithTemporalConfig(&TemporalConfig{TaskQueue: "q"}), + WithLLMClient(stubLLM{}), + WithRetrieverMode(mode), + }) + if err != nil { + t.Fatal(err) + } + if cfg.retrieverMode != mode { + t.Fatalf("retrieverMode = %q, want %q", cfg.retrieverMode, mode) + } + }) + } +} + +func TestAgentConfigFingerprint_RetrieverModeChangesDigest(t *testing.T) { + baseOpts := []Option{ + WithName("test"), + WithTemporalConfig(&TemporalConfig{TaskQueue: "q"}), + WithLLMClient(stubLLM{}), + } + build := func(mode RetrieverMode) string { + t.Helper() + opts := append(append([]Option(nil), baseOpts...), WithRetrieverMode(mode)) + cfg, err := buildAgentConfig(opts) + if err != nil { + t.Fatal(err) + } + return cfg.agentConfigFingerprint() + } + fpAgentic := build(RetrieverModeAgentic) + fpPrefetch := build(RetrieverModePrefetch) + fpHybrid := build(RetrieverModeHybrid) + if fpAgentic == fpPrefetch { + t.Fatal("expected different fingerprints for agentic vs prefetch retriever mode") + } + if fpAgentic == fpHybrid { + t.Fatal("expected different fingerprints for agentic vs hybrid retriever mode") + } + if fpPrefetch == fpHybrid { + t.Fatal("expected different fingerprints for prefetch vs hybrid retriever mode") + } +} + +func TestBuildAgentConfig_toolsList_includesRetrieverTools_hybrid(t *testing.T) { + cfg, err := buildAgentConfig([]Option{ + WithName("test"), + WithTemporalConfig(&TemporalConfig{TaskQueue: "q"}), + WithLLMClient(stubLLM{}), + WithTools(mockTool{name: "echo"}), + WithRetrievers(stubRetriever{}), + WithRetrieverMode(RetrieverModeHybrid), + }) + if err != nil { + t.Fatal(err) + } + list := cfg.toolsList() + if len(list) != 2 { + t.Fatalf("toolsList len = %d, want 2 (base tool + retriever tool)", len(list)) + } + if list[1].Name() != "retriever_stub" { + t.Fatalf("tool[1].Name = %q, want retriever_stub", list[1].Name()) + } +} + +func TestAgentConfigFingerprint_AgenticRetrieverNamesChangesDigest(t *testing.T) { + baseOpts := []Option{ + WithName("test"), + WithTemporalConfig(&TemporalConfig{TaskQueue: "q"}), + WithLLMClient(stubLLM{}), + WithRetrieverMode(RetrieverModeAgentic), + } + cfgNoR, err := buildAgentConfig(baseOpts) + if err != nil { + t.Fatal(err) + } + cfgWithR, err := buildAgentConfig(append(baseOpts, WithRetrievers(namedStubRetriever("wiki")))) + if err != nil { + t.Fatal(err) + } + if cfgNoR.agentConfigFingerprint() == cfgWithR.agentConfigFingerprint() { + t.Fatal("expected different fingerprints for agentic mode with vs without retriever names") + } +} + +func TestBuildAgentConfig_RetrieverMode_invalid(t *testing.T) { + _, err := buildAgentConfig([]Option{ + WithName("test"), + WithTemporalConfig(&TemporalConfig{TaskQueue: "q"}), + WithLLMClient(stubLLM{}), + WithRetrieverMode(RetrieverMode("bogus")), + }) + if err == nil || !strings.Contains(err.Error(), "invalid retriever mode") { + t.Fatalf("got %v", err) + } +} + // --------------------------------------------------------------------------- // A2A config tests // --------------------------------------------------------------------------- diff --git a/pkg/agent/mcp.go b/pkg/agent/mcp.go index 77226e3..23c239a 100644 --- a/pkg/agent/mcp.go +++ b/pkg/agent/mcp.go @@ -63,47 +63,47 @@ func NewMCPTool(serverName string, spec interfaces.ToolSpec, client interfaces.M } // Name implements interfaces.Tool. -func (m *MCPTool) Name() string { - if m == nil { +func (t *MCPTool) Name() string { + if t == nil { return "" } - return mcpToolName(m.ServerName, m.Spec.Name) + return mcpToolName(t.ServerName, t.Spec.Name) } // DisplayName implements interfaces.Tool. -func (m *MCPTool) DisplayName() string { - if m == nil { +func (t *MCPTool) DisplayName() string { + if t == nil { return "" } - return mcpToolDisplayName(m.ServerName, m.Spec.Name) + return mcpToolDisplayName(t.ServerName, t.Spec.Name) } // Description implements interfaces.Tool. -func (m *MCPTool) Description() string { - if m == nil { +func (t *MCPTool) Description() string { + if t == nil { return "" } - return m.Spec.Description + return t.Spec.Description } // Parameters implements interfaces.Tool. -func (m *MCPTool) Parameters() interfaces.JSONSchema { - if m == nil || m.Spec.Parameters == nil { +func (t *MCPTool) Parameters() interfaces.JSONSchema { + if t == nil || t.Spec.Parameters == nil { return interfaces.JSONSchema{"type": "object"} } - return m.Spec.Parameters + return t.Spec.Parameters } // Execute implements interfaces.Tool: marshal args, CallTool, return decoded JSON or raw string. -func (m *MCPTool) Execute(ctx context.Context, args map[string]any) (any, error) { - if m == nil || m.Client == nil { +func (t *MCPTool) Execute(ctx context.Context, args map[string]any) (any, error) { + if t == nil || t.Client == nil { return nil, fmt.Errorf("mcp tool: nil client") } raw, err := json.Marshal(args) if err != nil { return nil, err } - out, err := m.Client.CallTool(ctx, m.Spec.Name, raw) + out, err := t.Client.CallTool(ctx, t.Spec.Name, raw) if err != nil { return nil, err } diff --git a/pkg/agent/retriever.go b/pkg/agent/retriever.go new file mode 100644 index 0000000..f8e1901 --- /dev/null +++ b/pkg/agent/retriever.go @@ -0,0 +1,176 @@ +package agent + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "sort" + "strings" + + "github.com/agenticenv/agent-sdk-go/internal/types" + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" + "github.com/agenticenv/agent-sdk-go/pkg/tools" +) + +var ( + retrieverToolNameTemplate = "retriever_%s" + retrieverToolDisplayNameTemplate = "%s Retriever Tool" +) + +var _ interfaces.Tool = (*RetrieverTool)(nil) + +// RetrieverTool implements [interfaces.Tool] for [RetrieverModeAgentic] and [RetrieverModeHybrid]. +type RetrieverTool struct { + // RetrieverName is the stable key from [interfaces.Retriever.Name] (used in tool name / display name). + RetrieverName string + Retriever interfaces.Retriever +} + +// retrieverToolName returns the registered tool name for a retriever name +// (same format as [RetrieverTool.Name]). Trims whitespace; returns "" if empty after trim. +func retrieverToolName(retrieverName string) string { + n := strings.TrimSpace(retrieverName) + if n == "" { + return "" + } + return fmt.Sprintf(retrieverToolNameTemplate, n) +} + +// retrieverToolDisplayName returns the display name for a retriever name +// (same format as [RetrieverTool.DisplayName]). Trims whitespace; returns "" if empty after trim. +func retrieverToolDisplayName(retrieverName string) string { + n := strings.TrimSpace(retrieverName) + if n == "" { + return "" + } + return fmt.Sprintf(retrieverToolDisplayNameTemplate, n) +} + +// NewRetrieverTool builds a RetrieverTool. Returns nil when retriever is nil or [interfaces.Retriever.Name] is empty. +func NewRetrieverTool(retriever interfaces.Retriever) interfaces.Tool { + if retriever == nil { + return nil + } + rn := strings.TrimSpace(retriever.Name()) + if rn == "" { + return nil + } + return &RetrieverTool{RetrieverName: rn, Retriever: retriever} +} + +// Name implements [interfaces.Tool]. +func (t *RetrieverTool) Name() string { + if t == nil { + return "" + } + return retrieverToolName(t.RetrieverName) +} + +// DisplayName implements [interfaces.Tool]. +func (t *RetrieverTool) DisplayName() string { + if t == nil { + return "" + } + return fmt.Sprintf(retrieverToolDisplayNameTemplate, t.RetrieverName) +} + +// Description implements [interfaces.Tool]. +func (t *RetrieverTool) Description() string { + if t == nil { + return "" + } + return fmt.Sprintf( + "Search the %s knowledge base for relevant context. "+ + "Call this when you need external knowledge to answer the user query.", + t.RetrieverName, + ) +} + +// Parameters implements [interfaces.Tool]. Requires [types.RetrieverToolParamQuery]. +func (t *RetrieverTool) Parameters() interfaces.JSONSchema { + if t == nil { + return interfaces.JSONSchema{"type": "object"} + } + return tools.Params(map[string]interfaces.JSONSchema{ + types.RetrieverToolParamQuery: tools.ParamString( + fmt.Sprintf("Search query to find relevant knowledge in %s", t.RetrieverName), + ), + }, types.RetrieverToolParamQuery) +} + +// Execute implements [interfaces.Tool]: reads the query argument, calls [interfaces.Retriever.Search], +// and returns a numbered plain-text summary of matching documents. +func (t *RetrieverTool) Execute(ctx context.Context, args map[string]any) (any, error) { + if t.Retriever == nil { + return nil, fmt.Errorf("retriever tool: nil retriever") + } + raw, ok := args[types.RetrieverToolParamQuery].(string) + if !ok { + return nil, fmt.Errorf("retriever tool: %q parameter required", types.RetrieverToolParamQuery) + } + query := strings.TrimSpace(raw) + if query == "" { + return nil, fmt.Errorf("retriever tool: %q must be non-empty", types.RetrieverToolParamQuery) + } + docs, err := t.Retriever.Search(ctx, query) + if err != nil { + return nil, err + } + return formatRetrieverDocs(docs), nil +} + +func formatRetrieverDocs(docs []interfaces.Document) string { + if len(docs) == 0 { + return "no relevant documents found" + } + var sb strings.Builder + for i, doc := range docs { + fmt.Fprintf(&sb, types.RetrieverDocFormat, i+1, doc.Content, doc.Source, doc.Score) + } + return sb.String() +} + +// --------------------------------------------------------------------------- +// Fingerprint +// --------------------------------------------------------------------------- + +// retrieverConfigFingerprint returns a stable SHA-256 digest of retriever mode and retriever names +// for [agentConfig.agentConfigFingerprint]. Names are deduplicated, whitespace-trimmed, and sorted +// for stability. Returns "" for [RetrieverModeAgentic] with no retrievers (no fingerprint contribution). +func retrieverConfigFingerprint(mode types.RetrieverMode, retrievers []interfaces.Retriever) string { + rm := mode + if rm == "" { + rm = types.RetrieverModeAgentic + } + seen := make(map[string]struct{}, len(retrievers)) + var names []string + for _, r := range retrievers { + if r == nil { + continue + } + n := strings.TrimSpace(r.Name()) + if n == "" { + continue + } + if _, dup := seen[n]; dup { + continue + } + seen[n] = struct{}{} + names = append(names, n) + } + sort.Strings(names) + if len(names) == 0 && rm == types.RetrieverModeAgentic { + return "" + } + b, err := json.Marshal(struct { + Mode string `json:"mode"` + Names []string `json:"names,omitempty"` + }{Mode: string(rm), Names: names}) + if err != nil { + return "" + } + sum := sha256.Sum256(b) + return hex.EncodeToString(sum[:]) +} diff --git a/pkg/agent/retriever_test.go b/pkg/agent/retriever_test.go new file mode 100644 index 0000000..7d3428c --- /dev/null +++ b/pkg/agent/retriever_test.go @@ -0,0 +1,323 @@ +package agent + +import ( + "context" + "errors" + "strings" + "testing" + + "github.com/agenticenv/agent-sdk-go/internal/types" + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" +) + +// testCtxKey is an unexported type used as a context key in tests (avoids SA1029 staticcheck). +type testCtxKey struct{} + +type retrieverExecuteStub struct { + name string + lastCtx context.Context + lastQuery string + docs []interfaces.Document + err error +} + +func (r *retrieverExecuteStub) Name() string { return r.name } + +func (r *retrieverExecuteStub) Search(ctx context.Context, query string) ([]interfaces.Document, error) { + r.lastCtx = ctx + r.lastQuery = query + if r.err != nil { + return nil, r.err + } + return r.docs, nil +} + +func TestRetrieverToolName(t *testing.T) { + if got := retrieverToolName(" wiki "); got != "retriever_wiki" { + t.Fatalf("got %q", got) + } + if retrieverToolName("") != "" || retrieverToolName(" ") != "" { + t.Fatal("expected empty for missing name") + } +} + +func TestRetrieverToolDisplayName(t *testing.T) { + if got := retrieverToolDisplayName("wiki"); got != "wiki Retriever Tool" { + t.Fatalf("got %q", got) + } + if retrieverToolDisplayName(" ") != "" { + t.Fatal("expected empty") + } +} + +func TestNewRetrieverTool_nil(t *testing.T) { + if NewRetrieverTool(nil) != nil { + t.Fatal("expected nil") + } +} + +func TestNewRetrieverTool_emptyName(t *testing.T) { + if NewRetrieverTool(&retrieverExecuteStub{name: " "}) != nil { + t.Fatal("expected nil for empty name") + } +} + +func TestRetrieverTool_NilReceiver(t *testing.T) { + var t0 *RetrieverTool + if t0.Name() != "" || t0.DisplayName() != "" || t0.Description() != "" { + t.Fatal("nil receiver should return empty strings") + } + p := t0.Parameters() + if p["type"] != "object" { + t.Fatalf("Parameters = %v", p) + } +} + +func TestRetrieverToolName_ViaStruct(t *testing.T) { + tool := NewRetrieverTool(&retrieverExecuteStub{name: " weaviate "}) + if tool.Name() != "retriever_weaviate" { + t.Fatalf("Name = %q", tool.Name()) + } +} + +func TestRetrieverTool_NameDisplayDescription(t *testing.T) { + tool := NewRetrieverTool(&retrieverExecuteStub{name: "weaviate"}).(*RetrieverTool) + if tool.Name() != "retriever_weaviate" { + t.Fatalf("Name = %q", tool.Name()) + } + if tool.DisplayName() != "weaviate Retriever Tool" { + t.Fatalf("DisplayName = %q", tool.DisplayName()) + } + if !strings.Contains(tool.Description(), "weaviate") { + t.Fatalf("Description = %q", tool.Description()) + } +} + +func TestRetrieverTool_Parameters(t *testing.T) { + tool := NewRetrieverTool(&retrieverExecuteStub{name: "kb"}) + p := tool.Parameters() + if p["type"] != "object" { + t.Fatalf("type = %v", p["type"]) + } + req, ok := p["required"].([]string) + if !ok || len(req) != 1 || req[0] != types.RetrieverToolParamQuery { + t.Fatalf("required = %v", p["required"]) + } +} + +func TestRetrieverTool_Execute_success(t *testing.T) { + stub := &retrieverExecuteStub{ + name: "kb", + docs: []interfaces.Document{ + {Content: "Go is great", Source: "doc1.md", Score: 0.9}, + {Content: "Rust is fast", Source: "doc2.md", Score: 0.8}, + }, + } + tool := NewRetrieverTool(stub) + ctx := context.WithValue(context.Background(), testCtxKey{}, "marker") + + out, err := tool.Execute(ctx, map[string]any{types.RetrieverToolParamQuery: " golang "}) + if err != nil { + t.Fatal(err) + } + s, ok := out.(string) + if !ok { + t.Fatalf("got %T", out) + } + if !strings.Contains(s, "[1] Go is great") || !strings.Contains(s, "[2] Rust is fast") { + t.Fatalf("output = %q", s) + } + if stub.lastQuery != "golang" { + t.Fatalf("query = %q", stub.lastQuery) + } + if stub.lastCtx != ctx { + t.Fatal("Search did not receive Execute context") + } +} + +func TestRetrieverTool_Execute_noDocs(t *testing.T) { + tool := NewRetrieverTool(&retrieverExecuteStub{name: "kb"}) + out, err := tool.Execute(context.Background(), map[string]any{types.RetrieverToolParamQuery: "x"}) + if err != nil { + t.Fatal(err) + } + if out != "no relevant documents found" { + t.Fatalf("got %q", out) + } +} + +func TestRetrieverTool_Execute_missingQuery(t *testing.T) { + tool := NewRetrieverTool(&retrieverExecuteStub{name: "kb"}) + _, err := tool.Execute(context.Background(), map[string]any{}) + if err == nil || !strings.Contains(err.Error(), types.RetrieverToolParamQuery) { + t.Fatalf("got %v", err) + } +} + +func TestRetrieverTool_Execute_emptyQuery(t *testing.T) { + tool := NewRetrieverTool(&retrieverExecuteStub{name: "kb"}) + _, err := tool.Execute(context.Background(), map[string]any{types.RetrieverToolParamQuery: " "}) + if err == nil || !strings.Contains(err.Error(), "non-empty") { + t.Fatalf("got %v", err) + } +} + +func TestRetrieverTool_Execute_searchError(t *testing.T) { + want := errors.New("search failed") + tool := NewRetrieverTool(&retrieverExecuteStub{name: "kb", err: want}) + _, err := tool.Execute(context.Background(), map[string]any{types.RetrieverToolParamQuery: "q"}) + if !errors.Is(err, want) { + t.Fatalf("got %v", err) + } +} + +func TestRetrieverTool_Execute_nilRetriever(t *testing.T) { + tool := &RetrieverTool{RetrieverName: "kb", Retriever: nil} + _, err := tool.Execute(context.Background(), map[string]any{types.RetrieverToolParamQuery: "q"}) + if err == nil || !strings.Contains(err.Error(), "nil retriever") { + t.Fatalf("got %v", err) + } +} + +func TestFormatRetrieverDocs(t *testing.T) { + if formatRetrieverDocs(nil) != "no relevant documents found" { + t.Fatal("nil docs") + } + got := formatRetrieverDocs([]interfaces.Document{ + {Content: "alpha", Source: "a", Score: 0.5}, + }) + if !strings.Contains(got, "[1] alpha") || !strings.Contains(got, "score: 0.50") { + t.Fatalf("got %q", got) + } +} + +// --------------------------------------------------------------------------- +// retrieverConfigFingerprint tests + +func TestRetrieverConfigFingerprint_nilEntriesIgnored(t *testing.T) { + r := &retrieverExecuteStub{name: "kb"} + // A list with nil entries must produce the same fingerprint as the list without them. + fpClean := retrieverConfigFingerprint(RetrieverModeAgentic, []interfaces.Retriever{r}) + fpNils := retrieverConfigFingerprint(RetrieverModeAgentic, []interfaces.Retriever{nil, r, nil}) + if fpClean != fpNils { + t.Fatalf("nil entries must not affect the fingerprint: %q vs %q", fpClean, fpNils) + } +} + +func TestRetrieverConfigFingerprint_duplicateNamesIgnored(t *testing.T) { + r1 := &retrieverExecuteStub{name: "kb"} + r2 := &retrieverExecuteStub{name: "kb"} + fpOne := retrieverConfigFingerprint(RetrieverModeAgentic, []interfaces.Retriever{r1}) + fpDup := retrieverConfigFingerprint(RetrieverModeAgentic, []interfaces.Retriever{r1, r2}) + if fpOne != fpDup { + t.Fatalf("duplicate retriever names must not affect the fingerprint: %q vs %q", fpOne, fpDup) + } +} + +func TestRetrieverConfigFingerprint_agenticEmptyReturnsEmpty(t *testing.T) { + if fp := retrieverConfigFingerprint(RetrieverModeAgentic, nil); fp != "" { + t.Fatalf("agentic with no retrievers should return empty string, got %q", fp) + } + // Explicit agentic with empty-name-only retrievers (all filtered out) also returns "". + r := &retrieverExecuteStub{name: " "} + if fp := retrieverConfigFingerprint(RetrieverModeAgentic, []interfaces.Retriever{r}); fp != "" { + t.Fatalf("agentic with blank-name-only retrievers should return empty string, got %q", fp) + } +} + +func TestRetrieverConfigFingerprint_agenticWithNamesNonEmpty(t *testing.T) { + r := &retrieverExecuteStub{name: "kb"} + fp := retrieverConfigFingerprint(RetrieverModeAgentic, []interfaces.Retriever{r}) + if fp == "" { + t.Fatal("agentic with a named retriever should produce a non-empty fingerprint") + } + if len(fp) != 64 { + t.Fatalf("expected 64-char hex digest, got %q (len=%d)", fp, len(fp)) + } +} + +func TestRetrieverConfigFingerprint_modeAndNames(t *testing.T) { + fpAgentic := retrieverConfigFingerprint(RetrieverModeAgentic, nil) + fpPrefetch := retrieverConfigFingerprint(RetrieverModePrefetch, nil) + if fpAgentic == fpPrefetch { + t.Fatal("expected different fingerprints for agentic vs prefetch") + } + if fpPrefetch == "" { + t.Fatal("prefetch mode should produce non-empty fingerprint even with no retrievers") + } + + r1 := &retrieverExecuteStub{name: "wiki"} + r2 := &retrieverExecuteStub{name: "docs"} + fpOne := retrieverConfigFingerprint(RetrieverModePrefetch, []interfaces.Retriever{r1}) + fpTwo := retrieverConfigFingerprint(RetrieverModePrefetch, []interfaces.Retriever{r1, r2}) + if fpOne == fpTwo { + t.Fatal("expected different fingerprints when retriever names differ") + } +} + +func TestRetrieverConfigFingerprint_hybridDiffersFromAgenticAndPrefetch(t *testing.T) { + fpHybrid := retrieverConfigFingerprint(RetrieverModeHybrid, nil) + if fpHybrid == "" { + t.Fatal("hybrid mode should produce a non-empty fingerprint even with no retrievers") + } + fpAgentic := retrieverConfigFingerprint(RetrieverModeAgentic, nil) + if fpHybrid == fpAgentic { + t.Fatal("hybrid fingerprint must differ from agentic fingerprint") + } + fpPrefetch := retrieverConfigFingerprint(RetrieverModePrefetch, nil) + if fpHybrid == fpPrefetch { + t.Fatal("hybrid fingerprint must differ from prefetch fingerprint") + } +} + +func TestRetrieverConfigFingerprint_hybridNamesChangeDigest(t *testing.T) { + r1 := &retrieverExecuteStub{name: "wiki"} + r2 := &retrieverExecuteStub{name: "docs"} + fpOne := retrieverConfigFingerprint(RetrieverModeHybrid, []interfaces.Retriever{r1}) + fpTwo := retrieverConfigFingerprint(RetrieverModeHybrid, []interfaces.Retriever{r1, r2}) + if fpOne == fpTwo { + t.Fatal("expected different hybrid fingerprints when retriever names differ") + } +} + +func TestRetrieverConfigFingerprint_stability(t *testing.T) { + r := &retrieverExecuteStub{name: "kb"} + retrievers := []interfaces.Retriever{r} + for _, mode := range []RetrieverMode{RetrieverModeAgentic, RetrieverModePrefetch, RetrieverModeHybrid} { + fp1 := retrieverConfigFingerprint(mode, retrievers) + fp2 := retrieverConfigFingerprint(mode, retrievers) + if fp1 != fp2 { + t.Fatalf("fingerprint not stable for mode %q: %q vs %q", mode, fp1, fp2) + } + } +} + +func TestRetrieverConfigFingerprint_nameOrderDoesNotMatter(t *testing.T) { + a := &retrieverExecuteStub{name: "alpha"} + b := &retrieverExecuteStub{name: "beta"} + fp1 := retrieverConfigFingerprint(RetrieverModeAgentic, []interfaces.Retriever{a, b}) + fp2 := retrieverConfigFingerprint(RetrieverModeAgentic, []interfaces.Retriever{b, a}) + if fp1 != fp2 { + t.Fatalf("fingerprint must be order-independent: %q vs %q", fp1, fp2) + } +} + +func TestAgentConfigFingerprint_RetrieverNamesChangesDigest(t *testing.T) { + baseOpts := []Option{ + WithName("test"), + WithTemporalConfig(&TemporalConfig{TaskQueue: "q"}), + WithLLMClient(stubLLM{}), + WithRetrieverMode(RetrieverModePrefetch), + } + cfgNoR, err := buildAgentConfig(baseOpts) + if err != nil { + t.Fatal(err) + } + cfgWithR, err := buildAgentConfig(append(baseOpts, WithRetrievers(&retrieverExecuteStub{name: "wiki"}))) + if err != nil { + t.Fatal(err) + } + if cfgNoR.agentConfigFingerprint() == cfgWithR.agentConfigFingerprint() { + t.Fatal("expected different fingerprints when retriever names are registered") + } +} diff --git a/pkg/agent/runtime_factory.go b/pkg/agent/runtime_factory.go index 888b4aa..b8f446a 100644 --- a/pkg/agent/runtime_factory.go +++ b/pkg/agent/runtime_factory.go @@ -34,6 +34,7 @@ func (cfg *agentConfig) buildTemporalRuntime(remoteWorker bool) (runtime.Runtime temporal.WithMetrics(cfg.metrics), temporal.WithAgentMode(string(cfg.agentMode)), temporal.WithAgentToolExecutionMode(cfg.agentToolExecutionMode), + temporal.WithRetrieverFingerprint(retrieverConfigFingerprint(cfg.retrieverMode, cfg.retrievers)), temporal.WithDisableLocalWorker(cfg.disableLocalWorker), // Never allow fingerprint bypass on remote worker runtime. temporal.WithDisableFingerprintCheck(cfg.disableFingerprintCheck && !remoteWorker), diff --git a/pkg/interfaces/mocks/mock_retriever.go b/pkg/interfaces/mocks/mock_retriever.go new file mode 100644 index 0000000..49bc800 --- /dev/null +++ b/pkg/interfaces/mocks/mock_retriever.go @@ -0,0 +1,65 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/agenticenv/agent-sdk-go/pkg/interfaces (interfaces: Retriever) + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + interfaces "github.com/agenticenv/agent-sdk-go/pkg/interfaces" + gomock "github.com/golang/mock/gomock" +) + +// MockRetriever is a mock of Retriever interface. +type MockRetriever struct { + ctrl *gomock.Controller + recorder *MockRetrieverMockRecorder +} + +// MockRetrieverMockRecorder is the mock recorder for MockRetriever. +type MockRetrieverMockRecorder struct { + mock *MockRetriever +} + +// NewMockRetriever creates a new mock instance. +func NewMockRetriever(ctrl *gomock.Controller) *MockRetriever { + mock := &MockRetriever{ctrl: ctrl} + mock.recorder = &MockRetrieverMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRetriever) EXPECT() *MockRetrieverMockRecorder { + return m.recorder +} + +// Name mocks base method. +func (m *MockRetriever) Name() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Name") + ret0, _ := ret[0].(string) + return ret0 +} + +// Name indicates an expected call of Name. +func (mr *MockRetrieverMockRecorder) Name() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockRetriever)(nil).Name)) +} + +// Search mocks base method. +func (m *MockRetriever) Search(arg0 context.Context, arg1 string) ([]interfaces.Document, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Search", arg0, arg1) + ret0, _ := ret[0].([]interfaces.Document) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Search indicates an expected call of Search. +func (mr *MockRetrieverMockRecorder) Search(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Search", reflect.TypeOf((*MockRetriever)(nil).Search), arg0, arg1) +} diff --git a/pkg/interfaces/retriever.go b/pkg/interfaces/retriever.go new file mode 100644 index 0000000..651a3f4 --- /dev/null +++ b/pkg/interfaces/retriever.go @@ -0,0 +1,19 @@ +package interfaces + +import "context" + +//go:generate mockgen -destination=./mocks/mock_retriever.go -package=mocks github.com/agenticenv/agent-sdk-go/pkg/interfaces Retriever + +type Retriever interface { + // Name returns the unique name of the retriever. + Name() string + // Search searches the retriever for documents matching the query. + Search(ctx context.Context, query string) ([]Document, error) +} + +type Document struct { + Content string + Source string + Score float64 + Metadata map[string]any +} diff --git a/pkg/retriever/pgvector/retriever.go b/pkg/retriever/pgvector/retriever.go new file mode 100644 index 0000000..339f7a5 --- /dev/null +++ b/pkg/retriever/pgvector/retriever.go @@ -0,0 +1,295 @@ +// Package pgvector provides a retriever backed by PostgreSQL with the pgvector extension. +// Callers provide a plain-text query; the retriever converts it to an embedding via [EmbedFunc] +// and runs a cosine-similarity nearest-neighbour search against the configured table. +package pgvector + +import ( + "context" + "errors" + "fmt" + "log/slog" + "strings" + "time" + + "github.com/agenticenv/agent-sdk-go/internal/types" + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" + "github.com/agenticenv/agent-sdk-go/pkg/logger" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + pgvec "github.com/pgvector/pgvector-go" + pgxvec "github.com/pgvector/pgvector-go/pgx" +) + +var _ interfaces.Retriever = (*PgvectorRetriever)(nil) + +// EmbedFunc converts a plain-text query into a vector embedding. +// Callers typically wrap an LLM embedding API (e.g. OpenAI text-embedding-3-small). +type EmbedFunc func(ctx context.Context, text string) ([]float32, error) + +// pgRows is the subset of [pgx.Rows] used by Search, allowing injection in tests. +type pgRows interface { + Close() + Next() bool + Scan(dest ...any) error + Err() error +} + +// pgQuerier abstracts the database query call; satisfied by [pgxPoolQuerier] and test stubs. +type pgQuerier interface { + Query(ctx context.Context, sql string, args ...any) (pgRows, error) +} + +// pgxPoolQuerier wraps [pgxpool.Pool] to satisfy [pgQuerier]. +type pgxPoolQuerier struct{ pool *pgxpool.Pool } + +func (q *pgxPoolQuerier) Query(ctx context.Context, sql string, args ...any) (pgRows, error) { + return q.pool.Query(ctx, sql, args...) +} + +// PgvectorRetriever searches a PostgreSQL table with the pgvector extension using cosine similarity. +// The query text is converted to an embedding via [EmbedFunc] before each search. +type PgvectorRetriever struct { + // name is the stable identifier returned by [Name]; required. + name string + + // runtime fields — used by Search. + db pgQuerier + table string + contentCol string + sourceCol string + embeddingCol string + topK int + minScore float64 + embed EmbedFunc + logger logger.Logger + + // build-time fields — consumed by NewRetriever; not used after construction. + dsn string + logLevel string +} + +// Option configures PgvectorRetriever. +type Option func(*PgvectorRetriever) + +// WithPool sets an existing [pgxpool.Pool]. When provided, [WithDSN] is ignored. +// Callers must register pgvector types on the pool: +// +// config.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error { +// return pgxvec.RegisterTypes(ctx, conn) +// } +func WithPool(pool *pgxpool.Pool) Option { + return func(r *PgvectorRetriever) { r.db = &pgxPoolQuerier{pool: pool} } +} + +// WithDSN sets the PostgreSQL connection string used to create a new pool when [WithPool] is omitted. +// The pool is configured to register pgvector types on each new connection automatically. +func WithDSN(dsn string) Option { + return func(r *PgvectorRetriever) { r.dsn = dsn } +} + +// WithTable sets the PostgreSQL table (or view) to search. Required. +func WithTable(table string) Option { + return func(r *PgvectorRetriever) { r.table = table } +} + +// WithContentCol sets the column that holds document text. Defaults to [types.DefaultContentField]. +func WithContentCol(col string) Option { + return func(r *PgvectorRetriever) { r.contentCol = col } +} + +// WithSourceCol sets the column that holds the document source identifier. Defaults to [types.DefaultSourceField]. +func WithSourceCol(col string) Option { + return func(r *PgvectorRetriever) { r.sourceCol = col } +} + +// WithEmbeddingCol sets the column that holds the pgvector embedding. Defaults to "embedding". +func WithEmbeddingCol(col string) Option { + return func(r *PgvectorRetriever) { r.embeddingCol = col } +} + +// WithTopK sets the maximum number of documents returned per search. Defaults to [types.DefaultTopK]. +func WithTopK(topK int) Option { + return func(r *PgvectorRetriever) { r.topK = topK } +} + +// WithMinScore sets the minimum cosine similarity (0–1) for returned documents. Defaults to [types.DefaultMinScore]. +func WithMinScore(minScore float64) Option { + return func(r *PgvectorRetriever) { r.minScore = minScore } +} + +// WithLogger sets the logger. When omitted, a default logger at the configured log level is used. +func WithLogger(l logger.Logger) Option { + return func(r *PgvectorRetriever) { r.logger = l } +} + +// WithLogLevel sets the default log level when [WithLogger] is omitted. Defaults to "error". +func WithLogLevel(level string) Option { + return func(r *PgvectorRetriever) { r.logLevel = level } +} + +// NewRetriever builds a PgvectorRetriever. name must be non-empty and unique across all retrievers +// registered with the same agent. embed is required. [WithTable] is required. When [WithPool] is +// omitted, [WithDSN] must be provided. Zero-valued topK and minScore default to [types.DefaultTopK] and +// [types.DefaultMinScore]. +func NewRetriever(name string, embed EmbedFunc, opts ...Option) (*PgvectorRetriever, error) { + if strings.TrimSpace(name) == "" { + return nil, errors.New("name is required and must be non-empty") + } + if embed == nil { + return nil, errors.New("embed func is required") + } + r := &PgvectorRetriever{name: strings.TrimSpace(name), embed: embed} + for _, opt := range opts { + opt(r) + } + if r.table == "" { + return nil, errors.New("table is required; use WithTable") + } + if r.contentCol == "" { + r.contentCol = types.DefaultContentField + } + if r.sourceCol == "" { + r.sourceCol = types.DefaultSourceField + } + if r.embeddingCol == "" { + r.embeddingCol = "embedding" + } + if r.topK == 0 { + r.topK = types.DefaultTopK + } + if r.minScore == 0 { + r.minScore = types.DefaultMinScore + } + if r.logLevel == "" { + r.logLevel = "error" + } + if r.logger == nil { + r.logger = logger.DefaultLogger(r.logLevel) + } + if r.db == nil { + if r.dsn == "" { + return nil, errors.New("DSN is required when not using WithPool; use WithDSN or WithPool") + } + cfg, err := pgxpool.ParseConfig(r.dsn) + if err != nil { + return nil, fmt.Errorf("parse DSN: %w", err) + } + // Register pgvector types on every new connection so the <=> operator and + // pgvec.NewVector arguments are correctly encoded and decoded. + cfg.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error { + return pgxvec.RegisterTypes(ctx, conn) + } + pool, err := pgxpool.NewWithConfig(context.Background(), cfg) + if err != nil { + return nil, fmt.Errorf("create pgx pool: %w", err) + } + r.db = &pgxPoolQuerier{pool: pool} + } + r.logger.Info(context.Background(), "pgvector retriever built", + slog.String("scope", "pgvector"), + slog.String("name", r.name), + slog.String("table", r.table), + slog.String("contentCol", r.contentCol), + slog.String("sourceCol", r.sourceCol), + slog.String("embeddingCol", r.embeddingCol), + slog.Int("topK", r.topK), + slog.Float64("minScore", r.minScore), + ) + return r, nil +} + +// Name implements [interfaces.Retriever]. +func (r *PgvectorRetriever) Name() string { + return r.name +} + +// Search embeds the query and runs a cosine-similarity nearest-neighbour search against the +// configured PostgreSQL table. Returns at most [WithTopK] documents with similarity ≥ [WithMinScore]. +// +// Table and column names are developer-controlled build-time configuration and are not +// sanitised against SQL injection because they are never derived from runtime user input. +func (r *PgvectorRetriever) Search(ctx context.Context, query string) ([]interfaces.Document, error) { + r.logger.Debug(ctx, "pgvector search start", + slog.String("scope", "pgvector"), + slog.String("name", r.name), + slog.String("table", r.table), + slog.String("query", query), + slog.Int("topK", r.topK), + slog.Float64("minScore", r.minScore), + ) + start := time.Now() + + vec, err := r.embed(ctx, query) + if err != nil { + r.logger.Error(ctx, "pgvector embed failed", + slog.String("scope", "pgvector"), + slog.String("name", r.name), + slog.Duration("elapsed", time.Since(start)), + slog.Any("error", err), + ) + return nil, fmt.Errorf("embed query: %w", err) + } + + // Cosine distance operator (<=>): 0 = identical, 2 = opposite. + // Score = 1 − cosine_distance, giving cosine similarity in [−1, 1] (typically [0, 1] for text). + // Table/column identifiers are build-time developer config, not runtime user input. + //nolint:gosec + sql := fmt.Sprintf( + `SELECT %s, %s, 1 - (%s <=> $1) AS score + FROM %s + WHERE 1 - (%s <=> $1) >= $2 + ORDER BY %s <=> $1 + LIMIT $3`, + r.contentCol, r.sourceCol, r.embeddingCol, + r.table, + r.embeddingCol, + r.embeddingCol, + ) + + rows, err := r.db.Query(ctx, sql, pgvec.NewVector(vec), r.minScore, r.topK) + if err != nil { + r.logger.Error(ctx, "pgvector search failed", + slog.String("scope", "pgvector"), + slog.String("name", r.name), + slog.String("table", r.table), + slog.Duration("elapsed", time.Since(start)), + slog.Any("error", err), + ) + return nil, fmt.Errorf("pgvector query: %w", err) + } + + docs, err := scanRows(rows) + if err != nil { + r.logger.Error(ctx, "pgvector scan failed", + slog.String("scope", "pgvector"), + slog.String("name", r.name), + slog.String("table", r.table), + slog.Duration("elapsed", time.Since(start)), + slog.Any("error", err), + ) + return nil, err + } + + r.logger.Debug(ctx, "pgvector search done", + slog.String("scope", "pgvector"), + slog.String("name", r.name), + slog.String("table", r.table), + slog.Int("docs", len(docs)), + slog.Duration("elapsed", time.Since(start)), + ) + return docs, nil +} + +// scanRows reads content, source, and score from each row into []interfaces.Document. +func scanRows(rows pgRows) ([]interfaces.Document, error) { + defer rows.Close() + var docs []interfaces.Document + for rows.Next() { + var doc interfaces.Document + if err := rows.Scan(&doc.Content, &doc.Source, &doc.Score); err != nil { + return nil, fmt.Errorf("scan row: %w", err) + } + docs = append(docs, doc) + } + return docs, rows.Err() +} diff --git a/pkg/retriever/pgvector/retriever_test.go b/pkg/retriever/pgvector/retriever_test.go new file mode 100644 index 0000000..fa95c82 --- /dev/null +++ b/pkg/retriever/pgvector/retriever_test.go @@ -0,0 +1,402 @@ +package pgvector + +import ( + "context" + "errors" + "testing" + + "github.com/agenticenv/agent-sdk-go/internal/types" + "github.com/agenticenv/agent-sdk-go/pkg/logger" +) + +// --------------------------------------------------------------------------- +// Stubs +// --------------------------------------------------------------------------- + +// noopLogger discards all log output. +type noopLogger struct{} + +func (noopLogger) Debug(_ context.Context, _ string, _ ...any) {} +func (noopLogger) Info(_ context.Context, _ string, _ ...any) {} +func (noopLogger) Warn(_ context.Context, _ string, _ ...any) {} +func (noopLogger) Error(_ context.Context, _ string, _ ...any) {} + +var _ logger.Logger = noopLogger{} + +// stubRows drives scanRows without a real database connection. +type stubRows struct { + data []rowData + pos int + scanErr error + iterErr error +} + +type rowData struct { + content string + source string + score float64 +} + +func (r *stubRows) Close() {} +func (r *stubRows) Next() bool { + r.pos++ + return r.pos <= len(r.data) +} +func (r *stubRows) Scan(dest ...any) error { + if r.scanErr != nil { + return r.scanErr + } + row := r.data[r.pos-1] + *dest[0].(*string) = row.content + *dest[1].(*string) = row.source + *dest[2].(*float64) = row.score + return nil +} +func (r *stubRows) Err() error { return r.iterErr } + +// stubQuerier returns a pre-canned pgRows (or error) for any Query call. +type stubQuerier struct { + rows pgRows + err error +} + +func (s *stubQuerier) Query(_ context.Context, _ string, _ ...any) (pgRows, error) { + return s.rows, s.err +} + +// stubEmbed is a fixed-vector embedding function for tests. +func stubEmbed(_ context.Context, _ string) ([]float32, error) { + return []float32{0.1, 0.2, 0.3}, nil +} + +func errEmbed(_ context.Context, _ string) ([]float32, error) { + return nil, errors.New("embed error") +} + +// newTestRetriever builds a PgvectorRetriever with a stub querier, bypassing the DSN/pool +// validation path so tests do not need a real database. Options are applied after defaults. +func newTestRetriever(t *testing.T, q pgQuerier, opts ...Option) *PgvectorRetriever { + t.Helper() + r := &PgvectorRetriever{ + name: "test-kb", + embed: stubEmbed, + db: q, + table: "items", + contentCol: "content", + sourceCol: "source", + embeddingCol: "embedding", + topK: 5, + minScore: 0.75, + logger: noopLogger{}, + } + for _, opt := range opts { + opt(r) + } + return r +} + +// --------------------------------------------------------------------------- +// Constructor tests +// --------------------------------------------------------------------------- + +func TestNewRetriever_MissingName(t *testing.T) { + _, err := NewRetriever("", stubEmbed, WithTable("items")) + if err == nil || !contains(err.Error(), "name") { + t.Fatalf("expected name error, got %v", err) + } +} + +func TestNewRetriever_WhitespaceName(t *testing.T) { + _, err := NewRetriever(" ", stubEmbed, WithTable("items")) + if err == nil { + t.Fatal("expected error for whitespace name") + } +} + +func TestNewRetriever_MissingEmbed(t *testing.T) { + _, err := NewRetriever("kb", nil, WithTable("items")) + if err == nil || !contains(err.Error(), "embed") { + t.Fatalf("expected embed error, got %v", err) + } +} + +func TestNewRetriever_MissingTable(t *testing.T) { + _, err := NewRetriever("kb", stubEmbed) + if err == nil || !contains(err.Error(), "table") { + t.Fatalf("expected table error, got %v", err) + } +} + +func TestNewRetriever_MissingDSNAndPool(t *testing.T) { + _, err := NewRetriever("kb", stubEmbed, WithTable("items"), WithLogger(noopLogger{})) + if err == nil || !contains(err.Error(), "DSN") { + t.Fatalf("expected DSN error, got %v", err) + } +} + +func TestNewRetriever_InvalidDSN(t *testing.T) { + _, err := NewRetriever("kb", stubEmbed, + WithTable("items"), + WithDSN("not-a-valid-dsn"), + WithLogger(noopLogger{}), + ) + if err == nil { + t.Fatal("expected parse DSN error") + } +} + +func TestNewRetriever_Defaults(t *testing.T) { + r := newTestRetriever(t, &stubQuerier{}) + if r.contentCol != "content" { + t.Errorf("contentCol = %q, want %q", r.contentCol, "content") + } + if r.sourceCol != "source" { + t.Errorf("sourceCol = %q, want %q", r.sourceCol, "source") + } + if r.embeddingCol != "embedding" { + t.Errorf("embeddingCol = %q, want %q", r.embeddingCol, "embedding") + } + if r.topK != 5 { + t.Errorf("topK = %d, want 5", r.topK) + } + if r.minScore != types.DefaultMinScore { + t.Errorf("minScore = %f, want %f", r.minScore, types.DefaultMinScore) + } +} + +func TestNewRetriever_NameTrimmed(t *testing.T) { + r := newTestRetriever(t, &stubQuerier{}) + r2 := newTestRetriever(t, &stubQuerier{}, WithTopK(3)) + _ = r2 + // The retriever in newTestRetriever uses name "test-kb" (already trimmed) + if r.name != "test-kb" { + t.Errorf("name = %q, want %q", r.name, "test-kb") + } +} + +func TestNewRetriever_CustomOptions(t *testing.T) { + r := newTestRetriever(t, &stubQuerier{}, + WithContentCol("body"), + WithSourceCol("url"), + WithEmbeddingCol("vec"), + WithTopK(10), + WithMinScore(0.9), + ) + if r.contentCol != "body" { + t.Errorf("contentCol = %q", r.contentCol) + } + if r.sourceCol != "url" { + t.Errorf("sourceCol = %q", r.sourceCol) + } + if r.embeddingCol != "vec" { + t.Errorf("embeddingCol = %q", r.embeddingCol) + } + if r.topK != 10 { + t.Errorf("topK = %d", r.topK) + } + if r.minScore != 0.9 { + t.Errorf("minScore = %f", r.minScore) + } +} + +func TestNewRetriever_WithLogger(t *testing.T) { + r := newTestRetriever(t, &stubQuerier{}, WithLogger(noopLogger{})) + if _, ok := r.logger.(noopLogger); !ok { + t.Errorf("logger type = %T, want noopLogger", r.logger) + } +} + +// --------------------------------------------------------------------------- +// Name +// --------------------------------------------------------------------------- + +func TestPgvectorRetriever_Name(t *testing.T) { + r := newTestRetriever(t, &stubQuerier{}) + if r.Name() != "test-kb" { + t.Errorf("Name() = %q", r.Name()) + } +} + +// --------------------------------------------------------------------------- +// Search tests +// --------------------------------------------------------------------------- + +func TestSearch_ReturnsDocs(t *testing.T) { + q := &stubQuerier{ + rows: &stubRows{data: []rowData{ + {content: "Go routines are lightweight", source: "go.dev", score: 0.95}, + {content: "Channels enable communication", source: "go.dev", score: 0.88}, + }}, + } + r := newTestRetriever(t, q) + + docs, err := r.Search(context.Background(), "concurrency in Go") + if err != nil { + t.Fatal(err) + } + if len(docs) != 2 { + t.Fatalf("docs len = %d, want 2", len(docs)) + } + if docs[0].Content != "Go routines are lightweight" { + t.Errorf("docs[0].Content = %q", docs[0].Content) + } + if docs[0].Source != "go.dev" { + t.Errorf("docs[0].Source = %q", docs[0].Source) + } + if docs[0].Score != 0.95 { + t.Errorf("docs[0].Score = %f", docs[0].Score) + } + if docs[1].Content != "Channels enable communication" { + t.Errorf("docs[1].Content = %q", docs[1].Content) + } +} + +func TestSearch_EmptyResult(t *testing.T) { + q := &stubQuerier{rows: &stubRows{}} + r := newTestRetriever(t, q) + + docs, err := r.Search(context.Background(), "nothing matches") + if err != nil { + t.Fatal(err) + } + if len(docs) != 0 { + t.Fatalf("expected 0 docs, got %d", len(docs)) + } +} + +func TestSearch_EmbedError(t *testing.T) { + r := newTestRetriever(t, &stubQuerier{}) + r.embed = errEmbed + + _, err := r.Search(context.Background(), "query") + if err == nil || !contains(err.Error(), "embed") { + t.Fatalf("expected embed error, got %v", err) + } +} + +func TestSearch_QueryError(t *testing.T) { + q := &stubQuerier{err: errors.New("connection refused")} + r := newTestRetriever(t, q) + + _, err := r.Search(context.Background(), "query") + if err == nil || !contains(err.Error(), "pgvector query") { + t.Fatalf("expected query error, got %v", err) + } +} + +func TestSearch_ScanError(t *testing.T) { + q := &stubQuerier{ + rows: &stubRows{ + data: []rowData{{content: "x", source: "y", score: 0.9}}, + scanErr: errors.New("scan failed"), + }, + } + r := newTestRetriever(t, q) + + _, err := r.Search(context.Background(), "query") + if err == nil || !contains(err.Error(), "scan") { + t.Fatalf("expected scan error, got %v", err) + } +} + +func TestSearch_IterError(t *testing.T) { + q := &stubQuerier{ + rows: &stubRows{iterErr: errors.New("cursor error")}, + } + r := newTestRetriever(t, q) + + _, err := r.Search(context.Background(), "query") + if err == nil || !contains(err.Error(), "cursor") { + t.Fatalf("expected iter error, got %v", err) + } +} + +func TestSearch_PassesTopKAndMinScore(t *testing.T) { + var gotSQL string + var gotArgs []any + + capturingQ := &capturingQuerier{rows: &stubRows{}} + r := newTestRetriever(t, capturingQ, WithTopK(3), WithMinScore(0.8)) + + _, err := r.Search(context.Background(), "q") + if err != nil { + t.Fatal(err) + } + gotSQL = capturingQ.lastSQL + gotArgs = capturingQ.lastArgs + + if !contains(gotSQL, "LIMIT $3") { + t.Errorf("SQL missing LIMIT: %s", gotSQL) + } + // args: $1=vector, $2=minScore, $3=topK + if len(gotArgs) != 3 { + t.Fatalf("args len = %d, want 3", len(gotArgs)) + } + if gotArgs[1] != 0.8 { + t.Errorf("minScore arg = %v, want 0.8", gotArgs[1]) + } + if gotArgs[2] != 3 { + t.Errorf("topK arg = %v, want 3", gotArgs[2]) + } +} + +// --------------------------------------------------------------------------- +// scanRows tests +// --------------------------------------------------------------------------- + +func TestScanRows_SingleDoc(t *testing.T) { + rows := &stubRows{data: []rowData{ + {content: "hello", source: "src.md", score: 0.91}, + }} + docs, err := scanRows(rows) + if err != nil { + t.Fatal(err) + } + if len(docs) != 1 { + t.Fatalf("len = %d", len(docs)) + } + if docs[0].Content != "hello" || docs[0].Source != "src.md" || docs[0].Score != 0.91 { + t.Errorf("got %+v", docs[0]) + } +} + +func TestScanRows_Empty(t *testing.T) { + docs, err := scanRows(&stubRows{}) + if err != nil { + t.Fatal(err) + } + if len(docs) != 0 { + t.Fatalf("expected empty, got %d docs", len(docs)) + } +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +func contains(s, sub string) bool { + return len(sub) > 0 && len(s) >= len(sub) && + (s == sub || len(s) > 0 && containsStr(s, sub)) +} + +func containsStr(s, sub string) bool { + for i := 0; i <= len(s)-len(sub); i++ { + if s[i:i+len(sub)] == sub { + return true + } + } + return false +} + +// capturingQuerier records the last SQL and args for assertion. +type capturingQuerier struct { + rows pgRows + lastSQL string + lastArgs []any +} + +func (c *capturingQuerier) Query(_ context.Context, sql string, args ...any) (pgRows, error) { + c.lastSQL = sql + c.lastArgs = args + return c.rows, nil +} diff --git a/pkg/retriever/weaviate/retriever.go b/pkg/retriever/weaviate/retriever.go new file mode 100644 index 0000000..a38f922 --- /dev/null +++ b/pkg/retriever/weaviate/retriever.go @@ -0,0 +1,280 @@ +// Package weaviate provides a vector retriever backed by Weaviate's nearText GraphQL API. +// The server embeds query text internally; callers supply plain-text queries only. +package weaviate + +import ( + "context" + "errors" + "fmt" + "log/slog" + "strings" + "time" + + "github.com/agenticenv/agent-sdk-go/internal/types" + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" + "github.com/agenticenv/agent-sdk-go/pkg/logger" + client "github.com/weaviate/weaviate-go-client/v5/weaviate" + "github.com/weaviate/weaviate-go-client/v5/weaviate/graphql" + "github.com/weaviate/weaviate/entities/models" +) + +var _ interfaces.Retriever = (*WeaviateRetriever)(nil) + +// WeaviateRetriever searches a Weaviate class via nearText and maps hits to interfaces.Document. +type WeaviateRetriever struct { + // name is the stable identifier returned by [Name]; required, set via the first argument of [NewRetriever]. + name string + + // runtime fields — used by Search. + className string + contentField string + sourceField string + topK int + minScore float64 + logger logger.Logger + client *client.Client + + // build-time fields — consumed by NewRetriever; not used after construction. + host string + scheme string + logLevel string +} + +// Option configures WeaviateRetriever. +type Option func(*WeaviateRetriever) + +// WithClient sets the Weaviate client. +func WithClient(client *client.Client) Option { + return func(c *WeaviateRetriever) { c.client = client } +} + +// WithHost sets the Weaviate host. +func WithHost(host string) Option { + return func(c *WeaviateRetriever) { c.host = host } +} + +// WithScheme sets the Weaviate scheme. +func WithScheme(scheme string) Option { + return func(c *WeaviateRetriever) { c.scheme = scheme } +} + +// WithContentField sets the Weaviate content field. +func WithContentField(contentField string) Option { + return func(c *WeaviateRetriever) { c.contentField = contentField } +} + +// WithSourceField sets the Weaviate source field. +func WithSourceField(sourceField string) Option { + return func(c *WeaviateRetriever) { c.sourceField = sourceField } +} + +// WithClassName sets the Weaviate class name. +func WithClassName(className string) Option { + return func(c *WeaviateRetriever) { c.className = className } +} + +// WithTopK sets the maximum number of documents returned per search. +func WithTopK(topK int) Option { + return func(c *WeaviateRetriever) { c.topK = topK } +} + +// WithMinScore sets the Weaviate minimum score. +func WithMinScore(minScore float64) Option { + return func(c *WeaviateRetriever) { c.minScore = minScore } +} + +// WithLogger sets the Weaviate logger. +func WithLogger(logger logger.Logger) Option { + return func(c *WeaviateRetriever) { c.logger = logger } +} + +// WithLogLevel sets the Weaviate log level. +func WithLogLevel(logLevel string) Option { + return func(c *WeaviateRetriever) { c.logLevel = logLevel } +} + +// NewRetriever builds a WeaviateRetriever. name is the stable identifier returned by [Name] and +// must be non-empty and unique across all retrievers registered with the same agent. +// className is required. When WithClient is omitted, host is required and scheme defaults to +// [types.DefaultScheme]. Zero-valued topK and minScore default to [types.DefaultTopK] and +// [types.DefaultMinScore]. +func NewRetriever(name string, opts ...Option) (*WeaviateRetriever, error) { + if strings.TrimSpace(name) == "" { + return nil, errors.New("name is required and must be non-empty") + } + r := &WeaviateRetriever{name: strings.TrimSpace(name)} + for _, opt := range opts { + opt(r) + } + if r.className == "" { + return nil, errors.New("className is required") + } + if r.contentField == "" { + r.contentField = types.DefaultContentField + } + if r.sourceField == "" { + r.sourceField = types.DefaultSourceField + } + if r.topK == 0 { + r.topK = types.DefaultTopK + } + if r.minScore == 0 { + r.minScore = types.DefaultMinScore + } + if r.logLevel == "" { + r.logLevel = "error" + } + if r.logger == nil { + r.logger = logger.DefaultLogger(r.logLevel) + } + if r.client == nil { + if r.host == "" { + return nil, errors.New("host is required when not using WithClient") + } + if r.scheme == "" { + r.scheme = types.DefaultScheme + } + weaviateConfig := client.Config{ + Scheme: r.scheme, + Host: r.host, + } + client, err := client.NewClient(weaviateConfig) + if err != nil { + return nil, fmt.Errorf("create weaviate client: %w", err) + } + r.client = client + } + r.logger.Info(context.Background(), "weaviate retriever built", + slog.String("scope", "weaviate"), + slog.String("name", r.name), + slog.String("class", r.className), + slog.Int("topK", r.topK), + slog.Float64("minScore", r.minScore), + slog.String("contentField", r.contentField), + slog.String("sourceField", r.sourceField), + ) + return r, nil +} + +// Name implements [interfaces.Retriever]. +func (r *WeaviateRetriever) Name() string { + return r.name +} + +// Search runs a nearText GraphQL query against the configured class and returns ranked documents. +func (r *WeaviateRetriever) Search(ctx context.Context, query string) ([]interfaces.Document, error) { + if r.client == nil { + return nil, errors.New("client is not set") + } + + r.logger.Debug(ctx, "weaviate search start", + slog.String("scope", "weaviate"), + slog.String("name", r.name), + slog.String("class", r.className), + slog.String("query", query), + slog.Int("topK", r.topK), + slog.Float64("minScore", r.minScore), + ) + start := time.Now() + + fields := []graphql.Field{ + {Name: r.contentField}, + {Name: r.sourceField}, + {Name: "_additional { certainty }"}, + } + nearText := r.client.GraphQL(). + NearTextArgBuilder(). + WithConcepts([]string{query}). + WithCertainty(float32(r.minScore)) + + result, err := r.client.GraphQL().Get(). + WithClassName(r.className). + WithNearText(nearText). + WithLimit(r.topK). + WithFields(fields...). + Do(ctx) + if err != nil { + r.logger.Error(ctx, "weaviate search failed", + slog.String("scope", "weaviate"), + slog.String("name", r.name), + slog.String("class", r.className), + slog.Duration("elapsed", time.Since(start)), + slog.Any("error", err), + ) + return nil, err + } + + docs, err := r.parseDocuments(ctx, result) + if err != nil { + r.logger.Error(ctx, "weaviate parse response failed", + slog.String("scope", "weaviate"), + slog.String("name", r.name), + slog.String("class", r.className), + slog.Duration("elapsed", time.Since(start)), + slog.Any("error", err), + ) + return nil, err + } + + r.logger.Debug(ctx, "weaviate search done", + slog.String("scope", "weaviate"), + slog.String("name", r.name), + slog.String("class", r.className), + slog.Int("docs", len(docs)), + slog.Duration("elapsed", time.Since(start)), + ) + return docs, nil +} + +// parseDocuments maps a Weaviate GraphQL response into []interfaces.Document. +// Shape: data.Get[className][]object with content/source and _additional.certainty. +func (r *WeaviateRetriever) parseDocuments(ctx context.Context, result *models.GraphQLResponse) ([]interfaces.Document, error) { + if result == nil || result.Data == nil { + return nil, nil + } + + get, ok := result.Data["Get"].(map[string]interface{}) + if !ok { + return nil, errors.New("invalid response: missing Get") + } + + items, ok := get[r.className].([]interface{}) + if !ok { + return nil, errors.New("invalid response: missing class data") + } + + var docs []interfaces.Document + for _, item := range items { + obj, ok := item.(map[string]interface{}) + if !ok { + r.logger.Warn(ctx, "weaviate: skipping non-object item in response", + slog.String("scope", "weaviate"), + slog.String("name", r.name), + slog.String("class", r.className), + ) + continue + } + + doc := interfaces.Document{ + Content: getString(obj, r.contentField), + Source: getString(obj, r.sourceField), + Metadata: obj, + } + if additional, ok := obj["_additional"].(map[string]interface{}); ok { + if certainty, ok := additional["certainty"].(float64); ok { + doc.Score = certainty + } + } + + docs = append(docs, doc) + } + return docs, nil +} + +// getString reads a string property from a Weaviate object map, or "" if missing or wrong type. +func getString(obj map[string]interface{}, key string) string { + if v, ok := obj[key].(string); ok { + return v + } + return "" +} diff --git a/pkg/retriever/weaviate/retriever_test.go b/pkg/retriever/weaviate/retriever_test.go new file mode 100644 index 0000000..2122792 --- /dev/null +++ b/pkg/retriever/weaviate/retriever_test.go @@ -0,0 +1,364 @@ +package weaviate + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/agenticenv/agent-sdk-go/internal/types" + "github.com/agenticenv/agent-sdk-go/pkg/interfaces" + weaviateclient "github.com/weaviate/weaviate-go-client/v5/weaviate" + "github.com/weaviate/weaviate/entities/models" +) + +// noopLogger is a logger.Logger that discards all output, used in tests. +type noopLogger struct{} + +func (noopLogger) Debug(ctx context.Context, msg string, _ ...any) {} +func (noopLogger) Info(ctx context.Context, msg string, _ ...any) {} +func (noopLogger) Warn(ctx context.Context, msg string, _ ...any) {} +func (noopLogger) Error(ctx context.Context, msg string, _ ...any) {} + +func testWeaviateHost(t *testing.T, srv *httptest.Server) string { + t.Helper() + u, err := url.Parse(srv.URL) + if err != nil { + t.Fatal(err) + } + return u.Host +} + +func TestNewRetriever_MissingName(t *testing.T) { + _, err := NewRetriever("", WithHost("localhost:8080"), WithClassName("Article")) + if err == nil || !strings.Contains(err.Error(), "name is required") { + t.Fatalf("err = %v", err) + } + _, err = NewRetriever(" ", WithHost("localhost:8080"), WithClassName("Article")) + if err == nil || !strings.Contains(err.Error(), "name is required") { + t.Fatalf("whitespace-only name: err = %v", err) + } +} + +func TestNewRetriever_MissingClassName(t *testing.T) { + _, err := NewRetriever("kb", WithHost("localhost:8080")) + if err == nil || !strings.Contains(err.Error(), "className") { + t.Fatalf("err = %v", err) + } +} + +func TestNewRetriever_MissingHost(t *testing.T) { + _, err := NewRetriever("kb", WithClassName("Article")) + if err == nil || !strings.Contains(err.Error(), "host") { + t.Fatalf("err = %v", err) + } +} + +func TestNewRetriever_Defaults(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + r, err := NewRetriever("kb", + WithHost(testWeaviateHost(t, srv)), + WithClassName("Article"), + ) + if err != nil { + t.Fatal(err) + } + if r.name != "kb" { + t.Fatalf("name = %q", r.name) + } + if r.contentField != types.DefaultContentField { + t.Fatalf("contentField = %q", r.contentField) + } + if r.sourceField != types.DefaultSourceField { + t.Fatalf("sourceField = %q", r.sourceField) + } + if r.topK != types.DefaultTopK { + t.Fatalf("topK = %d", r.topK) + } + if r.minScore != types.DefaultMinScore { + t.Fatalf("minScore = %v", r.minScore) + } + if r.scheme != types.DefaultScheme { + t.Fatalf("scheme = %q", r.scheme) + } + if r.client == nil { + t.Fatal("client nil") + } +} + +func TestNewRetriever_WithClient(t *testing.T) { + wc, err := weaviateclient.NewClient(weaviateclient.Config{ + Scheme: "http", + Host: "unused:0", + }) + if err != nil { + t.Fatal(err) + } + r, err := NewRetriever("articles", + WithClient(wc), + WithClassName("Article"), + WithTopK(3), + WithMinScore(0.5), + WithContentField("body"), + WithSourceField("url"), + ) + if err != nil { + t.Fatal(err) + } + if r.topK != 3 || r.minScore != 0.5 { + t.Fatalf("topK=%d minScore=%v", r.topK, r.minScore) + } + if r.contentField != "body" || r.sourceField != "url" { + t.Fatalf("fields %q %q", r.contentField, r.sourceField) + } +} + +func TestWeaviateRetriever_Search_NoClient(t *testing.T) { + r := &WeaviateRetriever{name: "kb", className: "Article", client: nil} + _, err := r.Search(context.Background(), "query") + if err == nil || !strings.Contains(err.Error(), "client is not set") { + t.Fatalf("err = %v", err) + } +} + +func TestGetString(t *testing.T) { + obj := map[string]interface{}{ + "content": "hello", + "count": 42, + } + if got := getString(obj, "content"); got != "hello" { + t.Fatalf("got %q", got) + } + if got := getString(obj, "missing"); got != "" { + t.Fatalf("got %q", got) + } + if got := getString(obj, "count"); got != "" { + t.Fatalf("got %q", got) + } +} + +func TestParseDocuments_NilAndEmpty(t *testing.T) { + r := &WeaviateRetriever{className: "Article", contentField: "content", sourceField: "source", logger: noopLogger{}} + + docs, err := r.parseDocuments(context.Background(), nil) + if err != nil || docs != nil { + t.Fatalf("nil: docs=%v err=%v", docs, err) + } + + docs, err = r.parseDocuments(context.Background(), &models.GraphQLResponse{}) + if err != nil || docs != nil { + t.Fatalf("empty data: docs=%v err=%v", docs, err) + } +} + +func graphQLData(entries map[string]interface{}) map[string]models.JSONObject { + out := make(map[string]models.JSONObject, len(entries)) + for k, v := range entries { + out[k] = v + } + return out +} + +func TestParseDocuments_InvalidResponse(t *testing.T) { + r := &WeaviateRetriever{className: "Article", contentField: "content", sourceField: "source", logger: noopLogger{}} + + _, err := r.parseDocuments(context.Background(), &models.GraphQLResponse{ + Data: graphQLData(map[string]interface{}{"Get": "not-a-map"}), + }) + if err == nil || !strings.Contains(err.Error(), "missing Get") { + t.Fatalf("err = %v", err) + } + + _, err = r.parseDocuments(context.Background(), &models.GraphQLResponse{ + Data: graphQLData(map[string]interface{}{ + "Get": map[string]interface{}{"Article": "not-a-slice"}, + }), + }) + if err == nil || !strings.Contains(err.Error(), "missing class data") { + t.Fatalf("err = %v", err) + } +} + +func TestParseDocuments_Success(t *testing.T) { + r := &WeaviateRetriever{ + className: "Article", + contentField: "content", + sourceField: "source", + logger: noopLogger{}, + } + result := &models.GraphQLResponse{ + Data: graphQLData(map[string]interface{}{ + "Get": map[string]interface{}{ + "Article": []interface{}{ + map[string]interface{}{ + "content": "first doc", + "source": "a.md", + "tags": []interface{}{"go"}, + "_additional": map[string]interface{}{ + "certainty": 0.91, + }, + }, + "not-an-object", + map[string]interface{}{ + "content": 42, + "source": "b.md", + }, + }, + }, + }), + } + + docs, err := r.parseDocuments(context.Background(), result) + if err != nil { + t.Fatal(err) + } + if len(docs) != 2 { + t.Fatalf("len(docs) = %d", len(docs)) + } + if docs[0].Content != "first doc" || docs[0].Source != "a.md" || docs[0].Score != 0.91 { + t.Fatalf("doc[0] = %#v", docs[0]) + } + if docs[0].Metadata == nil { + t.Fatal("metadata nil") + } + if docs[1].Content != "" || docs[1].Source != "b.md" || docs[1].Score != 0 { + t.Fatalf("doc[1] = %#v", docs[1]) + } +} + +func TestWeaviateRetriever_Search_Success(t *testing.T) { + const className = "Article" + mockBody := `{ + "data": { + "Get": { + "Article": [ + { + "content": "Weaviate is a vector database", + "source": "docs/weaviate.md", + "_additional": { "certainty": 0.88 } + } + ] + } + } + }` + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + switch r.URL.Path { + case "/v1/meta", "/v1/.well-known/ready": + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, `{}`) + return + case "/v1/graphql": + if r.Method != http.MethodPost { + t.Errorf("graphql method = %s", r.Method) + } + body, _ := io.ReadAll(r.Body) + if !strings.Contains(string(body), "nearText") { + t.Errorf("body missing nearText: %s", body) + } + if !strings.Contains(string(body), className) { + t.Errorf("body missing class: %s", body) + } + _, _ = io.WriteString(w, mockBody) + return + default: + t.Errorf("unexpected path = %s", r.URL.Path) + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + + r, err := NewRetriever("kb", + WithHost(testWeaviateHost(t, srv)), + WithClassName(className), + WithTopK(2), + WithMinScore(0.7), + ) + if err != nil { + t.Fatal(err) + } + + docs, err := r.Search(context.Background(), "vector database") + if err != nil { + t.Fatal(err) + } + if len(docs) != 1 { + t.Fatalf("len(docs) = %d", len(docs)) + } + want := interfaces.Document{ + Content: "Weaviate is a vector database", + Source: "docs/weaviate.md", + Score: 0.88, + } + if docs[0].Content != want.Content || docs[0].Source != want.Source || docs[0].Score != want.Score { + t.Fatalf("got %#v want content/source/score match", docs[0]) + } + if docs[0].Metadata == nil { + t.Fatal("metadata nil") + } +} + +func TestWeaviateRetriever_Name(t *testing.T) { + wc, _ := weaviateclient.NewClient(weaviateclient.Config{Scheme: "http", Host: "unused:0"}) + r, err := NewRetriever("kb-articles", WithClient(wc), WithClassName("Article")) + if err != nil { + t.Fatal(err) + } + if got := r.Name(); got != "kb-articles" { + t.Fatalf("Name() = %q, want %q", got, "kb-articles") + } +} + +func TestNewRetriever_NameTrimmed(t *testing.T) { + wc, _ := weaviateclient.NewClient(weaviateclient.Config{Scheme: "http", Host: "unused:0"}) + r, err := NewRetriever(" kb ", WithClient(wc), WithClassName("Article")) + if err != nil { + t.Fatal(err) + } + if got := r.Name(); got != "kb" { + t.Fatalf("Name() = %q, want trimmed %q", got, "kb") + } +} + +func TestNewRetriever_WithLogger(t *testing.T) { + wc, _ := weaviateclient.NewClient(weaviateclient.Config{Scheme: "http", Host: "unused:0"}) + r, err := NewRetriever("kb", WithClient(wc), WithClassName("Article"), WithLogger(noopLogger{})) + if err != nil { + t.Fatal(err) + } + if r.logger == nil { + t.Fatal("logger is nil after WithLogger") + } + if _, ok := r.logger.(noopLogger); !ok { + t.Fatalf("logger type = %T, want noopLogger", r.logger) + } +} + +func TestWeaviateRetriever_Search_GraphQLError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = io.WriteString(w, `{"error":[{"message":"boom"}]}`) + })) + defer srv.Close() + + r, err := NewRetriever("kb", + WithHost(testWeaviateHost(t, srv)), + WithClassName("Article"), + ) + if err != nil { + t.Fatal(err) + } + + _, err = r.Search(context.Background(), "query") + if err == nil { + t.Fatal("expected error") + } +} From f168c5d093155a16acfcecf9027d0c4df57bf136 Mon Sep 17 00:00:00 2001 From: Vinod Vanjarapu Date: Sun, 24 May 2026 11:52:50 -0700 Subject: [PATCH 2/2] fix: formate issue and align make check with ci --- Makefile | 7 ++++--- examples/agent_with_retriever/common/config.go | 16 ++++++++-------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/Makefile b/Makefile index 128bf74..b51d485 100644 --- a/Makefile +++ b/Makefile @@ -33,9 +33,10 @@ test: go test ./internal/... -count=1 @echo "==> Tests complete" -# Full gate: format, then tests + fmt-check + spell + go vet + golangci-lint + secrets-scan (lint runs fmt-check and spell) -check: fmt test lint secrets-scan - @echo "==> All checks passed" +# Run before push: lint, test, build, and secrets scan (same core gates as CI; no auto-format). +# Coverage is CI-only (`make test-coverage` when you want the report). If fmt-check fails, run `make fmt`. +check: lint test build secrets-scan + @echo "==> All checks passed (ready to push)" # Run tests with coverage test-coverage: diff --git a/examples/agent_with_retriever/common/config.go b/examples/agent_with_retriever/common/config.go index 4e9cee2..2ad3f96 100644 --- a/examples/agent_with_retriever/common/config.go +++ b/examples/agent_with_retriever/common/config.go @@ -82,15 +82,15 @@ func LoadSettings() (*Settings, error) { WeaviateTopK: getEnvInt("WEAVIATE_TOP_K", 0), WeaviateMinScore: getEnvFloat("WEAVIATE_MIN_SCORE", 0), - PGDSN: strings.TrimSpace(getEnv("PGVECTOR_DSN", "")), - PGTable: getEnv("PGVECTOR_TABLE", "documents"), - PGContentCol: getEnv("PGVECTOR_CONTENT_COL", "content"), - PGSourceCol: getEnv("PGVECTOR_SOURCE_COL", "source"), - PGEmbeddingCol: getEnv("PGVECTOR_EMBEDDING_COL", "embedding"), - PGRetrieverName: getEnv("PGVECTOR_RETRIEVER_NAME", "pgvector-kb"), - PGTopK: getEnvInt("PGVECTOR_TOP_K", 0), + PGDSN: strings.TrimSpace(getEnv("PGVECTOR_DSN", "")), + PGTable: getEnv("PGVECTOR_TABLE", "documents"), + PGContentCol: getEnv("PGVECTOR_CONTENT_COL", "content"), + PGSourceCol: getEnv("PGVECTOR_SOURCE_COL", "source"), + PGEmbeddingCol: getEnv("PGVECTOR_EMBEDDING_COL", "embedding"), + PGRetrieverName: getEnv("PGVECTOR_RETRIEVER_NAME", "pgvector-kb"), + PGTopK: getEnvInt("PGVECTOR_TOP_K", 0), // Example default 0.35 — sample KB often scores 0.3–0.6 per topic; 0.5 drops secondary docs on combined queries. - PGMinScore: getEnvFloat("PGVECTOR_MIN_SCORE", 0.35), + PGMinScore: getEnvFloat("PGVECTOR_MIN_SCORE", 0.35), EmbeddingModel: getEnv("EMBEDDING_MODEL", "text-embedding-3-small"), EmbeddingBaseURL: strings.TrimSpace(getEnv("EMBEDDING_BASEURL", "")), EmbeddingAPIKey: strings.TrimSpace(getEnv("EMBEDDING_APIKEY", "")),