From e2a6c74467d47838773963a9376e77310f3376a7 Mon Sep 17 00:00:00 2001 From: Zherebko Dmitry Date: Thu, 12 May 2022 23:21:15 +0300 Subject: [PATCH] add ability to replace golang migration registry with custom one --- migrate.go | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/migrate.go b/migrate.go index 3986746ca..20f4ea2d3 100644 --- a/migrate.go +++ b/migrate.go @@ -125,20 +125,34 @@ func (ms Migrations) String() string { // AddMigration adds a migration. func AddMigration(up func(*sql.Tx) error, down func(*sql.Tx) error) { - _, filename, _, _ := runtime.Caller(1) - AddNamedMigration(filename, up, down) + AddMigrationToRegistry(registeredGoMigrations, up, down) } // AddNamedMigration : Add a named migration. func AddNamedMigration(filename string, up func(*sql.Tx) error, down func(*sql.Tx) error) { + AddNamedMigrationToRegistry(registeredGoMigrations, filename, up, down) +} + +// AddMigration adds a migration. +func AddMigrationToRegistry(registry map[int64]*Migration, up func(*sql.Tx) error, down func(*sql.Tx) error) { + _, filename, _, _ := runtime.Caller(1) + AddNamedMigrationToRegistry(registry, filename, up, down) +} + +// AddNamedMigration : Add a named migration. +func AddNamedMigrationToRegistry(registry map[int64]*Migration, filename string, up func(*sql.Tx) error, down func(*sql.Tx) error) { v, _ := NumericComponent(filename) migration := &Migration{Version: v, Next: -1, Previous: -1, Registered: true, UpFn: up, DownFn: down, Source: filename} - if existing, ok := registeredGoMigrations[v]; ok { + if existing, ok := registry[v]; ok { panic(fmt.Sprintf("failed to add migration %q: version conflicts with %q", filename, existing.Source)) } - registeredGoMigrations[v] = migration + registry[v] = migration +} + +func SetMigrationsRegistry(m map[int64]*Migration) { + registeredGoMigrations = m } func collectMigrationsFS(fsys fs.FS, dirpath string, current, target int64) (Migrations, error) {