Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(auth): support configurable token extraction lookups
Browse files Browse the repository at this point in the history
Extracting JWT from different request sources (headers, query params, form data) is now
configurable via `TokenLookup`.
ch3nnn authored and kevwan committed Jan 30, 2025
1 parent 80609c2 commit ae51115
Showing 8 changed files with 100 additions and 29 deletions.
20 changes: 16 additions & 4 deletions rest/config.go
Original file line number Diff line number Diff line change
@@ -39,16 +39,28 @@ type (
JWTConf struct {
AccessSecret string
AccessExpire int64
// extract a jwt from custom request header or url arguments
TokenKeys []string `json:",optional"`
// TokenLookup is a slice in the form of "<source>:<name>" that is used
// to extract token from the request.
// Optional.
// Possible values:
// - "header:<name>"
// - "query:<name>"
// - "form:<name>"
TokenLookup []string `json:",optional"`
}

// A JWTTransConf is a jwtTrans config.
JWTTransConf struct {
Secret string
PrevSecret string
// extract a jwt from custom request header or url arguments
TokenKeys []string `json:",optional"`
// TokenLookup is a slice in the form of "<source>:<name>" that is used
// to extract token from the request.
// Optional.
// Possible values:
// - "header:<name>"
// - "query:<name>"
// - "form:<name>"
TokenLookup []string `json:",optional"`
}

// A RestConf is a http service config.
4 changes: 2 additions & 2 deletions rest/engine.go
Original file line number Diff line number Diff line change
@@ -72,8 +72,8 @@ func (ng *engine) appendAuthHandler(fr featuredRoutes, chn chain.Chain,
if len(fr.jwt.prevSecret) > 0 {
authOpts = append(authOpts, handler.WithPrevSecret(fr.jwt.prevSecret))
}
if len(fr.jwt.tokenKeys) > 0 {
authOpts = append(authOpts, handler.WithTokenKeys(fr.jwt.tokenKeys))
if len(fr.jwt.tokenLookups) > 0 {
authOpts = append(authOpts, handler.WithTokenLookups(fr.jwt.tokenLookups))
}

chn = chn.Append(handler.Authorize(fr.jwt.secret, authOpts...))
16 changes: 8 additions & 8 deletions rest/handler/authhandler.go
Original file line number Diff line number Diff line change
@@ -31,9 +31,9 @@ var (
type (
// An AuthorizeOptions is authorize options.
AuthorizeOptions struct {
PrevSecret string
Callback UnauthorizedCallback
TokenKeys []string
PrevSecret string
Callback UnauthorizedCallback
TokenLookups []string
}

// UnauthorizedCallback defines the method of unauthorized callback.
@@ -50,8 +50,8 @@ func Authorize(secret string, opts ...AuthorizeOption) func(http.Handler) http.H
}

var parseOpts []token.ParseOption
if len(authOpts.TokenKeys) > 0 {
parseOpts = append(parseOpts, token.WithExtractor(authOpts.TokenKeys))
if len(authOpts.TokenLookups) > 0 {
parseOpts = append(parseOpts, token.WithExtractor(authOpts.TokenLookups))
}

parser := token.NewTokenParser(parseOpts...)
@@ -103,10 +103,10 @@ func WithUnauthorizedCallback(callback UnauthorizedCallback) AuthorizeOption {
}
}

// WithTokenKeys custom token key
func WithTokenKeys(tokenKeys []string) AuthorizeOption {
// WithTokenLookups used to set the source of the token
func WithTokenLookups(tokenLookups []string) AuthorizeOption {
return func(opts *AuthorizeOptions) {
opts.TokenKeys = tokenKeys
opts.TokenLookups = tokenLookups
}
}

4 changes: 2 additions & 2 deletions rest/handler/authhandler_test.go
Original file line number Diff line number Diff line change
@@ -57,15 +57,15 @@ func TestAuthHandler(t *testing.T) {
assert.Equal(t, "content", resp.Body.String())
}

func TestAuthHandler_WithTokenKeys(t *testing.T) {
func TestAuthHandler_WithTokenLookups(t *testing.T) {
const key = "B63F477D-BBA3-4E52-96D3-C0034C27694A"
req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
token, err := buildToken(key, map[string]any{
"key": "value",
}, 3600)
assert.Nil(t, err)
req.Header.Set("X-Token", token)
handler := Authorize(key, WithTokenKeys([]string{"Token", "X-Token"}))(
handler := Authorize(key, WithTokenLookups([]string{"header:X-Token"}))(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Test", "test")
_, err := w.Write([]byte("content"))
4 changes: 2 additions & 2 deletions rest/server.go
Original file line number Diff line number Diff line change
@@ -196,7 +196,7 @@ func WithJwt(jwt JWTConf) RouteOption {
validateSecret(jwt.AccessSecret)
r.jwt.enabled = true
r.jwt.secret = jwt.AccessSecret
r.jwt.tokenKeys = jwt.TokenKeys
r.jwt.tokenLookups = jwt.TokenLookup
}
}

@@ -210,7 +210,7 @@ func WithJwtTransition(jwt JWTTransConf) RouteOption {
r.jwt.enabled = true
r.jwt.secret = jwt.Secret
r.jwt.prevSecret = jwt.PrevSecret
r.jwt.tokenKeys = jwt.TokenKeys
r.jwt.tokenLookups = jwt.TokenLookup

}
}
37 changes: 32 additions & 5 deletions rest/token/tokenparser.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,26 @@
package token

import (
"fmt"
"net/http"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/golang-jwt/jwt/v4"
"github.com/golang-jwt/jwt/v4/request"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/timex"
)

const claimHistoryResetDuration = time.Hour * 24
const (
claimHistoryResetDuration = time.Hour * 24

jwtLookupHeader = "header"
jwtLookupQuery = "query"
jwtLookupForm = "form"
)

type (
// ParseOption defines the method to customize a TokenParser.
@@ -122,12 +131,30 @@ func WithResetDuration(duration time.Duration) ParseOption {
}
}

func WithExtractor(tokenKeys []string) ParseOption {
// WithExtractor used to configure the token extraction method of the TokenParser.
func WithExtractor(tokenLookups []string) ParseOption {
return func(parser *TokenParser) {
var headerNames, argumentNames []string
for _, lookup := range tokenLookups {
parts := strings.Split(strings.TrimSpace(lookup), ":")
if len(parts) < 2 {
logx.Must(fmt.Errorf("extractor source for lookup could not be split into needed parts: %v", lookup))
}

source := strings.TrimSpace(parts[0])
name := strings.TrimSpace(parts[1])
switch source {
case jwtLookupHeader:
headerNames = append(headerNames, name)
case jwtLookupQuery, jwtLookupForm:
argumentNames = append(argumentNames, name)
}
}

parser.extractor = request.MultiExtractor{
request.HeaderExtractor(tokenKeys),
request.ArgumentExtractor(tokenKeys),
parser.extractor,
request.HeaderExtractor(headerNames),
request.ArgumentExtractor(argumentNames),
request.AuthorizationHeaderExtractor,
}
}
}
36 changes: 34 additions & 2 deletions rest/token/tokenparser_test.go
Original file line number Diff line number Diff line change
@@ -3,6 +3,8 @@ package token
import (
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"

@@ -55,7 +57,7 @@ func TestTokenParser_CustomHeader(t *testing.T) {
assert.Nil(t, err)
req.Header.Set("Token", token)

parser := NewTokenParser(WithExtractor([]string{"Token"}))
parser := NewTokenParser(WithExtractor([]string{"header:Token"}))
tok, err := parser.ParseToken(req, key, prevKey)
assert.Nil(t, err)
assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"])
@@ -78,7 +80,7 @@ func TestTokenParser_URLArgument(t *testing.T) {

req := httptest.NewRequest(http.MethodGet, "http://localhost?token="+token, http.NoBody)

parser := NewTokenParser(WithExtractor([]string{"token"}))
parser := NewTokenParser(WithExtractor([]string{"query:token"}))
tok, err := parser.ParseToken(req, key, prevKey)
assert.Nil(t, err)
assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"])
@@ -91,6 +93,36 @@ func TestTokenParser_URLArgument(t *testing.T) {
assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"])
}

func TestTokenParser_FormArgument(t *testing.T) {
const (
key = "14F17379-EB8F-411B-8F12-6929002DCA76"
prevKey = "B63F477D-BBA3-4E52-96D3-C0034C27694A"
)
token, err := buildToken(key, map[string]any{"key": "value"}, 3600)
assert.Nil(t, err)

// create form data
form := url.Values{}
form.Add("form_token", token)

// Using httptest.NewRequest to create a fake POST request
req := httptest.NewRequest(http.MethodPost, "http://localhost", strings.NewReader(form.Encode()))
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")

parser := NewTokenParser(WithExtractor([]string{"form:form_token"}))
tok, err := parser.ParseToken(req, key, prevKey)
assert.Nil(t, err)
assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"])
tok, err = parser.ParseToken(req, key, prevKey)
assert.Nil(t, err)
assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"])
parser.resetTime = timex.Now() - time.Hour
tok, err = parser.ParseToken(req, key, prevKey)
assert.Nil(t, err)
assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"])

}

func TestTokenParser_Expired(t *testing.T) {
const (
key = "14F17379-EB8F-411B-8F12-6929002DCA76"
8 changes: 4 additions & 4 deletions rest/types.go
Original file line number Diff line number Diff line change
@@ -20,10 +20,10 @@ type (
RouteOption func(r *featuredRoutes)

jwtSetting struct {
enabled bool
secret string
prevSecret string
tokenKeys []string
enabled bool
secret string
prevSecret string
tokenLookups []string
}

signatureSetting struct {

0 comments on commit ae51115

Please sign in to comment.