diff --git a/rest/config.go b/rest/config.go index 2d86086b6596..61364155c3c7 100644 --- a/rest/config.go +++ b/rest/config.go @@ -39,16 +39,28 @@ type ( JWTConf struct { AccessSecret string AccessExpire int64 - // extract a jwt from custom request header or url arguments - TokenKeys []string `json:",optional"` + // TokenLookup is a slice in the form of ":" that is used + // to extract token from the request. + // Optional. + // Possible values: + // - "header:" + // - "query:" + // - "form:" + TokenLookup []string `json:",optional"` } // A JWTTransConf is a jwtTrans config. JWTTransConf struct { Secret string PrevSecret string - // extract a jwt from custom request header or url arguments - TokenKeys []string `json:",optional"` + // TokenLookup is a slice in the form of ":" that is used + // to extract token from the request. + // Optional. + // Possible values: + // - "header:" + // - "query:" + // - "form:" + TokenLookup []string `json:",optional"` } // A RestConf is a http service config. diff --git a/rest/engine.go b/rest/engine.go index adcc9c67b597..d69164d89228 100644 --- a/rest/engine.go +++ b/rest/engine.go @@ -72,8 +72,8 @@ func (ng *engine) appendAuthHandler(fr featuredRoutes, chn chain.Chain, if len(fr.jwt.prevSecret) > 0 { authOpts = append(authOpts, handler.WithPrevSecret(fr.jwt.prevSecret)) } - if len(fr.jwt.tokenKeys) > 0 { - authOpts = append(authOpts, handler.WithTokenKeys(fr.jwt.tokenKeys)) + if len(fr.jwt.tokenLookups) > 0 { + authOpts = append(authOpts, handler.WithTokenLookups(fr.jwt.tokenLookups)) } chn = chn.Append(handler.Authorize(fr.jwt.secret, authOpts...)) diff --git a/rest/handler/authhandler.go b/rest/handler/authhandler.go index 44f1e985378a..cd3a485bfc59 100644 --- a/rest/handler/authhandler.go +++ b/rest/handler/authhandler.go @@ -31,9 +31,9 @@ var ( type ( // An AuthorizeOptions is authorize options. AuthorizeOptions struct { - PrevSecret string - Callback UnauthorizedCallback - TokenKeys []string + PrevSecret string + Callback UnauthorizedCallback + TokenLookups []string } // UnauthorizedCallback defines the method of unauthorized callback. @@ -50,8 +50,8 @@ func Authorize(secret string, opts ...AuthorizeOption) func(http.Handler) http.H } var parseOpts []token.ParseOption - if len(authOpts.TokenKeys) > 0 { - parseOpts = append(parseOpts, token.WithExtractor(authOpts.TokenKeys)) + if len(authOpts.TokenLookups) > 0 { + parseOpts = append(parseOpts, token.WithExtractor(authOpts.TokenLookups)) } parser := token.NewTokenParser(parseOpts...) @@ -103,10 +103,10 @@ func WithUnauthorizedCallback(callback UnauthorizedCallback) AuthorizeOption { } } -// WithTokenKeys custom token key -func WithTokenKeys(tokenKeys []string) AuthorizeOption { +// WithTokenLookups used to set the source of the token +func WithTokenLookups(tokenLookups []string) AuthorizeOption { return func(opts *AuthorizeOptions) { - opts.TokenKeys = tokenKeys + opts.TokenLookups = tokenLookups } } diff --git a/rest/handler/authhandler_test.go b/rest/handler/authhandler_test.go index 8046c3de55f2..ce478e5e347f 100644 --- a/rest/handler/authhandler_test.go +++ b/rest/handler/authhandler_test.go @@ -57,7 +57,7 @@ func TestAuthHandler(t *testing.T) { assert.Equal(t, "content", resp.Body.String()) } -func TestAuthHandler_WithTokenKeys(t *testing.T) { +func TestAuthHandler_WithTokenLookups(t *testing.T) { const key = "B63F477D-BBA3-4E52-96D3-C0034C27694A" req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody) token, err := buildToken(key, map[string]any{ @@ -65,7 +65,7 @@ func TestAuthHandler_WithTokenKeys(t *testing.T) { }, 3600) assert.Nil(t, err) req.Header.Set("X-Token", token) - handler := Authorize(key, WithTokenKeys([]string{"Token", "X-Token"}))( + handler := Authorize(key, WithTokenLookups([]string{"header:X-Token"}))( http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Test", "test") _, err := w.Write([]byte("content")) diff --git a/rest/server.go b/rest/server.go index d9f8ce140d39..297c511e6f00 100644 --- a/rest/server.go +++ b/rest/server.go @@ -196,7 +196,7 @@ func WithJwt(jwt JWTConf) RouteOption { validateSecret(jwt.AccessSecret) r.jwt.enabled = true r.jwt.secret = jwt.AccessSecret - r.jwt.tokenKeys = jwt.TokenKeys + r.jwt.tokenLookups = jwt.TokenLookup } } @@ -210,7 +210,7 @@ func WithJwtTransition(jwt JWTTransConf) RouteOption { r.jwt.enabled = true r.jwt.secret = jwt.Secret r.jwt.prevSecret = jwt.PrevSecret - r.jwt.tokenKeys = jwt.TokenKeys + r.jwt.tokenLookups = jwt.TokenLookup } } diff --git a/rest/token/tokenparser.go b/rest/token/tokenparser.go index 2eb2e547e9ac..be3eb924d340 100644 --- a/rest/token/tokenparser.go +++ b/rest/token/tokenparser.go @@ -1,17 +1,26 @@ package token import ( + "fmt" "net/http" + "strings" "sync" "sync/atomic" "time" "github.com/golang-jwt/jwt/v4" "github.com/golang-jwt/jwt/v4/request" + "github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/timex" ) -const claimHistoryResetDuration = time.Hour * 24 +const ( + claimHistoryResetDuration = time.Hour * 24 + + jwtLookupHeader = "header" + jwtLookupQuery = "query" + jwtLookupForm = "form" +) type ( // ParseOption defines the method to customize a TokenParser. @@ -122,12 +131,30 @@ func WithResetDuration(duration time.Duration) ParseOption { } } -func WithExtractor(tokenKeys []string) ParseOption { +// WithExtractor used to configure the token extraction method of the TokenParser. +func WithExtractor(tokenLookups []string) ParseOption { return func(parser *TokenParser) { + var headerNames, argumentNames []string + for _, lookup := range tokenLookups { + parts := strings.Split(strings.TrimSpace(lookup), ":") + if len(parts) < 2 { + logx.Must(fmt.Errorf("extractor source for lookup could not be split into needed parts: %v", lookup)) + } + + source := strings.TrimSpace(parts[0]) + name := strings.TrimSpace(parts[1]) + switch source { + case jwtLookupHeader: + headerNames = append(headerNames, name) + case jwtLookupQuery, jwtLookupForm: + argumentNames = append(argumentNames, name) + } + } + parser.extractor = request.MultiExtractor{ - request.HeaderExtractor(tokenKeys), - request.ArgumentExtractor(tokenKeys), - parser.extractor, + request.HeaderExtractor(headerNames), + request.ArgumentExtractor(argumentNames), + request.AuthorizationHeaderExtractor, } } } diff --git a/rest/token/tokenparser_test.go b/rest/token/tokenparser_test.go index 3687a7842498..3c2cac3a8c8f 100644 --- a/rest/token/tokenparser_test.go +++ b/rest/token/tokenparser_test.go @@ -3,6 +3,8 @@ package token import ( "net/http" "net/http/httptest" + "net/url" + "strings" "testing" "time" @@ -55,7 +57,7 @@ func TestTokenParser_CustomHeader(t *testing.T) { assert.Nil(t, err) req.Header.Set("Token", token) - parser := NewTokenParser(WithExtractor([]string{"Token"})) + parser := NewTokenParser(WithExtractor([]string{"header:Token"})) tok, err := parser.ParseToken(req, key, prevKey) assert.Nil(t, err) assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"]) @@ -78,7 +80,7 @@ func TestTokenParser_URLArgument(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "http://localhost?token="+token, http.NoBody) - parser := NewTokenParser(WithExtractor([]string{"token"})) + parser := NewTokenParser(WithExtractor([]string{"query:token"})) tok, err := parser.ParseToken(req, key, prevKey) assert.Nil(t, err) assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"]) @@ -91,6 +93,36 @@ func TestTokenParser_URLArgument(t *testing.T) { assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"]) } +func TestTokenParser_FormArgument(t *testing.T) { + const ( + key = "14F17379-EB8F-411B-8F12-6929002DCA76" + prevKey = "B63F477D-BBA3-4E52-96D3-C0034C27694A" + ) + token, err := buildToken(key, map[string]any{"key": "value"}, 3600) + assert.Nil(t, err) + + // create form data + form := url.Values{} + form.Add("form_token", token) + + // Using httptest.NewRequest to create a fake POST request + req := httptest.NewRequest(http.MethodPost, "http://localhost", strings.NewReader(form.Encode())) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + + parser := NewTokenParser(WithExtractor([]string{"form:form_token"})) + tok, err := parser.ParseToken(req, key, prevKey) + assert.Nil(t, err) + assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"]) + tok, err = parser.ParseToken(req, key, prevKey) + assert.Nil(t, err) + assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"]) + parser.resetTime = timex.Now() - time.Hour + tok, err = parser.ParseToken(req, key, prevKey) + assert.Nil(t, err) + assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"]) + +} + func TestTokenParser_Expired(t *testing.T) { const ( key = "14F17379-EB8F-411B-8F12-6929002DCA76" diff --git a/rest/types.go b/rest/types.go index 52b41486841c..ed72cf57f16a 100644 --- a/rest/types.go +++ b/rest/types.go @@ -20,10 +20,10 @@ type ( RouteOption func(r *featuredRoutes) jwtSetting struct { - enabled bool - secret string - prevSecret string - tokenKeys []string + enabled bool + secret string + prevSecret string + tokenLookups []string } signatureSetting struct {