Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: prevent interface type array from causing runtime errors #7361

Closed
wants to merge 10 commits into from
14 changes: 7 additions & 7 deletions finisher_api.go
Original file line number Diff line number Diff line change
@@ -19,10 +19,7 @@ func (db *DB) Create(value interface{}) (tx *DB) {
if db.CreateBatchSize > 0 {
return db.CreateInBatches(value, db.CreateBatchSize)
}

tx = db.getInstance()
tx.Statement.Dest = value
return tx.callbacks.Create().Execute(tx)
return db.create(value)
}

// CreateInBatches inserts value in batches of batchSize
@@ -63,12 +60,15 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {

tx.RowsAffected = rowsAffected
default:
tx = db.getInstance()
tx.Statement.Dest = value
tx = tx.callbacks.Create().Execute(tx)
db.create(value)
}
return
}
func (db *DB) create(value interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.Dest = value
return tx.callbacks.Create().Execute(tx)
}

// Save updates value in database. If value doesn't contain a matching primary key, value is inserted.
func (db *DB) Save(value interface{}) (tx *DB) {
7 changes: 6 additions & 1 deletion scan.go
Original file line number Diff line number Diff line change
@@ -202,6 +202,9 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
switch reflectValueType.Kind() {
case reflect.Array, reflect.Slice:
reflectValueType = reflectValueType.Elem()
if reflectValueType.Kind() == reflect.Interface && reflectValue.Len() > 0 {
reflectValueType = reflect.Indirect(reflectValue.Index(0)).Elem().Type()
}
}
isPtr := reflectValueType.Kind() == reflect.Ptr
if isPtr {
@@ -318,7 +321,9 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
} else {
elem = reflect.New(reflectValueType)
}

if elem.Type().Kind() == reflect.Interface {
elem = elem.Elem()
}
db.scanIntoStruct(rows, elem, values, fields, joinFields)

if !update {
3 changes: 3 additions & 0 deletions schema/field.go
Original file line number Diff line number Diff line change
@@ -462,6 +462,9 @@ func (field *Field) setupValuerAndSetter() {
default:
field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
v = reflect.Indirect(v)
if v.Kind() == reflect.Interface {
v = reflect.Indirect(v)
}
for _, fieldIdx := range field.StructField.Index {
if fieldIdx >= 0 {
v = v.Field(fieldIdx)
4 changes: 3 additions & 1 deletion schema/schema.go
Original file line number Diff line number Diff line change
@@ -136,8 +136,10 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam

for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
if modelType.Kind() == reflect.Interface && value.Len() > 0 {
modelType = reflect.Indirect(value.Index(0)).Elem().Type()
}
}

if modelType.Kind() != reflect.Struct {
if modelType.PkgPath() == "" {
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
62 changes: 62 additions & 0 deletions tests/create_test.go
Original file line number Diff line number Diff line change
@@ -791,3 +791,65 @@ func TestCreateFromMapWithTable(t *testing.T) {
t.Errorf("failed to create data from map with table, @id != id")
}
}

func TestCreateWithInterfaceType(t *testing.T) {
user := *GetUser("create", Config{})
type UserInterface interface{}
var userInterface UserInterface = &user

if results := DB.Create(userInterface); results.Error != nil {
t.Fatalf("errors happened when create: %v", results.Error)
} else if results.RowsAffected != 1 {
t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected)
}

if user.ID == 0 {
t.Errorf("user's primary key should has value after create, got : %v", user.ID)
}

if user.CreatedAt.IsZero() {
t.Errorf("user's created at should be not zero")
}

if user.UpdatedAt.IsZero() {
t.Errorf("user's updated at should be not zero")
}

var newUser User
if err := DB.Where("id = ?", user.ID).First(&newUser).Error; err != nil {
t.Fatalf("errors happened when query: %v", err)
} else {
CheckUser(t, newUser, user)
}
}

func TestCreateWithInterfaceArrayTypeWithTable(t *testing.T) {
user := *GetUser("create", Config{})
type UserInterface interface{}
var userInterface UserInterface = user

if results := DB.Table("users").Create([]UserInterface{userInterface}); results.Error != nil {
t.Fatalf("errors happened when create: %v", results.Error)
} else if results.RowsAffected != 1 {
t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected)
}

if user.ID == 0 {
t.Errorf("user's primary key should has value after create, got : %v", user.ID)
}

if user.CreatedAt.IsZero() {
t.Errorf("user's created at should be not zero")
}

if user.UpdatedAt.IsZero() {
t.Errorf("user's updated at should be not zero")
}

var newUser User
if err := DB.Table("users").Where("id = ?", user.ID).First(&newUser).Error; err != nil {
t.Fatalf("errors happened when query: %v", err)
} else {
CheckUser(t, newUser, user)
}
}