From 0f83082b7110428474d0ab61ff85eaa7e4098aa6 Mon Sep 17 00:00:00 2001 From: Volodymyr Kohut Date: Thu, 16 Apr 2026 15:50:51 +0100 Subject: [PATCH 1/2] add lib hdx specific --- .github/CODEOWNERS | 14 - .github/issue_commands.json | 10 - .github/workflows/ci.yml | 41 -- .github/workflows/issue_commands.yml | 26 - completion.go | 116 --- completion_test.go | 205 ------ connector.go | 229 +++--- connector_test.go | 476 ++++++++++++ dataframe_test.go | 14 +- datasource.go | 131 ++-- datasource_connect_test.go | 71 +- datasource_middleware_test.go | 55 -- datasource_rowlimit_test.go | 138 ++-- datasource_test.go | 282 +------- driver-mock.go | 5 - driver.go | 1 - driver_round_time_test.go | 63 ++ go.mod | 17 +- go.sum | 35 +- health.go | 2 +- health_test.go | 33 +- interpolator.go | 328 +++++++++ interpolator_test.go | 463 ++++++++++++ macros.go | 456 +++++++++++- macros_test.go | 1004 ++++++++++++++++++++++++++ metadata.go | 216 ++++++ metadata_test.go | 122 ++++ mock/csv/csv_mock.go | 7 +- models/settings.go | 211 ++++++ models/settings_test.go | 207 ++++++ query_integration_test.go | 91 --- test/driver.go | 17 +- 32 files changed, 3896 insertions(+), 1190 deletions(-) delete mode 100644 .github/CODEOWNERS delete mode 100644 .github/issue_commands.json delete mode 100644 .github/workflows/ci.yml delete mode 100644 .github/workflows/issue_commands.yml delete mode 100644 completion.go delete mode 100644 completion_test.go create mode 100644 connector_test.go delete mode 100644 datasource_middleware_test.go create mode 100644 driver_round_time_test.go create mode 100644 interpolator.go create mode 100644 interpolator_test.go create mode 100644 macros_test.go create mode 100644 metadata.go create mode 100644 metadata_test.go create mode 100644 models/settings.go create mode 100644 models/settings_test.go delete mode 100644 query_integration_test.go diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS deleted file mode 100644 index c58a5f9..0000000 --- a/.github/CODEOWNERS +++ /dev/null @@ -1,14 +0,0 @@ -# Lines starting with '#' are comments. -# Each line is a file pattern followed by one or more owners. - -# More details are here: https://help.github.com/articles/about-codeowners/ - -# The '*' pattern is global owners. - -# Order is important. The last matching pattern has the most precedence. -# The folders are ordered as follows: - -# In each subsection folders are ordered first by depth, then alphabetically. -# This should make it easy to add new rules without breaking existing ones. - -* @grafana/data-sources-plugins diff --git a/.github/issue_commands.json b/.github/issue_commands.json deleted file mode 100644 index a1f4ec2..0000000 --- a/.github/issue_commands.json +++ /dev/null @@ -1,10 +0,0 @@ -[ - { - "type": "label", - "name": "type/docs", - "action": "addToProject", - "addToProject": { - "url": "https://github.com/orgs/grafana/projects/69" - } - } -] diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml deleted file mode 100644 index 10a9829..0000000 --- a/.github/workflows/ci.yml +++ /dev/null @@ -1,41 +0,0 @@ -name: ci -on: - pull_request: - push: - branches: - - main -permissions: - contents: read - -jobs: - tests: - runs-on: ubuntu-latest - services: - mysql: - image: mysql:9.6@sha256:6b18d01fb632c0f568ace1cc1ebffb42d1d21bc1de86f6d3e8b7eb18278444d9 - env: - MYSQL_ALLOW_EMPTY_PASSWORD: yes - MYSQL_DATABASE: mysql - MYSQL_USER: mysql - MYSQL_PASSWORD: mysql - MYSQL_HOST: 127.0.0.1 - ports: - - 3306:3306 - options: --health-cmd="mysqladmin ping" --health-interval=10s --health-timeout=5s --health-retries=3 - steps: - - name: checkout - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - with: - persist-credentials: false - - name: setup Go environment - uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0 - with: - go-version-file: "go.mod" - cache-dependency-path: "**/*.sum" - - name: Test - run: go test -v ./... - - name: Integration tests - run: go test -v ./... - env: - INTEGRATION_TESTS: "true" - MYSQL_URL: "mysql:mysql@tcp(127.0.0.1:3306)/mysql" diff --git a/.github/workflows/issue_commands.yml b/.github/workflows/issue_commands.yml deleted file mode 100644 index 94cd510..0000000 --- a/.github/workflows/issue_commands.yml +++ /dev/null @@ -1,26 +0,0 @@ -name: Run commands when issues are labeled -permissions: {} - -on: - issues: - types: [labeled] -jobs: - main: - runs-on: ubuntu-latest - steps: - - name: Checkout Actions - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - with: - repository: "grafana/grafana-github-actions" - persist-credentials: false - path: ./actions - ref: main - - name: Install Actions - run: npm install --production --prefix ./actions - - name: Run Commands - uses: ./actions/commands - env: - ISSUE_COMMANDS_TOKEN: ${{secrets.ISSUE_COMMANDS_TOKEN}} - with: - token: ${ISSUE_COMMANDS_TOKEN} - configPath: issue_commands diff --git a/completion.go b/completion.go deleted file mode 100644 index 717f05b..0000000 --- a/completion.go +++ /dev/null @@ -1,116 +0,0 @@ -package sqlds - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "net/http" - - "github.com/grafana/grafana-plugin-sdk-go/backend" -) - -const ( - schemas = "schemas" - tables = "tables" - columns = "columns" -) - -var ( - // ErrorNotImplemented is returned if the function is not implemented by the provided Driver (the Completable pointer is nil) - ErrorNotImplemented = errors.New("not implemented") - // ErrorWrongOptions when trying to parse Options with a invalid JSON - ErrorWrongOptions = errors.New("error reading query options") -) - -// Options are used to query schemas, tables and columns. They will be encoded in the request body (e.g. {"database": "mydb"}) -type Options map[string]string - -// Completable will be used to autocomplete Tables Schemas and Columns for SQL languages -type Completable interface { - Schemas(ctx context.Context, options Options) ([]string, error) - Tables(ctx context.Context, options Options) ([]string, error) - Columns(ctx context.Context, options Options) ([]string, error) -} - -func handleError(rw http.ResponseWriter, err error) { - rw.WriteHeader(http.StatusBadRequest) - _, err = rw.Write([]byte(err.Error())) - if err != nil { - backend.Logger.Error(err.Error()) - } -} - -func sendResourceResponse(rw http.ResponseWriter, res []string) { - rw.Header().Add("Content-Type", "application/json") - if err := json.NewEncoder(rw).Encode(res); err != nil { - handleError(rw, err) - return - } -} - -func (ds *SQLDatasource) getResources(rtype string) func(rw http.ResponseWriter, req *http.Request) { - return func(rw http.ResponseWriter, req *http.Request) { - if ds.Completable == nil { - handleError(rw, ErrorNotImplemented) - return - } - - options := Options{} - if req.Body != nil { - err := json.NewDecoder(req.Body).Decode(&options) - if err != nil { - handleError(rw, err) - return - } - } - - var res []string - var err error - switch rtype { - case schemas: - res, err = ds.Completable.Schemas(req.Context(), options) - case tables: - res, err = ds.Completable.Tables(req.Context(), options) - case columns: - res, err = ds.Completable.Columns(req.Context(), options) - default: - err = fmt.Errorf("unexpected resource type: %s", rtype) - } - if err != nil { - handleError(rw, err) - return - } - - sendResourceResponse(rw, res) - } -} - -func (ds *SQLDatasource) registerRoutes(mux *http.ServeMux) error { - defaultRoutes := map[string]func(http.ResponseWriter, *http.Request){ - "/tables": ds.getResources(tables), - "/schemas": ds.getResources(schemas), - "/columns": ds.getResources(columns), - } - for route, handler := range defaultRoutes { - mux.HandleFunc(route, handler) - } - for route, handler := range ds.CustomRoutes { - if _, ok := defaultRoutes[route]; ok { - return fmt.Errorf("unable to redefine %s, use the Completable interface instead", route) - } - mux.HandleFunc(route, handler) - } - return nil -} - -func ParseOptions(rawOptions json.RawMessage) (Options, error) { - args := Options{} - if rawOptions != nil { - err := json.Unmarshal(rawOptions, &args) - if err != nil { - return nil, fmt.Errorf("%w: %v", ErrorWrongOptions, err) - } - } - return args, nil -} diff --git a/completion_test.go b/completion_test.go deleted file mode 100644 index 7f9ca5e..0000000 --- a/completion_test.go +++ /dev/null @@ -1,205 +0,0 @@ -package sqlds - -import ( - "bytes" - "context" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - "net/http/httptest" - "testing" - - "github.com/google/go-cmp/cmp" -) - -func Test_handleError(t *testing.T) { - t.Run("it should write an error code and a message", func(t *testing.T) { - w := httptest.NewRecorder() - handleError(w, fmt.Errorf("test!")) - - resp := w.Result() - body, _ := io.ReadAll(resp.Body) - - if resp.StatusCode != http.StatusBadRequest { - t.Errorf("expecting code %v got %v", http.StatusBadRequest, resp.StatusCode) - } - if string(body) != "test!" { - t.Errorf("expecting response test! got %v", string(body)) - } - }) -} - -func Test_sendResourceResponse(t *testing.T) { - t.Run("it should send a JSON response", func(t *testing.T) { - w := httptest.NewRecorder() - sendResourceResponse(w, []string{"foo", "bar"}) - - resp := w.Result() - body, _ := io.ReadAll(resp.Body) - - if resp.StatusCode != http.StatusOK { - t.Errorf("expecting code %v got %v", http.StatusBadRequest, http.StatusOK) - } - expectedResult := `["foo","bar"]` + "\n" - if string(body) != expectedResult { - t.Errorf("expecting response %v got %v", expectedResult, string(body)) - } - if resp.Header.Get("Content-Type") != "application/json" { - t.Errorf("expecting content-type application/json got %v", resp.Header.Get("Content-Type")) - } - }) -} - -type fakeCompletable struct { - schemas map[string][]string - tables map[string][]string - columns map[string][]string - err error -} - -func (f *fakeCompletable) Schemas(ctx context.Context, options Options) ([]string, error) { - return f.schemas[options["database"]], f.err -} - -func (f *fakeCompletable) Tables(ctx context.Context, options Options) ([]string, error) { - return f.tables[options["schema"]], f.err -} - -func (f *fakeCompletable) Columns(ctx context.Context, options Options) ([]string, error) { - return f.columns[options["table"]], f.err -} - -func TestCompletable(t *testing.T) { - tests := []struct { - description string - method string - fakeImpl *fakeCompletable - reqBody string - expectedRes string - }{ - { - "it should return schemas", - schemas, - &fakeCompletable{schemas: map[string][]string{"foobar": {"foo", "bar"}}}, - `{"database":"foobar"}`, - `["foo","bar"]` + "\n", - }, - { - "it should return tables of a schema", - tables, - &fakeCompletable{tables: map[string][]string{"foobar": {"foo", "bar"}}}, - `{"schema":"foobar"}`, - `["foo","bar"]` + "\n", - }, - { - "it should return columns of a table", - columns, - &fakeCompletable{columns: map[string][]string{"foobar": {"foo", "bar"}}}, - `{"table":"foobar"}`, - `["foo","bar"]` + "\n", - }, - } - for _, test := range tests { - t.Run(test.description, func(t *testing.T) { - w := httptest.NewRecorder() - - sqlds := &SQLDatasource{} - sqlds.Completable = test.fakeImpl - - b := io.NopCloser(bytes.NewReader([]byte(test.reqBody))) - sqlds.getResources(test.method)(w, &http.Request{Body: b}) - resp := w.Result() - body, _ := io.ReadAll(resp.Body) - - if resp.StatusCode != http.StatusOK { - t.Errorf("expecting code %v got %v", http.StatusOK, resp.StatusCode) - } - if string(body) != test.expectedRes { - t.Errorf("expecting response %v got %v", test.expectedRes, string(body)) - } - if resp.Header.Get("Content-Type") != "application/json" { - t.Errorf("expecting content-type application/json got %v", resp.Header.Get("Content-Type")) - } - }) - } -} - -func Test_registerRoutes(t *testing.T) { - t.Run("it should add a new route", func(t *testing.T) { - sqlds := &SQLDatasource{} - sqlds.CustomRoutes = map[string]func(http.ResponseWriter, *http.Request){ - "/foo": func(w http.ResponseWriter, r *http.Request) { - _, err := w.Write([]byte("bar")) - if err != nil { - t.Fatal((err)) - } - }, - } - - mux := http.NewServeMux() - err := sqlds.registerRoutes(mux) - if err != nil { - t.Fatalf("unexpected error %v", err) - } - resp := httptest.NewRecorder() - req, err := http.NewRequest("GET", "/foo", nil) - if err != nil { - t.Fatalf("unexpected error %v", err) - } - mux.ServeHTTP(resp, req) - - respByte, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("unexpected error %v", err) - } - if string(respByte) != "bar" { - t.Errorf("unexpected response %s", string(respByte)) - } - }) - - t.Run("it error if tried to add an existing route", func(t *testing.T) { - sqlds := &SQLDatasource{} - sqlds.CustomRoutes = map[string]func(http.ResponseWriter, *http.Request){ - "/tables": func(w http.ResponseWriter, r *http.Request) {}, - } - - mux := http.NewServeMux() - err := sqlds.registerRoutes(mux) - if err == nil || err.Error() != "unable to redefine /tables, use the Completable interface instead" { - t.Errorf("unexpected error %v", err) - } - }) -} - -func TestParseOptions(t *testing.T) { - tests := []struct { - err error - result Options - description string - input json.RawMessage - }{ - { - description: "parses input", - input: json.RawMessage(`{"foo":"bar"}`), - result: Options{"foo": "bar"}, - }, - { - description: "returns an error", - input: json.RawMessage(`not a json`), - err: ErrorWrongOptions, - }, - } - for _, tc := range tests { - t.Run(tc.description, func(t *testing.T) { - res, err := ParseOptions(tc.input) - if (err != nil || tc.err != nil) && !errors.Is(err, tc.err) { - t.Errorf("unexpected error %v", err) - } - if !cmp.Equal(res, tc.result) { - t.Errorf("unexpected result: %v", cmp.Diff(res, tc.result)) - } - }) - } -} diff --git a/connector.go b/connector.go index 2628928..c2c244e 100644 --- a/connector.go +++ b/connector.go @@ -3,62 +3,99 @@ package sqlds import ( "context" "database/sql" - "fmt" + "encoding/json" + "github.com/grafana/grafana-plugin-sdk-go/data/sqlutil" + "github.com/hydrolix/sqlds/v5/models" + "github.com/jellydator/ttlcache/v3" "net/http" "strings" - "sync" "time" "github.com/grafana/grafana-plugin-sdk-go/backend" ) -type Connector struct { - UID string - connections sync.Map - driver Driver - driverSettings DriverSettings - // Enabling multiple connections may cause that concurrent connection limits - // are hit. The datasource enabling this should make sure connections are cached - // if necessary. - enableMultipleConnections bool +type Connector interface { + Connect(ctx context.Context, headers http.Header) (*dbConnection, error) + connectWithRetries(ctx context.Context, conn dbConnection, key string, headers http.Header) error + connect(conn dbConnection) error + ping(conn dbConnection) error + Reconnect(ctx context.Context, dbConn dbConnection, q *sqlutil.Query, cacheKey string) (*sql.DB, error) + getDBConnection(key string) (dbConnection, bool) + storeDBConnection(key string, dbConn dbConnection) + Dispose() + GetConnectionFromQuery(ctx context.Context, q *sqlutil.Query) (string, dbConnection, error) + GetDriver() Driver + GetUID() string + getDriverSettings() DriverSettings + getInstanceSettings() backend.DataSourceInstanceSettings } -func NewConnector(ctx context.Context, driver Driver, settings backend.DataSourceInstanceSettings, enableMultipleConnections bool) (*Connector, error) { - ds := driver.Settings(ctx, settings) - db, err := driver.Connect(ctx, settings, nil) +type HydrolixConnector struct { + UID string + connections *ttlcache.Cache[string, dbConnection] + Driver Driver + driverSettings DriverSettings + instanceSettings backend.DataSourceInstanceSettings + pluginSettings models.PluginSettings +} + +func NewConnector(ctx context.Context, driver Driver, settings backend.DataSourceInstanceSettings) (*HydrolixConnector, error) { + pluginSettings, err := models.NewPluginSettings(ctx, settings) if err != nil { return nil, backend.DownstreamError(err) } + ds := driver.Settings(ctx, settings) + connections := ttlcache.New[string, dbConnection](ttlcache.WithTTL[string, dbConnection](time.Hour)) + connections.OnEviction(func(ctx context.Context, reason ttlcache.EvictionReason, i *ttlcache.Item[string, dbConnection]) { + _ = i.Value().db.Close() + }) - conn := &Connector{ - UID: settings.UID, - driver: driver, - driverSettings: ds, - enableMultipleConnections: enableMultipleConnections, + conn := &HydrolixConnector{ + UID: settings.UID, + Driver: driver, + driverSettings: ds, + instanceSettings: settings, + pluginSettings: pluginSettings, + connections: connections, + } + if pluginSettings.CredentialsType != "forwardOAuth" { + key := defaultKey(settings.UID) + db, err := driver.Connect(ctx, settings, nil) + if err != nil { + return nil, backend.DownstreamError(err) + } + conn.storeDBConnectionWithTTL(key, dbConnection{db, settings}, ttlcache.NoTTL) } - key := defaultKey(settings.UID) - conn.storeDBConnection(key, dbConnection{db, settings}) return conn, nil } -func (c *Connector) Connect(ctx context.Context, headers http.Header) (*dbConnection, error) { - key := defaultKey(c.UID) +func (c *HydrolixConnector) Connect(ctx context.Context, headers http.Header) (*dbConnection, error) { + key := "" + if c.pluginSettings.CredentialsType == "forwardOAuth" { + key = keyWithConnectionArgs(c.UID, getOAuthConnectionArgs(headers.Get(backend.OAuthIdentityTokenHeaderName))) + } else { + key = defaultKey(c.UID) + } dbConn, ok := c.getDBConnection(key) if !ok { - return nil, ErrorMissingDBConnection + db, err := c.Driver.Connect(ctx, c.instanceSettings, getOAuthConnectionArgs(headers.Get(backend.OAuthIdentityTokenHeaderName))) + if err != nil { + return nil, err + } + // Assign this connection in the cache + dbConn = dbConnection{db, c.instanceSettings} + c.storeDBConnection(key, dbConn) } - if c.driverSettings.Retries == 0 { - err := c.connect(ctx, dbConn) - return nil, err + err := c.connect(dbConn) + return &dbConn, err } - err := c.connectWithRetries(ctx, dbConn, key, headers) return &dbConn, err } -func (c *Connector) connectWithRetries(ctx context.Context, conn dbConnection, key string, headers http.Header) error { - q := &Query{} +func (c *HydrolixConnector) connectWithRetries(ctx context.Context, connection dbConnection, key string, headers http.Header) error { + q := &sqlutil.Query{} if c.driverSettings.ForwardHeaders { applyHeaders(q, headers) } @@ -66,15 +103,15 @@ func (c *Connector) connectWithRetries(ctx context.Context, conn dbConnection, k var db *sql.DB var err error for i := 0; i < c.driverSettings.Retries; i++ { - db, err = c.Reconnect(ctx, conn, q, key) + db, err = c.Reconnect(ctx, connection, q, key) if err != nil { return err } conn := dbConnection{ db: db, - settings: conn.settings, + settings: connection.settings, } - err = c.connect(ctx, conn) + err = c.connect(conn) if err == nil { break } @@ -90,98 +127,118 @@ func (c *Connector) connectWithRetries(ctx context.Context, conn dbConnection, k if c.driverSettings.Pause > 0 { time.Sleep(time.Duration(c.driverSettings.Pause * int(time.Second))) } - backend.Logger.Warn(fmt.Sprintf("connect failed: %s. Retrying %d times", err.Error(), i+1)) + backend.Logger.Warn("connect failed", "error", err.Error(), "retry", i+1) } return err } -func (c *Connector) connect(ctx context.Context, conn dbConnection) error { - if err := c.ping(ctx, conn); err != nil { +func (c *HydrolixConnector) connect(conn dbConnection) error { + if err := c.ping(conn); err != nil { return backend.DownstreamError(err) } return nil } -func (c *Connector) ping(ctx context.Context, conn dbConnection) error { - if c.driverSettings.Timeout == 0 { - return conn.db.PingContext(ctx) - } +func (c *HydrolixConnector) ping(conn dbConnection) error { - ctx, cancel := context.WithTimeout(ctx, c.driverSettings.Timeout) - defer cancel() - - return conn.db.PingContext(ctx) + return conn.db.Ping() } -func (c *Connector) Reconnect(ctx context.Context, dbConn dbConnection, q *Query, cacheKey string) (*sql.DB, error) { +func (c *HydrolixConnector) Reconnect(ctx context.Context, dbConn dbConnection, q *sqlutil.Query, cacheKey string) (*sql.DB, error) { if err := dbConn.db.Close(); err != nil { - backend.Logger.Warn(fmt.Sprintf("closing existing connection failed: %s", err.Error())) + backend.Logger.Warn("closing existing connection failed", "error", err.Error()) } - db, err := c.driver.Connect(ctx, dbConn.settings, q.ConnectionArgs) + db, err := c.Driver.Connect(ctx, dbConn.settings, q.ConnectionArgs) if err != nil { + if db != nil { + _ = db.Close() + } return nil, backend.DownstreamError(err) } c.storeDBConnection(cacheKey, dbConnection{db, dbConn.settings}) return db, nil } -func (ds *Connector) getDBConnection(key string) (dbConnection, bool) { - conn, ok := ds.connections.Load(key) - if !ok { +func (c *HydrolixConnector) getDBConnection(key string) (dbConnection, bool) { + conn := c.connections.Get(key) + if conn == nil { return dbConnection{}, false } - return conn.(dbConnection), true + return conn.Value(), true } -func (ds *Connector) storeDBConnection(key string, dbConn dbConnection) { - ds.connections.Store(key, dbConn) +func (c *HydrolixConnector) storeDBConnectionWithTTL(key string, dbConn dbConnection, ttl time.Duration) { + c.connections.Set(key, dbConn, ttl) +} +func (c *HydrolixConnector) storeDBConnection(key string, dbConn dbConnection) { + c.storeDBConnectionWithTTL(key, dbConn, ttlcache.DefaultTTL) } // Dispose is called when an existing SQLDatasource needs to be replaced -func (c *Connector) Dispose() { - c.connections.Range(func(_, conn interface{}) bool { - _ = conn.(dbConnection).db.Close() - return true - }) - c.connections.Clear() +func (c *HydrolixConnector) Dispose() { + c.connections.DeleteAll() + c.connections.Stop() } -func (c *Connector) GetConnectionFromQuery(ctx context.Context, q *Query) (string, dbConnection, error) { - if !c.enableMultipleConnections && !c.driverSettings.ForwardHeaders && len(q.ConnectionArgs) > 0 && string(q.ConnectionArgs) != "{}" { - return "", dbConnection{}, MissingMultipleConnectionsConfig - } +func (c *HydrolixConnector) getDriverSettings() DriverSettings { + return c.driverSettings +} + +func (c *HydrolixConnector) GetDriver() Driver { + return c.Driver +} +func (c *HydrolixConnector) GetUID() string { + return c.UID +} +func (c *HydrolixConnector) getInstanceSettings() backend.DataSourceInstanceSettings { + return c.instanceSettings +} + +func (c *HydrolixConnector) GetConnectionFromQuery(ctx context.Context, q *sqlutil.Query) (string, dbConnection, error) { + // The database connection may vary depending on query arguments // The raw arguments are used as key to store the db connection in memory so they can be reused - key := defaultKey(c.UID) - dbConn, ok := c.getDBConnection(key) - if !ok { - return "", dbConnection{}, MissingDBConnection - } - if !c.enableMultipleConnections || len(q.ConnectionArgs) == 0 { - backend.Logger.Debug("using single user connection") + if len(q.ConnectionArgs) == 0 { + key := defaultKey(c.UID) + dbConn, ok := c.getDBConnection(key) + + if !ok { + // Connection not in cache (expired or never created), establish a new one + db, err := c.Driver.Connect(ctx, c.instanceSettings, nil) + if err != nil { + return "", dbConnection{}, backend.DownstreamError(err) + } + dbConn = dbConnection{db, c.instanceSettings} + c.storeDBConnection(key, dbConn) + } return key, dbConn, nil - } - - key = keyWithConnectionArgs(c.UID, q.ConnectionArgs) - if cachedConn, ok := c.getDBConnection(key); ok { - backend.Logger.Debug("cached connection") - return key, cachedConn, nil - } + } else { + key := keyWithConnectionArgs(c.UID, q.ConnectionArgs) + if cachedConn, ok := c.getDBConnection(key); ok { + return key, cachedConn, nil + } - db, err := c.driver.Connect(ctx, dbConn.settings, q.ConnectionArgs) - if err != nil { - backend.Logger.Debug("connect error " + err.Error()) - return "", dbConnection{}, backend.DownstreamError(err) + db, err := c.Driver.Connect(ctx, c.instanceSettings, q.ConnectionArgs) + if err != nil { + return "", dbConnection{}, backend.DownstreamError(err) + } + // Assign this connection in the cache + dbConn := dbConnection{db, c.instanceSettings} + c.storeDBConnection(key, dbConn) + + return key, dbConn, nil } - backend.Logger.Debug("new connection(multiple) created") - // Assign this connection in the cache - dbConn = dbConnection{db, dbConn.settings} - c.storeDBConnection(key, dbConn) +} - return key, dbConn, nil +func getOAuthConnectionArgs(header string) json.RawMessage { + q := &sqlutil.Query{} + headers := http.Header{} + headers.Set(backend.OAuthIdentityTokenHeaderName, header) + applyHeaders(q, headers) + return q.ConnectionArgs } func shouldRetry(retryOn []string, err string) bool { diff --git a/connector_test.go b/connector_test.go new file mode 100644 index 0000000..7c000a0 --- /dev/null +++ b/connector_test.go @@ -0,0 +1,476 @@ +package sqlds + +import ( + "context" + "database/sql" + "encoding/json" + "github.com/DATA-DOG/go-sqlmock" + "github.com/grafana/grafana-plugin-sdk-go/backend" + "github.com/grafana/grafana-plugin-sdk-go/data/sqlutil" + "github.com/hydrolix/sqlds/v5/models" + "net/http" + "sync/atomic" + "testing" + "time" +) + +// --- helpers --- + +type stubDriver struct { + settings DriverSettings + connectDBs []*sql.DB + connectErrs []error + connectCalls int32 +} + +func (d *stubDriver) Settings(_ context.Context, _ backend.DataSourceInstanceSettings) DriverSettings { + return d.settings +} + +func (d *stubDriver) Connect(_ context.Context, _ backend.DataSourceInstanceSettings, _ json.RawMessage) (*sql.DB, error) { + i := int(atomic.AddInt32(&d.connectCalls, 1)) - 1 + var db *sql.DB + var err error + if i < len(d.connectDBs) { + db = d.connectDBs[i] + } + if i < len(d.connectErrs) { + err = d.connectErrs[i] + } + // Fallback when arrays shorter: last provided + if db == nil && len(d.connectDBs) > 0 { + db = d.connectDBs[len(d.connectDBs)-1] + } + if err == nil && len(d.connectErrs) > 0 && i >= len(d.connectErrs) { + err = d.connectErrs[len(d.connectErrs)-1] + } + return db, err +} +func (d *stubDriver) Macros() sqlutil.Macros { + return make(sqlutil.Macros) +} +func (d *stubDriver) Converters() []sqlutil.Converter { + return []sqlutil.Converter{} +} + +func newSqlmockDB(t *testing.T) (*sql.DB, sqlmock.Sqlmock) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("sqlmock.New() failed: %v", err) + } + t.Cleanup(func() { _ = db.Close() }) + return db, mock +} + +// Dummy instance settings +func inst(uid string) backend.DataSourceInstanceSettings { + return backend.DataSourceInstanceSettings{UID: uid} +} + +// --- tests --- + +func TestShouldRetry(t *testing.T) { + cases := []struct { + retryOn []string + err string + want bool + }{ + {[]string{"timeout", "deadlock"}, "query timeout occurred", true}, + {[]string{"temporary"}, "temporary network issue", true}, + {[]string{"temporary"}, "permanent failure", false}, + {nil, "anything", false}, + } + for _, c := range cases { + if got := shouldRetry(c.retryOn, c.err); got != c.want { + t.Fatalf("shouldRetry(%v,%q)=%v want %v", c.retryOn, c.err, got, c.want) + } + } +} + +func TestApplyHeaders(t *testing.T) { + q := &sqlutil.Query{} + h := http.Header{} + h.Set("X-Auth", "abc") + h.Add("X-Auth", "def") + h.Set("X-User", "alice") + + out := applyHeaders(q, h) + if string(out.ConnectionArgs) == "" { + t.Fatalf("ConnectionArgs empty") + } + var args map[string]any + if err := json.Unmarshal(out.ConnectionArgs, &args); err != nil { + t.Fatalf("json.Unmarshal: %v", err) + } + raw, ok := args[HeaderKey] + if !ok { + t.Fatalf("expected %q key in ConnectionArgs", HeaderKey) + } + // http.Header marshals as map[string][]string + m, ok := raw.(map[string]any) + if !ok { + t.Fatalf("expected header map, got %T", raw) + } + if _, ok := m["X-Auth"]; !ok { + t.Fatalf("missing X-Auth in headers") + } + if _, ok := m["X-User"]; !ok { + t.Fatalf("missing X-User in headers") + } +} + +func TestReconnectClosesAndReplacesConnection(t *testing.T) { + // initial connection (created by NewConnector) + initDB, initMock := newSqlmockDB(t) + initMock.ExpectClose().WillReturnError(nil) + + // new connection returned by Reconnect + newDB, _ := newSqlmockDB(t) + + driver := &stubDriver{ + settings: DriverSettings{}, + connectDBs: []*sql.DB{initDB, newDB}, + } + connector, err := NewConnector(context.Background(), driver, buildInstanceSettings()) + if err != nil { + t.Fatalf("NewConnector: %v", err) + } + key := defaultKey(connector.GetUID()) + dbConn, _ := connector.getDBConnection(key) + + gotDB, err := connector.Reconnect(context.Background(), dbConn, &sqlutil.Query{}, key) + if err != nil { + t.Fatalf("Reconnect: %v", err) + } + if gotDB != newDB { + t.Fatalf("Reconnect returned wrong db") + } + // Ensure close on old was called + if err := initMock.ExpectationsWereMet(); err != nil { + t.Fatalf("sqlmock expectations: %v", err) + } +} + +func TestGetConnectionFromQuery_NoArgs_ReturnsDefault(t *testing.T) { + db, _ := newSqlmockDB(t) + driver := &stubDriver{ + settings: DriverSettings{}, + connectDBs: []*sql.DB{db}, + } + connector, err := NewConnector(context.Background(), driver, buildInstanceSettings()) + if err != nil { + t.Fatalf("NewConnector: %v", err) + } + + q := &sqlutil.Query{} // no args + key, dbConn, err := connector.GetConnectionFromQuery(context.Background(), q) + if err != nil { + t.Fatalf("GetConnectionFromQuery: %v", err) + } + if key == "" { + t.Fatalf("expected non-empty key") + } + if dbConn.db == nil { + t.Fatalf("expected non-nil db") + } +} + +func TestGetConnectionFromQuery_NewArgs_CachesPerArgs(t *testing.T) { + // initial connection created by NewConnector + initDB, _ := newSqlmockDB(t) + // two distinct new DBs for two distinct arg sets (only first used twice) + dbA1, _ := newSqlmockDB(t) + dbA2, _ := newSqlmockDB(t) // this should NOT be used because first is cached + dbB, _ := newSqlmockDB(t) + + driver := &stubDriver{ + settings: DriverSettings{}, + connectDBs: []*sql.DB{initDB, dbA1, dbA2, dbB}, + } + connector, err := NewConnector(context.Background(), driver, buildInstanceSettings()) + if err != nil { + t.Fatalf("NewConnector: %v", err) + } + + ctx := context.Background() + qA := &sqlutil.Query{ConnectionArgs: []byte(`{"tenant":"A"}`)} + qB := &sqlutil.Query{ConnectionArgs: []byte(`{"tenant":"B"}`)} + + // First time with A -> creates and caches + keyA1, connA1, err := connector.GetConnectionFromQuery(ctx, qA) + if err != nil { + t.Fatalf("GetConnectionFromQuery A1: %v", err) + } + // Second time with same args -> should be cached (no extra Connect) + keyA2, connA2, err := connector.GetConnectionFromQuery(ctx, qA) + if err != nil { + t.Fatalf("GetConnectionFromQuery A2: %v", err) + } + if keyA1 != keyA2 || connA1.db != connA2.db { + t.Fatalf("expected cached connection for same args") + } + + // Different args -> new connection + keyB, connB, err := connector.GetConnectionFromQuery(ctx, qB) + if err != nil { + t.Fatalf("GetConnectionFromQuery B: %v", err) + } + if keyB == keyA1 || connB.db == connA1.db { + t.Fatalf("expected different key/connection for different args") + } +} + +func TestDispose_ClosesAllAndClears(t *testing.T) { + db1, mock1 := newSqlmockDB(t) + db2, mock2 := newSqlmockDB(t) + mock1.ExpectClose().WillReturnError(nil) + mock2.ExpectClose().WillReturnError(nil) + + driver := &stubDriver{ + settings: DriverSettings{}, + connectDBs: []*sql.DB{db1}, + } + connector, err := NewConnector(context.Background(), driver, buildInstanceSettings()) + if err != nil { + t.Fatalf("NewConnector: %v", err) + } + + // Manually store another connection to ensure both are closed + connector.storeDBConnection("extra", dbConnection{db2, inst("uid9")}) + + // Dispose should close both and clear map + connector.Dispose() + + // sleep while ttlcache calls eviction callback for connection + time.Sleep(100 * time.Millisecond) + + // Both closes must have been hit + if err := mock1.ExpectationsWereMet(); err != nil { + t.Fatalf("db1 expectations: %v", err) + } + if err := mock2.ExpectationsWereMet(); err != nil { + t.Fatalf("db2 expectations: %v", err) + } + + // After Clear, we shouldn't find previous keys + if _, ok := connector.getDBConnection(defaultKey(connector.GetUID())); ok { + t.Fatalf("expected connections map to be cleared") + } + if _, ok := connector.getDBConnection("extra"); ok { + t.Fatalf("expected connections map to be cleared") + } +} + +func TestNewConnector_ForwardOAuth_SkipsInitialConnect(t *testing.T) { + driver := &stubDriver{ + settings: DriverSettings{ForwardHeaders: true}, + connectDBs: []*sql.DB{}, // no DBs provided — Connect should NOT be called + } + connector, err := NewConnector(context.Background(), driver, buildForwardOAuthInstanceSettings()) + if err != nil { + t.Fatalf("NewConnector: %v", err) + } + + // No connection should be cached (forwardOAuth skips initial connect) + key := defaultKey(connector.GetUID()) + _, ok := connector.getDBConnection(key) + if ok { + t.Fatalf("expected no cached connection for forwardOAuth, but found one") + } + + // Driver.Connect should not have been called + if driver.connectCalls != 0 { + t.Fatalf("expected 0 Connect calls for forwardOAuth, got %d", driver.connectCalls) + } +} + +func TestNewConnector_UserAccount_ConnectsImmediately(t *testing.T) { + db, _ := newSqlmockDB(t) + driver := &stubDriver{ + settings: DriverSettings{}, + connectDBs: []*sql.DB{db}, + } + connector, err := NewConnector(context.Background(), driver, buildInstanceSettings()) + if err != nil { + t.Fatalf("NewConnector: %v", err) + } + + // Connection should be cached + key := defaultKey(connector.GetUID()) + dbConn, ok := connector.getDBConnection(key) + if !ok { + t.Fatalf("expected cached connection for userAccount") + } + if dbConn.db != db { + t.Fatalf("cached DB does not match the one provided by driver") + } + + // Driver.Connect should have been called once + if driver.connectCalls != 1 { + t.Fatalf("expected 1 Connect call, got %d", driver.connectCalls) + } +} + +func TestGetOAuthConnectionArgs(t *testing.T) { + args := getOAuthConnectionArgs("Bearer my-oauth-token") + if args == nil { + t.Fatalf("expected non-nil ConnectionArgs") + } + + var parsed map[string]any + if err := json.Unmarshal(args, &parsed); err != nil { + t.Fatalf("json.Unmarshal: %v", err) + } + + raw, ok := parsed[HeaderKey] + if !ok { + t.Fatalf("expected %q key in ConnectionArgs", HeaderKey) + } + + m, ok := raw.(map[string]any) + if !ok { + t.Fatalf("expected header map, got %T", raw) + } + if _, ok := m["Authorization"]; !ok { + t.Fatalf("missing Authorization in headers") + } +} + +func TestGetOAuthConnectionArgs_EmptyHeaders(t *testing.T) { + args := getOAuthConnectionArgs("") + if args == nil { + t.Fatalf("expected non-nil ConnectionArgs even for empty headers") + } + + var parsed map[string]any + if err := json.Unmarshal(args, &parsed); err != nil { + t.Fatalf("json.Unmarshal: %v", err) + } + + _, ok := parsed[HeaderKey] + if !ok { + t.Fatalf("expected %q key in ConnectionArgs even for empty headers", HeaderKey) + } +} + +func TestGetConnectionFromQuery_WithArgs_CreatesNewConnection(t *testing.T) { + initDB, _ := newSqlmockDB(t) + newDB, _ := newSqlmockDB(t) + + driver := &stubDriver{ + settings: DriverSettings{}, + connectDBs: []*sql.DB{initDB, newDB}, + } + connector, err := NewConnector(context.Background(), driver, buildInstanceSettings()) + if err != nil { + t.Fatalf("NewConnector: %v", err) + } + + q := &sqlutil.Query{ConnectionArgs: []byte(`{"tenant":"A"}`)} + _, dbConn, err := connector.GetConnectionFromQuery(context.Background(), q) + if err != nil { + t.Fatalf("GetConnectionFromQuery: %v", err) + } + if dbConn.db != newDB { + t.Fatalf("expected new connection for new args") + } +} + +func buildForwardOAuthInstanceSettings() backend.DataSourceInstanceSettings { + settings := models.PluginSettings{ + Host: "localhost", + Port: 80, + Protocol: "http", + UserName: "", + Password: "", + CredentialsType: "forwardOAuth", + Secure: true, + Path: "/query", + SkipTlsVerify: true, + DialTimeout: "10", + QueryTimeout: "20", + DefaultDatabase: "foo", + } + jsonData, _ := json.Marshal(settings) + + return backend.DataSourceInstanceSettings{ + Name: "test-hydrolix-oauth-datasource", + JSONData: jsonData, + DecryptedSecureJSONData: map[string]string{}, + } +} + +type MockConnector struct { + db *sql.DB + uid string + connCalls int +} + +func (m *MockConnector) Connect(_ context.Context, _ http.Header) (*dbConnection, error) { + return &dbConnection{db: m.db}, nil +} +func (m *MockConnector) connectWithRetries(_ context.Context, _ dbConnection, _ string, _ http.Header) error { + return nil +} +func (m *MockConnector) connect(_ dbConnection) error { return nil } +func (m *MockConnector) ping(_ dbConnection) error { return nil } + +func (m *MockConnector) Reconnect(_ context.Context, _ dbConnection, _ *sqlutil.Query, _ string) (*sql.DB, error) { + return m.db, nil +} + +func (m *MockConnector) getDBConnection(_ string) (dbConnection, bool) { + m.connCalls++ + return dbConnection{db: m.db}, true +} + +func (m *MockConnector) storeDBConnection(_ string, _ dbConnection) {} + +func (m *MockConnector) Dispose() {} + +func (m *MockConnector) GetConnectionFromQuery(_ context.Context, _ *sqlutil.Query) (string, dbConnection, error) { + m.connCalls++ + return "key", dbConnection{db: m.db}, nil +} + +func (m *MockConnector) GetDriver() Driver { + return &stubDriver{ + settings: DriverSettings{}, + connectDBs: []*sql.DB{}, + } +} + +func (m *MockConnector) GetUID() string { return m.uid } + +func (m *MockConnector) getDriverSettings() DriverSettings { return DriverSettings{} } + +func (m *MockConnector) getInstanceSettings() backend.DataSourceInstanceSettings { + return buildInstanceSettings() +} + +func buildInstanceSettingsWithUID(uid string) backend.DataSourceInstanceSettings { + settings := models.PluginSettings{ + Host: "localhost", + Port: 80, + Protocol: "http", + UserName: "default", + Password: "pass", + Secure: true, + Path: "/query", + SkipTlsVerify: true, + DialTimeout: "10", + QueryTimeout: "20", + DefaultDatabase: "foo", + } + jsonData, _ := json.Marshal(settings) + return backend.DataSourceInstanceSettings{ + UID: uid, + Name: "test-hydrolix-http-datasource", + JSONData: jsonData, + DecryptedSecureJSONData: map[string]string{"password": settings.Password}, + } +} + +func buildInstanceSettings() backend.DataSourceInstanceSettings { + return buildInstanceSettingsWithUID("uid1") +} diff --git a/dataframe_test.go b/dataframe_test.go index cef1d73..829722e 100644 --- a/dataframe_test.go +++ b/dataframe_test.go @@ -8,8 +8,8 @@ import ( "github.com/grafana/grafana-plugin-sdk-go/backend" "github.com/grafana/grafana-plugin-sdk-go/data/sqlutil" - "github.com/grafana/sqlds/v5" - "github.com/grafana/sqlds/v5/test" + "github.com/hydrolix/sqlds/v5" + "github.com/hydrolix/sqlds/v5/test" "github.com/stretchr/testify/require" ) @@ -158,12 +158,14 @@ func TestNoRowsFrame(t *testing.T) { for _, tt := range tts { t.Run(tt.name, func(t *testing.T) { id := "empty-frames" + tt.name - driver, _ := test.NewDriver(id, tt.data, nil, test.DriverOpts{}, nil) - ds := sqlds.NewDatasource(driver) + driver, _ := test.NewDriver(id, tt.data, nil, test.DriverOpts{}) - settings := backend.DataSourceInstanceSettings{UID: id, JSONData: []byte("{}")} - _, err := ds.NewDatasource(context.Background(), settings) + settings := backend.DataSourceInstanceSettings{UID: id, JSONData: []byte(`{"host":"localhost","port":9000,"protocol":"native"}`)} + connector, err := sqlds.NewConnector(context.Background(), driver, settings) + require.NoError(t, err) + ds := &sqlds.HydrolixDatasource{Connector: connector} + _, err = ds.NewDatasource(context.Background(), settings) require.NoError(t, err) req := backend.QueryDataRequest{ diff --git a/datasource.go b/datasource.go index 34e3e1b..d6d8ba1 100644 --- a/datasource.go +++ b/datasource.go @@ -7,6 +7,7 @@ import ( "encoding/json" "errors" "fmt" + "github.com/grafana/grafana-plugin-sdk-go/backend/log" "net/http" "os" "runtime/debug" @@ -14,13 +15,11 @@ import ( "sync" "time" - "github.com/grafana/grafana-plugin-sdk-go/data/sqlutil" - "github.com/grafana/grafana-plugin-sdk-go/backend" "github.com/grafana/grafana-plugin-sdk-go/backend/instancemgmt" - "github.com/grafana/grafana-plugin-sdk-go/backend/log" "github.com/grafana/grafana-plugin-sdk-go/backend/resource/httpadapter" "github.com/grafana/grafana-plugin-sdk-go/data" + "github.com/grafana/grafana-plugin-sdk-go/data/sqlutil" ) const defaultKeySuffix = "default" @@ -28,13 +27,8 @@ const defaultRowLimit = int64(-1) const envRowLimit = "GF_DATAPROXY_ROW_LIMIT" var ( - ErrorMissingMultipleConnectionsConfig = backend.PluginError(errors.New("received connection arguments but the feature is not enabled")) - ErrorMissingDBConnection = backend.PluginError(errors.New("unable to get default db connection")) - HeaderKey = "grafana-http-headers" - // Deprecated: ErrorMissingMultipleConnectionsConfig should be used instead - MissingMultipleConnectionsConfig = ErrorMissingMultipleConnectionsConfig - // Deprecated: ErrorMissingDBConnection should be used instead - MissingDBConnection = ErrorMissingDBConnection + HeaderKey = "grafana-http-headers" + ErrorParsingMacroBrackets = errors.New("failed to parse macro arguments (missing close bracket?)") ) func defaultKey(datasourceUID string) string { @@ -51,74 +45,54 @@ type dbConnection struct { settings backend.DataSourceInstanceSettings } -type SQLDatasource struct { - Completable +type HydrolixDatasource struct { backend.CallResourceHandler - connector *Connector - CustomRoutes map[string]func(http.ResponseWriter, *http.Request) - metrics Metrics - EnableMultipleConnections bool + Connector Connector + ID string + Interpolator Interpolator + metrics Metrics // EnableRowLimit: enables using the dataproxy.row_limit setting to limit the number of rows returned by the query // https://grafana.com/docs/grafana/latest/setup-grafana/configure-grafana/#row_limit EnableRowLimit bool rowLimit int64 - // PreCheckHealth (optional). Performs custom health check before the Connect method - PreCheckHealth func(ctx context.Context, req *backend.CheckHealthRequest) *backend.CheckHealthResult - // PostCheckHealth (optional).Performs custom health check after the Connect method - PostCheckHealth func(ctx context.Context, req *backend.CheckHealthRequest) *backend.CheckHealthResult - // ResourceMiddleware (optional). Allows interception to CallResource before it is passed to sqlds - ResourceMiddleware func(next backend.CallResourceHandler) backend.CallResourceHandler } // NewDatasource creates a new `SQLDatasource`. // It uses the provided settings argument to call the ds.Driver to connect to the SQL server -func (ds *SQLDatasource) NewDatasource(ctx context.Context, settings backend.DataSourceInstanceSettings) (instancemgmt.Instance, error) { - conn, err := NewConnector(ctx, ds.driver(), settings, ds.EnableMultipleConnections) - if err != nil { - return nil, backend.DownstreamError(err) - } - - ds.connector = conn - mux := http.NewServeMux() - err = ds.registerRoutes(mux) - if err != nil { - return nil, backend.PluginError(err) - } - - ds.CallResourceHandler = httpadapter.New(mux) - if ds.ResourceMiddleware != nil { - ds.CallResourceHandler = ds.ResourceMiddleware(ds.CallResourceHandler) - } +func (ds *HydrolixDatasource) NewDatasource(ctx context.Context, settings backend.DataSourceInstanceSettings) (instancemgmt.Instance, error) { + ds.Interpolator = NewInterpolator(ds) + ds.ID = settings.UID ds.metrics = NewMetrics(settings.Name, settings.Type, EndpointQuery) - ds.rowLimit = ds.newRowLimit(ctx, conn) + ds.rowLimit = ds.newRowLimit(ctx, ds.Connector) return ds, nil } -// NewDatasource initializes the Datasource wrapper and instance manager -func NewDatasource(c Driver) *SQLDatasource { - return &SQLDatasource{ - connector: &Connector{driver: c}, +func (ds *HydrolixDatasource) RegisterRoutes(customRoutes map[string]func(http.ResponseWriter, *http.Request)) { + mux := http.NewServeMux() + for route, handler := range customRoutes { + mux.HandleFunc(route, handler) } + + ds.CallResourceHandler = httpadapter.New(mux) } // Dispose cleans up datasource instance resources. // Note: Called when testing and saving a datasource -func (ds *SQLDatasource) Dispose() { - ds.connector.Dispose() +func (ds *HydrolixDatasource) Dispose() { + ds.Connector.Dispose() } // QueryData creates the Responses list and executes each query -func (ds *SQLDatasource) QueryData(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) { - headers := req.GetHTTPHeaders() +func (ds *HydrolixDatasource) QueryData(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) { + headers := req.GetHTTPHeaders() var ( response = NewResponse(backend.NewQueryDataResponse()) wg = sync.WaitGroup{} ) - wg.Add(len(req.Queries)) if queryDataMutator, ok := ds.driver().(QueryDataMutator); ok { @@ -172,9 +146,7 @@ func (ds *SQLDatasource) QueryData(ctx context.Context, req *backend.QueryDataRe }) }(q) } - wg.Wait() - errs := ds.errors(response) if ds.DriverSettings().Errors { return response.Response(), errs @@ -183,13 +155,14 @@ func (ds *SQLDatasource) QueryData(ctx context.Context, req *backend.QueryDataRe return response.Response(), nil } -func (ds *SQLDatasource) GetDBFromQuery(ctx context.Context, q *Query) (*sql.DB, error) { - _, dbConn, err := ds.connector.GetConnectionFromQuery(ctx, q) +func (ds *HydrolixDatasource) GetDBFromQuery(ctx context.Context, q *sqlutil.Query) (*sql.DB, error) { + + _, dbConn, err := ds.Connector.GetConnectionFromQuery(ctx, q) return dbConn.db, err } // handleQuery will call query, and attempt to reconnect if the query failed -func (ds *SQLDatasource) handleQuery(ctx context.Context, req backend.DataQuery, headers http.Header) (data.Frames, error) { +func (ds *HydrolixDatasource) handleQuery(ctx context.Context, req backend.DataQuery, headers http.Header) (data.Frames, error) { if queryMutator, ok := ds.driver().(QueryMutator); ok { ctx, req = queryMutator.MutateQuery(ctx, req) } @@ -200,10 +173,13 @@ func (ds *SQLDatasource) handleQuery(ctx context.Context, req backend.DataQuery, return nil, err } - // Apply supported macros to the query - q.RawSQL, err = Interpolate(ds.driver(), q) + hdxQuery, err := GetHdxQuery(req, headers, nil, nil) + if err != nil { + return nil, err + } + q.RawSQL, err = ds.Interpolator.Interpolate(ctx, hdxQuery) if err != nil { - if errors.Is(err, sqlutil.ErrorBadArgumentCount) || err.Error() == ErrorParsingMacroBrackets.Error() { + if errors.Is(err, sqlutil.ErrorBadArgumentCount) || errors.Is(err, ErrorParsingMacroBrackets) { err = backend.DownstreamError(err) } return sqlutil.ErrorFrameFromQuery(q), fmt.Errorf("%s: %w", "Could not apply macros", err) @@ -216,7 +192,7 @@ func (ds *SQLDatasource) handleQuery(ctx context.Context, req backend.DataQuery, } // Retrieve the database connection - cacheKey, dbConn, err := ds.connector.GetConnectionFromQuery(ctx, q) + cacheKey, dbConn, err := ds.Connector.GetConnectionFromQuery(ctx, q) if err != nil { return sqlutil.ErrorFrameFromQuery(q), err } @@ -240,7 +216,7 @@ func (ds *SQLDatasource) handleQuery(ctx context.Context, req backend.DataQuery, // FIXES: // * Some datasources (snowflake) expire connections or have an authentication token that expires if not used in 1 or 4 hours. - // Because the datasource driver does not include an option for permanent connections, we retry the connection + // Because the datasource Driver does not include an option for permanent connections, we retry the connection // if the query fails. NOTE: this does not include some errors like "ErrNoRows" dbQuery := NewQuery(dbConn.db, dbConn.settings, ds.driver().Converters(), fillMode, ds.rowLimit) res, err := dbQuery.Run(ctx, q, queryErrorMutator, args...) @@ -258,8 +234,8 @@ func (ds *SQLDatasource) handleQuery(ctx context.Context, req backend.DataQuery, // only retry on messages that contain specific errors if shouldRetry(ds.DriverSettings().RetryOn, err.Error()) { for i := 0; i < ds.DriverSettings().Retries; i++ { - backend.Logger.Warn(fmt.Sprintf("query failed: %s. Retrying %d times", err.Error(), i)) - db, err := ds.connector.Reconnect(ctx, dbConn, q, cacheKey) + backend.Logger.Warn("query failed", "error", err.Error(), "retry", i) + db, err := ds.Connector.Reconnect(ctx, dbConn, q, cacheKey) if err != nil { return nil, backend.DownstreamError(err) } @@ -276,7 +252,7 @@ func (ds *SQLDatasource) handleQuery(ctx context.Context, req backend.DataQuery, if !shouldRetry(ds.DriverSettings().RetryOn, err.Error()) { return res, err } - backend.Logger.Warn(fmt.Sprintf("Retry failed: %s", err.Error())) + backend.Logger.Warn("Retry failed", "error", err.Error()) } } } @@ -292,8 +268,8 @@ func (ds *SQLDatasource) handleQuery(ctx context.Context, req backend.DataQuery, // allow retries on timeouts if errors.Is(err, context.DeadlineExceeded) { for i := 0; i < ds.DriverSettings().Retries; i++ { - backend.Logger.Warn(fmt.Sprintf("connection timed out. retrying %d times", i)) - db, err := ds.connector.Reconnect(ctx, dbConn, q, cacheKey) + backend.Logger.Warn("connection timed out", "retry", i) + db, err := ds.Connector.Reconnect(ctx, dbConn, q, cacheKey) if err != nil { continue } @@ -310,28 +286,27 @@ func (ds *SQLDatasource) handleQuery(ctx context.Context, req backend.DataQuery, } // CheckHealth pings the connected SQL database -func (ds *SQLDatasource) CheckHealth(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) { +func (ds *HydrolixDatasource) CheckHealth(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) { if checkHealthMutator, ok := ds.driver().(CheckHealthMutator); ok { ctx, req = checkHealthMutator.MutateCheckHealth(ctx, req) } healthChecker := &HealthChecker{ - Connector: ds.connector, - Metrics: ds.metrics.WithEndpoint(EndpointHealth), - PreCheckHealth: ds.PreCheckHealth, - PostCheckHealth: ds.PostCheckHealth, + Connector: ds.Connector, + Metrics: ds.metrics.WithEndpoint(EndpointHealth), } return healthChecker.Check(ctx, req) } -func (ds *SQLDatasource) DriverSettings() DriverSettings { - return ds.connector.driverSettings +func (ds *HydrolixDatasource) DriverSettings() DriverSettings { + return ds.Connector.getDriverSettings() } -func (ds *SQLDatasource) driver() Driver { - return ds.connector.driver +func (ds *HydrolixDatasource) driver() Driver { + + return ds.Connector.GetDriver() } -func (ds *SQLDatasource) errors(response *Response) error { +func (ds *HydrolixDatasource) errors(response *Response) error { if response == nil { return nil } @@ -349,11 +324,11 @@ func (ds *SQLDatasource) errors(response *Response) error { return err } -func (ds *SQLDatasource) GetRowLimit() int64 { +func (ds *HydrolixDatasource) GetRowLimit() int64 { return ds.rowLimit } -func (ds *SQLDatasource) SetDefaultRowLimit(limit int64) { +func (ds *HydrolixDatasource) SetDefaultRowLimit(limit int64) { ds.EnableRowLimit = true ds.rowLimit = limit } @@ -364,13 +339,13 @@ func (ds *SQLDatasource) SetDefaultRowLimit(limit int64) { // 2. set via the environment variable // 3. set is set on grafana_ini and passed via grafana context // 4. default row limit set by SetDefaultRowLimit -func (ds *SQLDatasource) newRowLimit(ctx context.Context, conn *Connector) int64 { +func (ds *HydrolixDatasource) newRowLimit(ctx context.Context, conn Connector) int64 { if !ds.EnableRowLimit { return defaultRowLimit } // Handles when row limit is set in the datasource configuration page - settingsLimit := conn.driverSettings.RowLimit + settingsLimit := conn.getDriverSettings().RowLimit if settingsLimit != 0 { return settingsLimit } diff --git a/datasource_connect_test.go b/datasource_connect_test.go index 1150173..723f1ff 100644 --- a/datasource_connect_test.go +++ b/datasource_connect_test.go @@ -5,7 +5,6 @@ import ( "database/sql" "database/sql/driver" "encoding/json" - "errors" "testing" "github.com/grafana/grafana-plugin-sdk-go/backend" @@ -22,8 +21,8 @@ func (d fakeDriver) Connect(_ context.Context, _ backend.DataSourceInstanceSetti return d.openDBfn(msg) } -func (d fakeDriver) Macros() Macros { - return Macros{} +func (d fakeDriver) Settings(context.Context, backend.DataSourceInstanceSettings) DriverSettings { + return DriverSettings{} } func (d fakeDriver) Converters() []sqlutil.Converter { @@ -78,11 +77,13 @@ func Test_getDBConnectionFromQuery(t *testing.T) { } for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { - conn := &Connector{UID: tt.dsUID, driver: d, enableMultipleConnections: true, driverSettings: DriverSettings{}} - settings := backend.DataSourceInstanceSettings{UID: tt.dsUID} + settings := buildInstanceSettings() + + conn, err := NewConnector(context.Background(), d, settings) key := defaultKey(tt.dsUID) // Add the mandatory default db conn.storeDBConnection(key, dbConnection{db, settings}) + if tt.existingDB != nil { key = keyWithConnectionArgs(tt.dsUID, []byte(tt.args)) conn.storeDBConnection(key, dbConnection{tt.existingDB, settings}) @@ -101,39 +102,31 @@ func Test_getDBConnectionFromQuery(t *testing.T) { }) } - t.Run("it should return an error if connection args are used without enabling multiple connections", func(t *testing.T) { - conn := &Connector{driver: d, enableMultipleConnections: false} - _, _, err := conn.GetConnectionFromQuery(context.Background(), &Query{ConnectionArgs: json.RawMessage("foo")}) - if err == nil || !errors.Is(err, MissingMultipleConnectionsConfig) { - t.Errorf("expecting error: %v", MissingMultipleConnectionsConfig) - } - }) - - t.Run("it should return an error if the default connection is missing", func(t *testing.T) { - conn := &Connector{driver: d} - _, _, err := conn.GetConnectionFromQuery(context.Background(), &Query{}) - if err == nil || !errors.Is(err, MissingDBConnection) { - t.Errorf("expecting error: %v", MissingDBConnection) - } - }) + //t.Run("it should return an error if the default connection is missing", func(t *testing.T) { + // conn := &HydrolixConnector{Driver: d} + // _, _, err := conn.GetConnectionFromQuery(context.Background(), &Query{}) + // if err == nil || !errors.Is(err, MissingDBConnection) { + // t.Errorf("expecting error: %v", MissingDBConnection) + // } + //}) } -func Test_Dispose(t *testing.T) { - t.Run("it should close connections", func(t *testing.T) { - db := sql.OpenDB(fakeSQLConnector{}) - d := &fakeDriver{openDBfn: func(msg json.RawMessage) (*sql.DB, error) { return db, nil }} - conn := &Connector{driver: d} - ds := &SQLDatasource{connector: conn} - conn.connections.Store(defaultKey("uid1"), dbConnection{db: db}) - conn.connections.Store("foo", dbConnection{db: db}) - ds.Dispose() - count := 0 - conn.connections.Range(func(key, value interface{}) bool { - count++ - return true - }) - if count != 0 { - t.Errorf("did not close all connections") - } - }) -} +//func Test_Dispose(t *testing.T) { +// t.Run("it should close connections", func(t *testing.T) { +// db := sql.OpenDB(fakeSQLConnector{}) +// d := &fakeDriver{openDBfn: func(msg json.RawMessage) (*sql.DB, error) { return db, nil }} +// conn := &HydrolixDatasource{Driver: d} +// ds := &SQLDatasource{connector: conn} +// conn.connections.Store(defaultKey("uid1"), dbConnection{db: db}) +// conn.connections.Store("foo", dbConnection{db: db}) +// ds.Dispose() +// count := 0 +// conn.connections.Range(func(key, value interface{}) bool { +// count++ +// return true +// }) +// if count != 0 { +// t.Errorf("did not close all connections") +// } +// }) +//} diff --git a/datasource_middleware_test.go b/datasource_middleware_test.go deleted file mode 100644 index 5c7de2a..0000000 --- a/datasource_middleware_test.go +++ /dev/null @@ -1,55 +0,0 @@ -package sqlds_test - -import ( - "context" - "testing" - - "github.com/grafana/grafana-plugin-sdk-go/backend" - "github.com/grafana/sqlds/v5" - "github.com/grafana/sqlds/v5/test" - "github.com/stretchr/testify/assert" -) - -func Test_resource_middleware_is_applied(t *testing.T) { - called := false - middleware := func(next backend.CallResourceHandler) backend.CallResourceHandler { - return backend.CallResourceHandlerFunc(func(ctx context.Context, req *backend.CallResourceRequest, sender backend.CallResourceResponseSender) error { - called = true - return next.CallResource(ctx, req, sender) - }) - } - - driver, _ := test.NewDriver("middleware-applied", test.Data{}, nil, test.DriverOpts{}, nil) - ds := sqlds.NewDatasource(driver) - ds.ResourceMiddleware = middleware - - settings := backend.DataSourceInstanceSettings{UID: "middleware-applied", JSONData: []byte("{}")} - _, err := ds.NewDatasource(context.Background(), settings) - assert.Nil(t, err) - - sender := &fakeResourceSender{} - err = ds.CallResource(context.Background(), &backend.CallResourceRequest{Path: "tables", Method: "GET"}, sender) - assert.Nil(t, err) - assert.True(t, called, "expected ResourceMiddleware to be called") -} - -func Test_resource_middleware_nil_is_skipped(t *testing.T) { - driver, _ := test.NewDriver("middleware-nil", test.Data{}, nil, test.DriverOpts{}, nil) - ds := sqlds.NewDatasource(driver) - // ResourceMiddleware intentionally left nil - - settings := backend.DataSourceInstanceSettings{UID: "middleware-nil", JSONData: []byte("{}")} - _, err := ds.NewDatasource(context.Background(), settings) - assert.Nil(t, err) - assert.NotNil(t, ds.CallResourceHandler, "expected CallResourceHandler to be set even without middleware") -} - -// fakeResourceSender captures the last sent response. -type fakeResourceSender struct { - response *backend.CallResourceResponse -} - -func (f *fakeResourceSender) Send(resp *backend.CallResourceResponse) error { - f.response = resp - return nil -} diff --git a/datasource_rowlimit_test.go b/datasource_rowlimit_test.go index 1784880..c4fffc2 100644 --- a/datasource_rowlimit_test.go +++ b/datasource_rowlimit_test.go @@ -6,8 +6,8 @@ import ( "testing" "github.com/grafana/grafana-plugin-sdk-go/backend" - "github.com/grafana/sqlds/v5" - "github.com/grafana/sqlds/v5/test" + "github.com/hydrolix/sqlds/v5" + "github.com/hydrolix/sqlds/v5/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -41,35 +41,19 @@ func TestRowLimitFromConfig(t *testing.T) { // Create datasource with row limit enabled driver := &mockDriver{} - ds := sqlds.NewDatasource(driver) - ds.EnableRowLimit = true - - // Create settings and initialize datasource settings := backend.DataSourceInstanceSettings{UID: "rowlimit-config", JSONData: []byte("{}")} - instance, err := ds.NewDatasource(ctx, settings) - require.NoError(t, err) + ds := newRowLimitTestDatasource(t, driver, settings, ctx, true) - // Verify row limit was set correctly from config - sqlDS, ok := instance.(*sqlds.SQLDatasource) - require.True(t, ok) - assert.Equal(t, int64(200), sqlDS.GetRowLimit()) + assert.Equal(t, int64(200), ds.GetRowLimit()) } func TestRowLimitFromDriverSettings(t *testing.T) { // Create datasource with driver that has row limit driver := &mockDriver{rowLimit: 300} - ds := sqlds.NewDatasource(driver) - ds.EnableRowLimit = true - - // Create settings and initialize datasource settings := backend.DataSourceInstanceSettings{UID: "rowlimit-driver", JSONData: []byte("{}")} - instance, err := ds.NewDatasource(context.Background(), settings) - require.NoError(t, err) + ds := newRowLimitTestDatasource(t, driver, settings, context.Background(), true) - // Verify driver settings row limit was used - sqlDS, ok := instance.(*sqlds.SQLDatasource) - require.True(t, ok) - assert.Equal(t, int64(300), sqlDS.GetRowLimit()) + assert.Equal(t, int64(300), ds.GetRowLimit()) } func TestRowLimitPrecedence(t *testing.T) { @@ -81,18 +65,10 @@ func TestRowLimitPrecedence(t *testing.T) { // Create datasource with driver that has row limit driver := &mockDriver{rowLimit: 300} - ds := sqlds.NewDatasource(driver) - ds.EnableRowLimit = true - - // Create settings and initialize datasource settings := backend.DataSourceInstanceSettings{UID: "rowlimit-precedence", JSONData: []byte("{}")} - instance, err := ds.NewDatasource(ctx, settings) - require.NoError(t, err) + ds := newRowLimitTestDatasource(t, driver, settings, ctx, true) - // Verify driver settings take precedence over config - sqlDS, ok := instance.(*sqlds.SQLDatasource) - require.True(t, ok) - assert.Equal(t, int64(300), sqlDS.GetRowLimit()) + assert.Equal(t, int64(300), ds.GetRowLimit()) } func TestRowLimitDisabled(t *testing.T) { @@ -103,18 +79,10 @@ func TestRowLimitDisabled(t *testing.T) { // Create datasource with row limit disabled driver := &mockDriver{} - ds := sqlds.NewDatasource(driver) - ds.EnableRowLimit = false - - // Create settings and initialize datasource settings := backend.DataSourceInstanceSettings{UID: "rowlimit-disabled", JSONData: []byte("{}")} - instance, err := ds.NewDatasource(ctx, settings) - require.NoError(t, err) + ds := newRowLimitTestDatasource(t, driver, settings, ctx, false) - // Verify default row limit is used when feature is disabled - sqlDS, ok := instance.(*sqlds.SQLDatasource) - require.True(t, ok) - assert.Equal(t, int64(-1), sqlDS.GetRowLimit()) + assert.Equal(t, int64(-1), ds.GetRowLimit()) } func TestRowLimitDefault(t *testing.T) { @@ -126,39 +94,21 @@ func TestRowLimitDefault(t *testing.T) { // Create datasource with row limit disabled driver := &mockDriver{} - ds := sqlds.NewDatasource(driver) - - // Create settings and initialize datasource settings := backend.DataSourceInstanceSettings{UID: "rowlimit-disabled", JSONData: []byte("{}")} - instance, err := ds.NewDatasource(ctx, settings) - require.NoError(t, err) + ds := newRowLimitTestDatasource(t, driver, settings, ctx, false) - // Verify default row limit is used when feature is disabled - sqlDS, ok := instance.(*sqlds.SQLDatasource) - require.True(t, ok) - assert.Equal(t, int64(-1), sqlDS.GetRowLimit()) + assert.Equal(t, int64(-1), ds.GetRowLimit()) } func TestSetDefaultRowLimit(t *testing.T) { - // Create datasource driver := &mockDriver{} - ds := sqlds.NewDatasource(driver) - - // Initialize datasource settings := backend.DataSourceInstanceSettings{UID: "rowlimit-set", JSONData: []byte("{}")} - instance, err := ds.NewDatasource(context.Background(), settings) - require.NoError(t, err) - - // Cast to SQLDatasource - sqlDS, ok := instance.(*sqlds.SQLDatasource) - require.True(t, ok) + ds := newRowLimitTestDatasource(t, driver, settings, context.Background(), false) - // Set row limit - sqlDS.SetDefaultRowLimit(500) + ds.SetDefaultRowLimit(500) - // Verify row limit was set correctly - assert.Equal(t, int64(500), sqlDS.GetRowLimit()) - assert.True(t, sqlDS.EnableRowLimit) + assert.Equal(t, int64(500), ds.GetRowLimit()) + assert.True(t, ds.EnableRowLimit) } func TestRowLimitPassedToQuery(t *testing.T) { @@ -176,18 +126,10 @@ func TestRowLimitPassedToQuery(t *testing.T) { } // Create datasource with row limit - driver, _ := test.NewDriver("rowlimit-query", testData, nil, test.DriverOpts{}, nil) - ds := sqlds.NewDatasource(driver) - - // Create settings and initialize datasource + driver, _ := test.NewDriver("rowlimit-query", testData, nil, test.DriverOpts{}) settings := backend.DataSourceInstanceSettings{UID: "rowlimit-query", JSONData: []byte("{}")} - instance, err := ds.NewDatasource(context.Background(), settings) - require.NoError(t, err) - - // Cast to SQLDatasource and set row limit - sqlDS, ok := instance.(*sqlds.SQLDatasource) - require.True(t, ok) - sqlDS.SetDefaultRowLimit(2) + ds := newRowLimitTestDatasource(t, driver, settings, context.Background(), false) + ds.SetDefaultRowLimit(2) // Create query request req := &backend.QueryDataRequest{ @@ -203,7 +145,7 @@ func TestRowLimitPassedToQuery(t *testing.T) { } // Execute query - resp, err := sqlDS.QueryData(context.Background(), req) + resp, err := ds.QueryData(context.Background(), req) assert.NoError(t, err) // Verify response @@ -269,30 +211,44 @@ func TestRowLimitFromEnvVar(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // Set env var for test os.Setenv("GF_DATAPROXY_ROW_LIMIT", tt.envValue) - // Create context with config if needed ctx := context.Background() if tt.configValue != "" { mockConfig := getMockGrafanaCfg(tt.configValue) ctx = backend.WithGrafanaConfig(ctx, mockConfig) } - // Create datasource with driver that may have row limit driver := &mockDriver{rowLimit: tt.driverRowLimit} - ds := sqlds.NewDatasource(driver) - ds.EnableRowLimit = true - - // Create settings and initialize datasource settings := backend.DataSourceInstanceSettings{UID: "rowlimit-env-" + tt.name, JSONData: []byte("{}")} - instance, err := ds.NewDatasource(ctx, settings) - require.NoError(t, err) + ds := newRowLimitTestDatasource(t, driver, settings, ctx, true) - // Verify row limit was set correctly - sqlDS, ok := instance.(*sqlds.SQLDatasource) - require.True(t, ok) - assert.Equal(t, tt.expectedLimit, sqlDS.GetRowLimit()) + assert.Equal(t, tt.expectedLimit, ds.GetRowLimit()) }) } } + +// validTestSettings returns DataSourceInstanceSettings with minimal valid plugin JSON +// to pass NewPluginSettings validation (host, port, protocol are required). +func validTestSettings(uid string) backend.DataSourceInstanceSettings { + return backend.DataSourceInstanceSettings{ + UID: uid, + JSONData: []byte(`{"host":"localhost","port":9000,"protocol":"native"}`), + } +} + +// newRowLimitTestDatasource creates a HydrolixDatasource for rowlimit testing. +func newRowLimitTestDatasource(t *testing.T, driver sqlds.Driver, settings backend.DataSourceInstanceSettings, ctx context.Context, enableRowLimit bool) *sqlds.HydrolixDatasource { + // Use valid plugin settings for NewConnector, but keep original settings for DriverSettings + connSettings := validTestSettings(settings.UID) + _, err := driver.Connect(context.Background(), connSettings, nil) + require.NoError(t, err) + conn, err := sqlds.NewConnector(context.Background(), driver, connSettings) + require.NoError(t, err) + + ds := &sqlds.HydrolixDatasource{Connector: conn, EnableRowLimit: enableRowLimit} + _, err = ds.NewDatasource(ctx, settings) + require.NoError(t, err) + + return ds +} diff --git a/datasource_test.go b/datasource_test.go index 3ca5256..0ea679c 100644 --- a/datasource_test.go +++ b/datasource_test.go @@ -2,26 +2,20 @@ package sqlds_test import ( "context" - "database/sql" - "database/sql/driver" "encoding/json" "errors" - "fmt" - "testing" - "github.com/grafana/grafana-plugin-sdk-go/backend" - "github.com/grafana/grafana-plugin-sdk-go/data/sqlutil" - "github.com/grafana/sqlds/v5" - "github.com/grafana/sqlds/v5/mock" - "github.com/grafana/sqlds/v5/test" + "github.com/hydrolix/sqlds/v5" + "github.com/hydrolix/sqlds/v5/test" "github.com/stretchr/testify/assert" + "testing" ) func Test_health_retries(t *testing.T) { opts := test.DriverOpts{ ConnectError: errors.New("foo"), } - cfg := `{ "timeout": 0, "retries": 5, "retryOn": ["foo"] }` + cfg := `{ "timeout": 0, "retries": 5, "retryOn": ["foo"], "host": "localhost", "port": 9000, "protocol": "native" }` req, handler, ds := healthRequest(t, "timeout", opts, cfg) _, err := ds.CheckHealth(context.Background(), &req) @@ -31,12 +25,12 @@ func Test_health_retries(t *testing.T) { } func Test_query_retries(t *testing.T) { - cfg := `{ "timeout": 0, "retries": 5, "retryOn": ["foo"] }` + cfg := `{ "timeout": 0, "retries": 5, "retryOn": ["foo"], "host": "localhost", "port": 9000, "protocol": "native" }` opts := test.DriverOpts{ QueryError: errors.New("foo"), } - req, handler, ds := queryRequest(t, "error", opts, cfg, nil) + req, handler, ds := queryRequest(t, "error", opts, cfg) data, err := ds.QueryData(context.Background(), req) assert.Nil(t, err) @@ -59,9 +53,9 @@ func Test_query_apply_headers(t *testing.T) { QueryFailTimes: 1, // first check always fails since headers are not available on initial connect OnConnect: onConnect, } - cfg := `{ "timeout": 0, "retries": 1, "retryOn": ["missing token"], "forwardHeaders": true }` + cfg := `{ "timeout": 0, "retries": 1, "retryOn": ["missing token"], "forwardHeaders": true, "host": "localhost", "port": 9000, "protocol": "native" }` - req, handler, ds := queryRequest(t, "headers", opts, cfg, nil) + req, handler, ds := queryRequest(t, "headers", opts, cfg) req.SetHTTPHeader("foo", "bar") @@ -83,7 +77,7 @@ func Test_check_health_with_headers(t *testing.T) { ConnectFailTimes: 1, // first check always fails since headers are not available on initial connect OnConnect: onConnect, } - cfg := `{ "timeout": 0, "retries": 2, "retryOn": ["missing token"], "forwardHeaders": true }` + cfg := `{ "timeout": 0, "retries": 2, "retryOn": ["missing token"], "forwardHeaders": true, "host": "localhost", "port": 9000, "protocol": "native" }` req, handler, ds := healthRequest(t, "health-headers", opts, cfg) r := &req r.SetHTTPHeader("foo", "bar") @@ -96,7 +90,7 @@ func Test_check_health_with_headers(t *testing.T) { } func Test_no_errors(t *testing.T) { - req, _, ds := healthRequest(t, "pass", test.DriverOpts{}, "{}") + req, _, ds := healthRequest(t, "pass", test.DriverOpts{}, `{"host":"localhost","port":9000,"protocol":"native"}`) result, err := ds.CheckHealth(context.Background(), &req) assert.Nil(t, err) @@ -104,198 +98,12 @@ func Test_no_errors(t *testing.T) { assert.Equal(t, expected, result.Message) } -func Test_custom_marco_errors(t *testing.T) { - cfg := `{ "timeout": 0, "retries": 0, "retryOn": ["foo"], query: "badArgumentCount" }` - opts := test.DriverOpts{} - - badArgumentCountFunc := func(query *sqlds.Query, args []string) (string, error) { - return "", sqlutil.ErrorBadArgumentCount - } - macros := sqlds.Macros{ - "foo": badArgumentCountFunc, - } - - req, _, ds := queryRequest(t, "interpolate", opts, cfg, macros) - - req.Queries[0].JSON = []byte(`{ "rawSql": "select $__foo from bar;" }`) - - data, err := ds.QueryData(context.Background(), req) - assert.Nil(t, err) - - res := data.Responses["foo"] - assert.NotNil(t, res.Error) - assert.Equal(t, backend.ErrorSourceDownstream, res.ErrorSource) - assert.Contains(t, res.Error.Error(), sqlutil.ErrorBadArgumentCount.Error()) -} - -func Test_default_macro_errors(t *testing.T) { - tests := []struct { - name string - rawSQL string - wantError string - }{ - { - name: "missing parameters", - rawSQL: "select * from bar where $__timeGroup(", - wantError: "missing close bracket", - }, - { - name: "incorrect argument count 0 - timeGroup", - rawSQL: "select * from bar where $__timeGroup()", - wantError: sqlutil.ErrorBadArgumentCount.Error(), - }, - { - name: "incorrect argument count 3 - timeGroup", - rawSQL: "select * from bar where $__timeGroup(1,2,3)", - wantError: sqlutil.ErrorBadArgumentCount.Error(), - }, - { - name: "incorrect argument count 0 - timeFilter", - rawSQL: "select * from bar where $__timeFilter", - wantError: sqlutil.ErrorBadArgumentCount.Error(), - }, - { - name: "incorrect argument count 3 - timeFilter", - rawSQL: "select * from bar where $__timeFilter(1,2,3)", - wantError: sqlutil.ErrorBadArgumentCount.Error(), - }, - { - name: "incorrect argument count 0 - timeFrom", - rawSQL: "select * from bar where $__timeFrom", - wantError: sqlutil.ErrorBadArgumentCount.Error(), - }, - { - name: "incorrect argument count 3 - timeFrom", - rawSQL: "select * from bar where $__timeFrom(1,2,3)", - wantError: sqlutil.ErrorBadArgumentCount.Error(), - }, - } - - // Common test configuration - cfg := `{ "timeout": 0, "retries": 0, "retryOn": ["foo"], query: "badArgumentCount" }` - opts := test.DriverOpts{} - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Setup request - req, _, ds := queryRequest(t, "interpolate", opts, cfg, nil) - req.Queries[0].JSON = []byte(fmt.Sprintf(`{ "rawSql": "%s" }`, tt.rawSQL)) - - // Execute query - data, err := ds.QueryData(context.Background(), req) - assert.Nil(t, err) - - // Verify response - res := data.Responses["foo"] - assert.NotNil(t, res.Error) - assert.Equal(t, backend.ErrorSourceDownstream, res.ErrorSource) - assert.Contains(t, res.Error.Error(), tt.wantError) - }) - } -} - -func Test_query_panic_recovery(t *testing.T) { - cfg := `{ "timeout": 0, "retries": 0, "retryOn": [] }` - opts := test.DriverOpts{} - - // Create a macro that triggers a panic - panicMacro := func(query *sqlds.Query, args []string) (string, error) { - panic("Random panic for testing purposes") - } - macros := sqlds.Macros{ - "panicTest": panicMacro, - } - - req, _, ds := queryRequest(t, "panic-test", opts, cfg, macros) - - // Set up a query that uses the panic-triggering macro - req.Queries[0].JSON = []byte(`{ "rawSql": "SELECT $__panicTest() FROM test_table;" }`) - - // Execute the query - data, err := ds.QueryData(context.Background(), req) - - // Verify that the panic was caught and converted to an error - assert.Nil(t, err) - assert.NotNil(t, data.Responses) - - res := data.Responses["foo"] - assert.NotNil(t, res.Error) - assert.Equal(t, backend.ErrorSourcePlugin, res.ErrorSource) - assert.Contains(t, res.Error.Error(), "SQL datasource query execution panic") - assert.Contains(t, res.Error.Error(), "Random panic for testing purposes") - assert.Nil(t, res.Frames) -} - -func Test_query_panic_in_rows_validation(t *testing.T) { - cfg := `{ "timeout": 0, "retries": 0, "retryOn": [] }` - opts := test.DriverOpts{} - - // Set up a driver that returns rows which will cause a panic when accessing columns - // This will be caught by our validateRows function - customQueryFunc := func(args []driver.Value) (driver.Rows, error) { - // Return a panicking rows implementation that will cause a panic when columns are accessed - return &panickingRows{}, nil - } - - // Create a custom driver with a wrapper for the query method - driverName := "panic-rows-test" - - // Create and register the handler - handler := test.NewDriverHandler(test.Data{}, opts) - - // Create a custom handler that overrides the Query method - customHandler := &panickingDBHandler{ - SqlHandler: handler, - customQueryFunc: customQueryFunc, - } - - mock.RegisterDriver(driverName, customHandler) - - // Create datasource with the custom driver - testDS := &testDriver{driverName: driverName} - ds := sqlds.NewDatasource(testDS) - - // Set up the query request - req, settings := setupQueryRequest("panic-rows-validation", cfg) - _, err := ds.NewDatasource(context.Background(), settings) - assert.Nil(t, err) - - // Execute the query - data, err := ds.QueryData(context.Background(), req) - - // Verify that the panic was caught and converted to an error - assert.Nil(t, err) - assert.NotNil(t, data.Responses) - - res := data.Responses["foo"] - assert.NotNil(t, res.Error) - assert.Contains(t, res.Error.Error(), "SQL rows validation failed") - assert.NotNil(t, res.Frames) // Error frame is returned, not nil -} - -// panickingRows is a custom rows implementation that panics when columns are accessed -type panickingRows struct{} - -func (r *panickingRows) Columns() []string { - panic("panic in Columns method") -} - -func (r *panickingRows) Close() error { - return nil -} - -func (r *panickingRows) Next(dest []driver.Value) error { - return nil -} - -func queryRequest(t *testing.T, name string, opts test.DriverOpts, cfg string, marcos sqlds.Macros) (*backend.QueryDataRequest, *test.SqlHandler, *sqlds.SQLDatasource) { - driver, handler := test.NewDriver(name, test.Data{}, nil, opts, marcos) - ds := sqlds.NewDatasource(driver) - +func queryRequest(t *testing.T, name string, opts test.DriverOpts, cfg string) (*backend.QueryDataRequest, *test.SqlHandler, *sqlds.HydrolixDatasource) { + driver, handler := test.NewDriver(name, test.Data{}, nil, opts) req, settings := setupQueryRequest(name, cfg) - _, err := ds.NewDatasource(context.Background(), settings) - assert.Equal(t, nil, err) + ds := newTestDatasource(t, driver, settings) + return req, handler, ds } @@ -314,14 +122,12 @@ func setupQueryRequest(id string, cfg string) (*backend.QueryDataRequest, backen }, s } -func healthRequest(t *testing.T, name string, opts test.DriverOpts, cfg string) (backend.CheckHealthRequest, *test.SqlHandler, *sqlds.SQLDatasource) { - driver, handler := test.NewDriver(name, test.Data{}, nil, opts, nil) - ds := sqlds.NewDatasource(driver) - +func healthRequest(t *testing.T, name string, opts test.DriverOpts, cfg string) (backend.CheckHealthRequest, *test.SqlHandler, *sqlds.HydrolixDatasource) { + driver, handler := test.NewDriver(name, test.Data{}, nil, opts) req, settings := setupHealthRequest(name, cfg) - _, err := ds.NewDatasource(context.Background(), settings) - assert.Equal(t, nil, err) + ds := newTestDatasource(t, driver, settings) + return req, handler, ds } @@ -335,49 +141,17 @@ func setupHealthRequest(id string, cfg string) (backend.CheckHealthRequest, back return req, settings } -// testDriver implements sqlds.Driver interface for testing -type testDriver struct { - driverName string -} - -func (d *testDriver) Connect(ctx context.Context, cfg backend.DataSourceInstanceSettings, msg json.RawMessage) (*sql.DB, error) { - return sql.Open(d.driverName, "") -} - -func (d *testDriver) Settings(ctx context.Context, config backend.DataSourceInstanceSettings) sqlds.DriverSettings { - settings, _ := test.LoadSettings(ctx, config) - return settings -} - -func (d *testDriver) Macros() sqlds.Macros { - return nil -} - -func (d *testDriver) Converters() []sqlutil.Converter { - return nil -} - -// panickingDBHandler implements mock.DBHandler and causes panics when querying -type panickingDBHandler struct { - test.SqlHandler - customQueryFunc func(args []driver.Value) (driver.Rows, error) -} - -func (h *panickingDBHandler) Query(args []driver.Value) (driver.Rows, error) { - if h.customQueryFunc != nil { - return h.customQueryFunc(args) - } - return h.SqlHandler.Query(args) -} +// newTestDatasource creates a HydrolixDatasource for testing, bypassing NewConnector's plugin settings validation. +func newTestDatasource(t *testing.T, driver sqlds.Driver, settings backend.DataSourceInstanceSettings) *sqlds.HydrolixDatasource { -func (h *panickingDBHandler) Ping(ctx context.Context) error { - return nil -} + _, err := driver.Connect(context.Background(), settings, nil) + assert.Nil(t, err) + conn, err := sqlds.NewConnector(context.Background(), driver, settings) + assert.Nil(t, err) -func (h *panickingDBHandler) Columns() []string { - return []string{"test_column"} -} + ds := &sqlds.HydrolixDatasource{Connector: conn} + _, err = ds.NewDatasource(context.Background(), settings) + assert.Nil(t, err) -func (h *panickingDBHandler) Next(dest []driver.Value) error { - return errors.New("no more rows") + return ds } diff --git a/driver-mock.go b/driver-mock.go index 2719251..1bffbe2 100644 --- a/driver-mock.go +++ b/driver-mock.go @@ -70,8 +70,3 @@ func (h *SQLMock) Connect(_ context.Context, _ backend.DataSourceInstanceSetting func (h *SQLMock) Converters() []sqlutil.Converter { return []sqlutil.Converter{} } - -// Macros returns list of macro functions convert the macros of raw query -func (h *SQLMock) Macros() Macros { - return Macros{} -} diff --git a/driver.go b/driver.go index 14aa126..2e7ffa1 100644 --- a/driver.go +++ b/driver.go @@ -30,7 +30,6 @@ type Driver interface { Connect(context.Context, backend.DataSourceInstanceSettings, json.RawMessage) (*sql.DB, error) // Settings are read whenever the plugin is initialized, or after the data source settings are updated Settings(context.Context, backend.DataSourceInstanceSettings) DriverSettings - Macros() Macros Converters() []sqlutil.Converter } diff --git a/driver_round_time_test.go b/driver_round_time_test.go new file mode 100644 index 0000000..1c5bd43 --- /dev/null +++ b/driver_round_time_test.go @@ -0,0 +1,63 @@ +package sqlds + +import ( + "github.com/grafana/grafana-plugin-sdk-go/backend" + "testing" + "time" +) + +var defaultTimeRange = backend.TimeRange{ + To: time.Unix(1740678412, 123456789), + From: time.Unix(1740674812, 123456789), +} + +func TestRoundToSecond(t *testing.T) { + timeRange := RoundTimeRange(defaultTimeRange, "1s") + if timeRange.To != time.Unix(1740678412, 0) { + t.Error("To time should be rounded to 1s") + } + if timeRange.From != time.Unix(1740674812, 0) { + t.Error("From time should be rounded to 1s") + } +} + +func TestRoundToMinute(t *testing.T) { + timeRange := RoundTimeRange(defaultTimeRange, "1m") + if timeRange.To != time.Unix(1740678420, 0) { + t.Error("To time should be rounded to 1m") + } + if timeRange.From != time.Unix(1740674820, 0) { + t.Error("From time should be rounded to 1m") + } +} + +func TestRoundToHour(t *testing.T) { + timeRange := RoundTimeRange(defaultTimeRange, "1h") + if timeRange.To != time.Unix(1740679200, 0) { + t.Error("To time should be rounded to 1h") + } + if timeRange.From != time.Unix(1740675600, 0) { + t.Error("From time should be rounded to 1h") + } +} + +func TestRoundToZero(t *testing.T) { + timeRange := RoundTimeRange(defaultTimeRange, "0") + if timeRange != defaultTimeRange { + t.Error("TimeRange should not be rounded") + } +} + +func TestRoundEmpty(t *testing.T) { + timeRange := RoundTimeRange(defaultTimeRange, "") + if timeRange != defaultTimeRange { + t.Error("TimeRange should not be rounded") + } +} + +func TestRoundInvalid(t *testing.T) { + timeRange := RoundTimeRange(defaultTimeRange, "not valid duration") + if timeRange != defaultTimeRange { + t.Error("TimeRange should not be rounded") + } +} diff --git a/go.mod b/go.mod index 190c2bc..a2d2381 100644 --- a/go.mod +++ b/go.mod @@ -1,12 +1,14 @@ -module github.com/grafana/sqlds/v5 +module github.com/hydrolix/sqlds/v5 go 1.25.7 require ( - github.com/go-sql-driver/mysql v1.9.3 + github.com/DATA-DOG/go-sqlmock v1.5.2 github.com/google/go-cmp v0.7.0 github.com/grafana/dataplane/sdata v0.0.9 github.com/grafana/grafana-plugin-sdk-go v0.290.1 + github.com/hydrolix/clickhouse-sql-parser v0.3.0 + github.com/jellydator/ttlcache/v3 v3.4.0 github.com/mithrandie/csvq-driver v1.7.0 github.com/prometheus/client_golang v1.23.2 github.com/stretchr/testify v1.11.1 @@ -25,12 +27,12 @@ require ( go.opentelemetry.io/contrib/instrumentation/net/http/httptrace/otelhttptrace v0.65.0 // indirect go.opentelemetry.io/contrib/propagators/jaeger v1.40.0 // indirect go.opentelemetry.io/contrib/samplers/jaegerremote v0.34.0 // indirect - go.opentelemetry.io/otel v1.40.0 // indirect + go.opentelemetry.io/otel v1.41.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.40.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.40.0 // indirect - go.opentelemetry.io/otel/metric v1.40.0 // indirect + go.opentelemetry.io/otel/metric v1.41.0 // indirect go.opentelemetry.io/otel/sdk v1.40.0 // indirect - go.opentelemetry.io/otel/trace v1.40.0 // indirect + go.opentelemetry.io/otel/trace v1.41.0 // indirect go.opentelemetry.io/proto/otlp v1.9.0 // indirect golang.org/x/mod v0.32.0 // indirect golang.org/x/tools v0.41.0 // indirect @@ -39,7 +41,6 @@ require ( ) require ( - filippo.io/edwards25519 v1.1.1 // indirect github.com/apache/arrow-go/v18 v18.5.1 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cenkalti/backoff/v5 v5.0.3 // indirect @@ -63,7 +64,7 @@ require ( github.com/jaegertracing/jaeger-idl v0.6.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/jszwedko/go-datemath v0.1.1-0.20230526204004-640a500621d6 // indirect - github.com/klauspost/compress v1.18.2 // indirect + github.com/klauspost/compress v1.18.3 // indirect github.com/mattetti/filebuffer v1.0.1 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect @@ -82,7 +83,7 @@ require ( github.com/olekukonko/ll v0.1.4-0.20260115111900-9e59c2286df0 // indirect github.com/olekukonko/tablewriter v1.1.3 // indirect github.com/patrickmn/go-cache v2.1.0+incompatible // indirect - github.com/pierrec/lz4/v4 v4.1.23 // indirect + github.com/pierrec/lz4/v4 v4.1.25 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/prometheus/client_model v0.6.2 // indirect github.com/prometheus/common v0.67.5 // indirect diff --git a/go.sum b/go.sum index ef8edd5..1c3821e 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -filippo.io/edwards25519 v1.1.1 h1:YpjwWWlNmGIDyXOn8zLzqiD+9TyIlPhGFG96P39uBpw= -filippo.io/edwards25519 v1.1.1/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= +github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/apache/arrow-go/v18 v18.5.1 h1:yaQ6zxMGgf9YCYw4/oaeOU3AULySDlAYDOcnr4LdHdI= @@ -37,8 +37,6 @@ github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= -github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo= -github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/gogo/googleapis v1.4.1 h1:1Yx4Myt7BxzvUr5ldGSbwYiZG6t9wGBZ+8/fX3Wvtq0= @@ -77,8 +75,12 @@ github.com/hashicorp/go-plugin v1.7.0 h1:YghfQH/0QmPNc/AZMTFE3ac8fipZyZECHdDPshf github.com/hashicorp/go-plugin v1.7.0/go.mod h1:BExt6KEaIYx804z8k4gRzRLEvxKVb+kn0NMcihqOqb8= github.com/hashicorp/yamux v0.1.2 h1:XtB8kyFOyHXYVFnwT5C3+Bdo8gArse7j2AQ0DA0Uey8= github.com/hashicorp/yamux v0.1.2/go.mod h1:C+zze2n6e/7wshOZep2A70/aQU6QBRWJO/G6FT1wIns= +github.com/hydrolix/clickhouse-sql-parser v0.3.0 h1:H1dTFltNcIp5KWToEh8WsSkVDkvCyaf5K9MDt8jDy1s= +github.com/hydrolix/clickhouse-sql-parser v0.3.0/go.mod h1:XM12SuNWq8DPbgWQ6mB8tIxg2BucGZrX6636VuSk2No= github.com/jaegertracing/jaeger-idl v0.6.0 h1:LOVQfVby9ywdMPI9n3hMwKbyLVV3BL1XH2QqsP5KTMk= github.com/jaegertracing/jaeger-idl v0.6.0/go.mod h1:mpW0lZfG907/+o5w5OlnNnig7nHJGT3SfKmRqC42HGQ= +github.com/jellydator/ttlcache/v3 v3.4.0 h1:YS4P125qQS0tNhtL6aeYkheEaB/m8HCqdMMP4mnWdTY= +github.com/jellydator/ttlcache/v3 v3.4.0/go.mod h1:Hw9EgjymziQD3yGsQdf1FqFdpp7YjFMd4Srg5EJlgD4= github.com/jhump/protoreflect v1.17.0 h1:qOEr613fac2lOuTgWN4tPAtLL7fUSbuJL5X5XumQh94= github.com/jhump/protoreflect v1.17.0/go.mod h1:h9+vUUL38jiBzck8ck+6G/aeMX8Z4QUY/NiJPwPNi+8= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= @@ -87,10 +89,11 @@ github.com/jszwedko/go-datemath v0.1.1-0.20230526204004-640a500621d6 h1:SwcnSwBR github.com/jszwedko/go-datemath v0.1.1-0.20230526204004-640a500621d6/go.mod h1:WrYiIuiXUMIvTDAQw97C+9l0CnBmCcvosPjN3XDqS/o= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE= github.com/klauspost/asmfmt v1.3.2 h1:4Ri7ox3EwapiOjCki+hw14RyKk201CN4rzyCJRFLpK4= github.com/klauspost/asmfmt v1.3.2/go.mod h1:AG8TuvYojzulgDAMCnYn50l/5QV3Bs/tp6j0HLHbNSE= -github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk= -github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= +github.com/klauspost/compress v1.18.3 h1:9PJRvfbmTabkOX8moIpXPbMMbYN60bWImDDU7L+/6zw= +github.com/klauspost/compress v1.18.3/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= @@ -148,8 +151,8 @@ github.com/olekukonko/tablewriter v1.1.3 h1:VSHhghXxrP0JHl+0NnKid7WoEmd9/urKRJLy github.com/olekukonko/tablewriter v1.1.3/go.mod h1:9VU0knjhmMkXjnMKrZ3+L2JhhtsQ/L38BbL3CRNE8tM= github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= -github.com/pierrec/lz4/v4 v4.1.23 h1:oJE7T90aYBGtFNrI8+KbETnPymobAhzRrR8Mu8n1yfU= -github.com/pierrec/lz4/v4 v4.1.23/go.mod h1:EoQMVJgeeEOMsCqCzqFm2O0cJvljX2nGZjcRIPL34O4= +github.com/pierrec/lz4/v4 v4.1.25 h1:kocOqRffaIbU5djlIBr7Wh+cx82C0vtFb0fOurZHqD0= +github.com/pierrec/lz4/v4 v4.1.25/go.mod h1:EoQMVJgeeEOMsCqCzqFm2O0cJvljX2nGZjcRIPL34O4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -163,6 +166,10 @@ github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzM github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= +github.com/sebdah/goldie/v2 v2.5.3 h1:9ES/mNN+HNUbNWpVAlrzuZ7jE+Nrczbj8uFRjM7624Y= +github.com/sebdah/goldie/v2 v2.5.3/go.mod h1:oZ9fp0+se1eapSRjfYbsV/0Hqhbuu3bJVvKI/NNtssI= +github.com/sergi/go-diff v1.0.0 h1:Kpca3qRNrduNnOQeazBd0ysaKrUJiIuISHxogkT9RPQ= +github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -192,23 +199,23 @@ go.opentelemetry.io/contrib/propagators/jaeger v1.40.0/go.mod h1:ioMePqe6k6c/ovX go.opentelemetry.io/contrib/samplers/jaegerremote v0.34.0 h1:RZjNfF9OoR4oPLEWaP+Memql2MNVkZvnwjB2N5tR3cA= go.opentelemetry.io/contrib/samplers/jaegerremote v0.34.0/go.mod h1:b5U9IcSnv+lMvEcSOXZB61kXSf0KkwickleKWuAQclw= go.opentelemetry.io/otel v1.21.0/go.mod h1:QZzNPQPm1zLX4gZK4cMi+71eaorMSGT3A4znnUvNNEo= -go.opentelemetry.io/otel v1.40.0 h1:oA5YeOcpRTXq6NN7frwmwFR0Cn3RhTVZvXsP4duvCms= -go.opentelemetry.io/otel v1.40.0/go.mod h1:IMb+uXZUKkMXdPddhwAHm6UfOwJyh4ct1ybIlV14J0g= +go.opentelemetry.io/otel v1.41.0 h1:YlEwVsGAlCvczDILpUXpIpPSL/VPugt7zHThEMLce1c= +go.opentelemetry.io/otel v1.41.0/go.mod h1:Yt4UwgEKeT05QbLwbyHXEwhnjxNO6D8L5PQP51/46dE= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.40.0 h1:QKdN8ly8zEMrByybbQgv8cWBcdAarwmIPZ6FThrWXJs= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.40.0/go.mod h1:bTdK1nhqF76qiPoCCdyFIV+N/sRHYXYCTQc+3VCi3MI= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.40.0 h1:DvJDOPmSWQHWywQS6lKL+pb8s3gBLOZUtw4N+mavW1I= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.40.0/go.mod h1:EtekO9DEJb4/jRyN4v4Qjc2yA7AtfCBuz2FynRUWTXs= go.opentelemetry.io/otel/metric v1.21.0/go.mod h1:o1p3CA8nNHW8j5yuQLdc1eeqEaPfzug24uvsyIEJRWM= -go.opentelemetry.io/otel/metric v1.40.0 h1:rcZe317KPftE2rstWIBitCdVp89A2HqjkxR3c11+p9g= -go.opentelemetry.io/otel/metric v1.40.0/go.mod h1:ib/crwQH7N3r5kfiBZQbwrTge743UDc7DTFVZrrXnqc= +go.opentelemetry.io/otel/metric v1.41.0 h1:rFnDcs4gRzBcsO9tS8LCpgR0dxg4aaxWlJxCno7JlTQ= +go.opentelemetry.io/otel/metric v1.41.0/go.mod h1:xPvCwd9pU0VN8tPZYzDZV/BMj9CM9vs00GuBjeKhJps= go.opentelemetry.io/otel/sdk v1.21.0/go.mod h1:Nna6Yv7PWTdgJHVRD9hIYywQBRx7pbox6nwBnZIxl/E= go.opentelemetry.io/otel/sdk v1.40.0 h1:KHW/jUzgo6wsPh9At46+h4upjtccTmuZCFAc9OJ71f8= go.opentelemetry.io/otel/sdk v1.40.0/go.mod h1:Ph7EFdYvxq72Y8Li9q8KebuYUr2KoeyHx0DRMKrYBUE= go.opentelemetry.io/otel/sdk/metric v1.40.0 h1:mtmdVqgQkeRxHgRv4qhyJduP3fYJRMX4AtAlbuWdCYw= go.opentelemetry.io/otel/sdk/metric v1.40.0/go.mod h1:4Z2bGMf0KSK3uRjlczMOeMhKU2rhUqdWNoKcYrtcBPg= go.opentelemetry.io/otel/trace v1.21.0/go.mod h1:LGbsEB0f9LGjN+OZaQQ26sohbOmiMR+BaslueVtS/qQ= -go.opentelemetry.io/otel/trace v1.40.0 h1:WA4etStDttCSYuhwvEa8OP8I5EWu24lkOzp+ZYblVjw= -go.opentelemetry.io/otel/trace v1.40.0/go.mod h1:zeAhriXecNGP/s2SEG3+Y8X9ujcJOTqQ5RgdEJcawiA= +go.opentelemetry.io/otel/trace v1.41.0 h1:Vbk2co6bhj8L59ZJ6/xFTskY+tGAbOnCtQGVVa9TIN0= +go.opentelemetry.io/otel/trace v1.41.0/go.mod h1:U1NU4ULCoxeDKc09yCWdWe+3QoyweJcISEVa1RBzOis= go.opentelemetry.io/proto/otlp v1.9.0 h1:l706jCMITVouPOqEnii2fIAuO3IVGBRPV5ICjceRb/A= go.opentelemetry.io/proto/otlp v1.9.0/go.mod h1:xE+Cx5E/eEHw+ISFkwPLwCZefwVjY+pqKg1qcK03+/4= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= diff --git a/health.go b/health.go index a1a8a94..b8272d9 100644 --- a/health.go +++ b/health.go @@ -8,7 +8,7 @@ import ( ) type HealthChecker struct { - Connector *Connector + Connector Connector Metrics Metrics PreCheckHealth func(ctx context.Context, req *backend.CheckHealthRequest) *backend.CheckHealthResult PostCheckHealth func(ctx context.Context, req *backend.CheckHealthRequest) *backend.CheckHealthResult diff --git a/health_test.go b/health_test.go index 1aa0cea..7d42877 100644 --- a/health_test.go +++ b/health_test.go @@ -1,26 +1,35 @@ -package sqlds_test +package sqlds import ( "context" + "database/sql" "testing" "github.com/grafana/grafana-plugin-sdk-go/backend" - sqlds "github.com/grafana/sqlds/v5" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func getFakeConnector(t *testing.T, shouldFail bool) *sqlds.Connector { +func getFakeConnector(t *testing.T, shouldFail bool) Connector { t.Helper() - c, _ := sqlds.NewConnector(context.TODO(), &sqlds.SQLMock{ShouldFailToConnect: shouldFail}, backend.DataSourceInstanceSettings{}, false) + db, _ := newSqlmockDB(t) + if shouldFail { + db.Close() + } + + driver := &stubDriver{ + settings: DriverSettings{}, + connectDBs: []*sql.DB{db}, + } + c, _ := NewConnector(context.TODO(), driver, buildInstanceSettings()) return c } func TestHealthChecker_Check(t *testing.T) { tests := []struct { name string - Connector *sqlds.Connector - Metrics sqlds.Metrics + Connector Connector + Metrics Metrics PreCheckHealth func(ctx context.Context, req *backend.CheckHealthRequest) *backend.CheckHealthResult PostCheckHealth func(ctx context.Context, req *backend.CheckHealthRequest) *backend.CheckHealthResult ctx context.Context @@ -56,7 +65,7 @@ func TestHealthChecker_Check(t *testing.T) { PostCheckHealth: func(ctx context.Context, req *backend.CheckHealthRequest) *backend.CheckHealthResult { return &backend.CheckHealthResult{Status: backend.HealthStatusOk} }, - want: &backend.CheckHealthResult{Status: backend.HealthStatusError, Message: "unable to get default db connection"}, + want: &backend.CheckHealthResult{Status: backend.HealthStatusError, Message: "sql: database is closed"}, }, { name: "should not error when post check succeed", @@ -78,7 +87,7 @@ func TestHealthChecker_Check(t *testing.T) { t.Run(tt.name, func(t *testing.T) { connector := tt.Connector if connector == nil { - connector = &sqlds.Connector{} + connector = getFakeConnector(t, false) } req := tt.req if req == nil { @@ -88,17 +97,13 @@ func TestHealthChecker_Check(t *testing.T) { if want == nil { want = &backend.CheckHealthResult{Status: backend.HealthStatusOk, Message: "Data source is working"} } - ctx := tt.ctx - if ctx == nil { - ctx = context.Background() - } - hc := &sqlds.HealthChecker{ + hc := &HealthChecker{ Connector: connector, Metrics: tt.Metrics, PreCheckHealth: tt.PreCheckHealth, PostCheckHealth: tt.PostCheckHealth, } - got, err := hc.Check(ctx, req) + got, err := hc.Check(tt.ctx, req) if tt.wantErr != nil { require.NotNil(t, err) assert.Equal(t, tt.wantErr.Error(), err.Error()) diff --git a/interpolator.go b/interpolator.go new file mode 100644 index 0000000..23bd75d --- /dev/null +++ b/interpolator.go @@ -0,0 +1,328 @@ +package sqlds + +import ( + "context" + "encoding/json" + "fmt" + "github.com/grafana/grafana-plugin-sdk-go/backend" + "github.com/grafana/grafana-plugin-sdk-go/backend/log" + "github.com/hydrolix/clickhouse-sql-parser/parser" + "github.com/hydrolix/sqlds/v5/models" + "net/http" + "regexp" + "slices" + "sort" + "strings" + "time" +) + +type Interpolator struct { + md *MetaDataProvider + macros map[string]MacroFunc +} + +type macroMatch struct { + full string + name string + args []string + escaped bool + pos parser.Pos +} + +func NewInterpolator(ds *HydrolixDatasource) Interpolator { + return Interpolator{NewMetaDataProvider(ds), Macros} +} + +// getMacroMatches extracts macro strings with their respective arguments from the sql input given +// It manually parses the string to find the closing parenthesis of the macro (because regex has no memory) +func getMacroMatches(input string, name string, positions []parser.Pos) ([]macroMatch, error) { + rgx, err := regexp.Compile(fmt.Sprintf(`\$+__%s\b`, name)) + + if err != nil { + return nil, err + } + + var matches []macroMatch + for _, window := range rgx.FindAllStringIndex(input, -1) { + start, end := window[0], window[1] + args, length := parseArgs(input[end:]) + if length < 0 { + return nil, fmt.Errorf("failed to parse macro arguments (missing close bracket?)") + } + if positions == nil || slices.Contains(positions, parser.Pos(start)) { + matches = append(matches, macroMatch{full: input[start : end+length], args: args, escaped: input[start+1] == '$', pos: parser.Pos(start), name: name}) + } + } + return matches, nil +} + +func getMacroPositions(input string) ([]parser.Pos, error) { + exps, err := parser.NewParser(input).ParseStmts() + if err != nil { + return nil, err + } + positions := make([]parser.Pos, 0) + mVisitor := macroVisitor{macros: make([]MacroId, 0)} + + for _, expr := range exps { + err = expr.Accept(&mVisitor) + if err != nil { + return nil, err + } + } + for _, m := range mVisitor.macros { + positions = append(positions, m.Index) + } + + return positions, nil +} + +// parseArgs looks for a bracketed argument list at the beginning of argString. +// If one is present, returns a list of whitespace-trimmed arguments and the +// length of the string comprising the bracketed argument list. +func parseArgs(argString string) ([]string, int) { + if !strings.HasPrefix(argString, "(") { + return nil, 0 // single empty arg for backwards compatibility + } + + var args []string + depth := 0 + arg := []rune{} + + for i, r := range argString { + switch r { + case '(': + depth++ + if depth == 1 { + // don't include the outer bracket in the arg + continue + } + case ')': + depth-- + if depth == 0 { + // closing bracket + args = append(args, strings.TrimSpace(string(arg))) + return args, i + 1 + } + case ',': + if depth == 1 { + // a comma at this level is separating args + args = append(args, strings.TrimSpace(string(arg))) + arg = []rune{} + continue + } + } + arg = append(arg, r) + } + // If we get here, we have seen an open bracket but not a close bracket. This + // would formerly cause a panic; now it is treated as an error. + return nil, -1 +} + +// Interpolate returns an interpolated query string given a backend.DataQuery +func (i Interpolator) Interpolate(ctx context.Context, query *HDXQuery) (string, error) { + if query.Round != "" && query.Round != "0" { + query.TimeRange = RoundTimeRange(query.TimeRange, query.Round) + } + + // sort macros so longer macros are applied first to prevent it from being + // overridden by a shorter macro that is a substring of the longer one + sortedMacroKeys := make([]string, 0, len(i.macros)) + for key := range i.macros { + sortedMacroKeys = append(sortedMacroKeys, key) + } + sort.Slice(sortedMacroKeys, func(i, j int) bool { + return len(sortedMacroKeys[i]) > len(sortedMacroKeys[j]) + }) + rawSQL := query.RawSQL + macroMatches := make([]macroMatch, 0) + positions, err := getMacroPositions(rawSQL) + if err != nil { + positions = nil + } + for _, key := range sortedMacroKeys { + matches, err := getMacroMatches(rawSQL, key, positions) + if err != nil { + return rawSQL, err + } + macroMatches = append(macroMatches, matches...) + } + + sort.Slice(macroMatches, func(i, j int) bool { + return macroMatches[i].pos > macroMatches[j].pos + }) + for _, match := range macroMatches { + if match.escaped { + rawSQL = rawSQL[0:match.pos] + strings.Replace(rawSQL[match.pos:], "$", "", 1) + } else { + macro := i.macros[match.name] + res, err := macro(ctx, query.WithSQL(rawSQL), match.args, match.pos, i.md) + if err != nil { + return rawSQL, err + } + + rawSQL = rawSQL[0:match.pos] + strings.Replace(rawSQL[match.pos:], match.full, res, 1) + } + } + return rawSQL, nil +} + +type MacroId struct { + Name string `json:"name"` + Index parser.Pos `json:"index"` +} + +type CTE struct { + Macro string `json:"macro"` + MacroPos parser.Pos `json:"macroPos"` + CTE string `json:"cte"` + Table string `json:"table"` + Database string `json:"database"` + Pos parser.Pos `json:"pos"` +} + +// RoundTimeRange rounds the time range to provided time interval +func RoundTimeRange(timeRange backend.TimeRange, interval string) backend.TimeRange { + if dInterval, err := time.ParseDuration(interval); err == nil && dInterval.Seconds() >= 1 { + To := timeRange.To.Round(dInterval) + From := timeRange.From.Round(dInterval) + + log.DefaultLogger.Debug("Time range rounded", "original", timeRange, "from", From, "to", To, "interval", interval) + return backend.TimeRange{To: To, From: From} + } + + log.DefaultLogger.Warn("Using default time range, provided round interval is invalid", "interval", interval) + return timeRange +} + +type macroVisitor struct { + parser.DefaultASTVisitor + macros []MacroId +} + +func (v *macroVisitor) VisitIdent(expr *parser.Ident) error { + if strings.HasPrefix(expr.Name, "$__") { + v.macros = append(v.macros, MacroId{Name: expr.Name, Index: expr.NamePos}) + } + return nil +} + +type tableVisitor struct { + parser.DefaultASTVisitor + pos parser.Pos + table string + database string +} + +func (v *tableVisitor) VisitTableIdentifier(expr *parser.TableIdentifier) error { + if v.pos == expr.Pos() { + if expr.Table != nil { + v.table = expr.Table.String() + } + if expr.Database != nil { + v.database = expr.Database.String() + } else { + v.database = "" + } + + } + return nil +} + +type queryVisitor struct { + parser.DefaultASTVisitor + macroIds map[MacroId]CTE +} + +func (v *queryVisitor) VisitSelectQuery(expr *parser.SelectQuery) error { + if expr.From != nil { + pos := expr.Pos() + cte := expr.From.Expr.String() + tPos := expr.From.Expr.Pos() + tVisitor := tableVisitor{pos: tPos} + _ = expr.Accept(&tVisitor) + mVisitor := macroVisitor{macros: make([]MacroId, 0)} + _ = expr.Accept(&mVisitor) + for _, macro := range mVisitor.macros { + if existing, ok := v.macroIds[macro]; !ok || existing.Pos < pos { + v.macroIds[macro] = CTE{Macro: macro.Name, MacroPos: macro.Index, CTE: cte, Pos: pos, Database: tVisitor.database, Table: tVisitor.table} + } + + } + } + return nil +} + +func GetMacroCTEs(ast []parser.Expr) (map[MacroId]CTE, error) { + visitor := queryVisitor{macroIds: make(map[MacroId]CTE)} + for _, expr := range ast { + err := expr.Accept(&visitor) + if err != nil { + return nil, err + } + } + return visitor.macroIds, nil +} + +func GetHdxQuery(query backend.DataQuery, headers http.Header, timeRange *backend.TimeRange, interval *time.Duration) (*HDXQuery, error) { + q := &HDXQuery{} + + if err := json.Unmarshal(query.JSON, &q); err != nil { + return nil, backend.DownstreamError(fmt.Errorf("error unmarshaling query JSON to the Query Model: %v", err)) + } + if timeRange == nil { + timeRange = &query.TimeRange + } + + if interval == nil { + interval = &query.Interval + } + + // Copy directly from the well typed query + return &HDXQuery{ + RawSQL: q.RawSQL, + Format: q.Format, + Round: q.Round, + QuerySettings: q.QuerySettings, + Filters: q.Filters, + Meta: q.Meta, + TimeRange: *timeRange, + Interval: *interval, + Headers: headers, + }, nil +} + +func (q *HDXQuery) WithSQL(rawSql string) *HDXQuery { + return &HDXQuery{ + RawSQL: rawSql, + Format: q.Format, + Round: q.Round, + QuerySettings: q.QuerySettings, + Filters: q.Filters, + Meta: q.Meta, + TimeRange: q.TimeRange, + Interval: q.Interval, + Headers: q.Headers, + } +} + +type HDXQuery struct { + RawSQL string `json:"rawSql"` + Format int `json:"format"` + Round string `json:"round,omitempty"` + QuerySettings []models.QuerySetting `json:"querySettings,omitempty"` + Filters []AdHocFilter `json:"filters,omitempty"` + Meta struct { + TimeZone string `json:"timezone"` + } `json:"meta"` + TimeRange backend.TimeRange `json:"-"` + Interval time.Duration `json:"-"` + Headers http.Header `json:"-"` +} + +type AdHocFilter struct { + Key string `json:"key"` + Operator string `json:"operator"` + Value string `json:"value"` + Values []string `json:"values,omitempty"` +} diff --git a/interpolator_test.go b/interpolator_test.go new file mode 100644 index 0000000..007116c --- /dev/null +++ b/interpolator_test.go @@ -0,0 +1,463 @@ +package sqlds + +import ( + "context" + "fmt" + "maps" + "net/http" + "slices" + "testing" + "time" + + "github.com/grafana/grafana-plugin-sdk-go/backend" + "github.com/hydrolix/clickhouse-sql-parser/parser" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGetHdxQuery_PreservesHeaders(t *testing.T) { + headers := http.Header{} + headers.Set("Authorization", "Bearer test-token") + headers.Set("X-Grafana-Org-Id", "42") + + queryJSON := []byte(`{"rawSql": "SELECT 1", "format": 0}`) + from := time.Now().Add(-time.Hour) + to := time.Now() + dataQuery := backend.DataQuery{ + RefID: "A", + JSON: queryJSON, + TimeRange: backend.TimeRange{From: from, To: to}, + Interval: time.Second, + } + + hdxQuery, err := GetHdxQuery(dataQuery, headers, nil, nil) + require.NoError(t, err) + assert.Equal(t, "SELECT 1", hdxQuery.RawSQL) + assert.Equal(t, "Bearer test-token", hdxQuery.Headers.Get("Authorization")) + assert.Equal(t, "42", hdxQuery.Headers.Get("X-Grafana-Org-Id")) +} + +func TestGetHdxQuery_NilHeaders(t *testing.T) { + queryJSON := []byte(`{"rawSql": "SELECT 1", "format": 0}`) + dataQuery := backend.DataQuery{ + RefID: "A", + JSON: queryJSON, + TimeRange: backend.TimeRange{From: time.Now().Add(-time.Hour), To: time.Now()}, + Interval: time.Second, + } + + hdxQuery, err := GetHdxQuery(dataQuery, nil, nil, nil) + require.NoError(t, err) + assert.Nil(t, hdxQuery.Headers) +} + +func TestWithSQL_PreservesHeaders(t *testing.T) { + headers := http.Header{} + headers.Set("Authorization", "Bearer token123") + + q := &HDXQuery{ + RawSQL: "SELECT 1", + Headers: headers, + TimeRange: backend.TimeRange{ + From: time.Now().Add(-time.Hour), + To: time.Now(), + }, + Interval: time.Second, + } + + newQ := q.WithSQL("SELECT 2") + assert.Equal(t, "SELECT 2", newQ.RawSQL) + assert.Equal(t, "Bearer token123", newQ.Headers.Get("Authorization")) +} + +func TestGetMacroCTEs(t *testing.T) { + type test struct { + name string + input string + result string + } + + tests := []test{ + {input: "SELECT * FROM table WHERE $__macro()", result: "table", name: "should return the table for filter"}, + {input: "SELECT * FROM schema.table WHERE $__macro()", result: "schema.table", name: "should return the table with schema for filter"}, + {input: "SELECT * FROM schema.table as t1 WHERE $__macro()", result: "schema.table AS t1", name: "should return the table and schema with alias for filter"}, + {input: "SELECT * FROM (Select * from table2 where 1=1) WHERE $__macro()", result: "(SELECT * FROM table2 WHERE 1 = 1)", name: "should return the subquery for filter"}, + {input: "SELECT * FROM (Select * from table2 where l in (select * from table2)) WHERE $__macro()", result: "(SELECT * FROM table2 WHERE l IN (SELECT * FROM table2))", name: "should return subqueries for filter"}, + {input: "WITH\n top_50_reqPath AS (\n SELECT\n topK (50) (reqPath)\n FROM\n table\n WHERE\n $__macro() \n )\nSELECT\n *\nFROM\n top_50_reqPath", result: "table", name: "should return the table and ignore with alias for filter"}, + + {input: "SELECT $__macro() FROM table WHERE 1=1", result: "table", name: "should return the table for value"}, + {input: "SELECT $__macro() FROM schema.table WHERE 1=1", result: "schema.table", name: "should return the table with schema for value"}, + {input: "SELECT $__macro() FROM schema.table as t1 WHERE 1=1", result: "schema.table AS t1", name: "should return the table and schema with alias for value"}, + {input: "SELECT $__macro() FROM (Select * from table2 where 1=1) WHERE 1=1", result: "(SELECT * FROM table2 WHERE 1 = 1)", name: "should return the subquery for value"}, + {input: "SELECT $__macro() FROM (Select * from table2 where l in (select * from table2)) WHERE 1=1", result: "(SELECT * FROM table2 WHERE l IN (SELECT * FROM table2))", name: "should return subqueries for value"}, + {input: "WITH\n top_50_reqPath AS (\n SELECT\n $__macro()\n FROM\n table\n WHERE\n 1=1\n )\nSELECT\n *\nFROM\n top_50_reqPath", result: "table", name: "should return the table and ignore with alias for value"}, + } + + for i, tc := range tests { + + t.Run(fmt.Sprintf("[%d/%d] %s", i+1, len(tests), tc.name), func(t *testing.T) { + expr, _ := parser.NewParser(tc.input).ParseStmts() + res, err := GetMacroCTEs(expr) + require.NoError(t, err) + fmt.Println(res) + assert.Equal(t, len(res), 1) + require.Nil(t, err) + v := slices.Collect(maps.Values(res))[0] + assert.Equal(t, tc.result, v.CTE) + }) + } + +} + +func TestGetMacroCTEsForComplexQuery(t *testing.T) { + expected := []string{"logs AS subquery", "logs AS subquery", "akamai.logs AS main_query", "akamai.logs AS main_query", "akamai.logs AS subquery", "akamai.logs AS subquery"} + sql := "SELECT\n main_query.reqTimeSec,\n (\n SELECT COUNT(*)\n FROM logs AS subquery\n WHERE $__timeFilter(reqTimeSec) AND $__adHocFilter() \n )\nFROM\n akamai.logs AS main_query\nWHERE\n$__timeFilter(reqTimeSec) AND $__adHocFilter() AND\n reqId IN (\n SELECT\n reqId\n FROM\n akamai.logs AS subquery\n WHERE\n statusCode = 404\n AND reqMethod = 'GET'\n AND $__timeFilter(reqTimeSec) AND $__adHocFilter() \n );" + expr, _ := parser.NewParser(sql).ParseStmts() + res, err := GetMacroCTEs(expr) + require.NoError(t, err) + fmt.Println(res) + for i, v := range slices.SortedFunc(maps.Values(res), func(a, b CTE) int { return int(a.MacroPos) - int(b.MacroPos) }) { + assert.Equal(t, expected[i], v.CTE, fmt.Sprintf("For macro %s at index %d", v.Macro, v.MacroPos)) + } +} + +func TestGetMacroMatches(t *testing.T) { + type test struct { + name string + input string + macro string + expected []macroMatch + } + + tests := []test{ + { + name: "should match unescaped macro", + input: "SELECT * FROM table WHERE $__timeFilter(timestamp)", + macro: "timeFilter", + expected: []macroMatch{ + { + full: "$__timeFilter(timestamp)", + name: "timeFilter", + args: []string{"timestamp"}, + escaped: false, + pos: parser.Pos(26), + }, + }, + }, + { + name: "should match escaped macro with double dollar sign", + input: "SELECT * FROM table WHERE $$__timeFilter(timestamp)", + macro: "timeFilter", + expected: []macroMatch{ + { + full: "$$__timeFilter(timestamp)", + name: "timeFilter", + args: []string{"timestamp"}, + escaped: true, + pos: parser.Pos(26), + }, + }, + }, + { + name: "should match multiple unescaped macros", + input: "SELECT $__timeInterval(value) FROM table WHERE $__timeFilter(timestamp)", + macro: "timeFilter", + expected: []macroMatch{ + { + full: "$__timeFilter(timestamp)", + name: "timeFilter", + args: []string{"timestamp"}, + escaped: false, + pos: parser.Pos(47), + }, + }, + }, + { + name: "should match macro with no arguments", + input: "SELECT * FROM table WHERE $__adHocFilter()", + macro: "adHocFilter", + expected: []macroMatch{ + { + full: "$__adHocFilter()", + name: "adHocFilter", + args: []string{""}, + escaped: false, + pos: parser.Pos(26), + }, + }, + }, + { + name: "should match escaped macro with no arguments", + input: "SELECT * FROM table WHERE $$__adHocFilter()", + macro: "adHocFilter", + expected: []macroMatch{ + { + full: "$$__adHocFilter()", + name: "adHocFilter", + args: []string{""}, + escaped: true, + pos: parser.Pos(26), + }, + }, + }, + { + name: "should match macro with multiple arguments", + input: "SELECT * FROM table WHERE $__dateTimeFilter(timestamp, created_at)", + macro: "dateTimeFilter", + expected: []macroMatch{ + { + full: "$__dateTimeFilter(timestamp, created_at)", + name: "dateTimeFilter", + args: []string{"timestamp", "created_at"}, + escaped: false, + pos: parser.Pos(26), + }, + }, + }, + { + name: "should match escaped macro with multiple arguments", + input: "SELECT * FROM table WHERE $$__dateTimeFilter(timestamp, created_at)", + macro: "dateTimeFilter", + expected: []macroMatch{ + { + full: "$$__dateTimeFilter(timestamp, created_at)", + name: "dateTimeFilter", + args: []string{"timestamp", "created_at"}, + escaped: true, + pos: parser.Pos(26), + }, + }, + }, + { + name: "should not match macro without dollar sign", + input: "SELECT * FROM table WHERE __timeFilter(timestamp)", + macro: "timeFilter", + expected: []macroMatch{}, + }, + } + + for i, tc := range tests { + t.Run(fmt.Sprintf("[%d/%d] %s", i+1, len(tests), tc.name), func(t *testing.T) { + matches, err := getMacroMatches(tc.input, tc.macro, nil) + require.NoError(t, err) + assert.Equal(t, len(tc.expected), len(matches)) + for j, expected := range tc.expected { + assert.Equal(t, expected.full, matches[j].full, "full macro text should match") + assert.Equal(t, expected.name, matches[j].name, "macro name should match") + assert.Equal(t, expected.args, matches[j].args, "macro args should match") + assert.Equal(t, expected.escaped, matches[j].escaped, "escaped flag should match") + assert.Equal(t, expected.pos, matches[j].pos, "position should match") + } + }) + } +} + +func TestGetMacroMatches_ErrorCases(t *testing.T) { + type test struct { + name string + input string + macro string + expectErr bool + } + + tests := []test{ + { + name: "should return error for unclosed macro", + input: "SELECT * FROM table WHERE $__timeFilter(timestamp", + macro: "timeFilter", + expectErr: true, + }, + { + name: "should return error for unclosed nested parenthesis", + input: "SELECT * FROM table WHERE $__macro(func(arg)", + macro: "macro", + expectErr: true, + }, + } + + for i, tc := range tests { + t.Run(fmt.Sprintf("[%d/%d] %s", i+1, len(tests), tc.name), func(t *testing.T) { + _, err := getMacroMatches(tc.input, tc.macro, nil) + if tc.expectErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestInterpolateMacroEscaping(t *testing.T) { + type test struct { + name string + input string + output string + } + + tests := []test{ + { + name: "should escape macro with double dollar sign", + input: "SELECT * FROM table WHERE $$__timeFilter(timestamp)", + output: "SELECT * FROM table WHERE $__timeFilter(timestamp)", + }, + { + name: "should escape multiple macros", + input: "SELECT $$__timeInterval(value) FROM table WHERE $$__timeFilter(timestamp)", + output: "SELECT $__timeInterval(value) FROM table WHERE $__timeFilter(timestamp)", + }, + { + name: "should handle mix of escaped and unescaped macros", + input: "SELECT * FROM table WHERE $__timeFilter(timestamp) AND $$__adHocFilter()", + output: "SELECT * FROM table WHERE timestamp >= toDateTime(1415792726) AND timestamp <= toDateTime(1447328726) AND $__adHocFilter()", + }, + { + name: "should escape macro with no arguments", + input: "SELECT * FROM table WHERE $$__adHocFilter()", + output: "SELECT * FROM table WHERE $__adHocFilter()", + }, + { + name: "should escape macro with multiple arguments", + input: "SELECT * FROM table WHERE $$__dateTimeFilter(timestamp, created_at)", + output: "SELECT * FROM table WHERE $__dateTimeFilter(timestamp, created_at)", + }, + { + name: "should process unescaped macro normally", + input: "SELECT * FROM table WHERE $__fromTime", + output: "SELECT * FROM table WHERE toDateTime(1415792726)", + }, + { + name: "should escape fromTime and toTime macros", + input: "SELECT * FROM table WHERE $$__fromTime AND $$__toTime", + output: "SELECT * FROM table WHERE $__fromTime AND $__toTime", + }, + { + name: "should handle escaped timeFilter_ms", + input: "SELECT * FROM table WHERE $$__timeFilter_ms(timestamp)", + output: "SELECT * FROM table WHERE $__timeFilter_ms(timestamp)", + }, + { + name: "should handle complex query with mix of escaped and unescaped", + input: "SELECT * FROM table WHERE $__timeFilter(timestamp) AND status = 'active' OR $$__timeFilter_ms(created_at)", + output: "SELECT * FROM table WHERE timestamp >= toDateTime(1415792726) AND timestamp <= toDateTime(1447328726) AND status = 'active' OR $__timeFilter_ms(created_at)", + }, + { + name: "should handle multiple escaped macros in different positions", + input: "SELECT $$__fromTime, $$__toTime, $__interval_s FROM table WHERE $$__timeFilter(timestamp)", + output: "SELECT $__fromTime, $__toTime, 1 FROM table WHERE $__timeFilter(timestamp)", + }, + { + name: "should handle macros escaped multiple times", + input: "SELECT * FROM table WHERE $$$$__timeFilter(timestamp)", + output: "SELECT * FROM table WHERE $$$__timeFilter(timestamp)", + }, + { + name: "should handle macros escaped multiple times", + input: "SELECT * FROM table WHERE $$$__timeFilter(timestamp)", + output: "SELECT * FROM table WHERE $$__timeFilter(timestamp)", + }, + } + + from, _ := time.Parse("2006-01-02T15:04:05.000Z", "2014-11-12T11:45:26.123Z") + to, _ := time.Parse("2006-01-02T15:04:05.000Z", "2015-11-12T11:45:26.456Z") + + for i, tc := range tests { + interpolator := NewInterpolator(&HydrolixDatasource{ + Connector: &MockConnector{ + uid: "uid-123", + }, + }) + t.Run(fmt.Sprintf("[%d/%d] %s", i+1, len(tests), tc.name), func(t *testing.T) { + query := &HDXQuery{ + RawSQL: tc.input, + TimeRange: backend.TimeRange{ + From: from, + To: to, + }, + Interval: time.Duration(1000000000), + } + interpolatedQuery, err := interpolator.Interpolate(context.Background(), query) + require.NoError(t, err) + assert.Equal(t, tc.output, interpolatedQuery) + }) + } +} + +func TestInterpolateMacroInStringLiteral(t *testing.T) { + type test struct { + name string + input string + output string + } + + tests := []test{ + { + name: "should not interpolate macro inside string literal", + input: "SELECT '$__fromTime' FROM table", + output: "SELECT '$__fromTime' FROM table", + }, + { + name: "should not interpolate macro with brackets inside string literal", + input: "SELECT '$__timeFilter(timestamp)' FROM table", + output: "SELECT '$__timeFilter(timestamp)' FROM table", + }, + { + name: "should interpolate real macro but not one inside string literal", + input: "SELECT '$__fromTime' FROM table WHERE $__fromTime", + output: "SELECT '$__fromTime' FROM table WHERE toDateTime(1415792726)", + }, + { + name: "should interpolate real macro but not one inside string literal with brackets", + input: "SELECT '$__timeFilter(timestamp)' FROM table WHERE $__timeFilter(timestamp)", + output: "SELECT '$__timeFilter(timestamp)' FROM table WHERE timestamp >= toDateTime(1415792726) AND timestamp <= toDateTime(1447328726)", + }, + { + name: "should not interpolate macro inside line comment", + input: "SELECT * FROM table -- $__fromTime", + output: "SELECT * FROM table -- $__fromTime", + }, + { + name: "should not interpolate macro with brackets inside line comment", + input: "SELECT * FROM table -- $__timeFilter(timestamp)", + output: "SELECT * FROM table -- $__timeFilter(timestamp)", + }, + { + name: "should not interpolate macro inside block comment", + input: "SELECT * FROM table /* $__fromTime */", + output: "SELECT * FROM table /* $__fromTime */", + }, + { + name: "should interpolate real macro but not one inside line comment", + input: "SELECT * FROM table WHERE $__fromTime -- $__fromTime", + output: "SELECT * FROM table WHERE toDateTime(1415792726) -- $__fromTime", + }, + { + name: "should interpolate real macro but not one inside block comment", + input: "SELECT * FROM table WHERE $__fromTime /* $__fromTime */", + output: "SELECT * FROM table WHERE toDateTime(1415792726) /* $__fromTime */", + }, + } + + from, _ := time.Parse("2006-01-02T15:04:05.000Z", "2014-11-12T11:45:26.123Z") + to, _ := time.Parse("2006-01-02T15:04:05.000Z", "2015-11-12T11:45:26.456Z") + + for i, tc := range tests { + interpolator := NewInterpolator(&HydrolixDatasource{ + Connector: &MockConnector{ + uid: "uid-123", + }, + }) + t.Run(fmt.Sprintf("[%d/%d] %s", i+1, len(tests), tc.name), func(t *testing.T) { + query := &HDXQuery{ + RawSQL: tc.input, + TimeRange: backend.TimeRange{ + From: from, + To: to, + }, + Interval: time.Duration(1000000000), + } + interpolatedQuery, err := interpolator.Interpolate(context.Background(), query) + require.NoError(t, err) + assert.Equal(t, tc.output, interpolatedQuery) + }) + } +} diff --git a/macros.go b/macros.go index 059bd84..7fd7826 100644 --- a/macros.go +++ b/macros.go @@ -1,30 +1,452 @@ package sqlds import ( - "errors" + "context" + "fmt" + "github.com/hydrolix/clickhouse-sql-parser/parser" + "maps" + "math" + "net/http" + "regexp" + "slices" + "strings" + "time" + "github.com/grafana/grafana-plugin-sdk-go/backend" "github.com/grafana/grafana-plugin-sdk-go/data/sqlutil" ) -var ( - ErrorParsingMacroBrackets = errors.New("failed to parse macro arguments (missing close bracket?)") +const ( + SyntheticNull = "__null__" + SyntheticEmpty = "__empty__" + RegexPrefix = "regex:" ) -// MacroFunc defines a signature for applying a query macro -// Query macro implementations are defined by users / consumers of this package -// Deprecated: use sqlutil.MacroFunc directly -type MacroFunc = sqlutil.MacroFunc +var mapTypeFilterKey = regexp.MustCompile("^(.*)\\['.*']$") -// Macros is a list of MacroFuncs. -// The "string" key is the name of the macro function. This name has to be regex friendly. -// Deprecated: use sqlutil.Macros directly -type Macros = sqlutil.Macros +type MacroFunc func(context.Context, *HDXQuery, []string, parser.Pos, *MetaDataProvider) (string, error) -// Deprecated: use sqlutil.DefaultMacros directly -var DefaultMacros = sqlutil.DefaultMacros +// Converts a time.Time to a Date +func timeToDate(t time.Time) string { + return fmt.Sprintf("toDate('%s')", t.Format("2006-01-02")) +} + +// Converts a time.Time to a UTC DateTime with seconds precision +func timeToDateTime(t time.Time) string { + return fmt.Sprintf("toDateTime(%d)", t.Unix()) +} + +// Converts a time.Time to a UTC DateTime64 with milliseconds precision +func timeToDateTime64(t time.Time) string { + return fmt.Sprintf("fromUnixTimestamp64Milli(%d)", t.UnixMilli()) +} + +// FromTimeFilter returns a time filter expression based on grafana's timepicker's "from" time in seconds +func FromTimeFilter(_ context.Context, query *HDXQuery, _ []string, _ parser.Pos, _ *MetaDataProvider) (string, error) { + return timeToDateTime(query.TimeRange.From), nil +} + +// ToTimeFilter returns a time filter expression based on grafana's timepicker's "to" time in seconds +func ToTimeFilter(_ context.Context, query *HDXQuery, _ []string, _ parser.Pos, _ *MetaDataProvider) (string, error) { + return timeToDateTime(query.TimeRange.To), nil +} + +// FromTimeFilterMs returns a time filter expression based on grafana's timepicker's "from" time in milliseconds +func FromTimeFilterMs(_ context.Context, query *HDXQuery, _ []string, _ parser.Pos, _ *MetaDataProvider) (string, error) { + return timeToDateTime64(query.TimeRange.From), nil +} + +// ToTimeFilterMs returns a time filter expression based on grafana's timepicker's "to" time in milliseconds +func ToTimeFilterMs(_ context.Context, query *HDXQuery, _ []string, _ parser.Pos, _ *MetaDataProvider) (string, error) { + return timeToDateTime64(query.TimeRange.To), nil +} + +func TimeFilter(context context.Context, query *HDXQuery, args []string, pos parser.Pos, mdProvider *MetaDataProvider) (string, error) { + if len(args) > 1 { + return "", backend.DownstreamError(fmt.Errorf("%w: expected 0 or 1 argument, received %d", sqlutil.ErrorBadArgumentCount, len(args))) + } + + var ( + column string + from = query.TimeRange.From + to = query.TimeRange.To + ) + + if len(args) == 1 && args[0] != "" { + column = args[0] + } else { + pk, err := getPK(context, query.RawSQL, pos, mdProvider, query.Headers) + if err != nil { + return "", err + } + column = pk + } + + return fmt.Sprintf("%s >= %s AND %s <= %s", column, timeToDateTime(from), column, timeToDateTime(to)), nil +} + +func TimeFilterMs(context context.Context, query *HDXQuery, args []string, pos parser.Pos, mdProvider *MetaDataProvider) (string, error) { + if len(args) > 1 { + return "", backend.DownstreamError(fmt.Errorf("%w: expected 0 or 1 argument, received %d", sqlutil.ErrorBadArgumentCount, len(args))) + } + + var ( + column string + from = query.TimeRange.From + to = query.TimeRange.To + ) + + if len(args) == 1 && args[0] != "" { + column = args[0] + } else { + pk, err := getPK(context, query.RawSQL, pos, mdProvider, query.Headers) + if err != nil { + return "", err + } + column = pk + } + + return fmt.Sprintf("%s >= %s AND %s <= %s", column, timeToDateTime64(from), column, timeToDateTime64(to)), nil +} + +func DateFilter(_ context.Context, query *HDXQuery, args []string, _ parser.Pos, _ *MetaDataProvider) (string, error) { + if len(args) != 1 { + return "", backend.DownstreamError(fmt.Errorf("%w: expected 1 argument, received %d", sqlutil.ErrorBadArgumentCount, len(args))) + } + var ( + column = args[0] + from = query.TimeRange.From + to = query.TimeRange.To + ) + + return fmt.Sprintf("%s >= %s AND %s <= %s", column, timeToDate(from), column, timeToDate(to)), nil +} + +func DateTimeFilter(_ context.Context, query *HDXQuery, args []string, _ parser.Pos, _ *MetaDataProvider) (string, error) { + if len(args) != 2 { + return "", backend.DownstreamError(fmt.Errorf("%w: expected 2 arguments, received %d", sqlutil.ErrorBadArgumentCount, len(args))) + } + var ( + dateColumn = args[0] + timeColumn = args[1] + from = query.TimeRange.From + to = query.TimeRange.To + ) + + dateFilter := fmt.Sprintf("(%s >= %s AND %s <= %s)", dateColumn, timeToDate(from), dateColumn, timeToDate(to)) + timeFilter := fmt.Sprintf("(%s >= %s AND %s <= %s)", timeColumn, timeToDateTime(from), timeColumn, timeToDateTime(to)) + return fmt.Sprintf("%s AND %s", dateFilter, timeFilter), nil +} + +func TimeInterval(context context.Context, query *HDXQuery, args []string, pos parser.Pos, mdProvider *MetaDataProvider) (string, error) { + if len(args) > 1 { + return "", backend.DownstreamError(fmt.Errorf("%w: expected 0 or 1 argument, received %d", sqlutil.ErrorBadArgumentCount, len(args))) + } + var ( + column string + ) + + if len(args) == 1 && args[0] != "" { + column = args[0] + } else { + pk, err := getPK(context, query.RawSQL, pos, mdProvider, query.Headers) + if err != nil { + return "", err + } + column = pk + } + + seconds := math.Max(query.Interval.Seconds(), 1) + return fmt.Sprintf("toStartOfInterval(toDateTime(%s), INTERVAL %d second)", column, int(seconds)), nil +} + +func TimeIntervalMs(context context.Context, query *HDXQuery, args []string, pos parser.Pos, mdProvider *MetaDataProvider) (string, error) { + if len(args) > 1 { + return "", backend.DownstreamError(fmt.Errorf("%w: expected 0 or 1 argument, received %d", sqlutil.ErrorBadArgumentCount, len(args))) + } + var ( + column string + ) + + if len(args) == 1 && args[0] != "" { + column = args[0] + } else { + pk, err := getPK(context, query.RawSQL, pos, mdProvider, query.Headers) + if err != nil { + return "", err + } + column = pk + } + milliseconds := math.Max(float64(query.Interval.Milliseconds()), 1) + return fmt.Sprintf("toStartOfInterval(toDateTime64(%s, 3), INTERVAL %d millisecond)", column, int(milliseconds)), nil +} + +func IntervalSeconds(_ context.Context, query *HDXQuery, _ []string, _ parser.Pos, _ *MetaDataProvider) (string, error) { + seconds := math.Max(query.Interval.Seconds(), 1) + return fmt.Sprintf("%d", int(seconds)), nil +} + +// AdHocFilterMacro implements the $__adHocFilter() macro +func AdHocFilterMacro(ctx context.Context, query *HDXQuery, params []string, pos parser.Pos, mdProvider *MetaDataProvider) (string, error) { + if query.Filters == nil || len(query.Filters) == 0 { + return "1=1", nil + } + if len(params) > 1 { + return "", backend.DownstreamError(fmt.Errorf("%w: expected 0 or 1 argument, received %d", sqlutil.ErrorBadArgumentCount, len(params))) + } + + var cte = "" + if len(params) == 1 { + cte = params[0] + } + + if cte == "" { + expr, err := parser.NewParser(query.RawSQL).ParseStmts() + if err != nil { + return "", err + } + + macroCTEs, err := GetMacroCTEs(expr) + if err != nil { + return "", err + } + + for _, macroCTE := range macroCTEs { + if macroCTE.MacroPos == pos { + cte = macroCTE.CTE + break + } + } + } + if cte == "" { + return "", fmt.Errorf("cannot apply ad hoc filters: unable to resolve tableName for ad hoc filter at index %d", pos) + } + keys, err := mdProvider.GetKeys(ctx, query.Headers, cte) + + if err != nil { + return "", fmt.Errorf("cannot apply ad hoc filters: unable to resolve keys for cte: %s", cte) + } + var conditions []string + keyNames := slices.Collect(maps.Keys(keys)) + + for _, filter := range query.Filters { + column := filter.Key + if mapTypeFilterKey.MatchString(filter.Key) { + column = mapTypeFilterKey.FindStringSubmatch(filter.Key)[1] + } + if slices.Contains(keyNames, column) { + keyType := keys[column] + condition, err := buildFilterCondition(filter, keyType) + if err != nil { + return "", fmt.Errorf("error building filter condition for key '%s': %w", filter.Key, err) + } + if condition != "" { + conditions = append(conditions, condition) + } + } + } + + if len(conditions) == 0 { + return "1=1", nil + } + + return strings.Join(conditions, " AND "), nil +} + +func buildArrayCondition(filter AdHocFilter) (string, error) { + key := filter.Key + value := filter.Value + operator := filter.Operator + if operator == "=|" { + var buffer []string + for _, v := range filter.Values { + buffer = append(buffer, fmt.Sprintf("has(%s, $$%s$$)", key, v)) + } + return fmt.Sprintf("(%s)", strings.Join(buffer, " OR ")), nil + } else if operator == "!=|" { + var buffer []string + for _, v := range filter.Values { + buffer = append(buffer, fmt.Sprintf("not has(%s, $$%s$$)", key, v)) + } + return fmt.Sprintf("(%s)", strings.Join(buffer, " OR ")), nil + } else if operator == "!=" { + return fmt.Sprintf("not has(%s, $$%s$$)", key, value), nil + } else if operator == "=" { + return fmt.Sprintf("has(%s, $$%s$$)", key, value), nil + + } else { + return "", fmt.Errorf("operator %s unsupported for Array value", operator) + } +} + +// buildFilterCondition creates a SQL condition from an ad-hoc filter +func buildFilterCondition(filter AdHocFilter, keyType string) (string, error) { + isString := strings.Contains(strings.ToLower(keyType), "string)") || strings.ToLower(keyType) == "string" + isArray := strings.Contains(strings.ToLower(keyType), "array") + isMap := strings.Contains(strings.ToLower(keyType), "map") + if isArray { + return buildArrayCondition(filter) + } + + key := filter.Key + value := filter.Value + operator := filter.Operator + if operator == "=|" { + if isMap && !isString { + return "", fmt.Errorf("cannot apply =| operator over non string map values") + } + values, hasNull := getJoinedValues(filter.Values) + + var parts []string + if hasNull { + parts = append(parts, fmt.Sprintf("%s IS NULL", key)) + } + + if values != "" { + parts = append(parts, fmt.Sprintf("%s IN (%s)", key, values)) + } + if len(parts) == 0 { + return "", nil + } else if len(parts) == 1 { + return parts[0], nil + } else { + return fmt.Sprintf("(%s)", strings.Join(parts, " OR ")), nil + } + + } else if operator == "!=|" { + if isMap && !isString { + return "", fmt.Errorf("cannot apply !=| operator over non string map values") + } + values, hasNull := getJoinedValues(filter.Values) + + var parts []string + if hasNull { + parts = append(parts, fmt.Sprintf("%s IS NOT NULL", key)) + } + + if values != "" { + parts = append(parts, fmt.Sprintf("%s NOT IN (%s)", key, values)) + } + + return strings.Join(parts, " AND "), nil + } else if strings.ToUpper(value) == "NULL" || value == SyntheticNull { + if operator == "=" && isString { + return fmt.Sprintf("(%s IS NULL OR %s = '%s')", key, key, SyntheticNull), nil + } else if operator == "!=" && isString { + return fmt.Sprintf("(%s IS NOT NULL OR %s != '%s')", key, key, SyntheticNull), nil + } else if operator == "=" { + return fmt.Sprintf("%s IS NULL", key), nil + } else if operator == "!=" { + return fmt.Sprintf("%s IS NOT NULL", key), nil + } else { + return "", fmt.Errorf("%s: operator '%s' can not be applied to NULL value", key, operator) + } + } else if value == "" || value == SyntheticEmpty { + if operator == "=" { + return fmt.Sprintf("(%s = '' OR %s = '%s')", key, key, SyntheticEmpty), nil + } else if operator == "!=" { + return fmt.Sprintf("(%s != '' AND %s != '%s')", key, key, SyntheticEmpty), nil + } else { + return "", fmt.Errorf("%s: operator '%s' can not be applied to __empty__ value", key, operator) + } + + } else if operator == "=~" { + regex, isRegex := getRegexValue(value) + if isRegex { + return fmt.Sprintf("match(toString(%s), $$%s$$)", key, regex), nil + } else { + return fmt.Sprintf("toString(%s) LIKE $$%s$$", key, escapeWildcard(value)), nil + } + } else if operator == "!~" { + regex, isRegex := getRegexValue(value) + if isRegex { + return fmt.Sprintf("not match(toString(%s), $$%s$$)", key, regex), nil + } else { + return fmt.Sprintf("toString(%s) NOT LIKE $$%s$$", key, escapeWildcard(value)), nil + } + } else { + return fmt.Sprintf("%s %s $$%s$$", key, operator, value), nil + } +} + +func getRegexValue(value string) (string, bool) { + + isRegex := strings.HasPrefix(value, RegexPrefix) + if isRegex { + return value[len(RegexPrefix):], true + } else { + return "", false + } +} + +func getJoinedValues(values []string) (string, bool) { + var buffer []string + hasNull := false + for _, v := range values { + if strings.ToUpper(v) == "NULL" || v == SyntheticNull { + hasNull = true + } else if v == SyntheticEmpty { + buffer = append(buffer, "$$$$") + } else { + buffer = append(buffer, fmt.Sprintf("$$%s$$", v)) + } + } + return strings.Join(buffer, ", "), hasNull +} + +// escapeWildcard prepares wildcard patterns for LIKE queries +func escapeWildcard(v string) string { + chars := []rune(v) + for i := range len(chars) { + if chars[i] == '*' && (i == 0 || chars[i-1] != '\\') { + chars[i] = '%' + } + } + v = string(chars) + v = strings.ReplaceAll(v, `\*`, "*") + return v +} + +func Stub(_ context.Context, _ *HDXQuery, _ []string, _ parser.Pos, _ *MetaDataProvider) (string, error) { + return "1=1", nil +} + +func getPK(context context.Context, rawSQL string, pos parser.Pos, mdProvider *MetaDataProvider, headers http.Header) (string, error) { + expr, err := parser.NewParser(rawSQL).ParseStmts() + if err != nil { + return rawSQL, err + } + macroIds, err := GetMacroCTEs(expr) + if err != nil { + return rawSQL, err + } + var cte *CTE + for _, macroCTE := range macroIds { + if macroCTE.MacroPos == pos { + cte = ¯oCTE + break + } + } + if cte == nil { + return rawSQL, fmt.Errorf("no CTE found for macro at pos %d", pos) + } + return mdProvider.GetPK(context, headers, cte.Database, cte.Table) +} -// Interpolate wraps sqlutil.Interpolate for temporary backwards-compatibility -// Deprecated: use sqlutil.Interpolate directly -func Interpolate(driver Driver, query *Query) (string, error) { - return sqlutil.Interpolate(query, driver.Macros()) +// Macros is a map of all macro functions +var Macros = map[string]MacroFunc{ + "adHocFilter": AdHocFilterMacro, + "conditionalAll": Stub, + "fromTime": FromTimeFilter, + "toTime": ToTimeFilter, + "fromTime_ms": FromTimeFilterMs, + "toTime_ms": ToTimeFilterMs, + "timeFilter": TimeFilter, + "timeFilter_ms": TimeFilterMs, + "dateFilter": DateFilter, + "dateTimeFilter": DateTimeFilter, + "dt": DateTimeFilter, + "timeInterval": TimeInterval, + "timeInterval_ms": TimeIntervalMs, + "interval_s": IntervalSeconds, } diff --git a/macros_test.go b/macros_test.go new file mode 100644 index 0000000..70cb635 --- /dev/null +++ b/macros_test.go @@ -0,0 +1,1004 @@ +package sqlds + +import ( + "context" + "fmt" + "github.com/DATA-DOG/go-sqlmock" + "testing" + "time" + + "github.com/grafana/grafana-plugin-sdk-go/backend" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTimeToDate(t *testing.T) { + d, _ := time.Parse("2006-01-02T15:04:05.000Z", "2014-11-12T11:45:26.371Z") + + expected := "toDate('2014-11-12')" + result := timeToDate(d) + + if expected != result { + t.Errorf("unexpected output. expected: %s got: %s", expected, result) + } +} + +func TestTimeToDateTime(t *testing.T) { + dt := time.Unix(1708430068, 0) + + expected := "toDateTime(1708430068)" + result := timeToDateTime(dt) + + if expected != result { + t.Errorf("unexpected output. expected: %s got: %s", expected, result) + } +} + +func TestTimeToDateTime64(t *testing.T) { + dt := time.UnixMilli(1708430068123) + + expected := "fromUnixTimestamp64Milli(1708430068123)" + result := timeToDateTime64(dt) + + if expected != result { + t.Errorf("unexpected output. expected: %s got: %s", expected, result) + } +} + +func TestMacroFromTimeFilter(t *testing.T) { + from, _ := time.Parse("2006-01-02T15:04:05.000Z", "2014-11-12T11:45:26.371Z") + to, _ := time.Parse("2006-01-02T15:04:05.000Z", "2015-11-12T11:45:26.371Z") + query := HDXQuery{ + TimeRange: backend.TimeRange{ + From: from, + To: to, + }, + RawSQL: "select foo from foo where bar > $__fromTime", + } + tests := []struct { + want string + wantErr bool + name string + }{ + { + name: "should return timeFilter", + want: "toDateTime(1415792726)", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := FromTimeFilter(context.Background(), &query, []string{}, 0, &MetaDataProvider{}) + if (err != nil) != tt.wantErr { + t.Errorf("macroFromTimeFilter() error = %v, wantErr %v", err, tt.wantErr) + return + } + assert.Equal(t, tt.want, got) + }) + } +} + +func TestMacroToTimeFilter(t *testing.T) { + from, _ := time.Parse("2006-01-02T15:04:05.000Z", "2014-11-12T11:45:26.371Z") + to, _ := time.Parse("2006-01-02T15:04:05.000Z", "2015-11-12T11:45:26.371Z") + query := HDXQuery{ + TimeRange: backend.TimeRange{ + From: from, + To: to, + }, + RawSQL: "select foo from foo where bar > $__toTime", + } + tests := []struct { + want string + wantErr bool + name string + }{ + { + name: "should return timeFilter", + want: "toDateTime(1447328726)", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ToTimeFilter(context.Background(), &query, []string{}, 0, &MetaDataProvider{}) + if (err != nil) != tt.wantErr { + t.Errorf("macroToTimeFilter() error = %v, wantErr %v", err, tt.wantErr) + return + } + assert.Equal(t, tt.want, got) + }) + } +} + +func TestMacroFromTimeFilterMs(t *testing.T) { + from, _ := time.Parse("2006-01-02T15:04:05.000Z", "2014-11-12T11:45:26.371Z") + to, _ := time.Parse("2006-01-02T15:04:05.000Z", "2015-11-12T11:45:26.371Z") + query := HDXQuery{ + TimeRange: backend.TimeRange{ + From: from, + To: to, + }, + RawSQL: "select foo from foo where bar > $__fromTime", + } + tests := []struct { + want string + wantErr bool + name string + }{ + { + name: "should return timeFilter_ms", + want: "fromUnixTimestamp64Milli(1415792726371)", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := FromTimeFilterMs(context.Background(), &query, []string{}, 0, &MetaDataProvider{}) + if (err != nil) != tt.wantErr { + t.Errorf("macroFromTimeFilterMs() error = %v, wantErr %v", err, tt.wantErr) + return + } + assert.Equal(t, tt.want, got) + }) + } +} + +func TestMacroToTimeFilterMs(t *testing.T) { + from, _ := time.Parse("2006-01-02T15:04:05.000Z", "2014-11-12T11:45:26.371Z") + to, _ := time.Parse("2006-01-02T15:04:05.000Z", "2015-11-12T11:45:26.371Z") + query := HDXQuery{ + TimeRange: backend.TimeRange{ + From: from, + To: to, + }, + RawSQL: "select foo from foo where bar > $__toTime", + } + tests := []struct { + want string + wantErr bool + name string + }{ + { + name: "should return timeFilter_ms", + want: "fromUnixTimestamp64Milli(1447328726371)", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ToTimeFilterMs(context.Background(), &query, []string{}, 0, &MetaDataProvider{}) + if (err != nil) != tt.wantErr { + t.Errorf("macroToTimeFilterMs() error = %v, wantErr %v", err, tt.wantErr) + return + } + assert.Equal(t, tt.want, got) + }) + } +} + +func TestMacroDateFilter(t *testing.T) { + from, _ := time.Parse("2006-01-02T15:04:05.000Z", "2014-11-12T11:45:26.371Z") + to, _ := time.Parse("2006-01-02T15:04:05.000Z", "2015-11-12T11:45:26.371Z") + query := HDXQuery{ + TimeRange: backend.TimeRange{ + From: from, + To: to, + }, + } + got, err := DateFilter(context.Background(), &query, []string{"dateCol"}, 0, &MetaDataProvider{}) + assert.Nil(t, err) + assert.Equal(t, "dateCol >= toDate('2014-11-12') AND dateCol <= toDate('2015-11-12')", got) +} + +func TestMacroDateTimeFilter(t *testing.T) { + from, _ := time.Parse("2006-01-02T15:04:05.000Z", "2014-11-12T11:45:26.371Z") + to, _ := time.Parse("2006-01-02T15:04:05.000Z", "2015-11-12T11:45:26.371Z") + query := HDXQuery{ + TimeRange: backend.TimeRange{ + From: from, + To: to, + }, + } + got, err := DateTimeFilter(context.Background(), &query, []string{"dateCol", "timeCol"}, 0, &MetaDataProvider{}) + assert.Nil(t, err) + assert.Equal(t, "(dateCol >= toDate('2014-11-12') AND dateCol <= toDate('2015-11-12')) AND (timeCol >= toDateTime(1415792726) AND timeCol <= toDateTime(1447328726))", got) +} + +func TestMacroTimeInterval(t *testing.T) { + query := HDXQuery{ + RawSQL: "select $__timeInterval(col) from foo", + Interval: time.Duration(20000000000), + } + got, err := TimeInterval(context.Background(), &query, []string{"col"}, 0, &MetaDataProvider{}) + assert.Nil(t, err) + assert.Equal(t, "toStartOfInterval(toDateTime(col), INTERVAL 20 second)", got) +} + +func TestMacroTimeIntervalMs(t *testing.T) { + query := HDXQuery{ + RawSQL: "select $__timeInterval_ms(col) from foo", + Interval: time.Duration(20000000000), + } + got, err := TimeIntervalMs(context.Background(), &query, []string{"col"}, 0, &MetaDataProvider{}) + assert.Nil(t, err) + assert.Equal(t, "toStartOfInterval(toDateTime64(col, 3), INTERVAL 20000 millisecond)", got) +} + +func TestMacroIntervalSeconds(t *testing.T) { + query := HDXQuery{ + RawSQL: "select toStartOfInterval(col, INTERVAL $__interval_s second) AS time from foo", + Interval: time.Duration(20000000000), + } + got, err := IntervalSeconds(context.Background(), &query, []string{}, 0, &MetaDataProvider{}) + assert.Nil(t, err) + assert.Equal(t, "20", got) +} + +// test sqlds query interpolation with clickhouse filters used +func TestInterpolate(t *testing.T) { + from, _ := time.Parse("2006-01-02T15:04:05.000Z", "2014-11-12T11:45:26.123Z") + to, _ := time.Parse("2006-01-02T15:04:05.000Z", "2015-11-12T11:45:26.456Z") + + type test struct { + name string + input string + output string + } + + tests := []test{ + {input: "select * from foo where $__timeFilter(cast(sth as timestamp))", output: "select * from foo where cast(sth as timestamp) >= toDateTime(1415792726) AND cast(sth as timestamp) <= toDateTime(1447328726)", name: "clickhouse timeFilter"}, + {input: "select * from foo where $__timeFilter(cast(sth as timestamp) )", output: "select * from foo where cast(sth as timestamp) >= toDateTime(1415792726) AND cast(sth as timestamp) <= toDateTime(1447328726)", name: "clickhouse timeFilter with empty spaces"}, + {input: "select * from foo where $__timeFilter_ms(cast(sth as timestamp))", output: "select * from foo where cast(sth as timestamp) >= fromUnixTimestamp64Milli(1415792726123) AND cast(sth as timestamp) <= fromUnixTimestamp64Milli(1447328726456)", name: "clickhouse timeFilter_ms"}, + {input: "select * from foo where $__timeFilter_ms(cast(sth as timestamp) )", output: "select * from foo where cast(sth as timestamp) >= fromUnixTimestamp64Milli(1415792726123) AND cast(sth as timestamp) <= fromUnixTimestamp64Milli(1447328726456)", name: "clickhouse timeFilter_ms with empty spaces"}, + {input: "select * from foo where ( date >= $__fromTime and date <= $__toTime ) limit 100", output: "select * from foo where ( date >= toDateTime(1415792726) and date <= toDateTime(1447328726) ) limit 100", name: "clickhouse fromTime and toTime"}, + {input: "select * from foo where ( date >= $__fromTime ) and ( date <= $__toTime ) limit 100", output: "select * from foo where ( date >= toDateTime(1415792726) ) and ( date <= toDateTime(1447328726) ) limit 100", name: "clickhouse fromTime and toTime inside a complex clauses"}, + {input: "select * from foo where ( date >= $__fromTime_ms and date <= $__toTime_ms ) limit 100", output: "select * from foo where ( date >= fromUnixTimestamp64Milli(1415792726123) and date <= fromUnixTimestamp64Milli(1447328726456) ) limit 100", name: "clickhouse fromTime_ms and toTime_ms"}, + {input: "select * from foo where ( date >= $__fromTime_ms ) and ( date <= $__toTime_ms ) limit 100", output: "select * from foo where ( date >= fromUnixTimestamp64Milli(1415792726123) ) and ( date <= fromUnixTimestamp64Milli(1447328726456) ) limit 100", name: "clickhouse fromTime_ms and toTime_ms inside a complex clauses"}, + } + + for i, tc := range tests { + db, _, _ := sqlmock.New() + interpolator := NewInterpolator(&HydrolixDatasource{ + + Connector: &MockConnector{ + db: db, + uid: "uid-123", + }, + }) + t.Run(fmt.Sprintf("[%d/%d] %s", i+1, len(tests), tc.name), func(t *testing.T) { + query := &HDXQuery{ + RawSQL: tc.input, + TimeRange: backend.TimeRange{ + From: from, + To: to, + }, + } + interpolatedQuery, err := interpolator.Interpolate(context.Background(), query) + require.Nil(t, err) + assert.Equal(t, tc.output, interpolatedQuery) + }) + } +} + +// test sqlds query interpolation with clickhouse filters used +func TestInterpolateWithAutomaticParams(t *testing.T) { + from, _ := time.Parse("2006-01-02T15:04:05.000Z", "2014-11-12T11:45:26.123Z") + to, _ := time.Parse("2006-01-02T15:04:05.000Z", "2015-11-12T11:45:26.456Z") + + type test struct { + name string + input string + output string + } + + tests := []test{ + {input: "select * from foo.bar where $__timeFilter()", output: "select * from foo.bar where timestamp >= toDateTime(1415792726) AND timestamp <= toDateTime(1447328726)", name: "timeFilter auto timestamp empty param"}, + {input: "select * from bar where $__timeFilter()", output: "select * from bar where timestamp >= toDateTime(1415792726) AND timestamp <= toDateTime(1447328726)", name: "timeFilter auto timestamp empty param default db"}, + {input: "select * from foo.bar where $__timeFilter_ms()", output: "select * from foo.bar where timestamp >= fromUnixTimestamp64Milli(1415792726123) AND timestamp <= fromUnixTimestamp64Milli(1447328726456)", name: "timeFilter_ms auto timestamp empty param"}, + {input: "select * from bar where $__timeFilter_ms()", output: "select * from bar where timestamp >= fromUnixTimestamp64Milli(1415792726123) AND timestamp <= fromUnixTimestamp64Milli(1447328726456)", name: "timeFilter_ms auto timestamp empty param default db"}, + {input: "select $__timeInterval() from foo.bar", output: "select toStartOfInterval(toDateTime(timestamp), INTERVAL 1 second) from foo.bar", name: "timeInterval auto timestamp empty param"}, + {input: "select $__timeInterval() from bar", output: "select toStartOfInterval(toDateTime(timestamp), INTERVAL 1 second) from bar", name: "timeInterval auto timestamp empty param default db"}, + {input: "select $__timeInterval_ms() from foo.bar", output: "select toStartOfInterval(toDateTime64(timestamp, 3), INTERVAL 1 millisecond) from foo.bar", name: "timeInterval_ms auto timestamp empty param"}, + {input: "select $__timeInterval_ms() from bar", output: "select toStartOfInterval(toDateTime64(timestamp, 3), INTERVAL 1 millisecond) from bar", name: "timeInterval_ms auto timestamp empty param default db"}, + } + + for i, tc := range tests { + db, mock, _ := sqlmock.New() + + rows := sqlmock.NewRows([]string{"primary_key"}).AddRow("timestamp") + mock.ExpectQuery(fmt.Sprintf(PRIMARY_KEY_QUERY_STRING, "foo", "bar")). + WillReturnRows(rows) + interpolator := NewInterpolator(&HydrolixDatasource{ + + Connector: &MockConnector{ + db: db, + uid: "uid-123", + }, + rowLimit: defaultRowLimit, + }) + t.Run(fmt.Sprintf("[%d/%d] %s", i+1, len(tests), tc.name), func(t *testing.T) { + query := &HDXQuery{ + RawSQL: tc.input, + TimeRange: backend.TimeRange{ + From: from, + To: to, + }, + } + interpolatedQuery, err := interpolator.Interpolate(context.Background(), query) + require.Nil(t, err) + assert.Equal(t, tc.output, interpolatedQuery) + }) + } +} + +func TestNegativeCases(t *testing.T) { + + type test struct { + name string + input string + error string + } + + tests := []test{ + {input: "select * from foo.bar where $__timeFilter(arg1, arg2)", error: "unexpected number of arguments: expected 0 or 1 argument, received 2", name: "timeFilter auto timestamp empty param"}, + {input: "select * from foo.bar where $__timeFilter(arg1, arg2", error: "failed to parse macro arguments (missing close bracket?)", name: "timeFilter auto timestamp empty param"}, + } + + for i, tc := range tests { + + interpolator := NewInterpolator(&HydrolixDatasource{ + Connector: &MockConnector{ + uid: "uid-123", + }, + }) + t.Run(fmt.Sprintf("[%d/%d] %s", i+1, len(tests), tc.name), func(t *testing.T) { + query := &HDXQuery{ + RawSQL: tc.input, + } + _, err := interpolator.Interpolate(context.Background(), query) + require.Error(t, err, tc.error) + require.Equal(t, err.Error(), tc.error) + }) + } +} + +func TestAdHocFilterMacro(t *testing.T) { + type test struct { + name string + input string + output string + filters []AdHocFilter + } + + tests := []test{ + { + input: "select * from foo where $__adHocFilter()", + output: "select * from foo where 1=1", + filters: []AdHocFilter{}, + name: "no filters test", + }, + { + input: "select * from foo where $__adHocFilter()", + output: "select * from foo where column = $$test$$", + filters: []AdHocFilter{{Key: "column", Operator: "=", Value: "test"}}, + name: "single equals filter", + }, + { + input: "select * from foo where $__adHocFilter()", + output: "select * from foo where column = $$test$$ AND column2 != $$value2$$", + filters: []AdHocFilter{ + {Key: "column", Operator: "=", Value: "test"}, + {Key: "column2", Operator: "!=", Value: "value2"}, + }, + name: "multiple filters", + }, + { + input: "select * from foo where $__adHocFilter()", + output: "select * from foo where column2 IS NULL", + filters: []AdHocFilter{{Key: "column2", Operator: "=", Value: "null"}}, + name: "null value filter", + }, + { + input: "select * from foo where $__adHocFilter()", + output: "select * from foo where (column = '' OR column = '__empty__')", + filters: []AdHocFilter{{Key: "column", Operator: "=", Value: ""}}, + name: "empty value filter", + }, + { + input: "select * from foo where $__adHocFilter()", + output: "select * from foo where toString(column) LIKE $$%pattern%$$", + filters: []AdHocFilter{{Key: "column", Operator: "=~", Value: "*pattern*"}}, + name: "regex wildcard filter", + }, + { + input: "select * from foo where $__adHocFilter()", + output: "select * from foo where toString(column) NOT LIKE $$pattern$$", + filters: []AdHocFilter{{Key: "column", Operator: "!~", Value: "pattern"}}, + name: "regex not match filter", + }, + { + input: "select * from foo where $__adHocFilter()", + output: "select * from foo where column IN ($$a$$, $$b$$, $$c$$)", + filters: []AdHocFilter{{Key: "column", Operator: "=|", Values: []string{"a", "b", "c"}}}, + name: "multi-value IN filter", + }, + { + input: "select * from foo where $__adHocFilter()", + output: "select * from foo where column NOT IN ($$a$$, $$b$$, $$c$$)", + filters: []AdHocFilter{{Key: "column", Operator: "!=|", Values: []string{"a", "b", "c"}}}, + name: "multi-value NOT IN filter", + }, + { + input: "select * from foo where $__adHocFilter()", + output: "select * from foo where (column IS NULL OR column IN ($$a$$, $$c$$))", + filters: []AdHocFilter{{Key: "column", Operator: "=|", Values: []string{"a", "null", "c"}}}, + name: "multi-value IN with null", + }, + { + input: "select * from foo where $__adHocFilter()", + output: "select * from foo where column IS NOT NULL AND column NOT IN ($$a$$, $$c$$)", + filters: []AdHocFilter{{Key: "column", Operator: "!=|", Values: []string{"a", "null", "c"}}}, + name: "multi-value NOT IN with null", + }, + { + input: "select * from foo where $__adHocFilter()", + output: "select * from foo where 1=1", + filters: []AdHocFilter{{Key: "nonexistent", Operator: "=", Value: "test"}}, + name: "filter on non-existent column", + }, + { + input: "select * from foo where $__adHocFilter()", + output: "select * from foo where column = $$val'ue$$", + filters: []AdHocFilter{{Key: "column", Operator: "=", Value: "val'ue"}}, + name: "value with single quotes", + }, + { + input: "select * from foo where $__adHocFilter()", + output: "select * from foo where column IN ($$$$, $$b$$)", + filters: []AdHocFilter{{Key: "column", Operator: "=|", Values: []string{"", "b"}}}, + name: "multi-value IN with empty string", + }, + { + input: "select * from foo where $__adHocFilter()", + output: "select * from foo where has(arrayColumn, $$value$$)", + filters: []AdHocFilter{{Key: "arrayColumn", Operator: "=", Value: "value"}}, + name: "array column with equals", + }, + { + input: "select * from foo where $__adHocFilter()", + output: "select * from foo where not has(arrayColumn, $$test$$)", + filters: []AdHocFilter{{Key: "arrayColumn", Operator: "!=", Value: "test"}}, + name: "array column with not equals", + }, + { + input: "select * from foo where $__adHocFilter()", + output: "select * from foo where (has(arrayColumn, $$a$$) OR has(arrayColumn, $$b$$) OR has(arrayColumn, $$c$$))", + filters: []AdHocFilter{{Key: "arrayColumn", Operator: "=|", Values: []string{"a", "b", "c"}}}, + name: "array column with multi-value IN", + }, + { + input: "select * from foo where $__adHocFilter()", + output: "select * from foo where (not has(arrayColumn, $$x$$) OR not has(arrayColumn, $$y$$))", + filters: []AdHocFilter{{Key: "arrayColumn", Operator: "!=|", Values: []string{"x", "y"}}}, + name: "array column with multi-value NOT IN", + }, + { + input: "select * from foo where $__adHocFilter()", + output: "select * from foo where column = $$test$$ AND has(arrayColumn, $$prod$$)", + filters: []AdHocFilter{ + {Key: "column", Operator: "=", Value: "test"}, + {Key: "arrayColumn", Operator: "=", Value: "prod"}, + }, + name: "mixed string and array columns", + }, + { + input: "select * from foo where $__adHocFilter()", + output: "select * from foo where mapColumn['key1'] = $$value1$$", + filters: []AdHocFilter{{Key: "mapColumn['key1']", Operator: "=", Value: "value1"}}, + name: "map column with key syntax", + }, + { + input: "select * from foo where $__adHocFilter()", + output: "select * from foo where mapColumn['status'] IN ($$active$$, $$pending$$)", + filters: []AdHocFilter{{Key: "mapColumn['status']", Operator: "=|", Values: []string{"active", "pending"}}}, + name: "map column with multi-value IN", + }, + { + input: "select * from foo where $__adHocFilter()", + output: "select * from foo where column = $$test$$ AND mapColumn['env'] = $$prod$$", + filters: []AdHocFilter{ + {Key: "column", Operator: "=", Value: "test"}, + {Key: "mapColumn['env']", Operator: "=", Value: "prod"}, + }, + name: "mixed string and map columns", + }, + } + for i, tc := range tests { + db, mock, _ := sqlmock.New() + + rows := sqlmock.NewRows([]string{"name", "type"}). + AddRow("column", "Nullable(String)"). + AddRow("column2", "UInt64"). + AddRow("arrayColumn", "Array(String)"). + AddRow("mapColumn", "Map(String, String)") + mock.ExpectQuery(fmt.Sprintf(AD_HOC_KEY_QUERY, "foo")). + WillReturnRows(rows) + interpolator := NewInterpolator(&HydrolixDatasource{ + + Connector: &MockConnector{ + db: db, + uid: "uid-123", + }, + rowLimit: defaultRowLimit, + }) + t.Run(fmt.Sprintf("[%d/%d] %s", i+1, len(tests), tc.name), func(t *testing.T) { + query := &HDXQuery{ + RawSQL: tc.input, + Filters: tc.filters, + } + interpolatedQuery, err := interpolator.Interpolate(context.Background(), query) + require.Nil(t, err) + assert.Equal(t, tc.output, interpolatedQuery) + }) + } +} + +func TestBuildFilterCondition(t *testing.T) { + tests := []struct { + name string + filter AdHocFilter + keyType string + expected string + wantErr bool + }{ + { + name: "equals operator", + filter: AdHocFilter{Key: "column", Operator: "=", Value: "value"}, + keyType: "String", + expected: "column = $$value$$", + }, + { + name: "equals with empty string", + filter: AdHocFilter{Key: "column", Operator: "=", Value: ""}, + keyType: "String", + expected: "(column = '' OR column = '__empty__')", + }, + { + name: "equals with null string", + filter: AdHocFilter{Key: "column", Operator: "=", Value: "null"}, + keyType: "String", + expected: "(column IS NULL OR column = '__null__')", + }, + { + name: "not equals operator", + filter: AdHocFilter{Key: "column", Operator: "!=", Value: "value"}, + keyType: "String", + expected: "column != $$value$$", + }, + { + name: "not equals with empty string", + filter: AdHocFilter{Key: "column", Operator: "!=", Value: ""}, + keyType: "String", + expected: "(column != '' AND column != '__empty__')", + }, + { + name: "not equals with null", + filter: AdHocFilter{Key: "column", Operator: "!=", Value: "null"}, + keyType: "String", + expected: "(column IS NOT NULL OR column != '__null__')", + }, + { + name: "regex match", + filter: AdHocFilter{Key: "column", Operator: "=~", Value: "pattern"}, + keyType: "String", + expected: "toString(column) LIKE $$pattern$$", + }, + { + name: "regex match with wildcards", + filter: AdHocFilter{Key: "column", Operator: "=~", Value: "*test*"}, + keyType: "String", + expected: "toString(column) LIKE $$%test%$$", + }, + { + name: "regex not match", + filter: AdHocFilter{Key: "column", Operator: "!~", Value: "pattern"}, + keyType: "String", + expected: "toString(column) NOT LIKE $$pattern$$", + }, + { + name: "multi-value IN", + filter: AdHocFilter{Key: "column", Operator: "=|", Values: []string{"a", "b", "c"}}, + keyType: "String", + expected: "column IN ($$a$$, $$b$$, $$c$$)", + }, + { + name: "multi-value IN with null", + filter: AdHocFilter{Key: "column", Operator: "=|", Values: []string{"a", "null", "c"}}, + keyType: "String", + expected: "(column IS NULL OR column IN ($$a$$, $$c$$))", + }, + { + name: "multi-value IN with empty", + filter: AdHocFilter{Key: "column", Operator: "=|", Values: []string{"a", "", "c"}}, + keyType: "String", + expected: "column IN ($$a$$, $$$$, $$c$$)", + }, + { + name: "multi-value NOT IN", + filter: AdHocFilter{Key: "column", Operator: "!=|", Values: []string{"a", "b", "c"}}, + keyType: "String", + expected: "column NOT IN ($$a$$, $$b$$, $$c$$)", + }, + { + name: "multi-value NOT IN with null", + filter: AdHocFilter{Key: "column", Operator: "!=|", Values: []string{"a", "null", "c"}}, + keyType: "String", + expected: "column IS NOT NULL AND column NOT IN ($$a$$, $$c$$)", + }, + { + name: "single quote escaping", + filter: AdHocFilter{Key: "column", Operator: "=", Value: "val'ue"}, + keyType: "String", + expected: "column = $$val'ue$$", + }, + { + name: "less than operator", + filter: AdHocFilter{Key: "column", Operator: "<", Value: "100"}, + keyType: "UInt32", + expected: "column < $$100$$", + }, + { + name: "greater than operator", + filter: AdHocFilter{Key: "column", Operator: ">", Value: "50"}, + keyType: "UInt32", + expected: "column > $$50$$", + }, + { + name: "multi-value IN only null", + filter: AdHocFilter{Key: "column", Operator: "=|", Values: []string{"null"}}, + keyType: "String", + expected: "column IS NULL", + }, + { + name: "multi-value NOT IN only null", + filter: AdHocFilter{Key: "column", Operator: "!=|", Values: []string{"null"}}, + keyType: "String", + expected: "column IS NOT NULL", + }, + { + name: "array type with has operator", + filter: AdHocFilter{Key: "column", Operator: "=", Value: "value"}, + keyType: "Array(String)", + expected: "has(column, $$value$$)", + }, + { + name: "array type with not has operator", + filter: AdHocFilter{Key: "column", Operator: "!=", Value: "value"}, + keyType: "Array(String)", + expected: "not has(column, $$value$$)", + }, + { + name: "array type with multi-value IN", + filter: AdHocFilter{Key: "column", Operator: "=|", Values: []string{"a", "b", "c"}}, + keyType: "Array(String)", + expected: "(has(column, $$a$$) OR has(column, $$b$$) OR has(column, $$c$$))", + }, + { + name: "array type with multi-value NOT IN", + filter: AdHocFilter{Key: "column", Operator: "!=|", Values: []string{"a", "b", "c"}}, + keyType: "Array(String)", + expected: "(not has(column, $$a$$) OR not has(column, $$b$$) OR not has(column, $$c$$))", + }, + { + name: "nullable array type", + filter: AdHocFilter{Key: "column", Operator: "=", Value: "value"}, + keyType: "Array(Nullable(String))", + expected: "has(column, $$value$$)", + }, + { + name: "array type with unsupported operator", + filter: AdHocFilter{Key: "column", Operator: "=~", Value: "pattern"}, + keyType: "Array(String)", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := buildFilterCondition(tt.filter, tt.keyType) + + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} + +func TestBuildArrayCondition(t *testing.T) { + tests := []struct { + name string + filter AdHocFilter + expected string + wantErr bool + }{ + { + name: "array equals operator", + filter: AdHocFilter{Key: "tags", Operator: "=", Value: "production"}, + expected: "has(tags, $$production$$)", + }, + { + name: "array not equals operator", + filter: AdHocFilter{Key: "tags", Operator: "!=", Value: "test"}, + expected: "not has(tags, $$test$$)", + }, + { + name: "array multi-value IN", + filter: AdHocFilter{Key: "tags", Operator: "=|", Values: []string{"prod", "staging", "dev"}}, + expected: "(has(tags, $$prod$$) OR has(tags, $$staging$$) OR has(tags, $$dev$$))", + }, + { + name: "array multi-value NOT IN", + filter: AdHocFilter{Key: "tags", Operator: "!=|", Values: []string{"prod", "staging"}}, + expected: "(not has(tags, $$prod$$) OR not has(tags, $$staging$$))", + }, + { + name: "array with single value in multi-value IN", + filter: AdHocFilter{Key: "tags", Operator: "=|", Values: []string{"production"}}, + expected: "(has(tags, $$production$$))", + }, + { + name: "array with less than operator (unsupported)", + filter: AdHocFilter{Key: "tags", Operator: "<", Value: "100"}, + wantErr: true, + }, + { + name: "array with greater than operator (unsupported)", + filter: AdHocFilter{Key: "tags", Operator: ">", Value: "50"}, + wantErr: true, + }, + { + name: "array with regex operator (unsupported)", + filter: AdHocFilter{Key: "tags", Operator: "=~", Value: "pattern"}, + wantErr: true, + }, + { + name: "array with regex not match operator (unsupported)", + filter: AdHocFilter{Key: "tags", Operator: "!~", Value: "pattern"}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := buildArrayCondition(tt.filter) + + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), "unsupported") + } else { + require.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} + +func TestBuildFilterConditionWithMaps(t *testing.T) { + tests := []struct { + name string + filter AdHocFilter + keyType string + expected string + wantErr bool + }{ + { + name: "map string type with equals", + filter: AdHocFilter{Key: "labels['env']", Operator: "=", Value: "prod"}, + keyType: "Map(String, String)", + expected: "labels['env'] = $$prod$$", + }, + { + name: "map string type with not equals", + filter: AdHocFilter{Key: "labels['status']", Operator: "!=", Value: "inactive"}, + keyType: "Map(String, String)", + expected: "labels['status'] != $$inactive$$", + }, + { + name: "map string type with multi-value IN", + filter: AdHocFilter{Key: "labels['region']", Operator: "=|", Values: []string{"us-east", "us-west"}}, + keyType: "Map(String, String)", + expected: "labels['region'] IN ($$us-east$$, $$us-west$$)", + }, + { + name: "map string type with multi-value NOT IN", + filter: AdHocFilter{Key: "labels['env']", Operator: "!=|", Values: []string{"dev", "test"}}, + keyType: "Map(String, String)", + expected: "labels['env'] NOT IN ($$dev$$, $$test$$)", + }, + { + name: "map nullable string type", + filter: AdHocFilter{Key: "metadata['key']", Operator: "=", Value: "value"}, + keyType: "Map(String, Nullable(String))", + expected: "metadata['key'] = $$value$$", + }, + { + name: "map uint type with multi-value IN (error)", + filter: AdHocFilter{Key: "counts['total']", Operator: "=|", Values: []string{"100", "200"}}, + keyType: "Map(String, UInt32)", + wantErr: true, + }, + { + name: "map uint type with multi-value NOT IN (error)", + filter: AdHocFilter{Key: "counts['total']", Operator: "!=|", Values: []string{"100", "200"}}, + keyType: "Map(String, UInt32)", + wantErr: true, + }, + { + name: "map uint type with equals (allowed)", + filter: AdHocFilter{Key: "counts['total']", Operator: "=", Value: "100"}, + keyType: "Map(String, UInt32)", + expected: "counts['total'] = $$100$$", + }, + { + name: "map with regex match", + filter: AdHocFilter{Key: "labels['name']", Operator: "=~", Value: "*prod*"}, + keyType: "Map(String, String)", + expected: "toString(labels['name']) LIKE $$%prod%$$", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := buildFilterCondition(tt.filter, tt.keyType) + + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), "cannot apply") + } else { + require.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} + +func TestEscapeWildcard(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "no wildcards", + input: "pattern", + expected: "pattern", + }, + { + name: "single wildcard", + input: "patt*ern", + expected: "patt%ern", + }, + { + name: "multiple wildcards", + input: "*pattern*", + expected: "%pattern%", + }, + { + name: "only wildcards", + input: "***", + expected: "%%%", + }, + { + name: "escaped wildcard", + input: "foo\\*bar", + expected: "foo*bar", + }, + { + name: "escaped and unescaped wildcard", + input: "a*b\\*c*d", + expected: "a%b*c%d", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := escapeWildcard(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestAdHocFilterMacroWithExplicitTable tests the AdHocFilterMacro with an explicit table parameter +func TestAdHocFilterMacroWithExplicitTable(t *testing.T) { + type test struct { + name string + input string + output string + filters []AdHocFilter + table string + } + + tests := []test{ + { + input: "select * from bar where $__adHocFilter(foo)", + output: "select * from bar where column = $$test$$", + filters: []AdHocFilter{{Key: "column", Operator: "=", Value: "test"}}, + table: "foo", + name: "explicit table parameter", + }, + { + input: "select * from bar where $__adHocFilter(baz)", + output: "select * from bar where column = $$test$$ AND column2 != $$value2$$", + filters: []AdHocFilter{ + {Key: "column", Operator: "=", Value: "test"}, + {Key: "column2", Operator: "!=", Value: "value2"}, + }, + table: "baz", + name: "explicit table with multiple filters", + }, + { + input: "select * from bar where $__adHocFilter(myTable)", + output: "select * from bar where 1=1", + filters: []AdHocFilter{}, + table: "myTable", + name: "explicit table with no filters", + }, + { + input: "select * from bar where $__adHocFilter(foo)", + output: "select * from bar where 1=1", + filters: []AdHocFilter{{Key: "nonexistent", Operator: "=", Value: "test"}}, + table: "foo", + name: "explicit table with filter on non-existent column", + }, + } + + for i, tc := range tests { + db, mock, _ := sqlmock.New() + + rows := sqlmock.NewRows([]string{"name", "type"}). + AddRow("column", "Nullable(String)"). + AddRow("column2", "UInt64") + mock.ExpectQuery(fmt.Sprintf(AD_HOC_KEY_QUERY, tc.table)). + WillReturnRows(rows) + + interpolator := NewInterpolator(&HydrolixDatasource{ + Connector: &MockConnector{ + db: db, + uid: "uid-123", + }, + rowLimit: defaultRowLimit, + }) + + t.Run(fmt.Sprintf("[%d/%d] %s", i+1, len(tests), tc.name), func(t *testing.T) { + query := &HDXQuery{ + RawSQL: tc.input, + Filters: tc.filters, + } + interpolatedQuery, err := interpolator.Interpolate(context.Background(), query) + require.Nil(t, err) + assert.Equal(t, tc.output, interpolatedQuery) + }) + } +} + +// TestAdHocFilterMacroWithTooManyParams tests that AdHocFilterMacro returns an error with too many parameters +func TestAdHocFilterMacroWithTooManyParams(t *testing.T) { + db, _, _ := sqlmock.New() + + interpolator := NewInterpolator(&HydrolixDatasource{ + Connector: &MockConnector{ + db: db, + uid: "uid-123", + }, + rowLimit: defaultRowLimit, + }) + + query := &HDXQuery{ + RawSQL: "select * from foo where $__adHocFilter(table1, table2)", + Filters: []AdHocFilter{{Key: "column", Operator: "=", Value: "test"}}, + } + + _, err := interpolator.Interpolate(context.Background(), query) + require.Error(t, err) + assert.Contains(t, err.Error(), "expected 0 or 1 argument, received 2") +} diff --git a/metadata.go b/metadata.go new file mode 100644 index 0000000..8623595 --- /dev/null +++ b/metadata.go @@ -0,0 +1,216 @@ +package sqlds + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "github.com/grafana/grafana-plugin-sdk-go/backend" + "github.com/grafana/grafana-plugin-sdk-go/backend/log" + "github.com/grafana/grafana-plugin-sdk-go/data" + "github.com/hydrolix/sqlds/v5/models" + "github.com/jellydator/ttlcache/v3" + "net/http" + "strings" + "time" +) + +var ( + PRIMARY_KEY_QUERY_STRING = "SELECT primary_key FROM system.tables WHERE database='%s' AND table ='%s'" + AD_HOC_KEY_QUERY = "DESCRIBE %s" + PRIMARY_KEY_NOT_FOUND_ERROR = backend.PluginError(errors.New("primary key not found")) + KEYS_NOT_FOUND_ERROR = backend.PluginError(errors.New("adHocFilter keys not found")) +) + +type MetaDataProvider struct { + ds *HydrolixDatasource + pkCache *ttlcache.Cache[string, string] + keyCache *ttlcache.Cache[string, map[string]string] +} + +func NewMetaDataProvider(ds *HydrolixDatasource) *MetaDataProvider { + pkCache := ttlcache.New[string, string](ttlcache.WithTTL[string, string](time.Hour)) + keyCache := ttlcache.New[string, map[string]string](ttlcache.WithTTL[string, map[string]string](time.Hour)) + return &MetaDataProvider{ds: ds, pkCache: pkCache, keyCache: keyCache} +} + +func (p *MetaDataProvider) GetPK(context context.Context, headers http.Header, database string, table string) (string, error) { + + if database == "" { + defaultDB, err := p.getDefaultDatabase(context) + if err != nil { + return "", err + } + database = defaultDB + } + + cacheKey := fmt.Sprintf("%s_%s", database, table) + + entry := p.pkCache.Get(cacheKey) + if entry == nil { + log.DefaultLogger.Debug("Cache miss", "key", cacheKey) + pk, err := p.QueryPK(context, headers, database, table) + if err != nil { + return "", err + } + p.pkCache.Set(cacheKey, pk, ttlcache.DefaultTTL) + + return pk, nil + } else { + log.DefaultLogger.Debug("Cache hit", "key", cacheKey) + return entry.Value(), nil + } + +} + +func (p *MetaDataProvider) GetKeys(context context.Context, headers http.Header, cte string) (map[string]string, error) { + cacheKey := cte + + entry := p.keyCache.Get(cacheKey) + if entry == nil { + log.DefaultLogger.Debug("Cache miss", "key", cacheKey) + keys, err := p.QueryKeys(context, headers, cte) + if err != nil { + return nil, err + } + p.keyCache.Set(cacheKey, keys, ttlcache.DefaultTTL) + + return keys, nil + } else { + log.DefaultLogger.Debug("Cache hit", "key", cacheKey) + return entry.Value(), nil + } +} + +func (p *MetaDataProvider) getDefaultDatabase(context context.Context) (string, error) { + settings, err := models.NewPluginSettings(context, p.ds.Connector.getInstanceSettings()) + if err != nil { + return "", err + } + return settings.DefaultDatabase, nil +} + +// executeQuery executes a SQL query using the QueryData method and returns the resulting frame +func (p *MetaDataProvider) executeQuery(ctx context.Context, headers http.Header, sql string, queryID string) (*data.Frame, error) { + // Create a query using QueryData method + queryJSON, err := json.Marshal(map[string]interface{}{ + "rawSql": sql, + "format": 1, + }) + if err != nil { + return nil, err + } + + newHeaders := make(map[string]string, len(headers)) + for k, _ := range headers { + newHeaders[k] = headers.Get(k) + } + + dataQuery := backend.DataQuery{ + RefID: queryID, + JSON: queryJSON, + } + + settings := p.ds.Connector.getInstanceSettings() + req := &backend.QueryDataRequest{ + PluginContext: backend.PluginContext{ + DataSourceInstanceSettings: &settings, + }, + Queries: []backend.DataQuery{dataQuery}, + Headers: newHeaders, + } + + // Execute the query using QueryData + response, err := p.ds.QueryData(ctx, req) + if err != nil { + return nil, err + } + + // Check for errors in the response + dataResponse, ok := response.Responses[dataQuery.RefID] + if !ok { + return nil, fmt.Errorf("no response for query %s", queryID) + } + if dataResponse.Error != nil { + return nil, dataResponse.Error + } + + // Extract the frame from the response + if len(dataResponse.Frames) == 0 { + return nil, fmt.Errorf("no frames in response") + } + + return dataResponse.Frames[0], nil +} + +func (p *MetaDataProvider) QueryPK(ctx context.Context, headers http.Header, database string, table string) (string, error) { + // Format the SQL query with actual parameter values + formattedSQL := fmt.Sprintf(PRIMARY_KEY_QUERY_STRING, database, table) + + frame, err := p.executeQuery(ctx, headers, formattedSQL, "pk_query") + if err != nil { + return "", err + } + + if len(frame.Fields) == 0 { + return "", PRIMARY_KEY_NOT_FOUND_ERROR + } + + field := frame.Fields[0] + if field.Len() == 0 { + return "", PRIMARY_KEY_NOT_FOUND_ERROR + } + + v, err := p.GetStringSafe(field.At(0)) + + return v, err +} + +func (p *MetaDataProvider) QueryKeys(ctx context.Context, headers http.Header, cte string) (map[string]string, error) { + if strings.Contains(strings.ToUpper(cte), "SELECT") { + cte = fmt.Sprintf("(%s)", cte) + } + formattedSQL := fmt.Sprintf(AD_HOC_KEY_QUERY, cte) + + frame, err := p.executeQuery(ctx, headers, formattedSQL, "key_query") + if err != nil { + return nil, err + } + if len(frame.Fields) < 2 { + return nil, KEYS_NOT_FOUND_ERROR + } + keyFiled := frame.Fields[0] + typeFiled := frame.Fields[1] + + keys := make(map[string]string, keyFiled.Len()) + + for i := range keyFiled.Len() { + key, err := p.GetStringSafe(keyFiled.At(i)) + if err != nil { + return nil, err + } + keyType, err := p.GetStringSafe(typeFiled.At(i)) + if err != nil { + return nil, err + } + keys[key] = keyType + + } + + return keys, err +} + +func (p *MetaDataProvider) GetStringSafe(v any) (string, error) { + + switch x := v.(type) { + case string: + return x, nil + case *string: + if x == nil { + return "", nil + } + return *x, nil + + } + return "", errors.New("invalid type") +} diff --git a/metadata_test.go b/metadata_test.go new file mode 100644 index 0000000..7318946 --- /dev/null +++ b/metadata_test.go @@ -0,0 +1,122 @@ +package sqlds + +import ( + "context" + "fmt" + "net/http" + "testing" + + "github.com/DATA-DOG/go-sqlmock" +) + +func TestQueryPK_Success(t *testing.T) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("sqlmock.New(): %v", err) + } + defer db.Close() + p := "id" + rows := mock.NewRows([]string{"primary_key"}).AddRow(p) + + mock.ExpectQuery(fmt.Sprintf(PRIMARY_KEY_QUERY_STRING, "db1", "tbl1")). + WillReturnRows(rows) + + ds := &HydrolixDatasource{ + Connector: &MockConnector{ + db: db, + uid: "uid-123", + }, + rowLimit: defaultRowLimit, + } + provider := &MetaDataProvider{ds: ds} + pk, err := provider.QueryPK(context.Background(), http.Header{}, "db1", "tbl1") + + if err != nil { + t.Fatalf("QueryPK returned error: %v", err) + } + if pk != "id" { + t.Fatalf("expected pk 'id', got %q", pk) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Fatalf("unmet SQL expectations: %v", err) + } +} + +func TestQueryPK_NoRows_ReturnsNotFound(t *testing.T) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("sqlmock.New(): %v", err) + } + defer db.Close() + + // Zero rows -> field.Len()==0 -> PRIMARY_KEY_NOT_FOUND_ERROR + rows := sqlmock.NewRows([]string{"primary_key"}) + mock.ExpectQuery(fmt.Sprintf(PRIMARY_KEY_QUERY_STRING, "db2", "tbl2")). + WillReturnRows(rows) + + ds := &HydrolixDatasource{ + Connector: &MockConnector{ + db: db, + uid: "uid-abc", + }, + rowLimit: defaultRowLimit, + } + provider := &MetaDataProvider{ds: ds} + + _, err = provider.QueryPK(context.Background(), http.Header{}, "db2", "tbl2") + if err == nil || err.Error() != PRIMARY_KEY_NOT_FOUND_ERROR.Error() { + t.Fatalf("expected PRIMARY_KEY_NOT_FOUND_ERROR, got %v", err) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Fatalf("unmet SQL expectations: %v", err) + } +} + +func TestGetPK_UsesCache_AvoidsSecondQuery(t *testing.T) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("sqlmock.New(): %v", err) + } + defer db.Close() + + // First call: cache miss -> DB hit + rows := sqlmock.NewRows([]string{"primary_key"}).AddRow("event_id") + mock.ExpectQuery(fmt.Sprintf(PRIMARY_KEY_QUERY_STRING, "analytics", "events")). + WillReturnRows(rows) + mc := &MockConnector{ + db: db, + uid: "uid-cache", + } + ds := &HydrolixDatasource{Connector: mc, rowLimit: defaultRowLimit} + provider := NewMetaDataProvider(ds) + + ctx := context.Background() + + // First call populates cache + pk1, err := provider.GetPK(ctx, http.Header{}, "analytics", "events") + if err != nil { + t.Fatalf("GetPK (first) error: %v", err) + } + if pk1 != "event_id" { + t.Fatalf("expected 'event_id', got %q", pk1) + } + + // Second call should be a cache hit -> no new DB call + pk2, err := provider.GetPK(ctx, http.Header{}, "analytics", "events") + if err != nil { + t.Fatalf("GetPK (second) error: %v", err) + } + if pk2 != "event_id" { + t.Fatalf("expected 'event_id' on cache hit, got %q", pk2) + } + + if mc.connCalls != 1 { + t.Fatalf("expected exactly 1 getDBConnection call, got %d", mc.connCalls) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Fatalf("unmet SQL expectations: %v", err) + } +} diff --git a/mock/csv/csv_mock.go b/mock/csv/csv_mock.go index 9e2d339..2080002 100644 --- a/mock/csv/csv_mock.go +++ b/mock/csv/csv_mock.go @@ -13,7 +13,7 @@ import ( "github.com/grafana/grafana-plugin-sdk-go/backend" "github.com/grafana/grafana-plugin-sdk-go/data" "github.com/grafana/grafana-plugin-sdk-go/data/sqlutil" - "github.com/grafana/sqlds/v5" + "github.com/hydrolix/sqlds/v5" _ "github.com/mithrandie/csvq-driver" ) @@ -89,8 +89,3 @@ func (h *SQLCSVMock) Connect(_ context.Context, _ backend.DataSourceInstanceSett func (h *SQLCSVMock) Converters() []sqlutil.Converter { return []sqlutil.Converter{} } - -// Macros returns list of macro functions convert the macros of raw query -func (h *SQLCSVMock) Macros() sqlds.Macros { - return sqlds.Macros{} -} diff --git a/models/settings.go b/models/settings.go new file mode 100644 index 0000000..f0812cd --- /dev/null +++ b/models/settings.go @@ -0,0 +1,211 @@ +// Package models provides Hydrolix plugin's configuration settings +package models + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "github.com/grafana/grafana-plugin-sdk-go/backend" + "math" + "reflect" + "slices" + "strconv" + "strings" +) + +// QuerySettings validation errors +var ( + ErrorMessageInvalidJSON = errors.New("invalid settings json") + ErrorMessageInvalidHost = errors.New("Server address is missing") + ErrorMessageInvalidPort = errors.New("Server port is missing") + ErrorMessageInvalidProtocol = errors.New("Protocol should be either native or http") + ErrorMessageInvalidQueryTimeout = errors.New("Invalid Query Timeout") + ErrorMessageInvalidDialTimeout = errors.New("Invalid Connect Timeout") +) + +// PluginSettings structure represent data source configuration options +type PluginSettings struct { + Host string `json:"host"` + UserName string `json:"username"` + Port uint16 `json:"port"` + Protocol string `json:"protocol"` + Password string `json:"-"` + Token string `json:"-"` + CredentialsType string `json:"credentialsType"` + Secure bool `json:"secure"` + Path string `json:"path,omitempty"` + SkipTlsVerify bool `json:"skipTlsVerify,omitempty"` + DialTimeout string `json:"dialTimeout,omitempty"` + QueryTimeout string `json:"queryTimeout,omitempty"` + DefaultDatabase string `json:"defaultDatabase,omitempty"` + QuerySettings []QuerySetting `json:"querySettings,omitempty"` + Other map[string]any `json:"-"` +} + +type QuerySetting struct { + Setting string `json:"setting"` + Value string `json:"value"` +} + +// IsValid validates configuration data correctness +func (settings *PluginSettings) IsValid() error { + if settings.Host == "" { + return backend.DownstreamError(ErrorMessageInvalidHost) + } + if settings.Port == 0 { + return backend.DownstreamError(ErrorMessageInvalidPort) + } + + if !slices.Contains([]string{"http", "native"}, settings.Protocol) { + return backend.DownstreamError(ErrorMessageInvalidProtocol) + } + + if _, err := strconv.Atoi(settings.DialTimeout); err != nil { + return backend.DownstreamError(ErrorMessageInvalidDialTimeout) + } + + if _, err := strconv.Atoi(settings.QueryTimeout); err != nil { + return backend.DownstreamError(ErrorMessageInvalidQueryTimeout) + } + + return nil +} + +// SetDefaults applies default values to not defined options +func (settings *PluginSettings) SetDefaults() { + if strings.TrimSpace(settings.DialTimeout) == "" { + settings.DialTimeout = "10" + } + if strings.TrimSpace(settings.QueryTimeout) == "" { + settings.QueryTimeout = "60" + } +} + +// NewPluginSettings initializes PluginSettings with data provided by Grafana +func NewPluginSettings(_ context.Context, source backend.DataSourceInstanceSettings) (settings PluginSettings, e error) { + var jsonData map[string]interface{} + if err := json.Unmarshal(source.JSONData, &jsonData); err != nil { + return settings, fmt.Errorf("%s: %w", err.Error(), ErrorMessageInvalidJSON) + } + + if jsonData["host"] != nil { + settings.Host = jsonData["host"].(string) + } + + if jsonData["port"] != nil { + port, err := parseUint(jsonData["port"]) + if err != nil { + return settings, err + } + settings.Port = port + } + + if jsonData["protocol"] != nil { + settings.Protocol = jsonData["protocol"].(string) + } + + if jsonData["credentialsType"] != nil { + settings.CredentialsType = jsonData["credentialsType"].(string) + } + + if jsonData["secure"] != nil { + secure, err := parseBool(jsonData["secure"]) + if err != nil { + return settings, err + } + settings.Secure = secure + } + + if jsonData["path"] != nil { + settings.Path = jsonData["path"].(string) + } + + if jsonData["username"] != nil { + settings.UserName = jsonData["username"].(string) + } + + if jsonData["defaultDatabase"] != nil { + settings.DefaultDatabase = jsonData["defaultDatabase"].(string) + } + + if jsonData["dialTimeout"] != nil { + settings.DialTimeout = jsonData["dialTimeout"].(string) + } + + if jsonData["queryTimeout"] != nil { + settings.QueryTimeout = jsonData["queryTimeout"].(string) + } + + if jsonData["skipTlsVerify"] != nil { + skipTlsVerify, err := parseBool(jsonData["skipTlsVerify"]) + if err != nil { + return settings, err + } + settings.SkipTlsVerify = skipTlsVerify + } + + if jsonData["querySettings"] != nil { + settings.QuerySettings = []QuerySetting{} + rv := reflect.ValueOf(jsonData["querySettings"]) + if rv.Kind() == reflect.Slice { + for i := 0; i < rv.Len(); i++ { + qs := rv.Index(i).Interface().(map[string]interface{}) + settings.QuerySettings = append(settings.QuerySettings, QuerySetting{Value: qs["value"].(string), Setting: qs["setting"].(string)}) + } + } + } + + if password, ok := source.DecryptedSecureJSONData["password"]; ok { + settings.Password = password + } + + if token, ok := source.DecryptedSecureJSONData["token"]; ok { + settings.Token = token + } + + settings.SetDefaults() + + return settings, settings.IsValid() +} + +// parseBool parses boolean value +func parseBool(in any) (bool, error) { + switch v := in.(type) { + case bool: + return v, nil + case string: + return strconv.ParseBool(v) + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + return strconv.ParseBool(fmt.Sprintf("%d", v)) + case float32, float64: + v64, _ := strconv.ParseFloat(fmt.Sprintf("%f", v), 64) + if math.Trunc(v64) == v64 { + return strconv.ParseBool(strconv.FormatFloat(v64, 'f', -1, 64)) + } + return false, backend.DownstreamError(fmt.Errorf("could not parse bool value: %s", in)) + default: + return false, backend.DownstreamError(fmt.Errorf("could not parse bool value: %s", in)) + } +} + +// parseUint parses unsigned integer value +func parseUint(in any) (uint16, error) { + switch v := in.(type) { + case string: + port, err := strconv.ParseUint(v, 10, 16) + return uint16(port), err + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + v64, err := strconv.ParseUint(fmt.Sprintf("%d", v), 10, 16) + return uint16(v64), err + case float32, float64: + v64, _ := strconv.ParseFloat(fmt.Sprintf("%f", v), 64) + if math.Trunc(v64) == v64 { + return uint16(v64), nil + } + return 0, backend.DownstreamError(fmt.Errorf("could not parse bool value: %s", in)) + default: + return 0, backend.DownstreamError(fmt.Errorf("could not parse uint value: %s", in)) + } + +} diff --git a/models/settings_test.go b/models/settings_test.go new file mode 100644 index 0000000..f7e4196 --- /dev/null +++ b/models/settings_test.go @@ -0,0 +1,207 @@ +package models + +import ( + "context" + "encoding/json" + "fmt" + "github.com/grafana/grafana-plugin-sdk-go/backend" + "github.com/stretchr/testify/assert" + "reflect" + "strings" + "testing" +) + +func TestPluginSettings(t *testing.T) { + t.Run("parse grafana ds plugin settings", func(t *testing.T) { + settings := PluginSettings{ + Host: "localhost", + Port: 80, + Protocol: "native", + UserName: "default", + Password: "pass", + Secure: true, + Path: "/query", + SkipTlsVerify: true, + DialTimeout: "10", + QueryTimeout: "20", + DefaultDatabase: "dbdb", + Other: nil, + } + jsonData, err := json.Marshal(settings) + if err != nil { + t.Fatal(err) + } + + dsSettings := backend.DataSourceInstanceSettings{ + Name: "test-hydrolix-http-datasource", + JSONData: jsonData, + DecryptedSecureJSONData: map[string]string{"password": settings.Password}, + } + newSettings, err := NewPluginSettings(context.Background(), dsSettings) + assert.NoError(t, err) + assert.Equal(t, settings, newSettings) + + }) + + t.Run("parse ds plugin settings various types", func(t *testing.T) { + settings := PluginSettings{ + Host: "localhost", + Port: 80, + Protocol: "native", + UserName: "default", + Password: "pass", + Secure: true, + Path: "/query", + SkipTlsVerify: true, + DialTimeout: "10", + QueryTimeout: "20", + DefaultDatabase: "dbdb", + Other: nil, + } + originalSettings, err := json.Marshal(settings) + if err != nil { + t.Fatal(err) + } + + tests := []struct { + name string + val any + res any + }{ + {"secure", "true", true}, + {"secure", "True", true}, + {"secure", "1", true}, + {"secure", 1, true}, + {"secure", uint16(1), true}, + {"secure", int32(1), true}, + {"secure", int64(1), true}, + {"secure", float32(1), true}, + {"secure", float64(1), true}, + {"secure", "false", false}, + {"secure", "False", false}, + {"secure", 0, false}, + {"secure", int32(0), false}, + {"secure", int64(0), false}, + {"secure", float32(0), false}, + {"secure", float64(0), false}, + {"skipTlsVerify", "true", true}, + {"skipTlsVerify", "True", true}, + {"skipTlsVerify", "1", true}, + {"skipTlsVerify", 1, true}, + {"skipTlsVerify", uint16(1), true}, + {"skipTlsVerify", int64(1), true}, + {"skipTlsVerify", float32(1), true}, + {"skipTlsVerify", float64(1), true}, + {"skipTlsVerify", "false", false}, + {"skipTlsVerify", "False", false}, + {"skipTlsVerify", 0, false}, + {"skipTlsVerify", uint16(0), false}, + {"skipTlsVerify", int64(0), false}, + {"skipTlsVerify", int16(0), false}, + {"skipTlsVerify", float64(0), false}, + {"skipTlsVerify", float32(0), false}, + {"port", uint16(80), uint16(80)}, + {"port", int32(80), uint16(80)}, + {"port", int64(80), uint16(80)}, + {"port", float64(80), uint16(80)}, + {"port", float32(80), uint16(80)}, + {"port", "80", uint16(80)}, + } + for _, test := range tests { + t.Run(fmt.Sprintf("%s-%v-%t", test.name, test.val, test.val), func(t *testing.T) { + var rawJson map[string]any + err = json.Unmarshal(originalSettings, &rawJson) + rawJson[test.name] = test.val + jsonSettigns, _ := json.Marshal(rawJson) + dsSettings := backend.DataSourceInstanceSettings{JSONData: jsonSettigns} + + newSettings, err := NewPluginSettings(context.Background(), dsSettings) + assert.NoError(t, err) + + fieldName := strings.Title(test.name) + assert.Equal(t, test.res, getField(newSettings, fieldName)) + + switch test.res.(type) { + case bool: + v, err := parseBool(test.val) + assert.NoError(t, err) + assert.Equal(t, test.res, v) + default: + v, err := parseUint(test.val) + assert.NoError(t, err) + assert.Equal(t, test.res, v) + } + }) + } + + }) + t.Run("parse invalid grafana ds plugin settings", func(t *testing.T) { + dsSettings := backend.DataSourceInstanceSettings{ + Name: "test-hydrolix-http-datasource", + JSONData: []byte("invalid"), + } + _, err := NewPluginSettings(context.Background(), dsSettings) + assert.Error(t, err, "invalid json should return an error") + + }) + t.Run("validate mandatory plugin settings", func(t *testing.T) { + settings := PluginSettings{ + Host: "localhost", + Port: 80, + Protocol: "native", + UserName: "default", + Password: "pass", + Secure: true, + Path: "/query", + SkipTlsVerify: true, + DialTimeout: "10", + QueryTimeout: "20", + DefaultDatabase: "dbdb", + Other: nil, + } + assert.NoError(t, settings.IsValid(), "plugin settings should be valid") + + errSettings := settings + errSettings.Host = "" + assert.Error(t, errSettings.IsValid(), ErrorMessageInvalidHost) + + errSettings = settings + errSettings.Port = 0 + assert.Error(t, errSettings.IsValid(), ErrorMessageInvalidPort) + + errSettings = settings + errSettings.Protocol = "" + assert.Error(t, errSettings.IsValid(), ErrorMessageInvalidProtocol) + errSettings = settings + errSettings.Protocol = "https" + assert.Error(t, errSettings.IsValid(), ErrorMessageInvalidProtocol) + errSettings = settings + errSettings.Protocol = "native " + assert.Error(t, errSettings.IsValid(), ErrorMessageInvalidProtocol) + + errSettings = settings + errSettings.DialTimeout = "" + assert.Error(t, errSettings.IsValid(), "property should be validated") + errSettings.SetDefaults() + assert.NoError(t, errSettings.IsValid(), "plugin settings should be valid") + errSettings = settings + errSettings.DialTimeout = "a" + assert.Error(t, errSettings.IsValid(), ErrorMessageInvalidDialTimeout) + + errSettings = settings + errSettings.QueryTimeout = "" + assert.Error(t, errSettings.IsValid(), "property should be validated") + errSettings.SetDefaults() + assert.NoError(t, errSettings.IsValid(), "plugin settings should be valid") + errSettings = settings + errSettings.QueryTimeout = "b" + assert.Error(t, errSettings.IsValid(), ErrorMessageInvalidQueryTimeout) + + }) +} + +func getField(v any, field string) any { + r := reflect.ValueOf(v) + f := reflect.Indirect(r).FieldByName(field) + return f.Interface() +} diff --git a/query_integration_test.go b/query_integration_test.go deleted file mode 100644 index 2bf6d79..0000000 --- a/query_integration_test.go +++ /dev/null @@ -1,91 +0,0 @@ -package sqlds - -import ( - "context" - "database/sql" - "errors" - "os" - "strings" - "testing" - "time" - - "github.com/grafana/grafana-plugin-sdk-go/backend" - "github.com/grafana/grafana-plugin-sdk-go/data/sqlutil" - - _ "github.com/go-sql-driver/mysql" -) - -type testArgs struct { - MySQLURL string - RunIntegrationTests bool -} - -func testEnvArgs(t *testing.T) testArgs { - t.Helper() - var args testArgs - if val, ok := os.LookupEnv("MYSQL_URL"); ok { - args.MySQLURL = val - } else { - args.MySQLURL = "mysql:mysql@/mysql" - } - - if _, ok := os.LookupEnv("INTEGRATION_TESTS"); ok { - args.RunIntegrationTests = true - } - - return args -} - -func TestQuery_MySQL(t *testing.T) { - var ( - args = testEnvArgs(t) - ctx = context.Background() - - db *sql.DB - ) - - if !args.RunIntegrationTests { - t.SkipNow() - } - - ticker := time.NewTicker(time.Second * 5) - defer ticker.Stop() - - // Attempt to connect multiple times because these tests are ran in Drone, where the mysql server may not be immediately available when this test is ran. - limit := 10 - for i := 0; i < limit; i++ { - t.Log("Attempting mysql connection...") - d, err := sql.Open("mysql", args.MySQLURL) - if err == nil { - if err := d.Ping(); err == nil { - db = d - break - } - } - - <-ticker.C - } - defer db.Close() - - t.Run("The query should return a context.Canceled if it exceeds the timeout", func(t *testing.T) { - ctx, cancel := context.WithTimeout(ctx, time.Second) - defer cancel() - - q := &Query{ - RawSQL: "SELECT SLEEP(5)", - } - - settings := backend.DataSourceInstanceSettings{ - Name: "foo", - } - - sqlQuery := NewQuery(db, settings, []sqlutil.Converter{}, nil, defaultRowLimit) - _, err := sqlQuery.Run(ctx, q, nil) - if err == nil { - t.Fatal("expected an error but received none") - } - if !(errors.Is(err, context.Canceled) || strings.Contains(err.Error(), "context deadline exceeded")) { - t.Fatal("expected a context.Canceled error but received:", err) - } - }) -} diff --git a/test/driver.go b/test/driver.go index e065f44..0651b99 100644 --- a/test/driver.go +++ b/test/driver.go @@ -10,16 +10,17 @@ import ( "reflect" "time" + "github.com/hydrolix/sqlds/v5" + "github.com/grafana/grafana-plugin-sdk-go/backend" "github.com/grafana/grafana-plugin-sdk-go/data/sqlutil" - "github.com/grafana/sqlds/v5" - "github.com/grafana/sqlds/v5/mock" + "github.com/hydrolix/sqlds/v5/mock" ) var registered = map[string]*SqlHandler{} // NewDriver creates and registers a new test datasource driver -func NewDriver(name string, dbdata Data, converters []sqlutil.Converter, opts DriverOpts, macros sqlds.Macros) (TestDS, *SqlHandler) { +func NewDriver(name string, dbdata Data, converters []sqlutil.Converter, opts DriverOpts) (TestDS, *SqlHandler) { if registered[name] == nil { handler := NewDriverHandler(dbdata, opts) registered[name] = &handler @@ -34,16 +35,14 @@ func NewDriver(name string, dbdata Data, converters []sqlutil.Converter, opts Dr return sql.Open(name, "") }, converters, - macros, ), registered[name] } // NewTestDS creates a new test datasource driver -func NewTestDS(openDBfn func(msg json.RawMessage) (*sql.DB, error), converters []sqlutil.Converter, macros sqlds.Macros) TestDS { +func NewTestDS(openDBfn func(msg json.RawMessage) (*sql.DB, error), converters []sqlutil.Converter) TestDS { return TestDS{ openDBfn: openDBfn, converters: converters, - macros: macros, } } @@ -148,7 +147,6 @@ type Column struct { type TestDS struct { openDBfn func(msg json.RawMessage) (*sql.DB, error) converters []sqlutil.Converter - macros sqlds.Macros sqlds.Driver } @@ -172,11 +170,6 @@ func (s TestDS) Settings(ctx context.Context, config backend.DataSourceInstanceS return settings } -// Macros - Macros for the test database -func (s TestDS) Macros() sqlds.Macros { - return s.macros -} - // Converters - Converters for the test database func (s TestDS) Converters() []sqlutil.Converter { return nil From 62e0df7a57485fe6613a27a3ae933cc1dcf0698e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 16 Apr 2026 14:52:26 +0000 Subject: [PATCH 2/2] Bump go.opentelemetry.io/otel/sdk from 1.40.0 to 1.43.0 Bumps [go.opentelemetry.io/otel/sdk](https://github.com/open-telemetry/opentelemetry-go) from 1.40.0 to 1.43.0. - [Release notes](https://github.com/open-telemetry/opentelemetry-go/releases) - [Changelog](https://github.com/open-telemetry/opentelemetry-go/blob/main/CHANGELOG.md) - [Commits](https://github.com/open-telemetry/opentelemetry-go/compare/v1.40.0...v1.43.0) --- updated-dependencies: - dependency-name: go.opentelemetry.io/otel/sdk dependency-version: 1.43.0 dependency-type: indirect ... Signed-off-by: dependabot[bot] --- go.mod | 12 ++++++------ go.sum | 24 ++++++++++++------------ 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/go.mod b/go.mod index a2d2381..49b27c0 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,6 @@ go 1.25.7 require ( github.com/DATA-DOG/go-sqlmock v1.5.2 - github.com/google/go-cmp v0.7.0 github.com/grafana/dataplane/sdata v0.0.9 github.com/grafana/grafana-plugin-sdk-go v0.290.1 github.com/hydrolix/clickhouse-sql-parser v0.3.0 @@ -27,12 +26,12 @@ require ( go.opentelemetry.io/contrib/instrumentation/net/http/httptrace/otelhttptrace v0.65.0 // indirect go.opentelemetry.io/contrib/propagators/jaeger v1.40.0 // indirect go.opentelemetry.io/contrib/samplers/jaegerremote v0.34.0 // indirect - go.opentelemetry.io/otel v1.41.0 // indirect + go.opentelemetry.io/otel v1.43.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.40.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.40.0 // indirect - go.opentelemetry.io/otel/metric v1.41.0 // indirect - go.opentelemetry.io/otel/sdk v1.40.0 // indirect - go.opentelemetry.io/otel/trace v1.41.0 // indirect + go.opentelemetry.io/otel/metric v1.43.0 // indirect + go.opentelemetry.io/otel/sdk v1.43.0 // indirect + go.opentelemetry.io/otel/trace v1.43.0 // indirect go.opentelemetry.io/proto/otlp v1.9.0 // indirect golang.org/x/mod v0.32.0 // indirect golang.org/x/tools v0.41.0 // indirect @@ -54,6 +53,7 @@ require ( github.com/gogo/googleapis v1.4.1 // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/google/flatbuffers v25.12.19+incompatible // indirect + github.com/google/go-cmp v0.7.0 // indirect github.com/grafana/otel-profiling-go v0.5.1 // indirect github.com/grafana/pyroscope-go/godeltaprof v0.1.9 // indirect github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus v1.1.0 // indirect @@ -94,7 +94,7 @@ require ( golang.org/x/exp v0.0.0-20251002181428-27f1f14c8bb9 // indirect golang.org/x/net v0.51.0 // indirect golang.org/x/sync v0.19.0 // indirect - golang.org/x/sys v0.41.0 // indirect + golang.org/x/sys v0.42.0 // indirect golang.org/x/telemetry v0.0.0-20260109210033-bd525da824e2 // indirect golang.org/x/term v0.40.0 // indirect golang.org/x/text v0.34.0 // indirect diff --git a/go.sum b/go.sum index 1c3821e..8bf2fcd 100644 --- a/go.sum +++ b/go.sum @@ -199,23 +199,23 @@ go.opentelemetry.io/contrib/propagators/jaeger v1.40.0/go.mod h1:ioMePqe6k6c/ovX go.opentelemetry.io/contrib/samplers/jaegerremote v0.34.0 h1:RZjNfF9OoR4oPLEWaP+Memql2MNVkZvnwjB2N5tR3cA= go.opentelemetry.io/contrib/samplers/jaegerremote v0.34.0/go.mod h1:b5U9IcSnv+lMvEcSOXZB61kXSf0KkwickleKWuAQclw= go.opentelemetry.io/otel v1.21.0/go.mod h1:QZzNPQPm1zLX4gZK4cMi+71eaorMSGT3A4znnUvNNEo= -go.opentelemetry.io/otel v1.41.0 h1:YlEwVsGAlCvczDILpUXpIpPSL/VPugt7zHThEMLce1c= -go.opentelemetry.io/otel v1.41.0/go.mod h1:Yt4UwgEKeT05QbLwbyHXEwhnjxNO6D8L5PQP51/46dE= +go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I= +go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.40.0 h1:QKdN8ly8zEMrByybbQgv8cWBcdAarwmIPZ6FThrWXJs= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.40.0/go.mod h1:bTdK1nhqF76qiPoCCdyFIV+N/sRHYXYCTQc+3VCi3MI= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.40.0 h1:DvJDOPmSWQHWywQS6lKL+pb8s3gBLOZUtw4N+mavW1I= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.40.0/go.mod h1:EtekO9DEJb4/jRyN4v4Qjc2yA7AtfCBuz2FynRUWTXs= go.opentelemetry.io/otel/metric v1.21.0/go.mod h1:o1p3CA8nNHW8j5yuQLdc1eeqEaPfzug24uvsyIEJRWM= -go.opentelemetry.io/otel/metric v1.41.0 h1:rFnDcs4gRzBcsO9tS8LCpgR0dxg4aaxWlJxCno7JlTQ= -go.opentelemetry.io/otel/metric v1.41.0/go.mod h1:xPvCwd9pU0VN8tPZYzDZV/BMj9CM9vs00GuBjeKhJps= +go.opentelemetry.io/otel/metric v1.43.0 h1:d7638QeInOnuwOONPp4JAOGfbCEpYb+K6DVWvdxGzgM= +go.opentelemetry.io/otel/metric v1.43.0/go.mod h1:RDnPtIxvqlgO8GRW18W6Z/4P462ldprJtfxHxyKd2PY= go.opentelemetry.io/otel/sdk v1.21.0/go.mod h1:Nna6Yv7PWTdgJHVRD9hIYywQBRx7pbox6nwBnZIxl/E= -go.opentelemetry.io/otel/sdk v1.40.0 h1:KHW/jUzgo6wsPh9At46+h4upjtccTmuZCFAc9OJ71f8= -go.opentelemetry.io/otel/sdk v1.40.0/go.mod h1:Ph7EFdYvxq72Y8Li9q8KebuYUr2KoeyHx0DRMKrYBUE= -go.opentelemetry.io/otel/sdk/metric v1.40.0 h1:mtmdVqgQkeRxHgRv4qhyJduP3fYJRMX4AtAlbuWdCYw= -go.opentelemetry.io/otel/sdk/metric v1.40.0/go.mod h1:4Z2bGMf0KSK3uRjlczMOeMhKU2rhUqdWNoKcYrtcBPg= +go.opentelemetry.io/otel/sdk v1.43.0 h1:pi5mE86i5rTeLXqoF/hhiBtUNcrAGHLKQdhg4h4V9Dg= +go.opentelemetry.io/otel/sdk v1.43.0/go.mod h1:P+IkVU3iWukmiit/Yf9AWvpyRDlUeBaRg6Y+C58QHzg= +go.opentelemetry.io/otel/sdk/metric v1.43.0 h1:S88dyqXjJkuBNLeMcVPRFXpRw2fuwdvfCGLEo89fDkw= +go.opentelemetry.io/otel/sdk/metric v1.43.0/go.mod h1:C/RJtwSEJ5hzTiUz5pXF1kILHStzb9zFlIEe85bhj6A= go.opentelemetry.io/otel/trace v1.21.0/go.mod h1:LGbsEB0f9LGjN+OZaQQ26sohbOmiMR+BaslueVtS/qQ= -go.opentelemetry.io/otel/trace v1.41.0 h1:Vbk2co6bhj8L59ZJ6/xFTskY+tGAbOnCtQGVVa9TIN0= -go.opentelemetry.io/otel/trace v1.41.0/go.mod h1:U1NU4ULCoxeDKc09yCWdWe+3QoyweJcISEVa1RBzOis= +go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A= +go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0= go.opentelemetry.io/proto/otlp v1.9.0 h1:l706jCMITVouPOqEnii2fIAuO3IVGBRPV5ICjceRb/A= go.opentelemetry.io/proto/otlp v1.9.0/go.mod h1:xE+Cx5E/eEHw+ISFkwPLwCZefwVjY+pqKg1qcK03+/4= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= @@ -254,8 +254,8 @@ golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= -golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= +golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/telemetry v0.0.0-20260109210033-bd525da824e2 h1:O1cMQHRfwNpDfDJerqRoE2oD+AFlyid87D40L/OkkJo= golang.org/x/telemetry v0.0.0-20260109210033-bd525da824e2/go.mod h1:b7fPSJ0pKZ3ccUh8gnTONJxhn3c/PS6tyzQvyqw4iA8= golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg=