From 497acb407f47099986d3978c1d10acc5d0d3c380 Mon Sep 17 00:00:00 2001 From: Michael Fridman Date: Sun, 29 Oct 2023 21:55:41 -0400 Subject: [PATCH] feat(experimental): prefactor provider and cleanup (#626) --- Dockerfile.local | 8 -- goose.go | 2 +- internal/provider/collect.go | 65 +++++--------- internal/provider/collect_test.go | 123 +++++++++++++++----------- internal/provider/migration.go | 8 +- internal/provider/misc.go | 12 +-- internal/provider/provider.go | 61 +++++++------ internal/provider/provider_options.go | 26 +++--- internal/provider/provider_test.go | 16 ++-- internal/provider/run.go | 30 +++---- internal/provider/run_test.go | 112 ++++++++++++----------- 11 files changed, 236 insertions(+), 227 deletions(-) delete mode 100644 Dockerfile.local diff --git a/Dockerfile.local b/Dockerfile.local deleted file mode 100644 index 1c66de73e..000000000 --- a/Dockerfile.local +++ /dev/null @@ -1,8 +0,0 @@ -FROM golang:1.17-buster@sha256:3e663ba6af8281b04975b0a34a14d538cdd7d284213f83f05aaf596b80a8c725 as builder - -COPY . /src -WORKDIR /src -RUN CGO_ENABLED=0 make dist - -FROM scratch AS exporter -COPY --from=builder /src/bin/ / \ No newline at end of file diff --git a/goose.go b/goose.go index e952041b0..daf059366 100644 --- a/goose.go +++ b/goose.go @@ -8,7 +8,7 @@ import ( "strconv" ) -// Deprecated: VERSION will no longer be supported in v4. +// Deprecated: VERSION will no longer be supported in the next major release. const VERSION = "v3.2.0" var ( diff --git a/internal/provider/collect.go b/internal/provider/collect.go index a4a73c0e6..345da0d06 100644 --- a/internal/provider/collect.go +++ b/internal/provider/collect.go @@ -7,17 +7,10 @@ import ( "os" "path/filepath" "sort" - "strconv" "strings" -) -func NewSource(t MigrationType, fullpath string, version int64) Source { - return Source{ - Type: t, - Path: fullpath, - Version: version, - } -} + "github.com/pressly/goose/v3" +) // fileSources represents a collection of migration files on the filesystem. type fileSources struct { @@ -44,16 +37,16 @@ func (s *fileSources) lookup(t MigrationType, version int64) *Source { return nil } -// collectFileSources scans the file system for migration files that have a numeric prefix (greater -// than one) followed by an underscore and a file extension of either .go or .sql. fsys may be nil, -// in which case an empty fileSources is returned. +// collectFilesystemSources scans the file system for migration files that have a numeric prefix +// (greater than one) followed by an underscore and a file extension of either .go or .sql. fsys may +// be nil, in which case an empty fileSources is returned. // // If strict is true, then any error parsing the numeric component of the filename will result in an // error. The file is skipped otherwise. // // This function DOES NOT parse SQL migrations or merge registered Go migrations. It only collects // migration sources from the filesystem. -func collectFileSources(fsys fs.FS, strict bool, excludes map[string]bool) (*fileSources, error) { +func collectFilesystemSources(fsys fs.FS, strict bool, excludes map[string]bool) (*fileSources, error) { if fsys == nil { return new(fileSources), nil } @@ -78,7 +71,7 @@ func collectFileSources(fsys fs.FS, strict bool, excludes map[string]bool) (*fil // filenames, but still have versioned migrations within the same directory. For // example, a user could have a helpers.go file which contains unexported helper // functions for migrations. - version, err := NumericComponent(base) + version, err := goose.NumericComponent(base) if err != nil { if strict { return nil, fmt.Errorf("failed to parse numeric component from %q: %w", base, err) @@ -95,9 +88,17 @@ func collectFileSources(fsys fs.FS, strict bool, excludes map[string]bool) (*fil } switch filepath.Ext(base) { case ".sql": - sources.sqlSources = append(sources.sqlSources, NewSource(TypeSQL, fullpath, version)) + sources.sqlSources = append(sources.sqlSources, Source{ + Type: TypeSQL, + Path: fullpath, + Version: version, + }) case ".go": - sources.goSources = append(sources.goSources, NewSource(TypeGo, fullpath, version)) + sources.goSources = append(sources.goSources, Source{ + Type: TypeGo, + Path: fullpath, + Version: version, + }) default: // Should never happen since we already filtered out all other file types. return nil, fmt.Errorf("unknown migration type: %s", base) @@ -165,9 +166,12 @@ func merge(sources *fileSources, registerd map[int64]*goMigration) ([]*migration ) } m := &migration{ - // Note, the fullpath may be empty if the migration was registered manually. - Source: NewSource(TypeGo, fullpath, version), - Go: r, + Source: Source{ + Type: TypeGo, + Path: fullpath, // May be empty if migration was registered manually. + Version: version, + }, + Go: r, } migrations = append(migrations, m) migrationLookup[version] = m @@ -207,26 +211,3 @@ var _ fs.FS = noopFS{} func (f noopFS) Open(name string) (fs.File, error) { return nil, os.ErrNotExist } - -// NumericComponent parses the version from the migration file name. -// -// XXX_descriptivename.ext where XXX specifies the version number and ext specifies the type of -// migration, either .sql or .go. -func NumericComponent(filename string) (int64, error) { - base := filepath.Base(filename) - if ext := filepath.Ext(base); ext != ".go" && ext != ".sql" { - return 0, errors.New("migration file does not have .sql or .go file extension") - } - idx := strings.Index(base, "_") - if idx < 0 { - return 0, errors.New("no filename separator '_' found") - } - n, err := strconv.ParseInt(base[:idx], 10, 64) - if err != nil { - return 0, err - } - if n < 1 { - return 0, errors.New("migration version must be greater than zero") - } - return n, nil -} diff --git a/internal/provider/collect_test.go b/internal/provider/collect_test.go index b1983d76d..e696ab005 100644 --- a/internal/provider/collect_test.go +++ b/internal/provider/collect_test.go @@ -12,14 +12,21 @@ import ( func TestCollectFileSources(t *testing.T) { t.Parallel() t.Run("nil_fsys", func(t *testing.T) { - sources, err := collectFileSources(nil, false, nil) + sources, err := collectFilesystemSources(nil, false, nil) + check.NoError(t, err) + check.Bool(t, sources != nil, true) + check.Number(t, len(sources.goSources), 0) + check.Number(t, len(sources.sqlSources), 0) + }) + t.Run("noop_fsys", func(t *testing.T) { + sources, err := collectFilesystemSources(noopFS{}, false, nil) check.NoError(t, err) check.Bool(t, sources != nil, true) check.Number(t, len(sources.goSources), 0) check.Number(t, len(sources.sqlSources), 0) }) t.Run("empty_fsys", func(t *testing.T) { - sources, err := collectFileSources(fstest.MapFS{}, false, nil) + sources, err := collectFilesystemSources(fstest.MapFS{}, false, nil) check.NoError(t, err) check.Number(t, len(sources.goSources), 0) check.Number(t, len(sources.sqlSources), 0) @@ -30,28 +37,28 @@ func TestCollectFileSources(t *testing.T) { "00000_foo.sql": sqlMapFile, } // strict disable - should not error - sources, err := collectFileSources(mapFS, false, nil) + sources, err := collectFilesystemSources(mapFS, false, nil) check.NoError(t, err) check.Number(t, len(sources.goSources), 0) check.Number(t, len(sources.sqlSources), 0) // strict enabled - should error - _, err = collectFileSources(mapFS, true, nil) + _, err = collectFilesystemSources(mapFS, true, nil) check.HasError(t, err) check.Contains(t, err.Error(), "migration version must be greater than zero") }) t.Run("collect", func(t *testing.T) { fsys, err := fs.Sub(newSQLOnlyFS(), "migrations") check.NoError(t, err) - sources, err := collectFileSources(fsys, false, nil) + sources, err := collectFilesystemSources(fsys, false, nil) check.NoError(t, err) check.Number(t, len(sources.sqlSources), 4) check.Number(t, len(sources.goSources), 0) expected := fileSources{ sqlSources: []Source{ - NewSource(TypeSQL, "00001_foo.sql", 1), - NewSource(TypeSQL, "00002_bar.sql", 2), - NewSource(TypeSQL, "00003_baz.sql", 3), - NewSource(TypeSQL, "00110_qux.sql", 110), + newSource(TypeSQL, "00001_foo.sql", 1), + newSource(TypeSQL, "00002_bar.sql", 2), + newSource(TypeSQL, "00003_baz.sql", 3), + newSource(TypeSQL, "00110_qux.sql", 110), }, } for i := 0; i < len(sources.sqlSources); i++ { @@ -61,7 +68,7 @@ func TestCollectFileSources(t *testing.T) { t.Run("excludes", func(t *testing.T) { fsys, err := fs.Sub(newSQLOnlyFS(), "migrations") check.NoError(t, err) - sources, err := collectFileSources( + sources, err := collectFilesystemSources( fsys, false, // exclude 2 files explicitly @@ -75,8 +82,8 @@ func TestCollectFileSources(t *testing.T) { check.Number(t, len(sources.goSources), 0) expected := fileSources{ sqlSources: []Source{ - NewSource(TypeSQL, "00001_foo.sql", 1), - NewSource(TypeSQL, "00003_baz.sql", 3), + newSource(TypeSQL, "00001_foo.sql", 1), + newSource(TypeSQL, "00003_baz.sql", 3), }, } for i := 0; i < len(sources.sqlSources); i++ { @@ -89,7 +96,7 @@ func TestCollectFileSources(t *testing.T) { mapFS["migrations/not_valid.sql"] = &fstest.MapFile{Data: []byte("invalid")} fsys, err := fs.Sub(mapFS, "migrations") check.NoError(t, err) - _, err = collectFileSources(fsys, true, nil) + _, err = collectFilesystemSources(fsys, true, nil) check.HasError(t, err) check.Contains(t, err.Error(), `failed to parse numeric component from "not_valid.sql"`) }) @@ -101,7 +108,7 @@ func TestCollectFileSources(t *testing.T) { "4_qux.sql": sqlMapFile, "5_foo_test.go": {Data: []byte(`package goose_test`)}, } - sources, err := collectFileSources(mapFS, false, nil) + sources, err := collectFilesystemSources(mapFS, false, nil) check.NoError(t, err) check.Number(t, len(sources.sqlSources), 4) check.Number(t, len(sources.goSources), 0) @@ -116,7 +123,7 @@ func TestCollectFileSources(t *testing.T) { "no_a_real_migration.sql": {Data: []byte(`SELECT 1;`)}, "some/other/dir/2_foo.sql": {Data: []byte(`SELECT 1;`)}, } - sources, err := collectFileSources(mapFS, false, nil) + sources, err := collectFilesystemSources(mapFS, false, nil) check.NoError(t, err) check.Number(t, len(sources.sqlSources), 2) check.Number(t, len(sources.goSources), 1) @@ -135,7 +142,7 @@ func TestCollectFileSources(t *testing.T) { "001_foo.sql": sqlMapFile, "01_bar.sql": sqlMapFile, } - _, err := collectFileSources(mapFS, false, nil) + _, err := collectFilesystemSources(mapFS, false, nil) check.HasError(t, err) check.Contains(t, err.Error(), "found duplicate migration version 1") }) @@ -151,7 +158,7 @@ func TestCollectFileSources(t *testing.T) { t.Helper() f, err := fs.Sub(mapFS, dirpath) check.NoError(t, err) - got, err := collectFileSources(f, false, nil) + got, err := collectFilesystemSources(f, false, nil) check.NoError(t, err) check.Number(t, len(got.sqlSources), len(sqlSources)) check.Number(t, len(got.goSources), 0) @@ -160,15 +167,15 @@ func TestCollectFileSources(t *testing.T) { } } assertDirpath(".", []Source{ - NewSource(TypeSQL, "876_a.sql", 876), + newSource(TypeSQL, "876_a.sql", 876), }) assertDirpath("dir1", []Source{ - NewSource(TypeSQL, "101_a.sql", 101), - NewSource(TypeSQL, "102_b.sql", 102), - NewSource(TypeSQL, "103_c.sql", 103), + newSource(TypeSQL, "101_a.sql", 101), + newSource(TypeSQL, "102_b.sql", 102), + newSource(TypeSQL, "103_c.sql", 103), }) assertDirpath("dir2", []Source{ - NewSource(TypeSQL, "201_a.sql", 201), + newSource(TypeSQL, "201_a.sql", 201), }) assertDirpath("dir3", nil) }) @@ -187,7 +194,7 @@ func TestMerge(t *testing.T) { } fsys, err := fs.Sub(mapFS, "migrations") check.NoError(t, err) - sources, err := collectFileSources(fsys, false, nil) + sources, err := collectFilesystemSources(fsys, false, nil) check.NoError(t, err) check.Equal(t, len(sources.sqlSources), 1) check.Equal(t, len(sources.goSources), 2) @@ -205,9 +212,9 @@ func TestMerge(t *testing.T) { }) check.NoError(t, err) check.Number(t, len(migrations), 3) - assertMigration(t, migrations[0], NewSource(TypeSQL, "00001_foo.sql", 1)) - assertMigration(t, migrations[1], NewSource(TypeGo, "00002_bar.go", 2)) - assertMigration(t, migrations[2], NewSource(TypeGo, "00003_baz.go", 3)) + assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1)) + assertMigration(t, migrations[1], newSource(TypeGo, "00002_bar.go", 2)) + assertMigration(t, migrations[2], newSource(TypeGo, "00003_baz.go", 3)) }) t.Run("unregistered_all", func(t *testing.T) { _, err := merge(sources, nil) @@ -243,7 +250,7 @@ func TestMerge(t *testing.T) { } fsys, err := fs.Sub(mapFS, "migrations") check.NoError(t, err) - sources, err := collectFileSources(fsys, false, nil) + sources, err := collectFilesystemSources(fsys, false, nil) check.NoError(t, err) t.Run("unregistered_all", func(t *testing.T) { migrations, err := merge(sources, map[int64]*goMigration{ @@ -253,11 +260,11 @@ func TestMerge(t *testing.T) { }) check.NoError(t, err) check.Number(t, len(migrations), 5) - assertMigration(t, migrations[0], NewSource(TypeSQL, "00001_foo.sql", 1)) - assertMigration(t, migrations[1], NewSource(TypeSQL, "00002_bar.sql", 2)) - assertMigration(t, migrations[2], NewSource(TypeGo, "", 3)) - assertMigration(t, migrations[3], NewSource(TypeSQL, "00005_baz.sql", 5)) - assertMigration(t, migrations[4], NewSource(TypeGo, "", 6)) + assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1)) + assertMigration(t, migrations[1], newSource(TypeSQL, "00002_bar.sql", 2)) + assertMigration(t, migrations[2], newSource(TypeGo, "", 3)) + assertMigration(t, migrations[3], newSource(TypeSQL, "00005_baz.sql", 5)) + assertMigration(t, migrations[4], newSource(TypeGo, "", 6)) }) }) t.Run("partial_go_files_on_disk", func(t *testing.T) { @@ -267,7 +274,7 @@ func TestMerge(t *testing.T) { } fsys, err := fs.Sub(mapFS, "migrations") check.NoError(t, err) - sources, err := collectFileSources(fsys, false, nil) + sources, err := collectFilesystemSources(fsys, false, nil) check.NoError(t, err) t.Run("unregistered_all", func(t *testing.T) { migrations, err := merge(sources, map[int64]*goMigration{ @@ -279,15 +286,15 @@ func TestMerge(t *testing.T) { }) check.NoError(t, err) check.Number(t, len(migrations), 4) - assertMigration(t, migrations[0], NewSource(TypeSQL, "00001_foo.sql", 1)) - assertMigration(t, migrations[1], NewSource(TypeGo, "00002_bar.go", 2)) - assertMigration(t, migrations[2], NewSource(TypeGo, "", 3)) - assertMigration(t, migrations[3], NewSource(TypeGo, "", 6)) + assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1)) + assertMigration(t, migrations[1], newSource(TypeGo, "00002_bar.go", 2)) + assertMigration(t, migrations[2], newSource(TypeGo, "", 3)) + assertMigration(t, migrations[3], newSource(TypeGo, "", 6)) }) }) } -func TestFindMissingMigrations(t *testing.T) { +func TestCheckMissingMigrations(t *testing.T) { t.Parallel() t.Run("db_has_max_version", func(t *testing.T) { @@ -302,24 +309,24 @@ func TestFindMissingMigrations(t *testing.T) { {Version: 7}, // <-- database max version_id } fsMigrations := []*migration{ - newMigration(1), - newMigration(2), // missing migration - newMigration(3), - newMigration(4), - newMigration(5), - newMigration(6), // missing migration - newMigration(7), // ----- database max version_id ----- - newMigration(8), // new migration + newMigrationVersion(1), + newMigrationVersion(2), // missing migration + newMigrationVersion(3), + newMigrationVersion(4), + newMigrationVersion(5), + newMigrationVersion(6), // missing migration + newMigrationVersion(7), // ----- database max version_id ----- + newMigrationVersion(8), // new migration } - got := findMissingMigrations(dbMigrations, fsMigrations) + got := checkMissingMigrations(dbMigrations, fsMigrations) check.Number(t, len(got), 2) check.Number(t, got[0].versionID, 2) check.Number(t, got[1].versionID, 6) // Sanity check. - check.Number(t, len(findMissingMigrations(nil, nil)), 0) - check.Number(t, len(findMissingMigrations(dbMigrations, nil)), 0) - check.Number(t, len(findMissingMigrations(nil, fsMigrations)), 0) + check.Number(t, len(checkMissingMigrations(nil, nil)), 0) + check.Number(t, len(checkMissingMigrations(dbMigrations, nil)), 0) + check.Number(t, len(checkMissingMigrations(nil, fsMigrations)), 0) }) t.Run("fs_has_max_version", func(t *testing.T) { dbMigrations := []*database.ListMigrationsResult{ @@ -328,17 +335,17 @@ func TestFindMissingMigrations(t *testing.T) { {Version: 2}, } fsMigrations := []*migration{ - newMigration(3), // new migration - newMigration(4), // new migration + newMigrationVersion(3), // new migration + newMigrationVersion(4), // new migration } - got := findMissingMigrations(dbMigrations, fsMigrations) + got := checkMissingMigrations(dbMigrations, fsMigrations) check.Number(t, len(got), 2) check.Number(t, got[0].versionID, 3) check.Number(t, got[1].versionID, 4) }) } -func newMigration(version int64) *migration { +func newMigrationVersion(version int64) *migration { return &migration{ Source: Source{ Version: version, @@ -368,6 +375,14 @@ func newSQLOnlyFS() fstest.MapFS { } } +func newSource(t MigrationType, fullpath string, version int64) Source { + return Source{ + Type: t, + Path: fullpath, + Version: version, + } +} + var ( sqlMapFile = &fstest.MapFile{Data: []byte(`-- +goose Up`)} ) diff --git a/internal/provider/migration.go b/internal/provider/migration.go index 2ace5f93d..07508ff68 100644 --- a/internal/provider/migration.go +++ b/internal/provider/migration.go @@ -44,7 +44,7 @@ func (m *migration) useTx(direction bool) bool { func (m *migration) isEmpty(direction bool) bool { switch m.Source.Type { case TypeSQL: - return m.SQL == nil || m.SQL.IsEmpty(direction) + return m.SQL == nil || m.SQL.isEmpty(direction) case TypeGo: return m.Go == nil || m.Go.isEmpty(direction) } @@ -102,7 +102,7 @@ func (m *migration) runConn(ctx context.Context, conn *sql.Conn, direction bool) type goMigration struct { fullpath string - up, down *GoMigration + up, down *GoMigrationFunc } func (g *goMigration) isEmpty(direction bool) bool { @@ -115,7 +115,7 @@ func (g *goMigration) isEmpty(direction bool) bool { return g.down == nil } -func newGoMigration(fullpath string, up, down *GoMigration) *goMigration { +func newGoMigration(fullpath string, up, down *GoMigrationFunc) *goMigration { return &goMigration{ fullpath: fullpath, up: up, @@ -163,7 +163,7 @@ type sqlMigration struct { DownStatements []string } -func (s *sqlMigration) IsEmpty(direction bool) bool { +func (s *sqlMigration) isEmpty(direction bool) bool { if direction { return len(s.UpStatements) == 0 } diff --git a/internal/provider/misc.go b/internal/provider/misc.go index 717edff58..e20fbad18 100644 --- a/internal/provider/misc.go +++ b/internal/provider/misc.go @@ -5,9 +5,11 @@ import ( "database/sql" "errors" "fmt" + + "github.com/pressly/goose/v3" ) -type Migration struct { +type MigrationCopy struct { Version int64 Source string // path to .sql script or go file Registered bool @@ -15,13 +17,13 @@ type Migration struct { UpFnNoTxContext, DownFnNoTxContext func(context.Context, *sql.DB) error } -var registeredGoMigrations = make(map[int64]*Migration) +var registeredGoMigrations = make(map[int64]*MigrationCopy) // SetGlobalGoMigrations registers the given go migrations globally. It returns an error if any of // the migrations are nil or if a migration with the same version has already been registered. // // Not safe for concurrent use. -func SetGlobalGoMigrations(migrations []*Migration) error { +func SetGlobalGoMigrations(migrations []*MigrationCopy) error { for _, m := range migrations { if m == nil { return errors.New("cannot register nil go migration") @@ -35,7 +37,7 @@ func SetGlobalGoMigrations(migrations []*Migration) error { if m.Source != "" { // If the source is set, expect it to be a file path with a numeric component that // matches the version. - version, err := NumericComponent(m.Source) + version, err := goose.NumericComponent(m.Source) if err != nil { return err } @@ -62,5 +64,5 @@ func SetGlobalGoMigrations(migrations []*Migration) error { // // Not safe for concurrent use. func ResetGlobalGoMigrations() { - registeredGoMigrations = make(map[int64]*Migration) + registeredGoMigrations = make(map[int64]*MigrationCopy) } diff --git a/internal/provider/provider.go b/internal/provider/provider.go index 89c7444bd..bd68e2ff1 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -12,6 +12,21 @@ import ( "github.com/pressly/goose/v3/database" ) +// Provider is a goose migration provider. +type Provider struct { + // mu protects all accesses to the provider and must be held when calling operations on the + // database. + mu sync.Mutex + + db *sql.DB + fsys fs.FS + cfg config + store database.Store + + // migrations are ordered by version in ascending order. + migrations []*migration +} + // NewProvider returns a new goose Provider. // // The caller is responsible for matching the database dialect with the database/sql driver. For @@ -46,11 +61,13 @@ func NewProvider(dialect database.Dialect, db *sql.DB, fsys fs.FS, opts ...Provi return nil, err } } + // Allow users to specify a custom store implementation, but only if they don't specify a + // dialect. If they specify a dialect, we'll use the default store implementation. if dialect == "" && cfg.store == nil { return nil, errors.New("dialect must not be empty") } if dialect != "" && cfg.store != nil { - return nil, errors.New("cannot set both dialect and store") + return nil, errors.New("cannot set both dialect and custom store") } var store database.Store if dialect != "" { @@ -65,6 +82,16 @@ func NewProvider(dialect database.Dialect, db *sql.DB, fsys fs.FS, opts ...Provi if store.Tablename() == "" { return nil, errors.New("invalid store implementation: table name must not be empty") } + return newProvider(db, store, fsys, cfg, registeredGoMigrations /* global */) +} + +func newProvider( + db *sql.DB, + store database.Store, + fsys fs.FS, + cfg config, + global map[int64]*MigrationCopy, +) (*Provider, error) { // Collect migrations from the filesystem and merge with registered migrations. // // Note, neither of these functions parse SQL migrations by default. SQL migrations are parsed @@ -73,13 +100,10 @@ func NewProvider(dialect database.Dialect, db *sql.DB, fsys fs.FS, opts ...Provi // TODO(mf): we should expose a way to parse SQL migrations eagerly. This would allow us to // return an error if there are any SQL parsing errors. This adds a bit overhead to startup // though, so we should make it optional. - sources, err := collectFileSources(fsys, false, cfg.excludes) + filesystemSources, err := collectFilesystemSources(fsys, false, cfg.excludes) if err != nil { return nil, err } - // - // TODO(mf): move the merging of Go migrations into a separate function. - // registered := make(map[int64]*goMigration) // Add user-registered Go migrations. for version, m := range cfg.registered { @@ -87,7 +111,7 @@ func NewProvider(dialect database.Dialect, db *sql.DB, fsys fs.FS, opts ...Provi } // Add init() functions. This is a bit ugly because we need to convert from the old Migration // struct to the new goMigration struct. - for version, m := range registeredGoMigrations { + for version, m := range global { if _, ok := registered[version]; ok { return nil, fmt.Errorf("go migration with version %d already registered", version) } @@ -103,27 +127,27 @@ func NewProvider(dialect database.Dialect, db *sql.DB, fsys fs.FS, opts ...Provi } // Up if m.UpFnContext != nil { - g.up = &GoMigration{ + g.up = &GoMigrationFunc{ Run: m.UpFnContext, } } else if m.UpFnNoTxContext != nil { - g.up = &GoMigration{ + g.up = &GoMigrationFunc{ RunNoTx: m.UpFnNoTxContext, } } // Down if m.DownFnContext != nil { - g.down = &GoMigration{ + g.down = &GoMigrationFunc{ Run: m.DownFnContext, } } else if m.DownFnNoTxContext != nil { - g.down = &GoMigration{ + g.down = &GoMigrationFunc{ RunNoTx: m.DownFnNoTxContext, } } registered[version] = g } - migrations, err := merge(sources, registered) + migrations, err := merge(filesystemSources, registered) if err != nil { return nil, err } @@ -139,21 +163,6 @@ func NewProvider(dialect database.Dialect, db *sql.DB, fsys fs.FS, opts ...Provi }, nil } -// Provider is a goose migration provider. -type Provider struct { - // mu protects all accesses to the provider and must be held when calling operations on the - // database. - mu sync.Mutex - - db *sql.DB - fsys fs.FS - cfg config - store database.Store - - // migrations are ordered by version in ascending order. - migrations []*migration -} - // Status returns the status of all migrations, merging the list of migrations from the database and // filesystem. The returned items are ordered by version, in ascending order. func (p *Provider) Status(ctx context.Context) ([]*MigrationStatus, error) { diff --git a/internal/provider/provider_options.go b/internal/provider/provider_options.go index 50b4d3f84..dd29ee4a9 100644 --- a/internal/provider/provider_options.go +++ b/internal/provider/provider_options.go @@ -96,8 +96,8 @@ func WithExcludes(excludes []string) ProviderOption { }) } -// GoMigration is a user-defined Go migration, registered using the option [WithGoMigration]. -type GoMigration struct { +// GoMigrationFunc is a user-defined Go migration, registered using the option [WithGoMigration]. +type GoMigrationFunc struct { // One of the following must be set: Run func(context.Context, *sql.Tx) error // -- OR -- @@ -109,7 +109,7 @@ type GoMigration struct { // If WithGoMigration is called multiple times with the same version, an error is returned. Both up // and down [GoMigration] may be nil. But if set, exactly one of Run or RunNoTx functions must be // set. -func WithGoMigration(version int64, up, down *GoMigration) ProviderOption { +func WithGoMigration(version int64, up, down *GoMigrationFunc) ProviderOption { return configFunc(func(c *config) error { if version < 1 { return errors.New("version must be greater than zero") @@ -143,25 +143,27 @@ func WithGoMigration(version int64, up, down *GoMigration) ProviderOption { }) } -// WithAllowMissing allows the provider to apply missing (out-of-order) migrations. +// WithAllowedMissing allows the provider to apply missing (out-of-order) migrations. By default, +// goose will raise an error if it encounters a missing migration. // // Example: migrations 1,3 are applied and then version 2,6 are introduced. If this option is true, // then goose will apply 2 (missing) and 6 (new) instead of raising an error. The final order of // applied migrations will be: 1,3,2,6. Out-of-order migrations are always applied first, followed // by new migrations. -func WithAllowMissing(b bool) ProviderOption { +func WithAllowedMissing(b bool) ProviderOption { return configFunc(func(c *config) error { c.allowMissing = b return nil }) } -// WithNoVersioning disables versioning. Disabling versioning allows applying migrations without -// tracking the versions in the database schema table. Useful for tests, seeding a database or -// running ad-hoc queries. -func WithNoVersioning(b bool) ProviderOption { +// WithDisabledVersioning disables versioning. Disabling versioning allows applying migrations +// without tracking the versions in the database schema table. Useful for tests, seeding a database +// or running ad-hoc queries. By default, goose will track all versions in the database schema +// table. +func WithDisabledVersioning(b bool) ProviderOption { return configFunc(func(c *config) error { - c.noVersioning = b + c.disableVersioning = b return nil }) } @@ -181,8 +183,8 @@ type config struct { sessionLocker lock.SessionLocker // Feature - noVersioning bool - allowMissing bool + disableVersioning bool + allowMissing bool } type configFunc func(*config) error diff --git a/internal/provider/provider_test.go b/internal/provider/provider_test.go index 81ec517fc..3c1268854 100644 --- a/internal/provider/provider_test.go +++ b/internal/provider/provider_test.go @@ -35,12 +35,12 @@ func TestProvider(t *testing.T) { check.NoError(t, err) sources := p.ListSources() check.Equal(t, len(sources), 2) - check.Equal(t, sources[0], provider.NewSource(provider.TypeSQL, "001_foo.sql", 1)) - check.Equal(t, sources[1], provider.NewSource(provider.TypeSQL, "002_bar.sql", 2)) + check.Equal(t, sources[0], newSource(provider.TypeSQL, "001_foo.sql", 1)) + check.Equal(t, sources[1], newSource(provider.TypeSQL, "002_bar.sql", 2)) t.Run("duplicate_go", func(t *testing.T) { // Not parallel because it modifies global state. - register := []*provider.Migration{ + register := []*provider.MigrationCopy{ { Version: 1, Source: "00001_users_table.go", Registered: true, UpFnContext: nil, @@ -62,13 +62,13 @@ func TestProvider(t *testing.T) { db := newDB(t) // explicit _, err := provider.NewProvider(database.DialectSQLite3, db, nil, - provider.WithGoMigration(1, &provider.GoMigration{Run: nil}, &provider.GoMigration{Run: nil}), + provider.WithGoMigration(1, &provider.GoMigrationFunc{Run: nil}, &provider.GoMigrationFunc{Run: nil}), ) check.HasError(t, err) check.Contains(t, err.Error(), "go migration with version 1 must have an up function") }) t.Run("duplicate_up", func(t *testing.T) { - err := provider.SetGlobalGoMigrations([]*provider.Migration{ + err := provider.SetGlobalGoMigrations([]*provider.MigrationCopy{ { Version: 1, Source: "00001_users_table.go", Registered: true, UpFnContext: func(context.Context, *sql.Tx) error { return nil }, @@ -80,7 +80,7 @@ func TestProvider(t *testing.T) { check.Contains(t, err.Error(), "must specify exactly one of UpFnContext or UpFnNoTxContext") }) t.Run("duplicate_down", func(t *testing.T) { - err := provider.SetGlobalGoMigrations([]*provider.Migration{ + err := provider.SetGlobalGoMigrations([]*provider.MigrationCopy{ { Version: 1, Source: "00001_users_table.go", Registered: true, DownFnContext: func(context.Context, *sql.Tx) error { return nil }, @@ -92,7 +92,7 @@ func TestProvider(t *testing.T) { check.Contains(t, err.Error(), "must specify exactly one of DownFnContext or DownFnNoTxContext") }) t.Run("not_registered", func(t *testing.T) { - err := provider.SetGlobalGoMigrations([]*provider.Migration{ + err := provider.SetGlobalGoMigrations([]*provider.MigrationCopy{ { Version: 1, Source: "00001_users_table.go", }, @@ -102,7 +102,7 @@ func TestProvider(t *testing.T) { check.Contains(t, err.Error(), "migration must be registered") }) t.Run("zero_not_allowed", func(t *testing.T) { - err := provider.SetGlobalGoMigrations([]*provider.Migration{ + err := provider.SetGlobalGoMigrations([]*provider.MigrationCopy{ { Version: 0, }, diff --git a/internal/provider/run.go b/internal/provider/run.go index 79a1b2c50..c5f63f13b 100644 --- a/internal/provider/run.go +++ b/internal/provider/run.go @@ -34,13 +34,13 @@ func (p *Provider) up(ctx context.Context, upByOne bool, version int64) (_ []*Mi return nil, nil } var apply []*migration - if p.cfg.noVersioning { + if p.cfg.disableVersioning { apply = p.migrations } else { - // optimize(mf): Listing all migrations from the database isn't great. This is only required to - // support the allow missing (out-of-order) feature. For users that don't use this feature, we - // could just query the database for the current max version and then apply migrations greater - // than that version. + // optimize(mf): Listing all migrations from the database isn't great. This is only required + // to support the allow missing (out-of-order) feature. For users that don't use this + // feature, we could just query the database for the current max version and then apply + // migrations greater than that version. dbMigrations, err := p.store.ListMigrations(ctx, conn) if err != nil { return nil, err @@ -76,13 +76,13 @@ func (p *Provider) resolveUpMigrations( dbMaxVersion = m.Version } } - missingMigrations := findMissingMigrations(dbVersions, p.migrations) + missingMigrations := checkMissingMigrations(dbVersions, p.migrations) // feat(mf): It is very possible someone may want to apply ONLY new migrations and skip missing // migrations entirely. At the moment this is not supported, but leaving this comment because // that's where that logic would be handled. // - // For example, if db has 1,4 applied and 2,3,5 are new, we would apply only 5 and skip 2,3. - // Not sure if this is a common use case, but it's possible. + // For example, if db has 1,4 applied and 2,3,5 are new, we would apply only 5 and skip 2,3. Not + // sure if this is a common use case, but it's possible. if len(missingMigrations) > 0 && !p.cfg.allowMissing { var collected []string for _, v := range missingMigrations { @@ -127,7 +127,7 @@ func (p *Provider) down(ctx context.Context, downByOne bool, version int64) (_ [ if len(p.migrations) == 0 { return nil, nil } - if p.cfg.noVersioning { + if p.cfg.disableVersioning { downMigrations := p.migrations if downByOne { last := p.migrations[len(p.migrations)-1] @@ -245,7 +245,7 @@ func (p *Provider) runIndividually( if err := m.run(ctx, tx, direction); err != nil { return err } - if p.cfg.noVersioning { + if p.cfg.disableVersioning { return nil } if direction { @@ -268,7 +268,7 @@ func (p *Provider) runIndividually( return err } } - if p.cfg.noVersioning { + if p.cfg.disableVersioning { return nil } if direction { @@ -329,7 +329,7 @@ func (p *Provider) initialize(ctx context.Context) (*sql.Conn, func() error, err } // If versioning is enabled, ensure the version table exists. For ad-hoc migrations, we don't // need the version table because there is no versioning. - if !p.cfg.noVersioning { + if !p.cfg.disableVersioning { if err := p.ensureVersionTable(ctx, conn); err != nil { return nil, nil, multierr.Append(err, cleanup()) } @@ -370,7 +370,7 @@ func (p *Provider) ensureVersionTable(ctx context.Context, conn *sql.Conn) (retE if err := p.store.CreateVersionTable(ctx, tx); err != nil { return err } - if p.cfg.noVersioning { + if p.cfg.disableVersioning { return nil } return p.store.Insert(ctx, tx, database.InsertRequest{Version: 0}) @@ -382,9 +382,9 @@ type missingMigration struct { filename string } -// findMissingMigrations returns a list of migrations that are missing from the database. A missing +// checkMissingMigrations returns a list of migrations that are missing from the database. A missing // migration is one that has a version less than the max version in the database. -func findMissingMigrations( +func checkMissingMigrations( dbMigrations []*database.ListMigrationsResult, fsMigrations []*migration, ) []missingMigration { diff --git a/internal/provider/run_test.go b/internal/provider/run_test.go index 8d3e4463f..b88069c87 100644 --- a/internal/provider/run_test.go +++ b/internal/provider/run_test.go @@ -78,24 +78,24 @@ func TestProviderRun(t *testing.T) { res, err := p.Up(ctx) check.NoError(t, err) check.Number(t, len(res), numCount) - assertResult(t, res[0], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "up", false) - assertResult(t, res[1], provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), "up", false) - assertResult(t, res[2], provider.NewSource(provider.TypeSQL, "00003_comments_table.sql", 3), "up", false) - assertResult(t, res[3], provider.NewSource(provider.TypeSQL, "00004_insert_data.sql", 4), "up", false) - assertResult(t, res[4], provider.NewSource(provider.TypeSQL, "00005_posts_view.sql", 5), "up", false) - assertResult(t, res[5], provider.NewSource(provider.TypeSQL, "00006_empty_up.sql", 6), "up", true) - assertResult(t, res[6], provider.NewSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), "up", true) + assertResult(t, res[0], newSource(provider.TypeSQL, "00001_users_table.sql", 1), "up", false) + assertResult(t, res[1], newSource(provider.TypeSQL, "00002_posts_table.sql", 2), "up", false) + assertResult(t, res[2], newSource(provider.TypeSQL, "00003_comments_table.sql", 3), "up", false) + assertResult(t, res[3], newSource(provider.TypeSQL, "00004_insert_data.sql", 4), "up", false) + assertResult(t, res[4], newSource(provider.TypeSQL, "00005_posts_view.sql", 5), "up", false) + assertResult(t, res[5], newSource(provider.TypeSQL, "00006_empty_up.sql", 6), "up", true) + assertResult(t, res[6], newSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), "up", true) // Test Down res, err = p.DownTo(ctx, 0) check.NoError(t, err) check.Number(t, len(res), numCount) - assertResult(t, res[0], provider.NewSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), "down", true) - assertResult(t, res[1], provider.NewSource(provider.TypeSQL, "00006_empty_up.sql", 6), "down", true) - assertResult(t, res[2], provider.NewSource(provider.TypeSQL, "00005_posts_view.sql", 5), "down", false) - assertResult(t, res[3], provider.NewSource(provider.TypeSQL, "00004_insert_data.sql", 4), "down", false) - assertResult(t, res[4], provider.NewSource(provider.TypeSQL, "00003_comments_table.sql", 3), "down", false) - assertResult(t, res[5], provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), "down", false) - assertResult(t, res[6], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "down", false) + assertResult(t, res[0], newSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), "down", true) + assertResult(t, res[1], newSource(provider.TypeSQL, "00006_empty_up.sql", 6), "down", true) + assertResult(t, res[2], newSource(provider.TypeSQL, "00005_posts_view.sql", 5), "down", false) + assertResult(t, res[3], newSource(provider.TypeSQL, "00004_insert_data.sql", 4), "down", false) + assertResult(t, res[4], newSource(provider.TypeSQL, "00003_comments_table.sql", 3), "down", false) + assertResult(t, res[5], newSource(provider.TypeSQL, "00002_posts_table.sql", 2), "down", false) + assertResult(t, res[6], newSource(provider.TypeSQL, "00001_users_table.sql", 1), "down", false) }) t.Run("up_and_down_by_one", func(t *testing.T) { ctx := context.Background() @@ -149,8 +149,8 @@ func TestProviderRun(t *testing.T) { results, err := p.UpTo(ctx, upToVersion) check.NoError(t, err) check.Number(t, len(results), upToVersion) - assertResult(t, results[0], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "up", false) - assertResult(t, results[1], provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), "up", false) + assertResult(t, results[0], newSource(provider.TypeSQL, "00001_users_table.sql", 1), "up", false) + assertResult(t, results[1], newSource(provider.TypeSQL, "00002_posts_table.sql", 2), "up", false) // Fetch the goose version from DB currentVersion, err := p.GetDBVersion(ctx) check.NoError(t, err) @@ -272,26 +272,26 @@ func TestProviderRun(t *testing.T) { status, err := p.Status(ctx) check.NoError(t, err) check.Number(t, len(status), numCount) - assertStatus(t, status[0], provider.StatePending, provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), true) - assertStatus(t, status[1], provider.StatePending, provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), true) - assertStatus(t, status[2], provider.StatePending, provider.NewSource(provider.TypeSQL, "00003_comments_table.sql", 3), true) - assertStatus(t, status[3], provider.StatePending, provider.NewSource(provider.TypeSQL, "00004_insert_data.sql", 4), true) - assertStatus(t, status[4], provider.StatePending, provider.NewSource(provider.TypeSQL, "00005_posts_view.sql", 5), true) - assertStatus(t, status[5], provider.StatePending, provider.NewSource(provider.TypeSQL, "00006_empty_up.sql", 6), true) - assertStatus(t, status[6], provider.StatePending, provider.NewSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), true) + assertStatus(t, status[0], provider.StatePending, newSource(provider.TypeSQL, "00001_users_table.sql", 1), true) + assertStatus(t, status[1], provider.StatePending, newSource(provider.TypeSQL, "00002_posts_table.sql", 2), true) + assertStatus(t, status[2], provider.StatePending, newSource(provider.TypeSQL, "00003_comments_table.sql", 3), true) + assertStatus(t, status[3], provider.StatePending, newSource(provider.TypeSQL, "00004_insert_data.sql", 4), true) + assertStatus(t, status[4], provider.StatePending, newSource(provider.TypeSQL, "00005_posts_view.sql", 5), true) + assertStatus(t, status[5], provider.StatePending, newSource(provider.TypeSQL, "00006_empty_up.sql", 6), true) + assertStatus(t, status[6], provider.StatePending, newSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), true) // Apply all migrations _, err = p.Up(ctx) check.NoError(t, err) status, err = p.Status(ctx) check.NoError(t, err) check.Number(t, len(status), numCount) - assertStatus(t, status[0], provider.StateApplied, provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), false) - assertStatus(t, status[1], provider.StateApplied, provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), false) - assertStatus(t, status[2], provider.StateApplied, provider.NewSource(provider.TypeSQL, "00003_comments_table.sql", 3), false) - assertStatus(t, status[3], provider.StateApplied, provider.NewSource(provider.TypeSQL, "00004_insert_data.sql", 4), false) - assertStatus(t, status[4], provider.StateApplied, provider.NewSource(provider.TypeSQL, "00005_posts_view.sql", 5), false) - assertStatus(t, status[5], provider.StateApplied, provider.NewSource(provider.TypeSQL, "00006_empty_up.sql", 6), false) - assertStatus(t, status[6], provider.StateApplied, provider.NewSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), false) + assertStatus(t, status[0], provider.StateApplied, newSource(provider.TypeSQL, "00001_users_table.sql", 1), false) + assertStatus(t, status[1], provider.StateApplied, newSource(provider.TypeSQL, "00002_posts_table.sql", 2), false) + assertStatus(t, status[2], provider.StateApplied, newSource(provider.TypeSQL, "00003_comments_table.sql", 3), false) + assertStatus(t, status[3], provider.StateApplied, newSource(provider.TypeSQL, "00004_insert_data.sql", 4), false) + assertStatus(t, status[4], provider.StateApplied, newSource(provider.TypeSQL, "00005_posts_view.sql", 5), false) + assertStatus(t, status[5], provider.StateApplied, newSource(provider.TypeSQL, "00006_empty_up.sql", 6), false) + assertStatus(t, status[6], provider.StateApplied, newSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), false) }) t.Run("tx_partial_errors", func(t *testing.T) { countOwners := func(db *sql.DB) (int, error) { @@ -333,7 +333,7 @@ INSERT INTO owners (owner_name) VALUES ('seed-user-3'); check.Contains(t, expected.Err.Error(), "SQL logic error: no such table: invalid_table (1)") // Check Results field check.Number(t, len(expected.Applied), 1) - assertResult(t, expected.Applied[0], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "up", false) + assertResult(t, expected.Applied[0], newSource(provider.TypeSQL, "00001_users_table.sql", 1), "up", false) // Check Failed field check.Bool(t, expected.Failed != nil, true) assertSource(t, expected.Failed.Source, provider.TypeSQL, "00002_partial_error.sql", 2) @@ -351,9 +351,9 @@ INSERT INTO owners (owner_name) VALUES ('seed-user-3'); status, err := p.Status(ctx) check.NoError(t, err) check.Number(t, len(status), 3) - assertStatus(t, status[0], provider.StateApplied, provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), false) - assertStatus(t, status[1], provider.StatePending, provider.NewSource(provider.TypeSQL, "00002_partial_error.sql", 2), true) - assertStatus(t, status[2], provider.StatePending, provider.NewSource(provider.TypeSQL, "00003_insert_data.sql", 3), true) + assertStatus(t, status[0], provider.StateApplied, newSource(provider.TypeSQL, "00001_users_table.sql", 1), false) + assertStatus(t, status[1], provider.StatePending, newSource(provider.TypeSQL, "00002_partial_error.sql", 2), true) + assertStatus(t, status[2], provider.StatePending, newSource(provider.TypeSQL, "00003_insert_data.sql", 3), true) }) } @@ -488,7 +488,7 @@ func TestNoVersioning(t *testing.T) { ) p, err := provider.NewProvider(database.DialectSQLite3, db, fsys, provider.WithVerbose(testing.Verbose()), - provider.WithNoVersioning(false), // This is the default. + provider.WithDisabledVersioning(false), // This is the default. ) check.Number(t, len(p.ListSources()), 3) check.NoError(t, err) @@ -501,7 +501,7 @@ func TestNoVersioning(t *testing.T) { fsys := os.DirFS(filepath.Join("testdata", "no-versioning", "seed")) p, err := provider.NewProvider(database.DialectSQLite3, db, fsys, provider.WithVerbose(testing.Verbose()), - provider.WithNoVersioning(true), // Provider with no versioning. + provider.WithDisabledVersioning(true), // Provider with no versioning. ) check.NoError(t, err) check.Number(t, len(p.ListSources()), 2) @@ -553,7 +553,7 @@ func TestAllowMissing(t *testing.T) { t.Run("missing_now_allowed", func(t *testing.T) { db := newDB(t) p, err := provider.NewProvider(database.DialectSQLite3, db, newFsys(), - provider.WithAllowMissing(false), + provider.WithAllowedMissing(false), ) check.NoError(t, err) @@ -608,7 +608,7 @@ func TestAllowMissing(t *testing.T) { t.Run("missing_allowed", func(t *testing.T) { db := newDB(t) p, err := provider.NewProvider(database.DialectSQLite3, db, newFsys(), - provider.WithAllowMissing(true), + provider.WithAllowedMissing(true), ) check.NoError(t, err) @@ -703,7 +703,7 @@ func TestGoOnly(t *testing.T) { t.Run("with_tx", func(t *testing.T) { ctx := context.Background() - register := []*provider.Migration{ + register := []*provider.MigrationCopy{ { Version: 1, Source: "00001_users_table.go", Registered: true, UpFnContext: newTxFn("CREATE TABLE users (id INTEGER PRIMARY KEY)"), @@ -718,8 +718,8 @@ func TestGoOnly(t *testing.T) { p, err := provider.NewProvider(database.DialectSQLite3, db, nil, provider.WithGoMigration( 2, - &provider.GoMigration{Run: newTxFn("INSERT INTO users (id) VALUES (1), (2), (3)")}, - &provider.GoMigration{Run: newTxFn("DELETE FROM users")}, + &provider.GoMigrationFunc{Run: newTxFn("INSERT INTO users (id) VALUES (1), (2), (3)")}, + &provider.GoMigrationFunc{Run: newTxFn("DELETE FROM users")}, ), ) check.NoError(t, err) @@ -730,29 +730,29 @@ func TestGoOnly(t *testing.T) { // Apply migration 1 res, err := p.UpByOne(ctx) check.NoError(t, err) - assertResult(t, res, provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "up", false) + assertResult(t, res, newSource(provider.TypeGo, "00001_users_table.go", 1), "up", false) check.Number(t, countUser(db), 0) check.Bool(t, tableExists(t, db, "users"), true) // Apply migration 2 res, err = p.UpByOne(ctx) check.NoError(t, err) - assertResult(t, res, provider.NewSource(provider.TypeGo, "", 2), "up", false) + assertResult(t, res, newSource(provider.TypeGo, "", 2), "up", false) check.Number(t, countUser(db), 3) // Rollback migration 2 res, err = p.Down(ctx) check.NoError(t, err) - assertResult(t, res, provider.NewSource(provider.TypeGo, "", 2), "down", false) + assertResult(t, res, newSource(provider.TypeGo, "", 2), "down", false) check.Number(t, countUser(db), 0) // Rollback migration 1 res, err = p.Down(ctx) check.NoError(t, err) - assertResult(t, res, provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "down", false) + assertResult(t, res, newSource(provider.TypeGo, "00001_users_table.go", 1), "down", false) // Check table does not exist check.Bool(t, tableExists(t, db, "users"), false) }) t.Run("with_db", func(t *testing.T) { ctx := context.Background() - register := []*provider.Migration{ + register := []*provider.MigrationCopy{ { Version: 1, Source: "00001_users_table.go", Registered: true, UpFnNoTxContext: newDBFn("CREATE TABLE users (id INTEGER PRIMARY KEY)"), @@ -767,8 +767,8 @@ func TestGoOnly(t *testing.T) { p, err := provider.NewProvider(database.DialectSQLite3, db, nil, provider.WithGoMigration( 2, - &provider.GoMigration{RunNoTx: newDBFn("INSERT INTO users (id) VALUES (1), (2), (3)")}, - &provider.GoMigration{RunNoTx: newDBFn("DELETE FROM users")}, + &provider.GoMigrationFunc{RunNoTx: newDBFn("INSERT INTO users (id) VALUES (1), (2), (3)")}, + &provider.GoMigrationFunc{RunNoTx: newDBFn("DELETE FROM users")}, ), ) check.NoError(t, err) @@ -779,23 +779,23 @@ func TestGoOnly(t *testing.T) { // Apply migration 1 res, err := p.UpByOne(ctx) check.NoError(t, err) - assertResult(t, res, provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "up", false) + assertResult(t, res, newSource(provider.TypeGo, "00001_users_table.go", 1), "up", false) check.Number(t, countUser(db), 0) check.Bool(t, tableExists(t, db, "users"), true) // Apply migration 2 res, err = p.UpByOne(ctx) check.NoError(t, err) - assertResult(t, res, provider.NewSource(provider.TypeGo, "", 2), "up", false) + assertResult(t, res, newSource(provider.TypeGo, "", 2), "up", false) check.Number(t, countUser(db), 3) // Rollback migration 2 res, err = p.Down(ctx) check.NoError(t, err) - assertResult(t, res, provider.NewSource(provider.TypeGo, "", 2), "down", false) + assertResult(t, res, newSource(provider.TypeGo, "", 2), "down", false) check.Number(t, countUser(db), 0) // Rollback migration 1 res, err = p.Down(ctx) check.NoError(t, err) - assertResult(t, res, provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "down", false) + assertResult(t, res, newSource(provider.TypeGo, "00001_users_table.go", 1), "down", false) // Check table does not exist check.Bool(t, tableExists(t, db, "users"), false) }) @@ -1148,6 +1148,14 @@ func assertSource(t *testing.T, got provider.Source, typ provider.MigrationType, } } +func newSource(t provider.MigrationType, fullpath string, version int64) provider.Source { + return provider.Source{ + Type: t, + Path: fullpath, + Version: version, + } +} + func newMapFile(data string) *fstest.MapFile { return &fstest.MapFile{ Data: []byte(data),