From 96712d5865895e6d16b238b923a4b24d51e5a151 Mon Sep 17 00:00:00 2001 From: ch3nnn Date: Fri, 23 Aug 2024 18:28:54 +0800 Subject: [PATCH] feat: support multiple extract token key Implement support for multiple custom token keys and simplify the JWT authentication configuration. `WithTokenKeys` function enables setting token keys, improving the authentication process by accommodating various token header extraction strategies. by accommodating various token header extraction strategies. --- rest/config.go | 16 +++++++++++ rest/engine.go | 17 ++++++----- rest/handler/authhandler.go | 15 +++++++++- rest/server.go | 17 ++++++----- rest/server_test.go | 4 +-- rest/token/tokenparser.go | 21 +++++++++++--- rest/token/tokenparser_test.go | 46 ++++++++++++++++++++++++++++++ rest/types.go | 1 + tools/goctl/api/gogen/genconfig.go | 12 ++------ tools/goctl/api/gogen/genroutes.go | 4 +-- 10 files changed, 120 insertions(+), 33 deletions(-) diff --git a/rest/config.go b/rest/config.go index eb5fdb0ba234..2d86086b6596 100644 --- a/rest/config.go +++ b/rest/config.go @@ -35,6 +35,22 @@ type ( PrivateKeys []PrivateKeyConf } + // JWTConf Key and expiration time configuration required for JWT authentication + JWTConf struct { + AccessSecret string + AccessExpire int64 + // extract a jwt from custom request header or url arguments + TokenKeys []string `json:",optional"` + } + + // A JWTTransConf is a jwtTrans config. + JWTTransConf struct { + Secret string + PrevSecret string + // extract a jwt from custom request header or url arguments + TokenKeys []string `json:",optional"` + } + // A RestConf is a http service config. // Why not name it as Conf, because we need to consider usage like: // type Config struct { diff --git a/rest/engine.go b/rest/engine.go index e57786caf205..adcc9c67b597 100644 --- a/rest/engine.go +++ b/rest/engine.go @@ -66,14 +66,17 @@ func (ng *engine) addRoutes(r featuredRoutes) { func (ng *engine) appendAuthHandler(fr featuredRoutes, chn chain.Chain, verifier func(chain.Chain) chain.Chain) chain.Chain { if fr.jwt.enabled { - if len(fr.jwt.prevSecret) == 0 { - chn = chn.Append(handler.Authorize(fr.jwt.secret, - handler.WithUnauthorizedCallback(ng.unauthorizedCallback))) - } else { - chn = chn.Append(handler.Authorize(fr.jwt.secret, - handler.WithPrevSecret(fr.jwt.prevSecret), - handler.WithUnauthorizedCallback(ng.unauthorizedCallback))) + authOpts := []handler.AuthorizeOption{ + handler.WithUnauthorizedCallback(ng.unauthorizedCallback), + } + if len(fr.jwt.prevSecret) > 0 { + authOpts = append(authOpts, handler.WithPrevSecret(fr.jwt.prevSecret)) } + if len(fr.jwt.tokenKeys) > 0 { + authOpts = append(authOpts, handler.WithTokenKeys(fr.jwt.tokenKeys)) + } + + chn = chn.Append(handler.Authorize(fr.jwt.secret, authOpts...)) } return verifier(chn) diff --git a/rest/handler/authhandler.go b/rest/handler/authhandler.go index f72e2113a96f..6337d33af4c3 100644 --- a/rest/handler/authhandler.go +++ b/rest/handler/authhandler.go @@ -33,6 +33,7 @@ type ( AuthorizeOptions struct { PrevSecret string Callback UnauthorizedCallback + TokenKeys []string } // UnauthorizedCallback defines the method of unauthorized callback. @@ -48,7 +49,12 @@ func Authorize(secret string, opts ...AuthorizeOption) func(http.Handler) http.H opt(&authOpts) } - parser := token.NewTokenParser() + var parseOpts []token.ParseOption + if len(authOpts.TokenKeys) > 0 { + parseOpts = append(parseOpts, token.WithExtractor(authOpts.TokenKeys)) + } + + parser := token.NewTokenParser(parseOpts...) return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tok, err := parser.ParseToken(r, secret, authOpts.PrevSecret) @@ -97,6 +103,13 @@ func WithUnauthorizedCallback(callback UnauthorizedCallback) AuthorizeOption { } } +// WithTokenKeys custom token key +func WithTokenKeys(tokenKeys []string) AuthorizeOption { + return func(opts *AuthorizeOptions) { + opts.TokenKeys = tokenKeys + } +} + func detailAuthLog(r *http.Request, reason string) { // discard dump error, only for debug purpose details, _ := httputil.DumpRequest(r, true) diff --git a/rest/server.go b/rest/server.go index b1e5487bd8a5..d9f8ce140d39 100644 --- a/rest/server.go +++ b/rest/server.go @@ -191,24 +191,27 @@ func WithFileServer(path string, fs http.FileSystem) RunOption { } // WithJwt returns a func to enable jwt authentication in given route. -func WithJwt(secret string) RouteOption { +func WithJwt(jwt JWTConf) RouteOption { return func(r *featuredRoutes) { - validateSecret(secret) + validateSecret(jwt.AccessSecret) r.jwt.enabled = true - r.jwt.secret = secret + r.jwt.secret = jwt.AccessSecret + r.jwt.tokenKeys = jwt.TokenKeys } } // WithJwtTransition returns a func to enable jwt authentication as well as jwt secret transition. // Which means old and new jwt secrets work together for a period. -func WithJwtTransition(secret, prevSecret string) RouteOption { +func WithJwtTransition(jwt JWTTransConf) RouteOption { return func(r *featuredRoutes) { // why not validate prevSecret, because prevSecret is an already used one, // even it not meet our requirement, we still need to allow the transition. - validateSecret(secret) + validateSecret(jwt.Secret) r.jwt.enabled = true - r.jwt.secret = secret - r.jwt.prevSecret = prevSecret + r.jwt.secret = jwt.Secret + r.jwt.prevSecret = jwt.PrevSecret + r.jwt.tokenKeys = jwt.TokenKeys + } } diff --git a/rest/server_test.go b/rest/server_test.go index 9a92d58f8203..97faa646eaf1 100644 --- a/rest/server_test.go +++ b/rest/server_test.go @@ -90,8 +90,8 @@ Port: 0 Method: http.MethodGet, Path: "/", Handler: nil, - }, WithJwt("thesecret"), WithSignature(SignatureConf{}), - WithJwtTransition("preivous", "thenewone")) + }, WithJwt(JWTConf{AccessSecret: "thesecret"}), WithSignature(SignatureConf{}), + WithJwtTransition(JWTTransConf{Secret: "preivous", PrevSecret: "thenewone"})) func() { defer func() { diff --git a/rest/token/tokenparser.go b/rest/token/tokenparser.go index 775ef49af531..2eb2e547e9ac 100644 --- a/rest/token/tokenparser.go +++ b/rest/token/tokenparser.go @@ -22,6 +22,7 @@ type ( resetTime time.Duration resetDuration time.Duration history sync.Map + extractor request.MultiExtractor } ) @@ -30,6 +31,7 @@ func NewTokenParser(opts ...ParseOption) *TokenParser { parser := &TokenParser{ resetTime: timex.Now(), resetDuration: claimHistoryResetDuration, + extractor: request.MultiExtractor{request.AuthorizationHeaderExtractor}, } for _, opt := range opts { @@ -79,10 +81,11 @@ func (tp *TokenParser) ParseToken(r *http.Request, secret, prevSecret string) (* } func (tp *TokenParser) doParseToken(r *http.Request, secret string) (*jwt.Token, error) { - return request.ParseFromRequest(r, request.AuthorizationHeaderExtractor, - func(token *jwt.Token) (any, error) { - return []byte(secret), nil - }, request.WithParser(newParser())) + keyFunc := func(token *jwt.Token) (any, error) { + return []byte(secret), nil + } + + return request.ParseFromRequest(r, tp.extractor, keyFunc, request.WithParser(newParser())) } func (tp *TokenParser) incrementCount(secret string) { @@ -119,6 +122,16 @@ func WithResetDuration(duration time.Duration) ParseOption { } } +func WithExtractor(tokenKeys []string) ParseOption { + return func(parser *TokenParser) { + parser.extractor = request.MultiExtractor{ + request.HeaderExtractor(tokenKeys), + request.ArgumentExtractor(tokenKeys), + parser.extractor, + } + } +} + func newParser() *jwt.Parser { return jwt.NewParser(jwt.WithJSONNumber()) } diff --git a/rest/token/tokenparser_test.go b/rest/token/tokenparser_test.go index 147d64380580..3687a7842498 100644 --- a/rest/token/tokenparser_test.go +++ b/rest/token/tokenparser_test.go @@ -45,6 +45,52 @@ func TestTokenParser(t *testing.T) { } } +func TestTokenParser_CustomHeader(t *testing.T) { + const ( + key = "14F17379-EB8F-411B-8F12-6929002DCA76" + prevKey = "B63F477D-BBA3-4E52-96D3-C0034C27694A" + ) + req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody) + token, err := buildToken(key, map[string]any{"key": "value"}, 3600) + assert.Nil(t, err) + req.Header.Set("Token", token) + + parser := NewTokenParser(WithExtractor([]string{"Token"})) + tok, err := parser.ParseToken(req, key, prevKey) + assert.Nil(t, err) + assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"]) + tok, err = parser.ParseToken(req, key, prevKey) + assert.Nil(t, err) + assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"]) + parser.resetTime = timex.Now() - time.Hour + tok, err = parser.ParseToken(req, key, prevKey) + assert.Nil(t, err) + assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"]) +} + +func TestTokenParser_URLArgument(t *testing.T) { + const ( + key = "14F17379-EB8F-411B-8F12-6929002DCA76" + prevKey = "B63F477D-BBA3-4E52-96D3-C0034C27694A" + ) + token, err := buildToken(key, map[string]any{"key": "value"}, 3600) + assert.Nil(t, err) + + req := httptest.NewRequest(http.MethodGet, "http://localhost?token="+token, http.NoBody) + + parser := NewTokenParser(WithExtractor([]string{"token"})) + tok, err := parser.ParseToken(req, key, prevKey) + assert.Nil(t, err) + assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"]) + tok, err = parser.ParseToken(req, key, prevKey) + assert.Nil(t, err) + assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"]) + parser.resetTime = timex.Now() - time.Hour + tok, err = parser.ParseToken(req, key, prevKey) + assert.Nil(t, err) + assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"]) +} + func TestTokenParser_Expired(t *testing.T) { const ( key = "14F17379-EB8F-411B-8F12-6929002DCA76" diff --git a/rest/types.go b/rest/types.go index f7be7996432b..52b41486841c 100644 --- a/rest/types.go +++ b/rest/types.go @@ -23,6 +23,7 @@ type ( enabled bool secret string prevSecret string + tokenKeys []string } signatureSetting struct { diff --git a/tools/goctl/api/gogen/genconfig.go b/tools/goctl/api/gogen/genconfig.go index 0f3920d7f567..b7f10432e449 100644 --- a/tools/goctl/api/gogen/genconfig.go +++ b/tools/goctl/api/gogen/genconfig.go @@ -14,16 +14,8 @@ import ( const ( configFile = "config" - jwtTemplate = ` struct { - AccessSecret string - AccessExpire int64 - } -` - jwtTransTemplate = ` struct { - Secret string - PrevSecret string - } -` + jwtTemplate = ` rest.JWTConf` + jwtTransTemplate = ` rest.JWTTransConf` ) //go:embed config.tpl diff --git a/tools/goctl/api/gogen/genroutes.go b/tools/goctl/api/gogen/genroutes.go index 6d9295f7fb07..90e74e3a1fb9 100644 --- a/tools/goctl/api/gogen/genroutes.go +++ b/tools/goctl/api/gogen/genroutes.go @@ -118,10 +118,10 @@ func genRoutes(dir, rootPkg string, cfg *config.Config, api *spec.ApiSpec) error var jwt string if g.jwtEnabled { - jwt = fmt.Sprintf("\n rest.WithJwt(serverCtx.Config.%s.AccessSecret),", g.authName) + jwt = fmt.Sprintf("\n rest.WithJwt(serverCtx.Config.%s),", g.authName) } if len(g.jwtTrans) > 0 { - jwt = jwt + fmt.Sprintf("\n rest.WithJwtTransition(serverCtx.Config.%s.PrevSecret,serverCtx.Config.%s.Secret),", g.jwtTrans, g.jwtTrans) + jwt = jwt + fmt.Sprintf("\n rest.WithJwtTransition(serverCtx.Config.%s),", g.jwtTrans) } var signature, prefix string if g.signatureEnabled {