diff --git a/README.md b/README.md index 0ab1533..ed5491e 100644 --- a/README.md +++ b/README.md @@ -71,15 +71,23 @@ browser ── HTTP ──> pgpeek (Go, single static binary) ## Configuration (env vars) -Everything is configured via the environment. Any value can also be supplied -from a **mounted file** by setting `_FILE` to a path (Docker secrets / k8s -projected volumes); the file's trimmed contents become the value. This is wired -for the secret-bearing `DATABASE_URL` (use `DATABASE_URL_FILE`). +Everything is configured via the environment. Single-database deployments can +keep using `DATABASE_URL`; multi-database deployments can use a URL list, +numbered env vars, or a mounted JSON config file. Secret-bearing URLs can be +supplied from mounted files so they do not live in manifests. | Variable | Default | Notes | | ---------------------------- | -------------------- | --------------------------------------------------------------------- | -| `DATABASE_URL` | _(required)_ | Postgres DSN. Use the read-only role. **Never logged.** Aurora: include `?sslmode=require`. | +| `DATABASE_URL` | single-DB required | Postgres DSN for single-database installs. Use the read-only role. **Never logged.** Aurora: include `?sslmode=require`. | | `DATABASE_URL_FILE` | — | Path to a file holding the DSN (mounted-secret alternative). | +| `PGPEEK_DATABASE_URLS` | — | Comma- or semicolon-separated DSNs for multiple databases. Quoted CSV values are supported. | +| `PGPEEK_DATABASE_IDS` | `db1`, `db2`, … | Optional comma/semicolon IDs matching `PGPEEK_DATABASE_URLS`; URL-safe (`A-Z`, `a-z`, `0-9`, `_`, `-`, `.`). | +| `PGPEEK_DATABASE_NAMES` | `Database N` | Optional display names matching `PGPEEK_DATABASE_URLS`. | +| `PGPEEK_DATABASE_URL_1` | — | Numbered DSN form. Continue with `_2`, `_3`, …; each also supports `_FILE`. | +| `PGPEEK_DATABASE_ID_1` | `db1` | Optional ID for numbered database 1. | +| `PGPEEK_DATABASE_NAME_1` | `Database 1` | Optional display name for numbered database 1. | +| `PGPEEK_DATABASES_FILE` | — | Path to a mounted JSON config file with database entries. | +| `PGPEEK_DEFAULT_DATABASE` | first configured DB | Default database ID when the URL has no `db=` parameter. | | `PGPEEK_LISTEN` | `:8080` | Listen address. | | `PGPEEK_ROW_CAP` | `1000` | Max rows returned/exported per query. | | `PGPEEK_STATEMENT_TIMEOUT` | `30s` | Per-query DB statement timeout. | @@ -95,6 +103,79 @@ for the secret-bearing `DATABASE_URL` (use `DATABASE_URL_FILE`). | `PGPEEK_DB_IAM_AUTH` | `false` | Use RDS/Aurora IAM auth instead of a password (see below). | | `PGPEEK_AWS_REGION` | `$AWS_REGION` | AWS region for IAM token signing (required when IAM auth is on). | +### Multiple databases / clusters + +The UI shows a database selector. The selected ID is kept in the URL as +`?db=` alongside table, tab, filter, sort, and pagination state, so links are +bookmarkable and shareable. + +Same-env list form: + +```bash +export PGPEEK_DATABASE_URLS='postgres://reader:PASSWORD@prod:5432/app?sslmode=require;postgres://reader:PASSWORD@analytics:5432/warehouse?sslmode=require' +export PGPEEK_DATABASE_IDS='prod;analytics' +export PGPEEK_DATABASE_NAMES='Production;Analytics' +export PGPEEK_DEFAULT_DATABASE=prod +``` + +Numbered env var form: + +```bash +export PGPEEK_DATABASE_URL_1_FILE=/run/secrets/prod-url +export PGPEEK_DATABASE_ID_1=prod +export PGPEEK_DATABASE_NAME_1=Production +export PGPEEK_DATABASE_URL_2_FILE=/run/secrets/analytics-url +export PGPEEK_DATABASE_ID_2=analytics +export PGPEEK_DATABASE_NAME_2=Analytics +``` + +Mounted config file form (`PGPEEK_DATABASES_FILE=/config/pgpeek/databases.json`): + +```json +{ + "default": "prod", + "databases": [ + { "id": "prod", "name": "Production", "urlFile": "/secrets/prod-url" }, + { "id": "analytics", "name": "Analytics", "urlFile": "/secrets/analytics-url" } + ] +} +``` + +Kubernetes example (ConfigMap-mounted config + Secret-mounted DSNs; illustrative +only, not an extra manifest to commit): + +```yaml +env: + - name: PGPEEK_DATABASES_FILE + value: /config/pgpeek/databases.json +volumeMounts: + - name: pgpeek-db-config + mountPath: /config/pgpeek + readOnly: true + - name: pgpeek-db-urls + mountPath: /secrets + readOnly: true +volumes: + - name: pgpeek-db-config + configMap: + name: pgpeek-db-config + - name: pgpeek-db-urls + secret: + secretName: pgpeek-db-urls +``` + +Docker Compose example (volume-mounted JSON + secret files; illustrative only): + +```yaml +services: + pgpeek: + environment: + PGPEEK_DATABASES_FILE: /config/pgpeek/databases.json + volumes: + - ./pgpeek-config:/config/pgpeek:ro + - ./pgpeek-secrets:/secrets:ro +``` + ### RDS / Aurora IAM authentication Set `PGPEEK_DB_IAM_AUTH=true` and `PGPEEK_AWS_REGION`. The `DATABASE_URL` then @@ -218,13 +299,14 @@ Two ways: | Method & path | Purpose | | --------------------------------------------- | ---------------------------------------------- | -| `POST /api/query` | Run a query → JSON `{columns, rows, …}`. | -| `POST /api/export` | Run a query → CSV download. | -| `GET /api/meta` | Server limits the UI needs (`{rowCap}`). | -| `GET /api/tables` | List browsable tables/views (+ row estimate). | -| `GET /api/tables/{schema}/{table}/columns` | Column structure (name, type, nullable, default). | -| `GET /api/tables/{schema}/{table}/fks` | Single-column foreign keys (for click-through). | -| `GET /api/tables/{schema}/{table}/data` | Paged rows; `?limit=&offset=&search=&sort=&dir=&f=col:op:val` (`&format=csv`). | +| `GET /api/databases` | List configured databases → `{defaultId, databases:[{id,name}]}`. | +| `POST /api/query?db=` | Run a query → JSON `{columns, rows, …}`. | +| `POST /api/export?db=` | Run a query → CSV download. | +| `GET /api/meta?db=` | Server limits the UI needs (`{rowCap}`). | +| `GET /api/tables?db=` | List browsable tables/views (+ row estimate). | +| `GET /api/tables/{schema}/{table}/columns?db=` | Column structure (name, type, nullable, default). | +| `GET /api/tables/{schema}/{table}/fks?db=` | Single-column foreign keys (for click-through). | +| `GET /api/tables/{schema}/{table}/data?db=` | Paged rows; `&limit=&offset=&search=&sort=&dir=&f=col:op:val` (`&format=csv`). | | `GET /api/queries` | List saved/preset queries. | | `POST /api/queries` | Create a saved query. | | `PUT /api/queries/{id}` | Update a saved query. | diff --git a/docs/assets/img/multi-database.png b/docs/assets/img/multi-database.png new file mode 100644 index 0000000..38cc73f Binary files /dev/null and b/docs/assets/img/multi-database.png differ diff --git a/docs/index.html b/docs/index.html index 66f7421..f5e7c35 100644 --- a/docs/index.html +++ b/docs/index.html @@ -216,8 +216,27 @@

Foreign keys you can click through

-
+
+
+
?db=analytics&schema=public&table=users
+ pgpeek with the database selector set to Analytics cluster while browsing public.users, with the selected database encoded in the URL. +
+
+
+ 03 +

Switch databases without losing the link

+
    +
  • Pick from named databases or clusters in the header.
  • +
  • The selected database is encoded as db=... in the URL.
  • +
  • Table, tab, search, filter, sort, and pagination state remain bookmarkable.
  • +
+
+
+ + +
Structure
@@ -226,7 +245,7 @@

Foreign keys you can click through

- 03 + 04

Structure at a glance

  • Every column: name, type, nullable, default.
  • @@ -237,7 +256,7 @@

    Structure at a glance

-
+
SQL
@@ -246,7 +265,7 @@

Structure at a glance

- 04 + 05

A SQL scratchpad with memory

  • CodeMirror editor (pgsql mode), gracefully degrades to a textarea.
  • @@ -524,14 +543,18 @@

    Running in under a minute

    Everything is an environment variable

    -

    Any value can also come from a mounted file via <VAR>_FILE — Docker / Kubernetes secrets friendly.

    +

    Single-database installs keep using DATABASE_URL. Multi-database installs can use URL lists, numbered env vars, or a mounted JSON config file.

    - + + + + + @@ -556,6 +579,13 @@

    Everything is an environment variable

    RDS / Aurora IAM auth. Set PGPEEK_DB_IAM_AUTH=true and a region, drop the password from the DSN, and pgpeek mints a short-lived IAM token from the default AWS credential chain (env / web-identity / IRSA / instance role) before every new connection — no static DB password stored anywhere. + +
    + +
    + Multiple clusters are bookmarkable. The UI writes the selected database into the URL as ?db=prod, alongside table, tab, filter, sort, and pagination params. For Kubernetes, mount a ConfigMap at /config/pgpeek/databases.json and Secret files at /secrets/*; for Compose, mount local directories with ./pgpeek-config:/config/pgpeek:ro and ./pgpeek-secrets:/secrets:ro. These are examples only — no extra deploy files are required. +
    +
    @@ -572,13 +602,14 @@

    A small, predictable surface

    VariableDefaultNotes
    DATABASE_URLrequiredPostgres DSN. Use the read-only role. Never logged. (DATABASE_URL_FILE reads it from a mounted secret.)
    DATABASE_URLsingle-DB requiredPostgres DSN for single-database installs. Use the read-only role. Never logged. (DATABASE_URL_FILE reads it from a mounted secret.)
    PGPEEK_DATABASE_URLSComma/semicolon-separated DSNs for multiple databases. Pair with PGPEEK_DATABASE_IDS and PGPEEK_DATABASE_NAMES.
    PGPEEK_DATABASE_URL_1Numbered multi-DB form; continue with _2, _3, … and use _FILE for mounted secrets.
    PGPEEK_DATABASES_FILEMounted JSON config file with {default, databases:[{id,name,urlFile}]}.
    PGPEEK_DEFAULT_DATABASEfirst DBDefault database ID when the URL has no db=.
    PGPEEK_LISTEN:8080Listen address.
    PGPEEK_ROW_CAP1000Max rows returned/exported per query.
    PGPEEK_STATEMENT_TIMEOUT30sPer-query DB statement timeout.
    - - - - - - - + + + + + + + + diff --git a/internal/config/config.go b/internal/config/config.go index 046bbd6..2ada393 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -14,10 +14,12 @@ import ( // Config is the fully-resolved application configuration. type Config struct { - Server Server - DB DB - StorePath string - RowCap int + Server Server + DB DB + Databases []DatabaseEntry + DefaultDatabaseID string + StorePath string + RowCap int } // Server holds HTTP server settings. @@ -51,15 +53,13 @@ type DB struct { // Load reads and validates configuration from the environment. func Load() (*Config, error) { - dsn, err := envOrFile("DATABASE_URL") + stmtTimeout := envDur("PGPEEK_STATEMENT_TIMEOUT", 30*time.Second) + iamAuth := envBool("PGPEEK_DB_IAM_AUTH", false) + region := env("PGPEEK_AWS_REGION", os.Getenv("AWS_REGION")) + databases, defaultDatabaseID, err := loadDatabases(iamAuth, region) if err != nil { return nil, err } - if dsn == "" { - return nil, errors.New("DATABASE_URL (or DATABASE_URL_FILE) is required") - } - - stmtTimeout := envDur("PGPEEK_STATEMENT_TIMEOUT", 30*time.Second) c := &Config{ Server: Server{ @@ -72,15 +72,20 @@ func Load() (*Config, error) { TLSKeyFile: os.Getenv("PGPEEK_TLS_KEY_FILE"), }, DB: DB{ - DSN: dsn, + DSN: "", MaxConns: int32(envInt("PGPEEK_MAX_CONNS", 8)), StatementTimeout: stmtTimeout, IdleTxTimeout: envDur("PGPEEK_IDLE_TX_TIMEOUT", 30*time.Second), - IAMAuth: envBool("PGPEEK_DB_IAM_AUTH", false), - Region: env("PGPEEK_AWS_REGION", os.Getenv("AWS_REGION")), + IAMAuth: iamAuth, + Region: region, }, - StorePath: env("PGPEEK_STORE_PATH", "/data/pgpeek.db"), - RowCap: envInt("PGPEEK_ROW_CAP", 1000), + Databases: databases, + DefaultDatabaseID: defaultDatabaseID, + StorePath: env("PGPEEK_STORE_PATH", "/data/pgpeek.db"), + RowCap: envInt("PGPEEK_ROW_CAP", 1000), + } + if err := applyDefaultDatabase(c); err != nil { + return nil, err } if err := c.validate(); err != nil { return nil, err @@ -98,6 +103,9 @@ func (c *Config) validate() error { if c.DB.IAMAuth && c.DB.Region == "" { return errors.New("PGPEEK_DB_IAM_AUTH requires PGPEEK_AWS_REGION (or AWS_REGION)") } + if err := validateDatabases(c.Databases); err != nil { + return err + } if (c.Server.TLSCertFile == "") != (c.Server.TLSKeyFile == "") { return errors.New("PGPEEK_TLS_CERT_FILE and PGPEEK_TLS_KEY_FILE must be set together") } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 3e75d50..6f16e62 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -3,6 +3,7 @@ package config import ( "os" "path/filepath" + "strconv" "testing" "time" ) @@ -18,9 +19,17 @@ func clearEnv(t *testing.T) { "PGPEEK_MAX_CONNS", "PGPEEK_STATEMENT_TIMEOUT", "PGPEEK_IDLE_TX_TIMEOUT", "PGPEEK_DB_IAM_AUTH", "PGPEEK_AWS_REGION", "AWS_REGION", "PGPEEK_STORE_PATH", "PGPEEK_ROW_CAP", + "PGPEEK_DATABASE_URLS", "PGPEEK_DATABASE_IDS", "PGPEEK_DATABASE_NAMES", + "PGPEEK_DATABASES_FILE", "PGPEEK_DEFAULT_DATABASE", } { t.Setenv(k, "") } + for i := 1; i <= maxNumberedDatabases; i++ { + t.Setenv("PGPEEK_DATABASE_URL_"+strconv.Itoa(i), "") + t.Setenv("PGPEEK_DATABASE_URL_"+strconv.Itoa(i)+"_FILE", "") + t.Setenv("PGPEEK_DATABASE_ID_"+strconv.Itoa(i), "") + t.Setenv("PGPEEK_DATABASE_NAME_"+strconv.Itoa(i), "") + } } func TestLoad_Defaults(t *testing.T) { @@ -53,6 +62,12 @@ func TestLoad_Defaults(t *testing.T) { if c.Server.TLSEnabled() { t.Error("TLS should be disabled by default") } + if c.DefaultDatabaseID != "default" || len(c.Databases) != 1 { + t.Fatalf("default database registry = %q/%d", c.DefaultDatabaseID, len(c.Databases)) + } + if c.Databases[0].ID != "default" || c.Databases[0].Name != "Default" || c.Databases[0].DSN != c.DB.DSN { + t.Errorf("default database entry = %+v", c.Databases[0]) + } } func TestLoad_Overrides(t *testing.T) { diff --git a/internal/config/databases.go b/internal/config/databases.go new file mode 100644 index 0000000..c4a30f8 --- /dev/null +++ b/internal/config/databases.go @@ -0,0 +1,273 @@ +package config + +import ( + "encoding/csv" + "encoding/json" + "errors" + "fmt" + "os" + "strings" +) + +const maxNumberedDatabases = 64 + +// DatabaseEntry holds one named database target. +type DatabaseEntry struct { + ID string + Name string + DSN string + IAMAuth bool + Region string +} + +func loadDatabases(globalIAMAuth bool, globalRegion string) ([]DatabaseEntry, string, error) { + if err := validateDatabaseSourceFamily(); err != nil { + return nil, "", err + } + entries := make([]DatabaseEntry, 0, 4) + fileDefault, err := appendDatabaseFileEntries(&entries, globalIAMAuth, globalRegion) + if err != nil { + return nil, "", err + } + if err := appendListDatabaseEntries(&entries, globalIAMAuth, globalRegion); err != nil { + return nil, "", err + } + if err := appendNumberedDatabaseEntries(&entries, globalIAMAuth, globalRegion); err != nil { + return nil, "", err + } + if len(entries) == 0 { + dsn, err := envOrFile("DATABASE_URL") + if err != nil { + return nil, "", err + } + if dsn == "" { + return nil, "", errors.New("DATABASE_URL (or DATABASE_URL_FILE) is required") + } + entries = append(entries, DatabaseEntry{ID: "default", Name: "Default", DSN: dsn, IAMAuth: globalIAMAuth, Region: globalRegion}) + } + defaultID := env("PGPEEK_DEFAULT_DATABASE", fileDefault) + if defaultID == "" && len(entries) > 0 { + defaultID = entries[0].ID + } + return entries, defaultID, nil +} + +func validateDatabaseSourceFamily() error { + sources := 0 + if os.Getenv("PGPEEK_DATABASES_FILE") != "" { + sources++ + } + if strings.TrimSpace(os.Getenv("PGPEEK_DATABASE_URLS")) != "" { + sources++ + } + for i := 1; i <= maxNumberedDatabases; i++ { + key := fmt.Sprintf("PGPEEK_DATABASE_URL_%d", i) + if os.Getenv(key) != "" || os.Getenv(key+"_FILE") != "" { + sources++ + break + } + } + if sources > 1 { + return errors.New("configure only one multi-database source: PGPEEK_DATABASES_FILE, PGPEEK_DATABASE_URLS, or PGPEEK_DATABASE_URL_N") + } + return nil +} + +func appendDatabaseFileEntries(entries *[]DatabaseEntry, globalIAMAuth bool, globalRegion string) (string, error) { + if os.Getenv("PGPEEK_DATABASES_FILE") == "" { + return "", nil + } + b, err := readOperatorFile(os.Getenv("PGPEEK_DATABASES_FILE"), "PGPEEK_DATABASES_FILE") + if err != nil { + return "", err + } + var file struct { + Default string `json:"default"` + DefaultDatabaseID string `json:"defaultDatabaseID"` + Databases []struct { + ID, Name, URL, URLFile, Region string + IAMAuth bool + } `json:"databases"` + } + if err := json.Unmarshal(b, &file); err != nil { + return "", fmt.Errorf("parse PGPEEK_DATABASES_FILE: %w", err) + } + for i, item := range file.Databases { + dsn, err := databaseFileDSN(item.URL, item.URLFile) + if err != nil { + return "", err + } + *entries = append(*entries, DatabaseEntry{ID: item.ID, Name: entryName(item.Name, i+1), DSN: dsn, IAMAuth: item.IAMAuth || globalIAMAuth, Region: pick(item.Region, globalRegion)}) + } + if file.DefaultDatabaseID != "" { + return file.DefaultDatabaseID, nil + } + return file.Default, nil +} + +func databaseFileDSN(url, urlFile string) (string, error) { + dsn := strings.TrimSpace(url) + if dsn != "" || urlFile == "" { + return dsn, nil + } + b, err := readOperatorFile(urlFile, "database urlFile") + if err != nil { + return "", err + } + return strings.TrimSpace(string(b)), nil +} + +func appendListDatabaseEntries(entries *[]DatabaseEntry, globalIAMAuth bool, globalRegion string) error { + urls, err := splitDatabaseList(os.Getenv("PGPEEK_DATABASE_URLS")) + if err != nil || len(urls) == 0 { + return err + } + ids, err := splitDatabaseList(os.Getenv("PGPEEK_DATABASE_IDS")) + if err != nil { + return err + } + names, err := splitDatabaseList(os.Getenv("PGPEEK_DATABASE_NAMES")) + if err != nil { + return err + } + for i, url := range urls { + *entries = append(*entries, DatabaseEntry{ID: listValue(ids, i, fmt.Sprintf("db%d", i+1)), Name: entryName(listValue(names, i, ""), i+1), DSN: strings.TrimSpace(url), IAMAuth: globalIAMAuth, Region: globalRegion}) + } + return nil +} + +func appendNumberedDatabaseEntries(entries *[]DatabaseEntry, globalIAMAuth bool, globalRegion string) error { + for i := 1; i <= maxNumberedDatabases; i++ { + key := fmt.Sprintf("PGPEEK_DATABASE_URL_%d", i) + dsn, err := envOrFile(key) + if err != nil { + return err + } + if dsn == "" { + break + } + *entries = append(*entries, DatabaseEntry{ID: env(fmt.Sprintf("PGPEEK_DATABASE_ID_%d", i), fmt.Sprintf("db%d", i)), Name: entryName(os.Getenv(fmt.Sprintf("PGPEEK_DATABASE_NAME_%d", i)), i), DSN: dsn, IAMAuth: globalIAMAuth, Region: globalRegion}) + } + return nil +} + +func applyDefaultDatabase(c *Config) error { + for _, db := range c.Databases { + if db.ID == c.DefaultDatabaseID { + c.DB.DSN = db.DSN + c.DB.IAMAuth = db.IAMAuth + c.DB.Region = db.Region + return nil + } + } + return fmt.Errorf("PGPEEK_DEFAULT_DATABASE %q does not match a configured database", c.DefaultDatabaseID) +} + +func validateDatabases(databases []DatabaseEntry) error { + if len(databases) == 0 { + return errors.New("at least one database is required") + } + seen := make(map[string]struct{}, len(databases)) + for _, db := range databases { + if err := validateDatabaseEntry(db, seen); err != nil { + return err + } + } + return nil +} + +func validateDatabaseEntry(db DatabaseEntry, seen map[string]struct{}) error { + if db.ID == "" { + return errors.New("database ID must be non-empty") + } + if !isDatabaseID(db.ID) { + return fmt.Errorf("database ID %q must contain only letters, numbers, dot, underscore, or dash", db.ID) + } + if _, ok := seen[db.ID]; ok { + return fmt.Errorf("database ID %q is duplicated", db.ID) + } + seen[db.ID] = struct{}{} + if db.DSN == "" { + return fmt.Errorf("database %q requires url or urlFile", db.ID) + } + if db.IAMAuth && db.Region == "" { + return fmt.Errorf("database %q IAM auth requires region", db.ID) + } + return nil +} + +func splitDatabaseList(value string) ([]string, error) { + if strings.TrimSpace(value) == "" { + return nil, nil + } + reader := csv.NewReader(strings.NewReader(normalizeListSeparators(value))) + reader.TrimLeadingSpace = true + items, err := reader.Read() + if err != nil { + return nil, fmt.Errorf("parse database list: %w", err) + } + result := make([]string, 0, len(items)) + for _, item := range items { + if trimmed := strings.TrimSpace(item); trimmed != "" { + result = append(result, trimmed) + } + } + return result, nil +} + +func normalizeListSeparators(value string) string { + inQuotes := false + var b strings.Builder + for i := 0; i < len(value); i++ { + ch := value[i] + if ch == '"' { + inQuotes = !inQuotes + } + if ch == ';' && !inQuotes { + ch = ',' + } + b.WriteByte(ch) + } + return b.String() +} + +func listValue(values []string, index int, fallback string) string { + if index < len(values) && values[index] != "" { + return values[index] + } + return fallback +} + +func entryName(value string, index int) string { + if strings.TrimSpace(value) != "" { + return strings.TrimSpace(value) + } + return fmt.Sprintf("Database %d", index) +} + +func pick(value, fallback string) string { + if strings.TrimSpace(value) != "" { + return strings.TrimSpace(value) + } + return fallback +} + +func isDatabaseID(value string) bool { + for _, r := range value { + if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_' || r == '-' || r == '.' { + continue + } + return false + } + return value != "" +} + +func readOperatorFile(path, label string) ([]byte, error) { + // The path is supplied by the operator (env var / mounted-secret convention), + // not by any request input — this is the intended use. + b, err := os.ReadFile(path) //nolint:gosec // operator-controlled config/secret path + if err != nil { + return nil, fmt.Errorf("read %s: %w", label, err) + } + return b, nil +} diff --git a/internal/config/databases_errors_test.go b/internal/config/databases_errors_test.go new file mode 100644 index 0000000..50bed99 --- /dev/null +++ b/internal/config/databases_errors_test.go @@ -0,0 +1,216 @@ +package config + +import ( + "os" + "path/filepath" + "strconv" + "strings" + "testing" +) + +func TestLoadDatabases_returns_required_error_when_no_sources_configured(t *testing.T) { + // Given: no database URL source is configured. + clearEnv(t) + + // When: database entries are loaded. + _, _, err := loadDatabases(false, "") + + // Then: configuration fails with the public missing-URL message. + if err == nil || !strings.Contains(err.Error(), "DATABASE_URL") { + t.Fatalf("loadDatabases error = %v, want DATABASE_URL requirement", err) + } +} + +func TestAppendDatabaseFileEntries_uses_default_database_id_when_present(t *testing.T) { + // Given: a mounted database config uses the explicit defaultDatabaseID field. + clearEnv(t) + dir := t.TempDir() + configPath := filepath.Join(dir, "databases.json") + body := `{"default":"legacy","defaultDatabaseID":"primary","databases":[{"id":"primary","url":"postgres://u:p@h/db"}]}` + if err := os.WriteFile(configPath, []byte(body), 0o600); err != nil { + t.Fatal(err) + } + t.Setenv("PGPEEK_DATABASES_FILE", configPath) + entries := []DatabaseEntry{} + + // When: database entries are loaded from the file. + defaultID, err := appendDatabaseFileEntries(&entries, false, "us-east-1") + + // Then: defaultDatabaseID takes precedence over legacy default. + if err != nil { + t.Fatalf("appendDatabaseFileEntries: %v", err) + } + if defaultID != "primary" || len(entries) != 1 || entries[0].Region != "us-east-1" { + t.Fatalf("default=%q entries=%+v", defaultID, entries) + } +} + +func TestAppendDatabaseFileEntries_returns_parse_error_when_file_invalid(t *testing.T) { + // Given: the mounted database config is not valid JSON. + clearEnv(t) + configPath := filepath.Join(t.TempDir(), "databases.json") + if err := os.WriteFile(configPath, []byte(`{"databases":`), 0o600); err != nil { + t.Fatal(err) + } + t.Setenv("PGPEEK_DATABASES_FILE", configPath) + + // When: databases are loaded. + _, _, err := loadDatabases(false, "") + + // Then: the parse error names the operator-provided file variable. + if err == nil || !strings.Contains(err.Error(), "parse PGPEEK_DATABASES_FILE") { + t.Fatalf("appendDatabaseFileEntries error = %v", err) + } +} + +func TestAppendDatabaseFileEntries_returns_read_error_when_file_missing(t *testing.T) { + // Given: the mounted database config file is missing. + clearEnv(t) + t.Setenv("PGPEEK_DATABASES_FILE", filepath.Join(t.TempDir(), "missing.json")) + + // When: databases are loaded. + _, _, err := loadDatabases(false, "") + + // Then: the operator file read error is returned. + if err == nil || !strings.Contains(err.Error(), "read PGPEEK_DATABASES_FILE") { + t.Fatalf("loadDatabases error = %v", err) + } +} + +func TestDatabaseFileDSN_returns_read_error_when_url_file_missing(t *testing.T) { + // Given: an entry references a missing urlFile. + clearEnv(t) + dir := t.TempDir() + missing := filepath.Join(dir, "missing-url") + configPath := filepath.Join(dir, "databases.json") + body := `{"databases":[{"id":"primary","urlFile":` + strconv.Quote(missing) + `}]}` + if err := os.WriteFile(configPath, []byte(body), 0o600); err != nil { + t.Fatal(err) + } + t.Setenv("PGPEEK_DATABASES_FILE", configPath) + + // When: databases are loaded. + _, _, err := loadDatabases(false, "") + + // Then: the read error is returned without exposing a DSN. + if err == nil || !strings.Contains(err.Error(), "read database urlFile") { + t.Fatalf("databaseFileDSN error = %v", err) + } +} + +func TestAppendListDatabaseEntries_returns_parse_error_when_csv_invalid(t *testing.T) { + // Given: list-form URLs contain malformed CSV quoting. + clearEnv(t) + t.Setenv("PGPEEK_DATABASE_URLS", `"postgres://u:p@h/db`) + + // When: databases are loaded. + _, _, err := loadDatabases(false, "") + + // Then: parsing fails before any entry is appended. + if err == nil || !strings.Contains(err.Error(), "parse database list") { + t.Fatalf("appendListDatabaseEntries error = %v", err) + } +} + +func TestAppendListDatabaseEntries_returns_parse_error_when_ids_or_names_invalid(t *testing.T) { + tests := []struct { + name string + env string + }{ + {name: "ids", env: "PGPEEK_DATABASE_IDS"}, + {name: "names", env: "PGPEEK_DATABASE_NAMES"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Given: list-form URLs are valid but a companion list has malformed CSV quoting. + clearEnv(t) + t.Setenv("PGPEEK_DATABASE_URLS", "postgres://u:p@h/db") + t.Setenv(tt.env, `"unterminated`) + + // When: databases are loaded. + _, _, err := loadDatabases(false, "") + + // Then: companion-list parsing fails. + if err == nil || !strings.Contains(err.Error(), "parse database list") { + t.Fatalf("loadDatabases error = %v", err) + } + }) + } +} + +func TestAppendNumberedDatabaseEntries_returns_file_error_when_url_file_missing(t *testing.T) { + // Given: numbered env configuration points at a missing mounted secret. + clearEnv(t) + t.Setenv("PGPEEK_DATABASE_URL_1_FILE", filepath.Join(t.TempDir(), "missing-url")) + + // When: databases are loaded. + _, _, err := loadDatabases(false, "") + + // Then: the read error is returned. + if err == nil || !strings.Contains(err.Error(), "read PGPEEK_DATABASE_URL_1") { + t.Fatalf("appendNumberedDatabaseEntries error = %v", err) + } +} + +func TestLoadDatabasesRejectsMultipleMultiDatabaseSources(t *testing.T) { + tests := []struct { + name string + setup func(t *testing.T) + }{ + {name: "file and list", setup: func(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "databases.json") + if err := os.WriteFile(configPath, []byte(`{"databases":[]}`), 0o600); err != nil { + t.Fatal(err) + } + t.Setenv("PGPEEK_DATABASES_FILE", configPath) + t.Setenv("PGPEEK_DATABASE_URLS", "postgres://u:p@h/db") + }}, + {name: "list and numbered", setup: func(t *testing.T) { + t.Setenv("PGPEEK_DATABASE_URLS", "postgres://u:p@h/db") + t.Setenv("PGPEEK_DATABASE_URL_1", "postgres://u:p@h/other") + }}, + {name: "file and numbered file", setup: func(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "databases.json") + if err := os.WriteFile(configPath, []byte(`{"databases":[]}`), 0o600); err != nil { + t.Fatal(err) + } + t.Setenv("PGPEEK_DATABASES_FILE", configPath) + t.Setenv("PGPEEK_DATABASE_URL_1_FILE", filepath.Join(t.TempDir(), "db-url")) + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + clearEnv(t) + tt.setup(t) + + _, _, err := loadDatabases(false, "") + + if err == nil || !strings.Contains(err.Error(), "configure only one multi-database source") { + t.Fatalf("loadDatabases error = %v", err) + } + }) + } +} + +func TestValidateDatabases_rejects_empty_or_incomplete_entries(t *testing.T) { + tests := []struct { + name string + dbs []DatabaseEntry + }{ + {name: "empty registry", dbs: nil}, + {name: "empty id", dbs: []DatabaseEntry{{DSN: "postgres://u:p@h/db"}}}, + {name: "empty dsn", dbs: []DatabaseEntry{{ID: "primary"}}}, + {name: "iam missing region", dbs: []DatabaseEntry{{ID: "primary", DSN: "postgres://u:p@h/db", IAMAuth: true}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // When: database entries are validated. + err := validateDatabases(tt.dbs) + + // Then: invalid entries are rejected. + if err == nil { + t.Fatal("validateDatabases expected error") + } + }) + } +} diff --git a/internal/config/databases_test.go b/internal/config/databases_test.go new file mode 100644 index 0000000..b3bfd47 --- /dev/null +++ b/internal/config/databases_test.go @@ -0,0 +1,196 @@ +package config + +import ( + "os" + "path/filepath" + "strconv" + "strings" + "testing" +) + +func TestLoad_DatabaseURLsList(t *testing.T) { + clearEnv(t) + t.Setenv("PGPEEK_DATABASE_URLS", "postgres://u:secret@h/one, postgres://u:secret@h/two;postgres://u:secret@h/three") + t.Setenv("PGPEEK_DATABASE_IDS", "analytics, billing") + t.Setenv("PGPEEK_DATABASE_NAMES", "Analytics, Billing") + t.Setenv("PGPEEK_DEFAULT_DATABASE", "billing") + + c, err := Load() + if err != nil { + t.Fatalf("Load: %v", err) + } + if c.DefaultDatabaseID != "billing" || c.DB.DSN != "postgres://u:secret@h/two" { + t.Fatalf("default database selection failed; id = %q", c.DefaultDatabaseID) + } + entries := c.Databases + if len(entries) != 3 { + t.Fatalf("databases = %d", len(entries)) + } + if entries[0].ID != "analytics" || entries[0].Name != "Analytics" { + t.Errorf("first entry id/name = %q/%q", entries[0].ID, entries[0].Name) + } + if entries[1].ID != "billing" || entries[1].Name != "Billing" { + t.Errorf("second entry id/name = %q/%q", entries[1].ID, entries[1].Name) + } + if entries[2].ID != "db3" || entries[2].Name != "Database 3" { + t.Errorf("derived entry id/name = %q/%q", entries[2].ID, entries[2].Name) + } + if strings.Contains(entries[2].Name, "secret") || strings.Contains(entries[2].Name, "postgres://") { + t.Errorf("derived name exposes DSN: %q", entries[2].Name) + } +} + +func TestLoad_DatabaseURLsListQuotedSeparators(t *testing.T) { + clearEnv(t) + t.Setenv("PGPEEK_DATABASE_URLS", `"postgres://u:p@h/one?options=a,b";"postgres://u:p@h/two?options=x;y"`) + t.Setenv("PGPEEK_DATABASE_IDS", `"one";"two"`) + t.Setenv("PGPEEK_DATABASE_NAMES", `"One, Primary";"Two; Replica"`) + t.Setenv("PGPEEK_DEFAULT_DATABASE", "two") + + c, err := Load() + if err != nil { + t.Fatalf("Load: %v", err) + } + if len(c.Databases) != 2 { + t.Fatalf("databases = %d", len(c.Databases)) + } + if c.Databases[0].ID != "one" || c.Databases[0].Name != "One, Primary" || c.Databases[0].DSN != "postgres://u:p@h/one?options=a,b" { + t.Errorf("quoted comma entry = id %q dsn %q", c.Databases[0].ID, c.Databases[0].DSN) + } + if c.Databases[1].ID != "two" || c.Databases[1].Name != "Two; Replica" || c.DB.DSN != "postgres://u:p@h/two?options=x;y" { + t.Errorf("quoted semicolon entry = id %q", c.Databases[1].ID) + } +} + +func TestLoad_NumberedDatabaseURLs(t *testing.T) { + clearEnv(t) + dir := t.TempDir() + path := filepath.Join(dir, "two") + if err := os.WriteFile(path, []byte(" postgres://u:secret@h/two\n"), 0o600); err != nil { + t.Fatal(err) + } + t.Setenv("PGPEEK_DATABASE_URL_1", "postgres://u:secret@h/one") + t.Setenv("PGPEEK_DATABASE_ID_1", "one") + t.Setenv("PGPEEK_DATABASE_NAME_1", "One") + t.Setenv("PGPEEK_DATABASE_URL_2_FILE", path) + t.Setenv("PGPEEK_DATABASE_ID_2", "two") + + c, err := Load() + if err != nil { + t.Fatalf("Load: %v", err) + } + if len(c.Databases) != 2 { + t.Fatalf("databases = %d", len(c.Databases)) + } + if c.Databases[1].ID != "two" || c.Databases[1].Name != "Database 2" { + t.Errorf("file-backed numbered entry id/name = %q/%q", c.Databases[1].ID, c.Databases[1].Name) + } + if c.Databases[1].DSN != "postgres://u:secret@h/two" { + t.Error("file-backed numbered entry DSN was not read from file") + } +} + +func TestLoad_DatabasesJSONFile(t *testing.T) { + clearEnv(t) + dir := t.TempDir() + secretPath := filepath.Join(dir, "west-url") + if err := os.WriteFile(secretPath, []byte("postgres://u:secret@h/west\n"), 0o600); err != nil { + t.Fatal(err) + } + configPath := filepath.Join(dir, "databases.json") + body := `{"default":"west","databases":[{"id":"east","name":"East","url":"postgres://u:secret@h/east","iamAuth":true,"region":"us-east-1"},{"id":"west","name":"West","urlFile":` + strconv.Quote(secretPath) + `,"iamAuth":true}]}` + if err := os.WriteFile(configPath, []byte(body), 0o600); err != nil { + t.Fatal(err) + } + t.Setenv("PGPEEK_DATABASES_FILE", configPath) + t.Setenv("AWS_REGION", "us-west-2") + + c, err := Load() + if err != nil { + t.Fatalf("Load: %v", err) + } + if c.DefaultDatabaseID != "west" || c.DB.DSN != "postgres://u:secret@h/west" || !c.DB.IAMAuth || c.DB.Region != "us-west-2" { + t.Fatalf("default DB selection failed; id = %q iam = %v region = %q", c.DefaultDatabaseID, c.DB.IAMAuth, c.DB.Region) + } + if c.Databases[0].Region != "us-east-1" || c.Databases[1].Region != "us-west-2" { + t.Errorf("entry regions = %q/%q", c.Databases[0].Region, c.Databases[1].Region) + } +} + +func TestLoad_DatabasesJSONFileInheritsGlobalIAMAuth(t *testing.T) { + clearEnv(t) + configPath := filepath.Join(t.TempDir(), "databases.json") + body := `{"databases":[{"id":"prod","url":"postgres://u@h/prod"}]}` + if err := os.WriteFile(configPath, []byte(body), 0o600); err != nil { + t.Fatal(err) + } + t.Setenv("PGPEEK_DATABASES_FILE", configPath) + t.Setenv("PGPEEK_DB_IAM_AUTH", "true") + t.Setenv("PGPEEK_AWS_REGION", "us-east-1") + + c, err := Load() + if err != nil { + t.Fatalf("Load: %v", err) + } + if !c.Databases[0].IAMAuth || c.Databases[0].Region != "us-east-1" { + t.Fatalf("file database did not inherit global IAM config: %+v", c.Databases[0]) + } +} + +func TestLoad_DatabasesJSONFileItemIAMAuthOverridesGlobalFalse(t *testing.T) { + clearEnv(t) + configPath := filepath.Join(t.TempDir(), "databases.json") + body := `{"databases":[{"id":"prod","url":"postgres://u@h/prod","iamAuth":true}]}` + if err := os.WriteFile(configPath, []byte(body), 0o600); err != nil { + t.Fatal(err) + } + t.Setenv("PGPEEK_DATABASES_FILE", configPath) + t.Setenv("PGPEEK_AWS_REGION", "us-east-1") + + c, err := Load() + if err != nil { + t.Fatalf("Load: %v", err) + } + if !c.Databases[0].IAMAuth { + t.Fatalf("file database item IAM setting lost: %+v", c.Databases[0]) + } +} + +func TestLoad_DatabaseValidationErrors(t *testing.T) { + cases := []struct { + name string + setup func(t *testing.T) + secretDSN string + }{ + {"duplicate id", func(t *testing.T) { + t.Setenv("PGPEEK_DATABASE_URLS", "postgres://u:secret@h/one,postgres://u:secret@h/two") + t.Setenv("PGPEEK_DATABASE_IDS", "same,same") + }, "postgres://u:secret@h/one"}, + {"invalid id", func(t *testing.T) { + t.Setenv("PGPEEK_DATABASE_URLS", "postgres://u:secret@h/one") + t.Setenv("PGPEEK_DATABASE_IDS", "bad id") + }, "postgres://u:secret@h/one"}, + {"unknown default", func(t *testing.T) { + t.Setenv("PGPEEK_DATABASE_URLS", "postgres://u:secret@h/one") + t.Setenv("PGPEEK_DEFAULT_DATABASE", "missing") + }, "postgres://u:secret@h/one"}, + {"iam missing region", func(t *testing.T) { + t.Setenv("PGPEEK_DATABASE_URLS", "postgres://u:secret@h/one") + t.Setenv("PGPEEK_DATABASES_FILE", "") + t.Setenv("PGPEEK_DB_IAM_AUTH", "true") + }, "postgres://u:secret@h/one"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + clearEnv(t) + tc.setup(t) + _, err := Load() + if err == nil { + t.Fatal("expected validation error") + } + if strings.Contains(err.Error(), tc.secretDSN) || strings.Contains(err.Error(), "secret") { + t.Fatalf("error exposes DSN secret: %v", err) + } + }) + } +} diff --git a/internal/db/registry.go b/internal/db/registry.go new file mode 100644 index 0000000..253a4ff --- /dev/null +++ b/internal/db/registry.go @@ -0,0 +1,109 @@ +package db + +import ( + "context" + "errors" + "fmt" +) + +var ErrPoolNotFound = errors.New("database pool not found") + +type PoolMetadata struct { + ID string `json:"id"` + Name string `json:"name"` +} + +type RegistryEntry struct { + ID string + Name string + Pool *Pool + Default bool +} + +type Registry struct { + entries []registeredPool + byID map[string]*Pool + defaultID string +} + +type registeredPool struct { + metadata PoolMetadata + pool *Pool +} + +func NewRegistry(entries []RegistryEntry) (*Registry, error) { + if len(entries) == 0 { + return nil, errors.New("database registry requires at least one pool") + } + + registered := make([]registeredPool, 0, len(entries)) + byID := make(map[string]*Pool, len(entries)) + defaultID := "" + for _, entry := range entries { + if entry.ID == "" { + return nil, errors.New("database registry entry id is required") + } + if entry.Name == "" { + return nil, fmt.Errorf("database registry entry %q name is required", entry.ID) + } + if entry.Pool == nil { + return nil, fmt.Errorf("database registry entry %q pool is required", entry.ID) + } + if _, exists := byID[entry.ID]; exists { + return nil, fmt.Errorf("database registry entry %q is duplicated", entry.ID) + } + if entry.Default { + if defaultID != "" { + return nil, errors.New("database registry has multiple defaults") + } + defaultID = entry.ID + } + + byID[entry.ID] = entry.Pool + registered = append(registered, registeredPool{ + metadata: PoolMetadata{ID: entry.ID, Name: entry.Name}, + pool: entry.Pool, + }) + } + if defaultID == "" { + defaultID = entries[0].ID + } + + return &Registry{entries: registered, byID: byID, defaultID: defaultID}, nil +} + +func (r *Registry) List() []PoolMetadata { + metadata := make([]PoolMetadata, 0, len(r.entries)) + for _, entry := range r.entries { + metadata = append(metadata, entry.metadata) + } + return metadata +} + +func (r *Registry) DefaultID() string { return r.defaultID } + +func (r *Registry) Pool(id string) (*Pool, error) { + if id == "" { + id = r.defaultID + } + pool, ok := r.byID[id] + if !ok { + return nil, fmt.Errorf("%w: %s", ErrPoolNotFound, id) + } + return pool, nil +} + +func (r *Registry) Ping(ctx context.Context) error { + for _, entry := range r.entries { + if err := entry.pool.Ping(ctx); err != nil { + return fmt.Errorf("ping database pool %q: %w", entry.metadata.ID, err) + } + } + return nil +} + +func (r *Registry) Close() { + for i := len(r.entries) - 1; i >= 0; i-- { + r.entries[i].pool.Close() + } +} diff --git a/internal/db/registry_test.go b/internal/db/registry_test.go new file mode 100644 index 0000000..6a19770 --- /dev/null +++ b/internal/db/registry_test.go @@ -0,0 +1,221 @@ +package db + +import ( + "context" + "errors" + "reflect" + "testing" +) + +func TestNewRegistry_lists_metadata_without_dsn_when_entries_have_private_names(t *testing.T) { + // Given: registry entries carry public ids/names and pool handles only. + registry, err := NewRegistry([]RegistryEntry{ + {ID: "billing", Name: "Billing", Pool: &Pool{pool: &fakePool{}, rowCap: 100}, Default: true}, + {ID: "support", Name: "Support", Pool: &Pool{pool: &fakePool{}, rowCap: 200}}, + }) + if err != nil { + t.Fatalf("NewRegistry: %v", err) + } + + // When: callers request API-safe metadata. + got := registry.List() + + // Then: only id/name are exposed, in entry order. + want := []PoolMetadata{{ID: "billing", Name: "Billing"}, {ID: "support", Name: "Support"}} + if !reflect.DeepEqual(got, want) { + t.Fatalf("List() = %#v, want %#v", got, want) + } +} + +func TestRegistry_pool_returns_default_when_selection_empty(t *testing.T) { + // Given: registry has explicit default and another pool. + defaultPool := &Pool{pool: &fakePool{}, rowCap: 100} + otherPool := &Pool{pool: &fakePool{}, rowCap: 200} + registry, err := NewRegistry([]RegistryEntry{ + {ID: "billing", Name: "Billing", Pool: defaultPool, Default: true}, + {ID: "support", Name: "Support", Pool: otherPool}, + }) + if err != nil { + t.Fatalf("NewRegistry: %v", err) + } + + // When: caller selects empty id. + got, err := registry.Pool("") + + // Then: default pool is returned, not an unknown-id error. + if err != nil { + t.Fatalf("Pool(empty): %v", err) + } + if got != defaultPool { + t.Fatalf("Pool(empty) returned %p, want default %p", got, defaultPool) + } + if registry.DefaultID() != "billing" { + t.Fatalf("DefaultID() = %q, want billing", registry.DefaultID()) + } +} + +func TestRegistry_pool_returns_requested_pool_when_id_known(t *testing.T) { + // Given: registry has two named pools. + defaultPool := &Pool{pool: &fakePool{}, rowCap: 100} + otherPool := &Pool{pool: &fakePool{}, rowCap: 200} + registry, err := NewRegistry([]RegistryEntry{ + {ID: "billing", Name: "Billing", Pool: defaultPool}, + {ID: "support", Name: "Support", Pool: otherPool}, + }) + if err != nil { + t.Fatalf("NewRegistry: %v", err) + } + + // When: caller selects a known id. + got, err := registry.Pool("support") + + // Then: matching pool is returned. + if err != nil { + t.Fatalf("Pool(support): %v", err) + } + if got != otherPool { + t.Fatalf("Pool(support) returned %p, want %p", got, otherPool) + } +} + +func TestRegistry_pool_returns_not_found_when_non_empty_id_unknown(t *testing.T) { + // Given: registry has one pool. + registry, err := NewRegistry([]RegistryEntry{{ID: "billing", Name: "Billing", Pool: &Pool{pool: &fakePool{}, rowCap: 100}}}) + if err != nil { + t.Fatalf("NewRegistry: %v", err) + } + + // When: caller selects an unknown non-empty id. + _, err = registry.Pool("missing") + + // Then: caller can distinguish unknown id from default selection. + if !errors.Is(err, ErrPoolNotFound) { + t.Fatalf("Pool(missing) error = %v, want ErrPoolNotFound", err) + } +} + +func TestRegistry_ping_visits_pools_in_metadata_order_and_stops_on_error(t *testing.T) { + // Given: second pool fails ping. + first := &orderedPool{name: "first"} + second := &orderedPool{name: "second", pingErr: errors.New("down")} + third := &orderedPool{name: "third"} + registry, err := NewRegistry([]RegistryEntry{ + {ID: "first", Name: "First", Pool: &Pool{pool: first, rowCap: 10}}, + {ID: "second", Name: "Second", Pool: &Pool{pool: second, rowCap: 10}}, + {ID: "third", Name: "Third", Pool: &Pool{pool: third, rowCap: 10}}, + }) + if err != nil { + t.Fatalf("NewRegistry: %v", err) + } + + // When: readiness pings all pools. + err = registry.Ping(context.Background()) + + // Then: order is deterministic and ping stops at failing pool. + if err == nil { + t.Fatal("Ping() expected error") + } + got := append([]string{}, first.pings...) + got = append(got, second.pings...) + if want := []string{"first", "second"}; !reflect.DeepEqual(got, want) { + t.Fatalf("ping order = %v, want %v", got, want) + } + if len(third.pings) != 0 { + t.Fatalf("third pool pinged after error: %v", third.pings) + } +} + +func TestRegistry_ping_returns_nil_when_all_pools_healthy(t *testing.T) { + // Given: every registered pool can be pinged. + first := &orderedPool{name: "first"} + second := &orderedPool{name: "second"} + registry, err := NewRegistry([]RegistryEntry{ + {ID: "first", Name: "First", Pool: &Pool{pool: first, rowCap: 10}}, + {ID: "second", Name: "Second", Pool: &Pool{pool: second, rowCap: 10}}, + }) + if err != nil { + t.Fatalf("NewRegistry: %v", err) + } + + // When: readiness pings the registry. + err = registry.Ping(context.Background()) + + // Then: all pools are checked and no error is returned. + if err != nil { + t.Fatalf("Ping: %v", err) + } + got := append([]string{}, first.pings...) + got = append(got, second.pings...) + if want := []string{"first", "second"}; !reflect.DeepEqual(got, want) { + t.Fatalf("ping order = %v, want %v", got, want) + } +} + +func TestRegistry_close_visits_all_pools_in_reverse_metadata_order(t *testing.T) { + // Given: registry has three pools. + closeOrder := make([]string, 0, 3) + registry, err := NewRegistry([]RegistryEntry{ + {ID: "first", Name: "First", Pool: &Pool{pool: &orderedPool{name: "first", closes: &closeOrder}, rowCap: 10}}, + {ID: "second", Name: "Second", Pool: &Pool{pool: &orderedPool{name: "second", closes: &closeOrder}, rowCap: 10}}, + {ID: "third", Name: "Third", Pool: &Pool{pool: &orderedPool{name: "third", closes: &closeOrder}, rowCap: 10}}, + }) + if err != nil { + t.Fatalf("NewRegistry: %v", err) + } + + // When: registry closes. + registry.Close() + + // Then: every pool closes in deterministic reverse order. + want := []string{"third", "second", "first"} + if !reflect.DeepEqual(closeOrder, want) { + t.Fatalf("close order = %v, want %v", closeOrder, want) + } +} + +func TestNewRegistry_rejects_invalid_entries(t *testing.T) { + tests := []struct { + name string + entries []RegistryEntry + }{ + {name: "empty registry", entries: nil}, + {name: "empty id", entries: []RegistryEntry{{ID: "", Name: "Billing", Pool: &Pool{pool: &fakePool{}}}}}, + {name: "empty name", entries: []RegistryEntry{{ID: "billing", Name: "", Pool: &Pool{pool: &fakePool{}}}}}, + {name: "nil pool", entries: []RegistryEntry{{ID: "billing", Name: "Billing", Pool: nil}}}, + {name: "duplicate id", entries: []RegistryEntry{{ID: "billing", Name: "Billing", Pool: &Pool{pool: &fakePool{}}}, {ID: "billing", Name: "Other", Pool: &Pool{pool: &fakePool{}}}}}, + {name: "multiple defaults", entries: []RegistryEntry{{ID: "billing", Name: "Billing", Pool: &Pool{pool: &fakePool{}}, Default: true}, {ID: "support", Name: "Support", Pool: &Pool{pool: &fakePool{}}, Default: true}}}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Given: invalid registry input. + + // When: registry is constructed. + _, err := NewRegistry(tc.entries) + + // Then: construction fails before runtime lookup. + if err == nil { + t.Fatal("NewRegistry expected error") + } + }) + } +} + +type orderedPool struct { + fakePool + name string + pingErr error + pings []string + closes *[]string +} + +func (p *orderedPool) Ping(context.Context) error { + p.pings = append(p.pings, p.name) + return p.pingErr +} + +func (p *orderedPool) Close() { + if p.closes != nil { + *p.closes = append(*p.closes, p.name) + } +} diff --git a/internal/guard/guard_test.go b/internal/guard/guard_test.go index 815b152..e781133 100644 --- a/internal/guard/guard_test.go +++ b/internal/guard/guard_test.go @@ -93,6 +93,15 @@ func TestValidate_AllowsRestrictedCatalogNamesInStringsAndComments(t *testing.T) } } +func TestIsRestrictedRelation(t *testing.T) { + if !IsRestrictedRelation("pg_shadow") || !IsRestrictedRelation("PG_AUTHID") { + t.Fatal("expected sensitive catalogs to be restricted") + } + if IsRestrictedRelation("users") { + t.Fatal("ordinary table should not be restricted") + } +} + // FuzzValidate asserts the guard never panics on arbitrary input, and that any // input it *accepts* really is a single statement beginning with an allowed // keyword and containing no forbidden keyword (in masked form). This is the diff --git a/internal/server/catalog_handlers.go b/internal/server/catalog_handlers.go new file mode 100644 index 0000000..5a0d9aa --- /dev/null +++ b/internal/server/catalog_handlers.go @@ -0,0 +1,123 @@ +package server + +import ( + "context" + "net/http" + "time" + + "github.com/descope-sample-apps/pgpeek/internal/db" +) + +func (s *Server) handleMeta(w http.ResponseWriter, r *http.Request) { + pool, ok := s.poolForRequest(w, r) + if !ok { + return + } + writeJSON(w, http.StatusOK, map[string]int{"rowCap": pool.RowCap()}) +} + +func (s *Server) handleTables(w http.ResponseWriter, r *http.Request) { + pool, ok := s.poolForRequest(w, r) + if !ok { + return + } + ctx, cancel := context.WithTimeout(r.Context(), s.queryWait) + defer cancel() + tables, err := pool.Tables(ctx) + if err != nil { + s.log.Error("list tables", "err", err) + writeError(w, http.StatusInternalServerError, "failed to list tables") + return + } + if tables == nil { + tables = []db.TableInfo{} + } + writeJSON(w, http.StatusOK, tables) +} + +func (s *Server) handleColumns(w http.ResponseWriter, r *http.Request) { + pool, ok := s.poolForRequest(w, r) + if !ok { + return + } + ctx, cancel := context.WithTimeout(r.Context(), s.queryWait) + defer cancel() + cols, err := pool.Columns(ctx, r.PathValue("schema"), r.PathValue("table")) + if err != nil { + s.log.Error("read columns", "err", err) + writeError(w, http.StatusInternalServerError, "failed to read columns") + return + } + if cols == nil { + cols = []db.ColumnInfo{} + } + writeJSON(w, http.StatusOK, cols) +} + +func (s *Server) handleForeignKeys(w http.ResponseWriter, r *http.Request) { + pool, ok := s.poolForRequest(w, r) + if !ok { + return + } + ctx, cancel := context.WithTimeout(r.Context(), s.queryWait) + defer cancel() + fks, err := pool.ForeignKeys(ctx, r.PathValue("schema"), r.PathValue("table")) + if err != nil { + s.log.Error("read foreign keys", "err", err) + writeError(w, http.StatusInternalServerError, "failed to read foreign keys") + return + } + if fks == nil { + fks = []db.ForeignKey{} + } + writeJSON(w, http.StatusOK, fks) +} + +func (s *Server) handleTableData(w http.ResponseWriter, r *http.Request) { + pool, ok := s.poolForRequest(w, r) + if !ok { + return + } + q := db.TableQuery{ + Schema: r.PathValue("schema"), + Table: r.PathValue("table"), + Search: r.URL.Query().Get("search"), + Sort: r.URL.Query().Get("sort"), + Desc: r.URL.Query().Get("dir") == "desc", + Limit: queryInt(r, "limit", 0), + Offset: queryInt(r, "offset", 0), + Filters: parseFilters(r.URL.Query()["f"]), + } + ctx, cancel := context.WithTimeout(r.Context(), s.queryWait) + defer cancel() + res, err := pool.TableRows(ctx, q) + if err != nil { + s.log.Error("read rows", "err", err) + writeError(w, http.StatusBadRequest, "failed to read rows") + return + } + if r.URL.Query().Get("format") == "csv" { + w.Header().Set("Content-Type", "text/csv; charset=utf-8") + w.Header().Set("Content-Disposition", `attachment; filename="`+safeFilename(r.PathValue("table"))+`.csv"`) + if err := writeCSV(w, res); err != nil { + s.log.Error("csv export", "err", err) + } + return + } + writeJSON(w, http.StatusOK, res) +} + +func (s *Server) handleReady(w http.ResponseWriter, r *http.Request) { + pool, ok := s.poolForRequest(w, r) + if !ok { + return + } + ctx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + defer cancel() + if err := pool.Ping(ctx); err != nil { + writeError(w, http.StatusServiceUnavailable, "database not ready") + return + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ready")) +} diff --git a/internal/server/catalog_handlers_test.go b/internal/server/catalog_handlers_test.go new file mode 100644 index 0000000..e2e4b9e --- /dev/null +++ b/internal/server/catalog_handlers_test.go @@ -0,0 +1,250 @@ +package server + +import ( + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/descope-sample-apps/pgpeek/internal/db" +) + +// --- catalog / browse endpoints ----------------------------------------- + +func TestMeta(t *testing.T) { + ts, _ := newTestServer(t, &fakeQuerier{}) + resp := mustGet(t, ts, "/api/meta") + got := decode[map[string]int](t, resp) + if got["rowCap"] != 1000 { + t.Errorf("rowCap = %d, want 1000", got["rowCap"]) + } +} + +func TestSafeFilename(t *testing.T) { + cases := map[string]string{ + "users": "users", + `a"; x`: "a___x", + "sch.tbl": "sch.tbl", + "évil/../x": "_vil_.._x", + "": "table", + "weird name!@#$": "weird_name____", + } + for in, want := range cases { + if got := safeFilename(in); got != want { + t.Errorf("safeFilename(%q) = %q, want %q", in, got, want) + } + } +} + +func TestTables_OK(t *testing.T) { + q := &fakeQuerier{tables: []db.TableInfo{{Schema: "public", Name: "users", Type: "table", EstRows: 5}}} + ts, _ := newTestServer(t, q) + resp := mustGet(t, ts, "/api/tables") + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d", resp.StatusCode) + } + got := decode[[]db.TableInfo](t, resp) + if len(got) != 1 || got[0].Name != "users" { + t.Errorf("tables = %+v", got) + } +} + +func TestTables_EmptyReturnsArray(t *testing.T) { + srv := serverWithStore(t, &fakeQuerier{}, &fakeStore{}) + rec := httptest.NewRecorder() + srv.Routes().ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/api/tables", nil)) + if strings.TrimSpace(rec.Body.String()) != "[]" { + t.Errorf("body = %q, want []", rec.Body.String()) + } +} + +func TestTables_Error(t *testing.T) { + ts, _ := newTestServer(t, &fakeQuerier{catErr: errors.New("postgres://secret-host/hidden: boom")}) + resp := mustGet(t, ts, "/api/tables") + if resp.StatusCode != http.StatusInternalServerError { + t.Errorf("status = %d, want 500", resp.StatusCode) + } + got := decode[map[string]string](t, resp) + if got["error"] != "failed to list tables" { + t.Fatalf("error = %q, want sanitized tables error", got["error"]) + } +} + +func TestColumns_OK(t *testing.T) { + q := &fakeQuerier{cols: []db.ColumnInfo{{Name: "id", Type: "integer"}}} + ts, _ := newTestServer(t, q) + resp := mustGet(t, ts, "/api/tables/public/users/columns") + got := decode[[]db.ColumnInfo](t, resp) + if len(got) != 1 || got[0].Name != "id" { + t.Errorf("columns = %+v", got) + } + if q.lastArgs.schema != "public" || q.lastArgs.table != "users" { + t.Errorf("path values not passed: %+v", q.lastArgs) + } +} + +func TestColumns_EmptyReturnsArray(t *testing.T) { + srv := serverWithStore(t, &fakeQuerier{}, &fakeStore{}) + rec := httptest.NewRecorder() + srv.Routes().ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/api/tables/public/users/columns", nil)) + if strings.TrimSpace(rec.Body.String()) != "[]" { + t.Errorf("body = %q, want []", rec.Body.String()) + } +} + +func TestColumns_Error(t *testing.T) { + ts, _ := newTestServer(t, &fakeQuerier{catErr: errors.New("postgres://secret-host/hidden: boom")}) + resp := mustGet(t, ts, "/api/tables/public/users/columns") + if resp.StatusCode != http.StatusInternalServerError { + t.Errorf("status = %d, want 500", resp.StatusCode) + } + got := decode[map[string]string](t, resp) + if got["error"] != "failed to read columns" { + t.Fatalf("error = %q, want sanitized columns error", got["error"]) + } +} + +func TestForeignKeys_OK(t *testing.T) { + q := &fakeQuerier{fks: []db.ForeignKey{{Column: "company_id", RefSchema: "public", RefTable: "companies", RefColumn: "id"}}} + ts, _ := newTestServer(t, q) + resp := mustGet(t, ts, "/api/tables/public/users/fks") + got := decode[[]db.ForeignKey](t, resp) + if len(got) != 1 || got[0].RefTable != "companies" { + t.Errorf("fks = %+v", got) + } + if q.lastArgs.schema != "public" || q.lastArgs.table != "users" { + t.Errorf("path values not passed: %+v", q.lastArgs) + } +} + +func TestForeignKeys_EmptyReturnsArray(t *testing.T) { + srv := serverWithStore(t, &fakeQuerier{}, &fakeStore{}) + rec := httptest.NewRecorder() + srv.Routes().ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/api/tables/public/users/fks", nil)) + if strings.TrimSpace(rec.Body.String()) != "[]" { + t.Errorf("body = %q, want []", rec.Body.String()) + } +} + +func TestForeignKeys_Error(t *testing.T) { + ts, _ := newTestServer(t, &fakeQuerier{catErr: errors.New("postgres://secret-host/hidden: boom")}) + resp := mustGet(t, ts, "/api/tables/public/users/fks") + if resp.StatusCode != http.StatusInternalServerError { + t.Errorf("status = %d, want 500", resp.StatusCode) + } + got := decode[map[string]string](t, resp) + if got["error"] != "failed to read foreign keys" { + t.Fatalf("error = %q, want sanitized foreign-key error", got["error"]) + } +} + +func TestTableData_OK(t *testing.T) { + q := &fakeQuerier{result: okResult()} + ts, _ := newTestServer(t, q) + resp := mustGet(t, ts, "/api/tables/public/users/data?limit=25&offset=50") + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d", resp.StatusCode) + } + decode[db.Result](t, resp) + if q.lastArgs.limit != 25 || q.lastArgs.offset != 50 { + t.Errorf("limit/offset = %d/%d", q.lastArgs.limit, q.lastArgs.offset) + } +} + +func TestTableData_ParsesSearchSortFilters(t *testing.T) { + q := &fakeQuerier{result: okResult()} + ts, _ := newTestServer(t, q) + resp := mustGet(t, ts, "/api/tables/public/users/data?search=acme&sort=id&dir=desc&f=id:gt:100&f=name:ilike:%25a%25&f=deleted_at:is_null") + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d", resp.StatusCode) + } + lq := q.lastQuery + if lq.Search != "acme" || lq.Sort != "id" || !lq.Desc { + t.Errorf("search/sort/desc = %+v", lq) + } + if len(lq.Filters) != 3 { + t.Fatalf("filters = %+v", lq.Filters) + } + if lq.Filters[0] != (db.Filter{Column: "id", Op: "gt", Value: "100"}) { + t.Errorf("filter0 = %+v", lq.Filters[0]) + } + if lq.Filters[1] != (db.Filter{Column: "name", Op: "ilike", Value: "%a%"}) { + t.Errorf("filter1 = %+v", lq.Filters[1]) + } + if lq.Filters[2] != (db.Filter{Column: "deleted_at", Op: "is_null"}) { + t.Errorf("filter2 (no value) = %+v", lq.Filters[2]) + } +} + +func TestParseFilters(t *testing.T) { + if parseFilters(nil) != nil { + t.Error("nil input should yield nil") + } + got := parseFilters([]string{"id:gt:100", "name:is_null", "bare"}) + want := []db.Filter{ + {Column: "id", Op: "gt", Value: "100"}, + {Column: "name", Op: "is_null"}, + {Column: "bare"}, + } + if len(got) != len(want) { + t.Fatalf("len = %d", len(got)) + } + for i := range want { + if got[i] != want[i] { + t.Errorf("filter %d = %+v, want %+v", i, got[i], want[i]) + } + } +} + +func TestTableData_DefaultsAndBadParams(t *testing.T) { + q := &fakeQuerier{result: okResult()} + ts, _ := newTestServer(t, q) + // Non-numeric params fall back to defaults (0 -> pool clamps). + resp := mustGet(t, ts, "/api/tables/public/users/data?limit=abc") + resp.Body.Close() + if q.lastArgs.limit != 0 || q.lastArgs.offset != 0 { + t.Errorf("expected defaults, got %d/%d", q.lastArgs.limit, q.lastArgs.offset) + } +} + +func TestTableData_Error(t *testing.T) { + ts, _ := newTestServer(t, &fakeQuerier{err: errors.New("postgres://secret-host/hidden: no such table")}) + resp := mustGet(t, ts, "/api/tables/public/nope/data") + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("status = %d, want 400", resp.StatusCode) + } + got := decode[map[string]string](t, resp) + if got["error"] != "failed to read rows" { + t.Fatalf("error = %q, want sanitized row error", got["error"]) + } +} + +func TestTableData_CSV(t *testing.T) { + q := &fakeQuerier{result: &db.Result{Columns: []string{"a"}, Rows: [][]any{{"1"}}, RowCount: 1}} + ts, _ := newTestServer(t, q) + resp := mustGet(t, ts, "/api/tables/public/users/data?format=csv") + defer resp.Body.Close() + if ct := resp.Header.Get("Content-Type"); !strings.HasPrefix(ct, "text/csv") { + t.Errorf("content-type = %q", ct) + } + if cd := resp.Header.Get("Content-Disposition"); !strings.Contains(cd, "users.csv") { + t.Errorf("disposition = %q", cd) + } + body, _ := io.ReadAll(resp.Body) + if !strings.Contains(string(body), "a\n1\n") { + t.Errorf("csv body = %q", body) + } +} + +func TestTableData_CSVWriteFailure(t *testing.T) { + srv := serverWithStore(t, &fakeQuerier{result: okResult()}, &fakeStore{}) + req := httptest.NewRequest(http.MethodGet, "/api/tables/public/users/data?format=csv", nil) + fw := &failingWriter{} + srv.handleTableData(fw, req) + if ct := fw.Header().Get("Content-Type"); !strings.HasPrefix(ct, "text/csv") { + t.Errorf("content-type = %q", ct) + } +} diff --git a/internal/server/database_selection_test.go b/internal/server/database_selection_test.go new file mode 100644 index 0000000..6d5b3bb --- /dev/null +++ b/internal/server/database_selection_test.go @@ -0,0 +1,247 @@ +package server + +import ( + "context" + "encoding/json" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "strings" + "testing" + "testing/fstest" + "time" + + "github.com/descope-sample-apps/pgpeek/internal/db" + "github.com/descope-sample-apps/pgpeek/internal/store" +) + +type fakeRegistry struct { + defaultID string + metadata []db.PoolMetadata + pools map[string]Querier +} + +func (f fakeRegistry) List() []db.PoolMetadata { return f.metadata } +func (f fakeRegistry) DefaultID() string { return f.defaultID } + +func (f fakeRegistry) Pool(id string) (Querier, error) { + if id == "" { + id = f.defaultID + } + pool, ok := f.pools[id] + if !ok { + return nil, db.ErrPoolNotFound + } + return pool, nil +} + +func (f fakeRegistry) Ping(ctx context.Context) error { + for _, pool := range f.pools { + if err := pool.Ping(ctx); err != nil { + return err + } + } + return nil +} + +type selectedQuerier struct { + rowCap int + used bool + pinged bool +} + +func (q *selectedQuerier) Query(context.Context, string) (*db.Result, error) { + q.used = true + return okResult(), nil +} + +func (q *selectedQuerier) Tables(context.Context) ([]db.TableInfo, error) { + q.used = true + return []db.TableInfo{}, nil +} + +func (q *selectedQuerier) Columns(context.Context, string, string) ([]db.ColumnInfo, error) { + q.used = true + return []db.ColumnInfo{}, nil +} + +func (q *selectedQuerier) ForeignKeys(context.Context, string, string) ([]db.ForeignKey, error) { + q.used = true + return []db.ForeignKey{}, nil +} + +func (q *selectedQuerier) TableRows(context.Context, db.TableQuery) (*db.Result, error) { + q.used = true + return okResult(), nil +} + +func (q *selectedQuerier) RowCap() int { + q.used = true + return q.rowCap +} + +func (q *selectedQuerier) Ping(context.Context) error { + q.pinged = true + return nil +} + +func newRegistryTestServer(t *testing.T, registry DatabaseRegistry) *httptest.Server { + t.Helper() + st, err := store.Open(t.TempDir() + "/t.db") + if err != nil { + t.Fatalf("store.Open: %v", err) + } + t.Cleanup(func() { _ = st.Close() }) + web := fstest.MapFS{"index.html": &fstest.MapFile{Data: []byte("x")}} + log := slog.New(slog.NewTextHandler(io.Discard, nil)) + srv := NewWithRegistry(registry, st, web, log, time.Second) + ts := httptest.NewServer(srv.Routes()) + t.Cleanup(ts.Close) + return ts +} + +func TestDatabases_lists_safe_metadata(t *testing.T) { + primary := &selectedQuerier{rowCap: 1000} + registry := fakeRegistry{ + defaultID: "primary", + metadata: []db.PoolMetadata{ + {ID: "primary", Name: "Primary"}, + {ID: "analytics", Name: "Analytics"}, + }, + pools: map[string]Querier{"primary": primary, "analytics": &selectedQuerier{}}, + } + ts := newRegistryTestServer(t, registry) + + resp := mustGet(t, ts, "/api/databases") + got := decode[struct { + DefaultID string `json:"defaultId"` + Databases []db.PoolMetadata `json:"databases"` + }](t, resp) + + if got.DefaultID != "primary" || len(got.Databases) != 2 || got.Databases[0].ID != "primary" || got.Databases[1].Name != "Analytics" { + t.Fatalf("databases = %+v", got) + } + body := marshalString(t, got) + if strings.Contains(body, "postgres://") || strings.Contains(body, "dsn") { + t.Fatalf("database metadata leaked secret material: %s", body) + } +} + +func TestDatabaseSelection_uses_selected_pool_for_db_bound_endpoints(t *testing.T) { + tests := []struct { + name string + method string + path string + body string + }{ + {name: "readyz", method: http.MethodGet, path: "/readyz?db=analytics"}, + {name: "meta", method: http.MethodGet, path: "/api/meta?db=analytics"}, + {name: "query", method: http.MethodPost, path: "/api/query?db=analytics", body: `{"sql":"SELECT 1"}`}, + {name: "export", method: http.MethodPost, path: "/api/export?db=analytics", body: `{"sql":"SELECT 1"}`}, + {name: "tables", method: http.MethodGet, path: "/api/tables?db=analytics"}, + {name: "columns", method: http.MethodGet, path: "/api/tables/public/users/columns?db=analytics"}, + {name: "fks", method: http.MethodGet, path: "/api/tables/public/users/fks?db=analytics"}, + {name: "data", method: http.MethodGet, path: "/api/tables/public/users/data?db=analytics"}, + {name: "data csv", method: http.MethodGet, path: "/api/tables/public/users/data?format=csv&db=analytics"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + primary := &selectedQuerier{rowCap: 1000} + analytics := &selectedQuerier{rowCap: 2000} + ts := newRegistryTestServer(t, fakeRegistry{ + defaultID: "primary", + metadata: []db.PoolMetadata{{ID: "primary", Name: "Primary"}, {ID: "analytics", Name: "Analytics"}}, + pools: map[string]Querier{"primary": primary, "analytics": analytics}, + }) + req, err := http.NewRequest(tt.method, ts.URL+tt.path, strings.NewReader(tt.body)) + if err != nil { + t.Fatal(err) + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + if tt.name == "readyz" { + if !analytics.pinged || primary.pinged { + t.Fatalf("pinged primary=%v analytics=%v", primary.pinged, analytics.pinged) + } + return + } + if !analytics.used || primary.used { + t.Fatalf("used primary=%v analytics=%v", primary.used, analytics.used) + } + }) + } +} + +func TestDatabaseSelection_missing_or_empty_db_uses_default(t *testing.T) { + primary := &selectedQuerier{rowCap: 1000} + analytics := &selectedQuerier{rowCap: 2000} + ts := newRegistryTestServer(t, fakeRegistry{ + defaultID: "primary", + metadata: []db.PoolMetadata{{ID: "primary", Name: "Primary"}, {ID: "analytics", Name: "Analytics"}}, + pools: map[string]Querier{"primary": primary, "analytics": analytics}, + }) + + resp := mustGet(t, ts, "/api/meta?db=") + got := decode[map[string]int](t, resp) + + if got["rowCap"] != 1000 || !primary.used || analytics.used { + t.Fatalf("default selection failed: body=%+v primary=%v analytics=%v", got, primary.used, analytics.used) + } +} + +func TestDatabaseSelection_unknown_db_returns_404(t *testing.T) { + primary := &selectedQuerier{rowCap: 1000} + ts := newRegistryTestServer(t, fakeRegistry{ + defaultID: "primary", + metadata: []db.PoolMetadata{{ID: "primary", Name: "Primary"}}, + pools: map[string]Querier{"primary": primary}, + }) + + resp := mustGet(t, ts, "/api/tables?db=missing") + got := decode[map[string]string](t, resp) + + if resp.StatusCode != http.StatusNotFound { + t.Fatalf("status = %d, want 404", resp.StatusCode) + } + if got["error"] == "" || strings.Contains(got["error"], "postgres://") { + t.Fatalf("bad error body: %+v", got) + } + if primary.used { + t.Fatal("unknown db should not use default pool") + } +} + +func TestDatabaseSelection_saved_query_endpoints_ignore_db(t *testing.T) { + primary := &selectedQuerier{rowCap: 1000} + ts := newRegistryTestServer(t, fakeRegistry{ + defaultID: "primary", + metadata: []db.PoolMetadata{{ID: "primary", Name: "Primary"}}, + pools: map[string]Querier{"primary": primary}, + }) + + resp := post(t, ts, "/api/queries?db=missing", `{"name":"q","sql":"SELECT 1"}`) + created := decode[store.SavedQuery](t, resp) + + if resp.StatusCode != http.StatusCreated || created.ID == 0 { + t.Fatalf("create saved query status=%d query=%+v", resp.StatusCode, created) + } + if primary.used || primary.pinged { + t.Fatal("saved query endpoint should not select a database pool") + } +} + +func marshalString(t *testing.T, v any) string { + t.Helper() + b, err := json.Marshal(v) + if err != nil { + t.Fatalf("marshal: %v", err) + } + return string(b) +} diff --git a/internal/server/helpers_test.go b/internal/server/helpers_test.go new file mode 100644 index 0000000..1d356ad --- /dev/null +++ b/internal/server/helpers_test.go @@ -0,0 +1,113 @@ +package server + +import ( + "context" + "encoding/json" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "path/filepath" + "strings" + "testing" + "testing/fstest" + "time" + + "github.com/descope-sample-apps/pgpeek/internal/db" + "github.com/descope-sample-apps/pgpeek/internal/store" +) + +type fakeQuerier struct { + result *db.Result + err error + pingErr error + called bool + lastSQL string + tables []db.TableInfo + cols []db.ColumnInfo + fks []db.ForeignKey + catErr error + lastQuery db.TableQuery + lastArgs struct { + schema, table string + limit, offset int + } +} + +func (f *fakeQuerier) Query(_ context.Context, sql string) (*db.Result, error) { + f.called = true + f.lastSQL = sql + return f.result, f.err +} + +func (f *fakeQuerier) Tables(context.Context) ([]db.TableInfo, error) { + return f.tables, f.catErr +} + +func (f *fakeQuerier) Columns(_ context.Context, schema, table string) ([]db.ColumnInfo, error) { + f.lastArgs.schema, f.lastArgs.table = schema, table + return f.cols, f.catErr +} + +func (f *fakeQuerier) ForeignKeys(_ context.Context, schema, table string) ([]db.ForeignKey, error) { + f.lastArgs.schema, f.lastArgs.table = schema, table + return f.fks, f.catErr +} + +func (f *fakeQuerier) TableRows(_ context.Context, q db.TableQuery) (*db.Result, error) { + f.lastQuery = q + f.lastArgs.schema, f.lastArgs.table = q.Schema, q.Table + f.lastArgs.limit, f.lastArgs.offset = q.Limit, q.Offset + return f.result, f.err +} + +func (f *fakeQuerier) RowCap() int { return 1000 } + +func (f *fakeQuerier) Ping(_ context.Context) error { return f.pingErr } + +func newTestServer(t *testing.T, q Querier) (*httptest.Server, *store.Store) { + t.Helper() + st, err := store.Open(filepath.Join(t.TempDir(), "t.db")) + if err != nil { + t.Fatalf("store.Open: %v", err) + } + t.Cleanup(func() { _ = st.Close() }) + + web := fstest.MapFS{ + "index.html": &fstest.MapFile{Data: []byte("pgpeek")}, + "app.js": &fstest.MapFile{Data: []byte("// js")}, + } + log := slog.New(slog.NewTextHandler(io.Discard, nil)) + srv := New(q, st, web, log, 5*time.Second) + ts := httptest.NewServer(srv.Routes()) + t.Cleanup(ts.Close) + return ts, st +} + +func post(t *testing.T, ts *httptest.Server, path, body string) *http.Response { + t.Helper() + resp, err := http.Post(ts.URL+path, "application/json", strings.NewReader(body)) + if err != nil { + t.Fatalf("POST %s: %v", path, err) + } + return resp +} + +func decode[T any](t *testing.T, resp *http.Response) T { + t.Helper() + defer resp.Body.Close() + var v T + if err := json.NewDecoder(resp.Body).Decode(&v); err != nil { + t.Fatalf("decode: %v", err) + } + return v +} + +func okResult() *db.Result { + return &db.Result{ + Columns: []string{"n"}, + Rows: [][]any{{int64(1)}}, + RowCount: 1, + ElapsedMS: 1, + } +} diff --git a/internal/server/http_helpers.go b/internal/server/http_helpers.go new file mode 100644 index 0000000..0e58a60 --- /dev/null +++ b/internal/server/http_helpers.go @@ -0,0 +1,86 @@ +package server + +import ( + "encoding/json" + "net/http" + "strconv" + "strings" + + "github.com/descope-sample-apps/pgpeek/internal/db" +) + +const maxBodyBytes = 1 << 20 + +func decodeBody(w http.ResponseWriter, r *http.Request, v any) bool { + r.Body = http.MaxBytesReader(w, r.Body, maxBodyBytes) + dec := json.NewDecoder(r.Body) + dec.DisallowUnknownFields() + if err := dec.Decode(v); err != nil { + writeError(w, http.StatusBadRequest, "invalid JSON body: "+err.Error()) + return false + } + return true +} + +func safeFilename(name string) string { + out := strings.Map(func(r rune) rune { + switch { + case r >= 'a' && r <= 'z', r >= 'A' && r <= 'Z', r >= '0' && r <= '9', + r == '_', r == '-', r == '.': + return r + default: + return '_' + } + }, name) + if out == "" { + return "table" + } + return out +} + +func parseFilters(raw []string) []db.Filter { + if len(raw) == 0 { + return nil + } + out := make([]db.Filter, 0, len(raw)) + for _, s := range raw { + parts := strings.SplitN(s, ":", 3) + f := db.Filter{Column: parts[0]} + if len(parts) >= 2 { + f.Op = parts[1] + } + if len(parts) == 3 { + f.Value = parts[2] + } + out = append(out, f) + } + return out +} + +func queryInt(r *http.Request, key string, def int) int { + if v := r.URL.Query().Get(key); v != "" { + if n, err := strconv.Atoi(v); err == nil { + return n + } + } + return def +} + +func pathID(w http.ResponseWriter, r *http.Request) (int64, bool) { + id, err := strconv.ParseInt(r.PathValue("id"), 10, 64) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid id") + return 0, false + } + return id, true +} + +func writeJSON(w http.ResponseWriter, status int, v any) { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(v) +} + +func writeError(w http.ResponseWriter, status int, msg string) { + writeJSON(w, status, map[string]string{"error": msg}) +} diff --git a/internal/server/middleware.go b/internal/server/middleware.go new file mode 100644 index 0000000..b90e4fd --- /dev/null +++ b/internal/server/middleware.go @@ -0,0 +1,55 @@ +package server + +import ( + "log/slog" + "net/http" + "time" +) + +func logging(log *slog.Logger, next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + sw := &statusWriter{ResponseWriter: w, status: http.StatusOK} + next.ServeHTTP(sw, r) + if r.URL.Path == "/healthz" || r.URL.Path == "/readyz" { + return + } + log.Info("request", + "method", r.Method, + "path", r.URL.Path, + "status", sw.status, + "ms", time.Since(start).Milliseconds(), + ) + }) +} + +type statusWriter struct { + http.ResponseWriter + status int +} + +func (w *statusWriter) WriteHeader(code int) { + w.status = code + w.ResponseWriter.WriteHeader(code) +} + +const contentSecurityPolicy = "default-src 'self'; " + + "script-src 'self' https://cdnjs.cloudflare.com; " + + "style-src 'self' 'unsafe-inline' https://cdnjs.cloudflare.com; " + + "img-src 'self' data:; " + + "connect-src 'self'; " + + "base-uri 'self'; " + + "form-action 'self'; " + + "frame-ancestors 'none'" + +func securityHeaders(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h := w.Header() + h.Set("Content-Security-Policy", contentSecurityPolicy) + h.Set("X-Content-Type-Options", "nosniff") + h.Set("X-Frame-Options", "DENY") + h.Set("Referrer-Policy", "no-referrer") + h.Set("Cross-Origin-Opener-Policy", "same-origin") + next.ServeHTTP(w, r) + }) +} diff --git a/internal/server/multi_database_integration_test.go b/internal/server/multi_database_integration_test.go new file mode 100644 index 0000000..e4d609e --- /dev/null +++ b/internal/server/multi_database_integration_test.go @@ -0,0 +1,153 @@ +//go:build integration + +package server + +import ( + "context" + "io" + "log/slog" + "net/http/httptest" + "net/url" + "os" + "strconv" + "testing" + "testing/fstest" + "time" + + "github.com/jackc/pgx/v5" + + "github.com/descope-sample-apps/pgpeek/internal/db" + "github.com/descope-sample-apps/pgpeek/internal/store" +) + +func TestIntegrationMultiDatabaseHTTPSelection(t *testing.T) { + baseDSN := os.Getenv("PGPEEK_TEST_DATABASE_URL") + if baseDSN == "" { + t.Skip("PGPEEK_TEST_DATABASE_URL not set") + } + + ctx := context.Background() + maintenance, err := pgx.Connect(ctx, baseDSN) + if err != nil { + t.Fatalf("maintenance connect: %v", err) + } + defer maintenance.Close(ctx) + + suffix := strconv.FormatInt(time.Now().UnixNano(), 36) + leftName := "pgpeek_multi_left_" + suffix + rightName := "pgpeek_multi_right_" + suffix + createDatabase(t, maintenance, leftName) + createDatabase(t, maintenance, rightName) + t.Cleanup(func() { + dropDatabase(t, maintenance, leftName) + dropDatabase(t, maintenance, rightName) + }) + + leftDSN := databaseDSN(t, baseDSN, leftName) + rightDSN := databaseDSN(t, baseDSN, rightName) + seedMarker(t, leftDSN, "left") + seedMarker(t, rightDSN, "right") + + leftPool := newIntegrationPool(t, leftDSN) + rightPool := newIntegrationPool(t, rightDSN) + registry, err := db.NewRegistry([]db.RegistryEntry{ + {ID: "left", Name: "Left", Pool: leftPool, Default: true}, + {ID: "right", Name: "Right", Pool: rightPool}, + }) + if err != nil { + t.Fatalf("NewRegistry: %v", err) + } + t.Cleanup(registry.Close) + + st, err := store.Open(t.TempDir() + "/queries.db") + if err != nil { + t.Fatalf("store.Open: %v", err) + } + t.Cleanup(func() { _ = st.Close() }) + + srv := NewWithRegistry( + NewDatabaseRegistry(registry), + st, + fstest.MapFS{"index.html": &fstest.MapFile{Data: []byte("pgpeek")}}, + slog.New(slog.NewTextHandler(io.Discard, nil)), + 5*time.Second, + ) + ts := httptest.NewServer(srv.Routes()) + t.Cleanup(ts.Close) + + left := runMarkerQuery(t, ts, "left") + right := runMarkerQuery(t, ts, "right") + + if left != "left" || right != "right" { + t.Fatalf("markers = %q/%q, want left/right", left, right) + } +} + +func createDatabase(t *testing.T, conn *pgx.Conn, name string) { + t.Helper() + if _, err := conn.Exec(context.Background(), "CREATE DATABASE "+pgx.Identifier{name}.Sanitize()); err != nil { + t.Fatalf("create database %s: %v", name, err) + } +} + +func dropDatabase(t *testing.T, conn *pgx.Conn, name string) { + t.Helper() + _, _ = conn.Exec(context.Background(), "DROP DATABASE IF EXISTS "+pgx.Identifier{name}.Sanitize()+" WITH (FORCE)") +} + +func databaseDSN(t *testing.T, baseDSN, dbName string) string { + t.Helper() + u, err := url.Parse(baseDSN) + if err != nil || u.Scheme == "" || u.Host == "" { + t.Skip("PGPEEK_TEST_DATABASE_URL must be a postgres URL for multi-database integration") + } + u.Path = "/" + dbName + return u.String() +} + +func seedMarker(t *testing.T, dsn, marker string) { + t.Helper() + conn, err := pgx.Connect(context.Background(), dsn) + if err != nil { + t.Fatalf("seed connect: %v", err) + } + defer conn.Close(context.Background()) + if _, err := conn.Exec(context.Background(), `CREATE TABLE pgpeek_marker (value text not null)`); err != nil { + t.Fatalf("create marker table: %v", err) + } + if _, err := conn.Exec(context.Background(), `INSERT INTO pgpeek_marker VALUES ($1)`, marker); err != nil { + t.Fatalf("insert marker: %v", err) + } +} + +func newIntegrationPool(t *testing.T, dsn string) *db.Pool { + t.Helper() + pool, err := db.New(context.Background(), db.Config{ + DSN: dsn, + MaxConns: 2, + StatementTimeout: 5 * time.Second, + IdleTxTimeout: 5 * time.Second, + RowCap: 10, + }) + if err != nil { + t.Fatalf("db.New: %v", err) + } + return pool +} + +func runMarkerQuery(t *testing.T, ts *httptest.Server, id string) string { + t.Helper() + resp := post(t, ts, "/api/query?db="+id, `{"sql":"SELECT value FROM pgpeek_marker"}`) + if resp.StatusCode != 200 { + t.Fatalf("query %s status = %d", id, resp.StatusCode) + } + result := decode[db.Result](t, resp) + if result.RowCount != 1 { + t.Fatalf("query %s rowCount = %d, want 1", id, result.RowCount) + } + marker, ok := result.Rows[0][0].(string) + if !ok { + t.Fatalf("query %s marker type = %T", id, result.Rows[0][0]) + } + return marker +} diff --git a/internal/server/query_handlers.go b/internal/server/query_handlers.go new file mode 100644 index 0000000..ac7ce5e --- /dev/null +++ b/internal/server/query_handlers.go @@ -0,0 +1,83 @@ +package server + +import ( + "context" + "encoding/csv" + "io" + "net/http" + "strings" + + "github.com/descope-sample-apps/pgpeek/internal/db" + "github.com/descope-sample-apps/pgpeek/internal/guard" +) + +type queryRequest struct { + SQL string `json:"sql"` +} + +func (s *Server) handleQuery(w http.ResponseWriter, r *http.Request) { + res, ok := s.readOnlyResult(w, r) + if !ok { + return + } + writeJSON(w, http.StatusOK, res) +} + +func (s *Server) handleExport(w http.ResponseWriter, r *http.Request) { + res, ok := s.readOnlyResult(w, r) + if !ok { + return + } + w.Header().Set("Content-Type", "text/csv; charset=utf-8") + w.Header().Set("Content-Disposition", `attachment; filename="pgpeek-export.csv"`) + if err := writeCSV(w, res); err != nil { + s.log.Error("csv export", "err", err) + } +} + +func (s *Server) readOnlyResult(w http.ResponseWriter, r *http.Request) (*db.Result, bool) { + pool, ok := s.poolForRequest(w, r) + if !ok { + return nil, false + } + sql, ok := decodeReadOnlyQuery(w, r) + if !ok { + return nil, false + } + ctx, cancel := context.WithTimeout(r.Context(), s.queryWait) + defer cancel() + res, err := pool.Query(ctx, sql) + if err != nil { + s.log.Error("query", "err", err) + writeError(w, http.StatusBadRequest, "query failed") + return nil, false + } + return res, true +} + +func decodeReadOnlyQuery(w http.ResponseWriter, r *http.Request) (string, bool) { + var req queryRequest + if !decodeBody(w, r, &req) { + return "", false + } + sql := strings.TrimSpace(req.SQL) + if err := guard.Validate(sql); err != nil { + writeError(w, http.StatusBadRequest, err.Error()) + return "", false + } + return sql, true +} + +func writeCSV(w io.Writer, res *db.Result) error { + cw := csv.NewWriter(w) + _ = cw.Write(res.Columns) + row := make([]string, len(res.Columns)) + for _, rec := range res.Rows { + for i, cell := range rec { + row[i] = db.CellString(cell) + } + _ = cw.Write(row) + } + cw.Flush() + return cw.Error() +} diff --git a/internal/server/query_handlers_test.go b/internal/server/query_handlers_test.go new file mode 100644 index 0000000..a4bb725 --- /dev/null +++ b/internal/server/query_handlers_test.go @@ -0,0 +1,138 @@ +package server + +import ( + "errors" + "io" + "net/http" + "strings" + "testing" + + "github.com/descope-sample-apps/pgpeek/internal/db" +) + +func TestQuery_OK(t *testing.T) { + q := &fakeQuerier{result: okResult()} + ts, _ := newTestServer(t, q) + + resp := post(t, ts, "/api/query", `{"sql":" SELECT 1 "}`) + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d", resp.StatusCode) + } + res := decode[db.Result](t, resp) + if res.RowCount != 1 { + t.Errorf("rowCount = %d", res.RowCount) + } + if q.lastSQL != "SELECT 1" { + t.Errorf("SQL not trimmed before exec: %q", q.lastSQL) + } +} + +func TestQuery_GuardRejectsDML(t *testing.T) { + q := &fakeQuerier{result: okResult()} + ts, _ := newTestServer(t, q) + + resp := post(t, ts, "/api/query", `{"sql":"DELETE FROM users"}`) + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("status = %d, want 400", resp.StatusCode) + } + if q.called { + t.Error("guard should block the query before it reaches the database") + } +} + +func TestQuery_InvalidJSON(t *testing.T) { + ts, _ := newTestServer(t, &fakeQuerier{}) + resp := post(t, ts, "/api/query", `{not json`) + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("status = %d, want 400", resp.StatusCode) + } +} + +func TestQuery_UnknownField(t *testing.T) { + ts, _ := newTestServer(t, &fakeQuerier{}) + resp := post(t, ts, "/api/query", `{"sql":"SELECT 1","evil":true}`) + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("status = %d, want 400 (DisallowUnknownFields)", resp.StatusCode) + } +} + +func TestQuery_DBError(t *testing.T) { + q := &fakeQuerier{err: errors.New("postgres://secret-host/hidden: boom")} + ts, _ := newTestServer(t, q) + resp := post(t, ts, "/api/query", `{"sql":"SELECT 1"}`) + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("status = %d, want 400", resp.StatusCode) + } + got := decode[map[string]string](t, resp) + if got["error"] != "query failed" { + t.Fatalf("error = %q, want sanitized query failed", got["error"]) + } +} + +func TestQuery_BodyTooLarge(t *testing.T) { + ts, _ := newTestServer(t, &fakeQuerier{result: okResult()}) + huge := strings.Repeat("a", (1<<20)+10) + resp := post(t, ts, "/api/query", `{"sql":"`+huge+`"}`) + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("status = %d, want 400 for oversized body", resp.StatusCode) + } +} + +func TestExport_CSV(t *testing.T) { + q := &fakeQuerier{result: &db.Result{ + Columns: []string{"name", "n"}, + Rows: [][]any{{"Acme", int64(2)}, {"Globex,Inc", int64(1)}}, + RowCount: 2, + }} + ts, _ := newTestServer(t, q) + + resp := post(t, ts, "/api/export", `{"sql":"SELECT 1"}`) + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d", resp.StatusCode) + } + if ct := resp.Header.Get("Content-Type"); !strings.HasPrefix(ct, "text/csv") { + t.Errorf("content-type = %q", ct) + } + if cd := resp.Header.Get("Content-Disposition"); !strings.Contains(cd, "pgpeek-export.csv") { + t.Errorf("content-disposition = %q", cd) + } + body, _ := io.ReadAll(resp.Body) + got := string(body) + if !strings.Contains(got, "name,n") || !strings.Contains(got, "Acme,2") { + t.Errorf("csv body = %q", got) + } + // Field with a comma must be quoted by encoding/csv. + if !strings.Contains(got, `"Globex,Inc"`) { + t.Errorf("comma field not quoted: %q", got) + } +} + +func TestExport_GuardRejects(t *testing.T) { + q := &fakeQuerier{result: okResult()} + ts, _ := newTestServer(t, q) + resp := post(t, ts, "/api/export", `{"sql":"DROP TABLE x"}`) + resp.Body.Close() + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("status = %d, want 400", resp.StatusCode) + } +} + +func TestExport_DBError(t *testing.T) { + // Given: CSV export receives a read-only query but database execution fails. + q := &fakeQuerier{err: errors.New("postgres://secret-host/hidden: boom")} + ts, _ := newTestServer(t, q) + + // When: export is requested. + resp := post(t, ts, "/api/export", `{"sql":"SELECT 1"}`) + defer resp.Body.Close() + + // Then: handler returns the same bad-request contract as query execution. + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("status = %d, want 400", resp.StatusCode) + } + got := decode[map[string]string](t, resp) + if got["error"] != "query failed" { + t.Fatalf("error = %q, want sanitized query failed", got["error"]) + } +} diff --git a/internal/server/registry.go b/internal/server/registry.go new file mode 100644 index 0000000..92395bf --- /dev/null +++ b/internal/server/registry.go @@ -0,0 +1,102 @@ +package server + +import ( + "context" + "errors" + "net/http" + "strings" + + "github.com/descope-sample-apps/pgpeek/internal/db" + "github.com/descope-sample-apps/pgpeek/internal/store" +) + +type Querier interface { + Query(ctx context.Context, sql string) (*db.Result, error) + Tables(ctx context.Context) ([]db.TableInfo, error) + Columns(ctx context.Context, schema, table string) ([]db.ColumnInfo, error) + ForeignKeys(ctx context.Context, schema, table string) ([]db.ForeignKey, error) + TableRows(ctx context.Context, q db.TableQuery) (*db.Result, error) + RowCap() int + Ping(ctx context.Context) error +} + +type QueryStore interface { + List(ctx context.Context) ([]store.SavedQuery, error) + Get(ctx context.Context, id int64) (store.SavedQuery, error) + Create(ctx context.Context, name, desc, sql string) (store.SavedQuery, error) + Update(ctx context.Context, id int64, name, desc, sql string) (store.SavedQuery, error) + Delete(ctx context.Context, id int64) error +} + +type DatabaseRegistry interface { + List() []db.PoolMetadata + DefaultID() string + Pool(id string) (Querier, error) + Ping(ctx context.Context) error +} + +type singleDatabaseRegistry struct { + pool Querier +} + +func NewSingleDatabaseRegistry(pool Querier) DatabaseRegistry { + return singleDatabaseRegistry{pool: pool} +} + +func NewDatabaseRegistry(registry *db.Registry) DatabaseRegistry { + return dbRegistryAdapter{registry: registry} +} + +func (r singleDatabaseRegistry) List() []db.PoolMetadata { + return []db.PoolMetadata{{ID: "default", Name: "Default"}} +} + +func (r singleDatabaseRegistry) DefaultID() string { return "default" } + +func (r singleDatabaseRegistry) Pool(id string) (Querier, error) { + if id == "" || id == r.DefaultID() { + return r.pool, nil + } + return nil, db.ErrPoolNotFound +} + +func (r singleDatabaseRegistry) Ping(ctx context.Context) error { + return r.pool.Ping(ctx) +} + +type dbRegistryAdapter struct { + registry *db.Registry +} + +type databasesResponse struct { + DefaultID string `json:"defaultId"` + Databases []db.PoolMetadata `json:"databases"` +} + +func (r dbRegistryAdapter) List() []db.PoolMetadata { return r.registry.List() } + +func (r dbRegistryAdapter) DefaultID() string { return r.registry.DefaultID() } + +func (r dbRegistryAdapter) Pool(id string) (Querier, error) { + return r.registry.Pool(id) +} + +func (r dbRegistryAdapter) Ping(ctx context.Context) error { return r.registry.Ping(ctx) } + +func (s *Server) handleDatabases(w http.ResponseWriter, _ *http.Request) { + writeJSON(w, http.StatusOK, databasesResponse{DefaultID: s.registry.DefaultID(), Databases: s.registry.List()}) +} + +func (s *Server) poolForRequest(w http.ResponseWriter, r *http.Request) (Querier, bool) { + pool, err := s.registry.Pool(strings.TrimSpace(r.URL.Query().Get("db"))) + if errors.Is(err, db.ErrPoolNotFound) { + writeError(w, http.StatusNotFound, "database not found") + return nil, false + } + if err != nil { + writeError(w, http.StatusInternalServerError, "database unavailable") + s.log.Error("select database", "err", err) + return nil, false + } + return pool, true +} diff --git a/internal/server/registry_adapter_integration_test.go b/internal/server/registry_adapter_integration_test.go new file mode 100644 index 0000000..96dec78 --- /dev/null +++ b/internal/server/registry_adapter_integration_test.go @@ -0,0 +1,52 @@ +//go:build integration + +package server + +import ( + "context" + "os" + "testing" + "time" + + "github.com/descope-sample-apps/pgpeek/internal/db" +) + +func TestDatabaseRegistryAdapter_delegates_to_db_registry(t *testing.T) { + // Given: a real database registry backs the server adapter. + dsn := os.Getenv("PGPEEK_TEST_DATABASE_URL") + if dsn == "" { + t.Skip("PGPEEK_TEST_DATABASE_URL not set") + } + pool, err := db.New(context.Background(), db.Config{ + DSN: dsn, + MaxConns: 1, + StatementTimeout: 5 * time.Second, + IdleTxTimeout: 5 * time.Second, + RowCap: 10, + }) + if err != nil { + t.Fatalf("db.New: %v", err) + } + t.Cleanup(pool.Close) + registry, err := db.NewRegistry([]db.RegistryEntry{{ID: "primary", Name: "Primary", Pool: pool, Default: true}}) + if err != nil { + t.Fatalf("NewRegistry: %v", err) + } + adapter := NewDatabaseRegistry(registry) + + // When: server code calls adapter methods. + metadata := adapter.List() + querier, poolErr := adapter.Pool("primary") + pingErr := adapter.Ping(context.Background()) + + // Then: calls return the underlying registry data and pool. + if adapter.DefaultID() != "primary" || len(metadata) != 1 || metadata[0].Name != "Primary" { + t.Fatalf("default=%q metadata=%+v", adapter.DefaultID(), metadata) + } + if poolErr != nil || querier == nil { + t.Fatalf("Pool error=%v querier=%v", poolErr, querier) + } + if pingErr != nil { + t.Fatalf("Ping: %v", pingErr) + } +} diff --git a/internal/server/registry_unit_test.go b/internal/server/registry_unit_test.go new file mode 100644 index 0000000..7b7b516 --- /dev/null +++ b/internal/server/registry_unit_test.go @@ -0,0 +1,112 @@ +package server + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/descope-sample-apps/pgpeek/internal/db" +) + +func TestSingleDatabaseRegistry_methods_select_default_or_reject_unknown(t *testing.T) { + // Given: a single-database registry wraps one querier. + q := &fakeQuerier{} + registry := NewSingleDatabaseRegistry(q) + + // When: callers inspect metadata and select pools. + metadata := registry.List() + defaultPool, defaultErr := registry.Pool("") + namedPool, namedErr := registry.Pool("default") + _, missingErr := registry.Pool("missing") + pingErr := registry.Ping(context.Background()) + + // Then: only the default id is accepted and delegated. + if len(metadata) != 1 || metadata[0].ID != "default" || registry.DefaultID() != "default" { + t.Fatalf("metadata=%+v default=%q", metadata, registry.DefaultID()) + } + if defaultErr != nil || namedErr != nil || pingErr != nil { + t.Fatalf("defaultErr=%v namedErr=%v pingErr=%v", defaultErr, namedErr, pingErr) + } + if defaultPool != q || namedPool != q { + t.Fatalf("pool selection returned wrong querier") + } + if !errors.Is(missingErr, db.ErrPoolNotFound) { + t.Fatalf("missingErr=%v, want ErrPoolNotFound", missingErr) + } +} + +func TestPoolForRequest_returns_500_when_registry_selection_fails(t *testing.T) { + // Given: registry lookup fails with a non-not-found error. + srv := serverWithStore(t, &fakeQuerier{}, &fakeStore{}) + srv.registry = failingRegistry{err: errors.New("registry unavailable")} + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/meta?db=primary", nil) + + // When: a handler selects a pool for the request. + pool, ok := srv.poolForRequest(rec, req) + + // Then: callers receive a sanitized 500 response and no pool. + if ok || pool != nil { + t.Fatalf("poolForRequest ok=%v pool=%v, want no pool", ok, pool) + } + if rec.Code != http.StatusInternalServerError || strings.Contains(rec.Body.String(), "registry unavailable") { + t.Fatalf("response code=%d body=%q", rec.Code, rec.Body.String()) + } +} + +func TestHandlers_return_not_found_when_selected_database_missing(t *testing.T) { + tests := []struct { + name string + method string + path string + body string + }{ + {name: "meta", method: http.MethodGet, path: "/api/meta?db=missing"}, + {name: "columns", method: http.MethodGet, path: "/api/tables/public/users/columns?db=missing"}, + {name: "foreign keys", method: http.MethodGet, path: "/api/tables/public/users/fks?db=missing"}, + {name: "table data", method: http.MethodGet, path: "/api/tables/public/users/data?db=missing"}, + {name: "query", method: http.MethodPost, path: "/api/query?db=missing", body: `{"sql":"SELECT 1"}`}, + {name: "export", method: http.MethodPost, path: "/api/export?db=missing", body: `{"sql":"SELECT 1"}`}, + {name: "ready", method: http.MethodGet, path: "/readyz?db=missing"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Given: server has no database matching the request id. + ts := newRegistryTestServer(t, fakeRegistry{ + defaultID: "primary", + metadata: []db.PoolMetadata{{ID: "primary", Name: "Primary"}}, + pools: map[string]Querier{"primary": &selectedQuerier{rowCap: 1000}}, + }) + req, err := http.NewRequest(tt.method, ts.URL+tt.path, strings.NewReader(tt.body)) + if err != nil { + t.Fatal(err) + } + + // When: the request is served. + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + // Then: the handler stops at database selection. + if resp.StatusCode != http.StatusNotFound { + t.Fatalf("status=%d, want 404", resp.StatusCode) + } + }) + } +} + +type failingRegistry struct { + err error +} + +func (f failingRegistry) List() []db.PoolMetadata { return nil } +func (f failingRegistry) DefaultID() string { return "default" } +func (f failingRegistry) Pool(string) (Querier, error) { + return nil, f.err +} +func (f failingRegistry) Ping(context.Context) error { return f.err } diff --git a/internal/server/saved_query_errors_test.go b/internal/server/saved_query_errors_test.go new file mode 100644 index 0000000..aadea10 --- /dev/null +++ b/internal/server/saved_query_errors_test.go @@ -0,0 +1,149 @@ +package server + +import ( + "errors" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "testing" + + "github.com/descope-sample-apps/pgpeek/internal/db" +) + +func TestListEmptyReturnsArray(t *testing.T) { + // A successful List of an empty store returns nil; the handler must emit []. + srv := serverWithStore(t, &fakeQuerier{}, &fakeStore{}) + req := httptest.NewRequest(http.MethodGet, "/api/queries", nil) + rec := httptest.NewRecorder() + srv.Routes().ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("code = %d", rec.Code) + } + if body := strings.TrimSpace(rec.Body.String()); body != "[]" { + t.Errorf("body = %q, want []", body) + } +} + +func TestUpdate_InvalidJSON(t *testing.T) { + ts, _ := newTestServer(t, &fakeQuerier{}) + req, _ := http.NewRequest(http.MethodPut, ts.URL+"/api/queries/1", strings.NewReader(`{bad`)) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + resp.Body.Close() + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("status = %d, want 400", resp.StatusCode) + } +} + +func TestUpdate_NotFound(t *testing.T) { + ts, _ := newTestServer(t, &fakeQuerier{}) + req, _ := http.NewRequest(http.MethodPut, ts.URL+"/api/queries/999", strings.NewReader(`{"name":"x","sql":"SELECT 1"}`)) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + resp.Body.Close() + if resp.StatusCode != http.StatusNotFound { + t.Errorf("status = %d, want 404", resp.StatusCode) + } +} + +func TestDelete_BadID(t *testing.T) { + ts, _ := newTestServer(t, &fakeQuerier{}) + req, _ := http.NewRequest(http.MethodDelete, ts.URL+"/api/queries/abc", nil) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + resp.Body.Close() + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("status = %d, want 400", resp.StatusCode) + } +} + +func TestExport_QueryError(t *testing.T) { + ts, _ := newTestServer(t, &fakeQuerier{err: errors.New("boom")}) + resp := post(t, ts, "/api/export", `{"sql":"SELECT 1"}`) + resp.Body.Close() + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("status = %d, want 400", resp.StatusCode) + } +} + +func TestExport_InvalidJSON(t *testing.T) { + ts, _ := newTestServer(t, &fakeQuerier{}) + resp := post(t, ts, "/api/export", `{bad`) + resp.Body.Close() + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("status = %d, want 400", resp.StatusCode) + } +} + +func TestSavedQuery_InvalidJSON(t *testing.T) { + ts, _ := newTestServer(t, &fakeQuerier{}) + resp := post(t, ts, "/api/queries", `{bad`) + resp.Body.Close() + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("status = %d, want 400", resp.StatusCode) + } +} + +// failingWriter is an http.ResponseWriter whose body writes always fail. +type failingWriter struct { + header http.Header + code int +} + +func (f *failingWriter) Header() http.Header { + if f.header == nil { + f.header = http.Header{} + } + return f.header +} +func (f *failingWriter) Write([]byte) (int, error) { return 0, errors.New("connection reset") } +func (f *failingWriter) WriteHeader(code int) { f.code = code } + +func TestExport_WriteFailureLogged(t *testing.T) { + // Exercise the handler's csv-error branch directly with a writer that fails. + srv := serverWithStore(t, &fakeQuerier{result: okResult()}, &fakeStore{}) + req := httptest.NewRequest(http.MethodPost, "/api/export", strings.NewReader(`{"sql":"SELECT 1"}`)) + fw := &failingWriter{} + srv.handleExport(fw, req) + // Header is set before the failing body write. + if ct := fw.Header().Get("Content-Type"); !strings.HasPrefix(ct, "text/csv") { + t.Errorf("content-type = %q", ct) + } +} + +func TestWriteCSV(t *testing.T) { + res := &db.Result{Columns: []string{"a"}, Rows: [][]any{{"1"}}} + var buf strings.Builder + if err := writeCSV(&buf, res); err != nil { + t.Fatalf("writeCSV: %v", err) + } + if !strings.Contains(buf.String(), "a\n1\n") { + t.Errorf("csv = %q", buf.String()) + } + // Failing writer -> error surfaces via cw.Error() after Flush. + if err := writeCSV(failWriter{}, res); err == nil { + t.Error("expected error from failing writer") + } +} + +type failWriter struct{} + +func (failWriter) Write([]byte) (int, error) { return 0, errors.New("boom") } + +func mustGet(t *testing.T, ts *httptest.Server, path string) *http.Response { + t.Helper() + resp, err := http.Get(ts.URL + path) + if err != nil { + t.Fatalf("GET %s: %v", path, err) + } + return resp +} + +func itoa(i int64) string { return strconv.FormatInt(i, 10) } diff --git a/internal/server/saved_query_handlers.go b/internal/server/saved_query_handlers.go new file mode 100644 index 0000000..5891546 --- /dev/null +++ b/internal/server/saved_query_handlers.go @@ -0,0 +1,101 @@ +package server + +import ( + "errors" + "net/http" + "strings" + + "github.com/descope-sample-apps/pgpeek/internal/guard" + "github.com/descope-sample-apps/pgpeek/internal/store" +) + +func (s *Server) handleListQueries(w http.ResponseWriter, r *http.Request) { + qs, err := s.store.List(r.Context()) + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to list saved queries") + s.log.Error("list saved queries", "err", err) + return + } + if qs == nil { + qs = []store.SavedQuery{} + } + writeJSON(w, http.StatusOK, qs) +} + +type savedQueryRequest struct { + Name string `json:"name"` + Description string `json:"description"` + SQL string `json:"sql"` +} + +func (s *Server) handleCreateQuery(w http.ResponseWriter, r *http.Request) { + req, ok := decodeSavedQuery(w, r) + if !ok { + return + } + q, err := s.store.Create(r.Context(), req.Name, req.Description, req.SQL) + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to save query") + s.log.Error("create saved query", "err", err) + return + } + writeJSON(w, http.StatusCreated, q) +} + +func (s *Server) handleUpdateQuery(w http.ResponseWriter, r *http.Request) { + id, ok := pathID(w, r) + if !ok { + return + } + req, ok := decodeSavedQuery(w, r) + if !ok { + return + } + q, err := s.store.Update(r.Context(), id, req.Name, req.Description, req.SQL) + if errors.Is(err, store.ErrNotFound) { + writeError(w, http.StatusNotFound, "saved query not found") + return + } + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to update query") + s.log.Error("update saved query", "err", err) + return + } + writeJSON(w, http.StatusOK, q) +} + +func (s *Server) handleDeleteQuery(w http.ResponseWriter, r *http.Request) { + id, ok := pathID(w, r) + if !ok { + return + } + err := s.store.Delete(r.Context(), id) + if errors.Is(err, store.ErrNotFound) { + writeError(w, http.StatusNotFound, "saved query not found") + return + } + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to delete query") + s.log.Error("delete saved query", "err", err) + return + } + w.WriteHeader(http.StatusNoContent) +} + +func decodeSavedQuery(w http.ResponseWriter, r *http.Request) (savedQueryRequest, bool) { + var req savedQueryRequest + if !decodeBody(w, r, &req) { + return req, false + } + req.Name = strings.TrimSpace(req.Name) + req.SQL = strings.TrimSpace(req.SQL) + if req.Name == "" || req.SQL == "" { + writeError(w, http.StatusBadRequest, "name and sql are required") + return req, false + } + if err := guard.Validate(req.SQL); err != nil { + writeError(w, http.StatusBadRequest, "saved query must be read-only: "+err.Error()) + return req, false + } + return req, true +} diff --git a/internal/server/saved_query_handlers_test.go b/internal/server/saved_query_handlers_test.go new file mode 100644 index 0000000..1296664 --- /dev/null +++ b/internal/server/saved_query_handlers_test.go @@ -0,0 +1,99 @@ +package server + +import ( + "net/http" + "strings" + "testing" + + "github.com/descope-sample-apps/pgpeek/internal/store" +) + +func TestSavedQueries_CRUD(t *testing.T) { + ts, _ := newTestServer(t, &fakeQuerier{result: okResult()}) + + // Create + resp := post(t, ts, "/api/queries", `{"name":"q","description":"d","sql":"SELECT 1"}`) + if resp.StatusCode != http.StatusCreated { + t.Fatalf("create status = %d", resp.StatusCode) + } + created := decode[store.SavedQuery](t, resp) + if created.ID == 0 { + t.Fatal("no id returned") + } + + // List + resp = mustGet(t, ts, "/api/queries") + list := decode[[]store.SavedQuery](t, resp) + if len(list) != 1 { + t.Fatalf("list len = %d", len(list)) + } + + // Update + req, _ := http.NewRequest(http.MethodPut, ts.URL+"/api/queries/"+itoa(created.ID), + strings.NewReader(`{"name":"q2","sql":"SELECT 2"}`)) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("update status = %d", resp.StatusCode) + } + resp.Body.Close() + + // Delete + req, _ = http.NewRequest(http.MethodDelete, ts.URL+"/api/queries/"+itoa(created.ID), nil) + resp, err = http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != http.StatusNoContent { + t.Fatalf("delete status = %d", resp.StatusCode) + } + resp.Body.Close() +} + +func TestSavedQueries_Validation(t *testing.T) { + ts, _ := newTestServer(t, &fakeQuerier{}) + + cases := []struct { + name, body string + }{ + {"missing name", `{"sql":"SELECT 1"}`}, + {"missing sql", `{"name":"x"}`}, + {"non-readonly sql", `{"name":"x","sql":"DELETE FROM t"}`}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + resp := post(t, ts, "/api/queries", c.body) + resp.Body.Close() + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("status = %d, want 400", resp.StatusCode) + } + }) + } +} + +func TestSavedQueries_NotFoundAndBadID(t *testing.T) { + ts, _ := newTestServer(t, &fakeQuerier{}) + + req, _ := http.NewRequest(http.MethodDelete, ts.URL+"/api/queries/999", nil) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != http.StatusNotFound { + t.Errorf("delete missing status = %d, want 404", resp.StatusCode) + } + resp.Body.Close() + + req, _ = http.NewRequest(http.MethodPut, ts.URL+"/api/queries/abc", + strings.NewReader(`{"name":"x","sql":"SELECT 1"}`)) + resp, err = http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("bad id status = %d, want 400", resp.StatusCode) + } + resp.Body.Close() +} diff --git a/internal/server/server.go b/internal/server/server.go index fadc774..7e5f33f 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -3,51 +3,15 @@ package server import ( - "context" - "encoding/csv" - "encoding/json" - "errors" - "io" "io/fs" "log/slog" "net/http" - "strconv" - "strings" "time" - - "github.com/descope-sample-apps/pgpeek/internal/db" - "github.com/descope-sample-apps/pgpeek/internal/guard" - "github.com/descope-sample-apps/pgpeek/internal/store" ) -// maxBodyBytes caps request bodies. Queries are SQL text, not data, so 1 MiB is -// generous while preventing a client from forcing unbounded buffering. -const maxBodyBytes = 1 << 20 - -// Querier runs read-only queries, browses the catalog, and reports database -// health. *db.Pool implements it; tests substitute a fake. -type Querier interface { - Query(ctx context.Context, sql string) (*db.Result, error) - Tables(ctx context.Context) ([]db.TableInfo, error) - Columns(ctx context.Context, schema, table string) ([]db.ColumnInfo, error) - ForeignKeys(ctx context.Context, schema, table string) ([]db.ForeignKey, error) - TableRows(ctx context.Context, q db.TableQuery) (*db.Result, error) - RowCap() int - Ping(ctx context.Context) error -} - -// QueryStore persists saved/preset queries. *store.Store implements it. -type QueryStore interface { - List(ctx context.Context) ([]store.SavedQuery, error) - Get(ctx context.Context, id int64) (store.SavedQuery, error) - Create(ctx context.Context, name, desc, sql string) (store.SavedQuery, error) - Update(ctx context.Context, id int64, name, desc, sql string) (store.SavedQuery, error) - Delete(ctx context.Context, id int64) error -} - // Server holds the dependencies for the HTTP handlers. type Server struct { - pool Querier + registry DatabaseRegistry store QueryStore web fs.FS log *slog.Logger @@ -56,7 +20,11 @@ type Server struct { // New constructs a Server. func New(pool Querier, st QueryStore, web fs.FS, log *slog.Logger, queryWait time.Duration) *Server { - return &Server{pool: pool, store: st, web: web, log: log, queryWait: queryWait} + return NewWithRegistry(NewSingleDatabaseRegistry(pool), st, web, log, queryWait) +} + +func NewWithRegistry(registry DatabaseRegistry, st QueryStore, web fs.FS, log *slog.Logger, queryWait time.Duration) *Server { + return &Server{registry: registry, store: st, web: web, log: log, queryWait: queryWait} } // Routes returns the configured handler. @@ -71,6 +39,7 @@ func (s *Server) Routes() http.Handler { mux.HandleFunc("GET /readyz", s.handleReady) // API + mux.HandleFunc("GET /api/databases", s.handleDatabases) mux.HandleFunc("GET /api/meta", s.handleMeta) mux.HandleFunc("POST /api/query", s.handleQuery) mux.HandleFunc("POST /api/export", s.handleExport) @@ -88,419 +57,3 @@ func (s *Server) Routes() http.Handler { return securityHeaders(logging(s.log, mux)) } - -type queryRequest struct { - SQL string `json:"sql"` -} - -func (s *Server) handleQuery(w http.ResponseWriter, r *http.Request) { - var req queryRequest - if !decodeBody(w, r, &req) { - return - } - sql := strings.TrimSpace(req.SQL) - if err := guard.Validate(sql); err != nil { - writeError(w, http.StatusBadRequest, err.Error()) - return - } - - ctx, cancel := context.WithTimeout(r.Context(), s.queryWait) - defer cancel() - - res, err := s.pool.Query(ctx, sql) - if err != nil { - writeError(w, http.StatusBadRequest, "query failed: "+err.Error()) - return - } - writeJSON(w, http.StatusOK, res) -} - -func (s *Server) handleExport(w http.ResponseWriter, r *http.Request) { - var req queryRequest - if !decodeBody(w, r, &req) { - return - } - sql := strings.TrimSpace(req.SQL) - if err := guard.Validate(sql); err != nil { - writeError(w, http.StatusBadRequest, err.Error()) - return - } - - ctx, cancel := context.WithTimeout(r.Context(), s.queryWait) - defer cancel() - - res, err := s.pool.Query(ctx, sql) - if err != nil { - writeError(w, http.StatusBadRequest, "query failed: "+err.Error()) - return - } - - w.Header().Set("Content-Type", "text/csv; charset=utf-8") - w.Header().Set("Content-Disposition", `attachment; filename="pgpeek-export.csv"`) - if err := writeCSV(w, res); err != nil { - // Headers/partial body may already be sent; just log. - s.log.Error("csv export", "err", err) - } -} - -// writeCSV streams the result as CSV. encoding/csv accumulates errors stickily, -// so it's sufficient to check Error() once after Flush. -func writeCSV(w io.Writer, res *db.Result) error { - cw := csv.NewWriter(w) - _ = cw.Write(res.Columns) - row := make([]string, len(res.Columns)) - for _, rec := range res.Rows { - for i, cell := range rec { - row[i] = db.CellString(cell) - } - _ = cw.Write(row) - } - cw.Flush() - return cw.Error() -} - -// handleMeta exposes server-side limits the UI needs (notably the row cap, so -// the client can size its page and pagination correctly). -func (s *Server) handleMeta(w http.ResponseWriter, _ *http.Request) { - writeJSON(w, http.StatusOK, map[string]int{"rowCap": s.pool.RowCap()}) -} - -func (s *Server) handleTables(w http.ResponseWriter, r *http.Request) { - ctx, cancel := context.WithTimeout(r.Context(), s.queryWait) - defer cancel() - tables, err := s.pool.Tables(ctx) - if err != nil { - writeError(w, http.StatusInternalServerError, "failed to list tables") - s.log.Error("list tables", "err", err) - return - } - if tables == nil { - tables = []db.TableInfo{} - } - writeJSON(w, http.StatusOK, tables) -} - -func (s *Server) handleColumns(w http.ResponseWriter, r *http.Request) { - if rejectRestrictedRelation(w, r.PathValue("table")) { - return - } - ctx, cancel := context.WithTimeout(r.Context(), s.queryWait) - defer cancel() - cols, err := s.pool.Columns(ctx, r.PathValue("schema"), r.PathValue("table")) - if err != nil { - writeError(w, http.StatusInternalServerError, "failed to read columns") - s.log.Error("read columns", "err", err) - return - } - if cols == nil { - cols = []db.ColumnInfo{} - } - writeJSON(w, http.StatusOK, cols) -} - -func (s *Server) handleForeignKeys(w http.ResponseWriter, r *http.Request) { - if rejectRestrictedRelation(w, r.PathValue("table")) { - return - } - ctx, cancel := context.WithTimeout(r.Context(), s.queryWait) - defer cancel() - fks, err := s.pool.ForeignKeys(ctx, r.PathValue("schema"), r.PathValue("table")) - if err != nil { - writeError(w, http.StatusInternalServerError, "failed to read foreign keys") - s.log.Error("read foreign keys", "err", err) - return - } - if fks == nil { - fks = []db.ForeignKey{} - } - writeJSON(w, http.StatusOK, fks) -} - -func (s *Server) handleTableData(w http.ResponseWriter, r *http.Request) { - if rejectRestrictedRelation(w, r.PathValue("table")) { - return - } - q := db.TableQuery{ - Schema: r.PathValue("schema"), - Table: r.PathValue("table"), - Search: r.URL.Query().Get("search"), - Sort: r.URL.Query().Get("sort"), - Desc: r.URL.Query().Get("dir") == "desc", - Limit: queryInt(r, "limit", 0), // 0 -> pool clamps to row cap - Offset: queryInt(r, "offset", 0), // negative -> pool clamps to 0 - Filters: parseFilters(r.URL.Query()["f"]), - } - - ctx, cancel := context.WithTimeout(r.Context(), s.queryWait) - defer cancel() - res, err := s.pool.TableRows(ctx, q) - if err != nil { - writeError(w, http.StatusBadRequest, "failed to read rows") - s.log.Error("read rows", "err", err) - return - } - - if r.URL.Query().Get("format") == "csv" { - w.Header().Set("Content-Type", "text/csv; charset=utf-8") - w.Header().Set("Content-Disposition", `attachment; filename="`+safeFilename(r.PathValue("table"))+`.csv"`) - if err := writeCSV(w, res); err != nil { - s.log.Error("csv export", "err", err) - } - return - } - writeJSON(w, http.StatusOK, res) -} - -func (s *Server) handleListQueries(w http.ResponseWriter, r *http.Request) { - qs, err := s.store.List(r.Context()) - if err != nil { - writeError(w, http.StatusInternalServerError, "failed to list saved queries") - s.log.Error("list saved queries", "err", err) - return - } - if qs == nil { - qs = []store.SavedQuery{} - } - writeJSON(w, http.StatusOK, qs) -} - -type savedQueryRequest struct { - Name string `json:"name"` - Description string `json:"description"` - SQL string `json:"sql"` -} - -func (s *Server) handleCreateQuery(w http.ResponseWriter, r *http.Request) { - req, ok := decodeSavedQuery(w, r) - if !ok { - return - } - q, err := s.store.Create(r.Context(), req.Name, req.Description, req.SQL) - if err != nil { - writeError(w, http.StatusInternalServerError, "failed to save query") - s.log.Error("create saved query", "err", err) - return - } - writeJSON(w, http.StatusCreated, q) -} - -func (s *Server) handleUpdateQuery(w http.ResponseWriter, r *http.Request) { - id, ok := pathID(w, r) - if !ok { - return - } - req, ok := decodeSavedQuery(w, r) - if !ok { - return - } - q, err := s.store.Update(r.Context(), id, req.Name, req.Description, req.SQL) - if errors.Is(err, store.ErrNotFound) { - writeError(w, http.StatusNotFound, "saved query not found") - return - } - if err != nil { - writeError(w, http.StatusInternalServerError, "failed to update query") - s.log.Error("update saved query", "err", err) - return - } - writeJSON(w, http.StatusOK, q) -} - -func (s *Server) handleDeleteQuery(w http.ResponseWriter, r *http.Request) { - id, ok := pathID(w, r) - if !ok { - return - } - err := s.store.Delete(r.Context(), id) - if errors.Is(err, store.ErrNotFound) { - writeError(w, http.StatusNotFound, "saved query not found") - return - } - if err != nil { - writeError(w, http.StatusInternalServerError, "failed to delete query") - s.log.Error("delete saved query", "err", err) - return - } - w.WriteHeader(http.StatusNoContent) -} - -func (s *Server) handleReady(w http.ResponseWriter, r *http.Request) { - ctx, cancel := context.WithTimeout(r.Context(), 3*time.Second) - defer cancel() - if err := s.pool.Ping(ctx); err != nil { - writeError(w, http.StatusServiceUnavailable, "database not ready") - return - } - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte("ready")) -} - -// --- helpers --- - -func decodeSavedQuery(w http.ResponseWriter, r *http.Request) (savedQueryRequest, bool) { - var req savedQueryRequest - if !decodeBody(w, r, &req) { - return req, false - } - req.Name = strings.TrimSpace(req.Name) - req.SQL = strings.TrimSpace(req.SQL) - if req.Name == "" || req.SQL == "" { - writeError(w, http.StatusBadRequest, "name and sql are required") - return req, false - } - if err := guard.Validate(req.SQL); err != nil { - writeError(w, http.StatusBadRequest, "saved query must be read-only: "+err.Error()) - return req, false - } - return req, true -} - -// decodeBody reads a size-capped JSON body into v, rejecting unknown fields. It -// writes the error response itself and returns false on failure. -func decodeBody(w http.ResponseWriter, r *http.Request, v any) bool { - r.Body = http.MaxBytesReader(w, r.Body, maxBodyBytes) - dec := json.NewDecoder(r.Body) - dec.DisallowUnknownFields() - if err := dec.Decode(v); err != nil { - writeError(w, http.StatusBadRequest, "invalid JSON body: "+err.Error()) - return false - } - return true -} - -// safeFilename keeps only filename-safe characters, so a table name can't break -// out of the quoted Content-Disposition value or inject a different extension. -func safeFilename(name string) string { - out := strings.Map(func(r rune) rune { - switch { - case r >= 'a' && r <= 'z', r >= 'A' && r <= 'Z', r >= '0' && r <= '9', - r == '_', r == '-', r == '.': - return r - default: - return '_' - } - }, name) - if out == "" { - return "table" - } - return out -} - -// parseFilters turns repeated "f" params of the form "column:op[:value]" into -// db.Filter values. Column/op/value are validated downstream by db.TableRows. -func parseFilters(raw []string) []db.Filter { - if len(raw) == 0 { - return nil - } - out := make([]db.Filter, 0, len(raw)) - for _, s := range raw { - parts := strings.SplitN(s, ":", 3) - f := db.Filter{Column: parts[0]} - if len(parts) >= 2 { - f.Op = parts[1] - } - if len(parts) == 3 { - f.Value = parts[2] - } - out = append(out, f) - } - return out -} - -// queryInt parses a query-string integer, falling back to def on absence or -// parse error (the db layer clamps the actual bounds). -func queryInt(r *http.Request, key string, def int) int { - if v := r.URL.Query().Get(key); v != "" { - if n, err := strconv.Atoi(v); err == nil { - return n - } - } - return def -} - -func pathID(w http.ResponseWriter, r *http.Request) (int64, bool) { - id, err := strconv.ParseInt(r.PathValue("id"), 10, 64) - if err != nil { - writeError(w, http.StatusBadRequest, "invalid id") - return 0, false - } - return id, true -} - -func writeJSON(w http.ResponseWriter, status int, v any) { - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.WriteHeader(status) - _ = json.NewEncoder(w).Encode(v) -} - -func writeError(w http.ResponseWriter, status int, msg string) { - writeJSON(w, status, map[string]string{"error": msg}) -} - -func rejectRestrictedRelation(w http.ResponseWriter, table string) bool { - if !guard.IsRestrictedRelation(table) { - return false - } - writeError(w, http.StatusBadRequest, "restricted system catalog") - return true -} - -// logging is a minimal request logger. It never logs request bodies (which -// contain SQL) at info level beyond method/path. -func logging(log *slog.Logger, next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - start := time.Now() - sw := &statusWriter{ResponseWriter: w, status: http.StatusOK} - next.ServeHTTP(sw, r) - if r.URL.Path == "/healthz" || r.URL.Path == "/readyz" { - return // don't spam probe logs - } - log.Info("request", - "method", r.Method, - "path", r.URL.Path, - "status", sw.status, - "ms", time.Since(start).Milliseconds(), - ) - }) -} - -type statusWriter struct { - http.ResponseWriter - status int -} - -func (w *statusWriter) WriteHeader(code int) { - w.status = code - w.ResponseWriter.WriteHeader(code) -} - -// contentSecurityPolicy allows only the app's own assets. CodeMirror 6 is -// vendored and served from /vendor (no third-party CDN). The page script lives -// in /app.js (no 'unsafe-inline' for scripts); styles permit 'unsafe-inline' -// because CodeMirror injects them at runtime. -const contentSecurityPolicy = "default-src 'self'; " + - "script-src 'self'; " + - "style-src 'self' 'unsafe-inline'; " + - "img-src 'self' data:; " + - "connect-src 'self'; " + - "base-uri 'self'; " + - "form-action 'self'; " + - "frame-ancestors 'none'" - -// securityHeaders sets conservative defaults on every response. -func securityHeaders(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - h := w.Header() - h.Set("Content-Security-Policy", contentSecurityPolicy) - h.Set("X-Content-Type-Options", "nosniff") - h.Set("X-Frame-Options", "DENY") - h.Set("Referrer-Policy", "no-referrer") - h.Set("Cross-Origin-Opener-Policy", "same-origin") - // Advertise HSTS only on connections that actually reached us over TLS - // (direct TLS or via a TLS-terminating proxy that sets X-Forwarded-Proto). - if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" { - h.Set("Strict-Transport-Security", "max-age=63072000; includeSubDomains") - } - next.ServeHTTP(w, r) - }) -} diff --git a/internal/server/server_routes_test.go b/internal/server/server_routes_test.go new file mode 100644 index 0000000..760b673 --- /dev/null +++ b/internal/server/server_routes_test.go @@ -0,0 +1,80 @@ +package server + +import ( + "bytes" + "errors" + "io" + "net/http" + "strings" + "testing" + + "github.com/descope-sample-apps/pgpeek/internal/guard" + "github.com/descope-sample-apps/pgpeek/internal/store" +) + +func TestHealthz(t *testing.T) { + ts, _ := newTestServer(t, &fakeQuerier{}) + resp := mustGet(t, ts, "/healthz") + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Errorf("healthz = %d", resp.StatusCode) + } +} + +func TestReadyz(t *testing.T) { + t.Run("healthy", func(t *testing.T) { + ts, _ := newTestServer(t, &fakeQuerier{}) + resp := mustGet(t, ts, "/readyz") + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Errorf("readyz = %d, want 200", resp.StatusCode) + } + }) + t.Run("db down", func(t *testing.T) { + ts, _ := newTestServer(t, &fakeQuerier{pingErr: errors.New("down")}) + resp := mustGet(t, ts, "/readyz") + resp.Body.Close() + if resp.StatusCode != http.StatusServiceUnavailable { + t.Errorf("readyz = %d, want 503", resp.StatusCode) + } + }) +} + +func TestSecurityHeaders(t *testing.T) { + ts, _ := newTestServer(t, &fakeQuerier{}) + resp := mustGet(t, ts, "/healthz") + resp.Body.Close() + want := map[string]string{ + "X-Content-Type-Options": "nosniff", + "X-Frame-Options": "DENY", + "Referrer-Policy": "no-referrer", + } + for k, v := range want { + if got := resp.Header.Get(k); got != v { + t.Errorf("%s = %q, want %q", k, got, v) + } + } + if csp := resp.Header.Get("Content-Security-Policy"); !strings.Contains(csp, "default-src 'self'") { + t.Errorf("missing CSP: %q", csp) + } +} + +func TestUIServed(t *testing.T) { + ts, _ := newTestServer(t, &fakeQuerier{}) + resp := mustGet(t, ts, "/") + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + if !bytes.Contains(body, []byte("pgpeek")) { + t.Errorf("index not served: %q", body) + } +} + +// DefaultPresets must all survive the read-only guard, otherwise the saved-query +// validation would reject them if a user re-saved one. +func TestDefaultPresetsPassGuard(t *testing.T) { + for _, p := range store.DefaultPresets { + if err := guard.Validate(p.SQL); err != nil { + t.Errorf("preset %q fails guard: %v", p.Name, err) + } + } +} diff --git a/internal/server/server_test.go b/internal/server/server_test.go deleted file mode 100644 index 806565a..0000000 --- a/internal/server/server_test.go +++ /dev/null @@ -1,837 +0,0 @@ -package server - -import ( - "bytes" - "context" - "encoding/json" - "errors" - "io" - "log/slog" - "net/http" - "net/http/httptest" - "path/filepath" - "strconv" - "strings" - "testing" - "testing/fstest" - "time" - - "github.com/descope-sample-apps/pgpeek/internal/db" - "github.com/descope-sample-apps/pgpeek/internal/guard" - "github.com/descope-sample-apps/pgpeek/internal/store" -) - -type fakeQuerier struct { - result *db.Result - err error - pingErr error - called bool - tableRowsCalled bool - lastSQL string - tables []db.TableInfo - cols []db.ColumnInfo - fks []db.ForeignKey - catErr error - lastQuery db.TableQuery - lastArgs struct { - schema, table string - limit, offset int - } -} - -func (f *fakeQuerier) Query(_ context.Context, sql string) (*db.Result, error) { - f.called = true - f.lastSQL = sql - return f.result, f.err -} - -func (f *fakeQuerier) Tables(context.Context) ([]db.TableInfo, error) { - return f.tables, f.catErr -} - -func (f *fakeQuerier) Columns(_ context.Context, schema, table string) ([]db.ColumnInfo, error) { - f.lastArgs.schema, f.lastArgs.table = schema, table - return f.cols, f.catErr -} - -func (f *fakeQuerier) ForeignKeys(_ context.Context, _, _ string) ([]db.ForeignKey, error) { - return f.fks, f.catErr -} - -func (f *fakeQuerier) TableRows(_ context.Context, q db.TableQuery) (*db.Result, error) { - f.tableRowsCalled = true - f.lastQuery = q - f.lastArgs.schema, f.lastArgs.table = q.Schema, q.Table - f.lastArgs.limit, f.lastArgs.offset = q.Limit, q.Offset - return f.result, f.err -} - -func (f *fakeQuerier) RowCap() int { return 1000 } - -func (f *fakeQuerier) Ping(_ context.Context) error { return f.pingErr } - -func newTestServer(t *testing.T, q Querier) (*httptest.Server, *store.Store) { - t.Helper() - st, err := store.Open(filepath.Join(t.TempDir(), "t.db")) - if err != nil { - t.Fatalf("store.Open: %v", err) - } - t.Cleanup(func() { _ = st.Close() }) - - web := fstest.MapFS{ - "index.html": &fstest.MapFile{Data: []byte("pgpeek")}, - "app.js": &fstest.MapFile{Data: []byte("// js")}, - } - log := slog.New(slog.NewTextHandler(io.Discard, nil)) - srv := New(q, st, web, log, 5*time.Second) - ts := httptest.NewServer(srv.Routes()) - t.Cleanup(ts.Close) - return ts, st -} - -func post(t *testing.T, ts *httptest.Server, path, body string) *http.Response { - t.Helper() - resp, err := http.Post(ts.URL+path, "application/json", strings.NewReader(body)) - if err != nil { - t.Fatalf("POST %s: %v", path, err) - } - return resp -} - -func decode[T any](t *testing.T, resp *http.Response) T { - t.Helper() - defer resp.Body.Close() - var v T - if err := json.NewDecoder(resp.Body).Decode(&v); err != nil { - t.Fatalf("decode: %v", err) - } - return v -} - -func okResult() *db.Result { - return &db.Result{ - Columns: []string{"n"}, - Rows: [][]any{{int64(1)}}, - RowCount: 1, - ElapsedMS: 1, - } -} - -func TestQuery_OK(t *testing.T) { - q := &fakeQuerier{result: okResult()} - ts, _ := newTestServer(t, q) - - resp := post(t, ts, "/api/query", `{"sql":" SELECT 1 "}`) - if resp.StatusCode != http.StatusOK { - t.Fatalf("status = %d", resp.StatusCode) - } - res := decode[db.Result](t, resp) - if res.RowCount != 1 { - t.Errorf("rowCount = %d", res.RowCount) - } - if q.lastSQL != "SELECT 1" { - t.Errorf("SQL not trimmed before exec: %q", q.lastSQL) - } -} - -func TestQuery_GuardRejectsDML(t *testing.T) { - q := &fakeQuerier{result: okResult()} - ts, _ := newTestServer(t, q) - - resp := post(t, ts, "/api/query", `{"sql":"DELETE FROM users"}`) - if resp.StatusCode != http.StatusBadRequest { - t.Fatalf("status = %d, want 400", resp.StatusCode) - } - if q.called { - t.Error("guard should block the query before it reaches the database") - } -} - -func TestQuery_InvalidJSON(t *testing.T) { - ts, _ := newTestServer(t, &fakeQuerier{}) - resp := post(t, ts, "/api/query", `{not json`) - if resp.StatusCode != http.StatusBadRequest { - t.Fatalf("status = %d, want 400", resp.StatusCode) - } -} - -func TestQuery_UnknownField(t *testing.T) { - ts, _ := newTestServer(t, &fakeQuerier{}) - resp := post(t, ts, "/api/query", `{"sql":"SELECT 1","evil":true}`) - if resp.StatusCode != http.StatusBadRequest { - t.Fatalf("status = %d, want 400 (DisallowUnknownFields)", resp.StatusCode) - } -} - -func TestQuery_DBError(t *testing.T) { - q := &fakeQuerier{err: errors.New("boom")} - ts, _ := newTestServer(t, q) - resp := post(t, ts, "/api/query", `{"sql":"SELECT 1"}`) - if resp.StatusCode != http.StatusBadRequest { - t.Fatalf("status = %d, want 400", resp.StatusCode) - } -} - -func TestQuery_BodyTooLarge(t *testing.T) { - ts, _ := newTestServer(t, &fakeQuerier{result: okResult()}) - huge := strings.Repeat("a", (1<<20)+10) - resp := post(t, ts, "/api/query", `{"sql":"`+huge+`"}`) - if resp.StatusCode != http.StatusBadRequest { - t.Fatalf("status = %d, want 400 for oversized body", resp.StatusCode) - } -} - -func TestExport_CSV(t *testing.T) { - q := &fakeQuerier{result: &db.Result{ - Columns: []string{"name", "n"}, - Rows: [][]any{{"Acme", int64(2)}, {"Globex,Inc", int64(1)}}, - RowCount: 2, - }} - ts, _ := newTestServer(t, q) - - resp := post(t, ts, "/api/export", `{"sql":"SELECT 1"}`) - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - t.Fatalf("status = %d", resp.StatusCode) - } - if ct := resp.Header.Get("Content-Type"); !strings.HasPrefix(ct, "text/csv") { - t.Errorf("content-type = %q", ct) - } - if cd := resp.Header.Get("Content-Disposition"); !strings.Contains(cd, "pgpeek-export.csv") { - t.Errorf("content-disposition = %q", cd) - } - body, _ := io.ReadAll(resp.Body) - got := string(body) - if !strings.Contains(got, "name,n") || !strings.Contains(got, "Acme,2") { - t.Errorf("csv body = %q", got) - } - // Field with a comma must be quoted by encoding/csv. - if !strings.Contains(got, `"Globex,Inc"`) { - t.Errorf("comma field not quoted: %q", got) - } -} - -func TestExport_GuardRejects(t *testing.T) { - q := &fakeQuerier{result: okResult()} - ts, _ := newTestServer(t, q) - resp := post(t, ts, "/api/export", `{"sql":"DROP TABLE x"}`) - resp.Body.Close() - if resp.StatusCode != http.StatusBadRequest { - t.Fatalf("status = %d, want 400", resp.StatusCode) - } -} - -func TestSavedQueries_CRUD(t *testing.T) { - ts, _ := newTestServer(t, &fakeQuerier{result: okResult()}) - - // Create - resp := post(t, ts, "/api/queries", `{"name":"q","description":"d","sql":"SELECT 1"}`) - if resp.StatusCode != http.StatusCreated { - t.Fatalf("create status = %d", resp.StatusCode) - } - created := decode[store.SavedQuery](t, resp) - if created.ID == 0 { - t.Fatal("no id returned") - } - - // List - resp = mustGet(t, ts, "/api/queries") - list := decode[[]store.SavedQuery](t, resp) - if len(list) != 1 { - t.Fatalf("list len = %d", len(list)) - } - - // Update - req, _ := http.NewRequest(http.MethodPut, ts.URL+"/api/queries/"+itoa(created.ID), - strings.NewReader(`{"name":"q2","sql":"SELECT 2"}`)) - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatal(err) - } - if resp.StatusCode != http.StatusOK { - t.Fatalf("update status = %d", resp.StatusCode) - } - resp.Body.Close() - - // Delete - req, _ = http.NewRequest(http.MethodDelete, ts.URL+"/api/queries/"+itoa(created.ID), nil) - resp, err = http.DefaultClient.Do(req) - if err != nil { - t.Fatal(err) - } - if resp.StatusCode != http.StatusNoContent { - t.Fatalf("delete status = %d", resp.StatusCode) - } - resp.Body.Close() -} - -func TestSavedQueries_Validation(t *testing.T) { - ts, _ := newTestServer(t, &fakeQuerier{}) - - cases := []struct { - name, body string - }{ - {"missing name", `{"sql":"SELECT 1"}`}, - {"missing sql", `{"name":"x"}`}, - {"non-readonly sql", `{"name":"x","sql":"DELETE FROM t"}`}, - } - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - resp := post(t, ts, "/api/queries", c.body) - resp.Body.Close() - if resp.StatusCode != http.StatusBadRequest { - t.Errorf("status = %d, want 400", resp.StatusCode) - } - }) - } -} - -func TestSavedQueries_NotFoundAndBadID(t *testing.T) { - ts, _ := newTestServer(t, &fakeQuerier{}) - - req, _ := http.NewRequest(http.MethodDelete, ts.URL+"/api/queries/999", nil) - resp, _ := http.DefaultClient.Do(req) - if resp.StatusCode != http.StatusNotFound { - t.Errorf("delete missing status = %d, want 404", resp.StatusCode) - } - resp.Body.Close() - - req, _ = http.NewRequest(http.MethodPut, ts.URL+"/api/queries/abc", - strings.NewReader(`{"name":"x","sql":"SELECT 1"}`)) - resp, _ = http.DefaultClient.Do(req) - if resp.StatusCode != http.StatusBadRequest { - t.Errorf("bad id status = %d, want 400", resp.StatusCode) - } - resp.Body.Close() -} - -func TestHealthz(t *testing.T) { - ts, _ := newTestServer(t, &fakeQuerier{}) - resp := mustGet(t, ts, "/healthz") - resp.Body.Close() - if resp.StatusCode != http.StatusOK { - t.Errorf("healthz = %d", resp.StatusCode) - } -} - -func TestReadyz(t *testing.T) { - t.Run("healthy", func(t *testing.T) { - ts, _ := newTestServer(t, &fakeQuerier{}) - resp := mustGet(t, ts, "/readyz") - resp.Body.Close() - if resp.StatusCode != http.StatusOK { - t.Errorf("readyz = %d, want 200", resp.StatusCode) - } - }) - t.Run("db down", func(t *testing.T) { - ts, _ := newTestServer(t, &fakeQuerier{pingErr: errors.New("down")}) - resp := mustGet(t, ts, "/readyz") - resp.Body.Close() - if resp.StatusCode != http.StatusServiceUnavailable { - t.Errorf("readyz = %d, want 503", resp.StatusCode) - } - }) -} - -func TestSecurityHeaders(t *testing.T) { - ts, _ := newTestServer(t, &fakeQuerier{}) - resp := mustGet(t, ts, "/healthz") - resp.Body.Close() - want := map[string]string{ - "X-Content-Type-Options": "nosniff", - "X-Frame-Options": "DENY", - "Referrer-Policy": "no-referrer", - } - for k, v := range want { - if got := resp.Header.Get(k); got != v { - t.Errorf("%s = %q, want %q", k, got, v) - } - } - if csp := resp.Header.Get("Content-Security-Policy"); !strings.Contains(csp, "default-src 'self'") { - t.Errorf("missing CSP: %q", csp) - } -} - -func TestHSTS(t *testing.T) { - srv := serverWithStore(t, &fakeQuerier{}, &fakeStore{}) - - // No TLS signal -> no HSTS header. - rec := httptest.NewRecorder() - srv.Routes().ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/healthz", nil)) - if h := rec.Header().Get("Strict-Transport-Security"); h != "" { - t.Errorf("HSTS set on plaintext request: %q", h) - } - - // X-Forwarded-Proto: https (TLS-terminating proxy) -> HSTS present. - req := httptest.NewRequest(http.MethodGet, "/healthz", nil) - req.Header.Set("X-Forwarded-Proto", "https") - rec = httptest.NewRecorder() - srv.Routes().ServeHTTP(rec, req) - if h := rec.Header().Get("Strict-Transport-Security"); h != "max-age=63072000; includeSubDomains" { - t.Errorf("HSTS = %q, want max-age=63072000; includeSubDomains", h) - } -} - -func TestUIServed(t *testing.T) { - ts, _ := newTestServer(t, &fakeQuerier{}) - resp := mustGet(t, ts, "/") - defer resp.Body.Close() - body, _ := io.ReadAll(resp.Body) - if !bytes.Contains(body, []byte("pgpeek")) { - t.Errorf("index not served: %q", body) - } -} - -// DefaultPresets must all survive the read-only guard, otherwise the saved-query -// validation would reject them if a user re-saved one. -func TestDefaultPresetsPassGuard(t *testing.T) { - for _, p := range store.DefaultPresets { - if err := guard.Validate(p.SQL); err != nil { - t.Errorf("preset %q fails guard: %v", p.Name, err) - } - } -} - -// --- error-path coverage with a fake store and failing writer ------------ - -type fakeStore struct { - listErr, getErr, createErr, updateErr, deleteErr error -} - -func (f *fakeStore) List(context.Context) ([]store.SavedQuery, error) { - return nil, f.listErr -} -func (f *fakeStore) Get(context.Context, int64) (store.SavedQuery, error) { - return store.SavedQuery{}, f.getErr -} -func (f *fakeStore) Create(context.Context, string, string, string) (store.SavedQuery, error) { - return store.SavedQuery{}, f.createErr -} -func (f *fakeStore) Update(context.Context, int64, string, string, string) (store.SavedQuery, error) { - return store.SavedQuery{}, f.updateErr -} -func (f *fakeStore) Delete(context.Context, int64) error { return f.deleteErr } - -func serverWithStore(t *testing.T, q Querier, st QueryStore) *Server { - t.Helper() - web := fstest.MapFS{"index.html": &fstest.MapFile{Data: []byte("x")}} - log := slog.New(slog.NewTextHandler(io.Discard, nil)) - return New(q, st, web, log, time.Second) -} - -func TestStoreErrorPaths(t *testing.T) { - boom := errors.New("db down") - srv := serverWithStore(t, &fakeQuerier{}, &fakeStore{ - listErr: boom, createErr: boom, updateErr: boom, deleteErr: boom, - }) - mux := srv.Routes() - - do := func(method, path, body string) int { - var rdr io.Reader - if body != "" { - rdr = strings.NewReader(body) - } - req := httptest.NewRequest(method, path, rdr) - rec := httptest.NewRecorder() - mux.ServeHTTP(rec, req) - return rec.Code - } - - if code := do(http.MethodGet, "/api/queries", ""); code != http.StatusInternalServerError { - t.Errorf("list error code = %d, want 500", code) - } - if code := do(http.MethodPost, "/api/queries", `{"name":"x","sql":"SELECT 1"}`); code != http.StatusInternalServerError { - t.Errorf("create error code = %d, want 500", code) - } - if code := do(http.MethodPut, "/api/queries/1", `{"name":"x","sql":"SELECT 1"}`); code != http.StatusInternalServerError { - t.Errorf("update error code = %d, want 500", code) - } - if code := do(http.MethodDelete, "/api/queries/1", ""); code != http.StatusInternalServerError { - t.Errorf("delete error code = %d, want 500", code) - } -} - -// --- catalog / browse endpoints ----------------------------------------- - -func TestMeta(t *testing.T) { - ts, _ := newTestServer(t, &fakeQuerier{}) - resp := mustGet(t, ts, "/api/meta") - got := decode[map[string]int](t, resp) - if got["rowCap"] != 1000 { - t.Errorf("rowCap = %d, want 1000", got["rowCap"]) - } -} - -func TestSafeFilename(t *testing.T) { - cases := map[string]string{ - "users": "users", - `a"; x`: "a___x", - "sch.tbl": "sch.tbl", - "évil/../x": "_vil_.._x", - "": "table", - "weird name!@#$": "weird_name____", - } - for in, want := range cases { - if got := safeFilename(in); got != want { - t.Errorf("safeFilename(%q) = %q, want %q", in, got, want) - } - } -} - -func TestTables_OK(t *testing.T) { - q := &fakeQuerier{tables: []db.TableInfo{{Schema: "public", Name: "users", Type: "table", EstRows: 5}}} - ts, _ := newTestServer(t, q) - resp := mustGet(t, ts, "/api/tables") - if resp.StatusCode != http.StatusOK { - t.Fatalf("status = %d", resp.StatusCode) - } - got := decode[[]db.TableInfo](t, resp) - if len(got) != 1 || got[0].Name != "users" { - t.Errorf("tables = %+v", got) - } -} - -func TestTables_EmptyReturnsArray(t *testing.T) { - srv := serverWithStore(t, &fakeQuerier{}, &fakeStore{}) - rec := httptest.NewRecorder() - srv.Routes().ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/api/tables", nil)) - if strings.TrimSpace(rec.Body.String()) != "[]" { - t.Errorf("body = %q, want []", rec.Body.String()) - } -} - -func TestTables_Error(t *testing.T) { - ts, _ := newTestServer(t, &fakeQuerier{catErr: errors.New("boom")}) - resp := mustGet(t, ts, "/api/tables") - resp.Body.Close() - if resp.StatusCode != http.StatusInternalServerError { - t.Errorf("status = %d, want 500", resp.StatusCode) - } -} - -func TestColumns_OK(t *testing.T) { - q := &fakeQuerier{cols: []db.ColumnInfo{{Name: "id", Type: "integer"}}} - ts, _ := newTestServer(t, q) - resp := mustGet(t, ts, "/api/tables/public/users/columns") - got := decode[[]db.ColumnInfo](t, resp) - if len(got) != 1 || got[0].Name != "id" { - t.Errorf("columns = %+v", got) - } - if q.lastArgs.schema != "public" || q.lastArgs.table != "users" { - t.Errorf("path values not passed: %+v", q.lastArgs) - } -} - -func TestColumns_EmptyReturnsArray(t *testing.T) { - srv := serverWithStore(t, &fakeQuerier{}, &fakeStore{}) - rec := httptest.NewRecorder() - srv.Routes().ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/api/tables/public/users/columns", nil)) - if strings.TrimSpace(rec.Body.String()) != "[]" { - t.Errorf("body = %q, want []", rec.Body.String()) - } -} - -func TestColumns_Error(t *testing.T) { - ts, _ := newTestServer(t, &fakeQuerier{catErr: errors.New("boom")}) - resp := mustGet(t, ts, "/api/tables/public/users/columns") - resp.Body.Close() - if resp.StatusCode != http.StatusInternalServerError { - t.Errorf("status = %d, want 500", resp.StatusCode) - } -} - -func TestForeignKeys_OK(t *testing.T) { - q := &fakeQuerier{fks: []db.ForeignKey{{Column: "company_id", RefSchema: "public", RefTable: "companies", RefColumn: "id"}}} - ts, _ := newTestServer(t, q) - resp := mustGet(t, ts, "/api/tables/public/users/fks") - got := decode[[]db.ForeignKey](t, resp) - if len(got) != 1 || got[0].RefTable != "companies" { - t.Errorf("fks = %+v", got) - } -} - -func TestForeignKeys_EmptyReturnsArray(t *testing.T) { - srv := serverWithStore(t, &fakeQuerier{}, &fakeStore{}) - rec := httptest.NewRecorder() - srv.Routes().ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/api/tables/public/users/fks", nil)) - if strings.TrimSpace(rec.Body.String()) != "[]" { - t.Errorf("body = %q, want []", rec.Body.String()) - } -} - -func TestForeignKeys_Error(t *testing.T) { - ts, _ := newTestServer(t, &fakeQuerier{catErr: errors.New("boom")}) - resp := mustGet(t, ts, "/api/tables/public/users/fks") - resp.Body.Close() - if resp.StatusCode != http.StatusInternalServerError { - t.Errorf("status = %d, want 500", resp.StatusCode) - } -} - -func TestCatalogEndpoints_RejectRestrictedCatalogs(t *testing.T) { - q := &fakeQuerier{result: okResult()} - ts, _ := newTestServer(t, q) - - paths := []string{ - "/api/tables/pg_catalog/pg_authid/columns", - "/api/tables/pg_catalog/pg_shadow/fks", - "/api/tables/pg_catalog/pg_hba_file_rules/data", - } - for _, path := range paths { - resp := mustGet(t, ts, path) - resp.Body.Close() - if resp.StatusCode != http.StatusBadRequest { - t.Errorf("%s status = %d, want 400", path, resp.StatusCode) - } - } - if q.tableRowsCalled { - t.Error("restricted catalog data request reached TableRows") - } -} - -func TestTableData_OK(t *testing.T) { - q := &fakeQuerier{result: okResult()} - ts, _ := newTestServer(t, q) - resp := mustGet(t, ts, "/api/tables/public/users/data?limit=25&offset=50") - if resp.StatusCode != http.StatusOK { - t.Fatalf("status = %d", resp.StatusCode) - } - decode[db.Result](t, resp) - if q.lastArgs.limit != 25 || q.lastArgs.offset != 50 { - t.Errorf("limit/offset = %d/%d", q.lastArgs.limit, q.lastArgs.offset) - } -} - -func TestTableData_ParsesSearchSortFilters(t *testing.T) { - q := &fakeQuerier{result: okResult()} - ts, _ := newTestServer(t, q) - resp := mustGet(t, ts, "/api/tables/public/users/data?search=acme&sort=id&dir=desc&f=id:gt:100&f=name:ilike:%25a%25&f=deleted_at:is_null") - resp.Body.Close() - if resp.StatusCode != http.StatusOK { - t.Fatalf("status = %d", resp.StatusCode) - } - lq := q.lastQuery - if lq.Search != "acme" || lq.Sort != "id" || !lq.Desc { - t.Errorf("search/sort/desc = %+v", lq) - } - if len(lq.Filters) != 3 { - t.Fatalf("filters = %+v", lq.Filters) - } - if lq.Filters[0] != (db.Filter{Column: "id", Op: "gt", Value: "100"}) { - t.Errorf("filter0 = %+v", lq.Filters[0]) - } - if lq.Filters[1] != (db.Filter{Column: "name", Op: "ilike", Value: "%a%"}) { - t.Errorf("filter1 = %+v", lq.Filters[1]) - } - if lq.Filters[2] != (db.Filter{Column: "deleted_at", Op: "is_null"}) { - t.Errorf("filter2 (no value) = %+v", lq.Filters[2]) - } -} - -func TestParseFilters(t *testing.T) { - if parseFilters(nil) != nil { - t.Error("nil input should yield nil") - } - got := parseFilters([]string{"id:gt:100", "name:is_null", "bare"}) - want := []db.Filter{ - {Column: "id", Op: "gt", Value: "100"}, - {Column: "name", Op: "is_null"}, - {Column: "bare"}, - } - if len(got) != len(want) { - t.Fatalf("len = %d", len(got)) - } - for i := range want { - if got[i] != want[i] { - t.Errorf("filter %d = %+v, want %+v", i, got[i], want[i]) - } - } -} - -func TestTableData_DefaultsAndBadParams(t *testing.T) { - q := &fakeQuerier{result: okResult()} - ts, _ := newTestServer(t, q) - // Non-numeric params fall back to defaults (0 -> pool clamps). - resp := mustGet(t, ts, "/api/tables/public/users/data?limit=abc") - resp.Body.Close() - if q.lastArgs.limit != 0 || q.lastArgs.offset != 0 { - t.Errorf("expected defaults, got %d/%d", q.lastArgs.limit, q.lastArgs.offset) - } -} - -func TestTableData_Error(t *testing.T) { - ts, _ := newTestServer(t, &fakeQuerier{err: errors.New("no such table")}) - resp := mustGet(t, ts, "/api/tables/public/nope/data") - body, _ := io.ReadAll(resp.Body) - resp.Body.Close() - if resp.StatusCode != http.StatusBadRequest { - t.Errorf("status = %d, want 400", resp.StatusCode) - } - if strings.Contains(string(body), "no such table") { - t.Errorf("table data error leaked database detail: %s", body) - } -} - -func TestTableData_CSV(t *testing.T) { - q := &fakeQuerier{result: &db.Result{Columns: []string{"a"}, Rows: [][]any{{"1"}}, RowCount: 1}} - ts, _ := newTestServer(t, q) - resp := mustGet(t, ts, "/api/tables/public/users/data?format=csv") - defer resp.Body.Close() - if ct := resp.Header.Get("Content-Type"); !strings.HasPrefix(ct, "text/csv") { - t.Errorf("content-type = %q", ct) - } - if cd := resp.Header.Get("Content-Disposition"); !strings.Contains(cd, "users.csv") { - t.Errorf("disposition = %q", cd) - } - body, _ := io.ReadAll(resp.Body) - if !strings.Contains(string(body), "a\n1\n") { - t.Errorf("csv body = %q", body) - } -} - -func TestTableData_CSVWriteFailure(t *testing.T) { - srv := serverWithStore(t, &fakeQuerier{result: okResult()}, &fakeStore{}) - req := httptest.NewRequest(http.MethodGet, "/api/tables/public/users/data?format=csv", nil) - fw := &failingWriter{} - srv.handleTableData(fw, req) - if ct := fw.Header().Get("Content-Type"); !strings.HasPrefix(ct, "text/csv") { - t.Errorf("content-type = %q", ct) - } -} - -func TestListEmptyReturnsArray(t *testing.T) { - // A successful List of an empty store returns nil; the handler must emit []. - srv := serverWithStore(t, &fakeQuerier{}, &fakeStore{}) - req := httptest.NewRequest(http.MethodGet, "/api/queries", nil) - rec := httptest.NewRecorder() - srv.Routes().ServeHTTP(rec, req) - if rec.Code != http.StatusOK { - t.Fatalf("code = %d", rec.Code) - } - if body := strings.TrimSpace(rec.Body.String()); body != "[]" { - t.Errorf("body = %q, want []", body) - } -} - -func TestUpdate_InvalidJSON(t *testing.T) { - ts, _ := newTestServer(t, &fakeQuerier{}) - req, _ := http.NewRequest(http.MethodPut, ts.URL+"/api/queries/1", strings.NewReader(`{bad`)) - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatal(err) - } - resp.Body.Close() - if resp.StatusCode != http.StatusBadRequest { - t.Errorf("status = %d, want 400", resp.StatusCode) - } -} - -func TestUpdate_NotFound(t *testing.T) { - ts, _ := newTestServer(t, &fakeQuerier{}) - req, _ := http.NewRequest(http.MethodPut, ts.URL+"/api/queries/999", strings.NewReader(`{"name":"x","sql":"SELECT 1"}`)) - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatal(err) - } - resp.Body.Close() - if resp.StatusCode != http.StatusNotFound { - t.Errorf("status = %d, want 404", resp.StatusCode) - } -} - -func TestDelete_BadID(t *testing.T) { - ts, _ := newTestServer(t, &fakeQuerier{}) - req, _ := http.NewRequest(http.MethodDelete, ts.URL+"/api/queries/abc", nil) - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatal(err) - } - resp.Body.Close() - if resp.StatusCode != http.StatusBadRequest { - t.Errorf("status = %d, want 400", resp.StatusCode) - } -} - -func TestExport_QueryError(t *testing.T) { - ts, _ := newTestServer(t, &fakeQuerier{err: errors.New("boom")}) - resp := post(t, ts, "/api/export", `{"sql":"SELECT 1"}`) - resp.Body.Close() - if resp.StatusCode != http.StatusBadRequest { - t.Errorf("status = %d, want 400", resp.StatusCode) - } -} - -func TestExport_InvalidJSON(t *testing.T) { - ts, _ := newTestServer(t, &fakeQuerier{}) - resp := post(t, ts, "/api/export", `{bad`) - resp.Body.Close() - if resp.StatusCode != http.StatusBadRequest { - t.Errorf("status = %d, want 400", resp.StatusCode) - } -} - -func TestSavedQuery_InvalidJSON(t *testing.T) { - ts, _ := newTestServer(t, &fakeQuerier{}) - resp := post(t, ts, "/api/queries", `{bad`) - resp.Body.Close() - if resp.StatusCode != http.StatusBadRequest { - t.Errorf("status = %d, want 400", resp.StatusCode) - } -} - -// failingWriter is an http.ResponseWriter whose body writes always fail. -type failingWriter struct { - header http.Header - code int -} - -func (f *failingWriter) Header() http.Header { - if f.header == nil { - f.header = http.Header{} - } - return f.header -} -func (f *failingWriter) Write([]byte) (int, error) { return 0, errors.New("connection reset") } -func (f *failingWriter) WriteHeader(code int) { f.code = code } - -func TestExport_WriteFailureLogged(t *testing.T) { - // Exercise the handler's csv-error branch directly with a writer that fails. - srv := serverWithStore(t, &fakeQuerier{result: okResult()}, &fakeStore{}) - req := httptest.NewRequest(http.MethodPost, "/api/export", strings.NewReader(`{"sql":"SELECT 1"}`)) - fw := &failingWriter{} - srv.handleExport(fw, req) - // Header is set before the failing body write. - if ct := fw.Header().Get("Content-Type"); !strings.HasPrefix(ct, "text/csv") { - t.Errorf("content-type = %q", ct) - } -} - -func TestWriteCSV(t *testing.T) { - res := &db.Result{Columns: []string{"a"}, Rows: [][]any{{"1"}}} - var buf strings.Builder - if err := writeCSV(&buf, res); err != nil { - t.Fatalf("writeCSV: %v", err) - } - if !strings.Contains(buf.String(), "a\n1\n") { - t.Errorf("csv = %q", buf.String()) - } - // Failing writer -> error surfaces via cw.Error() after Flush. - if err := writeCSV(failWriter{}, res); err == nil { - t.Error("expected error from failing writer") - } -} - -type failWriter struct{} - -func (failWriter) Write([]byte) (int, error) { return 0, errors.New("boom") } - -func mustGet(t *testing.T, ts *httptest.Server, path string) *http.Response { - t.Helper() - resp, err := http.Get(ts.URL + path) - if err != nil { - t.Fatalf("GET %s: %v", path, err) - } - return resp -} - -func itoa(i int64) string { return strconv.FormatInt(i, 10) } diff --git a/internal/server/store_error_test.go b/internal/server/store_error_test.go new file mode 100644 index 0000000..9ae40eb --- /dev/null +++ b/internal/server/store_error_test.go @@ -0,0 +1,75 @@ +package server + +import ( + "context" + "errors" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "strings" + "testing" + "testing/fstest" + "time" + + "github.com/descope-sample-apps/pgpeek/internal/store" +) + +// --- error-path coverage with a fake store and failing writer ------------ + +type fakeStore struct { + listErr, getErr, createErr, updateErr, deleteErr error +} + +func (f *fakeStore) List(context.Context) ([]store.SavedQuery, error) { + return nil, f.listErr +} +func (f *fakeStore) Get(context.Context, int64) (store.SavedQuery, error) { + return store.SavedQuery{}, f.getErr +} +func (f *fakeStore) Create(context.Context, string, string, string) (store.SavedQuery, error) { + return store.SavedQuery{}, f.createErr +} +func (f *fakeStore) Update(context.Context, int64, string, string, string) (store.SavedQuery, error) { + return store.SavedQuery{}, f.updateErr +} +func (f *fakeStore) Delete(context.Context, int64) error { return f.deleteErr } + +func serverWithStore(t *testing.T, q Querier, st QueryStore) *Server { + t.Helper() + web := fstest.MapFS{"index.html": &fstest.MapFile{Data: []byte("x")}} + log := slog.New(slog.NewTextHandler(io.Discard, nil)) + return New(q, st, web, log, time.Second) +} + +func TestStoreErrorPaths(t *testing.T) { + boom := errors.New("db down") + srv := serverWithStore(t, &fakeQuerier{}, &fakeStore{ + listErr: boom, createErr: boom, updateErr: boom, deleteErr: boom, + }) + mux := srv.Routes() + + do := func(method, path, body string) int { + var rdr io.Reader + if body != "" { + rdr = strings.NewReader(body) + } + req := httptest.NewRequest(method, path, rdr) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + return rec.Code + } + + if code := do(http.MethodGet, "/api/queries", ""); code != http.StatusInternalServerError { + t.Errorf("list error code = %d, want 500", code) + } + if code := do(http.MethodPost, "/api/queries", `{"name":"x","sql":"SELECT 1"}`); code != http.StatusInternalServerError { + t.Errorf("create error code = %d, want 500", code) + } + if code := do(http.MethodPut, "/api/queries/1", `{"name":"x","sql":"SELECT 1"}`); code != http.StatusInternalServerError { + t.Errorf("update error code = %d, want 500", code) + } + if code := do(http.MethodDelete, "/api/queries/1", ""); code != http.StatusInternalServerError { + t.Errorf("delete error code = %d, want 500", code) + } +} diff --git a/main.go b/main.go index 61a1653..0c6b038 100644 --- a/main.go +++ b/main.go @@ -14,6 +14,7 @@ import ( "context" "embed" "errors" + "fmt" "io/fs" "log/slog" "net/http" @@ -50,32 +51,47 @@ func run(ctx context.Context, log *slog.Logger) error { return err } - dbCfg := db.Config{ - DSN: cfg.DB.DSN, - MaxConns: cfg.DB.MaxConns, - StatementTimeout: cfg.DB.StatementTimeout, - IdleTxTimeout: cfg.DB.IdleTxTimeout, - RowCap: cfg.RowCap, + signalCtx, stop := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM) + defer stop() + + entries := make([]db.RegistryEntry, 0, len(cfg.Databases)) + closeEntries := func() { + for i := len(entries) - 1; i >= 0; i-- { + entries[i].Pool.Close() + } } - if cfg.DB.IAMAuth { - provider, perr := awsauth.New(ctx, cfg.DB.Region) + for _, entry := range cfg.Databases { + dbCfg := db.Config{ + DSN: entry.DSN, + MaxConns: cfg.DB.MaxConns, + StatementTimeout: cfg.DB.StatementTimeout, + IdleTxTimeout: cfg.DB.IdleTxTimeout, + RowCap: cfg.RowCap, + } + if entry.IAMAuth { + provider, perr := awsauth.New(signalCtx, entry.Region) + if perr != nil { + closeEntries() + return fmt.Errorf("create IAM auth provider for database %q: %w", entry.ID, perr) + } + dbCfg.BeforeConnect = provider.BeforeConnect + log.Info("RDS IAM authentication enabled", "databaseID", entry.ID, "databaseName", entry.Name) + } + log.Info("connecting database", "databaseID", entry.ID, "databaseName", entry.Name, "default", entry.ID == cfg.DefaultDatabaseID) + pool, perr := db.New(signalCtx, dbCfg) if perr != nil { - return perr + closeEntries() + return fmt.Errorf("connect database %q: %w", entry.ID, perr) } - dbCfg.BeforeConnect = provider.BeforeConnect - log.Info("RDS IAM authentication enabled", "region", cfg.DB.Region) + entries = append(entries, db.RegistryEntry{ID: entry.ID, Name: entry.Name, Pool: pool, Default: entry.ID == cfg.DefaultDatabaseID}) } - - signalCtx, stop := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM) - defer stop() - - pool, err := db.New(signalCtx, dbCfg) + registry, err := db.NewRegistry(entries) if err != nil { - return err + closeEntries() + return fmt.Errorf("create database registry: %w", err) } - defer pool.Close() - log.Info("connected to database", "maxConns", cfg.DB.MaxConns, "rowCap", cfg.RowCap, - "statementTimeout", cfg.DB.StatementTimeout.String(), "iamAuth", cfg.DB.IAMAuth) + defer registry.Close() + log.Info("database registry ready", "databaseCount", len(entries), "defaultDatabaseID", registry.DefaultID()) st, err := store.Open(cfg.StorePath) if err != nil { @@ -87,7 +103,7 @@ func run(ctx context.Context, log *slog.Logger) error { } web := mustSubFS(webFiles, "web") - srv := server.New(pool, st, web, log, cfg.DB.StatementTimeout+5*time.Second) + srv := server.NewWithRegistry(server.NewDatabaseRegistry(registry), st, web, log, cfg.DB.StatementTimeout+5*time.Second) httpSrv := &http.Server{ Addr: cfg.Server.Listen, Handler: srv.Routes(), diff --git a/main_test.go b/main_test.go index 2f455dd..afa40b1 100644 --- a/main_test.go +++ b/main_test.go @@ -14,6 +14,8 @@ import ( "net/http" "os" "path/filepath" + "strconv" + "strings" "testing" "time" @@ -29,9 +31,37 @@ func clearAppEnv(t *testing.T) { "PGPEEK_ROW_CAP", "PGPEEK_MAX_CONNS", "PGPEEK_STATEMENT_TIMEOUT", "PGPEEK_DB_IAM_AUTH", "PGPEEK_AWS_REGION", "AWS_REGION", "PGPEEK_TLS_CERT_FILE", "PGPEEK_TLS_KEY_FILE", + "PGPEEK_DATABASE_URLS", "PGPEEK_DATABASE_IDS", "PGPEEK_DATABASE_NAMES", + "PGPEEK_DATABASES_FILE", "PGPEEK_DEFAULT_DATABASE", } { t.Setenv(k, "") } + for i := 1; i <= 64; i++ { + suffix := strconv.Itoa(i) + t.Setenv("PGPEEK_DATABASE_URL_"+suffix, "") + t.Setenv("PGPEEK_DATABASE_URL_"+suffix+"_FILE", "") + t.Setenv("PGPEEK_DATABASE_ID_"+suffix, "") + t.Setenv("PGPEEK_DATABASE_NAME_"+suffix, "") + } +} + +func TestClearAppEnv_clearsMultiDatabaseVars(t *testing.T) { + // Given + t.Setenv("PGPEEK_DATABASE_URLS", "postgres://u:p@127.0.0.1:1/one") + t.Setenv("PGPEEK_DATABASE_IDS", "one") + t.Setenv("PGPEEK_DATABASE_URL_1", "postgres://u:p@127.0.0.1:1/two") + t.Setenv("PGPEEK_DATABASE_ID_1", "two") + t.Setenv("PGPEEK_DEFAULT_DATABASE", "one") + + // When + clearAppEnv(t) + + // Then + for _, key := range []string{"PGPEEK_DATABASE_URLS", "PGPEEK_DATABASE_IDS", "PGPEEK_DATABASE_URL_1", "PGPEEK_DATABASE_ID_1", "PGPEEK_DEFAULT_DATABASE"} { + if got := os.Getenv(key); got != "" { + t.Fatalf("%s=%q, want empty", key, got) + } + } } // --- serve() ------------------------------------------------------------- @@ -112,6 +142,27 @@ func TestRun_DBConnectError(t *testing.T) { } } +func TestRun_MultiDatabaseConnectError(t *testing.T) { + clearAppEnv(t) + // Given: two configured database entries, second selected as default. + t.Setenv("PGPEEK_DATABASE_URLS", "postgres://u:p@127.0.0.1:1/one?connect_timeout=1&sslmode=disable,postgres://u:p@127.0.0.1:2/two?connect_timeout=1&sslmode=disable") + t.Setenv("PGPEEK_DATABASE_IDS", "one,two") + t.Setenv("PGPEEK_DATABASE_NAMES", "One,Two") + t.Setenv("PGPEEK_DEFAULT_DATABASE", "two") + t.Setenv("PGPEEK_STORE_PATH", filepath.Join(t.TempDir(), "s.db")) + + // When + err := run(context.Background(), testLogger()) + + // Then + if err == nil { + t.Fatal("expected multi-database connect error") + } + if !strings.Contains(err.Error(), `database "one"`) { + t.Fatalf("error %q, want first database id", err) + } +} + func TestRun_IAMPathThenDBError(t *testing.T) { clearAppEnv(t) t.Setenv("DATABASE_URL", "postgres://u@127.0.0.1:1/db?connect_timeout=1&sslmode=disable") diff --git a/package-lock.json b/package-lock.json index 98b9c92..f2a95c6 100644 --- a/package-lock.json +++ b/package-lock.json @@ -2293,490 +2293,6 @@ } } }, - "node_modules/vite/node_modules/@esbuild/aix-ppc64": { - "version": "0.27.7", - "resolved": "https://registry.npmjs.org/@esbuild/aix-ppc64/-/aix-ppc64-0.27.7.tgz", - "integrity": "sha512-EKX3Qwmhz1eMdEJokhALr0YiD0lhQNwDqkPYyPhiSwKrh7/4KRjQc04sZ8db+5DVVnZ1LmbNDI1uAMPEUBnQPg==", - "cpu": [ - "ppc64" - ], - "dev": true, - "license": "MIT", - "optional": true, - "os": [ - "aix" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/vite/node_modules/@esbuild/android-arm": { - "version": "0.27.7", - "resolved": "https://registry.npmjs.org/@esbuild/android-arm/-/android-arm-0.27.7.tgz", - "integrity": "sha512-jbPXvB4Yj2yBV7HUfE2KHe4GJX51QplCN1pGbYjvsyCZbQmies29EoJbkEc+vYuU5o45AfQn37vZlyXy4YJ8RQ==", - "cpu": [ - "arm" - ], - "dev": true, - "license": "MIT", - "optional": true, - "os": [ - "android" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/vite/node_modules/@esbuild/android-arm64": { - "version": "0.27.7", - "resolved": "https://registry.npmjs.org/@esbuild/android-arm64/-/android-arm64-0.27.7.tgz", - "integrity": "sha512-62dPZHpIXzvChfvfLJow3q5dDtiNMkwiRzPylSCfriLvZeq0a1bWChrGx/BbUbPwOrsWKMn8idSllklzBy+dgQ==", - "cpu": [ - "arm64" - ], - "dev": true, - "license": "MIT", - "optional": true, - "os": [ - "android" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/vite/node_modules/@esbuild/android-x64": { - "version": "0.27.7", - "resolved": "https://registry.npmjs.org/@esbuild/android-x64/-/android-x64-0.27.7.tgz", - "integrity": "sha512-x5VpMODneVDb70PYV2VQOmIUUiBtY3D3mPBG8NxVk5CogneYhkR7MmM3yR/uMdITLrC1ml/NV1rj4bMJuy9MCg==", - "cpu": [ - "x64" - ], - "dev": true, - "license": "MIT", - "optional": true, - "os": [ - "android" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/vite/node_modules/@esbuild/darwin-arm64": { - "version": "0.27.7", - "resolved": "https://registry.npmjs.org/@esbuild/darwin-arm64/-/darwin-arm64-0.27.7.tgz", - "integrity": "sha512-5lckdqeuBPlKUwvoCXIgI2D9/ABmPq3Rdp7IfL70393YgaASt7tbju3Ac+ePVi3KDH6N2RqePfHnXkaDtY9fkw==", - "cpu": [ - "arm64" - ], - "dev": true, - "license": "MIT", - "optional": true, - "os": [ - "darwin" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/vite/node_modules/@esbuild/darwin-x64": { - "version": "0.27.7", - "resolved": "https://registry.npmjs.org/@esbuild/darwin-x64/-/darwin-x64-0.27.7.tgz", - "integrity": "sha512-rYnXrKcXuT7Z+WL5K980jVFdvVKhCHhUwid+dDYQpH+qu+TefcomiMAJpIiC2EM3Rjtq0sO3StMV/+3w3MyyqQ==", - "cpu": [ - "x64" - ], - "dev": true, - "license": "MIT", - "optional": true, - "os": [ - "darwin" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/vite/node_modules/@esbuild/freebsd-arm64": { - "version": "0.27.7", - "resolved": "https://registry.npmjs.org/@esbuild/freebsd-arm64/-/freebsd-arm64-0.27.7.tgz", - "integrity": "sha512-B48PqeCsEgOtzME2GbNM2roU29AMTuOIN91dsMO30t+Ydis3z/3Ngoj5hhnsOSSwNzS+6JppqWsuhTp6E82l2w==", - "cpu": [ - "arm64" - ], - "dev": true, - "license": "MIT", - "optional": true, - "os": [ - "freebsd" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/vite/node_modules/@esbuild/freebsd-x64": { - "version": "0.27.7", - "resolved": "https://registry.npmjs.org/@esbuild/freebsd-x64/-/freebsd-x64-0.27.7.tgz", - "integrity": "sha512-jOBDK5XEjA4m5IJK3bpAQF9/Lelu/Z9ZcdhTRLf4cajlB+8VEhFFRjWgfy3M1O4rO2GQ/b2dLwCUGpiF/eATNQ==", - "cpu": [ - "x64" - ], - "dev": true, - "license": "MIT", - "optional": true, - "os": [ - "freebsd" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/vite/node_modules/@esbuild/linux-arm": { - "version": "0.27.7", - "resolved": "https://registry.npmjs.org/@esbuild/linux-arm/-/linux-arm-0.27.7.tgz", - "integrity": "sha512-RkT/YXYBTSULo3+af8Ib0ykH8u2MBh57o7q/DAs3lTJlyVQkgQvlrPTnjIzzRPQyavxtPtfg0EopvDyIt0j1rA==", - "cpu": [ - "arm" - ], - "dev": true, - "license": "MIT", - "optional": true, - "os": [ - "linux" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/vite/node_modules/@esbuild/linux-arm64": { - "version": "0.27.7", - "resolved": "https://registry.npmjs.org/@esbuild/linux-arm64/-/linux-arm64-0.27.7.tgz", - "integrity": "sha512-RZPHBoxXuNnPQO9rvjh5jdkRmVizktkT7TCDkDmQ0W2SwHInKCAV95GRuvdSvA7w4VMwfCjUiPwDi0ZO6Nfe9A==", - "cpu": [ - "arm64" - ], - "dev": true, - "license": "MIT", - "optional": true, - "os": [ - "linux" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/vite/node_modules/@esbuild/linux-ia32": { - "version": "0.27.7", - "resolved": "https://registry.npmjs.org/@esbuild/linux-ia32/-/linux-ia32-0.27.7.tgz", - "integrity": "sha512-GA48aKNkyQDbd3KtkplYWT102C5sn/EZTY4XROkxONgruHPU72l+gW+FfF8tf2cFjeHaRbWpOYa/uRBz/Xq1Pg==", - "cpu": [ - "ia32" - ], - "dev": true, - "license": "MIT", - "optional": true, - "os": [ - "linux" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/vite/node_modules/@esbuild/linux-loong64": { - "version": "0.27.7", - "resolved": "https://registry.npmjs.org/@esbuild/linux-loong64/-/linux-loong64-0.27.7.tgz", - "integrity": "sha512-a4POruNM2oWsD4WKvBSEKGIiWQF8fZOAsycHOt6JBpZ+JN2n2JH9WAv56SOyu9X5IqAjqSIPTaJkqN8F7XOQ5Q==", - "cpu": [ - "loong64" - ], - "dev": true, - "license": "MIT", - "optional": true, - "os": [ - "linux" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/vite/node_modules/@esbuild/linux-mips64el": { - "version": "0.27.7", - "resolved": "https://registry.npmjs.org/@esbuild/linux-mips64el/-/linux-mips64el-0.27.7.tgz", - "integrity": "sha512-KabT5I6StirGfIz0FMgl1I+R1H73Gp0ofL9A3nG3i/cYFJzKHhouBV5VWK1CSgKvVaG4q1RNpCTR2LuTVB3fIw==", - "cpu": [ - "mips64el" - ], - "dev": true, - "license": "MIT", - "optional": true, - "os": [ - "linux" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/vite/node_modules/@esbuild/linux-ppc64": { - "version": "0.27.7", - "resolved": "https://registry.npmjs.org/@esbuild/linux-ppc64/-/linux-ppc64-0.27.7.tgz", - "integrity": "sha512-gRsL4x6wsGHGRqhtI+ifpN/vpOFTQtnbsupUF5R5YTAg+y/lKelYR1hXbnBdzDjGbMYjVJLJTd2OFmMewAgwlQ==", - "cpu": [ - "ppc64" - ], - "dev": true, - "license": "MIT", - "optional": true, - "os": [ - "linux" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/vite/node_modules/@esbuild/linux-riscv64": { - "version": "0.27.7", - "resolved": "https://registry.npmjs.org/@esbuild/linux-riscv64/-/linux-riscv64-0.27.7.tgz", - "integrity": "sha512-hL25LbxO1QOngGzu2U5xeXtxXcW+/GvMN3ejANqXkxZ/opySAZMrc+9LY/WyjAan41unrR3YrmtTsUpwT66InQ==", - "cpu": [ - "riscv64" - ], - "dev": true, - "license": "MIT", - "optional": true, - "os": [ - "linux" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/vite/node_modules/@esbuild/linux-s390x": { - "version": "0.27.7", - "resolved": "https://registry.npmjs.org/@esbuild/linux-s390x/-/linux-s390x-0.27.7.tgz", - "integrity": "sha512-2k8go8Ycu1Kb46vEelhu1vqEP+UeRVj2zY1pSuPdgvbd5ykAw82Lrro28vXUrRmzEsUV0NzCf54yARIK8r0fdw==", - "cpu": [ - "s390x" - ], - "dev": true, - "license": "MIT", - "optional": true, - "os": [ - "linux" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/vite/node_modules/@esbuild/linux-x64": { - "version": "0.27.7", - "resolved": "https://registry.npmjs.org/@esbuild/linux-x64/-/linux-x64-0.27.7.tgz", - "integrity": "sha512-hzznmADPt+OmsYzw1EE33ccA+HPdIqiCRq7cQeL1Jlq2gb1+OyWBkMCrYGBJ+sxVzve2ZJEVeePbLM2iEIZSxA==", - "cpu": [ - "x64" - ], - "dev": true, - "license": "MIT", - "optional": true, - "os": [ - "linux" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/vite/node_modules/@esbuild/netbsd-arm64": { - "version": "0.27.7", - "resolved": "https://registry.npmjs.org/@esbuild/netbsd-arm64/-/netbsd-arm64-0.27.7.tgz", - "integrity": "sha512-b6pqtrQdigZBwZxAn1UpazEisvwaIDvdbMbmrly7cDTMFnw/+3lVxxCTGOrkPVnsYIosJJXAsILG9XcQS+Yu6w==", - "cpu": [ - "arm64" - ], - "dev": true, - "license": "MIT", - "optional": true, - "os": [ - "netbsd" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/vite/node_modules/@esbuild/netbsd-x64": { - "version": "0.27.7", - "resolved": "https://registry.npmjs.org/@esbuild/netbsd-x64/-/netbsd-x64-0.27.7.tgz", - "integrity": "sha512-OfatkLojr6U+WN5EDYuoQhtM+1xco+/6FSzJJnuWiUw5eVcicbyK3dq5EeV/QHT1uy6GoDhGbFpprUiHUYggrw==", - "cpu": [ - "x64" - ], - "dev": true, - "license": "MIT", - "optional": true, - "os": [ - "netbsd" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/vite/node_modules/@esbuild/openbsd-arm64": { - "version": "0.27.7", - "resolved": "https://registry.npmjs.org/@esbuild/openbsd-arm64/-/openbsd-arm64-0.27.7.tgz", - "integrity": "sha512-AFuojMQTxAz75Fo8idVcqoQWEHIXFRbOc1TrVcFSgCZtQfSdc1RXgB3tjOn/krRHENUB4j00bfGjyl2mJrU37A==", - "cpu": [ - "arm64" - ], - "dev": true, - "license": "MIT", - "optional": true, - "os": [ - "openbsd" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/vite/node_modules/@esbuild/openbsd-x64": { - "version": "0.27.7", - "resolved": "https://registry.npmjs.org/@esbuild/openbsd-x64/-/openbsd-x64-0.27.7.tgz", - "integrity": "sha512-+A1NJmfM8WNDv5CLVQYJ5PshuRm/4cI6WMZRg1by1GwPIQPCTs1GLEUHwiiQGT5zDdyLiRM/l1G0Pv54gvtKIg==", - "cpu": [ - "x64" - ], - "dev": true, - "license": "MIT", - "optional": true, - "os": [ - "openbsd" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/vite/node_modules/@esbuild/openharmony-arm64": { - "version": "0.27.7", - "resolved": "https://registry.npmjs.org/@esbuild/openharmony-arm64/-/openharmony-arm64-0.27.7.tgz", - "integrity": "sha512-+KrvYb/C8zA9CU/g0sR6w2RBw7IGc5J2BPnc3dYc5VJxHCSF1yNMxTV5LQ7GuKteQXZtspjFbiuW5/dOj7H4Yw==", - "cpu": [ - "arm64" - ], - "dev": true, - "license": "MIT", - "optional": true, - "os": [ - "openharmony" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/vite/node_modules/@esbuild/sunos-x64": { - "version": "0.27.7", - "resolved": "https://registry.npmjs.org/@esbuild/sunos-x64/-/sunos-x64-0.27.7.tgz", - "integrity": "sha512-ikktIhFBzQNt/QDyOL580ti9+5mL/YZeUPKU2ivGtGjdTYoqz6jObj6nOMfhASpS4GU4Q/Clh1QtxWAvcYKamA==", - "cpu": [ - "x64" - ], - "dev": true, - "license": "MIT", - "optional": true, - "os": [ - "sunos" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/vite/node_modules/@esbuild/win32-arm64": { - "version": "0.27.7", - "resolved": "https://registry.npmjs.org/@esbuild/win32-arm64/-/win32-arm64-0.27.7.tgz", - "integrity": "sha512-7yRhbHvPqSpRUV7Q20VuDwbjW5kIMwTHpptuUzV+AA46kiPze5Z7qgt6CLCK3pWFrHeNfDd1VKgyP4O+ng17CA==", - "cpu": [ - "arm64" - ], - "dev": true, - "license": "MIT", - "optional": true, - "os": [ - "win32" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/vite/node_modules/@esbuild/win32-ia32": { - "version": "0.27.7", - "resolved": "https://registry.npmjs.org/@esbuild/win32-ia32/-/win32-ia32-0.27.7.tgz", - "integrity": "sha512-SmwKXe6VHIyZYbBLJrhOoCJRB/Z1tckzmgTLfFYOfpMAx63BJEaL9ExI8x7v0oAO3Zh6D/Oi1gVxEYr5oUCFhw==", - "cpu": [ - "ia32" - ], - "dev": true, - "license": "MIT", - "optional": true, - "os": [ - "win32" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/vite/node_modules/@esbuild/win32-x64": { - "version": "0.27.7", - "resolved": "https://registry.npmjs.org/@esbuild/win32-x64/-/win32-x64-0.27.7.tgz", - "integrity": "sha512-56hiAJPhwQ1R4i+21FVF7V8kSD5zZTdHcVuRFMW0hn753vVfQN8xlx4uOPT4xoGH0Z/oVATuR82AiqSTDIpaHg==", - "cpu": [ - "x64" - ], - "dev": true, - "license": "MIT", - "optional": true, - "os": [ - "win32" - ], - "engines": { - "node": ">=18" - } - }, - "node_modules/vite/node_modules/esbuild": { - "version": "0.27.7", - "resolved": "https://registry.npmjs.org/esbuild/-/esbuild-0.27.7.tgz", - "integrity": "sha512-IxpibTjyVnmrIQo5aqNpCgoACA/dTKLTlhMHihVHhdkxKyPO1uBBthumT0rdHmcsk9uMonIWS0m4FljWzILh3w==", - "dev": true, - "hasInstallScript": true, - "license": "MIT", - "bin": { - "esbuild": "bin/esbuild" - }, - "engines": { - "node": ">=18" - }, - "optionalDependencies": { - "@esbuild/aix-ppc64": "0.27.7", - "@esbuild/android-arm": "0.27.7", - "@esbuild/android-arm64": "0.27.7", - "@esbuild/android-x64": "0.27.7", - "@esbuild/darwin-arm64": "0.27.7", - "@esbuild/darwin-x64": "0.27.7", - "@esbuild/freebsd-arm64": "0.27.7", - "@esbuild/freebsd-x64": "0.27.7", - "@esbuild/linux-arm": "0.27.7", - "@esbuild/linux-arm64": "0.27.7", - "@esbuild/linux-ia32": "0.27.7", - "@esbuild/linux-loong64": "0.27.7", - "@esbuild/linux-mips64el": "0.27.7", - "@esbuild/linux-ppc64": "0.27.7", - "@esbuild/linux-riscv64": "0.27.7", - "@esbuild/linux-s390x": "0.27.7", - "@esbuild/linux-x64": "0.27.7", - "@esbuild/netbsd-arm64": "0.27.7", - "@esbuild/netbsd-x64": "0.27.7", - "@esbuild/openbsd-arm64": "0.27.7", - "@esbuild/openbsd-x64": "0.27.7", - "@esbuild/openharmony-arm64": "0.27.7", - "@esbuild/sunos-x64": "0.27.7", - "@esbuild/win32-arm64": "0.27.7", - "@esbuild/win32-ia32": "0.27.7", - "@esbuild/win32-x64": "0.27.7" - } - }, "node_modules/vitest": { "version": "4.1.9", "resolved": "https://registry.npmjs.org/vitest/-/vitest-4.1.9.tgz", diff --git a/package.json b/package.json index cac9d8c..96bfd43 100644 --- a/package.json +++ b/package.json @@ -25,5 +25,8 @@ }, "constraints": { "golang": "^1.26" + }, + "overrides": { + "esbuild": "0.28.1" } } diff --git a/vitest.config.js b/vitest.config.js index a6ec8c3..9e0ed34 100644 --- a/vitest.config.js +++ b/vitest.config.js @@ -6,7 +6,7 @@ export default defineConfig({ include: ["web/**/*.test.js"], coverage: { provider: "v8", - include: ["web/app.js"], + include: ["web/app.js", "web/api.js", "web/url-state.js", "web/theme.js", "web/sidebar.js", "web/data-tab.js", "web/structure-tab.js", "web/sql-tab.js"], reporter: ["text", "html"], thresholds: { lines: 100, diff --git a/web/api.js b/web/api.js new file mode 100644 index 0000000..c7f2121 --- /dev/null +++ b/web/api.js @@ -0,0 +1,32 @@ +export function dbUrl(path, dbId) { + if (!dbId) return path; + const sep = path.includes("?") ? "&" : "?"; + return path + sep + "db=" + encodeURIComponent(dbId); +} + +export async function getJSON(url, dbId) { + const r = await fetch(dbUrl(url, dbId)); + if (!r.ok) { + const body = await r.json().catch(() => ({})); + throw new Error(body.error || r.statusText); + } + const body = await r.json(); + return body; +} + +export const tablePath = (t) => + "/api/tables/" + encodeURIComponent(t.schema) + "/" + encodeURIComponent(t.name); + +export const tableKey = (t) => t.schema + "." + t.name; + +// appendDataParams adds search/sort/filter params shared by browse + export. +export function appendDataParams(p, search, sort, filters) { + if (search) p.set("search", search); + if (sort) { p.set("sort", sort.col); p.set("dir", sort.dir); } + for (const f of filters) { + if (!f.column || !f.op) continue; + const noVal = f.op === "is_null" || f.op === "is_not_null"; + if (noVal) p.append("f", f.column + ":" + f.op); + else p.append("f", f.column + ":" + f.op + ":" + (f.value || "")); + } +} diff --git a/web/app.js b/web/app.js index a40d5c8..3ce1c18 100644 --- a/web/app.js +++ b/web/app.js @@ -1,147 +1,28 @@ -// pgpeek UI — Preact + htm, no build step. Vendored Preact/htm keeps the app a -// set of static files embedded in the Go binary, CSP-safe (no eval), with the -// reactivity that the imperative version was outgrowing. +// pgpeek UI — Preact + htm, no build step. This file is the entrypoint; +// feature modules live in ./theme.js, ./sidebar.js, ./data-tab.js, etc. import { html, render, useState, useEffect, useRef, useCallback, } from "./vendor/preact-htm.js"; +import { ThemeSelect } from "./theme.js"; +import { Sidebar, Tabs } from "./sidebar.js"; +import { DataTab } from "./data-tab.js"; +import { StructureTab } from "./structure-tab.js"; +import { SqlTab } from "./sql-tab.js"; +import { getJSON, tableKey } from "./api.js"; +import { readUrlState, pushUrlState, replaceUrlState } from "./url-state.js"; const PAGE_SIZE = 100; -const THEME_KEY = "pgpeek-theme"; -// Switchable color themes. id "" = built-in default (the :root palette). -const THEMES = [ - ["", "Default"], - ["dark-plus", "Dark+"], - ["light-plus", "Light+"], - ["monokai", "Monokai"], - ["dracula", "Dracula"], - ["one-dark", "One Dark Pro"], - ["nord", "Nord"], - ["solarized-dark", "Solarized Dark"], - ["solarized-light", "Solarized Light"], - ["github-dark", "GitHub Dark"], - ["github-light", "GitHub Light"], - ["catppuccin-mocha", "Catppuccin Mocha"], - ["catppuccin-latte", "Catppuccin Latte"], - ["tokyo-night", "Tokyo Night"], - ["ayu-dark", "Ayu Dark"], - ["ayu-mirage", "Ayu Mirage"], - ["night-owl", "Night Owl"], - ["houston", "Houston"], - ["matcha", "Matcha"], - ["dainty", "Dainty"], -]; - -function getStoredTheme() { - try { return localStorage.getItem(THEME_KEY) || ""; } catch { return ""; } -} - -// applyTheme sets (or clears) the data-theme attribute that selects a palette. -function applyTheme(id) { - const root = document.documentElement; - if (id) root.setAttribute("data-theme", id); - else root.removeAttribute("data-theme"); -} - -// Apply the saved theme at import time to avoid a flash of the default palette. -applyTheme(getStoredTheme()); - -// Allowlisted filter operators (key sent to the server, label shown in the UI). -const OPS = [ - ["", "—"], ["eq", "="], ["ne", "≠"], ["lt", "<"], ["lte", "≤"], - ["gt", ">"], ["gte", "≥"], ["ilike", "ILIKE"], ["like", "LIKE"], - ["is_null", "IS NULL"], ["is_not_null", "NOT NULL"], -]; -const opNeedsValue = (op) => op !== "" && op !== "is_null" && op !== "is_not_null"; - -const tablePath = (t) => "/api/tables/" + encodeURIComponent(t.schema) + "/" + encodeURIComponent(t.name); -const tableKey = (t) => t.schema + "." + t.name; - -async function getJSON(url) { - const r = await fetch(url); - const body = await r.json(); - if (!r.ok) throw new Error(body.error || r.statusText); - return body; -} - -// appendDataParams adds search/sort/filter params shared by browse + export. -function appendDataParams(p, search, sort, filters) { - if (search) p.set("search", search); - if (sort) { p.set("sort", sort.col); p.set("dir", sort.dir); } - for (const col of Object.keys(filters)) { - const f = filters[col]; - if (!f.op) continue; - if (opNeedsValue(f.op)) p.append("f", col + ":" + f.op + ":" + (f.value || "")); - else p.append("f", col + ":" + f.op); - } -} - -function cellText(v) { - if (v === null || v === undefined) return null; - return typeof v === "object" ? JSON.stringify(v) : String(v); -} - -// Cell renders one value, as an FK link when fkRef is set. -function Cell({ value, fkRef, onNavigate }) { - const text = cellText(value); - if (text === null) return html``; - if (fkRef) { - return html``; - } - return html``; -} - -function BodyRows({ rows, fkByCol, onNavigate }) { - return rows.map((row) => html`${row.map((v, i) => html`<${Cell} value=${v} fkRef=${fkByCol && fkByCol[i]} onNavigate=${onNavigate} />`)}`); -} - -// ---- Sidebar ---- -function Sidebar({ tables, loaded, currentKey, onSelect }) { - const [filter, setFilter] = useState(""); - const listRef = useRef(); - const f = filter.toLowerCase(); - const items = []; - let schema = null; - - useEffect(() => { - const active = listRef.current && listRef.current.querySelector(".tbl.active"); - if (active && active.scrollIntoView) active.scrollIntoView({ block: "nearest" }); - }, [currentKey, filter]); - - for (const t of tables) { - const label = tableKey(t); - if (f && !label.toLowerCase().includes(f)) continue; - if (t.schema !== schema) { - schema = t.schema; - items.push(html`
    ${schema}
    `); - } - const active = label === currentKey; - const cls = "tbl" + (t.type === "view" ? " view" : "") + (active ? " active" : ""); - items.push(html``); - } - return html` - `; -} - -// ---- Tabs ---- -function Tabs({ tab, setTab, title }) { - const btn = (id, label) => html``; +// ---- Database selector (hidden when ≤1 database) ---- +function DatabaseSelect({ databases, currentDb, onSwitch }) { + if (databases.length <= 1) return null; return html` -
    - ${btn("data", "Data")} ${btn("structure", "Structure")} ${btn("sql", "SQL")} - ${title} -
    `; + `; } function TableContext({ table }) { @@ -154,345 +35,228 @@ function TableContext({ table }) { return html`
    Current ${table.type === "view" ? "view" : "table"} ${table.schema}.${table.name} - ${table.estRows >= 0 ? html`~${table.estRows} rows` : html`row count unavailable`} + ${table.estRows >= 0 + ? html`~${table.estRows} rows` + : html`row count unavailable`}
    `; } -// ---- Data tab ---- -function DataTab({ table, pageSize, initialFilters, onNavigate, setStatus }) { - const [offset, setOffset] = useState(0); - const [search, setSearch] = useState(""); - const [searchBox, setSearchBox] = useState(""); - const [filters, setFilters] = useState(initialFilters || {}); - const [draft, setDraft] = useState(initialFilters || {}); - const [sort, setSort] = useState(null); - const [data, setData] = useState(null); - const [fks, setFks] = useState({}); +// ---- App ---- +function App() { + const [databases, setDatabases] = useState([]); + const [dbsLoaded, setDbsLoaded] = useState(false); + const [currentDb, setCurrentDb] = useState(null); + const [tables, setTables] = useState([]); + const [tablesLoaded, setTablesLoaded] = useState(false); + const [rowCap, setRowCap] = useState(PAGE_SIZE); + const [saved, setSaved] = useState([]); + const [tab, setTabState] = useState("data"); + const [current, setCurrent] = useState(null); + const [navKey, setNavKey] = useState(0); + const [pendingFilters, setPendingFilters] = useState(null); + const [urlInit, setUrlInit] = useState(null); + const [status, setStatus] = useState({ text: "Ready.", cls: "ok" }); + // Refs so popstate handler always sees the latest values. + const urlStateRef = useRef({}); + const dbRef = useRef(null); + const tablesRef = useRef([]); + + const setTab = useCallback((newTab) => { + setTabState(newTab); + const s = { ...urlStateRef.current, tab: newTab === "data" ? null : newTab }; + pushUrlState(s); + urlStateRef.current = s; + }, []); - useEffect(() => { - let live = true; - (async () => { - try { - const list = await getJSON(tablePath(table) + "/fks"); - if (!live) return; - const m = {}; - for (const fk of list) m[fk.column] = { schema: fk.refSchema, table: fk.refTable, column: fk.refColumn }; - setFks(m); - } catch { /* no FK links */ } - })(); - return () => { live = false; }; - }, [table]); + const reloadSaved = useCallback(async () => { + try { setSaved(await getJSON("/api/queries")); } + catch (e) { setStatus({ text: "✗ failed to load saved queries: " + e.message, cls: "error" }); } + }, []); + // Phase 1: fetch /api/databases, resolve active db, restore URL state, + // install popstate listener. useEffect(() => { - let live = true; - setStatus({ text: "Loading " + tableKey(table) + "…", cls: "ok" }); - const p = new URLSearchParams(); - p.set("limit", pageSize); - p.set("offset", offset); - appendDataParams(p, search, sort, filters); + const urlState = readUrlState(); + urlStateRef.current = urlState; + (async () => { + let dbId = null; try { - const d = await getJSON(tablePath(table) + "/data?" + p.toString()); - if (!live) return; - setData(d); - setStatus({ text: "✓ " + d.rowCount + " row" + (d.rowCount === 1 ? "" : "s") + " in " + d.elapsedMs + " ms", cls: "ok" }); + const r = await fetch("/api/databases"); + + if (!r.ok) throw new Error(r.statusText || "failed"); + const result = await r.json(); + const dbs = Array.isArray(result.databases) ? result.databases : []; + setDatabases(dbs); + const urlDb = urlState.db; + const valid = dbs.find((d) => d.id === urlDb); + dbId = valid ? urlDb : (result.defaultId || (dbs[0] && dbs[0].id) || null); + if (urlDb && !valid && dbs.length > 0) { + setStatus({ text: "✗ unknown database in URL, using default", cls: "error" }); + } } catch (e) { - if (live) setStatus({ text: "✗ " + e.message, cls: "error" }); - } - })(); - return () => { live = false; }; - }, [table, offset, search, JSON.stringify(filters), sort && sort.col, sort && sort.dir, pageSize]); - const applyDraft = useCallback((next) => { - const clean = {}; - for (const c of Object.keys(next)) if (next[c] && next[c].op) clean[c] = next[c]; - setFilters(clean); - setOffset(0); - }, []); + setStatus({ text: "✗ failed to load databases: " + e.message, cls: "error" }); + } - const toggleSort = (col) => { - setOffset(0); - setSort((s) => (s && s.col === col ? { col, dir: s.dir === "asc" ? "desc" : "asc" } : { col, dir: "asc" })); - }; - const exportURL = () => { - const p = new URLSearchParams(); - p.set("format", "csv"); - appendDataParams(p, search, sort, filters); - return tablePath(table) + "/data?" + p.toString(); - }; + // Restore tab from URL; build pending table-restore if schema+table present. + setTabState(urlState.tab); + const finalState = { ...urlState, db: dbId }; + replaceUrlState(finalState); + urlStateRef.current = finalState; + dbRef.current = dbId; - let grid; - if (!data || data.columns === undefined) { - grid = html`
    Loading…
    `; - } else if (!data.columns.length) { - grid = html`
    No columns.
    `; - } else { - const fkByCol = data.columns.map((c) => fks[c] || null); - grid = html` -
    Method & pathPurpose
    POST /api/queryRun a query → JSON {columns, rows, …}.
    POST /api/exportRun a query → CSV download.
    GET /api/metaServer limits the UI needs ({rowCap}).
    GET /api/tablesList browsable tables/views (+ row estimate).
    GET /api/tables/{schema}/{table}/columnsColumn structure.
    GET /api/tables/{schema}/{table}/fksSingle-column foreign keys (for click-through).
    GET /api/tables/{schema}/{table}/dataPaged rows: ?limit=&offset=&search=&sort=&dir=&f=col:op:val (&format=csv).
    GET /api/databasesList configured databases → {defaultId, databases:[{id,name}]}.
    POST /api/query?db=<id>Run a query → JSON {columns, rows, …}.
    POST /api/export?db=<id>Run a query → CSV download.
    GET /api/meta?db=<id>Server limits the UI needs ({rowCap}).
    GET /api/tables?db=<id>List browsable tables/views (+ row estimate).
    GET /api/tables/{schema}/{table}/columns?db=<id>Column structure.
    GET /api/tables/{schema}/{table}/fks?db=<id>Single-column foreign keys (for click-through).
    GET /api/tables/{schema}/{table}/data?db=<id>Paged rows: &limit=&offset=&search=&sort=&dir=&f=col:op:val (&format=csv).
    GET /api/queriesList saved/preset queries.
    POST /api/queriesCreate a saved query.
    PUT /api/queries/{id}Update a saved query.
    NULL${text}
    - - ${data.columns.map((c) => html``)} - ${data.columns.map((c) => { - const d = draft[c] || {}; - return html``; - })} - - <${BodyRows} rows=${data.rows} fkByCol=${fkByCol} onNavigate=${onNavigate} /> -
    toggleSort(c)}> - ${c}${sort && sort.col === c ? (sort.dir === "desc" ? " ▼" : " ▲") : ""}
    - - setDraft({ ...draft, [c]: { op: d.op || "", value: e.target.value } })} - onKeyDown=${(e) => { if (e.key === "Enter") applyDraft({ ...draft, [c]: { op: d.op || "", value: e.target.value } }); }} /> -
    - ${data.rows.length ? "" : html`
    0 rows.
    `}`; - } - - const rowCount = data && data.rowCount ? data.rowCount : 0; - const from = rowCount ? offset + 1 : 0; + if (urlState.schema && urlState.table) { + setUrlInit({ + schema: urlState.schema, table: urlState.table, + offset: urlState.offset, search: urlState.search, + sort: urlState.sort, filters: urlState.filters, + }); + } + setCurrentDb(dbId); + setDbsLoaded(true); + })(); - return html` -
    - setSearchBox(e.target.value)} - onKeyDown=${(e) => { if (e.key === "Enter") { setOffset(0); setSearch(searchBox.trim()); } }} /> - - - - - ${from}–${offset + rowCount} - Export CSV -
    -
    ${grid}
    `; -} + const onPopstate = () => { + const s = readUrlState(); + const sameDb = s.db === dbRef.current; + urlStateRef.current = s; + dbRef.current = s.db; + setTabState(s.tab); + setCurrentDb(s.db); + if (!sameDb) { + setCurrent(null); + // Tables effect will reload; queue table restore if URL has a table. + if (s.schema && s.table) { + setUrlInit({ schema: s.schema, table: s.table, + offset: s.offset, search: s.search, sort: s.sort, filters: s.filters }); + } else setUrlInit(null); + } else if (s.schema && s.table) { + const found = tablesRef.current.find((x) => x.schema === s.schema && x.name === s.table); + if (found) { setCurrent(found); setNavKey((k) => k + 1); setUrlInit(s); } + else setCurrent(null); + } else { + setCurrent(null); + } + }; + window.addEventListener("popstate", onPopstate); + return () => { window.removeEventListener("popstate", onPopstate); }; + }, []); -// ---- Structure tab ---- -function StructureTab({ table, setStatus }) { - const [cols, setCols] = useState(null); + // Phase 2: reload tables + meta whenever the resolved db changes. useEffect(() => { + if (!dbsLoaded) return; let live = true; + setTablesLoaded(false); + tablesRef.current = []; + setTables([]); + const db = currentDb; (async () => { try { - const c = await getJSON(tablePath(table) + "/columns"); - if (live) setCols(c); + const t = await getJSON("/api/tables", db); + if (!live) return; + tablesRef.current = t; + setTables(t); + // Consume pending URL table-restore (set during initial load or popstate). + setUrlInit((prev) => { + if (!prev) return null; + const found = t.find((x) => x.schema === prev.schema && x.name === prev.table); + if (found) { setCurrent(found); setNavKey((k) => k + 1); } + return null; + }); } catch (e) { - if (live) setStatus({ text: "✗ " + e.message, cls: "error" }); - } + if (live) setStatus({ text: "✗ failed to load tables: " + e.message, cls: "error" }); + } finally { if (live) setTablesLoaded(true); } })(); + (async () => { + try { + const m = await getJSON("/api/meta", db); + if (live && m && m.rowCap > 0) setRowCap(m.rowCap); + } catch { /* keep default */ } + })(); + reloadSaved(); return () => { live = false; }; - }, [table]); - - let body; - if (cols === null) body = html`
    Loading…
    `; - else if (!cols.length) body = html`
    No columns.
    `; - else body = html` - - ${cols.map((c) => html` - `)} -
    ColumnTypeNullableDefault
    ${c.name}${c.type}${c.nullable ? "YES" : "NO"}${c.default == null ? "" : c.default}
    `; - return html`
    ${body}
    `; -} - -// ---- SQL tab ---- -function SqlTab({ active, saved, reloadSaved, setStatus }) { - const wrapRef = useRef(); - const taRef = useRef(); - const cmRef = useRef(); - const [result, setResult] = useState(null); - const [lastSQL, setLastSQL] = useState(""); - const [selected, setSelected] = useState(""); - const [running, setRunning] = useState(false); - const runningRef = useRef(false); - - const getSQL = () => (cmRef.current ? cmRef.current.getValue() : taRef.current.value).trim(); - const setSQL = (v) => { if (cmRef.current) cmRef.current.setValue(v); else taRef.current.value = v; }; - - const run = useCallback(async () => { - const sql = getSQL(); - if (!sql) return; - if (runningRef.current) return; - runningRef.current = true; setRunning(true); - setStatus({ text: "Running…", cls: "ok" }); - try { - const r = await fetch("/api/query", { method: "POST", headers: { "Content-Type": "application/json" }, body: JSON.stringify({ sql }) }); - const d = await r.json(); - if (!r.ok) { setStatus({ text: "✗ " + (d.error || r.statusText), cls: "error" }); setResult(null); return; } - setLastSQL(sql); setResult(d); - const base = "✓ " + d.rowCount + " row" + (d.rowCount === 1 ? "" : "s") + " in " + d.elapsedMs + " ms"; - setStatus(d.truncated ? { text: base, cls: "ok", warn: "· capped (more rows available — add LIMIT or refine)" } : { text: base, cls: "ok" }); - } catch (e) { - setStatus({ text: "✗ " + e.message, cls: "error" }); - } finally { - runningRef.current = false; setRunning(false); - } - }, []); - - // Init the SQL editor once, into a Preact-stable wrapper it fully owns. Uses - // the vendored CodeMirror 6 bundle (window.cm6) when present, else degrades - // to a plain