Skip to content

Commit

Permalink
feat(controller+service/get): GetField with conditions
Browse files Browse the repository at this point in the history
- add service.GetAssociations, service.CountAssociations: query associations
- improve controller.GetFieldHandler: use GetAssociations instead of preload to supports
GetField with conditions, pagination and ordering (GetRequestBody)
  • Loading branch information
cdfmlr committed Jul 21, 2022
1 parent 93d1fd8 commit 669573f
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 12 deletions.
82 changes: 71 additions & 11 deletions controller/get.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package controller

import (
"context"
"fmt"
"github.com/cdfmlr/crud/orm"
"github.com/cdfmlr/crud/service"
"github.com/gin-gonic/gin"
Expand Down Expand Up @@ -31,7 +32,7 @@ func GetListHandler[T any]() gin.HandlerFunc {
return
}

options := buildQueryOptions[T](request)
options := buildQueryOptions(request)

var dest []*T
err := service.GetMany[T](c, &dest, options...)
Expand Down Expand Up @@ -69,7 +70,7 @@ func GetByIDHandler[T orm.Model](idParam string) gin.HandlerFunc {
return
}

options := buildQueryOptions[T](request)
options := buildQueryOptions(request)

dest, err := getModelByID[T](c, idParam, options...)
if err != nil {
Expand All @@ -84,8 +85,18 @@ func GetByIDHandler[T orm.Model](idParam string) gin.HandlerFunc {

// GetFieldHandler handles
// GET /T/:idParam/field
// All GetRequestBody will be conditions for the field, for example:
// GET /user/123/order?preload=Product
// Preloads User.Order.Product instead of User.Product.
func GetFieldHandler[T orm.Model](idParam string, field string) gin.HandlerFunc {
return func(c *gin.Context) {
logger.WithContext(c).
WithField("model", fmt.Sprintf("%T", *new(T))).
WithField("idParam", idParam).
WithField("field", field).
Trace("GetFieldHandler")

// 0. bind request options
var request GetRequestBody
if err := c.ShouldBind(&request); err != nil {
logger.WithContext(c).WithError(err).
Expand All @@ -94,33 +105,82 @@ func GetFieldHandler[T orm.Model](idParam string, field string) gin.HandlerFunc
return
}

model, err := getModelByID[T](c, idParam, service.PreloadAll())
// 1. check the model exists?
model, err := getModelByID[T](c, idParam)
if err != nil {
logger.WithContext(c).WithError(err).
Warn("GetFieldHandler: getModelByID failed")
ResponseError(c, CodeProcessFailed, err)
return
}
logger.WithField("model", model).Debug("GetFieldHandler: model found")

//field := strings.ToUpper(field)[:1] + field[1:]
// 2. find out the field's type F
// If F is Struct / *Struct / []Struct,
// then we GetAssociations to query it.
// Otherwise, we just return field value from getModelByID.
field := NameToField(field, model)

// model.field
fieldValue := reflect.ValueOf(model).
Elem(). // because model is a pointer
FieldByName(field)

// TODO other GetRequestBody options
// use subquery to get children models instead of preload
if fieldValue.Type().Kind() == reflect.Ptr {
fieldValue = fieldValue.Elem() // *F => F
}

var elemType reflect.Type
switch fieldValue.Type().Kind() {
case reflect.Slice, reflect.Array:
elemType = fieldValue.Type().Elem() // []F => F
default:
elemType = fieldValue.Type() // keep F
}
// for []*Struct, we have to unwrap it again
if elemType.Kind() == reflect.Ptr {
elemType = elemType.Elem() // *F => F
}

//logger.WithField("fieldType", fieldValue.Type().Kind()).
// WithField("elemType", elemType).
// Debug("GetFieldHandler: elemType found")

if elemType.Kind() != reflect.Struct {
// not a model, return the value
ResponseSuccess(c, fieldValue.Interface())
return
}

// Slice or Struct
// 3. GetAssociations
dest := fieldValue.Interface()
err = service.GetAssociations(c, model, field, &dest, buildQueryOptions(request)...)
if err != nil {
logger.WithContext(c).WithError(err).
Warn("GetFieldHandler: GetAssociations failed")
ResponseError(c, CodeProcessFailed, err)
return
}

// 4. Count
var addition []gin.H
if request.Total && fieldValue.Kind() == reflect.Slice {
addition = append(addition, gin.H{"total": fieldValue.Len()})
if request.Total {
total, err := service.CountAssociations(c, model, field, buildQueryOptions(request)...)
if err != nil {
logger.WithContext(c).WithError(err).
Warn("GetFieldHandler: CountAssociations failed")
addition = append(addition, gin.H{"totalError": err.Error()})
} else {
addition = append(addition, gin.H{"total": total})
}
}

ResponseSuccess(c, fieldValue.Interface(), addition...)
ResponseSuccess(c, dest)
}
}

func buildQueryOptions[T any](request GetRequestBody) []service.QueryOption {
func buildQueryOptions(request GetRequestBody) []service.QueryOption {
var options []service.QueryOption
if request.Limit > 0 {
options = append(options, service.WithPage(request.Limit, request.Offset))
Expand All @@ -132,7 +192,7 @@ func buildQueryOptions[T any](request GetRequestBody) []service.QueryOption {
options = append(options, service.FilterBy(request.FilterBy, request.FilterValue))
}
for _, field := range request.Preload {
logger.WithField("field", field).Debug("Preload field")
// logger.WithField("field", field).Debug("Preload field")
options = append(options, service.Preload(field))
}
return options
Expand Down
2 changes: 1 addition & 1 deletion log/log.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func WithReportCaller(reportCaller bool) LoggerOption {

func WithHook(hook logrus.Hook) LoggerOption {
return func(logger *logrus.Logger) {
logger.Debugf("WithHook: %v", hook)
//logger.Debugf("WithHook: %v", hook)
logger.AddHook(hook)
}
}
Expand Down
37 changes: 37 additions & 0 deletions service/get.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,43 @@ func Count[T any](ctx context.Context, options ...QueryOption) (count int64, err
return count, ret.Error
}

// GetAssociations find matched associations (model.field) into dest.
func GetAssociations(ctx context.Context, model any, field string, dest any, options ...QueryOption) error {
logger := logger.WithContext(ctx).
WithField("model", fmt.Sprintf("%T", model)).
WithField("field", field).
WithField("dest", fmt.Sprintf("%T", dest))

logger.Trace("GetAssociation: Get association into dest")

err := associationQuery(ctx, model, field, options...).Find(dest)
if err != nil {
logger.WithError(err).
Warn("GetAssociation: Get association into dest failed")
}
return err
}

// CountAssociations count matched associations (model.field).
func CountAssociations(ctx context.Context, model any, field string, options ...QueryOption) (count int64, err error) {
logger.WithContext(ctx).
WithField("model", fmt.Sprintf("%T", model)).
WithField("field", field).
Trace("CountAssociations: Count associations")

count = associationQuery(ctx, model, field, options...).Count()
return count, err
}

// associationQuery builds a gorm association query
func associationQuery(ctx context.Context, model any, field string, options ...QueryOption) *gorm.Association {
query := orm.DB.WithContext(ctx).Model(model)
for _, option := range options {
query = option(query)
}
return query.Association(field)
}

// QueryOption is a function that can be used to construct a query.
type QueryOption func(tx *gorm.DB) *gorm.DB

Expand Down

0 comments on commit 669573f

Please sign in to comment.