Skip to content

Commit

Permalink
fix clauses
Browse files Browse the repository at this point in the history
  • Loading branch information
nfjBill committed Feb 22, 2022
1 parent 0aecfe8 commit 3a396bc
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 32 deletions.
7 changes: 6 additions & 1 deletion clauses/when_not_matched.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func (w WhenNotMatched) Build(builder clause.Builder) {

builder.WriteString(" THEN")
builder.WriteString(" INSERT ")
w.Build(builder)
w.Values.Build(builder)

if len(w.Where.Exprs) > 0 {
builder.WriteString(w.Where.Name())
Expand All @@ -30,3 +30,8 @@ func (w WhenNotMatched) Build(builder clause.Builder) {
}
}
}

func (w WhenNotMatched) MergeClause(clause *clause.Clause) {
clause.Name = w.Name()
clause.Expression = w
}
18 changes: 8 additions & 10 deletions create.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,7 @@ func Create(db *gorm.DB) {
if stmt.SQL.String() == "" {
values := callbacks.ConvertToCreateValues(stmt)
onConflict, hasConflict := stmt.Clauses["ON CONFLICT"].Expression.(clause.OnConflict)
// are all columns in value the primary fields in schema only?
if hasConflict && funk.Contains(
funk.Map(values.Columns, func(c clause.Column) string { return c.Name }),
funk.Map(schema.PrimaryFields, func(field *gormSchema.Field) string { return field.DBName }),
) {
if hasConflict {
stmt.AddClauseIfNotExists(clauses.Merge{
Using: []clause.Interface{
clause.Select{
Expand All @@ -56,16 +52,17 @@ func Create(db *gorm.DB) {
}).([]clause.Column),
},
clause.From{
Tables: []clause.Table{{Name: db.Dialector.(Dialector).DummyTableName()}},
Tables: []clause.Table{{Name: "DUAL"}},
},
},
On: funk.Map(schema.PrimaryFields, func(field *gormSchema.Field) clause.Expression {
On: funk.Map(onConflict.Columns, func(field clause.Column) clause.Expression {
return clause.Eq{
Column: clause.Column{Table: stmt.Table, Name: field.DBName},
Value: clause.Column{Table: clauses.MergeDefaultExcludeName(), Name: field.DBName},
Column: clause.Column{Table: stmt.Table, Name: field.Name},
Value: clause.Column{Table: clauses.MergeDefaultExcludeName(), Name: field.Name},
}
}).([]clause.Expression),
})

stmt.AddClauseIfNotExists(clauses.WhenMatched{Set: onConflict.DoUpdates})
stmt.AddClauseIfNotExists(clauses.WhenNotMatched{Values: values})

Expand Down Expand Up @@ -117,7 +114,8 @@ func Create(db *gorm.DB) {
// sneaky that some transaction inserts will exceed the buffer and so will be pushed at unknown point,
// resulting in dangling row entries, so we might need to delete them if an error happens

switch result, err := stmt.ConnPool.ExecContext(stmt.Context, stmt.SQL.String(), stmt.Vars...); err {
sqlStr := stmt.Explain(stmt.SQL.String(), stmt.Vars...)
switch result, err := stmt.ConnPool.ExecContext(stmt.Context, sqlStr, stmt.Vars...); err {
case nil: // success
db.RowsAffected, _ = result.RowsAffected()

Expand Down
3 changes: 2 additions & 1 deletion dm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ var db *gorm.DB
func init() {
var err error
//dsn := "dm://sysdba:[email protected]:5236?autoCommit=true"
dsn := "dm://sysdba:SYSDBA@fe-repo.inner.px.nfjbill.ren:5237?autoCommit=true"
dsn := "dm://sysdba:SYSDBA@192.168.0.105:5236?autoCommit=true"
db, err = gorm.Open(Open(dsn), &gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
})
Expand Down Expand Up @@ -134,6 +134,7 @@ func TestDelete(t *testing.T) {
// err
func TestClausesAssignmentColumns(t *testing.T) {
err := Table(&User{Key: "2", Content: "EEE"}).ClausesAssignmentColumns("KEY", []string{"DELETED_AT", "CONTENT"})
err = Table(&User{Key: "4", Name: "Jinzhu5", Content: "FFF", Birthday: time.Now()}).ClausesAssignmentColumns("KEY", []string{"DELETED_AT", "CONTENT", "BIRTHDAY"})

if err != nil {
fmt.Printf("Error: failed to ClausesAssignmentColumns: %v\n", err)
Expand Down
25 changes: 5 additions & 20 deletions mask.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package dm
import (
"database/sql"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"reflect"
"strings"
)
Expand Down Expand Up @@ -54,26 +55,10 @@ func (stb *STable) GetAll(dest interface{}) error {
}

func (stb *STable) ClausesAssignmentColumns(name string, doUpdates []string) error {
we := RefInclude(RefClone(stb.Table), []string{name})
up := RefInclude(RefClone(stb.Table), doUpdates)

var data []interface{}
err := stb.Conn.Model(we).Select("ID").Where(we).Find(&data).Error
if len(data) > 0 {
tx := stb.Conn.Begin()
if err = tx.Where(we).Updates(up).Error; err != nil {
tx.Rollback()
return err
}
return tx.Commit().Error
} else {
return stb.Conn.Create(stb.Table).Error
}

//return stb.Conn.Clauses(clause.OnConflict{
// Columns: []clause.Column{{Name: name}},
// DoUpdates: clause.AssignmentColumns(doUpdates),
//}).Create(stb.Table).Error
return stb.Conn.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: name}},
DoUpdates: clause.AssignmentColumns(doUpdates),
}).Create(stb.Table).Error
}

func (stb *STable) Delete() error {
Expand Down

0 comments on commit 3a396bc

Please sign in to comment.