From b09e981199611e00ee0b96ce595805ce26290c80 Mon Sep 17 00:00:00 2001 From: Michael Salaverry Date: Sat, 23 May 2026 17:06:07 +0300 Subject: [PATCH] feat(mcp): add memory_get and memory_query tools memory_get(hash) returns the full dbmem_content row (hash, path, context, value, created_at, last_accessed) for a given content hash from search results. memory_query(query) executes an arbitrary SELECT against the memory database. Read-only safety is enforced at the SQLite engine level via PRAGMA query_only=ON on a dedicated connection, rather than a string-prefix check. Co-Authored-By: Claude Sonnet 4.6 --- cli/internal/mcp/mcp.go | 12 ++++ cli/internal/mcp/mcp_test.go | 2 + cli/internal/memory/memory.go | 66 ++++++++++++++++++++++ cli/internal/memory/memory_test.go | 89 ++++++++++++++++++++++++++++++ 4 files changed, 169 insertions(+) diff --git a/cli/internal/mcp/mcp.go b/cli/internal/mcp/mcp.go index ba0ac33..9305ee8 100644 --- a/cli/internal/mcp/mcp.go +++ b/cli/internal/mcp/mcp.go @@ -43,6 +43,8 @@ func ToolNames() []string { "memory_delete_context", "memory_reindex", "memory_status", + "memory_get", + "memory_query", } } @@ -155,6 +157,12 @@ func tools() []map[string]any { }, []string{"context"}), tool("memory_reindex", map[string]any{}, nil), tool("memory_status", map[string]any{}, nil), + tool("memory_get", map[string]any{ + "hash": stringSchema("Content hash from a memory_search result"), + }, []string{"hash"}), + tool("memory_query", map[string]any{ + "query": stringSchema("Read-only SELECT statement to run against the memory database"), + }, []string{"query"}), } } @@ -198,6 +206,10 @@ func (s Server) callTool(ctx context.Context, name string, args map[string]any) return "ok", memory.Delete(ctx, s.DB, strArg(args, "hash")) case "memory_delete_context": return "ok", memory.DeleteContext(ctx, s.DB, strArg(args, "context")) + case "memory_get": + return memory.Get(ctx, s.DB, strArg(args, "hash")) + case "memory_query": + return memory.Query(ctx, s.DB, strArg(args, "query")) case "memory_reindex": return "ok", memory.Reindex(ctx, s.DB) case "memory_status": diff --git a/cli/internal/mcp/mcp_test.go b/cli/internal/mcp/mcp_test.go index cd88a0c..e0270bf 100644 --- a/cli/internal/mcp/mcp_test.go +++ b/cli/internal/mcp/mcp_test.go @@ -23,6 +23,8 @@ func TestToolNames(t *testing.T) { "memory_delete_context": true, "memory_reindex": true, "memory_status": true, + "memory_get": true, + "memory_query": true, } for _, name := range names { delete(want, name) diff --git a/cli/internal/memory/memory.go b/cli/internal/memory/memory.go index b8397ae..5ebbc80 100644 --- a/cli/internal/memory/memory.go +++ b/cli/internal/memory/memory.go @@ -157,6 +157,72 @@ func Status(ctx context.Context, db *sql.DB) (map[string]any, error) { return out, nil } +type ContentResult struct { + Hash string `json:"hash"` + Path string `json:"path"` + Context *string `json:"context"` + Value *string `json:"value"` + CreatedAt int64 `json:"created_at"` + LastAccessed int64 `json:"last_accessed"` +} + +func Get(ctx context.Context, db *sql.DB, hash string) (string, error) { + var r ContentResult + err := db.QueryRowContext(ctx, + "SELECT hash, path, context, value, created_at, last_accessed FROM dbmem_content WHERE hash = ?", + hash, + ).Scan(&r.Hash, &r.Path, &r.Context, &r.Value, &r.CreatedAt, &r.LastAccessed) + if errors.Is(err, sql.ErrNoRows) { + return "", fmt.Errorf("hash not found: %s", hash) + } + if err != nil { + return "", err + } + data, _ := json.MarshalIndent(r, "", " ") + return string(data), nil +} + +func Query(ctx context.Context, db *sql.DB, query string) (string, error) { + conn, err := db.Conn(ctx) + if err != nil { + return "", err + } + defer conn.Close() + if _, err := conn.ExecContext(ctx, "PRAGMA query_only = ON"); err != nil { + return "", err + } + rows, err := conn.QueryContext(ctx, query) + if err != nil { + return "", err + } + defer rows.Close() + cols, err := rows.Columns() + if err != nil { + return "", err + } + var results []map[string]any + for rows.Next() { + vals := make([]any, len(cols)) + ptrs := make([]any, len(cols)) + for i := range vals { + ptrs[i] = &vals[i] + } + if err := rows.Scan(ptrs...); err != nil { + return "", err + } + row := make(map[string]any, len(cols)) + for i, col := range cols { + row[col] = vals[i] + } + results = append(results, row) + } + if err := rows.Err(); err != nil { + return "", err + } + data, _ := json.MarshalIndent(results, "", " ") + return string(data), nil +} + func ResultsJSON(results []SearchResult) string { data, _ := json.MarshalIndent(results, "", " ") return string(data) diff --git a/cli/internal/memory/memory_test.go b/cli/internal/memory/memory_test.go index ca86a2d..cb9c310 100644 --- a/cli/internal/memory/memory_test.go +++ b/cli/internal/memory/memory_test.go @@ -1,11 +1,100 @@ package memory import ( + "context" + "database/sql" "testing" + _ "github.com/mattn/go-sqlite3" + "github.com/sqliteai/sqlite-memory/cli/internal/config" ) +func openTestDB(t *testing.T) *sql.DB { + t.Helper() + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { db.Close() }) + _, err = db.Exec(`CREATE TABLE dbmem_content ( + hash TEXT PRIMARY KEY, + path TEXT NOT NULL, + context TEXT, + value TEXT, + created_at INTEGER NOT NULL DEFAULT 0, + last_accessed INTEGER NOT NULL DEFAULT 0 + )`) + if err != nil { + t.Fatal(err) + } + return db +} + +func TestGetReturnsFullContent(t *testing.T) { + db := openTestDB(t) + ctx := context.Background() + _, err := db.ExecContext(ctx, `INSERT INTO dbmem_content (hash, path, context, value, created_at, last_accessed) + VALUES ('abc123', '/docs/test.md', 'test-ctx', 'hello world', 1000, 2000)`) + if err != nil { + t.Fatal(err) + } + out, err := Get(ctx, db, "abc123") + if err != nil { + t.Fatal(err) + } + for _, want := range []string{"abc123", "/docs/test.md", "test-ctx", "hello world"} { + if !contains(out, want) { + t.Errorf("output missing %q:\n%s", want, out) + } + } +} + +func TestGetNotFound(t *testing.T) { + db := openTestDB(t) + _, err := Get(context.Background(), db, "nope") + if err == nil { + t.Fatal("expected error for missing hash") + } +} + +func TestQuerySelectWorks(t *testing.T) { + db := openTestDB(t) + ctx := context.Background() + _, err := db.ExecContext(ctx, `INSERT INTO dbmem_content (hash, path, value, created_at, last_accessed) + VALUES ('h1', '/a.md', 'content', 0, 0)`) + if err != nil { + t.Fatal(err) + } + out, err := Query(ctx, db, "SELECT hash, path FROM dbmem_content") + if err != nil { + t.Fatal(err) + } + if !contains(out, "h1") || !contains(out, "/a.md") { + t.Errorf("unexpected output: %s", out) + } +} + +func TestQueryRejectsWrites(t *testing.T) { + db := openTestDB(t) + _, err := Query(context.Background(), db, "INSERT INTO dbmem_content (hash, path, created_at, last_accessed) VALUES ('x', '/x', 0, 0)") + if err == nil { + t.Fatal("expected error for write statement under query_only") + } +} + +func contains(s, sub string) bool { + return len(s) >= len(sub) && (s == sub || len(sub) == 0 || + func() bool { + for i := 0; i <= len(s)-len(sub); i++ { + if s[i:i+len(sub)] == sub { + return true + } + } + return false + }()) +} + func TestResolveModelLocalWithoutAPIKey(t *testing.T) { cfg := config.Default() cfg.Embedding.Model = "/models/local.gguf"