diff --git a/README.md b/README.md index 12789d8..5eaf362 100644 --- a/README.md +++ b/README.md @@ -124,6 +124,28 @@ Every command is a method on the loaded `*goway.Flyway` value and takes a | `Repair` | Removes failed entries and realigns recorded checksums. | | `Clean` | Drops every object in the managed schemas. Disabled by default. | +### Callbacks + +Register lifecycle callbacks either as SQL scripts placed in the configured +locations (`beforeMigrate.sql`, `afterMigrate.sql`, `beforeEachMigrate.sql`, +`afterEachMigrate__description.sql`) or programmatically through +`Configure().Callbacks(...)` with a value implementing `Callback`, or the +`CallbackFunc` adapter. + +### Non-transactional migrations + +A migration that cannot run inside a transaction, such as one using PostgreSQL's +`CREATE INDEX CONCURRENTLY` or SQLite's `VACUUM`, opts out of the per-migration +transaction with a directive on the first lines of the script: + +```sql +-- goway:noTransaction +CREATE INDEX CONCURRENTLY idx_users_email ON users (email); +``` + +The statements then run directly on a dedicated connection and the history row +is still recorded. + ## Command line tool A command line front end lives in the `cmd/goway` module. @@ -171,12 +193,12 @@ go -C integration test ./... -count=1 Implemented: versioned and repeatable SQL migrations, the schema history table, migrate, info, validate, baseline, repair and clean, placeholder replacement, -multiple locations, embedded file systems, schema creation, and a command line -tool. +multiple locations, embedded file systems, schema creation, lifecycle callbacks +(SQL scripts and programmatic), per-script non-transactional execution, the +superseded state for repeatable migrations, and a command line tool. -Not yet implemented: Java style code based migrations, lifecycle callbacks, undo -migrations, and listing every historical run of a repeatable migration (only the -latest run is reported). +Not yet implemented: Go code based migrations, undo migrations, and the grouped +and mixed transaction modes. ## Acknowledgements and License diff --git a/callback.go b/callback.go new file mode 100644 index 0000000..059eb31 --- /dev/null +++ b/callback.go @@ -0,0 +1,83 @@ +package goway + +import ( + "context" + "database/sql" + "strings" +) + +// CallbackEvent identifies a point in the migrate lifecycle at which callbacks +// are invoked. The values match the corresponding Flyway event names. +type CallbackEvent string + +const ( + // EventBeforeMigrate fires once before any migration is applied. + EventBeforeMigrate CallbackEvent = "beforeMigrate" + + // EventAfterMigrate fires once after all migrations have been applied. + EventAfterMigrate CallbackEvent = "afterMigrate" + + // EventBeforeEachMigrate fires before each individual migration, inside the + // migration's transaction when one is used. + EventBeforeEachMigrate CallbackEvent = "beforeEachMigrate" + + // EventAfterEachMigrate fires after each individual migration, inside the + // migration's transaction when one is used. + EventAfterEachMigrate CallbackEvent = "afterEachMigrate" +) + +// Execer is the minimal interface needed to run a statement. It is satisfied by +// *sql.DB, *sql.Tx and *sql.Conn, so a callback can execute SQL on whichever +// handle is active for the current event. +type Execer interface { + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) +} + +// Callback receives lifecycle events during a migrate run. For the per-migration +// events the migration argument describes the migration being processed; it is +// nil for the run level events. +type Callback interface { + Handle(ctx context.Context, event CallbackEvent, exec Execer, migration *MigrationInfo) error +} + +// CallbackFunc adapts an ordinary function to the Callback interface. +type CallbackFunc func(ctx context.Context, event CallbackEvent, exec Execer, migration *MigrationInfo) error + +// Handle calls the underlying function. +func (f CallbackFunc) Handle(ctx context.Context, event CallbackEvent, exec Execer, migration *MigrationInfo) error { + return f(ctx, event, exec, migration) +} + +// callbackEventsByName maps the lower cased event name to its canonical value. +var callbackEventsByName = map[string]CallbackEvent{ + "beforemigrate": EventBeforeMigrate, + "aftermigrate": EventAfterMigrate, + "beforeeachmigrate": EventBeforeEachMigrate, + "aftereachmigrate": EventAfterEachMigrate, +} + +// sqlCallback is a callback backed by a SQL script discovered in the configured +// locations. +type sqlCallback struct { + event CallbackEvent + script string + read func() ([]byte, error) +} + +// parseCallbackName reports whether a script file name denotes a callback and, +// if so, which event it handles. The name without its suffix must equal an event +// name, optionally followed by the separator and a description, for example +// "afterEachMigrate__seed.sql". Matching is case insensitive. +func parseCallbackName(fileName, separator string, suffixes []string) (CallbackEvent, bool) { + suffix := matchSuffix(fileName, suffixes) + if suffix == "" { + return "", false + } + stem := fileName[:len(fileName)-len(suffix)] + name := stem + if index := strings.Index(stem, separator); index >= 0 { + name = stem[:index] + } + event, ok := callbackEventsByName[strings.ToLower(name)] + return event, ok +} diff --git a/callback_test.go b/callback_test.go new file mode 100644 index 0000000..093b616 --- /dev/null +++ b/callback_test.go @@ -0,0 +1,52 @@ +package goway + +import ( + "context" + "testing" +) + +func TestParseCallbackName(t *testing.T) { + suffixes := []string{".sql"} + recognized := map[string]CallbackEvent{ + "beforeMigrate.sql": EventBeforeMigrate, + "afterMigrate.sql": EventAfterMigrate, + "beforeEachMigrate.sql": EventBeforeEachMigrate, + "afterEachMigrate__seed.sql": EventAfterEachMigrate, + "AFTERMIGRATE.sql": EventAfterMigrate, + } + for name, want := range recognized { + got, ok := parseCallbackName(name, "__", suffixes) + if !ok || got != want { + t.Errorf("parseCallbackName(%q) = (%q, %v), want (%q, true)", name, got, ok, want) + } + } + + rejected := []string{ + "V1__create.sql", // versioned migration + "R__view.sql", // repeatable migration + "random.sql", // unknown stem + "beforeMigrate.txt", // wrong suffix + "afterMigrateNow.sql", // stem is not exactly an event name + } + for _, name := range rejected { + if _, ok := parseCallbackName(name, "__", suffixes); ok { + t.Errorf("parseCallbackName(%q) was recognized, want rejected", name) + } + } +} + +// TestCallbackFuncImplementsCallback verifies the function adapter satisfies the +// Callback interface and forwards its arguments. +func TestCallbackFuncImplementsCallback(t *testing.T) { + var seen CallbackEvent + var callback Callback = CallbackFunc(func(_ context.Context, event CallbackEvent, _ Execer, _ *MigrationInfo) error { + seen = event + return nil + }) + if err := callback.Handle(context.Background(), EventBeforeMigrate, nil, nil); err != nil { + t.Fatalf("Handle returned error: %v", err) + } + if seen != EventBeforeMigrate { + t.Errorf("callback saw event %q, want %q", seen, EventBeforeMigrate) + } +} diff --git a/config.go b/config.go index 3c5b2d2..0b983d1 100644 --- a/config.go +++ b/config.go @@ -49,6 +49,8 @@ type Configuration struct { target *Version installedBy string + callbacks []Callback + configErr error } @@ -228,6 +230,13 @@ func (c *Configuration) Target(version string) *Configuration { return c } +// Callbacks registers programmatic callbacks invoked during a migrate run, in +// addition to any SQL callback scripts found in the configured locations. +func (c *Configuration) Callbacks(callbacks ...Callback) *Configuration { + c.callbacks = append(c.callbacks, callbacks...) + return c +} + // InstalledBy overrides the user recorded for applied migrations. func (c *Configuration) InstalledBy(user string) *Configuration { c.installedBy = user @@ -275,7 +284,7 @@ func (c *Configuration) LoadContext(ctx context.Context) (*Migrator, error) { dialect = detected } - resolved, err := resolveMigrations(c) + resolved, callbacks, err := resolveMigrations(c) if err != nil { return nil, err } @@ -284,5 +293,6 @@ func (c *Configuration) LoadContext(ctx context.Context) (*Migrator, error) { configuration: c, dialect: dialect, resolved: resolved, + sqlCallbacks: callbacks, }, nil } diff --git a/dialect.go b/dialect.go index e188ccd..a3fe771 100644 --- a/dialect.go +++ b/dialect.go @@ -79,6 +79,12 @@ type Dialect interface { // string when the dialect has no such concept. setSearchPathSQL(schema string) string + // sessionSearchPathSQL returns a statement that makes the given schema the + // default for the whole session rather than a single transaction, for use by + // migrations that run without a transaction. It returns an empty string when + // the dialect has no such concept. + sessionSearchPathSQL(schema string) string + // cleanStatements returns the statements that drop every object in the given // schema, querying the database when the set of objects must be discovered // dynamically. diff --git a/dialect_postgres.go b/dialect_postgres.go index 42dedac..e17f739 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -95,6 +95,13 @@ func (d postgresDialect) setSearchPathSQL(schema string) string { return "SET LOCAL search_path TO " + d.quoteIdentifier(schema) } +func (d postgresDialect) sessionSearchPathSQL(schema string) string { + if schema == "" { + return "" + } + return "SET search_path TO " + d.quoteIdentifier(schema) +} + // cleanStatements drops the schema and recreates it, which removes every object // it contains. This is simpler and more robust than enumerating each object. func (d postgresDialect) cleanStatements(_ context.Context, _ querier, schema string) ([]string, error) { diff --git a/dialect_sqlite.go b/dialect_sqlite.go index 777979d..9d1daec 100644 --- a/dialect_sqlite.go +++ b/dialect_sqlite.go @@ -78,6 +78,8 @@ func (sqliteDialect) splitStatements(sql string) ([]string, error) { func (sqliteDialect) setSearchPathSQL(string) string { return "" } +func (sqliteDialect) sessionSearchPathSQL(string) string { return "" } + // cleanStatements enumerates the user defined objects from the SQLite catalog // and returns statements to drop each of them. Internal objects whose names // begin with the reserved prefix are skipped. diff --git a/docs/design.md b/docs/design.md index 35124ce..cc86fc2 100644 --- a/docs/design.md +++ b/docs/design.md @@ -118,14 +118,28 @@ splitter also recognizes backtick and square bracket identifiers and tracks `BEGIN`, `CASE` and `END` so the inner statements of a trigger body are kept together. +## Callbacks and non-transactional migrations + +Lifecycle callbacks fire around the migrate run (`beforeMigrate`, +`afterMigrate`) and around each applied migration (`beforeEachMigrate`, +`afterEachMigrate`). They are discovered as SQL scripts in the configured +locations by their file name, or registered programmatically through the +`Callback` interface. The per-migration callbacks run on the same executor as +the migration, so within its transaction when one is used. + +A migration whose first comment lines carry a `-- goway:noTransaction` directive +runs outside the per-migration transaction, on a dedicated connection, for +statements such as PostgreSQL's `CREATE INDEX CONCURRENTLY` or SQLite's +`VACUUM`. Because there is no transaction to roll back, a failure is recorded as +a failed history row that the repair command can clear. + ## Known divergences -- Only the latest run of a repeatable migration is reported by `Info`; Flyway - lists every historical run and marks the superseded ones. - Placeholder checksums are always computed on the raw content; Flyway computes them on the replaced content for repeatable migrations. -- Code based migrations, lifecycle callbacks, and undo migrations are not - implemented. +- Go code based migrations and undo migrations are not implemented. +- The grouped and mixed transaction modes are not implemented; each migration + runs in its own transaction unless it opts out. ## Naming and trademark diff --git a/info.go b/info.go index 7827435..9e07879 100644 --- a/info.go +++ b/info.go @@ -45,7 +45,7 @@ func computeInfos(resolved []*resolvedMigration, applied []appliedMigration, con } appliedVersioned := make(map[string]*appliedMigration) - repeatableApplied := make(map[string]*appliedMigration) + var appliedRepeatables []*appliedMigration var baseline *appliedMigration var current *Version noteCurrent := func(version *Version) { @@ -72,7 +72,7 @@ func computeInfos(resolved []*resolvedMigration, applied []appliedMigration, con noteCurrent(record.version) } default: - repeatableApplied[record.description] = record + appliedRepeatables = append(appliedRepeatables, record) } } @@ -101,11 +101,11 @@ func computeInfos(resolved []*resolvedMigration, applied []appliedMigration, con versionedEntries := make([]*migrationInfoEntry, 0, len(versionKeys)) for _, key := range versionKeys { - resolvedMigration := resolvedVersioned[key] - appliedMigration := appliedVersioned[key] + resolvedForKey := resolvedVersioned[key] + appliedForKey := appliedVersioned[key] entry := &migrationInfoEntry{ - resolved: resolvedMigration, - applied: appliedMigration, + resolved: resolvedForKey, + applied: appliedForKey, version: versionSet[key], } populateVersionedEntry(entry, configuration, current, baselineVersion, maxResolved) @@ -127,7 +127,14 @@ func computeInfos(resolved []*resolvedMigration, applied []appliedMigration, con }) entries = append(entries, versionedEntries...) - // Repeatable entries are the union of resolved and applied descriptions. + // Repeatable entries keep every applied run plus any resolved repeatable that + // has never been applied. For each description the run with the highest + // installed rank is current; older runs are marked superseded. + appliedByDescription := make(map[string][]*appliedMigration) + for _, record := range appliedRepeatables { + appliedByDescription[record.description] = append(appliedByDescription[record.description], record) + } + repeatableKeys := make([]string, 0) repeatableSet := make(map[string]bool) resolvedRepeatableByDescription := make(map[string]*resolvedMigration) @@ -138,7 +145,7 @@ func computeInfos(resolved []*resolvedMigration, applied []appliedMigration, con repeatableSet[migration.description] = true } } - for description := range repeatableApplied { + for description := range appliedByDescription { if !repeatableSet[description] { repeatableKeys = append(repeatableKeys, description) repeatableSet[description] = true @@ -147,15 +154,42 @@ func computeInfos(resolved []*resolvedMigration, applied []appliedMigration, con sort.Strings(repeatableKeys) for _, description := range repeatableKeys { - resolvedMigration := resolvedRepeatableByDescription[description] - appliedMigration := repeatableApplied[description] - entry := &migrationInfoEntry{ - resolved: resolvedMigration, - applied: appliedMigration, - description: description, + resolvedForKey := resolvedRepeatableByDescription[description] + runs := appliedByDescription[description] + sort.SliceStable(runs, func(i, j int) bool { + return runs[i].installedRank < runs[j].installedRank + }) + + latestRank := -1 + for _, run := range runs { + if run.installedRank > latestRank { + latestRank = run.installedRank + } + } + + for _, run := range runs { + entry := &migrationInfoEntry{ + applied: run, + description: description, + script: run.script, + } + if run.installedRank == latestRank { + entry.resolved = resolvedForKey + populateRepeatableEntry(entry) + } else { + entry.state = StateSuperseded + } + entries = append(entries, entry) + } + + if len(runs) == 0 && resolvedForKey != nil { + entry := &migrationInfoEntry{ + resolved: resolvedForKey, + description: description, + } + populateRepeatableEntry(entry) + entries = append(entries, entry) } - populateRepeatableEntry(entry) - entries = append(entries, entry) } return &migrationInfoService{entries: entries, current: current} diff --git a/info_superseded_test.go b/info_superseded_test.go new file mode 100644 index 0000000..9a12955 --- /dev/null +++ b/info_superseded_test.go @@ -0,0 +1,85 @@ +package goway + +import "testing" + +func i32(value int32) *int32 { return &value } + +// findInfo returns the first migration info whose script and installed rank +// match, or fails the test. +func findInfo(t *testing.T, infos []MigrationInfo, script string, rank int) MigrationInfo { + t.Helper() + for _, info := range infos { + if info.Script == script && info.InstalledRank == rank { + return info + } + } + t.Fatalf("no migration info for script %q rank %d in %+v", script, rank, infos) + return MigrationInfo{} +} + +func TestComputeInfosRepeatableSupersededWhenChecksumMatches(t *testing.T) { + configuration := Configure() + resolved := []*resolvedMigration{{ + description: "view", + script: "R__view.sql", + checksum: 100, + repeatable: true, + migrationType: MigrationTypeSQL, + }} + applied := []appliedMigration{ + {installedRank: 1, description: "view", script: "R__view.sql", migrationType: "SQL", checksum: i32(50), success: true}, + {installedRank: 2, description: "view", script: "R__view.sql", migrationType: "SQL", checksum: i32(100), success: true}, + } + + service := computeInfos(resolved, applied, configuration) + infos := service.infos() + + if got := findInfo(t, infos, "R__view.sql", 1).State; got != StateSuperseded { + t.Errorf("older run state = %q, want %q", got, StateSuperseded) + } + if got := findInfo(t, infos, "R__view.sql", 2).State; got != StateSuccess { + t.Errorf("latest run state = %q, want %q", got, StateSuccess) + } + + // A superseded run must never be a validation error, and the matching latest + // run keeps validation clean. + if problems := service.validate(false); len(problems) != 0 { + t.Errorf("validate reported problems for a superseded repeatable: %+v", problems) + } + + // Nothing is pending: the latest run matches the resolved checksum. + if pending := service.pending(); len(pending) != 0 { + t.Errorf("expected no pending migrations, got %d", len(pending)) + } +} + +func TestComputeInfosRepeatableLatestOutdated(t *testing.T) { + configuration := Configure() + resolved := []*resolvedMigration{{ + description: "view", + script: "R__view.sql", + checksum: 777, + repeatable: true, + migrationType: MigrationTypeSQL, + }} + applied := []appliedMigration{ + {installedRank: 1, description: "view", script: "R__view.sql", migrationType: "SQL", checksum: i32(50), success: true}, + {installedRank: 2, description: "view", script: "R__view.sql", migrationType: "SQL", checksum: i32(100), success: true}, + } + + service := computeInfos(resolved, applied, configuration) + infos := service.infos() + + if got := findInfo(t, infos, "R__view.sql", 1).State; got != StateSuperseded { + t.Errorf("older run state = %q, want %q", got, StateSuperseded) + } + if got := findInfo(t, infos, "R__view.sql", 2).State; got != StateOutdated { + t.Errorf("latest run state = %q, want %q", got, StateOutdated) + } + + // The outdated latest run is pending re-execution. + pending := service.pending() + if len(pending) != 1 { + t.Fatalf("expected exactly one pending migration, got %d", len(pending)) + } +} diff --git a/integration/harness_test.go b/integration/harness_test.go index 4bcf9f5..7dc4c27 100644 --- a/integration/harness_test.go +++ b/integration/harness_test.go @@ -20,7 +20,7 @@ import ( // migrationsFS holds the test migration scripts embedded into the test binary. // -//go:embed testdata/shared/*.sql testdata/pg/*.sql testdata/sqlite/*.sql testdata/placeholder/*.sql +//go:embed testdata/shared/*.sql testdata/pg/*.sql testdata/sqlite/*.sql testdata/placeholder/*.sql testdata/notx_pg/*.sql testdata/notx_sqlite/*.sql testdata/callbacks/*.sql var migrationsFS embed.FS const defaultPostgresImage = "postgres:18-alpine" diff --git a/integration/parity_test.go b/integration/parity_test.go new file mode 100644 index 0000000..3dc78f0 --- /dev/null +++ b/integration/parity_test.go @@ -0,0 +1,264 @@ +package integration + +import ( + "context" + "testing" + + "github.com/cgardev/goway" +) + +// TestPostgresNoTransactionConcurrentIndex verifies that a script marked with +// the no-transaction directive can run CREATE INDEX CONCURRENTLY, which fails +// inside a transaction block. +func TestPostgresNoTransactionConcurrentIndex(t *testing.T) { + database := openPostgres(t) + ctx := context.Background() + const schema = "test_notx" + migrator, err := goway.Configure(). + DataSource(database). + Dialect(goway.Postgres()). + Locations(). + FS(migrationsFS, "testdata/notx_pg"). + Schemas(schema). + CreateSchemas(true). + Load() + if err != nil { + t.Fatalf("load migrator: %v", err) + } + + result, err := migrator.Migrate(ctx) + if err != nil { + t.Fatalf("migrate with CREATE INDEX CONCURRENTLY: %v", err) + } + if result.MigrationsExecuted != 2 { + t.Fatalf("executed %d migrations, want 2", result.MigrationsExecuted) + } + + var indexCount int + err = database.QueryRow( + `SELECT COUNT(*) FROM pg_indexes WHERE schemaname = $1 AND indexname = 'idx_items_name'`, + schema).Scan(&indexCount) + if err != nil { + t.Fatalf("query index: %v", err) + } + if indexCount != 1 { + t.Fatalf("concurrent index count = %d, want 1", indexCount) + } +} + +// TestPostgresConcurrentIndexFailsWithoutDirective is a control: the same +// statement without the directive runs inside a transaction and is rejected by +// PostgreSQL, proving the directive is what makes the success case work. +func TestPostgresConcurrentIndexFailsWithoutDirective(t *testing.T) { + database := openPostgres(t) + ctx := context.Background() + const schema = "test_notx_control" + + setup, err := goway.Configure(). + DataSource(database).Dialect(goway.Postgres()). + Locations().FS(migrationsFS, "testdata/notx_pg"). + Schemas(schema).CreateSchemas(true). + SQLMigrationSuffixes(".never"). + Load() + if err != nil { + t.Fatalf("load setup migrator: %v", err) + } + _ = setup + + // Create the table directly, then attempt the concurrent index inside a + // transaction to confirm PostgreSQL rejects it. + if _, err := database.Exec(`CREATE SCHEMA IF NOT EXISTS test_notx_control`); err != nil { + t.Fatalf("create schema: %v", err) + } + if _, err := database.Exec(`CREATE TABLE test_notx_control.items (id INTEGER PRIMARY KEY, name VARCHAR(100))`); err != nil { + t.Fatalf("create table: %v", err) + } + tx, err := database.BeginTx(ctx, nil) + if err != nil { + t.Fatalf("begin: %v", err) + } + defer tx.Rollback() + if _, err := tx.ExecContext(ctx, `CREATE INDEX CONCURRENTLY idx_control ON test_notx_control.items (name)`); err == nil { + t.Fatal("expected CREATE INDEX CONCURRENTLY to fail inside a transaction") + } +} + +// TestSQLiteNoTransactionVacuum verifies a no-transaction migration runs VACUUM, +// which cannot run inside a transaction on SQLite. +func TestSQLiteNoTransactionVacuum(t *testing.T) { + database := openSQLite(t) + ctx := context.Background() + migrator, err := goway.Configure(). + DataSource(database). + Dialect(goway.SQLite()). + Locations(). + FS(migrationsFS, "testdata/notx_sqlite"). + Load() + if err != nil { + t.Fatalf("load migrator: %v", err) + } + + result, err := migrator.Migrate(ctx) + if err != nil { + t.Fatalf("migrate with VACUUM: %v", err) + } + if result.MigrationsExecuted != 2 { + t.Fatalf("executed %d migrations, want 2", result.MigrationsExecuted) + } + // The table from V1 must still exist after the vacuum. + if _, err := database.Exec("INSERT INTO notes (id, body) VALUES (1, 'hi')"); err != nil { + t.Fatalf("insert after vacuum: %v", err) + } +} + +// TestSQLiteSQLCallbacks verifies that beforeMigrate and afterEachMigrate SQL +// callback scripts are discovered and fired. +func TestSQLiteSQLCallbacks(t *testing.T) { + database := openSQLite(t) + ctx := context.Background() + migrator, err := goway.Configure(). + DataSource(database). + Dialect(goway.SQLite()). + Locations(). + FS(migrationsFS, "testdata/callbacks"). + Load() + if err != nil { + t.Fatalf("load migrator: %v", err) + } + + if _, err := migrator.Migrate(ctx); err != nil { + t.Fatalf("migrate with callbacks: %v", err) + } + + // beforeMigrate inserts one 'before' row; afterEachMigrate inserts one 'each' + // row per applied migration (one migration here). + count := func(note string) int { + var n int + if err := database.QueryRow(`SELECT COUNT(*) FROM audit WHERE note = ?`, note).Scan(&n); err != nil { + t.Fatalf("count %q: %v", note, err) + } + return n + } + if got := count("before"); got != 1 { + t.Errorf("beforeMigrate rows = %d, want 1", got) + } + if got := count("each"); got != 1 { + t.Errorf("afterEachMigrate rows = %d, want 1", got) + } +} + +// TestSQLiteProgrammaticCallback verifies a callback registered through the +// configuration receives the lifecycle events. +func TestSQLiteProgrammaticCallback(t *testing.T) { + database := openSQLite(t) + ctx := context.Background() + + events := map[goway.CallbackEvent]int{} + recorder := goway.CallbackFunc(func(_ context.Context, event goway.CallbackEvent, _ goway.Execer, _ *goway.MigrationInfo) error { + events[event]++ + return nil + }) + + migrator, err := goway.Configure(). + DataSource(database). + Dialect(goway.SQLite()). + Locations(). + FS(migrationsFS, "testdata/shared"). + Callbacks(recorder). + Load() + if err != nil { + t.Fatalf("load migrator: %v", err) + } + + if _, err := migrator.Migrate(ctx); err != nil { + t.Fatalf("migrate: %v", err) + } + + if events[goway.EventBeforeMigrate] != 1 { + t.Errorf("beforeMigrate fired %d times, want 1", events[goway.EventBeforeMigrate]) + } + if events[goway.EventAfterMigrate] != 1 { + t.Errorf("afterMigrate fired %d times, want 1", events[goway.EventAfterMigrate]) + } + // testdata/shared has three migrations (V1, V2, R), so each-migrate fires 3x. + if events[goway.EventBeforeEachMigrate] != 3 { + t.Errorf("beforeEachMigrate fired %d times, want 3", events[goway.EventBeforeEachMigrate]) + } + if events[goway.EventAfterEachMigrate] != 3 { + t.Errorf("afterEachMigrate fired %d times, want 3", events[goway.EventAfterEachMigrate]) + } +} + +// TestSQLiteRepeatableSuperseded applies a repeatable migration, changes it on +// disk via a second migrator with a different script, and verifies that re-runs +// accumulate history rows of which only the newest is current. +func TestSQLiteRepeatableSuperseded(t *testing.T) { + database := openSQLite(t) + ctx := context.Background() + + migrate := func() *goway.Migrator { + m, err := goway.Configure(). + DataSource(database). + Dialect(goway.SQLite()). + Locations(). + FS(migrationsFS, "testdata/shared"). + Load() + if err != nil { + t.Fatalf("load: %v", err) + } + return m + } + + if _, err := migrate().Migrate(ctx); err != nil { + t.Fatalf("first migrate: %v", err) + } + + // Force the repeatable migration to be re-applied by tampering its recorded + // checksum so it appears outdated, then migrate again. + if _, err := database.Exec( + `UPDATE flyway_schema_history SET checksum = checksum + 1 WHERE version IS NULL AND type = 'SQL'`); err != nil { + t.Fatalf("tamper repeatable checksum: %v", err) + } + if _, err := migrate().Migrate(ctx); err != nil { + t.Fatalf("second migrate: %v", err) + } + + // There must now be two runs of the repeatable migration in history. + var runs int + if err := database.QueryRow( + `SELECT COUNT(*) FROM flyway_schema_history WHERE script = 'R__widget_summary.sql'`).Scan(&runs); err != nil { + t.Fatalf("count repeatable runs: %v", err) + } + if runs != 2 { + t.Fatalf("repeatable runs in history = %d, want 2", runs) + } + + // Info must report exactly one Superseded run and one current run. + info, err := migrate().Info(ctx) + if err != nil { + t.Fatalf("info: %v", err) + } + var superseded, current int + for _, migration := range info.Migrations { + if migration.Script != "R__widget_summary.sql" { + continue + } + switch migration.State { + case goway.StateSuperseded: + superseded++ + case goway.StateSuccess, goway.StateOutdated: + current++ + } + } + if superseded != 1 { + t.Errorf("superseded repeatable runs = %d, want 1", superseded) + } + if current != 1 { + t.Errorf("current repeatable runs = %d, want 1", current) + } + + // Validation must not treat a superseded run as an error. + if result, err := migrate().Validate(ctx); err != nil || !result.Valid { + t.Fatalf("validate after supersede: valid=%v err=%v", result.Valid, err) + } +} diff --git a/integration/testdata/callbacks/V1__create_people.sql b/integration/testdata/callbacks/V1__create_people.sql new file mode 100644 index 0000000..66170e3 --- /dev/null +++ b/integration/testdata/callbacks/V1__create_people.sql @@ -0,0 +1,4 @@ +CREATE TABLE people ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL +); diff --git a/integration/testdata/callbacks/afterEachMigrate.sql b/integration/testdata/callbacks/afterEachMigrate.sql new file mode 100644 index 0000000..763cbba --- /dev/null +++ b/integration/testdata/callbacks/afterEachMigrate.sql @@ -0,0 +1 @@ +INSERT INTO audit (note) VALUES ('each'); diff --git a/integration/testdata/callbacks/beforeMigrate.sql b/integration/testdata/callbacks/beforeMigrate.sql new file mode 100644 index 0000000..296805e --- /dev/null +++ b/integration/testdata/callbacks/beforeMigrate.sql @@ -0,0 +1,6 @@ +CREATE TABLE IF NOT EXISTS audit ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + note TEXT NOT NULL +); + +INSERT INTO audit (note) VALUES ('before'); diff --git a/integration/testdata/notx_pg/V1__create_items.sql b/integration/testdata/notx_pg/V1__create_items.sql new file mode 100644 index 0000000..ace9c53 --- /dev/null +++ b/integration/testdata/notx_pg/V1__create_items.sql @@ -0,0 +1,4 @@ +CREATE TABLE items ( + id INTEGER PRIMARY KEY, + name VARCHAR(100) NOT NULL +); diff --git a/integration/testdata/notx_pg/V2__concurrent_index.sql b/integration/testdata/notx_pg/V2__concurrent_index.sql new file mode 100644 index 0000000..8360696 --- /dev/null +++ b/integration/testdata/notx_pg/V2__concurrent_index.sql @@ -0,0 +1,2 @@ +-- goway:noTransaction +CREATE INDEX CONCURRENTLY idx_items_name ON items (name); diff --git a/integration/testdata/notx_sqlite/V1__create_notes.sql b/integration/testdata/notx_sqlite/V1__create_notes.sql new file mode 100644 index 0000000..ca9a07c --- /dev/null +++ b/integration/testdata/notx_sqlite/V1__create_notes.sql @@ -0,0 +1,4 @@ +CREATE TABLE notes ( + id INTEGER PRIMARY KEY, + body TEXT NOT NULL +); diff --git a/integration/testdata/notx_sqlite/V2__vacuum.sql b/integration/testdata/notx_sqlite/V2__vacuum.sql new file mode 100644 index 0000000..e2a622f --- /dev/null +++ b/integration/testdata/notx_sqlite/V2__vacuum.sql @@ -0,0 +1,2 @@ +-- goway:noTransaction +VACUUM; diff --git a/migrate.go b/migrate.go index 6026c2f..0657562 100644 --- a/migrate.go +++ b/migrate.go @@ -72,6 +72,10 @@ func (f *Migrator) Migrate(ctx context.Context) (*MigrateResult, error) { } } + if err := f.fireCallbacks(ctx, db, EventBeforeMigrate, nil, schema); err != nil { + return nil, err + } + for _, entry := range service.pending() { info, err := f.applyMigration(ctx, db, history, entry, schema, installedBy) if err != nil { @@ -82,13 +86,18 @@ func (f *Migrator) Migrate(ctx context.Context) (*MigrateResult, error) { result.MigrationsExecuted++ } + if err := f.fireCallbacks(ctx, db, EventAfterMigrate, nil, schema); err != nil { + result.TargetSchemaVersion = f.currentVersionString(ctx, history) + return result, err + } + result.TargetSchemaVersion = f.currentVersionString(ctx, history) return result, nil } -// applyMigration executes a single migration inside its own transaction and -// records it in the schema history. The history row is written in the same -// transaction as the migration statements, so a failure rolls back both. +// applyMigration prepares a migration's statements and applies them either +// within a transaction or, when the script opted out, directly on a dedicated +// connection. func (f *Migrator) applyMigration(ctx context.Context, db *sql.DB, history *schemaHistory, entry *migrationInfoEntry, schema, installedBy string) (MigrationInfo, error) { migration := entry.resolved @@ -107,6 +116,22 @@ func (f *Migrator) applyMigration(ctx context.Context, db *sql.DB, history *sche return MigrationInfo{}, err } + info := MigrationInfo{ + Version: migration.version.String(), + Description: migration.description, + Type: string(migration.migrationType), + Script: migration.script, + } + + if migration.noTransaction { + return f.applyWithoutTransaction(ctx, db, history, migration, statements, schema, installedBy, info, entry.outOfOrder) + } + return f.applyWithinTransaction(ctx, db, history, migration, statements, schema, installedBy, info, entry.outOfOrder) +} + +// applyWithinTransaction runs the migration and records its history row inside a +// single transaction, so a failure rolls back both, leaving no partial state. +func (f *Migrator) applyWithinTransaction(ctx context.Context, db *sql.DB, history *schemaHistory, migration *resolvedMigration, statements []string, schema, installedBy string, info MigrationInfo, outOfOrder bool) (MigrationInfo, error) { transaction, err := db.BeginTx(ctx, nil) if err != nil { return MigrationInfo{}, err @@ -123,12 +148,20 @@ func (f *Migrator) applyMigration(ctx context.Context, db *sql.DB, history *sche } start := time.Now() + if err := f.fireCallbacks(ctx, transaction, EventBeforeEachMigrate, &info, schema); err != nil { + _ = transaction.Rollback() + return MigrationInfo{}, err + } for _, statement := range statements { if _, err := transaction.ExecContext(ctx, statement); err != nil { _ = transaction.Rollback() return MigrationInfo{}, fmt.Errorf("goway: applying migration %s: %w", migration.script, err) } } + if err := f.fireCallbacks(ctx, transaction, EventAfterEachMigrate, &info, schema); err != nil { + _ = transaction.Rollback() + return MigrationInfo{}, err + } executionTime := int(time.Since(start).Milliseconds()) rank, err := history.nextInstalledRank(ctx, transaction) @@ -136,9 +169,75 @@ func (f *Migrator) applyMigration(ctx context.Context, db *sql.DB, history *sche _ = transaction.Rollback() return MigrationInfo{}, err } + if err := history.insert(ctx, transaction, f.buildRecord(migration, rank, installedBy, executionTime, true)); err != nil { + _ = transaction.Rollback() + return MigrationInfo{}, err + } + if err := transaction.Commit(); err != nil { + return MigrationInfo{}, fmt.Errorf("goway: committing migration %s: %w", migration.script, err) + } + return f.finishInfo(info, migration, rank, installedBy, executionTime, outOfOrder), nil +} + +// applyWithoutTransaction runs a migration that opted out of the transaction, +// for statements such as PostgreSQL's CREATE INDEX CONCURRENTLY or SQLite's +// VACUUM that cannot run inside a transaction block. It uses a dedicated +// connection and, on failure, records a failed history row since there is no +// transaction to roll back. +func (f *Migrator) applyWithoutTransaction(ctx context.Context, db *sql.DB, history *schemaHistory, migration *resolvedMigration, statements []string, schema, installedBy string, info MigrationInfo, outOfOrder bool) (MigrationInfo, error) { + connection, err := db.Conn(ctx) + if err != nil { + return MigrationInfo{}, err + } + defer connection.Close() + + if searchPath := f.dialect.sessionSearchPathSQL(schema); searchPath != "" { + if _, err := connection.ExecContext(ctx, searchPath); err != nil { + return MigrationInfo{}, fmt.Errorf("goway: setting search path for %s: %w", migration.script, err) + } + } + + start := time.Now() + runErr := f.runStatements(ctx, connection, migration, statements, &info, schema) + executionTime := int(time.Since(start).Milliseconds()) + + rank, err := history.nextInstalledRank(ctx, connection) + if err != nil { + if runErr != nil { + return MigrationInfo{}, runErr + } + return MigrationInfo{}, err + } + + if runErr != nil { + // Best effort: record the failure so it is visible and can be repaired. + _ = history.insert(ctx, connection, f.buildRecord(migration, rank, installedBy, executionTime, false)) + return MigrationInfo{}, runErr + } + if err := history.insert(ctx, connection, f.buildRecord(migration, rank, installedBy, executionTime, true)); err != nil { + return MigrationInfo{}, err + } + return f.finishInfo(info, migration, rank, installedBy, executionTime, outOfOrder), nil +} +// runStatements fires the per-migration callbacks around the migration's own +// statements on the given executor. +func (f *Migrator) runStatements(ctx context.Context, exec Execer, migration *resolvedMigration, statements []string, info *MigrationInfo, schema string) error { + if err := f.fireCallbacks(ctx, exec, EventBeforeEachMigrate, info, schema); err != nil { + return err + } + for _, statement := range statements { + if _, err := exec.ExecContext(ctx, statement); err != nil { + return fmt.Errorf("goway: applying migration %s: %w", migration.script, err) + } + } + return f.fireCallbacks(ctx, exec, EventAfterEachMigrate, info, schema) +} + +// buildRecord assembles the schema history row for an applied migration. +func (f *Migrator) buildRecord(migration *resolvedMigration, rank int, installedBy string, executionTime int, success bool) appliedMigration { checksum := migration.checksum - record := appliedMigration{ + return appliedMigration{ installedRank: rank, version: migration.version, description: migration.description, @@ -147,33 +246,69 @@ func (f *Migrator) applyMigration(ctx context.Context, db *sql.DB, history *sche checksum: &checksum, installedBy: installedBy, executionTime: executionTime, - success: true, - } - if err := history.insert(ctx, transaction, record); err != nil { - _ = transaction.Rollback() - return MigrationInfo{}, err - } - if err := transaction.Commit(); err != nil { - return MigrationInfo{}, fmt.Errorf("goway: committing migration %s: %w", migration.script, err) + success: success, } +} +// finishInfo completes the public migration info for a successfully applied +// migration. +func (f *Migrator) finishInfo(info MigrationInfo, migration *resolvedMigration, rank int, installedBy string, executionTime int, outOfOrder bool) MigrationInfo { state := StateSuccess - if entry.outOfOrder { + if outOfOrder { state = StateOutOfOrder } + checksum := migration.checksum installedOn := nowUTC() - return MigrationInfo{ - Version: migration.version.String(), - Description: migration.description, - Type: string(migration.migrationType), - Script: migration.script, - Checksum: &checksum, - State: state, - InstalledRank: rank, - InstalledOn: &installedOn, - InstalledBy: installedBy, - ExecutionTime: executionTime, - }, nil + info.Checksum = &checksum + info.State = state + info.InstalledRank = rank + info.InstalledOn = &installedOn + info.InstalledBy = installedBy + info.ExecutionTime = executionTime + return info +} + +// fireCallbacks runs the SQL callback scripts and the programmatic callbacks +// registered for the given event, in that order. +func (f *Migrator) fireCallbacks(ctx context.Context, exec Execer, event CallbackEvent, migration *MigrationInfo, schema string) error { + for _, callback := range f.sqlCallbacks { + if callback.event != event { + continue + } + if err := f.runSQLCallback(ctx, exec, callback, schema); err != nil { + return err + } + } + for _, callback := range f.configuration.callbacks { + if err := callback.Handle(ctx, event, exec, migration); err != nil { + return fmt.Errorf("goway: callback for %s: %w", event, err) + } + } + return nil +} + +// runSQLCallback executes the statements of a single SQL callback script. +func (f *Migrator) runSQLCallback(ctx context.Context, exec Execer, callback sqlCallback, schema string) error { + content, err := callback.read() + if err != nil { + return fmt.Errorf("goway: reading callback %s: %w", callback.script, err) + } + script, err := replacePlaceholders(string(content), + f.effectivePlaceholders(schema, "", callback.script), + f.configuration.placeholderPrefix, f.configuration.placeholderSuffix) + if err != nil { + return err + } + statements, err := f.dialect.splitStatements(script) + if err != nil { + return err + } + for _, statement := range statements { + if _, err := exec.ExecContext(ctx, statement); err != nil { + return fmt.Errorf("goway: callback %s: %w", callback.script, err) + } + } + return nil } // recordSchemaCreation writes the synthetic entry that documents schemas created diff --git a/migration.go b/migration.go index a477ba7..669499b 100644 --- a/migration.go +++ b/migration.go @@ -45,6 +45,10 @@ const ( // it was last applied and will be re-applied. StateOutdated MigrationState = "Outdated" + // StateSuperseded indicates an older applied run of a repeatable migration + // that has since been re-applied; only the newest run is current. + StateSuperseded MigrationState = "Superseded" + // StateMissing indicates a successfully applied migration that can no longer // be resolved from the configured locations. StateMissing MigrationState = "Missing" @@ -93,6 +97,10 @@ type resolvedMigration struct { // migrationType classifies the migration for the history table. migrationType MigrationType + // noTransaction reports whether the script opted out of the per-migration + // transaction through a comment directive. + noTransaction bool + // read returns the raw script content. read func() ([]byte, error) } diff --git a/migrator.go b/migrator.go index dee7f15..9aad445 100644 --- a/migrator.go +++ b/migrator.go @@ -12,6 +12,7 @@ type Migrator struct { configuration *Configuration dialect Dialect resolved []*resolvedMigration + sqlCallbacks []sqlCallback } // Dialect returns the dialect that was configured or detected. diff --git a/resolver.go b/resolver.go index 1dbf152..b284c76 100644 --- a/resolver.go +++ b/resolver.go @@ -8,20 +8,22 @@ import ( // resolveMigrations scans every configured source, parses each candidate file // name, computes its checksum, and returns the resolved migrations sorted in // the order they would be applied: versioned migrations by ascending version, -// followed by repeatable migrations by ascending description. Duplicate +// followed by repeatable migrations by ascending description. Callback scripts +// are returned separately, in the order they were discovered. Duplicate // versions or repeatable descriptions are rejected. -func resolveMigrations(configuration *Configuration) ([]*resolvedMigration, error) { +func resolveMigrations(configuration *Configuration) ([]*resolvedMigration, []sqlCallback, error) { var scanned []scannedFile for _, source := range configuration.sources() { files, err := source.scan() if err != nil { - return nil, err + return nil, nil, err } scanned = append(scanned, files...) } var versioned []*resolvedMigration var repeatable []*resolvedMigration + var callbacks []sqlCallback seenVersions := make(map[string]string) seenRepeatable := make(map[string]string) @@ -34,12 +36,16 @@ func resolveMigrations(configuration *Configuration) ([]*resolvedMigration, erro configuration.sqlMigrationSuffixes, ) if !parsed.valid { + // A file that is not a migration may still be a callback script. + if event, ok := parseCallbackName(file.name, configuration.sqlMigrationSeparator, configuration.sqlMigrationSuffixes); ok { + callbacks = append(callbacks, sqlCallback{event: event, script: file.name, read: file.read}) + } continue } content, err := file.read() if err != nil { - return nil, fmt.Errorf("goway: reading migration %s: %w", file.location, err) + return nil, nil, fmt.Errorf("goway: reading migration %s: %w", file.location, err) } checksum := calculateChecksum(content) reader := file.read @@ -50,12 +56,13 @@ func resolveMigrations(configuration *Configuration) ([]*resolvedMigration, erro checksum: checksum, repeatable: parsed.repeatable, migrationType: MigrationTypeSQL, + noTransaction: scriptRequestsNoTransaction(string(content)), read: reader, } if parsed.repeatable { if previous, exists := seenRepeatable[parsed.description]; exists { - return nil, fmt.Errorf("%w: %q and %q both describe %q", + return nil, nil, fmt.Errorf("%w: %q and %q both describe %q", ErrDuplicateRepeatable, previous, file.name, parsed.description) } seenRepeatable[parsed.description] = file.name @@ -65,13 +72,13 @@ func resolveMigrations(configuration *Configuration) ([]*resolvedMigration, erro version, err := parseVersion(parsed.rawVersion) if err != nil { - return nil, fmt.Errorf("goway: migration %s: %w", file.name, err) + return nil, nil, fmt.Errorf("goway: migration %s: %w", file.name, err) } migration.version = version key := normalizedVersionKey(version) if previous, exists := seenVersions[key]; exists { - return nil, fmt.Errorf("%w: %q and %q both target version %s", + return nil, nil, fmt.Errorf("%w: %q and %q both target version %s", ErrDuplicateVersion, previous, file.name, version) } seenVersions[key] = file.name @@ -88,7 +95,7 @@ func resolveMigrations(configuration *Configuration) ([]*resolvedMigration, erro resolved := make([]*resolvedMigration, 0, len(versioned)+len(repeatable)) resolved = append(resolved, versioned...) resolved = append(resolved, repeatable...) - return resolved, nil + return resolved, callbacks, nil } // normalizedVersionKey produces a canonical key for a version that ignores diff --git a/script.go b/script.go new file mode 100644 index 0000000..ee6d104 --- /dev/null +++ b/script.go @@ -0,0 +1,32 @@ +package goway + +import "strings" + +// noTransactionMarkers are the comment directives that request a migration to +// run outside a transaction. The goway form is preferred; the flyway form is +// accepted for familiarity. +var noTransactionMarkers = []string{ + "goway:notransaction", + "flyway:executeintransaction=false", +} + +// scriptRequestsNoTransaction reports whether a SQL script opts out of the +// per-migration transaction through a comment directive. Only comment lines are +// inspected, so a marker that happens to appear inside a statement or string +// literal does not trigger the behavior. Detection is case insensitive and runs +// on the raw script before placeholder replacement. +func scriptRequestsNoTransaction(script string) bool { + for _, raw := range strings.Split(script, "\n") { + line := strings.TrimSpace(raw) + if !strings.HasPrefix(line, "--") { + continue + } + lower := strings.ToLower(line) + for _, marker := range noTransactionMarkers { + if strings.Contains(lower, marker) { + return true + } + } + } + return false +} diff --git a/script_test.go b/script_test.go new file mode 100644 index 0000000..1a0f657 --- /dev/null +++ b/script_test.go @@ -0,0 +1,29 @@ +package goway + +import "testing" + +func TestScriptRequestsNoTransaction(t *testing.T) { + requests := []string{ + "-- goway:noTransaction\nCREATE INDEX CONCURRENTLY idx ON t (c);", + "--goway:noTransaction\nVACUUM;", + "-- GOWAY:NOTRANSACTION\nSELECT 1;", + "-- flyway:executeInTransaction=false\nSELECT 1;", + "-- a normal comment\n-- goway:noTransaction\nSELECT 1;", + } + for _, script := range requests { + if !scriptRequestsNoTransaction(script) { + t.Errorf("expected no-transaction directive to be detected in:\n%s", script) + } + } + + plain := []string{ + "CREATE TABLE t (id INTEGER);", + "-- just a comment\nSELECT 1;", + "SELECT 'goway:noTransaction';", // appears in a statement, not a comment + } + for _, script := range plain { + if scriptRequestsNoTransaction(script) { + t.Errorf("did not expect a directive to be detected in:\n%s", script) + } + } +}