Skip to content

Commit

Permalink
[Feature] Make apply up/down work (#14)
Browse files Browse the repository at this point in the history
Added driver function RemoveMigration
Added test migrations for up/down
Made apply up/down work by running all migrations with the type specified
  • Loading branch information
atedesch1 authored May 18, 2022
1 parent d718c7f commit 9994744
Show file tree
Hide file tree
Showing 13 changed files with 201 additions and 47 deletions.
15 changes: 15 additions & 0 deletions applications/logs.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package applications

import (
"errors"
"fmt"
"regexp"
"strconv"
"strings"

Expand Down Expand Up @@ -30,3 +32,16 @@ func GetMigrationNumber(itemName string) (int, error) {
}
return migrationVersion, nil
}

func GetMigrationType(fileName string) (string, error) {
re := regexp.MustCompile(`\.(.*?)\.`) // gets string in between dots
match := re.FindStringSubmatch(fileName)
if len(match) > 1 {
migrationType := match[1]
if migrationType != "up" && migrationType != "down" {
return "", errors.New("migration type should be either up/down")
}
return migrationType, nil
}
return "", errors.New("migration file name in wrong format")
}
197 changes: 158 additions & 39 deletions cmd/apply.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ import (
"fmt"
"io/ioutil"
"os"
"path"
"path/filepath"
"sort"
"strings"
"time"

Expand All @@ -15,22 +16,39 @@ import (

type apply struct{}

type MigrationFile struct {
fullPath string
name string
}

type CommandArgs struct {
migrationFiles []MigrationFile
migrationType string
}

type Migrations struct {
files []MigrationFile
isUpType bool
}

func (a *apply) execute(args []string, databaseURL string, driver domain.Driver) error {
migrationType := args[0]
if migrationType != "up" && migrationType != "down" {
return errors.New("Apply's first argument should be either up/down.")
commandArgs, err := parseArgs(args)
if err != nil {
return err
}

isUpMigration := migrationType == "up"

folderName := args[1]
return driver.ExecuteTransaction(databaseURL, func() error {
previousMigrationNumber, err := applications.GetPreviousMigrationNumber(driver)
if err != nil {
return err
}

latestMigrationNumber, err := a.runFolderMigrations(isUpMigration, folderName, previousMigrationNumber, driver)
migrationsToRun, err := getMigrationsToRun(commandArgs)
if err != nil {
return err
}

latestMigrationNumber, err := a.runMigrations(migrationsToRun, previousMigrationNumber, driver)
if err != nil {
return err
}
Expand All @@ -43,57 +61,158 @@ func (a *apply) execute(args []string, databaseURL string, driver domain.Driver)
})
}

func (a *apply) runFolderMigrations(isUpMigration bool, folderName string, previousMigrationNumber int, driver domain.Driver) (int, error) {
latestMigrationNumber := 0
items, err := ioutil.ReadDir(folderName)
func parseArgs(args []string) (CommandArgs, error) {
var commandArgs CommandArgs

if len(args) < 2 {
return commandArgs, errors.New("arguments missing")
}

migrationType := args[0]
if migrationType != "up" && migrationType != "down" {
return commandArgs, errors.New("apply's first argument should be either up/down")
}

dir := args[1]
migrationFiles, err := getMigrationsFiles(dir)
if err != nil {
return 0, err
return commandArgs, err
}

username_service := applications.NewUserNameService()
username, err := username_service.GetUserName()
commandArgs.migrationType = migrationType
commandArgs.migrationFiles = migrationFiles

return commandArgs, nil
}

// reads directory and returns an array containing full paths of files inside
func getMigrationsFiles(dir string) ([]MigrationFile, error) {
migrationFiles := []MigrationFile{}

dirPath, err := filepath.Abs(dir)
if err != nil {
return 0, err
return []MigrationFile{}, err
}
fmt.Println("User detected: " + username)

for _, item := range items {
fileName := item.Name()
fullName := path.Join(folderName, fileName)
fileInfos, err := ioutil.ReadDir(dir)
if err != nil {
return []MigrationFile{}, err
}

for _, fileInfo := range fileInfos {
var migrationFile MigrationFile
migrationFile.name = fileInfo.Name()
migrationFile.fullPath = filepath.Join(dirPath, fileInfo.Name())
migrationFiles = append(migrationFiles, migrationFile)
}

return migrationFiles, err
}

// returns sorted migration files
// if migration of type up orders ascending, descending otherwise
func sortMigrationFiles(files []MigrationFile, isUpType bool) []MigrationFile {
if isUpType {
// sort by ascending
sort.Slice(files, func(i, j int) bool {
iNum, _ := applications.GetMigrationNumber(files[i].name)
jNum, _ := applications.GetMigrationNumber(files[j].name)
return iNum < jNum
})
} else {
// sort by descending
sort.Slice(files, func(i, j int) bool {
iNum, _ := applications.GetMigrationNumber(files[i].name)
jNum, _ := applications.GetMigrationNumber(files[j].name)
return iNum >= jNum
})
}
return files
}

// returns migrations files in folder that match type specified (up/down)
func getMigrationsToRun(args CommandArgs) (Migrations, error) {
var migrations Migrations

isUpType := args.migrationType == "up"
var files []MigrationFile

itemMigrationNumber, err := applications.GetMigrationNumber(fileName)
for _, file := range args.migrationFiles {
migrationType, err := applications.GetMigrationType(file.name)
if err != nil {
continue
return migrations, err
}
if itemMigrationNumber > latestMigrationNumber {
latestMigrationNumber = itemMigrationNumber
}
if itemMigrationNumber <= previousMigrationNumber {
valid, err := validateFileMigration(itemMigrationNumber, fullName, driver)
if err != nil {
return 0, err
}
if !valid {
return 0, fmt.Errorf("❌ invalid migration file %s", fileName)
}
continue

if migrationType == args.migrationType {
files = append(files, file)
}
err = a.applyMigrationScript(driver, fullName)
}

migrations.files = sortMigrationFiles(files, isUpType)
migrations.isUpType = isUpType

return migrations, nil
}

func (a *apply) runMigrations(migrations Migrations, previousMigrationNumber int, driver domain.Driver) (int, error) {
version := previousMigrationNumber

username_service := applications.NewUserNameService()
username, err := username_service.GetUserName()
if err != nil {
return 0, err
}
fmt.Println("User detected: " + username)

for _, file := range migrations.files {
migrationNum, err := applications.GetMigrationNumber(file.name)
if err != nil {
return 0, err
}

currentDate := time.Now().Format("2006-01-02 15:04:05")

hash, err := applications.GetSqlHash(fullName)
hash, err := applications.GetSqlHash(file.fullPath)
if err != nil {
return 0, err
}
err = driver.InsertLatestMigration(latestMigrationNumber, username, currentDate, hash)
if err != nil {
return 0, err

if migrations.isUpType {
if migrationNum == version+1 {
err = a.applyMigrationScript(driver, file.fullPath)
if err != nil {
return 0, err
}

version = version + 1
err = driver.InsertLatestMigration(version, username, currentDate, hash)
if err != nil {
return 0, err
}
} else {
valid, err := validateFileMigration(migrationNum, file.fullPath, driver)
if err != nil {
return 0, err
}
if !valid {
return 0, fmt.Errorf("❌ invalid migration file %s", file.name)
}
}
} else if !migrations.isUpType && migrationNum == version {
err = a.applyMigrationScript(driver, file.fullPath)
if err != nil {
return 0, err
}

err = driver.RemoveMigration(version)
if err != nil {
return 0, err
}
version = version - 1
}
}
return latestMigrationNumber, nil

return version, nil
}

func (a *apply) applyMigrationScript(driver domain.Driver, scriptName string) error {
Expand Down
4 changes: 2 additions & 2 deletions cmd/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ func validateFolderMigrations(folderName string, previousMigrationNumber int, dr
return 0, nil
}

func validateFileMigration(version int, fileName string, driver domain.Driver) (bool, error) {
hash_file, err := applications.GetSqlHash(fileName)
func validateFileMigration(version int, filePath string, driver domain.Driver) (bool, error) {
hash_file, err := applications.GetSqlHash(filePath)
if err != nil {
return false, err
}
Expand Down
4 changes: 2 additions & 2 deletions domain/constants.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package domain


const (
LogsTableName string = "migration_log"
LogsTableName string = "migration_log"
AppliedTableName string = "applied_migrations"
)
1 change: 1 addition & 0 deletions domain/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ type Driver interface {
GetLatestMigration() (int, error)
GetVersionHashing(version int) (string, error)
InsertLatestMigration(int, string, string, string) error
RemoveMigration(int) error
CreateBaseTable() error
HasBaseTable() (bool, error)

Expand Down
10 changes: 8 additions & 2 deletions drivers/mysql/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ package mysql

import (
"database/sql"
"log"

"github.com/pingcap/parser/mysql"
"github.com/pingcap/parser/types"
"log"

_ "github.com/go-sql-driver/mysql"
"github.com/jmoiron/sqlx"
Expand Down Expand Up @@ -91,6 +92,11 @@ func (d *mySqlDriver) InsertLatestMigration(version int, username string, curren
return err
}

func (d *mySqlDriver) RemoveMigration(migrationNum int) error {
_, err := d.tx.Exec(`DELETE FROM migration_log WHERE version = $1`, migrationNum)
return err
}

func (d *mySqlDriver) HasBaseTable() (bool, error) {
var installed bool
err := d.tx.QueryRow(`
Expand Down Expand Up @@ -145,7 +151,7 @@ func (x *extractor) parseTable(tableName string, stmt *ast.CreateTableStmt) *dom
columns[c.Name.Name.O] = x.parseColumn(c)
}
return &domain.Table{
Name: tableName,
Name: tableName,
Columns: columns,
}
}
Expand Down
13 changes: 11 additions & 2 deletions drivers/postgres/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func NewPostgresDriver() *postgresDriver {
return &postgresDriver{}
}

func (d *postgresDriver) Deparser() domain.Deparser{
func (d *postgresDriver) Deparser() domain.Deparser {
return &deparser{}
}

Expand Down Expand Up @@ -68,6 +68,9 @@ func (d *postgresDriver) GetLatestMigration() (int, error) {
var version int
err := d.tx.QueryRow(`SELECT version FROM migration_log ORDER BY version DESC LIMIT 1`).Scan(&version)
if err != nil {
if err == sql.ErrNoRows {
return 0, nil
}
return 0, err
}
return version, nil
Expand All @@ -87,6 +90,11 @@ func (d *postgresDriver) InsertLatestMigration(version int, username string, cur
return err
}

func (d *postgresDriver) RemoveMigration(migrationNum int) error {
_, err := d.tx.Exec(`DELETE FROM migration_log WHERE version = $1`, migrationNum)
return err
}

func (d *postgresDriver) HasBaseTable() (bool, error) {
var installed bool
err := d.tx.QueryRow(`SELECT EXISTS (
Expand Down Expand Up @@ -143,10 +151,11 @@ func (d *postgresDriver) parseTable(tableName string, parsedStatement *pg_query.
columns[columnDefinition.Colname] = d.parseColumn(columnDefinition)
}
return &domain.Table{
Name: tableName,
Name: tableName,
Columns: columns,
}
}

func (d *postgresDriver) parseView(parsedStatement *pg_query.ViewStmt) *domain.View {
// TODO
return &domain.View{
Expand Down
1 change: 1 addition & 0 deletions migrations/0001_test_migration.down.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
DROP TABLE users;
File renamed without changes.
2 changes: 2 additions & 0 deletions migrations/0002_test_migration.down.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
DROP VIEW user_phones;
ALTER TABLE users DROP COLUMN ddi;
File renamed without changes.
1 change: 1 addition & 0 deletions migrations/0003_test_migration.down.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALTER TABLE users DROP COLUMN abc;
File renamed without changes.

0 comments on commit 9994744

Please sign in to comment.