Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
nfjBill committed Feb 12, 2022
1 parent a56a249 commit 524e87a
Show file tree
Hide file tree
Showing 87 changed files with 30,721 additions and 0 deletions.
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,9 @@

# Dependency directories (remove the comment below to include it)
# vendor/

# IDE
.idea

# sys
go.sum
49 changes: 49 additions & 0 deletions clauses/merge.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package clauses

import (
"gorm.io/gorm/clause"
)

type Merge struct {
Table clause.Table
Using []clause.Interface
On []clause.Expression
}

func (merge Merge) Name() string {
return "MERGE"
}

func MergeDefaultExcludeName() string {
return "exclude"
}

// Build build from clause
func (merge Merge) Build(builder clause.Builder) {
clause.Insert{}.Build(builder)
builder.WriteString(" USING (")
for idx, iface := range merge.Using {
if idx > 0 {
builder.WriteByte(' ')
}
builder.WriteString(iface.Name())
builder.WriteByte(' ')
iface.Build(builder)
}
builder.WriteString(") ")
builder.WriteString(MergeDefaultExcludeName())
builder.WriteString(" ON (")
for idx, on := range merge.On {
if idx > 0 {
builder.WriteString(", ")
}
on.Build(builder)
}
builder.WriteString(")")
}

// MergeClause merge values clauses
func (merge Merge) MergeClause(clause *clause.Clause) {
clause.Name = merge.Name()
clause.Expression = merge
}
10 changes: 10 additions & 0 deletions clauses/returning_into.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package clauses

import (
"gorm.io/gorm/clause"
)

type ReturningInto struct {
Variables []clause.Column
Into []*clause.Values
}
39 changes: 39 additions & 0 deletions clauses/when_matched.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package clauses

import (
"gorm.io/gorm/clause"
)

type WhenMatched struct {
clause.Set
Where, Delete clause.Where
}

func (w WhenMatched) Name() string {
return "WHEN MATCHED"
}

func (w WhenMatched) Build(builder clause.Builder) {
if len(w.Set) > 0 {
builder.WriteString(" THEN")
builder.WriteString(" UPDATE ")
builder.WriteString(w.Name())
builder.WriteByte(' ')
w.Build(builder)

buildWhere := func(where clause.Where) {
builder.WriteString(where.Name())
builder.WriteByte(' ')
where.Build(builder)
}

if len(w.Where.Exprs) > 0 {
buildWhere(w.Where)
}

if len(w.Delete.Exprs) > 0 {
builder.WriteString(" DELETE ")
buildWhere(w.Delete)
}
}
}
32 changes: 32 additions & 0 deletions clauses/when_not_matched.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package clauses

import (
"gorm.io/gorm/clause"
)

type WhenNotMatched struct {
clause.Values
Where clause.Where
}

func (w WhenNotMatched) Name() string {
return "WHEN NOT MATCHED"
}

func (w WhenNotMatched) Build(builder clause.Builder) {
if len(w.Columns) > 0 {
if len(w.Values.Values) != 1 {
panic("cannot insert more than one rows due to DM SQL language restriction")
}

builder.WriteString(" THEN")
builder.WriteString(" INSERT ")
w.Build(builder)

if len(w.Where.Exprs) > 0 {
builder.WriteString(w.Where.Name())
builder.WriteByte(' ')
w.Where.Build(builder)
}
}
}
154 changes: 154 additions & 0 deletions create.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
package dm

import (
"bytes"
"database/sql"
"reflect"

"github.com/thoas/go-funk"
"gorm.io/gorm"
"gorm.io/gorm/callbacks"
"gorm.io/gorm/clause"
gormSchema "gorm.io/gorm/schema"

"github.com/nfjBill/gorm-driver-dm/clauses"
)

func Create(db *gorm.DB) {
stmt := db.Statement
schema := stmt.Schema
boundVars := make(map[string]int)

if stmt == nil || schema == nil {
return
}

hasDefaultValues := len(schema.FieldsWithDefaultDBValue) > 0

if !stmt.Unscoped {
for _, c := range schema.CreateClauses {
stmt.AddClause(c)
}
}

if stmt.SQL.String() == "" {
values := callbacks.ConvertToCreateValues(stmt)
onConflict, hasConflict := stmt.Clauses["ON CONFLICT"].Expression.(clause.OnConflict)
// are all columns in value the primary fields in schema only?
if hasConflict && funk.Contains(
funk.Map(values.Columns, func(c clause.Column) string { return c.Name }),
funk.Map(schema.PrimaryFields, func(field *gormSchema.Field) string { return field.DBName }),
) {
stmt.AddClauseIfNotExists(clauses.Merge{
Using: []clause.Interface{
clause.Select{
Columns: funk.Map(values.Columns, func(column clause.Column) clause.Column {
// HACK: I can not come up with a better alternative for now
// I want to add a value to the list of variable and then capture the bind variable position as well
buf := bytes.NewBufferString("")
stmt.Vars = append(stmt.Vars, values.Values[0][funk.IndexOf(values.Columns, column)])
stmt.BindVarTo(buf, stmt, nil)

column.Alias = column.Name
// then the captured bind var will be the name
column.Name = buf.String()
return column
}).([]clause.Column),
},
clause.From{
Tables: []clause.Table{{Name: db.Dialector.(Dialector).DummyTableName()}},
},
},
On: funk.Map(schema.PrimaryFields, func(field *gormSchema.Field) clause.Expression {
return clause.Eq{
Column: clause.Column{Table: stmt.Table, Name: field.DBName},
Value: clause.Column{Table: clauses.MergeDefaultExcludeName(), Name: field.DBName},
}
}).([]clause.Expression),
})
stmt.AddClauseIfNotExists(clauses.WhenMatched{Set: onConflict.DoUpdates})
stmt.AddClauseIfNotExists(clauses.WhenNotMatched{Values: values})

stmt.Build("MERGE", "WHEN MATCHED", "WHEN NOT MATCHED")
} else {
stmt.AddClauseIfNotExists(clause.Insert{Table: clause.Table{Name: stmt.Table}})
stmt.AddClause(clause.Values{Columns: values.Columns, Values: [][]interface{}{values.Values[0]}})
if hasDefaultValues {
stmt.AddClauseIfNotExists(clause.Returning{
Columns: funk.Map(schema.FieldsWithDefaultDBValue, func(field *gormSchema.Field) clause.Column {
return clause.Column{Name: field.DBName}
}).([]clause.Column),
})
}
stmt.Build("INSERT", "VALUES")
// 返回自增主键
//stmt.Build("INSERT", "VALUES", "RETURNING")
//if hasDefaultValues {
// stmt.WriteString(" INTO ")
// for idx, field := range schema.FieldsWithDefaultDBValue {
// if idx > 0 {
// stmt.WriteByte(',')
// }
// boundVars[field.Name] = len(stmt.Vars)
// stmt.AddVar(stmt, sql.Out{Dest: reflect.New(field.FieldType).Interface()})
// }
//}
}

if !db.DryRun {
for idx, vals := range values.Values {
// HACK HACK: replace values one by one, assuming its value layout will be the same all the time, i.e. aligned
for idx, val := range vals {
switch v := val.(type) {
case bool:
if v {
val = 1
} else {
val = 0
}
}

stmt.Vars[idx] = val
}
// and then we insert each row one by one then put the returning values back (i.e. last return id => smart insert)
// we keep track of the index so that the sub-reflected value is also correct

// BIG BUG: what if any of the transactions failed? some result might already be inserted that dm is so
// sneaky that some transaction inserts will exceed the buffer and so will be pushed at unknown point,
// resulting in dangling row entries, so we might need to delete them if an error happens

switch result, err := stmt.ConnPool.ExecContext(stmt.Context, stmt.SQL.String(), stmt.Vars...); err {
case nil: // success
db.RowsAffected, _ = result.RowsAffected()

insertTo := stmt.ReflectValue
switch insertTo.Kind() {
case reflect.Slice, reflect.Array:
insertTo = insertTo.Index(idx)
}

if hasDefaultValues {
// bind returning value back to reflected value in the respective fields
funk.ForEach(
funk.Filter(schema.FieldsWithDefaultDBValue, func(field *gormSchema.Field) bool {
return funk.Contains(boundVars, field.Name)
}),
func(field *gormSchema.Field) {
switch insertTo.Kind() {
case reflect.Struct:
if err = field.Set(insertTo, stmt.Vars[boundVars[field.Name]].(sql.Out).Dest); err != nil {
db.AddError(err)
}
case reflect.Map:
// todo 设置id的值
}
},
)
}
default: // failure
db.AddError(err)
}
}
}
}
}
Loading

0 comments on commit 524e87a

Please sign in to comment.