From 0a20f1ac5b574b2841e2facd6fc1b5f74a18f573 Mon Sep 17 00:00:00 2001 From: Omer <639682+omercnet@users.noreply.github.com> Date: Thu, 25 Jun 2026 14:40:13 +0000 Subject: [PATCH 01/31] feat(config): load multiple databases --- internal/config/config.go | 38 ++-- internal/config/config_test.go | 15 ++ internal/config/databases.go | 249 +++++++++++++++++++++++ internal/config/databases_errors_test.go | 176 ++++++++++++++++ internal/config/databases_test.go | 157 ++++++++++++++ 5 files changed, 620 insertions(+), 15 deletions(-) create mode 100644 internal/config/databases.go create mode 100644 internal/config/databases_errors_test.go create mode 100644 internal/config/databases_test.go diff --git a/internal/config/config.go b/internal/config/config.go index 046bbd6..2ada393 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -14,10 +14,12 @@ import ( // Config is the fully-resolved application configuration. type Config struct { - Server Server - DB DB - StorePath string - RowCap int + Server Server + DB DB + Databases []DatabaseEntry + DefaultDatabaseID string + StorePath string + RowCap int } // Server holds HTTP server settings. @@ -51,15 +53,13 @@ type DB struct { // Load reads and validates configuration from the environment. func Load() (*Config, error) { - dsn, err := envOrFile("DATABASE_URL") + stmtTimeout := envDur("PGPEEK_STATEMENT_TIMEOUT", 30*time.Second) + iamAuth := envBool("PGPEEK_DB_IAM_AUTH", false) + region := env("PGPEEK_AWS_REGION", os.Getenv("AWS_REGION")) + databases, defaultDatabaseID, err := loadDatabases(iamAuth, region) if err != nil { return nil, err } - if dsn == "" { - return nil, errors.New("DATABASE_URL (or DATABASE_URL_FILE) is required") - } - - stmtTimeout := envDur("PGPEEK_STATEMENT_TIMEOUT", 30*time.Second) c := &Config{ Server: Server{ @@ -72,15 +72,20 @@ func Load() (*Config, error) { TLSKeyFile: os.Getenv("PGPEEK_TLS_KEY_FILE"), }, DB: DB{ - DSN: dsn, + DSN: "", MaxConns: int32(envInt("PGPEEK_MAX_CONNS", 8)), StatementTimeout: stmtTimeout, IdleTxTimeout: envDur("PGPEEK_IDLE_TX_TIMEOUT", 30*time.Second), - IAMAuth: envBool("PGPEEK_DB_IAM_AUTH", false), - Region: env("PGPEEK_AWS_REGION", os.Getenv("AWS_REGION")), + IAMAuth: iamAuth, + Region: region, }, - StorePath: env("PGPEEK_STORE_PATH", "/data/pgpeek.db"), - RowCap: envInt("PGPEEK_ROW_CAP", 1000), + Databases: databases, + DefaultDatabaseID: defaultDatabaseID, + StorePath: env("PGPEEK_STORE_PATH", "/data/pgpeek.db"), + RowCap: envInt("PGPEEK_ROW_CAP", 1000), + } + if err := applyDefaultDatabase(c); err != nil { + return nil, err } if err := c.validate(); err != nil { return nil, err @@ -98,6 +103,9 @@ func (c *Config) validate() error { if c.DB.IAMAuth && c.DB.Region == "" { return errors.New("PGPEEK_DB_IAM_AUTH requires PGPEEK_AWS_REGION (or AWS_REGION)") } + if err := validateDatabases(c.Databases); err != nil { + return err + } if (c.Server.TLSCertFile == "") != (c.Server.TLSKeyFile == "") { return errors.New("PGPEEK_TLS_CERT_FILE and PGPEEK_TLS_KEY_FILE must be set together") } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 3e75d50..6f16e62 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -3,6 +3,7 @@ package config import ( "os" "path/filepath" + "strconv" "testing" "time" ) @@ -18,9 +19,17 @@ func clearEnv(t *testing.T) { "PGPEEK_MAX_CONNS", "PGPEEK_STATEMENT_TIMEOUT", "PGPEEK_IDLE_TX_TIMEOUT", "PGPEEK_DB_IAM_AUTH", "PGPEEK_AWS_REGION", "AWS_REGION", "PGPEEK_STORE_PATH", "PGPEEK_ROW_CAP", + "PGPEEK_DATABASE_URLS", "PGPEEK_DATABASE_IDS", "PGPEEK_DATABASE_NAMES", + "PGPEEK_DATABASES_FILE", "PGPEEK_DEFAULT_DATABASE", } { t.Setenv(k, "") } + for i := 1; i <= maxNumberedDatabases; i++ { + t.Setenv("PGPEEK_DATABASE_URL_"+strconv.Itoa(i), "") + t.Setenv("PGPEEK_DATABASE_URL_"+strconv.Itoa(i)+"_FILE", "") + t.Setenv("PGPEEK_DATABASE_ID_"+strconv.Itoa(i), "") + t.Setenv("PGPEEK_DATABASE_NAME_"+strconv.Itoa(i), "") + } } func TestLoad_Defaults(t *testing.T) { @@ -53,6 +62,12 @@ func TestLoad_Defaults(t *testing.T) { if c.Server.TLSEnabled() { t.Error("TLS should be disabled by default") } + if c.DefaultDatabaseID != "default" || len(c.Databases) != 1 { + t.Fatalf("default database registry = %q/%d", c.DefaultDatabaseID, len(c.Databases)) + } + if c.Databases[0].ID != "default" || c.Databases[0].Name != "Default" || c.Databases[0].DSN != c.DB.DSN { + t.Errorf("default database entry = %+v", c.Databases[0]) + } } func TestLoad_Overrides(t *testing.T) { diff --git a/internal/config/databases.go b/internal/config/databases.go new file mode 100644 index 0000000..fa989a1 --- /dev/null +++ b/internal/config/databases.go @@ -0,0 +1,249 @@ +package config + +import ( + "encoding/csv" + "encoding/json" + "errors" + "fmt" + "os" + "strings" +) + +const maxNumberedDatabases = 64 + +// DatabaseEntry holds one named database target. +type DatabaseEntry struct { + ID string + Name string + DSN string + IAMAuth bool + Region string +} + +func loadDatabases(globalIAMAuth bool, globalRegion string) ([]DatabaseEntry, string, error) { + entries := make([]DatabaseEntry, 0, 4) + fileDefault, err := appendDatabaseFileEntries(&entries, globalRegion) + if err != nil { + return nil, "", err + } + if err := appendListDatabaseEntries(&entries, globalIAMAuth, globalRegion); err != nil { + return nil, "", err + } + if err := appendNumberedDatabaseEntries(&entries, globalIAMAuth, globalRegion); err != nil { + return nil, "", err + } + if len(entries) == 0 { + dsn, err := envOrFile("DATABASE_URL") + if err != nil { + return nil, "", err + } + if dsn == "" { + return nil, "", errors.New("DATABASE_URL (or DATABASE_URL_FILE) is required") + } + entries = append(entries, DatabaseEntry{ID: "default", Name: "Default", DSN: dsn, IAMAuth: globalIAMAuth, Region: globalRegion}) + } + defaultID := env("PGPEEK_DEFAULT_DATABASE", fileDefault) + if defaultID == "" && len(entries) > 0 { + defaultID = entries[0].ID + } + return entries, defaultID, nil +} + +func appendDatabaseFileEntries(entries *[]DatabaseEntry, globalRegion string) (string, error) { + if os.Getenv("PGPEEK_DATABASES_FILE") == "" { + return "", nil + } + b, err := readOperatorFile(os.Getenv("PGPEEK_DATABASES_FILE"), "PGPEEK_DATABASES_FILE") + if err != nil { + return "", err + } + var file struct { + Default string `json:"default"` + DefaultDatabaseID string `json:"defaultDatabaseID"` + Databases []struct { + ID, Name, URL, URLFile, Region string + IAMAuth bool + } `json:"databases"` + } + if err := json.Unmarshal(b, &file); err != nil { + return "", fmt.Errorf("parse PGPEEK_DATABASES_FILE: %w", err) + } + for i, item := range file.Databases { + dsn, err := databaseFileDSN(item.URL, item.URLFile) + if err != nil { + return "", err + } + *entries = append(*entries, DatabaseEntry{ID: item.ID, Name: entryName(item.Name, i+1), DSN: dsn, IAMAuth: item.IAMAuth, Region: pick(item.Region, globalRegion)}) + } + if file.DefaultDatabaseID != "" { + return file.DefaultDatabaseID, nil + } + return file.Default, nil +} + +func databaseFileDSN(url, urlFile string) (string, error) { + dsn := strings.TrimSpace(url) + if dsn != "" || urlFile == "" { + return dsn, nil + } + b, err := readOperatorFile(urlFile, "database urlFile") + if err != nil { + return "", err + } + return strings.TrimSpace(string(b)), nil +} + +func appendListDatabaseEntries(entries *[]DatabaseEntry, globalIAMAuth bool, globalRegion string) error { + urls, err := splitDatabaseList(os.Getenv("PGPEEK_DATABASE_URLS")) + if err != nil || len(urls) == 0 { + return err + } + ids, err := splitDatabaseList(os.Getenv("PGPEEK_DATABASE_IDS")) + if err != nil { + return err + } + names, err := splitDatabaseList(os.Getenv("PGPEEK_DATABASE_NAMES")) + if err != nil { + return err + } + for i, url := range urls { + *entries = append(*entries, DatabaseEntry{ID: listValue(ids, i, fmt.Sprintf("db%d", i+1)), Name: entryName(listValue(names, i, ""), i+1), DSN: strings.TrimSpace(url), IAMAuth: globalIAMAuth, Region: globalRegion}) + } + return nil +} + +func appendNumberedDatabaseEntries(entries *[]DatabaseEntry, globalIAMAuth bool, globalRegion string) error { + for i := 1; i <= maxNumberedDatabases; i++ { + key := fmt.Sprintf("PGPEEK_DATABASE_URL_%d", i) + dsn, err := envOrFile(key) + if err != nil { + return err + } + if dsn == "" { + break + } + *entries = append(*entries, DatabaseEntry{ID: env(fmt.Sprintf("PGPEEK_DATABASE_ID_%d", i), fmt.Sprintf("db%d", i)), Name: entryName(os.Getenv(fmt.Sprintf("PGPEEK_DATABASE_NAME_%d", i)), i), DSN: dsn, IAMAuth: globalIAMAuth, Region: globalRegion}) + } + return nil +} + +func applyDefaultDatabase(c *Config) error { + for _, db := range c.Databases { + if db.ID == c.DefaultDatabaseID { + c.DB.DSN = db.DSN + c.DB.IAMAuth = db.IAMAuth + c.DB.Region = db.Region + return nil + } + } + return fmt.Errorf("PGPEEK_DEFAULT_DATABASE %q does not match a configured database", c.DefaultDatabaseID) +} + +func validateDatabases(databases []DatabaseEntry) error { + if len(databases) == 0 { + return errors.New("at least one database is required") + } + seen := make(map[string]struct{}, len(databases)) + for _, db := range databases { + if err := validateDatabaseEntry(db, seen); err != nil { + return err + } + } + return nil +} + +func validateDatabaseEntry(db DatabaseEntry, seen map[string]struct{}) error { + if db.ID == "" { + return errors.New("database ID must be non-empty") + } + if !isDatabaseID(db.ID) { + return fmt.Errorf("database ID %q must contain only letters, numbers, dot, underscore, or dash", db.ID) + } + if _, ok := seen[db.ID]; ok { + return fmt.Errorf("database ID %q is duplicated", db.ID) + } + seen[db.ID] = struct{}{} + if db.DSN == "" { + return fmt.Errorf("database %q requires url or urlFile", db.ID) + } + if db.IAMAuth && db.Region == "" { + return fmt.Errorf("database %q IAM auth requires region", db.ID) + } + return nil +} + +func splitDatabaseList(value string) ([]string, error) { + if strings.TrimSpace(value) == "" { + return nil, nil + } + reader := csv.NewReader(strings.NewReader(normalizeListSeparators(value))) + reader.TrimLeadingSpace = true + items, err := reader.Read() + if err != nil { + return nil, fmt.Errorf("parse database list: %w", err) + } + result := make([]string, 0, len(items)) + for _, item := range items { + if trimmed := strings.TrimSpace(item); trimmed != "" { + result = append(result, trimmed) + } + } + return result, nil +} + +func normalizeListSeparators(value string) string { + inQuotes := false + var b strings.Builder + for i := 0; i < len(value); i++ { + ch := value[i] + if ch == '"' { + inQuotes = !inQuotes + } + if ch == ';' && !inQuotes { + ch = ',' + } + b.WriteByte(ch) + } + return b.String() +} + +func listValue(values []string, index int, fallback string) string { + if index < len(values) && values[index] != "" { + return values[index] + } + return fallback +} + +func entryName(value string, index int) string { + if strings.TrimSpace(value) != "" { + return strings.TrimSpace(value) + } + return fmt.Sprintf("Database %d", index) +} + +func pick(value, fallback string) string { + if strings.TrimSpace(value) != "" { + return strings.TrimSpace(value) + } + return fallback +} + +func isDatabaseID(value string) bool { + for _, r := range value { + if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_' || r == '-' || r == '.' { + continue + } + return false + } + return value != "" +} + +func readOperatorFile(path, label string) ([]byte, error) { + // The path is supplied by the operator (env var / mounted-secret convention), + // not by any request input — this is the intended use. + b, err := os.ReadFile(path) //nolint:gosec // operator-controlled config/secret path + if err != nil { + return nil, fmt.Errorf("read %s: %w", label, err) + } + return b, nil +} diff --git a/internal/config/databases_errors_test.go b/internal/config/databases_errors_test.go new file mode 100644 index 0000000..4020944 --- /dev/null +++ b/internal/config/databases_errors_test.go @@ -0,0 +1,176 @@ +package config + +import ( + "os" + "path/filepath" + "strconv" + "strings" + "testing" +) + +func TestLoadDatabases_returns_required_error_when_no_sources_configured(t *testing.T) { + // Given: no database URL source is configured. + clearEnv(t) + + // When: database entries are loaded. + _, _, err := loadDatabases(false, "") + + // Then: configuration fails with the public missing-URL message. + if err == nil || !strings.Contains(err.Error(), "DATABASE_URL") { + t.Fatalf("loadDatabases error = %v, want DATABASE_URL requirement", err) + } +} + +func TestAppendDatabaseFileEntries_uses_default_database_id_when_present(t *testing.T) { + // Given: a mounted database config uses the explicit defaultDatabaseID field. + clearEnv(t) + dir := t.TempDir() + configPath := filepath.Join(dir, "databases.json") + body := `{"default":"legacy","defaultDatabaseID":"primary","databases":[{"id":"primary","url":"postgres://u:p@h/db"}]}` + if err := os.WriteFile(configPath, []byte(body), 0o600); err != nil { + t.Fatal(err) + } + t.Setenv("PGPEEK_DATABASES_FILE", configPath) + entries := []DatabaseEntry{} + + // When: database entries are loaded from the file. + defaultID, err := appendDatabaseFileEntries(&entries, "us-east-1") + + // Then: defaultDatabaseID takes precedence over legacy default. + if err != nil { + t.Fatalf("appendDatabaseFileEntries: %v", err) + } + if defaultID != "primary" || len(entries) != 1 || entries[0].Region != "us-east-1" { + t.Fatalf("default=%q entries=%+v", defaultID, entries) + } +} + +func TestAppendDatabaseFileEntries_returns_parse_error_when_file_invalid(t *testing.T) { + // Given: the mounted database config is not valid JSON. + clearEnv(t) + configPath := filepath.Join(t.TempDir(), "databases.json") + if err := os.WriteFile(configPath, []byte(`{"databases":`), 0o600); err != nil { + t.Fatal(err) + } + t.Setenv("PGPEEK_DATABASES_FILE", configPath) + + // When: databases are loaded. + _, _, err := loadDatabases(false, "") + + // Then: the parse error names the operator-provided file variable. + if err == nil || !strings.Contains(err.Error(), "parse PGPEEK_DATABASES_FILE") { + t.Fatalf("appendDatabaseFileEntries error = %v", err) + } +} + +func TestAppendDatabaseFileEntries_returns_read_error_when_file_missing(t *testing.T) { + // Given: the mounted database config file is missing. + clearEnv(t) + t.Setenv("PGPEEK_DATABASES_FILE", filepath.Join(t.TempDir(), "missing.json")) + + // When: databases are loaded. + _, _, err := loadDatabases(false, "") + + // Then: the operator file read error is returned. + if err == nil || !strings.Contains(err.Error(), "read PGPEEK_DATABASES_FILE") { + t.Fatalf("loadDatabases error = %v", err) + } +} + +func TestDatabaseFileDSN_returns_read_error_when_url_file_missing(t *testing.T) { + // Given: an entry references a missing urlFile. + clearEnv(t) + dir := t.TempDir() + missing := filepath.Join(dir, "missing-url") + configPath := filepath.Join(dir, "databases.json") + body := `{"databases":[{"id":"primary","urlFile":` + strconv.Quote(missing) + `}]}` + if err := os.WriteFile(configPath, []byte(body), 0o600); err != nil { + t.Fatal(err) + } + t.Setenv("PGPEEK_DATABASES_FILE", configPath) + + // When: databases are loaded. + _, _, err := loadDatabases(false, "") + + // Then: the read error is returned without exposing a DSN. + if err == nil || !strings.Contains(err.Error(), "read database urlFile") { + t.Fatalf("databaseFileDSN error = %v", err) + } +} + +func TestAppendListDatabaseEntries_returns_parse_error_when_csv_invalid(t *testing.T) { + // Given: list-form URLs contain malformed CSV quoting. + clearEnv(t) + t.Setenv("PGPEEK_DATABASE_URLS", `"postgres://u:p@h/db`) + + // When: databases are loaded. + _, _, err := loadDatabases(false, "") + + // Then: parsing fails before any entry is appended. + if err == nil || !strings.Contains(err.Error(), "parse database list") { + t.Fatalf("appendListDatabaseEntries error = %v", err) + } +} + +func TestAppendListDatabaseEntries_returns_parse_error_when_ids_or_names_invalid(t *testing.T) { + tests := []struct { + name string + env string + }{ + {name: "ids", env: "PGPEEK_DATABASE_IDS"}, + {name: "names", env: "PGPEEK_DATABASE_NAMES"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Given: list-form URLs are valid but a companion list has malformed CSV quoting. + clearEnv(t) + t.Setenv("PGPEEK_DATABASE_URLS", "postgres://u:p@h/db") + t.Setenv(tt.env, `"unterminated`) + + // When: databases are loaded. + _, _, err := loadDatabases(false, "") + + // Then: companion-list parsing fails. + if err == nil || !strings.Contains(err.Error(), "parse database list") { + t.Fatalf("loadDatabases error = %v", err) + } + }) + } +} + +func TestAppendNumberedDatabaseEntries_returns_file_error_when_url_file_missing(t *testing.T) { + // Given: numbered env configuration points at a missing mounted secret. + clearEnv(t) + t.Setenv("PGPEEK_DATABASE_URL_1_FILE", filepath.Join(t.TempDir(), "missing-url")) + + // When: databases are loaded. + _, _, err := loadDatabases(false, "") + + // Then: the read error is returned. + if err == nil || !strings.Contains(err.Error(), "read PGPEEK_DATABASE_URL_1") { + t.Fatalf("appendNumberedDatabaseEntries error = %v", err) + } +} + +func TestValidateDatabases_rejects_empty_or_incomplete_entries(t *testing.T) { + tests := []struct { + name string + dbs []DatabaseEntry + }{ + {name: "empty registry", dbs: nil}, + {name: "empty id", dbs: []DatabaseEntry{{DSN: "postgres://u:p@h/db"}}}, + {name: "empty dsn", dbs: []DatabaseEntry{{ID: "primary"}}}, + {name: "iam missing region", dbs: []DatabaseEntry{{ID: "primary", DSN: "postgres://u:p@h/db", IAMAuth: true}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // When: database entries are validated. + err := validateDatabases(tt.dbs) + + // Then: invalid entries are rejected. + if err == nil { + t.Fatal("validateDatabases expected error") + } + }) + } +} diff --git a/internal/config/databases_test.go b/internal/config/databases_test.go new file mode 100644 index 0000000..2bac074 --- /dev/null +++ b/internal/config/databases_test.go @@ -0,0 +1,157 @@ +package config + +import ( + "os" + "path/filepath" + "strconv" + "strings" + "testing" +) + +func TestLoad_DatabaseURLsList(t *testing.T) { + clearEnv(t) + t.Setenv("PGPEEK_DATABASE_URLS", "postgres://u:secret@h/one, postgres://u:secret@h/two;postgres://u:secret@h/three") + t.Setenv("PGPEEK_DATABASE_IDS", "analytics, billing") + t.Setenv("PGPEEK_DATABASE_NAMES", "Analytics, Billing") + t.Setenv("PGPEEK_DEFAULT_DATABASE", "billing") + + c, err := Load() + if err != nil { + t.Fatalf("Load: %v", err) + } + if c.DefaultDatabaseID != "billing" || c.DB.DSN != "postgres://u:secret@h/two" { + t.Fatalf("default database selection failed; id = %q", c.DefaultDatabaseID) + } + entries := c.Databases + if len(entries) != 3 { + t.Fatalf("databases = %d", len(entries)) + } + if entries[0].ID != "analytics" || entries[0].Name != "Analytics" { + t.Errorf("first entry id/name = %q/%q", entries[0].ID, entries[0].Name) + } + if entries[1].ID != "billing" || entries[1].Name != "Billing" { + t.Errorf("second entry id/name = %q/%q", entries[1].ID, entries[1].Name) + } + if entries[2].ID != "db3" || entries[2].Name != "Database 3" { + t.Errorf("derived entry id/name = %q/%q", entries[2].ID, entries[2].Name) + } + if strings.Contains(entries[2].Name, "secret") || strings.Contains(entries[2].Name, "postgres://") { + t.Errorf("derived name exposes DSN: %q", entries[2].Name) + } +} + +func TestLoad_DatabaseURLsListQuotedSeparators(t *testing.T) { + clearEnv(t) + t.Setenv("PGPEEK_DATABASE_URLS", `"postgres://u:p@h/one?options=a,b";"postgres://u:p@h/two?options=x;y"`) + t.Setenv("PGPEEK_DATABASE_IDS", `"one";"two"`) + t.Setenv("PGPEEK_DATABASE_NAMES", `"One, Primary";"Two; Replica"`) + t.Setenv("PGPEEK_DEFAULT_DATABASE", "two") + + c, err := Load() + if err != nil { + t.Fatalf("Load: %v", err) + } + if len(c.Databases) != 2 { + t.Fatalf("databases = %d", len(c.Databases)) + } + if c.Databases[0].ID != "one" || c.Databases[0].Name != "One, Primary" || c.Databases[0].DSN != "postgres://u:p@h/one?options=a,b" { + t.Errorf("quoted comma entry = id %q dsn %q", c.Databases[0].ID, c.Databases[0].DSN) + } + if c.Databases[1].ID != "two" || c.Databases[1].Name != "Two; Replica" || c.DB.DSN != "postgres://u:p@h/two?options=x;y" { + t.Errorf("quoted semicolon entry = id %q", c.Databases[1].ID) + } +} + +func TestLoad_NumberedDatabaseURLs(t *testing.T) { + clearEnv(t) + dir := t.TempDir() + path := filepath.Join(dir, "two") + if err := os.WriteFile(path, []byte(" postgres://u:secret@h/two\n"), 0o600); err != nil { + t.Fatal(err) + } + t.Setenv("PGPEEK_DATABASE_URL_1", "postgres://u:secret@h/one") + t.Setenv("PGPEEK_DATABASE_ID_1", "one") + t.Setenv("PGPEEK_DATABASE_NAME_1", "One") + t.Setenv("PGPEEK_DATABASE_URL_2_FILE", path) + t.Setenv("PGPEEK_DATABASE_ID_2", "two") + + c, err := Load() + if err != nil { + t.Fatalf("Load: %v", err) + } + if len(c.Databases) != 2 { + t.Fatalf("databases = %d", len(c.Databases)) + } + if c.Databases[1].ID != "two" || c.Databases[1].Name != "Database 2" { + t.Errorf("file-backed numbered entry id/name = %q/%q", c.Databases[1].ID, c.Databases[1].Name) + } + if c.Databases[1].DSN != "postgres://u:secret@h/two" { + t.Error("file-backed numbered entry DSN was not read from file") + } +} + +func TestLoad_DatabasesJSONFile(t *testing.T) { + clearEnv(t) + dir := t.TempDir() + secretPath := filepath.Join(dir, "west-url") + if err := os.WriteFile(secretPath, []byte("postgres://u:secret@h/west\n"), 0o600); err != nil { + t.Fatal(err) + } + configPath := filepath.Join(dir, "databases.json") + body := `{"default":"west","databases":[{"id":"east","name":"East","url":"postgres://u:secret@h/east","iamAuth":true,"region":"us-east-1"},{"id":"west","name":"West","urlFile":` + strconv.Quote(secretPath) + `,"iamAuth":true}]}` + if err := os.WriteFile(configPath, []byte(body), 0o600); err != nil { + t.Fatal(err) + } + t.Setenv("PGPEEK_DATABASES_FILE", configPath) + t.Setenv("AWS_REGION", "us-west-2") + + c, err := Load() + if err != nil { + t.Fatalf("Load: %v", err) + } + if c.DefaultDatabaseID != "west" || c.DB.DSN != "postgres://u:secret@h/west" || !c.DB.IAMAuth || c.DB.Region != "us-west-2" { + t.Fatalf("default DB selection failed; id = %q iam = %v region = %q", c.DefaultDatabaseID, c.DB.IAMAuth, c.DB.Region) + } + if c.Databases[0].Region != "us-east-1" || c.Databases[1].Region != "us-west-2" { + t.Errorf("entry regions = %q/%q", c.Databases[0].Region, c.Databases[1].Region) + } +} + +func TestLoad_DatabaseValidationErrors(t *testing.T) { + cases := []struct { + name string + setup func(t *testing.T) + secretDSN string + }{ + {"duplicate id", func(t *testing.T) { + t.Setenv("PGPEEK_DATABASE_URLS", "postgres://u:secret@h/one,postgres://u:secret@h/two") + t.Setenv("PGPEEK_DATABASE_IDS", "same,same") + }, "postgres://u:secret@h/one"}, + {"invalid id", func(t *testing.T) { + t.Setenv("PGPEEK_DATABASE_URLS", "postgres://u:secret@h/one") + t.Setenv("PGPEEK_DATABASE_IDS", "bad id") + }, "postgres://u:secret@h/one"}, + {"unknown default", func(t *testing.T) { + t.Setenv("PGPEEK_DATABASE_URLS", "postgres://u:secret@h/one") + t.Setenv("PGPEEK_DEFAULT_DATABASE", "missing") + }, "postgres://u:secret@h/one"}, + {"iam missing region", func(t *testing.T) { + t.Setenv("PGPEEK_DATABASE_URLS", "postgres://u:secret@h/one") + t.Setenv("PGPEEK_DATABASES_FILE", "") + t.Setenv("PGPEEK_DB_IAM_AUTH", "true") + }, "postgres://u:secret@h/one"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + clearEnv(t) + tc.setup(t) + _, err := Load() + if err == nil { + t.Fatal("expected validation error") + } + if strings.Contains(err.Error(), tc.secretDSN) || strings.Contains(err.Error(), "secret") { + t.Fatalf("error exposes DSN secret: %v", err) + } + }) + } +} From d8eecbf4c18d44398c7904ae545f4bbba24ed5f1 Mon Sep 17 00:00:00 2001 From: Omer <639682+omercnet@users.noreply.github.com> Date: Thu, 25 Jun 2026 14:40:26 +0000 Subject: [PATCH 02/31] feat(db): add database pool registry --- internal/db/registry.go | 109 +++++++++++++++++ internal/db/registry_test.go | 221 +++++++++++++++++++++++++++++++++++ 2 files changed, 330 insertions(+) create mode 100644 internal/db/registry.go create mode 100644 internal/db/registry_test.go diff --git a/internal/db/registry.go b/internal/db/registry.go new file mode 100644 index 0000000..253a4ff --- /dev/null +++ b/internal/db/registry.go @@ -0,0 +1,109 @@ +package db + +import ( + "context" + "errors" + "fmt" +) + +var ErrPoolNotFound = errors.New("database pool not found") + +type PoolMetadata struct { + ID string `json:"id"` + Name string `json:"name"` +} + +type RegistryEntry struct { + ID string + Name string + Pool *Pool + Default bool +} + +type Registry struct { + entries []registeredPool + byID map[string]*Pool + defaultID string +} + +type registeredPool struct { + metadata PoolMetadata + pool *Pool +} + +func NewRegistry(entries []RegistryEntry) (*Registry, error) { + if len(entries) == 0 { + return nil, errors.New("database registry requires at least one pool") + } + + registered := make([]registeredPool, 0, len(entries)) + byID := make(map[string]*Pool, len(entries)) + defaultID := "" + for _, entry := range entries { + if entry.ID == "" { + return nil, errors.New("database registry entry id is required") + } + if entry.Name == "" { + return nil, fmt.Errorf("database registry entry %q name is required", entry.ID) + } + if entry.Pool == nil { + return nil, fmt.Errorf("database registry entry %q pool is required", entry.ID) + } + if _, exists := byID[entry.ID]; exists { + return nil, fmt.Errorf("database registry entry %q is duplicated", entry.ID) + } + if entry.Default { + if defaultID != "" { + return nil, errors.New("database registry has multiple defaults") + } + defaultID = entry.ID + } + + byID[entry.ID] = entry.Pool + registered = append(registered, registeredPool{ + metadata: PoolMetadata{ID: entry.ID, Name: entry.Name}, + pool: entry.Pool, + }) + } + if defaultID == "" { + defaultID = entries[0].ID + } + + return &Registry{entries: registered, byID: byID, defaultID: defaultID}, nil +} + +func (r *Registry) List() []PoolMetadata { + metadata := make([]PoolMetadata, 0, len(r.entries)) + for _, entry := range r.entries { + metadata = append(metadata, entry.metadata) + } + return metadata +} + +func (r *Registry) DefaultID() string { return r.defaultID } + +func (r *Registry) Pool(id string) (*Pool, error) { + if id == "" { + id = r.defaultID + } + pool, ok := r.byID[id] + if !ok { + return nil, fmt.Errorf("%w: %s", ErrPoolNotFound, id) + } + return pool, nil +} + +func (r *Registry) Ping(ctx context.Context) error { + for _, entry := range r.entries { + if err := entry.pool.Ping(ctx); err != nil { + return fmt.Errorf("ping database pool %q: %w", entry.metadata.ID, err) + } + } + return nil +} + +func (r *Registry) Close() { + for i := len(r.entries) - 1; i >= 0; i-- { + r.entries[i].pool.Close() + } +} diff --git a/internal/db/registry_test.go b/internal/db/registry_test.go new file mode 100644 index 0000000..6a19770 --- /dev/null +++ b/internal/db/registry_test.go @@ -0,0 +1,221 @@ +package db + +import ( + "context" + "errors" + "reflect" + "testing" +) + +func TestNewRegistry_lists_metadata_without_dsn_when_entries_have_private_names(t *testing.T) { + // Given: registry entries carry public ids/names and pool handles only. + registry, err := NewRegistry([]RegistryEntry{ + {ID: "billing", Name: "Billing", Pool: &Pool{pool: &fakePool{}, rowCap: 100}, Default: true}, + {ID: "support", Name: "Support", Pool: &Pool{pool: &fakePool{}, rowCap: 200}}, + }) + if err != nil { + t.Fatalf("NewRegistry: %v", err) + } + + // When: callers request API-safe metadata. + got := registry.List() + + // Then: only id/name are exposed, in entry order. + want := []PoolMetadata{{ID: "billing", Name: "Billing"}, {ID: "support", Name: "Support"}} + if !reflect.DeepEqual(got, want) { + t.Fatalf("List() = %#v, want %#v", got, want) + } +} + +func TestRegistry_pool_returns_default_when_selection_empty(t *testing.T) { + // Given: registry has explicit default and another pool. + defaultPool := &Pool{pool: &fakePool{}, rowCap: 100} + otherPool := &Pool{pool: &fakePool{}, rowCap: 200} + registry, err := NewRegistry([]RegistryEntry{ + {ID: "billing", Name: "Billing", Pool: defaultPool, Default: true}, + {ID: "support", Name: "Support", Pool: otherPool}, + }) + if err != nil { + t.Fatalf("NewRegistry: %v", err) + } + + // When: caller selects empty id. + got, err := registry.Pool("") + + // Then: default pool is returned, not an unknown-id error. + if err != nil { + t.Fatalf("Pool(empty): %v", err) + } + if got != defaultPool { + t.Fatalf("Pool(empty) returned %p, want default %p", got, defaultPool) + } + if registry.DefaultID() != "billing" { + t.Fatalf("DefaultID() = %q, want billing", registry.DefaultID()) + } +} + +func TestRegistry_pool_returns_requested_pool_when_id_known(t *testing.T) { + // Given: registry has two named pools. + defaultPool := &Pool{pool: &fakePool{}, rowCap: 100} + otherPool := &Pool{pool: &fakePool{}, rowCap: 200} + registry, err := NewRegistry([]RegistryEntry{ + {ID: "billing", Name: "Billing", Pool: defaultPool}, + {ID: "support", Name: "Support", Pool: otherPool}, + }) + if err != nil { + t.Fatalf("NewRegistry: %v", err) + } + + // When: caller selects a known id. + got, err := registry.Pool("support") + + // Then: matching pool is returned. + if err != nil { + t.Fatalf("Pool(support): %v", err) + } + if got != otherPool { + t.Fatalf("Pool(support) returned %p, want %p", got, otherPool) + } +} + +func TestRegistry_pool_returns_not_found_when_non_empty_id_unknown(t *testing.T) { + // Given: registry has one pool. + registry, err := NewRegistry([]RegistryEntry{{ID: "billing", Name: "Billing", Pool: &Pool{pool: &fakePool{}, rowCap: 100}}}) + if err != nil { + t.Fatalf("NewRegistry: %v", err) + } + + // When: caller selects an unknown non-empty id. + _, err = registry.Pool("missing") + + // Then: caller can distinguish unknown id from default selection. + if !errors.Is(err, ErrPoolNotFound) { + t.Fatalf("Pool(missing) error = %v, want ErrPoolNotFound", err) + } +} + +func TestRegistry_ping_visits_pools_in_metadata_order_and_stops_on_error(t *testing.T) { + // Given: second pool fails ping. + first := &orderedPool{name: "first"} + second := &orderedPool{name: "second", pingErr: errors.New("down")} + third := &orderedPool{name: "third"} + registry, err := NewRegistry([]RegistryEntry{ + {ID: "first", Name: "First", Pool: &Pool{pool: first, rowCap: 10}}, + {ID: "second", Name: "Second", Pool: &Pool{pool: second, rowCap: 10}}, + {ID: "third", Name: "Third", Pool: &Pool{pool: third, rowCap: 10}}, + }) + if err != nil { + t.Fatalf("NewRegistry: %v", err) + } + + // When: readiness pings all pools. + err = registry.Ping(context.Background()) + + // Then: order is deterministic and ping stops at failing pool. + if err == nil { + t.Fatal("Ping() expected error") + } + got := append([]string{}, first.pings...) + got = append(got, second.pings...) + if want := []string{"first", "second"}; !reflect.DeepEqual(got, want) { + t.Fatalf("ping order = %v, want %v", got, want) + } + if len(third.pings) != 0 { + t.Fatalf("third pool pinged after error: %v", third.pings) + } +} + +func TestRegistry_ping_returns_nil_when_all_pools_healthy(t *testing.T) { + // Given: every registered pool can be pinged. + first := &orderedPool{name: "first"} + second := &orderedPool{name: "second"} + registry, err := NewRegistry([]RegistryEntry{ + {ID: "first", Name: "First", Pool: &Pool{pool: first, rowCap: 10}}, + {ID: "second", Name: "Second", Pool: &Pool{pool: second, rowCap: 10}}, + }) + if err != nil { + t.Fatalf("NewRegistry: %v", err) + } + + // When: readiness pings the registry. + err = registry.Ping(context.Background()) + + // Then: all pools are checked and no error is returned. + if err != nil { + t.Fatalf("Ping: %v", err) + } + got := append([]string{}, first.pings...) + got = append(got, second.pings...) + if want := []string{"first", "second"}; !reflect.DeepEqual(got, want) { + t.Fatalf("ping order = %v, want %v", got, want) + } +} + +func TestRegistry_close_visits_all_pools_in_reverse_metadata_order(t *testing.T) { + // Given: registry has three pools. + closeOrder := make([]string, 0, 3) + registry, err := NewRegistry([]RegistryEntry{ + {ID: "first", Name: "First", Pool: &Pool{pool: &orderedPool{name: "first", closes: &closeOrder}, rowCap: 10}}, + {ID: "second", Name: "Second", Pool: &Pool{pool: &orderedPool{name: "second", closes: &closeOrder}, rowCap: 10}}, + {ID: "third", Name: "Third", Pool: &Pool{pool: &orderedPool{name: "third", closes: &closeOrder}, rowCap: 10}}, + }) + if err != nil { + t.Fatalf("NewRegistry: %v", err) + } + + // When: registry closes. + registry.Close() + + // Then: every pool closes in deterministic reverse order. + want := []string{"third", "second", "first"} + if !reflect.DeepEqual(closeOrder, want) { + t.Fatalf("close order = %v, want %v", closeOrder, want) + } +} + +func TestNewRegistry_rejects_invalid_entries(t *testing.T) { + tests := []struct { + name string + entries []RegistryEntry + }{ + {name: "empty registry", entries: nil}, + {name: "empty id", entries: []RegistryEntry{{ID: "", Name: "Billing", Pool: &Pool{pool: &fakePool{}}}}}, + {name: "empty name", entries: []RegistryEntry{{ID: "billing", Name: "", Pool: &Pool{pool: &fakePool{}}}}}, + {name: "nil pool", entries: []RegistryEntry{{ID: "billing", Name: "Billing", Pool: nil}}}, + {name: "duplicate id", entries: []RegistryEntry{{ID: "billing", Name: "Billing", Pool: &Pool{pool: &fakePool{}}}, {ID: "billing", Name: "Other", Pool: &Pool{pool: &fakePool{}}}}}, + {name: "multiple defaults", entries: []RegistryEntry{{ID: "billing", Name: "Billing", Pool: &Pool{pool: &fakePool{}}, Default: true}, {ID: "support", Name: "Support", Pool: &Pool{pool: &fakePool{}}, Default: true}}}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Given: invalid registry input. + + // When: registry is constructed. + _, err := NewRegistry(tc.entries) + + // Then: construction fails before runtime lookup. + if err == nil { + t.Fatal("NewRegistry expected error") + } + }) + } +} + +type orderedPool struct { + fakePool + name string + pingErr error + pings []string + closes *[]string +} + +func (p *orderedPool) Ping(context.Context) error { + p.pings = append(p.pings, p.name) + return p.pingErr +} + +func (p *orderedPool) Close() { + if p.closes != nil { + *p.closes = append(*p.closes, p.name) + } +} From 82fff1c0b48fe98851e15f21669bdb82f5767190 Mon Sep 17 00:00:00 2001 From: Omer <639682+omercnet@users.noreply.github.com> Date: Thu, 25 Jun 2026 14:40:33 +0000 Subject: [PATCH 03/31] refactor(server): split route helpers --- internal/server/helpers_test.go | 112 ++++ internal/server/http_helpers.go | 86 +++ internal/server/middleware.go | 55 ++ internal/server/server.go | 461 +------------- internal/server/server_routes_test.go | 80 +++ internal/server/server_test.go | 837 -------------------------- 6 files changed, 340 insertions(+), 1291 deletions(-) create mode 100644 internal/server/helpers_test.go create mode 100644 internal/server/http_helpers.go create mode 100644 internal/server/middleware.go create mode 100644 internal/server/server_routes_test.go delete mode 100644 internal/server/server_test.go diff --git a/internal/server/helpers_test.go b/internal/server/helpers_test.go new file mode 100644 index 0000000..bb98c32 --- /dev/null +++ b/internal/server/helpers_test.go @@ -0,0 +1,112 @@ +package server + +import ( + "context" + "encoding/json" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "path/filepath" + "strings" + "testing" + "testing/fstest" + "time" + + "github.com/descope/pgpeek/internal/db" + "github.com/descope/pgpeek/internal/store" +) + +type fakeQuerier struct { + result *db.Result + err error + pingErr error + called bool + lastSQL string + tables []db.TableInfo + cols []db.ColumnInfo + fks []db.ForeignKey + catErr error + lastQuery db.TableQuery + lastArgs struct { + schema, table string + limit, offset int + } +} + +func (f *fakeQuerier) Query(_ context.Context, sql string) (*db.Result, error) { + f.called = true + f.lastSQL = sql + return f.result, f.err +} + +func (f *fakeQuerier) Tables(context.Context) ([]db.TableInfo, error) { + return f.tables, f.catErr +} + +func (f *fakeQuerier) Columns(_ context.Context, schema, table string) ([]db.ColumnInfo, error) { + f.lastArgs.schema, f.lastArgs.table = schema, table + return f.cols, f.catErr +} + +func (f *fakeQuerier) ForeignKeys(_ context.Context, _, _ string) ([]db.ForeignKey, error) { + return f.fks, f.catErr +} + +func (f *fakeQuerier) TableRows(_ context.Context, q db.TableQuery) (*db.Result, error) { + f.lastQuery = q + f.lastArgs.schema, f.lastArgs.table = q.Schema, q.Table + f.lastArgs.limit, f.lastArgs.offset = q.Limit, q.Offset + return f.result, f.err +} + +func (f *fakeQuerier) RowCap() int { return 1000 } + +func (f *fakeQuerier) Ping(_ context.Context) error { return f.pingErr } + +func newTestServer(t *testing.T, q Querier) (*httptest.Server, *store.Store) { + t.Helper() + st, err := store.Open(filepath.Join(t.TempDir(), "t.db")) + if err != nil { + t.Fatalf("store.Open: %v", err) + } + t.Cleanup(func() { _ = st.Close() }) + + web := fstest.MapFS{ + "index.html": &fstest.MapFile{Data: []byte("pgpeek")}, + "app.js": &fstest.MapFile{Data: []byte("// js")}, + } + log := slog.New(slog.NewTextHandler(io.Discard, nil)) + srv := New(q, st, web, log, 5*time.Second) + ts := httptest.NewServer(srv.Routes()) + t.Cleanup(ts.Close) + return ts, st +} + +func post(t *testing.T, ts *httptest.Server, path, body string) *http.Response { + t.Helper() + resp, err := http.Post(ts.URL+path, "application/json", strings.NewReader(body)) + if err != nil { + t.Fatalf("POST %s: %v", path, err) + } + return resp +} + +func decode[T any](t *testing.T, resp *http.Response) T { + t.Helper() + defer resp.Body.Close() + var v T + if err := json.NewDecoder(resp.Body).Decode(&v); err != nil { + t.Fatalf("decode: %v", err) + } + return v +} + +func okResult() *db.Result { + return &db.Result{ + Columns: []string{"n"}, + Rows: [][]any{{int64(1)}}, + RowCount: 1, + ElapsedMS: 1, + } +} diff --git a/internal/server/http_helpers.go b/internal/server/http_helpers.go new file mode 100644 index 0000000..6d1fe93 --- /dev/null +++ b/internal/server/http_helpers.go @@ -0,0 +1,86 @@ +package server + +import ( + "encoding/json" + "net/http" + "strconv" + "strings" + + "github.com/descope/pgpeek/internal/db" +) + +const maxBodyBytes = 1 << 20 + +func decodeBody(w http.ResponseWriter, r *http.Request, v any) bool { + r.Body = http.MaxBytesReader(w, r.Body, maxBodyBytes) + dec := json.NewDecoder(r.Body) + dec.DisallowUnknownFields() + if err := dec.Decode(v); err != nil { + writeError(w, http.StatusBadRequest, "invalid JSON body: "+err.Error()) + return false + } + return true +} + +func safeFilename(name string) string { + out := strings.Map(func(r rune) rune { + switch { + case r >= 'a' && r <= 'z', r >= 'A' && r <= 'Z', r >= '0' && r <= '9', + r == '_', r == '-', r == '.': + return r + default: + return '_' + } + }, name) + if out == "" { + return "table" + } + return out +} + +func parseFilters(raw []string) []db.Filter { + if len(raw) == 0 { + return nil + } + out := make([]db.Filter, 0, len(raw)) + for _, s := range raw { + parts := strings.SplitN(s, ":", 3) + f := db.Filter{Column: parts[0]} + if len(parts) >= 2 { + f.Op = parts[1] + } + if len(parts) == 3 { + f.Value = parts[2] + } + out = append(out, f) + } + return out +} + +func queryInt(r *http.Request, key string, def int) int { + if v := r.URL.Query().Get(key); v != "" { + if n, err := strconv.Atoi(v); err == nil { + return n + } + } + return def +} + +func pathID(w http.ResponseWriter, r *http.Request) (int64, bool) { + id, err := strconv.ParseInt(r.PathValue("id"), 10, 64) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid id") + return 0, false + } + return id, true +} + +func writeJSON(w http.ResponseWriter, status int, v any) { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(v) +} + +func writeError(w http.ResponseWriter, status int, msg string) { + writeJSON(w, status, map[string]string{"error": msg}) +} diff --git a/internal/server/middleware.go b/internal/server/middleware.go new file mode 100644 index 0000000..b90e4fd --- /dev/null +++ b/internal/server/middleware.go @@ -0,0 +1,55 @@ +package server + +import ( + "log/slog" + "net/http" + "time" +) + +func logging(log *slog.Logger, next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + sw := &statusWriter{ResponseWriter: w, status: http.StatusOK} + next.ServeHTTP(sw, r) + if r.URL.Path == "/healthz" || r.URL.Path == "/readyz" { + return + } + log.Info("request", + "method", r.Method, + "path", r.URL.Path, + "status", sw.status, + "ms", time.Since(start).Milliseconds(), + ) + }) +} + +type statusWriter struct { + http.ResponseWriter + status int +} + +func (w *statusWriter) WriteHeader(code int) { + w.status = code + w.ResponseWriter.WriteHeader(code) +} + +const contentSecurityPolicy = "default-src 'self'; " + + "script-src 'self' https://cdnjs.cloudflare.com; " + + "style-src 'self' 'unsafe-inline' https://cdnjs.cloudflare.com; " + + "img-src 'self' data:; " + + "connect-src 'self'; " + + "base-uri 'self'; " + + "form-action 'self'; " + + "frame-ancestors 'none'" + +func securityHeaders(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h := w.Header() + h.Set("Content-Security-Policy", contentSecurityPolicy) + h.Set("X-Content-Type-Options", "nosniff") + h.Set("X-Frame-Options", "DENY") + h.Set("Referrer-Policy", "no-referrer") + h.Set("Cross-Origin-Opener-Policy", "same-origin") + next.ServeHTTP(w, r) + }) +} diff --git a/internal/server/server.go b/internal/server/server.go index fadc774..7e5f33f 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -3,51 +3,15 @@ package server import ( - "context" - "encoding/csv" - "encoding/json" - "errors" - "io" "io/fs" "log/slog" "net/http" - "strconv" - "strings" "time" - - "github.com/descope-sample-apps/pgpeek/internal/db" - "github.com/descope-sample-apps/pgpeek/internal/guard" - "github.com/descope-sample-apps/pgpeek/internal/store" ) -// maxBodyBytes caps request bodies. Queries are SQL text, not data, so 1 MiB is -// generous while preventing a client from forcing unbounded buffering. -const maxBodyBytes = 1 << 20 - -// Querier runs read-only queries, browses the catalog, and reports database -// health. *db.Pool implements it; tests substitute a fake. -type Querier interface { - Query(ctx context.Context, sql string) (*db.Result, error) - Tables(ctx context.Context) ([]db.TableInfo, error) - Columns(ctx context.Context, schema, table string) ([]db.ColumnInfo, error) - ForeignKeys(ctx context.Context, schema, table string) ([]db.ForeignKey, error) - TableRows(ctx context.Context, q db.TableQuery) (*db.Result, error) - RowCap() int - Ping(ctx context.Context) error -} - -// QueryStore persists saved/preset queries. *store.Store implements it. -type QueryStore interface { - List(ctx context.Context) ([]store.SavedQuery, error) - Get(ctx context.Context, id int64) (store.SavedQuery, error) - Create(ctx context.Context, name, desc, sql string) (store.SavedQuery, error) - Update(ctx context.Context, id int64, name, desc, sql string) (store.SavedQuery, error) - Delete(ctx context.Context, id int64) error -} - // Server holds the dependencies for the HTTP handlers. type Server struct { - pool Querier + registry DatabaseRegistry store QueryStore web fs.FS log *slog.Logger @@ -56,7 +20,11 @@ type Server struct { // New constructs a Server. func New(pool Querier, st QueryStore, web fs.FS, log *slog.Logger, queryWait time.Duration) *Server { - return &Server{pool: pool, store: st, web: web, log: log, queryWait: queryWait} + return NewWithRegistry(NewSingleDatabaseRegistry(pool), st, web, log, queryWait) +} + +func NewWithRegistry(registry DatabaseRegistry, st QueryStore, web fs.FS, log *slog.Logger, queryWait time.Duration) *Server { + return &Server{registry: registry, store: st, web: web, log: log, queryWait: queryWait} } // Routes returns the configured handler. @@ -71,6 +39,7 @@ func (s *Server) Routes() http.Handler { mux.HandleFunc("GET /readyz", s.handleReady) // API + mux.HandleFunc("GET /api/databases", s.handleDatabases) mux.HandleFunc("GET /api/meta", s.handleMeta) mux.HandleFunc("POST /api/query", s.handleQuery) mux.HandleFunc("POST /api/export", s.handleExport) @@ -88,419 +57,3 @@ func (s *Server) Routes() http.Handler { return securityHeaders(logging(s.log, mux)) } - -type queryRequest struct { - SQL string `json:"sql"` -} - -func (s *Server) handleQuery(w http.ResponseWriter, r *http.Request) { - var req queryRequest - if !decodeBody(w, r, &req) { - return - } - sql := strings.TrimSpace(req.SQL) - if err := guard.Validate(sql); err != nil { - writeError(w, http.StatusBadRequest, err.Error()) - return - } - - ctx, cancel := context.WithTimeout(r.Context(), s.queryWait) - defer cancel() - - res, err := s.pool.Query(ctx, sql) - if err != nil { - writeError(w, http.StatusBadRequest, "query failed: "+err.Error()) - return - } - writeJSON(w, http.StatusOK, res) -} - -func (s *Server) handleExport(w http.ResponseWriter, r *http.Request) { - var req queryRequest - if !decodeBody(w, r, &req) { - return - } - sql := strings.TrimSpace(req.SQL) - if err := guard.Validate(sql); err != nil { - writeError(w, http.StatusBadRequest, err.Error()) - return - } - - ctx, cancel := context.WithTimeout(r.Context(), s.queryWait) - defer cancel() - - res, err := s.pool.Query(ctx, sql) - if err != nil { - writeError(w, http.StatusBadRequest, "query failed: "+err.Error()) - return - } - - w.Header().Set("Content-Type", "text/csv; charset=utf-8") - w.Header().Set("Content-Disposition", `attachment; filename="pgpeek-export.csv"`) - if err := writeCSV(w, res); err != nil { - // Headers/partial body may already be sent; just log. - s.log.Error("csv export", "err", err) - } -} - -// writeCSV streams the result as CSV. encoding/csv accumulates errors stickily, -// so it's sufficient to check Error() once after Flush. -func writeCSV(w io.Writer, res *db.Result) error { - cw := csv.NewWriter(w) - _ = cw.Write(res.Columns) - row := make([]string, len(res.Columns)) - for _, rec := range res.Rows { - for i, cell := range rec { - row[i] = db.CellString(cell) - } - _ = cw.Write(row) - } - cw.Flush() - return cw.Error() -} - -// handleMeta exposes server-side limits the UI needs (notably the row cap, so -// the client can size its page and pagination correctly). -func (s *Server) handleMeta(w http.ResponseWriter, _ *http.Request) { - writeJSON(w, http.StatusOK, map[string]int{"rowCap": s.pool.RowCap()}) -} - -func (s *Server) handleTables(w http.ResponseWriter, r *http.Request) { - ctx, cancel := context.WithTimeout(r.Context(), s.queryWait) - defer cancel() - tables, err := s.pool.Tables(ctx) - if err != nil { - writeError(w, http.StatusInternalServerError, "failed to list tables") - s.log.Error("list tables", "err", err) - return - } - if tables == nil { - tables = []db.TableInfo{} - } - writeJSON(w, http.StatusOK, tables) -} - -func (s *Server) handleColumns(w http.ResponseWriter, r *http.Request) { - if rejectRestrictedRelation(w, r.PathValue("table")) { - return - } - ctx, cancel := context.WithTimeout(r.Context(), s.queryWait) - defer cancel() - cols, err := s.pool.Columns(ctx, r.PathValue("schema"), r.PathValue("table")) - if err != nil { - writeError(w, http.StatusInternalServerError, "failed to read columns") - s.log.Error("read columns", "err", err) - return - } - if cols == nil { - cols = []db.ColumnInfo{} - } - writeJSON(w, http.StatusOK, cols) -} - -func (s *Server) handleForeignKeys(w http.ResponseWriter, r *http.Request) { - if rejectRestrictedRelation(w, r.PathValue("table")) { - return - } - ctx, cancel := context.WithTimeout(r.Context(), s.queryWait) - defer cancel() - fks, err := s.pool.ForeignKeys(ctx, r.PathValue("schema"), r.PathValue("table")) - if err != nil { - writeError(w, http.StatusInternalServerError, "failed to read foreign keys") - s.log.Error("read foreign keys", "err", err) - return - } - if fks == nil { - fks = []db.ForeignKey{} - } - writeJSON(w, http.StatusOK, fks) -} - -func (s *Server) handleTableData(w http.ResponseWriter, r *http.Request) { - if rejectRestrictedRelation(w, r.PathValue("table")) { - return - } - q := db.TableQuery{ - Schema: r.PathValue("schema"), - Table: r.PathValue("table"), - Search: r.URL.Query().Get("search"), - Sort: r.URL.Query().Get("sort"), - Desc: r.URL.Query().Get("dir") == "desc", - Limit: queryInt(r, "limit", 0), // 0 -> pool clamps to row cap - Offset: queryInt(r, "offset", 0), // negative -> pool clamps to 0 - Filters: parseFilters(r.URL.Query()["f"]), - } - - ctx, cancel := context.WithTimeout(r.Context(), s.queryWait) - defer cancel() - res, err := s.pool.TableRows(ctx, q) - if err != nil { - writeError(w, http.StatusBadRequest, "failed to read rows") - s.log.Error("read rows", "err", err) - return - } - - if r.URL.Query().Get("format") == "csv" { - w.Header().Set("Content-Type", "text/csv; charset=utf-8") - w.Header().Set("Content-Disposition", `attachment; filename="`+safeFilename(r.PathValue("table"))+`.csv"`) - if err := writeCSV(w, res); err != nil { - s.log.Error("csv export", "err", err) - } - return - } - writeJSON(w, http.StatusOK, res) -} - -func (s *Server) handleListQueries(w http.ResponseWriter, r *http.Request) { - qs, err := s.store.List(r.Context()) - if err != nil { - writeError(w, http.StatusInternalServerError, "failed to list saved queries") - s.log.Error("list saved queries", "err", err) - return - } - if qs == nil { - qs = []store.SavedQuery{} - } - writeJSON(w, http.StatusOK, qs) -} - -type savedQueryRequest struct { - Name string `json:"name"` - Description string `json:"description"` - SQL string `json:"sql"` -} - -func (s *Server) handleCreateQuery(w http.ResponseWriter, r *http.Request) { - req, ok := decodeSavedQuery(w, r) - if !ok { - return - } - q, err := s.store.Create(r.Context(), req.Name, req.Description, req.SQL) - if err != nil { - writeError(w, http.StatusInternalServerError, "failed to save query") - s.log.Error("create saved query", "err", err) - return - } - writeJSON(w, http.StatusCreated, q) -} - -func (s *Server) handleUpdateQuery(w http.ResponseWriter, r *http.Request) { - id, ok := pathID(w, r) - if !ok { - return - } - req, ok := decodeSavedQuery(w, r) - if !ok { - return - } - q, err := s.store.Update(r.Context(), id, req.Name, req.Description, req.SQL) - if errors.Is(err, store.ErrNotFound) { - writeError(w, http.StatusNotFound, "saved query not found") - return - } - if err != nil { - writeError(w, http.StatusInternalServerError, "failed to update query") - s.log.Error("update saved query", "err", err) - return - } - writeJSON(w, http.StatusOK, q) -} - -func (s *Server) handleDeleteQuery(w http.ResponseWriter, r *http.Request) { - id, ok := pathID(w, r) - if !ok { - return - } - err := s.store.Delete(r.Context(), id) - if errors.Is(err, store.ErrNotFound) { - writeError(w, http.StatusNotFound, "saved query not found") - return - } - if err != nil { - writeError(w, http.StatusInternalServerError, "failed to delete query") - s.log.Error("delete saved query", "err", err) - return - } - w.WriteHeader(http.StatusNoContent) -} - -func (s *Server) handleReady(w http.ResponseWriter, r *http.Request) { - ctx, cancel := context.WithTimeout(r.Context(), 3*time.Second) - defer cancel() - if err := s.pool.Ping(ctx); err != nil { - writeError(w, http.StatusServiceUnavailable, "database not ready") - return - } - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte("ready")) -} - -// --- helpers --- - -func decodeSavedQuery(w http.ResponseWriter, r *http.Request) (savedQueryRequest, bool) { - var req savedQueryRequest - if !decodeBody(w, r, &req) { - return req, false - } - req.Name = strings.TrimSpace(req.Name) - req.SQL = strings.TrimSpace(req.SQL) - if req.Name == "" || req.SQL == "" { - writeError(w, http.StatusBadRequest, "name and sql are required") - return req, false - } - if err := guard.Validate(req.SQL); err != nil { - writeError(w, http.StatusBadRequest, "saved query must be read-only: "+err.Error()) - return req, false - } - return req, true -} - -// decodeBody reads a size-capped JSON body into v, rejecting unknown fields. It -// writes the error response itself and returns false on failure. -func decodeBody(w http.ResponseWriter, r *http.Request, v any) bool { - r.Body = http.MaxBytesReader(w, r.Body, maxBodyBytes) - dec := json.NewDecoder(r.Body) - dec.DisallowUnknownFields() - if err := dec.Decode(v); err != nil { - writeError(w, http.StatusBadRequest, "invalid JSON body: "+err.Error()) - return false - } - return true -} - -// safeFilename keeps only filename-safe characters, so a table name can't break -// out of the quoted Content-Disposition value or inject a different extension. -func safeFilename(name string) string { - out := strings.Map(func(r rune) rune { - switch { - case r >= 'a' && r <= 'z', r >= 'A' && r <= 'Z', r >= '0' && r <= '9', - r == '_', r == '-', r == '.': - return r - default: - return '_' - } - }, name) - if out == "" { - return "table" - } - return out -} - -// parseFilters turns repeated "f" params of the form "column:op[:value]" into -// db.Filter values. Column/op/value are validated downstream by db.TableRows. -func parseFilters(raw []string) []db.Filter { - if len(raw) == 0 { - return nil - } - out := make([]db.Filter, 0, len(raw)) - for _, s := range raw { - parts := strings.SplitN(s, ":", 3) - f := db.Filter{Column: parts[0]} - if len(parts) >= 2 { - f.Op = parts[1] - } - if len(parts) == 3 { - f.Value = parts[2] - } - out = append(out, f) - } - return out -} - -// queryInt parses a query-string integer, falling back to def on absence or -// parse error (the db layer clamps the actual bounds). -func queryInt(r *http.Request, key string, def int) int { - if v := r.URL.Query().Get(key); v != "" { - if n, err := strconv.Atoi(v); err == nil { - return n - } - } - return def -} - -func pathID(w http.ResponseWriter, r *http.Request) (int64, bool) { - id, err := strconv.ParseInt(r.PathValue("id"), 10, 64) - if err != nil { - writeError(w, http.StatusBadRequest, "invalid id") - return 0, false - } - return id, true -} - -func writeJSON(w http.ResponseWriter, status int, v any) { - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.WriteHeader(status) - _ = json.NewEncoder(w).Encode(v) -} - -func writeError(w http.ResponseWriter, status int, msg string) { - writeJSON(w, status, map[string]string{"error": msg}) -} - -func rejectRestrictedRelation(w http.ResponseWriter, table string) bool { - if !guard.IsRestrictedRelation(table) { - return false - } - writeError(w, http.StatusBadRequest, "restricted system catalog") - return true -} - -// logging is a minimal request logger. It never logs request bodies (which -// contain SQL) at info level beyond method/path. -func logging(log *slog.Logger, next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - start := time.Now() - sw := &statusWriter{ResponseWriter: w, status: http.StatusOK} - next.ServeHTTP(sw, r) - if r.URL.Path == "/healthz" || r.URL.Path == "/readyz" { - return // don't spam probe logs - } - log.Info("request", - "method", r.Method, - "path", r.URL.Path, - "status", sw.status, - "ms", time.Since(start).Milliseconds(), - ) - }) -} - -type statusWriter struct { - http.ResponseWriter - status int -} - -func (w *statusWriter) WriteHeader(code int) { - w.status = code - w.ResponseWriter.WriteHeader(code) -} - -// contentSecurityPolicy allows only the app's own assets. CodeMirror 6 is -// vendored and served from /vendor (no third-party CDN). The page script lives -// in /app.js (no 'unsafe-inline' for scripts); styles permit 'unsafe-inline' -// because CodeMirror injects them at runtime. -const contentSecurityPolicy = "default-src 'self'; " + - "script-src 'self'; " + - "style-src 'self' 'unsafe-inline'; " + - "img-src 'self' data:; " + - "connect-src 'self'; " + - "base-uri 'self'; " + - "form-action 'self'; " + - "frame-ancestors 'none'" - -// securityHeaders sets conservative defaults on every response. -func securityHeaders(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - h := w.Header() - h.Set("Content-Security-Policy", contentSecurityPolicy) - h.Set("X-Content-Type-Options", "nosniff") - h.Set("X-Frame-Options", "DENY") - h.Set("Referrer-Policy", "no-referrer") - h.Set("Cross-Origin-Opener-Policy", "same-origin") - // Advertise HSTS only on connections that actually reached us over TLS - // (direct TLS or via a TLS-terminating proxy that sets X-Forwarded-Proto). - if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" { - h.Set("Strict-Transport-Security", "max-age=63072000; includeSubDomains") - } - next.ServeHTTP(w, r) - }) -} diff --git a/internal/server/server_routes_test.go b/internal/server/server_routes_test.go new file mode 100644 index 0000000..d6e5358 --- /dev/null +++ b/internal/server/server_routes_test.go @@ -0,0 +1,80 @@ +package server + +import ( + "bytes" + "errors" + "io" + "net/http" + "strings" + "testing" + + "github.com/descope/pgpeek/internal/guard" + "github.com/descope/pgpeek/internal/store" +) + +func TestHealthz(t *testing.T) { + ts, _ := newTestServer(t, &fakeQuerier{}) + resp := mustGet(t, ts, "/healthz") + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Errorf("healthz = %d", resp.StatusCode) + } +} + +func TestReadyz(t *testing.T) { + t.Run("healthy", func(t *testing.T) { + ts, _ := newTestServer(t, &fakeQuerier{}) + resp := mustGet(t, ts, "/readyz") + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Errorf("readyz = %d, want 200", resp.StatusCode) + } + }) + t.Run("db down", func(t *testing.T) { + ts, _ := newTestServer(t, &fakeQuerier{pingErr: errors.New("down")}) + resp := mustGet(t, ts, "/readyz") + resp.Body.Close() + if resp.StatusCode != http.StatusServiceUnavailable { + t.Errorf("readyz = %d, want 503", resp.StatusCode) + } + }) +} + +func TestSecurityHeaders(t *testing.T) { + ts, _ := newTestServer(t, &fakeQuerier{}) + resp := mustGet(t, ts, "/healthz") + resp.Body.Close() + want := map[string]string{ + "X-Content-Type-Options": "nosniff", + "X-Frame-Options": "DENY", + "Referrer-Policy": "no-referrer", + } + for k, v := range want { + if got := resp.Header.Get(k); got != v { + t.Errorf("%s = %q, want %q", k, got, v) + } + } + if csp := resp.Header.Get("Content-Security-Policy"); !strings.Contains(csp, "default-src 'self'") { + t.Errorf("missing CSP: %q", csp) + } +} + +func TestUIServed(t *testing.T) { + ts, _ := newTestServer(t, &fakeQuerier{}) + resp := mustGet(t, ts, "/") + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + if !bytes.Contains(body, []byte("pgpeek")) { + t.Errorf("index not served: %q", body) + } +} + +// DefaultPresets must all survive the read-only guard, otherwise the saved-query +// validation would reject them if a user re-saved one. +func TestDefaultPresetsPassGuard(t *testing.T) { + for _, p := range store.DefaultPresets { + if err := guard.Validate(p.SQL); err != nil { + t.Errorf("preset %q fails guard: %v", p.Name, err) + } + } +} diff --git a/internal/server/server_test.go b/internal/server/server_test.go deleted file mode 100644 index 806565a..0000000 --- a/internal/server/server_test.go +++ /dev/null @@ -1,837 +0,0 @@ -package server - -import ( - "bytes" - "context" - "encoding/json" - "errors" - "io" - "log/slog" - "net/http" - "net/http/httptest" - "path/filepath" - "strconv" - "strings" - "testing" - "testing/fstest" - "time" - - "github.com/descope-sample-apps/pgpeek/internal/db" - "github.com/descope-sample-apps/pgpeek/internal/guard" - "github.com/descope-sample-apps/pgpeek/internal/store" -) - -type fakeQuerier struct { - result *db.Result - err error - pingErr error - called bool - tableRowsCalled bool - lastSQL string - tables []db.TableInfo - cols []db.ColumnInfo - fks []db.ForeignKey - catErr error - lastQuery db.TableQuery - lastArgs struct { - schema, table string - limit, offset int - } -} - -func (f *fakeQuerier) Query(_ context.Context, sql string) (*db.Result, error) { - f.called = true - f.lastSQL = sql - return f.result, f.err -} - -func (f *fakeQuerier) Tables(context.Context) ([]db.TableInfo, error) { - return f.tables, f.catErr -} - -func (f *fakeQuerier) Columns(_ context.Context, schema, table string) ([]db.ColumnInfo, error) { - f.lastArgs.schema, f.lastArgs.table = schema, table - return f.cols, f.catErr -} - -func (f *fakeQuerier) ForeignKeys(_ context.Context, _, _ string) ([]db.ForeignKey, error) { - return f.fks, f.catErr -} - -func (f *fakeQuerier) TableRows(_ context.Context, q db.TableQuery) (*db.Result, error) { - f.tableRowsCalled = true - f.lastQuery = q - f.lastArgs.schema, f.lastArgs.table = q.Schema, q.Table - f.lastArgs.limit, f.lastArgs.offset = q.Limit, q.Offset - return f.result, f.err -} - -func (f *fakeQuerier) RowCap() int { return 1000 } - -func (f *fakeQuerier) Ping(_ context.Context) error { return f.pingErr } - -func newTestServer(t *testing.T, q Querier) (*httptest.Server, *store.Store) { - t.Helper() - st, err := store.Open(filepath.Join(t.TempDir(), "t.db")) - if err != nil { - t.Fatalf("store.Open: %v", err) - } - t.Cleanup(func() { _ = st.Close() }) - - web := fstest.MapFS{ - "index.html": &fstest.MapFile{Data: []byte("pgpeek")}, - "app.js": &fstest.MapFile{Data: []byte("// js")}, - } - log := slog.New(slog.NewTextHandler(io.Discard, nil)) - srv := New(q, st, web, log, 5*time.Second) - ts := httptest.NewServer(srv.Routes()) - t.Cleanup(ts.Close) - return ts, st -} - -func post(t *testing.T, ts *httptest.Server, path, body string) *http.Response { - t.Helper() - resp, err := http.Post(ts.URL+path, "application/json", strings.NewReader(body)) - if err != nil { - t.Fatalf("POST %s: %v", path, err) - } - return resp -} - -func decode[T any](t *testing.T, resp *http.Response) T { - t.Helper() - defer resp.Body.Close() - var v T - if err := json.NewDecoder(resp.Body).Decode(&v); err != nil { - t.Fatalf("decode: %v", err) - } - return v -} - -func okResult() *db.Result { - return &db.Result{ - Columns: []string{"n"}, - Rows: [][]any{{int64(1)}}, - RowCount: 1, - ElapsedMS: 1, - } -} - -func TestQuery_OK(t *testing.T) { - q := &fakeQuerier{result: okResult()} - ts, _ := newTestServer(t, q) - - resp := post(t, ts, "/api/query", `{"sql":" SELECT 1 "}`) - if resp.StatusCode != http.StatusOK { - t.Fatalf("status = %d", resp.StatusCode) - } - res := decode[db.Result](t, resp) - if res.RowCount != 1 { - t.Errorf("rowCount = %d", res.RowCount) - } - if q.lastSQL != "SELECT 1" { - t.Errorf("SQL not trimmed before exec: %q", q.lastSQL) - } -} - -func TestQuery_GuardRejectsDML(t *testing.T) { - q := &fakeQuerier{result: okResult()} - ts, _ := newTestServer(t, q) - - resp := post(t, ts, "/api/query", `{"sql":"DELETE FROM users"}`) - if resp.StatusCode != http.StatusBadRequest { - t.Fatalf("status = %d, want 400", resp.StatusCode) - } - if q.called { - t.Error("guard should block the query before it reaches the database") - } -} - -func TestQuery_InvalidJSON(t *testing.T) { - ts, _ := newTestServer(t, &fakeQuerier{}) - resp := post(t, ts, "/api/query", `{not json`) - if resp.StatusCode != http.StatusBadRequest { - t.Fatalf("status = %d, want 400", resp.StatusCode) - } -} - -func TestQuery_UnknownField(t *testing.T) { - ts, _ := newTestServer(t, &fakeQuerier{}) - resp := post(t, ts, "/api/query", `{"sql":"SELECT 1","evil":true}`) - if resp.StatusCode != http.StatusBadRequest { - t.Fatalf("status = %d, want 400 (DisallowUnknownFields)", resp.StatusCode) - } -} - -func TestQuery_DBError(t *testing.T) { - q := &fakeQuerier{err: errors.New("boom")} - ts, _ := newTestServer(t, q) - resp := post(t, ts, "/api/query", `{"sql":"SELECT 1"}`) - if resp.StatusCode != http.StatusBadRequest { - t.Fatalf("status = %d, want 400", resp.StatusCode) - } -} - -func TestQuery_BodyTooLarge(t *testing.T) { - ts, _ := newTestServer(t, &fakeQuerier{result: okResult()}) - huge := strings.Repeat("a", (1<<20)+10) - resp := post(t, ts, "/api/query", `{"sql":"`+huge+`"}`) - if resp.StatusCode != http.StatusBadRequest { - t.Fatalf("status = %d, want 400 for oversized body", resp.StatusCode) - } -} - -func TestExport_CSV(t *testing.T) { - q := &fakeQuerier{result: &db.Result{ - Columns: []string{"name", "n"}, - Rows: [][]any{{"Acme", int64(2)}, {"Globex,Inc", int64(1)}}, - RowCount: 2, - }} - ts, _ := newTestServer(t, q) - - resp := post(t, ts, "/api/export", `{"sql":"SELECT 1"}`) - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - t.Fatalf("status = %d", resp.StatusCode) - } - if ct := resp.Header.Get("Content-Type"); !strings.HasPrefix(ct, "text/csv") { - t.Errorf("content-type = %q", ct) - } - if cd := resp.Header.Get("Content-Disposition"); !strings.Contains(cd, "pgpeek-export.csv") { - t.Errorf("content-disposition = %q", cd) - } - body, _ := io.ReadAll(resp.Body) - got := string(body) - if !strings.Contains(got, "name,n") || !strings.Contains(got, "Acme,2") { - t.Errorf("csv body = %q", got) - } - // Field with a comma must be quoted by encoding/csv. - if !strings.Contains(got, `"Globex,Inc"`) { - t.Errorf("comma field not quoted: %q", got) - } -} - -func TestExport_GuardRejects(t *testing.T) { - q := &fakeQuerier{result: okResult()} - ts, _ := newTestServer(t, q) - resp := post(t, ts, "/api/export", `{"sql":"DROP TABLE x"}`) - resp.Body.Close() - if resp.StatusCode != http.StatusBadRequest { - t.Fatalf("status = %d, want 400", resp.StatusCode) - } -} - -func TestSavedQueries_CRUD(t *testing.T) { - ts, _ := newTestServer(t, &fakeQuerier{result: okResult()}) - - // Create - resp := post(t, ts, "/api/queries", `{"name":"q","description":"d","sql":"SELECT 1"}`) - if resp.StatusCode != http.StatusCreated { - t.Fatalf("create status = %d", resp.StatusCode) - } - created := decode[store.SavedQuery](t, resp) - if created.ID == 0 { - t.Fatal("no id returned") - } - - // List - resp = mustGet(t, ts, "/api/queries") - list := decode[[]store.SavedQuery](t, resp) - if len(list) != 1 { - t.Fatalf("list len = %d", len(list)) - } - - // Update - req, _ := http.NewRequest(http.MethodPut, ts.URL+"/api/queries/"+itoa(created.ID), - strings.NewReader(`{"name":"q2","sql":"SELECT 2"}`)) - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatal(err) - } - if resp.StatusCode != http.StatusOK { - t.Fatalf("update status = %d", resp.StatusCode) - } - resp.Body.Close() - - // Delete - req, _ = http.NewRequest(http.MethodDelete, ts.URL+"/api/queries/"+itoa(created.ID), nil) - resp, err = http.DefaultClient.Do(req) - if err != nil { - t.Fatal(err) - } - if resp.StatusCode != http.StatusNoContent { - t.Fatalf("delete status = %d", resp.StatusCode) - } - resp.Body.Close() -} - -func TestSavedQueries_Validation(t *testing.T) { - ts, _ := newTestServer(t, &fakeQuerier{}) - - cases := []struct { - name, body string - }{ - {"missing name", `{"sql":"SELECT 1"}`}, - {"missing sql", `{"name":"x"}`}, - {"non-readonly sql", `{"name":"x","sql":"DELETE FROM t"}`}, - } - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - resp := post(t, ts, "/api/queries", c.body) - resp.Body.Close() - if resp.StatusCode != http.StatusBadRequest { - t.Errorf("status = %d, want 400", resp.StatusCode) - } - }) - } -} - -func TestSavedQueries_NotFoundAndBadID(t *testing.T) { - ts, _ := newTestServer(t, &fakeQuerier{}) - - req, _ := http.NewRequest(http.MethodDelete, ts.URL+"/api/queries/999", nil) - resp, _ := http.DefaultClient.Do(req) - if resp.StatusCode != http.StatusNotFound { - t.Errorf("delete missing status = %d, want 404", resp.StatusCode) - } - resp.Body.Close() - - req, _ = http.NewRequest(http.MethodPut, ts.URL+"/api/queries/abc", - strings.NewReader(`{"name":"x","sql":"SELECT 1"}`)) - resp, _ = http.DefaultClient.Do(req) - if resp.StatusCode != http.StatusBadRequest { - t.Errorf("bad id status = %d, want 400", resp.StatusCode) - } - resp.Body.Close() -} - -func TestHealthz(t *testing.T) { - ts, _ := newTestServer(t, &fakeQuerier{}) - resp := mustGet(t, ts, "/healthz") - resp.Body.Close() - if resp.StatusCode != http.StatusOK { - t.Errorf("healthz = %d", resp.StatusCode) - } -} - -func TestReadyz(t *testing.T) { - t.Run("healthy", func(t *testing.T) { - ts, _ := newTestServer(t, &fakeQuerier{}) - resp := mustGet(t, ts, "/readyz") - resp.Body.Close() - if resp.StatusCode != http.StatusOK { - t.Errorf("readyz = %d, want 200", resp.StatusCode) - } - }) - t.Run("db down", func(t *testing.T) { - ts, _ := newTestServer(t, &fakeQuerier{pingErr: errors.New("down")}) - resp := mustGet(t, ts, "/readyz") - resp.Body.Close() - if resp.StatusCode != http.StatusServiceUnavailable { - t.Errorf("readyz = %d, want 503", resp.StatusCode) - } - }) -} - -func TestSecurityHeaders(t *testing.T) { - ts, _ := newTestServer(t, &fakeQuerier{}) - resp := mustGet(t, ts, "/healthz") - resp.Body.Close() - want := map[string]string{ - "X-Content-Type-Options": "nosniff", - "X-Frame-Options": "DENY", - "Referrer-Policy": "no-referrer", - } - for k, v := range want { - if got := resp.Header.Get(k); got != v { - t.Errorf("%s = %q, want %q", k, got, v) - } - } - if csp := resp.Header.Get("Content-Security-Policy"); !strings.Contains(csp, "default-src 'self'") { - t.Errorf("missing CSP: %q", csp) - } -} - -func TestHSTS(t *testing.T) { - srv := serverWithStore(t, &fakeQuerier{}, &fakeStore{}) - - // No TLS signal -> no HSTS header. - rec := httptest.NewRecorder() - srv.Routes().ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/healthz", nil)) - if h := rec.Header().Get("Strict-Transport-Security"); h != "" { - t.Errorf("HSTS set on plaintext request: %q", h) - } - - // X-Forwarded-Proto: https (TLS-terminating proxy) -> HSTS present. - req := httptest.NewRequest(http.MethodGet, "/healthz", nil) - req.Header.Set("X-Forwarded-Proto", "https") - rec = httptest.NewRecorder() - srv.Routes().ServeHTTP(rec, req) - if h := rec.Header().Get("Strict-Transport-Security"); h != "max-age=63072000; includeSubDomains" { - t.Errorf("HSTS = %q, want max-age=63072000; includeSubDomains", h) - } -} - -func TestUIServed(t *testing.T) { - ts, _ := newTestServer(t, &fakeQuerier{}) - resp := mustGet(t, ts, "/") - defer resp.Body.Close() - body, _ := io.ReadAll(resp.Body) - if !bytes.Contains(body, []byte("pgpeek")) { - t.Errorf("index not served: %q", body) - } -} - -// DefaultPresets must all survive the read-only guard, otherwise the saved-query -// validation would reject them if a user re-saved one. -func TestDefaultPresetsPassGuard(t *testing.T) { - for _, p := range store.DefaultPresets { - if err := guard.Validate(p.SQL); err != nil { - t.Errorf("preset %q fails guard: %v", p.Name, err) - } - } -} - -// --- error-path coverage with a fake store and failing writer ------------ - -type fakeStore struct { - listErr, getErr, createErr, updateErr, deleteErr error -} - -func (f *fakeStore) List(context.Context) ([]store.SavedQuery, error) { - return nil, f.listErr -} -func (f *fakeStore) Get(context.Context, int64) (store.SavedQuery, error) { - return store.SavedQuery{}, f.getErr -} -func (f *fakeStore) Create(context.Context, string, string, string) (store.SavedQuery, error) { - return store.SavedQuery{}, f.createErr -} -func (f *fakeStore) Update(context.Context, int64, string, string, string) (store.SavedQuery, error) { - return store.SavedQuery{}, f.updateErr -} -func (f *fakeStore) Delete(context.Context, int64) error { return f.deleteErr } - -func serverWithStore(t *testing.T, q Querier, st QueryStore) *Server { - t.Helper() - web := fstest.MapFS{"index.html": &fstest.MapFile{Data: []byte("x")}} - log := slog.New(slog.NewTextHandler(io.Discard, nil)) - return New(q, st, web, log, time.Second) -} - -func TestStoreErrorPaths(t *testing.T) { - boom := errors.New("db down") - srv := serverWithStore(t, &fakeQuerier{}, &fakeStore{ - listErr: boom, createErr: boom, updateErr: boom, deleteErr: boom, - }) - mux := srv.Routes() - - do := func(method, path, body string) int { - var rdr io.Reader - if body != "" { - rdr = strings.NewReader(body) - } - req := httptest.NewRequest(method, path, rdr) - rec := httptest.NewRecorder() - mux.ServeHTTP(rec, req) - return rec.Code - } - - if code := do(http.MethodGet, "/api/queries", ""); code != http.StatusInternalServerError { - t.Errorf("list error code = %d, want 500", code) - } - if code := do(http.MethodPost, "/api/queries", `{"name":"x","sql":"SELECT 1"}`); code != http.StatusInternalServerError { - t.Errorf("create error code = %d, want 500", code) - } - if code := do(http.MethodPut, "/api/queries/1", `{"name":"x","sql":"SELECT 1"}`); code != http.StatusInternalServerError { - t.Errorf("update error code = %d, want 500", code) - } - if code := do(http.MethodDelete, "/api/queries/1", ""); code != http.StatusInternalServerError { - t.Errorf("delete error code = %d, want 500", code) - } -} - -// --- catalog / browse endpoints ----------------------------------------- - -func TestMeta(t *testing.T) { - ts, _ := newTestServer(t, &fakeQuerier{}) - resp := mustGet(t, ts, "/api/meta") - got := decode[map[string]int](t, resp) - if got["rowCap"] != 1000 { - t.Errorf("rowCap = %d, want 1000", got["rowCap"]) - } -} - -func TestSafeFilename(t *testing.T) { - cases := map[string]string{ - "users": "users", - `a"; x`: "a___x", - "sch.tbl": "sch.tbl", - "évil/../x": "_vil_.._x", - "": "table", - "weird name!@#$": "weird_name____", - } - for in, want := range cases { - if got := safeFilename(in); got != want { - t.Errorf("safeFilename(%q) = %q, want %q", in, got, want) - } - } -} - -func TestTables_OK(t *testing.T) { - q := &fakeQuerier{tables: []db.TableInfo{{Schema: "public", Name: "users", Type: "table", EstRows: 5}}} - ts, _ := newTestServer(t, q) - resp := mustGet(t, ts, "/api/tables") - if resp.StatusCode != http.StatusOK { - t.Fatalf("status = %d", resp.StatusCode) - } - got := decode[[]db.TableInfo](t, resp) - if len(got) != 1 || got[0].Name != "users" { - t.Errorf("tables = %+v", got) - } -} - -func TestTables_EmptyReturnsArray(t *testing.T) { - srv := serverWithStore(t, &fakeQuerier{}, &fakeStore{}) - rec := httptest.NewRecorder() - srv.Routes().ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/api/tables", nil)) - if strings.TrimSpace(rec.Body.String()) != "[]" { - t.Errorf("body = %q, want []", rec.Body.String()) - } -} - -func TestTables_Error(t *testing.T) { - ts, _ := newTestServer(t, &fakeQuerier{catErr: errors.New("boom")}) - resp := mustGet(t, ts, "/api/tables") - resp.Body.Close() - if resp.StatusCode != http.StatusInternalServerError { - t.Errorf("status = %d, want 500", resp.StatusCode) - } -} - -func TestColumns_OK(t *testing.T) { - q := &fakeQuerier{cols: []db.ColumnInfo{{Name: "id", Type: "integer"}}} - ts, _ := newTestServer(t, q) - resp := mustGet(t, ts, "/api/tables/public/users/columns") - got := decode[[]db.ColumnInfo](t, resp) - if len(got) != 1 || got[0].Name != "id" { - t.Errorf("columns = %+v", got) - } - if q.lastArgs.schema != "public" || q.lastArgs.table != "users" { - t.Errorf("path values not passed: %+v", q.lastArgs) - } -} - -func TestColumns_EmptyReturnsArray(t *testing.T) { - srv := serverWithStore(t, &fakeQuerier{}, &fakeStore{}) - rec := httptest.NewRecorder() - srv.Routes().ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/api/tables/public/users/columns", nil)) - if strings.TrimSpace(rec.Body.String()) != "[]" { - t.Errorf("body = %q, want []", rec.Body.String()) - } -} - -func TestColumns_Error(t *testing.T) { - ts, _ := newTestServer(t, &fakeQuerier{catErr: errors.New("boom")}) - resp := mustGet(t, ts, "/api/tables/public/users/columns") - resp.Body.Close() - if resp.StatusCode != http.StatusInternalServerError { - t.Errorf("status = %d, want 500", resp.StatusCode) - } -} - -func TestForeignKeys_OK(t *testing.T) { - q := &fakeQuerier{fks: []db.ForeignKey{{Column: "company_id", RefSchema: "public", RefTable: "companies", RefColumn: "id"}}} - ts, _ := newTestServer(t, q) - resp := mustGet(t, ts, "/api/tables/public/users/fks") - got := decode[[]db.ForeignKey](t, resp) - if len(got) != 1 || got[0].RefTable != "companies" { - t.Errorf("fks = %+v", got) - } -} - -func TestForeignKeys_EmptyReturnsArray(t *testing.T) { - srv := serverWithStore(t, &fakeQuerier{}, &fakeStore{}) - rec := httptest.NewRecorder() - srv.Routes().ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/api/tables/public/users/fks", nil)) - if strings.TrimSpace(rec.Body.String()) != "[]" { - t.Errorf("body = %q, want []", rec.Body.String()) - } -} - -func TestForeignKeys_Error(t *testing.T) { - ts, _ := newTestServer(t, &fakeQuerier{catErr: errors.New("boom")}) - resp := mustGet(t, ts, "/api/tables/public/users/fks") - resp.Body.Close() - if resp.StatusCode != http.StatusInternalServerError { - t.Errorf("status = %d, want 500", resp.StatusCode) - } -} - -func TestCatalogEndpoints_RejectRestrictedCatalogs(t *testing.T) { - q := &fakeQuerier{result: okResult()} - ts, _ := newTestServer(t, q) - - paths := []string{ - "/api/tables/pg_catalog/pg_authid/columns", - "/api/tables/pg_catalog/pg_shadow/fks", - "/api/tables/pg_catalog/pg_hba_file_rules/data", - } - for _, path := range paths { - resp := mustGet(t, ts, path) - resp.Body.Close() - if resp.StatusCode != http.StatusBadRequest { - t.Errorf("%s status = %d, want 400", path, resp.StatusCode) - } - } - if q.tableRowsCalled { - t.Error("restricted catalog data request reached TableRows") - } -} - -func TestTableData_OK(t *testing.T) { - q := &fakeQuerier{result: okResult()} - ts, _ := newTestServer(t, q) - resp := mustGet(t, ts, "/api/tables/public/users/data?limit=25&offset=50") - if resp.StatusCode != http.StatusOK { - t.Fatalf("status = %d", resp.StatusCode) - } - decode[db.Result](t, resp) - if q.lastArgs.limit != 25 || q.lastArgs.offset != 50 { - t.Errorf("limit/offset = %d/%d", q.lastArgs.limit, q.lastArgs.offset) - } -} - -func TestTableData_ParsesSearchSortFilters(t *testing.T) { - q := &fakeQuerier{result: okResult()} - ts, _ := newTestServer(t, q) - resp := mustGet(t, ts, "/api/tables/public/users/data?search=acme&sort=id&dir=desc&f=id:gt:100&f=name:ilike:%25a%25&f=deleted_at:is_null") - resp.Body.Close() - if resp.StatusCode != http.StatusOK { - t.Fatalf("status = %d", resp.StatusCode) - } - lq := q.lastQuery - if lq.Search != "acme" || lq.Sort != "id" || !lq.Desc { - t.Errorf("search/sort/desc = %+v", lq) - } - if len(lq.Filters) != 3 { - t.Fatalf("filters = %+v", lq.Filters) - } - if lq.Filters[0] != (db.Filter{Column: "id", Op: "gt", Value: "100"}) { - t.Errorf("filter0 = %+v", lq.Filters[0]) - } - if lq.Filters[1] != (db.Filter{Column: "name", Op: "ilike", Value: "%a%"}) { - t.Errorf("filter1 = %+v", lq.Filters[1]) - } - if lq.Filters[2] != (db.Filter{Column: "deleted_at", Op: "is_null"}) { - t.Errorf("filter2 (no value) = %+v", lq.Filters[2]) - } -} - -func TestParseFilters(t *testing.T) { - if parseFilters(nil) != nil { - t.Error("nil input should yield nil") - } - got := parseFilters([]string{"id:gt:100", "name:is_null", "bare"}) - want := []db.Filter{ - {Column: "id", Op: "gt", Value: "100"}, - {Column: "name", Op: "is_null"}, - {Column: "bare"}, - } - if len(got) != len(want) { - t.Fatalf("len = %d", len(got)) - } - for i := range want { - if got[i] != want[i] { - t.Errorf("filter %d = %+v, want %+v", i, got[i], want[i]) - } - } -} - -func TestTableData_DefaultsAndBadParams(t *testing.T) { - q := &fakeQuerier{result: okResult()} - ts, _ := newTestServer(t, q) - // Non-numeric params fall back to defaults (0 -> pool clamps). - resp := mustGet(t, ts, "/api/tables/public/users/data?limit=abc") - resp.Body.Close() - if q.lastArgs.limit != 0 || q.lastArgs.offset != 0 { - t.Errorf("expected defaults, got %d/%d", q.lastArgs.limit, q.lastArgs.offset) - } -} - -func TestTableData_Error(t *testing.T) { - ts, _ := newTestServer(t, &fakeQuerier{err: errors.New("no such table")}) - resp := mustGet(t, ts, "/api/tables/public/nope/data") - body, _ := io.ReadAll(resp.Body) - resp.Body.Close() - if resp.StatusCode != http.StatusBadRequest { - t.Errorf("status = %d, want 400", resp.StatusCode) - } - if strings.Contains(string(body), "no such table") { - t.Errorf("table data error leaked database detail: %s", body) - } -} - -func TestTableData_CSV(t *testing.T) { - q := &fakeQuerier{result: &db.Result{Columns: []string{"a"}, Rows: [][]any{{"1"}}, RowCount: 1}} - ts, _ := newTestServer(t, q) - resp := mustGet(t, ts, "/api/tables/public/users/data?format=csv") - defer resp.Body.Close() - if ct := resp.Header.Get("Content-Type"); !strings.HasPrefix(ct, "text/csv") { - t.Errorf("content-type = %q", ct) - } - if cd := resp.Header.Get("Content-Disposition"); !strings.Contains(cd, "users.csv") { - t.Errorf("disposition = %q", cd) - } - body, _ := io.ReadAll(resp.Body) - if !strings.Contains(string(body), "a\n1\n") { - t.Errorf("csv body = %q", body) - } -} - -func TestTableData_CSVWriteFailure(t *testing.T) { - srv := serverWithStore(t, &fakeQuerier{result: okResult()}, &fakeStore{}) - req := httptest.NewRequest(http.MethodGet, "/api/tables/public/users/data?format=csv", nil) - fw := &failingWriter{} - srv.handleTableData(fw, req) - if ct := fw.Header().Get("Content-Type"); !strings.HasPrefix(ct, "text/csv") { - t.Errorf("content-type = %q", ct) - } -} - -func TestListEmptyReturnsArray(t *testing.T) { - // A successful List of an empty store returns nil; the handler must emit []. - srv := serverWithStore(t, &fakeQuerier{}, &fakeStore{}) - req := httptest.NewRequest(http.MethodGet, "/api/queries", nil) - rec := httptest.NewRecorder() - srv.Routes().ServeHTTP(rec, req) - if rec.Code != http.StatusOK { - t.Fatalf("code = %d", rec.Code) - } - if body := strings.TrimSpace(rec.Body.String()); body != "[]" { - t.Errorf("body = %q, want []", body) - } -} - -func TestUpdate_InvalidJSON(t *testing.T) { - ts, _ := newTestServer(t, &fakeQuerier{}) - req, _ := http.NewRequest(http.MethodPut, ts.URL+"/api/queries/1", strings.NewReader(`{bad`)) - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatal(err) - } - resp.Body.Close() - if resp.StatusCode != http.StatusBadRequest { - t.Errorf("status = %d, want 400", resp.StatusCode) - } -} - -func TestUpdate_NotFound(t *testing.T) { - ts, _ := newTestServer(t, &fakeQuerier{}) - req, _ := http.NewRequest(http.MethodPut, ts.URL+"/api/queries/999", strings.NewReader(`{"name":"x","sql":"SELECT 1"}`)) - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatal(err) - } - resp.Body.Close() - if resp.StatusCode != http.StatusNotFound { - t.Errorf("status = %d, want 404", resp.StatusCode) - } -} - -func TestDelete_BadID(t *testing.T) { - ts, _ := newTestServer(t, &fakeQuerier{}) - req, _ := http.NewRequest(http.MethodDelete, ts.URL+"/api/queries/abc", nil) - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatal(err) - } - resp.Body.Close() - if resp.StatusCode != http.StatusBadRequest { - t.Errorf("status = %d, want 400", resp.StatusCode) - } -} - -func TestExport_QueryError(t *testing.T) { - ts, _ := newTestServer(t, &fakeQuerier{err: errors.New("boom")}) - resp := post(t, ts, "/api/export", `{"sql":"SELECT 1"}`) - resp.Body.Close() - if resp.StatusCode != http.StatusBadRequest { - t.Errorf("status = %d, want 400", resp.StatusCode) - } -} - -func TestExport_InvalidJSON(t *testing.T) { - ts, _ := newTestServer(t, &fakeQuerier{}) - resp := post(t, ts, "/api/export", `{bad`) - resp.Body.Close() - if resp.StatusCode != http.StatusBadRequest { - t.Errorf("status = %d, want 400", resp.StatusCode) - } -} - -func TestSavedQuery_InvalidJSON(t *testing.T) { - ts, _ := newTestServer(t, &fakeQuerier{}) - resp := post(t, ts, "/api/queries", `{bad`) - resp.Body.Close() - if resp.StatusCode != http.StatusBadRequest { - t.Errorf("status = %d, want 400", resp.StatusCode) - } -} - -// failingWriter is an http.ResponseWriter whose body writes always fail. -type failingWriter struct { - header http.Header - code int -} - -func (f *failingWriter) Header() http.Header { - if f.header == nil { - f.header = http.Header{} - } - return f.header -} -func (f *failingWriter) Write([]byte) (int, error) { return 0, errors.New("connection reset") } -func (f *failingWriter) WriteHeader(code int) { f.code = code } - -func TestExport_WriteFailureLogged(t *testing.T) { - // Exercise the handler's csv-error branch directly with a writer that fails. - srv := serverWithStore(t, &fakeQuerier{result: okResult()}, &fakeStore{}) - req := httptest.NewRequest(http.MethodPost, "/api/export", strings.NewReader(`{"sql":"SELECT 1"}`)) - fw := &failingWriter{} - srv.handleExport(fw, req) - // Header is set before the failing body write. - if ct := fw.Header().Get("Content-Type"); !strings.HasPrefix(ct, "text/csv") { - t.Errorf("content-type = %q", ct) - } -} - -func TestWriteCSV(t *testing.T) { - res := &db.Result{Columns: []string{"a"}, Rows: [][]any{{"1"}}} - var buf strings.Builder - if err := writeCSV(&buf, res); err != nil { - t.Fatalf("writeCSV: %v", err) - } - if !strings.Contains(buf.String(), "a\n1\n") { - t.Errorf("csv = %q", buf.String()) - } - // Failing writer -> error surfaces via cw.Error() after Flush. - if err := writeCSV(failWriter{}, res); err == nil { - t.Error("expected error from failing writer") - } -} - -type failWriter struct{} - -func (failWriter) Write([]byte) (int, error) { return 0, errors.New("boom") } - -func mustGet(t *testing.T, ts *httptest.Server, path string) *http.Response { - t.Helper() - resp, err := http.Get(ts.URL + path) - if err != nil { - t.Fatalf("GET %s: %v", path, err) - } - return resp -} - -func itoa(i int64) string { return strconv.FormatInt(i, 10) } From d0f453746cf7ae3ec4b12c7ce104f64f35ffa4cd Mon Sep 17 00:00:00 2001 From: Omer <639682+omercnet@users.noreply.github.com> Date: Thu, 25 Jun 2026 14:40:46 +0000 Subject: [PATCH 04/31] feat(server): route catalog by database --- internal/server/catalog_handlers.go | 119 ++++++++++++ internal/server/catalog_handlers_test.go | 235 +++++++++++++++++++++++ 2 files changed, 354 insertions(+) create mode 100644 internal/server/catalog_handlers.go create mode 100644 internal/server/catalog_handlers_test.go diff --git a/internal/server/catalog_handlers.go b/internal/server/catalog_handlers.go new file mode 100644 index 0000000..77c5bf8 --- /dev/null +++ b/internal/server/catalog_handlers.go @@ -0,0 +1,119 @@ +package server + +import ( + "context" + "net/http" + "time" + + "github.com/descope/pgpeek/internal/db" +) + +func (s *Server) handleMeta(w http.ResponseWriter, r *http.Request) { + pool, ok := s.poolForRequest(w, r) + if !ok { + return + } + writeJSON(w, http.StatusOK, map[string]int{"rowCap": pool.RowCap()}) +} + +func (s *Server) handleTables(w http.ResponseWriter, r *http.Request) { + pool, ok := s.poolForRequest(w, r) + if !ok { + return + } + ctx, cancel := context.WithTimeout(r.Context(), s.queryWait) + defer cancel() + tables, err := pool.Tables(ctx) + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to list tables: "+err.Error()) + return + } + if tables == nil { + tables = []db.TableInfo{} + } + writeJSON(w, http.StatusOK, tables) +} + +func (s *Server) handleColumns(w http.ResponseWriter, r *http.Request) { + pool, ok := s.poolForRequest(w, r) + if !ok { + return + } + ctx, cancel := context.WithTimeout(r.Context(), s.queryWait) + defer cancel() + cols, err := pool.Columns(ctx, r.PathValue("schema"), r.PathValue("table")) + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to read columns: "+err.Error()) + return + } + if cols == nil { + cols = []db.ColumnInfo{} + } + writeJSON(w, http.StatusOK, cols) +} + +func (s *Server) handleForeignKeys(w http.ResponseWriter, r *http.Request) { + pool, ok := s.poolForRequest(w, r) + if !ok { + return + } + ctx, cancel := context.WithTimeout(r.Context(), s.queryWait) + defer cancel() + fks, err := pool.ForeignKeys(ctx, r.PathValue("schema"), r.PathValue("table")) + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to read foreign keys: "+err.Error()) + return + } + if fks == nil { + fks = []db.ForeignKey{} + } + writeJSON(w, http.StatusOK, fks) +} + +func (s *Server) handleTableData(w http.ResponseWriter, r *http.Request) { + pool, ok := s.poolForRequest(w, r) + if !ok { + return + } + q := db.TableQuery{ + Schema: r.PathValue("schema"), + Table: r.PathValue("table"), + Search: r.URL.Query().Get("search"), + Sort: r.URL.Query().Get("sort"), + Desc: r.URL.Query().Get("dir") == "desc", + Limit: queryInt(r, "limit", 0), + Offset: queryInt(r, "offset", 0), + Filters: parseFilters(r.URL.Query()["f"]), + } + ctx, cancel := context.WithTimeout(r.Context(), s.queryWait) + defer cancel() + res, err := pool.TableRows(ctx, q) + if err != nil { + writeError(w, http.StatusBadRequest, "failed to read rows: "+err.Error()) + return + } + if r.URL.Query().Get("format") == "csv" { + w.Header().Set("Content-Type", "text/csv; charset=utf-8") + w.Header().Set("Content-Disposition", `attachment; filename="`+safeFilename(r.PathValue("table"))+`.csv"`) + if err := writeCSV(w, res); err != nil { + s.log.Error("csv export", "err", err) + } + return + } + writeJSON(w, http.StatusOK, res) +} + +func (s *Server) handleReady(w http.ResponseWriter, r *http.Request) { + pool, ok := s.poolForRequest(w, r) + if !ok { + return + } + ctx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + defer cancel() + if err := pool.Ping(ctx); err != nil { + writeError(w, http.StatusServiceUnavailable, "database not ready") + return + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ready")) +} diff --git a/internal/server/catalog_handlers_test.go b/internal/server/catalog_handlers_test.go new file mode 100644 index 0000000..d11acd6 --- /dev/null +++ b/internal/server/catalog_handlers_test.go @@ -0,0 +1,235 @@ +package server + +import ( + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/descope/pgpeek/internal/db" +) + +// --- catalog / browse endpoints ----------------------------------------- + +func TestMeta(t *testing.T) { + ts, _ := newTestServer(t, &fakeQuerier{}) + resp := mustGet(t, ts, "/api/meta") + got := decode[map[string]int](t, resp) + if got["rowCap"] != 1000 { + t.Errorf("rowCap = %d, want 1000", got["rowCap"]) + } +} + +func TestSafeFilename(t *testing.T) { + cases := map[string]string{ + "users": "users", + `a"; x`: "a___x", + "sch.tbl": "sch.tbl", + "évil/../x": "_vil_.._x", + "": "table", + "weird name!@#$": "weird_name____", + } + for in, want := range cases { + if got := safeFilename(in); got != want { + t.Errorf("safeFilename(%q) = %q, want %q", in, got, want) + } + } +} + +func TestTables_OK(t *testing.T) { + q := &fakeQuerier{tables: []db.TableInfo{{Schema: "public", Name: "users", Type: "table", EstRows: 5}}} + ts, _ := newTestServer(t, q) + resp := mustGet(t, ts, "/api/tables") + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d", resp.StatusCode) + } + got := decode[[]db.TableInfo](t, resp) + if len(got) != 1 || got[0].Name != "users" { + t.Errorf("tables = %+v", got) + } +} + +func TestTables_EmptyReturnsArray(t *testing.T) { + srv := serverWithStore(t, &fakeQuerier{}, &fakeStore{}) + rec := httptest.NewRecorder() + srv.Routes().ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/api/tables", nil)) + if strings.TrimSpace(rec.Body.String()) != "[]" { + t.Errorf("body = %q, want []", rec.Body.String()) + } +} + +func TestTables_Error(t *testing.T) { + ts, _ := newTestServer(t, &fakeQuerier{catErr: errors.New("boom")}) + resp := mustGet(t, ts, "/api/tables") + resp.Body.Close() + if resp.StatusCode != http.StatusInternalServerError { + t.Errorf("status = %d, want 500", resp.StatusCode) + } +} + +func TestColumns_OK(t *testing.T) { + q := &fakeQuerier{cols: []db.ColumnInfo{{Name: "id", Type: "integer"}}} + ts, _ := newTestServer(t, q) + resp := mustGet(t, ts, "/api/tables/public/users/columns") + got := decode[[]db.ColumnInfo](t, resp) + if len(got) != 1 || got[0].Name != "id" { + t.Errorf("columns = %+v", got) + } + if q.lastArgs.schema != "public" || q.lastArgs.table != "users" { + t.Errorf("path values not passed: %+v", q.lastArgs) + } +} + +func TestColumns_EmptyReturnsArray(t *testing.T) { + srv := serverWithStore(t, &fakeQuerier{}, &fakeStore{}) + rec := httptest.NewRecorder() + srv.Routes().ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/api/tables/public/users/columns", nil)) + if strings.TrimSpace(rec.Body.String()) != "[]" { + t.Errorf("body = %q, want []", rec.Body.String()) + } +} + +func TestColumns_Error(t *testing.T) { + ts, _ := newTestServer(t, &fakeQuerier{catErr: errors.New("boom")}) + resp := mustGet(t, ts, "/api/tables/public/users/columns") + resp.Body.Close() + if resp.StatusCode != http.StatusInternalServerError { + t.Errorf("status = %d, want 500", resp.StatusCode) + } +} + +func TestForeignKeys_OK(t *testing.T) { + q := &fakeQuerier{fks: []db.ForeignKey{{Column: "company_id", RefSchema: "public", RefTable: "companies", RefColumn: "id"}}} + ts, _ := newTestServer(t, q) + resp := mustGet(t, ts, "/api/tables/public/users/fks") + got := decode[[]db.ForeignKey](t, resp) + if len(got) != 1 || got[0].RefTable != "companies" { + t.Errorf("fks = %+v", got) + } +} + +func TestForeignKeys_EmptyReturnsArray(t *testing.T) { + srv := serverWithStore(t, &fakeQuerier{}, &fakeStore{}) + rec := httptest.NewRecorder() + srv.Routes().ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/api/tables/public/users/fks", nil)) + if strings.TrimSpace(rec.Body.String()) != "[]" { + t.Errorf("body = %q, want []", rec.Body.String()) + } +} + +func TestForeignKeys_Error(t *testing.T) { + ts, _ := newTestServer(t, &fakeQuerier{catErr: errors.New("boom")}) + resp := mustGet(t, ts, "/api/tables/public/users/fks") + resp.Body.Close() + if resp.StatusCode != http.StatusInternalServerError { + t.Errorf("status = %d, want 500", resp.StatusCode) + } +} + +func TestTableData_OK(t *testing.T) { + q := &fakeQuerier{result: okResult()} + ts, _ := newTestServer(t, q) + resp := mustGet(t, ts, "/api/tables/public/users/data?limit=25&offset=50") + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d", resp.StatusCode) + } + decode[db.Result](t, resp) + if q.lastArgs.limit != 25 || q.lastArgs.offset != 50 { + t.Errorf("limit/offset = %d/%d", q.lastArgs.limit, q.lastArgs.offset) + } +} + +func TestTableData_ParsesSearchSortFilters(t *testing.T) { + q := &fakeQuerier{result: okResult()} + ts, _ := newTestServer(t, q) + resp := mustGet(t, ts, "/api/tables/public/users/data?search=acme&sort=id&dir=desc&f=id:gt:100&f=name:ilike:%25a%25&f=deleted_at:is_null") + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d", resp.StatusCode) + } + lq := q.lastQuery + if lq.Search != "acme" || lq.Sort != "id" || !lq.Desc { + t.Errorf("search/sort/desc = %+v", lq) + } + if len(lq.Filters) != 3 { + t.Fatalf("filters = %+v", lq.Filters) + } + if lq.Filters[0] != (db.Filter{Column: "id", Op: "gt", Value: "100"}) { + t.Errorf("filter0 = %+v", lq.Filters[0]) + } + if lq.Filters[1] != (db.Filter{Column: "name", Op: "ilike", Value: "%a%"}) { + t.Errorf("filter1 = %+v", lq.Filters[1]) + } + if lq.Filters[2] != (db.Filter{Column: "deleted_at", Op: "is_null"}) { + t.Errorf("filter2 (no value) = %+v", lq.Filters[2]) + } +} + +func TestParseFilters(t *testing.T) { + if parseFilters(nil) != nil { + t.Error("nil input should yield nil") + } + got := parseFilters([]string{"id:gt:100", "name:is_null", "bare"}) + want := []db.Filter{ + {Column: "id", Op: "gt", Value: "100"}, + {Column: "name", Op: "is_null"}, + {Column: "bare"}, + } + if len(got) != len(want) { + t.Fatalf("len = %d", len(got)) + } + for i := range want { + if got[i] != want[i] { + t.Errorf("filter %d = %+v, want %+v", i, got[i], want[i]) + } + } +} + +func TestTableData_DefaultsAndBadParams(t *testing.T) { + q := &fakeQuerier{result: okResult()} + ts, _ := newTestServer(t, q) + // Non-numeric params fall back to defaults (0 -> pool clamps). + resp := mustGet(t, ts, "/api/tables/public/users/data?limit=abc") + resp.Body.Close() + if q.lastArgs.limit != 0 || q.lastArgs.offset != 0 { + t.Errorf("expected defaults, got %d/%d", q.lastArgs.limit, q.lastArgs.offset) + } +} + +func TestTableData_Error(t *testing.T) { + ts, _ := newTestServer(t, &fakeQuerier{err: errors.New("no such table")}) + resp := mustGet(t, ts, "/api/tables/public/nope/data") + resp.Body.Close() + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("status = %d, want 400", resp.StatusCode) + } +} + +func TestTableData_CSV(t *testing.T) { + q := &fakeQuerier{result: &db.Result{Columns: []string{"a"}, Rows: [][]any{{"1"}}, RowCount: 1}} + ts, _ := newTestServer(t, q) + resp := mustGet(t, ts, "/api/tables/public/users/data?format=csv") + defer resp.Body.Close() + if ct := resp.Header.Get("Content-Type"); !strings.HasPrefix(ct, "text/csv") { + t.Errorf("content-type = %q", ct) + } + if cd := resp.Header.Get("Content-Disposition"); !strings.Contains(cd, "users.csv") { + t.Errorf("disposition = %q", cd) + } + body, _ := io.ReadAll(resp.Body) + if !strings.Contains(string(body), "a\n1\n") { + t.Errorf("csv body = %q", body) + } +} + +func TestTableData_CSVWriteFailure(t *testing.T) { + srv := serverWithStore(t, &fakeQuerier{result: okResult()}, &fakeStore{}) + req := httptest.NewRequest(http.MethodGet, "/api/tables/public/users/data?format=csv", nil) + fw := &failingWriter{} + srv.handleTableData(fw, req) + if ct := fw.Header().Get("Content-Type"); !strings.HasPrefix(ct, "text/csv") { + t.Errorf("content-type = %q", ct) + } +} From 4e3490d21a32838d78ab8b601704e4cb01283036 Mon Sep 17 00:00:00 2001 From: Omer <639682+omercnet@users.noreply.github.com> Date: Thu, 25 Jun 2026 14:40:54 +0000 Subject: [PATCH 05/31] feat(server): route queries by database --- internal/server/query_handlers.go | 85 ++++++++++++++++ internal/server/query_handlers_test.go | 130 +++++++++++++++++++++++++ 2 files changed, 215 insertions(+) create mode 100644 internal/server/query_handlers.go create mode 100644 internal/server/query_handlers_test.go diff --git a/internal/server/query_handlers.go b/internal/server/query_handlers.go new file mode 100644 index 0000000..0c94813 --- /dev/null +++ b/internal/server/query_handlers.go @@ -0,0 +1,85 @@ +package server + +import ( + "context" + "encoding/csv" + "io" + "net/http" + "strings" + + "github.com/descope/pgpeek/internal/db" + "github.com/descope/pgpeek/internal/guard" +) + +type queryRequest struct { + SQL string `json:"sql"` +} + +func (s *Server) handleQuery(w http.ResponseWriter, r *http.Request) { + pool, ok := s.poolForRequest(w, r) + if !ok { + return + } + sql, ok := decodeReadOnlyQuery(w, r) + if !ok { + return + } + ctx, cancel := context.WithTimeout(r.Context(), s.queryWait) + defer cancel() + res, err := pool.Query(ctx, sql) + if err != nil { + writeError(w, http.StatusBadRequest, "query failed: "+err.Error()) + return + } + writeJSON(w, http.StatusOK, res) +} + +func (s *Server) handleExport(w http.ResponseWriter, r *http.Request) { + pool, ok := s.poolForRequest(w, r) + if !ok { + return + } + sql, ok := decodeReadOnlyQuery(w, r) + if !ok { + return + } + ctx, cancel := context.WithTimeout(r.Context(), s.queryWait) + defer cancel() + res, err := pool.Query(ctx, sql) + if err != nil { + writeError(w, http.StatusBadRequest, "query failed: "+err.Error()) + return + } + w.Header().Set("Content-Type", "text/csv; charset=utf-8") + w.Header().Set("Content-Disposition", `attachment; filename="pgpeek-export.csv"`) + if err := writeCSV(w, res); err != nil { + s.log.Error("csv export", "err", err) + } +} + +func decodeReadOnlyQuery(w http.ResponseWriter, r *http.Request) (string, bool) { + var req queryRequest + if !decodeBody(w, r, &req) { + return "", false + } + sql := strings.TrimSpace(req.SQL) + if err := guard.Validate(sql); err != nil { + writeError(w, http.StatusBadRequest, err.Error()) + return "", false + } + return sql, true +} + +func writeCSV(w io.Writer, res *db.Result) error { + cw := csv.NewWriter(w) + _ = cw.Write(res.Columns) + row := make([]string, len(res.Columns)) + for _, rec := range res.Rows { + for i, cell := range rec { + row[i] = db.CellString(cell) + } + _ = cw.Write(row) + } + cw.Flush() + return cw.Error() +} diff --git a/internal/server/query_handlers_test.go b/internal/server/query_handlers_test.go new file mode 100644 index 0000000..02df0d5 --- /dev/null +++ b/internal/server/query_handlers_test.go @@ -0,0 +1,130 @@ +package server + +import ( + "errors" + "io" + "net/http" + "strings" + "testing" + + "github.com/descope/pgpeek/internal/db" +) + +func TestQuery_OK(t *testing.T) { + q := &fakeQuerier{result: okResult()} + ts, _ := newTestServer(t, q) + + resp := post(t, ts, "/api/query", `{"sql":" SELECT 1 "}`) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d", resp.StatusCode) + } + res := decode[db.Result](t, resp) + if res.RowCount != 1 { + t.Errorf("rowCount = %d", res.RowCount) + } + if q.lastSQL != "SELECT 1" { + t.Errorf("SQL not trimmed before exec: %q", q.lastSQL) + } +} + +func TestQuery_GuardRejectsDML(t *testing.T) { + q := &fakeQuerier{result: okResult()} + ts, _ := newTestServer(t, q) + + resp := post(t, ts, "/api/query", `{"sql":"DELETE FROM users"}`) + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("status = %d, want 400", resp.StatusCode) + } + if q.called { + t.Error("guard should block the query before it reaches the database") + } +} + +func TestQuery_InvalidJSON(t *testing.T) { + ts, _ := newTestServer(t, &fakeQuerier{}) + resp := post(t, ts, "/api/query", `{not json`) + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("status = %d, want 400", resp.StatusCode) + } +} + +func TestQuery_UnknownField(t *testing.T) { + ts, _ := newTestServer(t, &fakeQuerier{}) + resp := post(t, ts, "/api/query", `{"sql":"SELECT 1","evil":true}`) + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("status = %d, want 400 (DisallowUnknownFields)", resp.StatusCode) + } +} + +func TestQuery_DBError(t *testing.T) { + q := &fakeQuerier{err: errors.New("boom")} + ts, _ := newTestServer(t, q) + resp := post(t, ts, "/api/query", `{"sql":"SELECT 1"}`) + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("status = %d, want 400", resp.StatusCode) + } +} + +func TestQuery_BodyTooLarge(t *testing.T) { + ts, _ := newTestServer(t, &fakeQuerier{result: okResult()}) + huge := strings.Repeat("a", (1<<20)+10) + resp := post(t, ts, "/api/query", `{"sql":"`+huge+`"}`) + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("status = %d, want 400 for oversized body", resp.StatusCode) + } +} + +func TestExport_CSV(t *testing.T) { + q := &fakeQuerier{result: &db.Result{ + Columns: []string{"name", "n"}, + Rows: [][]any{{"Acme", int64(2)}, {"Globex,Inc", int64(1)}}, + RowCount: 2, + }} + ts, _ := newTestServer(t, q) + + resp := post(t, ts, "/api/export", `{"sql":"SELECT 1"}`) + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d", resp.StatusCode) + } + if ct := resp.Header.Get("Content-Type"); !strings.HasPrefix(ct, "text/csv") { + t.Errorf("content-type = %q", ct) + } + if cd := resp.Header.Get("Content-Disposition"); !strings.Contains(cd, "pgpeek-export.csv") { + t.Errorf("content-disposition = %q", cd) + } + body, _ := io.ReadAll(resp.Body) + got := string(body) + if !strings.Contains(got, "name,n") || !strings.Contains(got, "Acme,2") { + t.Errorf("csv body = %q", got) + } + // Field with a comma must be quoted by encoding/csv. + if !strings.Contains(got, `"Globex,Inc"`) { + t.Errorf("comma field not quoted: %q", got) + } +} + +func TestExport_GuardRejects(t *testing.T) { + q := &fakeQuerier{result: okResult()} + ts, _ := newTestServer(t, q) + resp := post(t, ts, "/api/export", `{"sql":"DROP TABLE x"}`) + resp.Body.Close() + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("status = %d, want 400", resp.StatusCode) + } +} + +func TestExport_DBError(t *testing.T) { + // Given: CSV export receives a read-only query but database execution fails. + q := &fakeQuerier{err: errors.New("boom")} + ts, _ := newTestServer(t, q) + + // When: export is requested. + resp := post(t, ts, "/api/export", `{"sql":"SELECT 1"}`) + defer resp.Body.Close() + + // Then: handler returns the same bad-request contract as query execution. + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("status = %d, want 400", resp.StatusCode) + } +} From 7162d69e54be5c921a50bc49a79e6ec4367e20da Mon Sep 17 00:00:00 2001 From: Omer <639682+omercnet@users.noreply.github.com> Date: Thu, 25 Jun 2026 14:41:00 +0000 Subject: [PATCH 06/31] refactor(server): split saved query handlers --- internal/server/saved_query_errors_test.go | 149 +++++++++++++++++++ internal/server/saved_query_handlers.go | 101 +++++++++++++ internal/server/saved_query_handlers_test.go | 93 ++++++++++++ internal/server/store_error_test.go | 75 ++++++++++ 4 files changed, 418 insertions(+) create mode 100644 internal/server/saved_query_errors_test.go create mode 100644 internal/server/saved_query_handlers.go create mode 100644 internal/server/saved_query_handlers_test.go create mode 100644 internal/server/store_error_test.go diff --git a/internal/server/saved_query_errors_test.go b/internal/server/saved_query_errors_test.go new file mode 100644 index 0000000..3d4ea86 --- /dev/null +++ b/internal/server/saved_query_errors_test.go @@ -0,0 +1,149 @@ +package server + +import ( + "errors" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "testing" + + "github.com/descope/pgpeek/internal/db" +) + +func TestListEmptyReturnsArray(t *testing.T) { + // A successful List of an empty store returns nil; the handler must emit []. + srv := serverWithStore(t, &fakeQuerier{}, &fakeStore{}) + req := httptest.NewRequest(http.MethodGet, "/api/queries", nil) + rec := httptest.NewRecorder() + srv.Routes().ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("code = %d", rec.Code) + } + if body := strings.TrimSpace(rec.Body.String()); body != "[]" { + t.Errorf("body = %q, want []", body) + } +} + +func TestUpdate_InvalidJSON(t *testing.T) { + ts, _ := newTestServer(t, &fakeQuerier{}) + req, _ := http.NewRequest(http.MethodPut, ts.URL+"/api/queries/1", strings.NewReader(`{bad`)) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + resp.Body.Close() + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("status = %d, want 400", resp.StatusCode) + } +} + +func TestUpdate_NotFound(t *testing.T) { + ts, _ := newTestServer(t, &fakeQuerier{}) + req, _ := http.NewRequest(http.MethodPut, ts.URL+"/api/queries/999", strings.NewReader(`{"name":"x","sql":"SELECT 1"}`)) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + resp.Body.Close() + if resp.StatusCode != http.StatusNotFound { + t.Errorf("status = %d, want 404", resp.StatusCode) + } +} + +func TestDelete_BadID(t *testing.T) { + ts, _ := newTestServer(t, &fakeQuerier{}) + req, _ := http.NewRequest(http.MethodDelete, ts.URL+"/api/queries/abc", nil) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + resp.Body.Close() + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("status = %d, want 400", resp.StatusCode) + } +} + +func TestExport_QueryError(t *testing.T) { + ts, _ := newTestServer(t, &fakeQuerier{err: errors.New("boom")}) + resp := post(t, ts, "/api/export", `{"sql":"SELECT 1"}`) + resp.Body.Close() + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("status = %d, want 400", resp.StatusCode) + } +} + +func TestExport_InvalidJSON(t *testing.T) { + ts, _ := newTestServer(t, &fakeQuerier{}) + resp := post(t, ts, "/api/export", `{bad`) + resp.Body.Close() + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("status = %d, want 400", resp.StatusCode) + } +} + +func TestSavedQuery_InvalidJSON(t *testing.T) { + ts, _ := newTestServer(t, &fakeQuerier{}) + resp := post(t, ts, "/api/queries", `{bad`) + resp.Body.Close() + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("status = %d, want 400", resp.StatusCode) + } +} + +// failingWriter is an http.ResponseWriter whose body writes always fail. +type failingWriter struct { + header http.Header + code int +} + +func (f *failingWriter) Header() http.Header { + if f.header == nil { + f.header = http.Header{} + } + return f.header +} +func (f *failingWriter) Write([]byte) (int, error) { return 0, errors.New("connection reset") } +func (f *failingWriter) WriteHeader(code int) { f.code = code } + +func TestExport_WriteFailureLogged(t *testing.T) { + // Exercise the handler's csv-error branch directly with a writer that fails. + srv := serverWithStore(t, &fakeQuerier{result: okResult()}, &fakeStore{}) + req := httptest.NewRequest(http.MethodPost, "/api/export", strings.NewReader(`{"sql":"SELECT 1"}`)) + fw := &failingWriter{} + srv.handleExport(fw, req) + // Header is set before the failing body write. + if ct := fw.Header().Get("Content-Type"); !strings.HasPrefix(ct, "text/csv") { + t.Errorf("content-type = %q", ct) + } +} + +func TestWriteCSV(t *testing.T) { + res := &db.Result{Columns: []string{"a"}, Rows: [][]any{{"1"}}} + var buf strings.Builder + if err := writeCSV(&buf, res); err != nil { + t.Fatalf("writeCSV: %v", err) + } + if !strings.Contains(buf.String(), "a\n1\n") { + t.Errorf("csv = %q", buf.String()) + } + // Failing writer -> error surfaces via cw.Error() after Flush. + if err := writeCSV(failWriter{}, res); err == nil { + t.Error("expected error from failing writer") + } +} + +type failWriter struct{} + +func (failWriter) Write([]byte) (int, error) { return 0, errors.New("boom") } + +func mustGet(t *testing.T, ts *httptest.Server, path string) *http.Response { + t.Helper() + resp, err := http.Get(ts.URL + path) + if err != nil { + t.Fatalf("GET %s: %v", path, err) + } + return resp +} + +func itoa(i int64) string { return strconv.FormatInt(i, 10) } diff --git a/internal/server/saved_query_handlers.go b/internal/server/saved_query_handlers.go new file mode 100644 index 0000000..c23f0e2 --- /dev/null +++ b/internal/server/saved_query_handlers.go @@ -0,0 +1,101 @@ +package server + +import ( + "errors" + "net/http" + "strings" + + "github.com/descope/pgpeek/internal/guard" + "github.com/descope/pgpeek/internal/store" +) + +func (s *Server) handleListQueries(w http.ResponseWriter, r *http.Request) { + qs, err := s.store.List(r.Context()) + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to list saved queries") + s.log.Error("list saved queries", "err", err) + return + } + if qs == nil { + qs = []store.SavedQuery{} + } + writeJSON(w, http.StatusOK, qs) +} + +type savedQueryRequest struct { + Name string `json:"name"` + Description string `json:"description"` + SQL string `json:"sql"` +} + +func (s *Server) handleCreateQuery(w http.ResponseWriter, r *http.Request) { + req, ok := decodeSavedQuery(w, r) + if !ok { + return + } + q, err := s.store.Create(r.Context(), req.Name, req.Description, req.SQL) + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to save query") + s.log.Error("create saved query", "err", err) + return + } + writeJSON(w, http.StatusCreated, q) +} + +func (s *Server) handleUpdateQuery(w http.ResponseWriter, r *http.Request) { + id, ok := pathID(w, r) + if !ok { + return + } + req, ok := decodeSavedQuery(w, r) + if !ok { + return + } + q, err := s.store.Update(r.Context(), id, req.Name, req.Description, req.SQL) + if errors.Is(err, store.ErrNotFound) { + writeError(w, http.StatusNotFound, "saved query not found") + return + } + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to update query") + s.log.Error("update saved query", "err", err) + return + } + writeJSON(w, http.StatusOK, q) +} + +func (s *Server) handleDeleteQuery(w http.ResponseWriter, r *http.Request) { + id, ok := pathID(w, r) + if !ok { + return + } + err := s.store.Delete(r.Context(), id) + if errors.Is(err, store.ErrNotFound) { + writeError(w, http.StatusNotFound, "saved query not found") + return + } + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to delete query") + s.log.Error("delete saved query", "err", err) + return + } + w.WriteHeader(http.StatusNoContent) +} + +func decodeSavedQuery(w http.ResponseWriter, r *http.Request) (savedQueryRequest, bool) { + var req savedQueryRequest + if !decodeBody(w, r, &req) { + return req, false + } + req.Name = strings.TrimSpace(req.Name) + req.SQL = strings.TrimSpace(req.SQL) + if req.Name == "" || req.SQL == "" { + writeError(w, http.StatusBadRequest, "name and sql are required") + return req, false + } + if err := guard.Validate(req.SQL); err != nil { + writeError(w, http.StatusBadRequest, "saved query must be read-only: "+err.Error()) + return req, false + } + return req, true +} diff --git a/internal/server/saved_query_handlers_test.go b/internal/server/saved_query_handlers_test.go new file mode 100644 index 0000000..8eda0d1 --- /dev/null +++ b/internal/server/saved_query_handlers_test.go @@ -0,0 +1,93 @@ +package server + +import ( + "net/http" + "strings" + "testing" + + "github.com/descope/pgpeek/internal/store" +) + +func TestSavedQueries_CRUD(t *testing.T) { + ts, _ := newTestServer(t, &fakeQuerier{result: okResult()}) + + // Create + resp := post(t, ts, "/api/queries", `{"name":"q","description":"d","sql":"SELECT 1"}`) + if resp.StatusCode != http.StatusCreated { + t.Fatalf("create status = %d", resp.StatusCode) + } + created := decode[store.SavedQuery](t, resp) + if created.ID == 0 { + t.Fatal("no id returned") + } + + // List + resp = mustGet(t, ts, "/api/queries") + list := decode[[]store.SavedQuery](t, resp) + if len(list) != 1 { + t.Fatalf("list len = %d", len(list)) + } + + // Update + req, _ := http.NewRequest(http.MethodPut, ts.URL+"/api/queries/"+itoa(created.ID), + strings.NewReader(`{"name":"q2","sql":"SELECT 2"}`)) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("update status = %d", resp.StatusCode) + } + resp.Body.Close() + + // Delete + req, _ = http.NewRequest(http.MethodDelete, ts.URL+"/api/queries/"+itoa(created.ID), nil) + resp, err = http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != http.StatusNoContent { + t.Fatalf("delete status = %d", resp.StatusCode) + } + resp.Body.Close() +} + +func TestSavedQueries_Validation(t *testing.T) { + ts, _ := newTestServer(t, &fakeQuerier{}) + + cases := []struct { + name, body string + }{ + {"missing name", `{"sql":"SELECT 1"}`}, + {"missing sql", `{"name":"x"}`}, + {"non-readonly sql", `{"name":"x","sql":"DELETE FROM t"}`}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + resp := post(t, ts, "/api/queries", c.body) + resp.Body.Close() + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("status = %d, want 400", resp.StatusCode) + } + }) + } +} + +func TestSavedQueries_NotFoundAndBadID(t *testing.T) { + ts, _ := newTestServer(t, &fakeQuerier{}) + + req, _ := http.NewRequest(http.MethodDelete, ts.URL+"/api/queries/999", nil) + resp, _ := http.DefaultClient.Do(req) + if resp.StatusCode != http.StatusNotFound { + t.Errorf("delete missing status = %d, want 404", resp.StatusCode) + } + resp.Body.Close() + + req, _ = http.NewRequest(http.MethodPut, ts.URL+"/api/queries/abc", + strings.NewReader(`{"name":"x","sql":"SELECT 1"}`)) + resp, _ = http.DefaultClient.Do(req) + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("bad id status = %d, want 400", resp.StatusCode) + } + resp.Body.Close() +} diff --git a/internal/server/store_error_test.go b/internal/server/store_error_test.go new file mode 100644 index 0000000..f59b4f0 --- /dev/null +++ b/internal/server/store_error_test.go @@ -0,0 +1,75 @@ +package server + +import ( + "context" + "errors" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "strings" + "testing" + "testing/fstest" + "time" + + "github.com/descope/pgpeek/internal/store" +) + +// --- error-path coverage with a fake store and failing writer ------------ + +type fakeStore struct { + listErr, getErr, createErr, updateErr, deleteErr error +} + +func (f *fakeStore) List(context.Context) ([]store.SavedQuery, error) { + return nil, f.listErr +} +func (f *fakeStore) Get(context.Context, int64) (store.SavedQuery, error) { + return store.SavedQuery{}, f.getErr +} +func (f *fakeStore) Create(context.Context, string, string, string) (store.SavedQuery, error) { + return store.SavedQuery{}, f.createErr +} +func (f *fakeStore) Update(context.Context, int64, string, string, string) (store.SavedQuery, error) { + return store.SavedQuery{}, f.updateErr +} +func (f *fakeStore) Delete(context.Context, int64) error { return f.deleteErr } + +func serverWithStore(t *testing.T, q Querier, st QueryStore) *Server { + t.Helper() + web := fstest.MapFS{"index.html": &fstest.MapFile{Data: []byte("x")}} + log := slog.New(slog.NewTextHandler(io.Discard, nil)) + return New(q, st, web, log, time.Second) +} + +func TestStoreErrorPaths(t *testing.T) { + boom := errors.New("db down") + srv := serverWithStore(t, &fakeQuerier{}, &fakeStore{ + listErr: boom, createErr: boom, updateErr: boom, deleteErr: boom, + }) + mux := srv.Routes() + + do := func(method, path, body string) int { + var rdr io.Reader + if body != "" { + rdr = strings.NewReader(body) + } + req := httptest.NewRequest(method, path, rdr) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + return rec.Code + } + + if code := do(http.MethodGet, "/api/queries", ""); code != http.StatusInternalServerError { + t.Errorf("list error code = %d, want 500", code) + } + if code := do(http.MethodPost, "/api/queries", `{"name":"x","sql":"SELECT 1"}`); code != http.StatusInternalServerError { + t.Errorf("create error code = %d, want 500", code) + } + if code := do(http.MethodPut, "/api/queries/1", `{"name":"x","sql":"SELECT 1"}`); code != http.StatusInternalServerError { + t.Errorf("update error code = %d, want 500", code) + } + if code := do(http.MethodDelete, "/api/queries/1", ""); code != http.StatusInternalServerError { + t.Errorf("delete error code = %d, want 500", code) + } +} From fc927338d4979aef370a391623df217065041acf Mon Sep 17 00:00:00 2001 From: Omer <639682+omercnet@users.noreply.github.com> Date: Thu, 25 Jun 2026 14:41:07 +0000 Subject: [PATCH 07/31] feat(server): add database selection --- internal/server/database_selection_test.go | 247 ++++++++++++++++++ internal/server/registry.go | 102 ++++++++ .../registry_adapter_integration_test.go | 52 ++++ internal/server/registry_unit_test.go | 112 ++++++++ 4 files changed, 513 insertions(+) create mode 100644 internal/server/database_selection_test.go create mode 100644 internal/server/registry.go create mode 100644 internal/server/registry_adapter_integration_test.go create mode 100644 internal/server/registry_unit_test.go diff --git a/internal/server/database_selection_test.go b/internal/server/database_selection_test.go new file mode 100644 index 0000000..f41a6e6 --- /dev/null +++ b/internal/server/database_selection_test.go @@ -0,0 +1,247 @@ +package server + +import ( + "context" + "encoding/json" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "strings" + "testing" + "testing/fstest" + "time" + + "github.com/descope/pgpeek/internal/db" + "github.com/descope/pgpeek/internal/store" +) + +type fakeRegistry struct { + defaultID string + metadata []db.PoolMetadata + pools map[string]Querier +} + +func (f fakeRegistry) List() []db.PoolMetadata { return f.metadata } +func (f fakeRegistry) DefaultID() string { return f.defaultID } + +func (f fakeRegistry) Pool(id string) (Querier, error) { + if id == "" { + id = f.defaultID + } + pool, ok := f.pools[id] + if !ok { + return nil, db.ErrPoolNotFound + } + return pool, nil +} + +func (f fakeRegistry) Ping(ctx context.Context) error { + for _, pool := range f.pools { + if err := pool.Ping(ctx); err != nil { + return err + } + } + return nil +} + +type selectedQuerier struct { + rowCap int + used bool + pinged bool +} + +func (q *selectedQuerier) Query(context.Context, string) (*db.Result, error) { + q.used = true + return okResult(), nil +} + +func (q *selectedQuerier) Tables(context.Context) ([]db.TableInfo, error) { + q.used = true + return []db.TableInfo{}, nil +} + +func (q *selectedQuerier) Columns(context.Context, string, string) ([]db.ColumnInfo, error) { + q.used = true + return []db.ColumnInfo{}, nil +} + +func (q *selectedQuerier) ForeignKeys(context.Context, string, string) ([]db.ForeignKey, error) { + q.used = true + return []db.ForeignKey{}, nil +} + +func (q *selectedQuerier) TableRows(context.Context, db.TableQuery) (*db.Result, error) { + q.used = true + return okResult(), nil +} + +func (q *selectedQuerier) RowCap() int { + q.used = true + return q.rowCap +} + +func (q *selectedQuerier) Ping(context.Context) error { + q.pinged = true + return nil +} + +func newRegistryTestServer(t *testing.T, registry DatabaseRegistry) *httptest.Server { + t.Helper() + st, err := store.Open(t.TempDir() + "/t.db") + if err != nil { + t.Fatalf("store.Open: %v", err) + } + t.Cleanup(func() { _ = st.Close() }) + web := fstest.MapFS{"index.html": &fstest.MapFile{Data: []byte("x")}} + log := slog.New(slog.NewTextHandler(io.Discard, nil)) + srv := NewWithRegistry(registry, st, web, log, time.Second) + ts := httptest.NewServer(srv.Routes()) + t.Cleanup(ts.Close) + return ts +} + +func TestDatabases_lists_safe_metadata(t *testing.T) { + primary := &selectedQuerier{rowCap: 1000} + registry := fakeRegistry{ + defaultID: "primary", + metadata: []db.PoolMetadata{ + {ID: "primary", Name: "Primary"}, + {ID: "analytics", Name: "Analytics"}, + }, + pools: map[string]Querier{"primary": primary, "analytics": &selectedQuerier{}}, + } + ts := newRegistryTestServer(t, registry) + + resp := mustGet(t, ts, "/api/databases") + got := decode[struct { + DefaultID string `json:"defaultId"` + Databases []db.PoolMetadata `json:"databases"` + }](t, resp) + + if got.DefaultID != "primary" || len(got.Databases) != 2 || got.Databases[0].ID != "primary" || got.Databases[1].Name != "Analytics" { + t.Fatalf("databases = %+v", got) + } + body := marshalString(t, got) + if strings.Contains(body, "postgres://") || strings.Contains(body, "dsn") { + t.Fatalf("database metadata leaked secret material: %s", body) + } +} + +func TestDatabaseSelection_uses_selected_pool_for_db_bound_endpoints(t *testing.T) { + tests := []struct { + name string + method string + path string + body string + }{ + {name: "readyz", method: http.MethodGet, path: "/readyz?db=analytics"}, + {name: "meta", method: http.MethodGet, path: "/api/meta?db=analytics"}, + {name: "query", method: http.MethodPost, path: "/api/query?db=analytics", body: `{"sql":"SELECT 1"}`}, + {name: "export", method: http.MethodPost, path: "/api/export?db=analytics", body: `{"sql":"SELECT 1"}`}, + {name: "tables", method: http.MethodGet, path: "/api/tables?db=analytics"}, + {name: "columns", method: http.MethodGet, path: "/api/tables/public/users/columns?db=analytics"}, + {name: "fks", method: http.MethodGet, path: "/api/tables/public/users/fks?db=analytics"}, + {name: "data", method: http.MethodGet, path: "/api/tables/public/users/data?db=analytics"}, + {name: "data csv", method: http.MethodGet, path: "/api/tables/public/users/data?format=csv&db=analytics"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + primary := &selectedQuerier{rowCap: 1000} + analytics := &selectedQuerier{rowCap: 2000} + ts := newRegistryTestServer(t, fakeRegistry{ + defaultID: "primary", + metadata: []db.PoolMetadata{{ID: "primary", Name: "Primary"}, {ID: "analytics", Name: "Analytics"}}, + pools: map[string]Querier{"primary": primary, "analytics": analytics}, + }) + req, err := http.NewRequest(tt.method, ts.URL+tt.path, strings.NewReader(tt.body)) + if err != nil { + t.Fatal(err) + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + if tt.name == "readyz" { + if !analytics.pinged || primary.pinged { + t.Fatalf("pinged primary=%v analytics=%v", primary.pinged, analytics.pinged) + } + return + } + if !analytics.used || primary.used { + t.Fatalf("used primary=%v analytics=%v", primary.used, analytics.used) + } + }) + } +} + +func TestDatabaseSelection_missing_or_empty_db_uses_default(t *testing.T) { + primary := &selectedQuerier{rowCap: 1000} + analytics := &selectedQuerier{rowCap: 2000} + ts := newRegistryTestServer(t, fakeRegistry{ + defaultID: "primary", + metadata: []db.PoolMetadata{{ID: "primary", Name: "Primary"}, {ID: "analytics", Name: "Analytics"}}, + pools: map[string]Querier{"primary": primary, "analytics": analytics}, + }) + + resp := mustGet(t, ts, "/api/meta?db=") + got := decode[map[string]int](t, resp) + + if got["rowCap"] != 1000 || !primary.used || analytics.used { + t.Fatalf("default selection failed: body=%+v primary=%v analytics=%v", got, primary.used, analytics.used) + } +} + +func TestDatabaseSelection_unknown_db_returns_404(t *testing.T) { + primary := &selectedQuerier{rowCap: 1000} + ts := newRegistryTestServer(t, fakeRegistry{ + defaultID: "primary", + metadata: []db.PoolMetadata{{ID: "primary", Name: "Primary"}}, + pools: map[string]Querier{"primary": primary}, + }) + + resp := mustGet(t, ts, "/api/tables?db=missing") + got := decode[map[string]string](t, resp) + + if resp.StatusCode != http.StatusNotFound { + t.Fatalf("status = %d, want 404", resp.StatusCode) + } + if got["error"] == "" || strings.Contains(got["error"], "postgres://") { + t.Fatalf("bad error body: %+v", got) + } + if primary.used { + t.Fatal("unknown db should not use default pool") + } +} + +func TestDatabaseSelection_saved_query_endpoints_ignore_db(t *testing.T) { + primary := &selectedQuerier{rowCap: 1000} + ts := newRegistryTestServer(t, fakeRegistry{ + defaultID: "primary", + metadata: []db.PoolMetadata{{ID: "primary", Name: "Primary"}}, + pools: map[string]Querier{"primary": primary}, + }) + + resp := post(t, ts, "/api/queries?db=missing", `{"name":"q","sql":"SELECT 1"}`) + created := decode[store.SavedQuery](t, resp) + + if resp.StatusCode != http.StatusCreated || created.ID == 0 { + t.Fatalf("create saved query status=%d query=%+v", resp.StatusCode, created) + } + if primary.used || primary.pinged { + t.Fatal("saved query endpoint should not select a database pool") + } +} + +func marshalString(t *testing.T, v any) string { + t.Helper() + b, err := json.Marshal(v) + if err != nil { + t.Fatalf("marshal: %v", err) + } + return string(b) +} diff --git a/internal/server/registry.go b/internal/server/registry.go new file mode 100644 index 0000000..c3c51e4 --- /dev/null +++ b/internal/server/registry.go @@ -0,0 +1,102 @@ +package server + +import ( + "context" + "errors" + "net/http" + "strings" + + "github.com/descope/pgpeek/internal/db" + "github.com/descope/pgpeek/internal/store" +) + +type Querier interface { + Query(ctx context.Context, sql string) (*db.Result, error) + Tables(ctx context.Context) ([]db.TableInfo, error) + Columns(ctx context.Context, schema, table string) ([]db.ColumnInfo, error) + ForeignKeys(ctx context.Context, schema, table string) ([]db.ForeignKey, error) + TableRows(ctx context.Context, q db.TableQuery) (*db.Result, error) + RowCap() int + Ping(ctx context.Context) error +} + +type QueryStore interface { + List(ctx context.Context) ([]store.SavedQuery, error) + Get(ctx context.Context, id int64) (store.SavedQuery, error) + Create(ctx context.Context, name, desc, sql string) (store.SavedQuery, error) + Update(ctx context.Context, id int64, name, desc, sql string) (store.SavedQuery, error) + Delete(ctx context.Context, id int64) error +} + +type DatabaseRegistry interface { + List() []db.PoolMetadata + DefaultID() string + Pool(id string) (Querier, error) + Ping(ctx context.Context) error +} + +type singleDatabaseRegistry struct { + pool Querier +} + +func NewSingleDatabaseRegistry(pool Querier) DatabaseRegistry { + return singleDatabaseRegistry{pool: pool} +} + +func NewDatabaseRegistry(registry *db.Registry) DatabaseRegistry { + return dbRegistryAdapter{registry: registry} +} + +func (r singleDatabaseRegistry) List() []db.PoolMetadata { + return []db.PoolMetadata{{ID: "default", Name: "Default"}} +} + +func (r singleDatabaseRegistry) DefaultID() string { return "default" } + +func (r singleDatabaseRegistry) Pool(id string) (Querier, error) { + if id == "" || id == r.DefaultID() { + return r.pool, nil + } + return nil, db.ErrPoolNotFound +} + +func (r singleDatabaseRegistry) Ping(ctx context.Context) error { + return r.pool.Ping(ctx) +} + +type dbRegistryAdapter struct { + registry *db.Registry +} + +type databasesResponse struct { + DefaultID string `json:"defaultId"` + Databases []db.PoolMetadata `json:"databases"` +} + +func (r dbRegistryAdapter) List() []db.PoolMetadata { return r.registry.List() } + +func (r dbRegistryAdapter) DefaultID() string { return r.registry.DefaultID() } + +func (r dbRegistryAdapter) Pool(id string) (Querier, error) { + return r.registry.Pool(id) +} + +func (r dbRegistryAdapter) Ping(ctx context.Context) error { return r.registry.Ping(ctx) } + +func (s *Server) handleDatabases(w http.ResponseWriter, _ *http.Request) { + writeJSON(w, http.StatusOK, databasesResponse{DefaultID: s.registry.DefaultID(), Databases: s.registry.List()}) +} + +func (s *Server) poolForRequest(w http.ResponseWriter, r *http.Request) (Querier, bool) { + pool, err := s.registry.Pool(strings.TrimSpace(r.URL.Query().Get("db"))) + if errors.Is(err, db.ErrPoolNotFound) { + writeError(w, http.StatusNotFound, "database not found") + return nil, false + } + if err != nil { + writeError(w, http.StatusInternalServerError, "database unavailable") + s.log.Error("select database", "err", err) + return nil, false + } + return pool, true +} diff --git a/internal/server/registry_adapter_integration_test.go b/internal/server/registry_adapter_integration_test.go new file mode 100644 index 0000000..714aef7 --- /dev/null +++ b/internal/server/registry_adapter_integration_test.go @@ -0,0 +1,52 @@ +//go:build integration + +package server + +import ( + "context" + "os" + "testing" + "time" + + "github.com/descope/pgpeek/internal/db" +) + +func TestDatabaseRegistryAdapter_delegates_to_db_registry(t *testing.T) { + // Given: a real database registry backs the server adapter. + dsn := os.Getenv("PGPEEK_TEST_DATABASE_URL") + if dsn == "" { + t.Skip("PGPEEK_TEST_DATABASE_URL not set") + } + pool, err := db.New(context.Background(), db.Config{ + DSN: dsn, + MaxConns: 1, + StatementTimeout: 5 * time.Second, + IdleTxTimeout: 5 * time.Second, + RowCap: 10, + }) + if err != nil { + t.Fatalf("db.New: %v", err) + } + t.Cleanup(pool.Close) + registry, err := db.NewRegistry([]db.RegistryEntry{{ID: "primary", Name: "Primary", Pool: pool, Default: true}}) + if err != nil { + t.Fatalf("NewRegistry: %v", err) + } + adapter := NewDatabaseRegistry(registry) + + // When: server code calls adapter methods. + metadata := adapter.List() + querier, poolErr := adapter.Pool("primary") + pingErr := adapter.Ping(context.Background()) + + // Then: calls return the underlying registry data and pool. + if adapter.DefaultID() != "primary" || len(metadata) != 1 || metadata[0].Name != "Primary" { + t.Fatalf("default=%q metadata=%+v", adapter.DefaultID(), metadata) + } + if poolErr != nil || querier == nil { + t.Fatalf("Pool error=%v querier=%v", poolErr, querier) + } + if pingErr != nil { + t.Fatalf("Ping: %v", pingErr) + } +} diff --git a/internal/server/registry_unit_test.go b/internal/server/registry_unit_test.go new file mode 100644 index 0000000..08e9fb3 --- /dev/null +++ b/internal/server/registry_unit_test.go @@ -0,0 +1,112 @@ +package server + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/descope/pgpeek/internal/db" +) + +func TestSingleDatabaseRegistry_methods_select_default_or_reject_unknown(t *testing.T) { + // Given: a single-database registry wraps one querier. + q := &fakeQuerier{} + registry := NewSingleDatabaseRegistry(q) + + // When: callers inspect metadata and select pools. + metadata := registry.List() + defaultPool, defaultErr := registry.Pool("") + namedPool, namedErr := registry.Pool("default") + _, missingErr := registry.Pool("missing") + pingErr := registry.Ping(context.Background()) + + // Then: only the default id is accepted and delegated. + if len(metadata) != 1 || metadata[0].ID != "default" || registry.DefaultID() != "default" { + t.Fatalf("metadata=%+v default=%q", metadata, registry.DefaultID()) + } + if defaultErr != nil || namedErr != nil || pingErr != nil { + t.Fatalf("defaultErr=%v namedErr=%v pingErr=%v", defaultErr, namedErr, pingErr) + } + if defaultPool != q || namedPool != q { + t.Fatalf("pool selection returned wrong querier") + } + if !errors.Is(missingErr, db.ErrPoolNotFound) { + t.Fatalf("missingErr=%v, want ErrPoolNotFound", missingErr) + } +} + +func TestPoolForRequest_returns_500_when_registry_selection_fails(t *testing.T) { + // Given: registry lookup fails with a non-not-found error. + srv := serverWithStore(t, &fakeQuerier{}, &fakeStore{}) + srv.registry = failingRegistry{err: errors.New("registry unavailable")} + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/meta?db=primary", nil) + + // When: a handler selects a pool for the request. + pool, ok := srv.poolForRequest(rec, req) + + // Then: callers receive a sanitized 500 response and no pool. + if ok || pool != nil { + t.Fatalf("poolForRequest ok=%v pool=%v, want no pool", ok, pool) + } + if rec.Code != http.StatusInternalServerError || strings.Contains(rec.Body.String(), "registry unavailable") { + t.Fatalf("response code=%d body=%q", rec.Code, rec.Body.String()) + } +} + +func TestHandlers_return_not_found_when_selected_database_missing(t *testing.T) { + tests := []struct { + name string + method string + path string + body string + }{ + {name: "meta", method: http.MethodGet, path: "/api/meta?db=missing"}, + {name: "columns", method: http.MethodGet, path: "/api/tables/public/users/columns?db=missing"}, + {name: "foreign keys", method: http.MethodGet, path: "/api/tables/public/users/fks?db=missing"}, + {name: "table data", method: http.MethodGet, path: "/api/tables/public/users/data?db=missing"}, + {name: "query", method: http.MethodPost, path: "/api/query?db=missing", body: `{"sql":"SELECT 1"}`}, + {name: "export", method: http.MethodPost, path: "/api/export?db=missing", body: `{"sql":"SELECT 1"}`}, + {name: "ready", method: http.MethodGet, path: "/readyz?db=missing"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Given: server has no database matching the request id. + ts := newRegistryTestServer(t, fakeRegistry{ + defaultID: "primary", + metadata: []db.PoolMetadata{{ID: "primary", Name: "Primary"}}, + pools: map[string]Querier{"primary": &selectedQuerier{rowCap: 1000}}, + }) + req, err := http.NewRequest(tt.method, ts.URL+tt.path, strings.NewReader(tt.body)) + if err != nil { + t.Fatal(err) + } + + // When: the request is served. + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + // Then: the handler stops at database selection. + if resp.StatusCode != http.StatusNotFound { + t.Fatalf("status=%d, want 404", resp.StatusCode) + } + }) + } +} + +type failingRegistry struct { + err error +} + +func (f failingRegistry) List() []db.PoolMetadata { return nil } +func (f failingRegistry) DefaultID() string { return "default" } +func (f failingRegistry) Pool(string) (Querier, error) { + return nil, f.err +} +func (f failingRegistry) Ping(context.Context) error { return f.err } From a51e7d7d354570d1e5e5943dc28069be36275f45 Mon Sep 17 00:00:00 2001 From: Omer <639682+omercnet@users.noreply.github.com> Date: Thu, 25 Jun 2026 14:41:14 +0000 Subject: [PATCH 08/31] feat(main): wire database registry --- main.go | 58 +++++++++++++++++++++++++++++++++------------------- main_test.go | 51 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 88 insertions(+), 21 deletions(-) diff --git a/main.go b/main.go index 61a1653..0c6b038 100644 --- a/main.go +++ b/main.go @@ -14,6 +14,7 @@ import ( "context" "embed" "errors" + "fmt" "io/fs" "log/slog" "net/http" @@ -50,32 +51,47 @@ func run(ctx context.Context, log *slog.Logger) error { return err } - dbCfg := db.Config{ - DSN: cfg.DB.DSN, - MaxConns: cfg.DB.MaxConns, - StatementTimeout: cfg.DB.StatementTimeout, - IdleTxTimeout: cfg.DB.IdleTxTimeout, - RowCap: cfg.RowCap, + signalCtx, stop := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM) + defer stop() + + entries := make([]db.RegistryEntry, 0, len(cfg.Databases)) + closeEntries := func() { + for i := len(entries) - 1; i >= 0; i-- { + entries[i].Pool.Close() + } } - if cfg.DB.IAMAuth { - provider, perr := awsauth.New(ctx, cfg.DB.Region) + for _, entry := range cfg.Databases { + dbCfg := db.Config{ + DSN: entry.DSN, + MaxConns: cfg.DB.MaxConns, + StatementTimeout: cfg.DB.StatementTimeout, + IdleTxTimeout: cfg.DB.IdleTxTimeout, + RowCap: cfg.RowCap, + } + if entry.IAMAuth { + provider, perr := awsauth.New(signalCtx, entry.Region) + if perr != nil { + closeEntries() + return fmt.Errorf("create IAM auth provider for database %q: %w", entry.ID, perr) + } + dbCfg.BeforeConnect = provider.BeforeConnect + log.Info("RDS IAM authentication enabled", "databaseID", entry.ID, "databaseName", entry.Name) + } + log.Info("connecting database", "databaseID", entry.ID, "databaseName", entry.Name, "default", entry.ID == cfg.DefaultDatabaseID) + pool, perr := db.New(signalCtx, dbCfg) if perr != nil { - return perr + closeEntries() + return fmt.Errorf("connect database %q: %w", entry.ID, perr) } - dbCfg.BeforeConnect = provider.BeforeConnect - log.Info("RDS IAM authentication enabled", "region", cfg.DB.Region) + entries = append(entries, db.RegistryEntry{ID: entry.ID, Name: entry.Name, Pool: pool, Default: entry.ID == cfg.DefaultDatabaseID}) } - - signalCtx, stop := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM) - defer stop() - - pool, err := db.New(signalCtx, dbCfg) + registry, err := db.NewRegistry(entries) if err != nil { - return err + closeEntries() + return fmt.Errorf("create database registry: %w", err) } - defer pool.Close() - log.Info("connected to database", "maxConns", cfg.DB.MaxConns, "rowCap", cfg.RowCap, - "statementTimeout", cfg.DB.StatementTimeout.String(), "iamAuth", cfg.DB.IAMAuth) + defer registry.Close() + log.Info("database registry ready", "databaseCount", len(entries), "defaultDatabaseID", registry.DefaultID()) st, err := store.Open(cfg.StorePath) if err != nil { @@ -87,7 +103,7 @@ func run(ctx context.Context, log *slog.Logger) error { } web := mustSubFS(webFiles, "web") - srv := server.New(pool, st, web, log, cfg.DB.StatementTimeout+5*time.Second) + srv := server.NewWithRegistry(server.NewDatabaseRegistry(registry), st, web, log, cfg.DB.StatementTimeout+5*time.Second) httpSrv := &http.Server{ Addr: cfg.Server.Listen, Handler: srv.Routes(), diff --git a/main_test.go b/main_test.go index 2f455dd..afa40b1 100644 --- a/main_test.go +++ b/main_test.go @@ -14,6 +14,8 @@ import ( "net/http" "os" "path/filepath" + "strconv" + "strings" "testing" "time" @@ -29,9 +31,37 @@ func clearAppEnv(t *testing.T) { "PGPEEK_ROW_CAP", "PGPEEK_MAX_CONNS", "PGPEEK_STATEMENT_TIMEOUT", "PGPEEK_DB_IAM_AUTH", "PGPEEK_AWS_REGION", "AWS_REGION", "PGPEEK_TLS_CERT_FILE", "PGPEEK_TLS_KEY_FILE", + "PGPEEK_DATABASE_URLS", "PGPEEK_DATABASE_IDS", "PGPEEK_DATABASE_NAMES", + "PGPEEK_DATABASES_FILE", "PGPEEK_DEFAULT_DATABASE", } { t.Setenv(k, "") } + for i := 1; i <= 64; i++ { + suffix := strconv.Itoa(i) + t.Setenv("PGPEEK_DATABASE_URL_"+suffix, "") + t.Setenv("PGPEEK_DATABASE_URL_"+suffix+"_FILE", "") + t.Setenv("PGPEEK_DATABASE_ID_"+suffix, "") + t.Setenv("PGPEEK_DATABASE_NAME_"+suffix, "") + } +} + +func TestClearAppEnv_clearsMultiDatabaseVars(t *testing.T) { + // Given + t.Setenv("PGPEEK_DATABASE_URLS", "postgres://u:p@127.0.0.1:1/one") + t.Setenv("PGPEEK_DATABASE_IDS", "one") + t.Setenv("PGPEEK_DATABASE_URL_1", "postgres://u:p@127.0.0.1:1/two") + t.Setenv("PGPEEK_DATABASE_ID_1", "two") + t.Setenv("PGPEEK_DEFAULT_DATABASE", "one") + + // When + clearAppEnv(t) + + // Then + for _, key := range []string{"PGPEEK_DATABASE_URLS", "PGPEEK_DATABASE_IDS", "PGPEEK_DATABASE_URL_1", "PGPEEK_DATABASE_ID_1", "PGPEEK_DEFAULT_DATABASE"} { + if got := os.Getenv(key); got != "" { + t.Fatalf("%s=%q, want empty", key, got) + } + } } // --- serve() ------------------------------------------------------------- @@ -112,6 +142,27 @@ func TestRun_DBConnectError(t *testing.T) { } } +func TestRun_MultiDatabaseConnectError(t *testing.T) { + clearAppEnv(t) + // Given: two configured database entries, second selected as default. + t.Setenv("PGPEEK_DATABASE_URLS", "postgres://u:p@127.0.0.1:1/one?connect_timeout=1&sslmode=disable,postgres://u:p@127.0.0.1:2/two?connect_timeout=1&sslmode=disable") + t.Setenv("PGPEEK_DATABASE_IDS", "one,two") + t.Setenv("PGPEEK_DATABASE_NAMES", "One,Two") + t.Setenv("PGPEEK_DEFAULT_DATABASE", "two") + t.Setenv("PGPEEK_STORE_PATH", filepath.Join(t.TempDir(), "s.db")) + + // When + err := run(context.Background(), testLogger()) + + // Then + if err == nil { + t.Fatal("expected multi-database connect error") + } + if !strings.Contains(err.Error(), `database "one"`) { + t.Fatalf("error %q, want first database id", err) + } +} + func TestRun_IAMPathThenDBError(t *testing.T) { clearAppEnv(t) t.Setenv("DATABASE_URL", "postgres://u@127.0.0.1:1/db?connect_timeout=1&sslmode=disable") From 356dabc0edd7031132a0819ffdb6d2404d981b7c Mon Sep 17 00:00:00 2001 From: Omer <639682+omercnet@users.noreply.github.com> Date: Thu, 25 Jun 2026 14:41:20 +0000 Subject: [PATCH 09/31] feat(web): add database URL state --- web/api.js | 33 +++++ web/url-state.js | 58 ++++++++ web/url-state.test.js | 303 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 394 insertions(+) create mode 100644 web/api.js create mode 100644 web/url-state.js create mode 100644 web/url-state.test.js diff --git a/web/api.js b/web/api.js new file mode 100644 index 0000000..11c2a91 --- /dev/null +++ b/web/api.js @@ -0,0 +1,33 @@ +// API helpers that inject the active db id as ?db= on every DB-bound +// GET request and include it in JSON bodies for POST requests. + +export function dbUrl(path, dbId) { + if (!dbId) return path; + const sep = path.includes("?") ? "&" : "?"; + return path + sep + "db=" + encodeURIComponent(dbId); +} + +export async function getJSON(url, dbId) { + const r = await fetch(dbUrl(url, dbId)); + const body = await r.json(); + if (!r.ok) throw new Error(body.error || r.statusText); + return body; +} + +export const tablePath = (t) => + "/api/tables/" + encodeURIComponent(t.schema) + "/" + encodeURIComponent(t.name); + +export const tableKey = (t) => t.schema + "." + t.name; + +// appendDataParams adds search/sort/filter params shared by browse + export. +export function appendDataParams(p, search, sort, filters) { + if (search) p.set("search", search); + if (sort) { p.set("sort", sort.col); p.set("dir", sort.dir); } + for (const col of Object.keys(filters)) { + const f = filters[col]; + if (!f.op) continue; + const noVal = f.op === "is_null" || f.op === "is_not_null"; + if (noVal) p.append("f", col + ":" + f.op); + else p.append("f", col + ":" + f.op + ":" + (f.value || "")); + } +} diff --git a/web/url-state.js b/web/url-state.js new file mode 100644 index 0000000..a6d3e21 --- /dev/null +++ b/web/url-state.js @@ -0,0 +1,58 @@ +// URL state helpers — read, push, and replace browser history for pgpeek. +// Stable param names match API names where they exist. + +export function readUrlState() { + const p = new URLSearchParams(window.location.search); + const filters = {}; + for (const f of p.getAll("f")) { + const first = f.indexOf(":"); + if (first < 0) continue; + const col = f.slice(0, first); + const rest = f.slice(first + 1); + const second = rest.indexOf(":"); + const op = second >= 0 ? rest.slice(0, second) : rest; + const value = second >= 0 ? rest.slice(second + 1) : ""; + if (col && op) filters[col] = { op, value }; + } + const sort = p.get("sort") + ? { col: p.get("sort"), dir: p.get("dir") || "asc" } + : null; + return { + db: p.get("db") || null, + tab: ["data", "structure", "sql"].includes(p.get("tab")) ? p.get("tab") : "data", + schema: p.get("schema") || null, + table: p.get("table") || null, + offset: Math.max(0, parseInt(p.get("offset"), 10) || 0), + search: p.get("search") || "", + sort, + filters, + }; +} + +export function buildUrlParams(state) { + const p = new URLSearchParams(); + if (state.db) p.set("db", state.db); + if (state.tab && state.tab !== "data") p.set("tab", state.tab); + if (state.schema) p.set("schema", state.schema); + if (state.table) p.set("table", state.table); + if (state.offset) p.set("offset", String(state.offset)); + if (state.search) p.set("search", state.search); + if (state.sort) { p.set("sort", state.sort.col); p.set("dir", state.sort.dir); } + if (state.filters) { + for (const col of Object.keys(state.filters)) { + const f = state.filters[col]; + if (!f || !f.op) continue; + const noVal = f.op === "is_null" || f.op === "is_not_null"; + p.append("f", noVal ? `${col}:${f.op}` : `${col}:${f.op}:${f.value || ""}`); + } + } + return p; +} + +function qs(p) { const s = p.toString(); return s ? "?" + s : ""; } + +export const pushUrlState = (state) => + window.history.pushState(null, "", window.location.pathname + qs(buildUrlParams(state))); + +export const replaceUrlState = (state) => + window.history.replaceState(null, "", window.location.pathname + qs(buildUrlParams(state))); diff --git a/web/url-state.test.js b/web/url-state.test.js new file mode 100644 index 0000000..724f53e --- /dev/null +++ b/web/url-state.test.js @@ -0,0 +1,303 @@ +// @vitest-environment jsdom +// Tests: URL db param, URL tab/table state, popstate restore, +// url-state and api module helpers, and coverage-completion edge cases. +import { describe, it, expect, beforeEach, afterEach, vi } from "vitest"; +import { + flush, makeResp, TWO_DBS, NO_DBS, SAMPLE_TABLES, + makeInstallFetch, $, click, changeSelect, loadApp, urlOf, + readUrlState, buildUrlParams, dbUrl, +} from "./test-helpers.js"; + +let routes; +function setRoute(key, resp) { routes[key] = resp; } +const installFetch = makeInstallFetch(() => routes); + +function defaultRoutes() { + return { + "GET /api/databases": makeResp({ json: TWO_DBS }), + "GET /api/meta": makeResp({ json: { rowCap: 100 } }), + "GET /api/tables": makeResp({ json: [] }), + "GET /api/tables/*/columns": makeResp({ json: [] }), + "GET /api/tables/*/fks": makeResp({ json: [] }), + "GET /api/queries": makeResp({ json: [] }), + }; +} + +beforeEach(() => { + document.body.innerHTML = '
'; + window.history.replaceState({}, "", "/"); + routes = defaultRoutes(); + installFetch(); + globalThis.prompt = vi.fn(); + globalThis.confirm = vi.fn(); + globalThis.URL.createObjectURL = vi.fn(() => "blob:fake"); + globalThis.URL.revokeObjectURL = vi.fn(); + HTMLAnchorElement.prototype.click = vi.fn(); + Element.prototype.scrollIntoView = vi.fn(); + globalThis.requestAnimationFrame = (cb) => setTimeout(cb, 0); + globalThis.cancelAnimationFrame = (id) => clearTimeout(id); + window.requestAnimationFrame = globalThis.requestAnimationFrame; + window.cancelAnimationFrame = globalThis.cancelAnimationFrame; + delete window.CodeMirror; + delete globalThis.CodeMirror; +}); + +afterEach(() => { + vi.restoreAllMocks(); + window.history.replaceState({}, "", "/"); + delete window.CodeMirror; + delete globalThis.CodeMirror; +}); + +// ── URL db param ────────────────────────────────────────────────────────────── + +describe("URL db param", () => { + it("writes defaultId into URL on load when no ?db= in URL", async () => { + await loadApp(); + expect(new URLSearchParams(window.location.search).get("db")).toBe("pg1"); + }); + + it("uses ?db= from URL when it matches a known database", async () => { + window.history.replaceState({}, "", "/?db=pg2"); + await loadApp(); + expect($("database-select").value).toBe("pg2"); + expect(new URLSearchParams(window.location.search).get("db")).toBe("pg2"); + }); + + it("falls back to defaultId and shows error when ?db= is unknown", async () => { + window.history.replaceState({}, "", "/?db=nonexistent"); + await loadApp(); + expect($("status").textContent).toContain("unknown database"); + expect($("database-select").value).toBe("pg1"); + }); + + it("falls back to first database when defaultId is null and URL has no db", async () => { + setRoute("GET /api/databases", makeResp({ + json: { defaultId: null, databases: [{ id: "first", name: "First" }, { id: "second", name: "Second" }] }, + })); + await loadApp(); + expect(new URLSearchParams(window.location.search).get("db")).toBe("first"); + }); + + it("pushes new db into URL when database is switched", async () => { + await loadApp(); + await changeSelect($("database-select"), "pg2"); + expect(new URLSearchParams(window.location.search).get("db")).toBe("pg2"); + }); + + it("db switch clears schema/table/offset/search from URL", async () => { + setRoute("GET /api/tables", makeResp({ json: SAMPLE_TABLES })); + setRoute("GET /api/tables/*/data", makeResp({ json: { columns: ["id"], rows: [[1]], rowCount: 1, elapsedMs: 1 } })); + await loadApp(); + await click($("tables").querySelectorAll(".tbl")[0]); + expect(new URLSearchParams(window.location.search).get("table")).toBe("users"); + await changeSelect($("database-select"), "pg2"); + const p = new URLSearchParams(window.location.search); + expect(p.has("schema")).toBe(false); + expect(p.has("table")).toBe(false); + expect(p.has("offset")).toBe(false); + }); +}); + +// ── URL state — tab and table ───────────────────────────────────────────────── + +describe("URL state — tab and table", () => { + it("restores tab=sql from URL on initial load", async () => { + window.history.replaceState({}, "", "/?db=pg1&tab=sql"); + await loadApp(); + expect($("panel-sql").hidden).toBe(false); + expect($("panel-data").hidden).toBe(true); + }); + + it("data tab is default; URL has no tab param when on data tab after load", async () => { + await loadApp(); + expect(new URLSearchParams(window.location.search).has("tab")).toBe(false); + }); + + it("pushes tab into URL when tab changes", async () => { + await loadApp(); + await click("tab-sql"); + expect(new URLSearchParams(window.location.search).get("tab")).toBe("sql"); + }); + + it("tab=data is omitted from URL params (canonical form)", async () => { + await loadApp(); + await click("tab-sql"); + await click("tab-data"); + expect(new URLSearchParams(window.location.search).has("tab")).toBe(false); + }); + + it("pushes schema and table into URL when table is selected", async () => { + setRoute("GET /api/tables", makeResp({ json: SAMPLE_TABLES })); + setRoute("GET /api/tables/*/data", makeResp({ json: { columns: ["id"], rows: [[1]], rowCount: 1, elapsedMs: 1 } })); + await loadApp(); + await click($("tables").querySelectorAll(".tbl")[0]); + const p = new URLSearchParams(window.location.search); + expect(p.get("schema")).toBe("public"); + expect(p.get("table")).toBe("users"); + }); + + it("restores table from URL on initial load", async () => { + window.history.replaceState({}, "", "/?db=pg1&schema=public&table=users"); + setRoute("GET /api/tables", makeResp({ json: SAMPLE_TABLES })); + setRoute("GET /api/tables/*/data", makeResp({ json: { columns: ["id"], rows: [[1]], rowCount: 1, elapsedMs: 1 } })); + await loadApp(); + expect($("tab-title").textContent).toBe("public.users"); + }); + + it("gracefully ignores unknown table in URL on initial load", async () => { + window.history.replaceState({}, "", "/?db=pg1&schema=public&table=nope"); + setRoute("GET /api/tables", makeResp({ json: SAMPLE_TABLES })); + await loadApp(); + expect($("tab-title").textContent).toBe("Pick a table"); + }); +}); + +// ── URL state — popstate ────────────────────────────────────────────────────── + +describe("URL state — popstate", () => { + it("restores tab on popstate (sql → data)", async () => { + await loadApp(); + await click("tab-sql"); + expect($("panel-sql").hidden).toBe(false); + window.history.replaceState({}, "", "/?db=pg1"); + window.dispatchEvent(new PopStateEvent("popstate")); + await flush(); + expect($("panel-data").hidden).toBe(false); + expect($("panel-sql").hidden).toBe(true); + }); + + it("restores db on popstate", async () => { + await loadApp(); + await changeSelect($("database-select"), "pg2"); + expect($("database-select").value).toBe("pg2"); + window.history.replaceState({}, "", "/?db=pg1"); + window.dispatchEvent(new PopStateEvent("popstate")); + await flush(); + expect($("database-select").value).toBe("pg1"); + }); +}); + +// ── popstate with table in URL (branch coverage) ────────────────────────────── + +describe("popstate with table in URL", () => { + it("queues table restore when db changes via popstate and URL has schema+table", async () => { + // Covers app.js: !sameDb branch when URL has table. + setRoute("GET /api/tables", makeResp({ json: SAMPLE_TABLES })); + setRoute("GET /api/tables/*/data", makeResp({ json: { columns: ["id"], rows: [[1]], rowCount: 1, elapsedMs: 1 } })); + await loadApp(); + await changeSelect($("database-select"), "pg2"); + window.history.replaceState({}, "", "/?db=pg1&schema=public&table=users"); + window.dispatchEvent(new PopStateEvent("popstate")); + await flush(); + expect($("database-select").value).toBe("pg1"); + }); + + it("restores table from already-loaded list on same-db popstate", async () => { + // Covers app.js: sameDb branch when found. + setRoute("GET /api/tables", makeResp({ json: SAMPLE_TABLES })); + setRoute("GET /api/tables/*/data", makeResp({ json: { columns: ["id"], rows: [[1]], rowCount: 1, elapsedMs: 1 } })); + await loadApp(); + await click($("tables").querySelectorAll(".tbl")[1]); // public.posts + expect($("tab-title").textContent).toBe("public.posts"); + window.history.replaceState({}, "", "/?db=pg1&schema=public&table=users"); + window.dispatchEvent(new PopStateEvent("popstate")); + await flush(); + expect($("tab-title").textContent).toBe("public.users"); + }); + + it("clears current table on same-db popstate when table not found in list", async () => { + // Covers app.js: sameDb branch when NOT found. + setRoute("GET /api/tables", makeResp({ json: SAMPLE_TABLES })); + setRoute("GET /api/tables/*/data", makeResp({ json: { columns: ["id"], rows: [[1]], rowCount: 1, elapsedMs: 1 } })); + await loadApp(); + await click($("tables").querySelectorAll(".tbl")[0]); + expect($("tab-title").textContent).toBe("public.users"); + window.history.replaceState({}, "", "/?db=pg1&schema=public&table=missing"); + window.dispatchEvent(new PopStateEvent("popstate")); + await flush(); + expect($("tab-title").textContent).toBe("Pick a table"); + }); +}); + +// ── url-state module unit tests ─────────────────────────────────────────────── + +describe("url-state helpers", () => { + it("readUrlState parses all supported params", async () => { + window.history.replaceState({}, "", "/?db=x&tab=sql&schema=s&table=t&offset=50&search=foo&sort=id&dir=desc&f=id:eq:5&f=name:is_null"); + const s = readUrlState(); + expect(s.db).toBe("x"); + expect(s.tab).toBe("sql"); + expect(s.schema).toBe("s"); + expect(s.table).toBe("t"); + expect(s.offset).toBe(50); + expect(s.search).toBe("foo"); + expect(s.sort).toEqual({ col: "id", dir: "desc" }); + expect(s.filters["id"]).toEqual({ op: "eq", value: "5" }); + expect(s.filters["name"]).toEqual({ op: "is_null", value: "" }); + }); + + it("readUrlState falls back to 'data' for an invalid tab value", async () => { + window.history.replaceState({}, "", "/?tab=invalid"); + expect(readUrlState().tab).toBe("data"); + }); + + it("buildUrlParams omits tab when data, omits falsy fields", async () => { + const p = buildUrlParams({ db: "x", tab: "data", schema: null, table: null, offset: 0, search: "", sort: null, filters: {} }); + expect(p.has("tab")).toBe(false); + expect(p.has("schema")).toBe(false); + expect(p.get("db")).toBe("x"); + }); + + it("buildUrlParams encodes is_null filter without value segment", async () => { + const p = buildUrlParams({ db: null, tab: "data", schema: null, table: null, offset: 0, search: "", sort: null, + filters: { col: { op: "is_null", value: "" } } }); + expect(p.get("f")).toBe("col:is_null"); + }); +}); + +// ── api module unit tests ───────────────────────────────────────────────────── + +describe("api helpers", () => { + it("dbUrl appends ?db= to paths without query params", async () => { + expect(dbUrl("/api/tables", "pg1")).toBe("/api/tables?db=pg1"); + }); + + it("dbUrl appends &db= to paths that already have query params", async () => { + expect(dbUrl("/api/tables?limit=100", "pg1")).toBe("/api/tables?limit=100&db=pg1"); + }); + + it("dbUrl returns path unchanged when dbId is falsy", async () => { + expect(dbUrl("/api/tables", null)).toBe("/api/tables"); + expect(dbUrl("/api/tables", "")).toBe("/api/tables"); + }); +}); + +// ── url-state edge cases (branch coverage) ──────────────────────────────────── + +describe("url-state edge cases", () => { + it("readUrlState skips malformed filter entries that lack a colon", async () => { + // Covers url-state.js: 'if (first < 0) continue' branch. + window.history.replaceState({}, "", "/?f=nocoion&f=col:eq:5"); + const s = readUrlState(); + expect(s.filters["col"]).toEqual({ op: "eq", value: "5" }); + expect(Object.keys(s.filters)).not.toContain("nocoion"); + }); + + it("readUrlState defaults sort direction to 'asc' when dir param is absent", async () => { + // Covers url-state.js: p.get('dir') || 'asc' false branch. + window.history.replaceState({}, "", "/?sort=id"); + const s = readUrlState(); + expect(s.sort).toEqual({ col: "id", dir: "asc" }); + }); + + it("buildUrlParams skips filter entries with no op (null or falsy)", async () => { + // Covers url-state.js: '!f || !f.op' continue branch. + const p = buildUrlParams({ + db: null, tab: "data", schema: null, table: null, + offset: 0, search: "", sort: null, + filters: { nullcol: null, emptyop: { op: "", value: "x" } }, + }); + expect(p.has("f")).toBe(false); + }); +}); From 855f6f0fdb7a3ddcae93ee5e3d3be7a0d7198105 Mon Sep 17 00:00:00 2001 From: Omer <639682+omercnet@users.noreply.github.com> Date: Thu, 25 Jun 2026 14:41:35 +0000 Subject: [PATCH 10/31] refactor(web): split app modules --- web/app.js | 617 +++++++++++++------------------------------ web/data-tab.js | 179 +++++++++++++ web/sidebar.js | 54 ++++ web/sql-tab.js | 156 +++++++++++ web/structure-tab.js | 29 ++ web/theme.js | 40 +++ 6 files changed, 647 insertions(+), 428 deletions(-) create mode 100644 web/data-tab.js create mode 100644 web/sidebar.js create mode 100644 web/sql-tab.js create mode 100644 web/structure-tab.js create mode 100644 web/theme.js diff --git a/web/app.js b/web/app.js index a40d5c8..7796298 100644 --- a/web/app.js +++ b/web/app.js @@ -1,147 +1,28 @@ -// pgpeek UI — Preact + htm, no build step. Vendored Preact/htm keeps the app a -// set of static files embedded in the Go binary, CSP-safe (no eval), with the -// reactivity that the imperative version was outgrowing. +// pgpeek UI — Preact + htm, no build step. This file is the entrypoint; +// feature modules live in ./theme.js, ./sidebar.js, ./data-tab.js, etc. import { html, render, useState, useEffect, useRef, useCallback, } from "./vendor/preact-htm.js"; +import { ThemeSelect } from "./theme.js"; +import { Sidebar, Tabs } from "./sidebar.js"; +import { DataTab } from "./data-tab.js"; +import { StructureTab } from "./structure-tab.js"; +import { SqlTab } from "./sql-tab.js"; +import { getJSON, tableKey } from "./api.js"; +import { readUrlState, pushUrlState, replaceUrlState } from "./url-state.js"; const PAGE_SIZE = 100; -const THEME_KEY = "pgpeek-theme"; -// Switchable color themes. id "" = built-in default (the :root palette). -const THEMES = [ - ["", "Default"], - ["dark-plus", "Dark+"], - ["light-plus", "Light+"], - ["monokai", "Monokai"], - ["dracula", "Dracula"], - ["one-dark", "One Dark Pro"], - ["nord", "Nord"], - ["solarized-dark", "Solarized Dark"], - ["solarized-light", "Solarized Light"], - ["github-dark", "GitHub Dark"], - ["github-light", "GitHub Light"], - ["catppuccin-mocha", "Catppuccin Mocha"], - ["catppuccin-latte", "Catppuccin Latte"], - ["tokyo-night", "Tokyo Night"], - ["ayu-dark", "Ayu Dark"], - ["ayu-mirage", "Ayu Mirage"], - ["night-owl", "Night Owl"], - ["houston", "Houston"], - ["matcha", "Matcha"], - ["dainty", "Dainty"], -]; - -function getStoredTheme() { - try { return localStorage.getItem(THEME_KEY) || ""; } catch { return ""; } -} - -// applyTheme sets (or clears) the data-theme attribute that selects a palette. -function applyTheme(id) { - const root = document.documentElement; - if (id) root.setAttribute("data-theme", id); - else root.removeAttribute("data-theme"); -} - -// Apply the saved theme at import time to avoid a flash of the default palette. -applyTheme(getStoredTheme()); - -// Allowlisted filter operators (key sent to the server, label shown in the UI). -const OPS = [ - ["", "—"], ["eq", "="], ["ne", "≠"], ["lt", "<"], ["lte", "≤"], - ["gt", ">"], ["gte", "≥"], ["ilike", "ILIKE"], ["like", "LIKE"], - ["is_null", "IS NULL"], ["is_not_null", "NOT NULL"], -]; -const opNeedsValue = (op) => op !== "" && op !== "is_null" && op !== "is_not_null"; - -const tablePath = (t) => "/api/tables/" + encodeURIComponent(t.schema) + "/" + encodeURIComponent(t.name); -const tableKey = (t) => t.schema + "." + t.name; - -async function getJSON(url) { - const r = await fetch(url); - const body = await r.json(); - if (!r.ok) throw new Error(body.error || r.statusText); - return body; -} - -// appendDataParams adds search/sort/filter params shared by browse + export. -function appendDataParams(p, search, sort, filters) { - if (search) p.set("search", search); - if (sort) { p.set("sort", sort.col); p.set("dir", sort.dir); } - for (const col of Object.keys(filters)) { - const f = filters[col]; - if (!f.op) continue; - if (opNeedsValue(f.op)) p.append("f", col + ":" + f.op + ":" + (f.value || "")); - else p.append("f", col + ":" + f.op); - } -} - -function cellText(v) { - if (v === null || v === undefined) return null; - return typeof v === "object" ? JSON.stringify(v) : String(v); -} - -// Cell renders one value, as an FK link when fkRef is set. -function Cell({ value, fkRef, onNavigate }) { - const text = cellText(value); - if (text === null) return html`NULL`; - if (fkRef) { - return html``; - } - return html`${text}`; -} - -function BodyRows({ rows, fkByCol, onNavigate }) { - return rows.map((row) => html`${row.map((v, i) => html`<${Cell} value=${v} fkRef=${fkByCol && fkByCol[i]} onNavigate=${onNavigate} />`)}`); -} - -// ---- Sidebar ---- -function Sidebar({ tables, loaded, currentKey, onSelect }) { - const [filter, setFilter] = useState(""); - const listRef = useRef(); - const f = filter.toLowerCase(); - const items = []; - let schema = null; - - useEffect(() => { - const active = listRef.current && listRef.current.querySelector(".tbl.active"); - if (active && active.scrollIntoView) active.scrollIntoView({ block: "nearest" }); - }, [currentKey, filter]); - - for (const t of tables) { - const label = tableKey(t); - if (f && !label.toLowerCase().includes(f)) continue; - if (t.schema !== schema) { - schema = t.schema; - items.push(html`
${schema}
`); - } - const active = label === currentKey; - const cls = "tbl" + (t.type === "view" ? " view" : "") + (active ? " active" : ""); - items.push(html``); - } +// ---- Database selector (hidden when ≤1 database) ---- +function DatabaseSelect({ databases, currentDb, onSwitch }) { + if (databases.length <= 1) return null; return html` - `; -} - -// ---- Tabs ---- -function Tabs({ tab, setTab, title }) { - const btn = (id, label) => html``; - return html` -
- ${btn("data", "Data")} ${btn("structure", "Structure")} ${btn("sql", "SQL")} - ${title} -
`; + `; } function TableContext({ table }) { @@ -154,345 +35,225 @@ function TableContext({ table }) { return html`
Current ${table.type === "view" ? "view" : "table"} ${table.schema}.${table.name} - ${table.estRows >= 0 ? html`~${table.estRows} rows` : html`row count unavailable`} + ${table.estRows >= 0 + ? html`~${table.estRows} rows` + : html`row count unavailable`}
`; } -// ---- Data tab ---- -function DataTab({ table, pageSize, initialFilters, onNavigate, setStatus }) { - const [offset, setOffset] = useState(0); - const [search, setSearch] = useState(""); - const [searchBox, setSearchBox] = useState(""); - const [filters, setFilters] = useState(initialFilters || {}); - const [draft, setDraft] = useState(initialFilters || {}); - const [sort, setSort] = useState(null); - const [data, setData] = useState(null); - const [fks, setFks] = useState({}); +// ---- App ---- +function App() { + const [databases, setDatabases] = useState([]); + const [dbsLoaded, setDbsLoaded] = useState(false); + const [currentDb, setCurrentDb] = useState(null); + const [tables, setTables] = useState([]); + const [tablesLoaded, setTablesLoaded] = useState(false); + const [rowCap, setRowCap] = useState(PAGE_SIZE); + const [saved, setSaved] = useState([]); + const [tab, setTabState] = useState("data"); + const [current, setCurrent] = useState(null); + const [navKey, setNavKey] = useState(0); + const [pendingFilters, setPendingFilters] = useState(null); + const [urlInit, setUrlInit] = useState(null); + const [status, setStatus] = useState({ text: "Ready.", cls: "ok" }); + // Refs so popstate handler always sees the latest values. + const urlStateRef = useRef({}); + const dbRef = useRef(null); + const tablesRef = useRef([]); + + const setTab = useCallback((newTab) => { + setTabState(newTab); + const s = { ...urlStateRef.current, tab: newTab === "data" ? null : newTab }; + pushUrlState(s); + urlStateRef.current = s; + }, []); - useEffect(() => { - let live = true; - (async () => { - try { - const list = await getJSON(tablePath(table) + "/fks"); - if (!live) return; - const m = {}; - for (const fk of list) m[fk.column] = { schema: fk.refSchema, table: fk.refTable, column: fk.refColumn }; - setFks(m); - } catch { /* no FK links */ } - })(); - return () => { live = false; }; - }, [table]); + const reloadSaved = useCallback(async () => { + try { setSaved(await getJSON("/api/queries")); } + catch (e) { setStatus({ text: "✗ failed to load saved queries: " + e.message, cls: "error" }); } + }, []); + // Phase 1: fetch /api/databases, resolve active db, restore URL state, + // install popstate listener. useEffect(() => { - let live = true; - setStatus({ text: "Loading " + tableKey(table) + "…", cls: "ok" }); - const p = new URLSearchParams(); - p.set("limit", pageSize); - p.set("offset", offset); - appendDataParams(p, search, sort, filters); + const urlState = readUrlState(); + urlStateRef.current = urlState; + (async () => { + let dbId = null; try { - const d = await getJSON(tablePath(table) + "/data?" + p.toString()); - if (!live) return; - setData(d); - setStatus({ text: "✓ " + d.rowCount + " row" + (d.rowCount === 1 ? "" : "s") + " in " + d.elapsedMs + " ms", cls: "ok" }); + const r = await fetch("/api/databases"); + + if (!r.ok) throw new Error(r.statusText || "failed"); + const result = await r.json(); + const dbs = Array.isArray(result.databases) ? result.databases : []; + setDatabases(dbs); + const urlDb = urlState.db; + const valid = dbs.find((d) => d.id === urlDb); + dbId = valid ? urlDb : (result.defaultId || (dbs[0] && dbs[0].id) || null); + if (urlDb && !valid && dbs.length > 0) { + setStatus({ text: "✗ unknown database in URL, using default", cls: "error" }); + } } catch (e) { - if (live) setStatus({ text: "✗ " + e.message, cls: "error" }); - } - })(); - return () => { live = false; }; - }, [table, offset, search, JSON.stringify(filters), sort && sort.col, sort && sort.dir, pageSize]); - const applyDraft = useCallback((next) => { - const clean = {}; - for (const c of Object.keys(next)) if (next[c] && next[c].op) clean[c] = next[c]; - setFilters(clean); - setOffset(0); - }, []); - - const toggleSort = (col) => { - setOffset(0); - setSort((s) => (s && s.col === col ? { col, dir: s.dir === "asc" ? "desc" : "asc" } : { col, dir: "asc" })); - }; + setStatus({ text: "✗ failed to load databases: " + e.message, cls: "error" }); + } - const exportURL = () => { - const p = new URLSearchParams(); - p.set("format", "csv"); - appendDataParams(p, search, sort, filters); - return tablePath(table) + "/data?" + p.toString(); - }; - let grid; - if (!data || data.columns === undefined) { - grid = html`
Loading…
`; - } else if (!data.columns.length) { - grid = html`
No columns.
`; - } else { - const fkByCol = data.columns.map((c) => fks[c] || null); - grid = html` - - - ${data.columns.map((c) => html``)} - ${data.columns.map((c) => { - const d = draft[c] || {}; - return html``; - })} - - <${BodyRows} rows=${data.rows} fkByCol=${fkByCol} onNavigate=${onNavigate} /> -
toggleSort(c)}> - ${c}${sort && sort.col === c ? (sort.dir === "desc" ? " ▼" : " ▲") : ""}
- - setDraft({ ...draft, [c]: { op: d.op || "", value: e.target.value } })} - onKeyDown=${(e) => { if (e.key === "Enter") applyDraft({ ...draft, [c]: { op: d.op || "", value: e.target.value } }); }} /> -
- ${data.rows.length ? "" : html`
0 rows.
`}`; - } + // Restore tab from URL; build pending table-restore if schema+table present. + setTabState(urlState.tab); + const finalState = { ...urlState, db: dbId }; + replaceUrlState(finalState); + urlStateRef.current = finalState; + dbRef.current = dbId; - const rowCount = data && data.rowCount ? data.rowCount : 0; - const from = rowCount ? offset + 1 : 0; + if (urlState.schema && urlState.table) { + setUrlInit({ + schema: urlState.schema, table: urlState.table, + offset: urlState.offset, search: urlState.search, + sort: urlState.sort, filters: urlState.filters, + }); + } + setCurrentDb(dbId); + setDbsLoaded(true); + })(); - return html` -
- setSearchBox(e.target.value)} - onKeyDown=${(e) => { if (e.key === "Enter") { setOffset(0); setSearch(searchBox.trim()); } }} /> - - - - - ${from}–${offset + rowCount} - Export CSV -
-
${grid}
`; -} + const onPopstate = () => { + const s = readUrlState(); + const sameDb = s.db === dbRef.current; + urlStateRef.current = s; + dbRef.current = s.db; + setTabState(s.tab); + setCurrentDb(s.db); + if (!sameDb) { + // Tables effect will reload; queue table restore if URL has a table. + if (s.schema && s.table) { + setUrlInit({ schema: s.schema, table: s.table, + offset: s.offset, search: s.search, sort: s.sort, filters: s.filters }); + } else { + setCurrent(null); setUrlInit(null); + } + } else if (s.schema && s.table) { + const found = tablesRef.current.find((x) => x.schema === s.schema && x.name === s.table); + if (found) { setCurrent(found); setNavKey((k) => k + 1); setUrlInit(s); } + else setCurrent(null); + } else { + setCurrent(null); + } + }; + window.addEventListener("popstate", onPopstate); + return () => { window.removeEventListener("popstate", onPopstate); }; + }, []); -// ---- Structure tab ---- -function StructureTab({ table, setStatus }) { - const [cols, setCols] = useState(null); + // Phase 2: reload tables + meta whenever the resolved db changes. useEffect(() => { - let live = true; + if (!dbsLoaded) return; + setTablesLoaded(false); + tablesRef.current = []; + setTables([]); + const db = currentDb; (async () => { try { - const c = await getJSON(tablePath(table) + "/columns"); - if (live) setCols(c); + const t = await getJSON("/api/tables", db); + tablesRef.current = t; + setTables(t); + // Consume pending URL table-restore (set during initial load or popstate). + setUrlInit((prev) => { + if (!prev) return null; + const found = t.find((x) => x.schema === prev.schema && x.name === prev.table); + if (found) { setCurrent(found); setNavKey((k) => k + 1); } + return null; + }); } catch (e) { - if (live) setStatus({ text: "✗ " + e.message, cls: "error" }); - } + setStatus({ text: "✗ failed to load tables: " + e.message, cls: "error" }); + } finally { setTablesLoaded(true); } })(); - return () => { live = false; }; - }, [table]); - - let body; - if (cols === null) body = html`
Loading…
`; - else if (!cols.length) body = html`
No columns.
`; - else body = html` - - ${cols.map((c) => html` - `)} -
ColumnTypeNullableDefault
${c.name}${c.type}${c.nullable ? "YES" : "NO"}${c.default == null ? "" : c.default}
`; - return html`
${body}
`; -} - -// ---- SQL tab ---- -function SqlTab({ active, saved, reloadSaved, setStatus }) { - const wrapRef = useRef(); - const taRef = useRef(); - const cmRef = useRef(); - const [result, setResult] = useState(null); - const [lastSQL, setLastSQL] = useState(""); - const [selected, setSelected] = useState(""); - const [running, setRunning] = useState(false); - const runningRef = useRef(false); - - const getSQL = () => (cmRef.current ? cmRef.current.getValue() : taRef.current.value).trim(); - const setSQL = (v) => { if (cmRef.current) cmRef.current.setValue(v); else taRef.current.value = v; }; - - const run = useCallback(async () => { - const sql = getSQL(); - if (!sql) return; - if (runningRef.current) return; - runningRef.current = true; setRunning(true); - setStatus({ text: "Running…", cls: "ok" }); - try { - const r = await fetch("/api/query", { method: "POST", headers: { "Content-Type": "application/json" }, body: JSON.stringify({ sql }) }); - const d = await r.json(); - if (!r.ok) { setStatus({ text: "✗ " + (d.error || r.statusText), cls: "error" }); setResult(null); return; } - setLastSQL(sql); setResult(d); - const base = "✓ " + d.rowCount + " row" + (d.rowCount === 1 ? "" : "s") + " in " + d.elapsedMs + " ms"; - setStatus(d.truncated ? { text: base, cls: "ok", warn: "· capped (more rows available — add LIMIT or refine)" } : { text: base, cls: "ok" }); - } catch (e) { - setStatus({ text: "✗ " + e.message, cls: "error" }); - } finally { - runningRef.current = false; setRunning(false); - } - }, []); - - // Init the SQL editor once, into a Preact-stable wrapper it fully owns. Uses - // the vendored CodeMirror 6 bundle (window.cm6) when present, else degrades - // to a plain