Skip to content

Commit

Permalink
feat: custom unset error add context
Browse files Browse the repository at this point in the history
  • Loading branch information
SpectatorNan committed Aug 1, 2023
1 parent 4f789e7 commit 0f838eb
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 30 deletions.
13 changes: 12 additions & 1 deletion core/mapping/unmarshaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,17 @@ func NewUnmarshaler(key string, opts ...UnmarshalOption) *Unmarshaler {
return &unmarshaler
}

func WithOpts(u *Unmarshaler, opts ...UnmarshalOption) *Unmarshaler {
if u == nil {
return u
}
for _, opt := range opts {
opt(&u.opts)
}

return u
}

// UnmarshalKey unmarshals m into v with tag key.
func UnmarshalKey(m map[string]any, v any) error {
return keyUnmarshaler.Unmarshal(m, v)
Expand Down Expand Up @@ -929,7 +940,7 @@ func WithCanonicalKeyFunc(f func(string) string) UnmarshalOption {
}

// WithCustomFieldUnsetErr customizes an Unmarshaler with custom field unset error.
func WithCustomFieldUnsetErr(f func(string) error) UnmarshalOption {
func WithCustomFieldUnsetErr(f func(fullName string) error) UnmarshalOption {
return func(opt *unmarshalOptions) {
opt.customFieldUnsetErr = f
}
Expand Down
33 changes: 21 additions & 12 deletions rest/httpx/requests.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package httpx

import (
"context"
"io"
"net/http"
"strings"
Expand All @@ -26,7 +27,7 @@ var (
formUnmarshaler = mapping.NewUnmarshaler(formKey, mapping.WithStringValues())
pathUnmarshaler = mapping.NewUnmarshaler(pathKey, mapping.WithStringValues())
validator atomic.Value
customFieldUnsetErr func(key string) error
customFieldUnsetErr func(ctx context.Context, key string) error
)

// Validator defines the interface for validating the request.
Expand Down Expand Up @@ -73,8 +74,8 @@ func ParseForm(r *http.Request, v any) error {
if err != nil {
return err
}

return formUnmarshaler.Unmarshal(params, v)
unmarshaler := mapping.WithOpts(formUnmarshaler, getUnmarshalOptions(r)...)
return unmarshaler.Unmarshal(params, v)
}

// ParseHeader parses the request header and returns a map.
Expand All @@ -101,10 +102,7 @@ func ParseHeader(headerValue string) map[string]string {

// ParseJsonBody parses the post request which contains json in body.
func ParseJsonBody(r *http.Request, v any) error {
var opts []mapping.UnmarshalOption
if customFieldUnsetErr != nil {
opts = append(opts, mapping.WithCustomFieldUnsetErr(customFieldUnsetErr))
}
opts := getUnmarshalOptions(r)
if withJsonBody(r) {
reader := io.LimitReader(r.Body, maxBodyLen)
return mapping.UnmarshalJsonReader(reader, v, opts...)
Expand All @@ -121,8 +119,8 @@ func ParsePath(r *http.Request, v any) error {
for k, v := range vars {
m[k] = v
}

return pathUnmarshaler.Unmarshal(m, v)
unmarshaler := mapping.WithOpts(pathUnmarshaler, getUnmarshalOptions(r)...)
return unmarshaler.Unmarshal(m, v)
}

// SetValidator sets the validator.
Expand All @@ -136,8 +134,19 @@ func withJsonBody(r *http.Request) bool {
return r.ContentLength > 0 && strings.Contains(r.Header.Get(header.ContentType), header.ApplicationJson)
}

func SetCustomUnsetError(f func(string) error) {
func getUnmarshalOptions(r *http.Request) []mapping.UnmarshalOption {
var opts []mapping.UnmarshalOption
if customFieldUnsetErr != nil {
unsetErrFun := func(key string) error {
return customFieldUnsetErr(r.Context(), key)
}
opts = append(opts, mapping.WithCustomFieldUnsetErr(unsetErrFun))
}
return opts
}

func SetCustomUnsetError(f func(ctx context.Context, fullName string) error) {
customFieldUnsetErr = f
formUnmarshaler = mapping.NewUnmarshaler(formKey, mapping.WithStringValues(), mapping.WithCustomFieldUnsetErr(f))
pathUnmarshaler = mapping.NewUnmarshaler(pathKey, mapping.WithStringValues(), mapping.WithCustomFieldUnsetErr(f))
//formUnmarshaler = mapping.NewUnmarshaler(formKey, mapping.WithStringValues(), mapping.WithCustomFieldUnsetErr(f))
//pathUnmarshaler = mapping.NewUnmarshaler(pathKey, mapping.WithStringValues(), mapping.WithCustomFieldUnsetErr(f))
}
44 changes: 27 additions & 17 deletions rest/httpx/requests_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package httpx

import (
"context"
"errors"
"fmt"
"net/http"
Expand Down Expand Up @@ -259,26 +260,35 @@ func TestParseJsonBody(t *testing.T) {
}

func TestParseCustomUnsetErr(t *testing.T) {
SetCustomUnsetError(func(tag string) error {
return fmt.Errorf("custom %s unset error", tag)
startCtx := context.Background()
ctxKey := "method"

SetCustomUnsetError(func(ctx context.Context, tag string) error {
return fmt.Errorf("%s: custom %s unset error", ctx.Value(ctxKey).(string), tag)
})
v := struct {
Name string `form:"name"`
Percent float64 `form:"percent"`
}{}
t.Run("request get", func(t *testing.T) {
v := struct {
Name string `form:"name"`
Percent float64 `form:"percent"`
}{}

gr, err := http.NewRequest(http.MethodGet, "/a?name=hello", http.NoBody)
assert.Nil(t, err)
assert.EqualErrorf(t, Parse(gr, &v), "custom percent unset error", "custom unset error")
gr, err := http.NewRequest(http.MethodGet, "/a?name=hello", http.NoBody)
gr = gr.WithContext(context.WithValue(startCtx, ctxKey, "GET"))
assert.Nil(t, err)
assert.EqualErrorf(t, Parse(gr, &v), "GET: custom percent unset error", "custom unset error")
})

pv := struct {
Name string `json:"name"`
Percent float64 `json:"percent"`
}{}
body := `{"name":"hello"}`
pr := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
pr.Header.Set(ContentType, header.JsonContentType)
assert.EqualErrorf(t, Parse(pr, &pv), "custom percent unset error", "custom unset error")
t.Run("request post", func(t *testing.T) {
pv := struct {
Name string `json:"name"`
Percent float64 `json:"percent"`
}{}
body := `{"name":"hello"}`
pr := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
pr.Header.Set(ContentType, header.JsonContentType)
pr = pr.WithContext(context.WithValue(startCtx, ctxKey, "POST"))
assert.EqualErrorf(t, Parse(pr, &pv), "POST: custom percent unset error", "custom unset error")
})
}

func TestParseRequired(t *testing.T) {
Expand Down

0 comments on commit 0f838eb

Please sign in to comment.