Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored code to have a Provider. #360

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 41 additions & 19 deletions create.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,38 @@ import (
)

type tmplVars struct {
Version string
CamelName string
Version string
CamelName string
PackageName string
ProviderVar string
}

var (
sequential = false
)

// SetSequential set whether to use sequential versioning instead of timestamp based versioning
func SetSequential(s bool) {
sequential = s
defaultProvider.SetSequential(s)
}

// Create writes a new blank migration file.
// SetSequential set's whether to use sequential versioning instead of timestamp based versioning
func (p *Provider) SetSequential(s bool) { p.sequential = s }

// CreateWithTemplate writes a new blank migration file.
func CreateWithTemplate(db *sql.DB, dir string, tmpl *template.Template, name, migrationType string) error {
var version string
if sequential {
return defaultProvider.CreateWithTemplate(db, dir, tmpl, name, migrationType)
}

// CreateWithTemplate writes a new blank migration file.
func (p *Provider) CreateWithTemplate(_ *sql.DB, dir string, tmpl *template.Template, name, migrationType string) error {
timefn := p.timeFn
if p.timeFn == nil {
timefn = time.Now
}
version := timefn().Format(p.timestampFormat)
if p.baseDir != "" && (dir == "" || dir == ".") {
dir = p.baseDir
}
if p.sequential {
// always use DirFS here because it's modifying operation
migrations, err := collectMigrationsFS(osFS{}, dir, minVersion, maxVersion)
migrations, err := p.collectMigrationsFS(osFS{}, dir, minVersion, maxVersion)
if err != nil {
return err
}
Expand All @@ -43,8 +56,6 @@ func CreateWithTemplate(db *sql.DB, dir string, tmpl *template.Template, name, m
} else {
version = fmt.Sprintf(seqVersionTemplate, int64(1))
}
} else {
version = time.Now().Format(timestampFormat)
}

filename := fmt.Sprintf("%v_%v.%v", version, snakeCase(name), migrationType)
Expand All @@ -69,20 +80,27 @@ func CreateWithTemplate(db *sql.DB, dir string, tmpl *template.Template, name, m
defer f.Close()

vars := tmplVars{
Version: version,
CamelName: camelCase(name),
PackageName: p.packageName,
ProviderVar: p.providerVarName,
Version: version,
CamelName: camelCase(name),
}
if err := tmpl.Execute(f, vars); err != nil {
return fmt.Errorf("failed to execute tmpl: %w", err)
}

log.Printf("Created new file: %s\n", f.Name())
p.log.Printf("Created new file: %s\n", f.Name())
return nil
}

// Create writes a new blank migration file.
func Create(db *sql.DB, dir, name, migrationType string) error {
return CreateWithTemplate(db, dir, nil, name, migrationType)
return defaultProvider.Create(db, dir, name, migrationType)
}

// Create writes a new blank migration file.
func (p *Provider) Create(db *sql.DB, dir, name, migrationType string) error {
return p.CreateWithTemplate(db, dir, nil, name, migrationType)
}

var sqlMigrationTemplate = template.Must(template.New("goose.sql-migration").Parse(`-- +goose Up
Expand All @@ -96,15 +114,19 @@ SELECT 'down SQL query';
-- +goose StatementEnd
`))

var goSQLMigrationTemplate = template.Must(template.New("goose.go-migration").Parse(`package migrations
var goSQLMigrationTemplate = template.Must(template.New("goose.go-migration").Parse(`package {{.PackageName}}

import (
"database/sql"
"github.com/pressly/goose/v3"
{{if eq .ProviderVar ""}} "github.com/pressly/goose/v3" {{end}}
)

func init() {
{{- if eq .ProviderVar "" }}
goose.AddMigration(up{{.CamelName}}, down{{.CamelName}})
{{- else }}
{{.ProviderVar}}.AddMigration(up{{.CamelName}}, down{{.CamelName}})
{{end }}
}

func up{{.CamelName}}(tx *sql.Tx) error {
Expand Down
8 changes: 7 additions & 1 deletion db.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@ import (
// OpenDBWithDriver creates a connection to a database, and modifies goose
// internals to be compatible with the supplied driver by calling SetDialect.
func OpenDBWithDriver(driver string, dbstring string) (*sql.DB, error) {
if err := SetDialect(driver); err != nil {
return defaultProvider.OpenDBWithDriver(driver, dbstring)
}

// OpenDBWithDriver creates a connection to a database, and modifies goose
// internals to be compatible with the supplied driver by calling SetDialect.
func (p *Provider) OpenDBWithDriver(driver string, dbstring string) (*sql.DB, error) {
if err := p.SetDialect(driver); err != nil {
return nil, err
}

Expand Down
Loading