diff --git a/rest/config.go b/rest/config.go index eb5fdb0ba234..61364155c3c7 100644 --- a/rest/config.go +++ b/rest/config.go @@ -35,6 +35,34 @@ type ( PrivateKeys []PrivateKeyConf } + // JWTConf Key and expiration time configuration required for JWT authentication + JWTConf struct { + AccessSecret string + AccessExpire int64 + // TokenLookup is a slice in the form of ":" that is used + // to extract token from the request. + // Optional. + // Possible values: + // - "header:" + // - "query:" + // - "form:" + TokenLookup []string `json:",optional"` + } + + // A JWTTransConf is a jwtTrans config. + JWTTransConf struct { + Secret string + PrevSecret string + // TokenLookup is a slice in the form of ":" that is used + // to extract token from the request. + // Optional. + // Possible values: + // - "header:" + // - "query:" + // - "form:" + TokenLookup []string `json:",optional"` + } + // A RestConf is a http service config. // Why not name it as Conf, because we need to consider usage like: // type Config struct { diff --git a/rest/engine.go b/rest/engine.go index e57786caf205..d69164d89228 100644 --- a/rest/engine.go +++ b/rest/engine.go @@ -66,14 +66,17 @@ func (ng *engine) addRoutes(r featuredRoutes) { func (ng *engine) appendAuthHandler(fr featuredRoutes, chn chain.Chain, verifier func(chain.Chain) chain.Chain) chain.Chain { if fr.jwt.enabled { - if len(fr.jwt.prevSecret) == 0 { - chn = chn.Append(handler.Authorize(fr.jwt.secret, - handler.WithUnauthorizedCallback(ng.unauthorizedCallback))) - } else { - chn = chn.Append(handler.Authorize(fr.jwt.secret, - handler.WithPrevSecret(fr.jwt.prevSecret), - handler.WithUnauthorizedCallback(ng.unauthorizedCallback))) + authOpts := []handler.AuthorizeOption{ + handler.WithUnauthorizedCallback(ng.unauthorizedCallback), + } + if len(fr.jwt.prevSecret) > 0 { + authOpts = append(authOpts, handler.WithPrevSecret(fr.jwt.prevSecret)) } + if len(fr.jwt.tokenLookups) > 0 { + authOpts = append(authOpts, handler.WithTokenLookups(fr.jwt.tokenLookups)) + } + + chn = chn.Append(handler.Authorize(fr.jwt.secret, authOpts...)) } return verifier(chn) diff --git a/rest/engine_test.go b/rest/engine_test.go index 4f86d2173efd..f2feaf3845e7 100644 --- a/rest/engine_test.go +++ b/rest/engine_test.go @@ -114,7 +114,8 @@ Verbose: true { priority: true, jwt: jwtSetting{ - enabled: true, + enabled: true, + tokenLookups: []string{"header:Token", "query:Token", "form:Token"}, }, signature: signatureSetting{}, routes: []Route{{ diff --git a/rest/handler/authhandler.go b/rest/handler/authhandler.go index ab781bbfa2a9..cd3a485bfc59 100644 --- a/rest/handler/authhandler.go +++ b/rest/handler/authhandler.go @@ -31,8 +31,9 @@ var ( type ( // An AuthorizeOptions is authorize options. AuthorizeOptions struct { - PrevSecret string - Callback UnauthorizedCallback + PrevSecret string + Callback UnauthorizedCallback + TokenLookups []string } // UnauthorizedCallback defines the method of unauthorized callback. @@ -48,7 +49,12 @@ func Authorize(secret string, opts ...AuthorizeOption) func(http.Handler) http.H opt(&authOpts) } - parser := token.NewTokenParser() + var parseOpts []token.ParseOption + if len(authOpts.TokenLookups) > 0 { + parseOpts = append(parseOpts, token.WithExtractor(authOpts.TokenLookups)) + } + + parser := token.NewTokenParser(parseOpts...) return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tok, err := parser.ParseToken(r, secret, authOpts.PrevSecret) @@ -97,6 +103,13 @@ func WithUnauthorizedCallback(callback UnauthorizedCallback) AuthorizeOption { } } +// WithTokenLookups used to set the source of the token +func WithTokenLookups(tokenLookups []string) AuthorizeOption { + return func(opts *AuthorizeOptions) { + opts.TokenLookups = tokenLookups + } +} + func detailAuthLog(r *http.Request, reason string) { // discard dump error, only for debug purpose details, _ := httputil.DumpRequest(r, true) diff --git a/rest/handler/authhandler_test.go b/rest/handler/authhandler_test.go index 27347a5c6eab..ce478e5e347f 100644 --- a/rest/handler/authhandler_test.go +++ b/rest/handler/authhandler_test.go @@ -57,6 +57,32 @@ func TestAuthHandler(t *testing.T) { assert.Equal(t, "content", resp.Body.String()) } +func TestAuthHandler_WithTokenLookups(t *testing.T) { + const key = "B63F477D-BBA3-4E52-96D3-C0034C27694A" + req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody) + token, err := buildToken(key, map[string]any{ + "key": "value", + }, 3600) + assert.Nil(t, err) + req.Header.Set("X-Token", token) + handler := Authorize(key, WithTokenLookups([]string{"header:X-Token"}))( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Test", "test") + _, err := w.Write([]byte("content")) + assert.Nil(t, err) + + flusher, ok := w.(http.Flusher) + assert.True(t, ok) + flusher.Flush() + })) + + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + assert.Equal(t, http.StatusOK, resp.Code) + assert.Equal(t, "test", resp.Header().Get("X-Test")) + assert.Equal(t, "content", resp.Body.String()) +} + func TestAuthHandlerWithPrevSecret(t *testing.T) { const ( key = "14F17379-EB8F-411B-8F12-6929002DCA76" diff --git a/rest/server.go b/rest/server.go index b1e5487bd8a5..297c511e6f00 100644 --- a/rest/server.go +++ b/rest/server.go @@ -191,24 +191,27 @@ func WithFileServer(path string, fs http.FileSystem) RunOption { } // WithJwt returns a func to enable jwt authentication in given route. -func WithJwt(secret string) RouteOption { +func WithJwt(jwt JWTConf) RouteOption { return func(r *featuredRoutes) { - validateSecret(secret) + validateSecret(jwt.AccessSecret) r.jwt.enabled = true - r.jwt.secret = secret + r.jwt.secret = jwt.AccessSecret + r.jwt.tokenLookups = jwt.TokenLookup } } // WithJwtTransition returns a func to enable jwt authentication as well as jwt secret transition. // Which means old and new jwt secrets work together for a period. -func WithJwtTransition(secret, prevSecret string) RouteOption { +func WithJwtTransition(jwt JWTTransConf) RouteOption { return func(r *featuredRoutes) { // why not validate prevSecret, because prevSecret is an already used one, // even it not meet our requirement, we still need to allow the transition. - validateSecret(secret) + validateSecret(jwt.Secret) r.jwt.enabled = true - r.jwt.secret = secret - r.jwt.prevSecret = prevSecret + r.jwt.secret = jwt.Secret + r.jwt.prevSecret = jwt.PrevSecret + r.jwt.tokenLookups = jwt.TokenLookup + } } diff --git a/rest/server_test.go b/rest/server_test.go index 9a92d58f8203..97faa646eaf1 100644 --- a/rest/server_test.go +++ b/rest/server_test.go @@ -90,8 +90,8 @@ Port: 0 Method: http.MethodGet, Path: "/", Handler: nil, - }, WithJwt("thesecret"), WithSignature(SignatureConf{}), - WithJwtTransition("preivous", "thenewone")) + }, WithJwt(JWTConf{AccessSecret: "thesecret"}), WithSignature(SignatureConf{}), + WithJwtTransition(JWTTransConf{Secret: "preivous", PrevSecret: "thenewone"})) func() { defer func() { diff --git a/rest/token/tokenparser.go b/rest/token/tokenparser.go index 775ef49af531..be3eb924d340 100644 --- a/rest/token/tokenparser.go +++ b/rest/token/tokenparser.go @@ -1,17 +1,26 @@ package token import ( + "fmt" "net/http" + "strings" "sync" "sync/atomic" "time" "github.com/golang-jwt/jwt/v4" "github.com/golang-jwt/jwt/v4/request" + "github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/timex" ) -const claimHistoryResetDuration = time.Hour * 24 +const ( + claimHistoryResetDuration = time.Hour * 24 + + jwtLookupHeader = "header" + jwtLookupQuery = "query" + jwtLookupForm = "form" +) type ( // ParseOption defines the method to customize a TokenParser. @@ -22,6 +31,7 @@ type ( resetTime time.Duration resetDuration time.Duration history sync.Map + extractor request.MultiExtractor } ) @@ -30,6 +40,7 @@ func NewTokenParser(opts ...ParseOption) *TokenParser { parser := &TokenParser{ resetTime: timex.Now(), resetDuration: claimHistoryResetDuration, + extractor: request.MultiExtractor{request.AuthorizationHeaderExtractor}, } for _, opt := range opts { @@ -79,10 +90,11 @@ func (tp *TokenParser) ParseToken(r *http.Request, secret, prevSecret string) (* } func (tp *TokenParser) doParseToken(r *http.Request, secret string) (*jwt.Token, error) { - return request.ParseFromRequest(r, request.AuthorizationHeaderExtractor, - func(token *jwt.Token) (any, error) { - return []byte(secret), nil - }, request.WithParser(newParser())) + keyFunc := func(token *jwt.Token) (any, error) { + return []byte(secret), nil + } + + return request.ParseFromRequest(r, tp.extractor, keyFunc, request.WithParser(newParser())) } func (tp *TokenParser) incrementCount(secret string) { @@ -119,6 +131,34 @@ func WithResetDuration(duration time.Duration) ParseOption { } } +// WithExtractor used to configure the token extraction method of the TokenParser. +func WithExtractor(tokenLookups []string) ParseOption { + return func(parser *TokenParser) { + var headerNames, argumentNames []string + for _, lookup := range tokenLookups { + parts := strings.Split(strings.TrimSpace(lookup), ":") + if len(parts) < 2 { + logx.Must(fmt.Errorf("extractor source for lookup could not be split into needed parts: %v", lookup)) + } + + source := strings.TrimSpace(parts[0]) + name := strings.TrimSpace(parts[1]) + switch source { + case jwtLookupHeader: + headerNames = append(headerNames, name) + case jwtLookupQuery, jwtLookupForm: + argumentNames = append(argumentNames, name) + } + } + + parser.extractor = request.MultiExtractor{ + request.HeaderExtractor(headerNames), + request.ArgumentExtractor(argumentNames), + request.AuthorizationHeaderExtractor, + } + } +} + func newParser() *jwt.Parser { return jwt.NewParser(jwt.WithJSONNumber()) } diff --git a/rest/token/tokenparser_test.go b/rest/token/tokenparser_test.go index 147d64380580..3c2cac3a8c8f 100644 --- a/rest/token/tokenparser_test.go +++ b/rest/token/tokenparser_test.go @@ -3,6 +3,8 @@ package token import ( "net/http" "net/http/httptest" + "net/url" + "strings" "testing" "time" @@ -45,6 +47,82 @@ func TestTokenParser(t *testing.T) { } } +func TestTokenParser_CustomHeader(t *testing.T) { + const ( + key = "14F17379-EB8F-411B-8F12-6929002DCA76" + prevKey = "B63F477D-BBA3-4E52-96D3-C0034C27694A" + ) + req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody) + token, err := buildToken(key, map[string]any{"key": "value"}, 3600) + assert.Nil(t, err) + req.Header.Set("Token", token) + + parser := NewTokenParser(WithExtractor([]string{"header:Token"})) + tok, err := parser.ParseToken(req, key, prevKey) + assert.Nil(t, err) + assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"]) + tok, err = parser.ParseToken(req, key, prevKey) + assert.Nil(t, err) + assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"]) + parser.resetTime = timex.Now() - time.Hour + tok, err = parser.ParseToken(req, key, prevKey) + assert.Nil(t, err) + assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"]) +} + +func TestTokenParser_URLArgument(t *testing.T) { + const ( + key = "14F17379-EB8F-411B-8F12-6929002DCA76" + prevKey = "B63F477D-BBA3-4E52-96D3-C0034C27694A" + ) + token, err := buildToken(key, map[string]any{"key": "value"}, 3600) + assert.Nil(t, err) + + req := httptest.NewRequest(http.MethodGet, "http://localhost?token="+token, http.NoBody) + + parser := NewTokenParser(WithExtractor([]string{"query:token"})) + tok, err := parser.ParseToken(req, key, prevKey) + assert.Nil(t, err) + assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"]) + tok, err = parser.ParseToken(req, key, prevKey) + assert.Nil(t, err) + assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"]) + parser.resetTime = timex.Now() - time.Hour + tok, err = parser.ParseToken(req, key, prevKey) + assert.Nil(t, err) + assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"]) +} + +func TestTokenParser_FormArgument(t *testing.T) { + const ( + key = "14F17379-EB8F-411B-8F12-6929002DCA76" + prevKey = "B63F477D-BBA3-4E52-96D3-C0034C27694A" + ) + token, err := buildToken(key, map[string]any{"key": "value"}, 3600) + assert.Nil(t, err) + + // create form data + form := url.Values{} + form.Add("form_token", token) + + // Using httptest.NewRequest to create a fake POST request + req := httptest.NewRequest(http.MethodPost, "http://localhost", strings.NewReader(form.Encode())) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + + parser := NewTokenParser(WithExtractor([]string{"form:form_token"})) + tok, err := parser.ParseToken(req, key, prevKey) + assert.Nil(t, err) + assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"]) + tok, err = parser.ParseToken(req, key, prevKey) + assert.Nil(t, err) + assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"]) + parser.resetTime = timex.Now() - time.Hour + tok, err = parser.ParseToken(req, key, prevKey) + assert.Nil(t, err) + assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"]) + +} + func TestTokenParser_Expired(t *testing.T) { const ( key = "14F17379-EB8F-411B-8F12-6929002DCA76" diff --git a/rest/types.go b/rest/types.go index f7be7996432b..ed72cf57f16a 100644 --- a/rest/types.go +++ b/rest/types.go @@ -20,9 +20,10 @@ type ( RouteOption func(r *featuredRoutes) jwtSetting struct { - enabled bool - secret string - prevSecret string + enabled bool + secret string + prevSecret string + tokenLookups []string } signatureSetting struct { diff --git a/tools/goctl/api/gogen/genconfig.go b/tools/goctl/api/gogen/genconfig.go index 0f3920d7f567..b7f10432e449 100644 --- a/tools/goctl/api/gogen/genconfig.go +++ b/tools/goctl/api/gogen/genconfig.go @@ -14,16 +14,8 @@ import ( const ( configFile = "config" - jwtTemplate = ` struct { - AccessSecret string - AccessExpire int64 - } -` - jwtTransTemplate = ` struct { - Secret string - PrevSecret string - } -` + jwtTemplate = ` rest.JWTConf` + jwtTransTemplate = ` rest.JWTTransConf` ) //go:embed config.tpl diff --git a/tools/goctl/api/gogen/genroutes.go b/tools/goctl/api/gogen/genroutes.go index 9770a57e1341..41af7bcb8a8e 100644 --- a/tools/goctl/api/gogen/genroutes.go +++ b/tools/goctl/api/gogen/genroutes.go @@ -118,10 +118,10 @@ func genRoutes(dir, rootPkg string, cfg *config.Config, api *spec.ApiSpec) error var jwt string if g.jwtEnabled { - jwt = fmt.Sprintf("\n rest.WithJwt(serverCtx.Config.%s.AccessSecret),", g.authName) + jwt = fmt.Sprintf("\n rest.WithJwt(serverCtx.Config.%s),", g.authName) } if len(g.jwtTrans) > 0 { - jwt = jwt + fmt.Sprintf("\n rest.WithJwtTransition(serverCtx.Config.%s.PrevSecret,serverCtx.Config.%s.Secret),", g.jwtTrans, g.jwtTrans) + jwt = jwt + fmt.Sprintf("\n rest.WithJwtTransition(serverCtx.Config.%s),", g.jwtTrans) } var signature, prefix string if g.signatureEnabled {