-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
87 changed files
with
30,721 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,3 +13,9 @@ | |
|
||
# Dependency directories (remove the comment below to include it) | ||
# vendor/ | ||
|
||
# IDE | ||
.idea | ||
|
||
# sys | ||
go.sum |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
package clauses | ||
|
||
import ( | ||
"gorm.io/gorm/clause" | ||
) | ||
|
||
type Merge struct { | ||
Table clause.Table | ||
Using []clause.Interface | ||
On []clause.Expression | ||
} | ||
|
||
func (merge Merge) Name() string { | ||
return "MERGE" | ||
} | ||
|
||
func MergeDefaultExcludeName() string { | ||
return "exclude" | ||
} | ||
|
||
// Build build from clause | ||
func (merge Merge) Build(builder clause.Builder) { | ||
clause.Insert{}.Build(builder) | ||
builder.WriteString(" USING (") | ||
for idx, iface := range merge.Using { | ||
if idx > 0 { | ||
builder.WriteByte(' ') | ||
} | ||
builder.WriteString(iface.Name()) | ||
builder.WriteByte(' ') | ||
iface.Build(builder) | ||
} | ||
builder.WriteString(") ") | ||
builder.WriteString(MergeDefaultExcludeName()) | ||
builder.WriteString(" ON (") | ||
for idx, on := range merge.On { | ||
if idx > 0 { | ||
builder.WriteString(", ") | ||
} | ||
on.Build(builder) | ||
} | ||
builder.WriteString(")") | ||
} | ||
|
||
// MergeClause merge values clauses | ||
func (merge Merge) MergeClause(clause *clause.Clause) { | ||
clause.Name = merge.Name() | ||
clause.Expression = merge | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
package clauses | ||
|
||
import ( | ||
"gorm.io/gorm/clause" | ||
) | ||
|
||
type ReturningInto struct { | ||
Variables []clause.Column | ||
Into []*clause.Values | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
package clauses | ||
|
||
import ( | ||
"gorm.io/gorm/clause" | ||
) | ||
|
||
type WhenMatched struct { | ||
clause.Set | ||
Where, Delete clause.Where | ||
} | ||
|
||
func (w WhenMatched) Name() string { | ||
return "WHEN MATCHED" | ||
} | ||
|
||
func (w WhenMatched) Build(builder clause.Builder) { | ||
if len(w.Set) > 0 { | ||
builder.WriteString(" THEN") | ||
builder.WriteString(" UPDATE ") | ||
builder.WriteString(w.Name()) | ||
builder.WriteByte(' ') | ||
w.Build(builder) | ||
|
||
buildWhere := func(where clause.Where) { | ||
builder.WriteString(where.Name()) | ||
builder.WriteByte(' ') | ||
where.Build(builder) | ||
} | ||
|
||
if len(w.Where.Exprs) > 0 { | ||
buildWhere(w.Where) | ||
} | ||
|
||
if len(w.Delete.Exprs) > 0 { | ||
builder.WriteString(" DELETE ") | ||
buildWhere(w.Delete) | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
package clauses | ||
|
||
import ( | ||
"gorm.io/gorm/clause" | ||
) | ||
|
||
type WhenNotMatched struct { | ||
clause.Values | ||
Where clause.Where | ||
} | ||
|
||
func (w WhenNotMatched) Name() string { | ||
return "WHEN NOT MATCHED" | ||
} | ||
|
||
func (w WhenNotMatched) Build(builder clause.Builder) { | ||
if len(w.Columns) > 0 { | ||
if len(w.Values.Values) != 1 { | ||
panic("cannot insert more than one rows due to DM SQL language restriction") | ||
} | ||
|
||
builder.WriteString(" THEN") | ||
builder.WriteString(" INSERT ") | ||
w.Build(builder) | ||
|
||
if len(w.Where.Exprs) > 0 { | ||
builder.WriteString(w.Where.Name()) | ||
builder.WriteByte(' ') | ||
w.Where.Build(builder) | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
package dm | ||
|
||
import ( | ||
"bytes" | ||
"database/sql" | ||
"reflect" | ||
|
||
"github.com/thoas/go-funk" | ||
"gorm.io/gorm" | ||
"gorm.io/gorm/callbacks" | ||
"gorm.io/gorm/clause" | ||
gormSchema "gorm.io/gorm/schema" | ||
|
||
"github.com/nfjBill/gorm-driver-dm/clauses" | ||
) | ||
|
||
func Create(db *gorm.DB) { | ||
stmt := db.Statement | ||
schema := stmt.Schema | ||
boundVars := make(map[string]int) | ||
|
||
if stmt == nil || schema == nil { | ||
return | ||
} | ||
|
||
hasDefaultValues := len(schema.FieldsWithDefaultDBValue) > 0 | ||
|
||
if !stmt.Unscoped { | ||
for _, c := range schema.CreateClauses { | ||
stmt.AddClause(c) | ||
} | ||
} | ||
|
||
if stmt.SQL.String() == "" { | ||
values := callbacks.ConvertToCreateValues(stmt) | ||
onConflict, hasConflict := stmt.Clauses["ON CONFLICT"].Expression.(clause.OnConflict) | ||
// are all columns in value the primary fields in schema only? | ||
if hasConflict && funk.Contains( | ||
funk.Map(values.Columns, func(c clause.Column) string { return c.Name }), | ||
funk.Map(schema.PrimaryFields, func(field *gormSchema.Field) string { return field.DBName }), | ||
) { | ||
stmt.AddClauseIfNotExists(clauses.Merge{ | ||
Using: []clause.Interface{ | ||
clause.Select{ | ||
Columns: funk.Map(values.Columns, func(column clause.Column) clause.Column { | ||
// HACK: I can not come up with a better alternative for now | ||
// I want to add a value to the list of variable and then capture the bind variable position as well | ||
buf := bytes.NewBufferString("") | ||
stmt.Vars = append(stmt.Vars, values.Values[0][funk.IndexOf(values.Columns, column)]) | ||
stmt.BindVarTo(buf, stmt, nil) | ||
|
||
column.Alias = column.Name | ||
// then the captured bind var will be the name | ||
column.Name = buf.String() | ||
return column | ||
}).([]clause.Column), | ||
}, | ||
clause.From{ | ||
Tables: []clause.Table{{Name: db.Dialector.(Dialector).DummyTableName()}}, | ||
}, | ||
}, | ||
On: funk.Map(schema.PrimaryFields, func(field *gormSchema.Field) clause.Expression { | ||
return clause.Eq{ | ||
Column: clause.Column{Table: stmt.Table, Name: field.DBName}, | ||
Value: clause.Column{Table: clauses.MergeDefaultExcludeName(), Name: field.DBName}, | ||
} | ||
}).([]clause.Expression), | ||
}) | ||
stmt.AddClauseIfNotExists(clauses.WhenMatched{Set: onConflict.DoUpdates}) | ||
stmt.AddClauseIfNotExists(clauses.WhenNotMatched{Values: values}) | ||
|
||
stmt.Build("MERGE", "WHEN MATCHED", "WHEN NOT MATCHED") | ||
} else { | ||
stmt.AddClauseIfNotExists(clause.Insert{Table: clause.Table{Name: stmt.Table}}) | ||
stmt.AddClause(clause.Values{Columns: values.Columns, Values: [][]interface{}{values.Values[0]}}) | ||
if hasDefaultValues { | ||
stmt.AddClauseIfNotExists(clause.Returning{ | ||
Columns: funk.Map(schema.FieldsWithDefaultDBValue, func(field *gormSchema.Field) clause.Column { | ||
return clause.Column{Name: field.DBName} | ||
}).([]clause.Column), | ||
}) | ||
} | ||
stmt.Build("INSERT", "VALUES") | ||
// 返回自增主键 | ||
//stmt.Build("INSERT", "VALUES", "RETURNING") | ||
//if hasDefaultValues { | ||
// stmt.WriteString(" INTO ") | ||
// for idx, field := range schema.FieldsWithDefaultDBValue { | ||
// if idx > 0 { | ||
// stmt.WriteByte(',') | ||
// } | ||
// boundVars[field.Name] = len(stmt.Vars) | ||
// stmt.AddVar(stmt, sql.Out{Dest: reflect.New(field.FieldType).Interface()}) | ||
// } | ||
//} | ||
} | ||
|
||
if !db.DryRun { | ||
for idx, vals := range values.Values { | ||
// HACK HACK: replace values one by one, assuming its value layout will be the same all the time, i.e. aligned | ||
for idx, val := range vals { | ||
switch v := val.(type) { | ||
case bool: | ||
if v { | ||
val = 1 | ||
} else { | ||
val = 0 | ||
} | ||
} | ||
|
||
stmt.Vars[idx] = val | ||
} | ||
// and then we insert each row one by one then put the returning values back (i.e. last return id => smart insert) | ||
// we keep track of the index so that the sub-reflected value is also correct | ||
|
||
// BIG BUG: what if any of the transactions failed? some result might already be inserted that dm is so | ||
// sneaky that some transaction inserts will exceed the buffer and so will be pushed at unknown point, | ||
// resulting in dangling row entries, so we might need to delete them if an error happens | ||
|
||
switch result, err := stmt.ConnPool.ExecContext(stmt.Context, stmt.SQL.String(), stmt.Vars...); err { | ||
case nil: // success | ||
db.RowsAffected, _ = result.RowsAffected() | ||
|
||
insertTo := stmt.ReflectValue | ||
switch insertTo.Kind() { | ||
case reflect.Slice, reflect.Array: | ||
insertTo = insertTo.Index(idx) | ||
} | ||
|
||
if hasDefaultValues { | ||
// bind returning value back to reflected value in the respective fields | ||
funk.ForEach( | ||
funk.Filter(schema.FieldsWithDefaultDBValue, func(field *gormSchema.Field) bool { | ||
return funk.Contains(boundVars, field.Name) | ||
}), | ||
func(field *gormSchema.Field) { | ||
switch insertTo.Kind() { | ||
case reflect.Struct: | ||
if err = field.Set(insertTo, stmt.Vars[boundVars[field.Name]].(sql.Out).Dest); err != nil { | ||
db.AddError(err) | ||
} | ||
case reflect.Map: | ||
// todo 设置id的值 | ||
} | ||
}, | ||
) | ||
} | ||
default: // failure | ||
db.AddError(err) | ||
} | ||
} | ||
} | ||
} | ||
} |
Oops, something went wrong.