Skip to content

Commit

Permalink
feat(auth): support configurable token extraction lookups
Browse files Browse the repository at this point in the history
Extracting JWT from different request sources (headers, query params, form data) is now
configurable via `TokenLookup`.
  • Loading branch information
ch3nnn committed Aug 29, 2024
1 parent a3f44cf commit 5fad085
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 29 deletions.
20 changes: 16 additions & 4 deletions rest/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,28 @@ type (
JWTConf struct {
AccessSecret string
AccessExpire int64
// extract a jwt from custom request header or url arguments
TokenKeys []string `json:",optional"`
// TokenLookup is a slice in the form of "<source>:<name>" that is used
// to extract token from the request.
// Optional.
// Possible values:
// - "header:<name>"
// - "query:<name>"
// - "form:<name>"
TokenLookup []string `json:",optional"`
}

// A JWTTransConf is a jwtTrans config.
JWTTransConf struct {
Secret string
PrevSecret string
// extract a jwt from custom request header or url arguments
TokenKeys []string `json:",optional"`
// TokenLookup is a slice in the form of "<source>:<name>" that is used
// to extract token from the request.
// Optional.
// Possible values:
// - "header:<name>"
// - "query:<name>"
// - "form:<name>"
TokenLookup []string `json:",optional"`
}

// A RestConf is a http service config.
Expand Down
4 changes: 2 additions & 2 deletions rest/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ func (ng *engine) appendAuthHandler(fr featuredRoutes, chn chain.Chain,
if len(fr.jwt.prevSecret) > 0 {
authOpts = append(authOpts, handler.WithPrevSecret(fr.jwt.prevSecret))
}
if len(fr.jwt.tokenKeys) > 0 {
authOpts = append(authOpts, handler.WithTokenKeys(fr.jwt.tokenKeys))
if len(fr.jwt.tokenLookups) > 0 {
authOpts = append(authOpts, handler.WithTokenLookups(fr.jwt.tokenLookups))
}

chn = chn.Append(handler.Authorize(fr.jwt.secret, authOpts...))
Expand Down
16 changes: 8 additions & 8 deletions rest/handler/authhandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ var (
type (
// An AuthorizeOptions is authorize options.
AuthorizeOptions struct {
PrevSecret string
Callback UnauthorizedCallback
TokenKeys []string
PrevSecret string
Callback UnauthorizedCallback
TokenLookups []string
}

// UnauthorizedCallback defines the method of unauthorized callback.
Expand All @@ -50,8 +50,8 @@ func Authorize(secret string, opts ...AuthorizeOption) func(http.Handler) http.H
}

var parseOpts []token.ParseOption
if len(authOpts.TokenKeys) > 0 {
parseOpts = append(parseOpts, token.WithExtractor(authOpts.TokenKeys))
if len(authOpts.TokenLookups) > 0 {
parseOpts = append(parseOpts, token.WithExtractor(authOpts.TokenLookups))
}

parser := token.NewTokenParser(parseOpts...)
Expand Down Expand Up @@ -103,10 +103,10 @@ func WithUnauthorizedCallback(callback UnauthorizedCallback) AuthorizeOption {
}
}

// WithTokenKeys custom token key
func WithTokenKeys(tokenKeys []string) AuthorizeOption {
// WithTokenLookups used to set the source of the token
func WithTokenLookups(tokenLookups []string) AuthorizeOption {
return func(opts *AuthorizeOptions) {
opts.TokenKeys = tokenKeys
opts.TokenLookups = tokenLookups
}
}

Expand Down
4 changes: 2 additions & 2 deletions rest/handler/authhandler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,15 @@ func TestAuthHandler(t *testing.T) {
assert.Equal(t, "content", resp.Body.String())
}

func TestAuthHandler_WithTokenKeys(t *testing.T) {
func TestAuthHandler_WithTokenLookups(t *testing.T) {
const key = "B63F477D-BBA3-4E52-96D3-C0034C27694A"
req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
token, err := buildToken(key, map[string]any{
"key": "value",
}, 3600)
assert.Nil(t, err)
req.Header.Set("X-Token", token)
handler := Authorize(key, WithTokenKeys([]string{"Token", "X-Token"}))(
handler := Authorize(key, WithTokenLookups([]string{"header:X-Token"}))(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Test", "test")
_, err := w.Write([]byte("content"))
Expand Down
4 changes: 2 additions & 2 deletions rest/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ func WithJwt(jwt JWTConf) RouteOption {
validateSecret(jwt.AccessSecret)
r.jwt.enabled = true
r.jwt.secret = jwt.AccessSecret
r.jwt.tokenKeys = jwt.TokenKeys
r.jwt.tokenLookups = jwt.TokenLookup
}
}

Expand All @@ -210,7 +210,7 @@ func WithJwtTransition(jwt JWTTransConf) RouteOption {
r.jwt.enabled = true
r.jwt.secret = jwt.Secret
r.jwt.prevSecret = jwt.PrevSecret
r.jwt.tokenKeys = jwt.TokenKeys
r.jwt.tokenLookups = jwt.TokenLookup

}
}
Expand Down
37 changes: 32 additions & 5 deletions rest/token/tokenparser.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,26 @@
package token

import (
"fmt"
"net/http"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/golang-jwt/jwt/v4"
"github.com/golang-jwt/jwt/v4/request"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/timex"
)

const claimHistoryResetDuration = time.Hour * 24
const (
claimHistoryResetDuration = time.Hour * 24

jwtLookupHeader = "header"
jwtLookupQuery = "query"
jwtLookupForm = "form"
)

type (
// ParseOption defines the method to customize a TokenParser.
Expand Down Expand Up @@ -122,12 +131,30 @@ func WithResetDuration(duration time.Duration) ParseOption {
}
}

func WithExtractor(tokenKeys []string) ParseOption {
// WithExtractor used to configure the token extraction method of the TokenParser.
func WithExtractor(tokenLookups []string) ParseOption {
return func(parser *TokenParser) {
var headerNames, argumentNames []string
for _, lookup := range tokenLookups {
parts := strings.Split(strings.TrimSpace(lookup), ":")
if len(parts) < 2 {
logx.Must(fmt.Errorf("extractor source for lookup could not be split into needed parts: %v", lookup))
}

source := strings.TrimSpace(parts[0])
name := strings.TrimSpace(parts[1])
switch source {
case jwtLookupHeader:
headerNames = append(headerNames, name)
case jwtLookupQuery, jwtLookupForm:
argumentNames = append(argumentNames, name)
}
}

parser.extractor = request.MultiExtractor{
request.HeaderExtractor(tokenKeys),
request.ArgumentExtractor(tokenKeys),
parser.extractor,
request.HeaderExtractor(headerNames),
request.ArgumentExtractor(argumentNames),
request.AuthorizationHeaderExtractor,
}
}
}
Expand Down
36 changes: 34 additions & 2 deletions rest/token/tokenparser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package token
import (
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -55,7 +57,7 @@ func TestTokenParser_CustomHeader(t *testing.T) {
assert.Nil(t, err)
req.Header.Set("Token", token)

parser := NewTokenParser(WithExtractor([]string{"Token"}))
parser := NewTokenParser(WithExtractor([]string{"header:Token"}))
tok, err := parser.ParseToken(req, key, prevKey)
assert.Nil(t, err)
assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"])
Expand All @@ -78,7 +80,7 @@ func TestTokenParser_URLArgument(t *testing.T) {

req := httptest.NewRequest(http.MethodGet, "http://localhost?token="+token, http.NoBody)

parser := NewTokenParser(WithExtractor([]string{"token"}))
parser := NewTokenParser(WithExtractor([]string{"query:token"}))
tok, err := parser.ParseToken(req, key, prevKey)
assert.Nil(t, err)
assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"])
Expand All @@ -91,6 +93,36 @@ func TestTokenParser_URLArgument(t *testing.T) {
assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"])
}

func TestTokenParser_FormArgument(t *testing.T) {
const (
key = "14F17379-EB8F-411B-8F12-6929002DCA76"
prevKey = "B63F477D-BBA3-4E52-96D3-C0034C27694A"
)
token, err := buildToken(key, map[string]any{"key": "value"}, 3600)
assert.Nil(t, err)

// create form data
form := url.Values{}
form.Add("form_token", token)

// Using httptest.NewRequest to create a fake POST request
req := httptest.NewRequest(http.MethodPost, "http://localhost", strings.NewReader(form.Encode()))
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")

parser := NewTokenParser(WithExtractor([]string{"form:form_token"}))
tok, err := parser.ParseToken(req, key, prevKey)
assert.Nil(t, err)
assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"])
tok, err = parser.ParseToken(req, key, prevKey)
assert.Nil(t, err)
assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"])
parser.resetTime = timex.Now() - time.Hour
tok, err = parser.ParseToken(req, key, prevKey)
assert.Nil(t, err)
assert.Equal(t, "value", tok.Claims.(jwt.MapClaims)["key"])

}

func TestTokenParser_Expired(t *testing.T) {
const (
key = "14F17379-EB8F-411B-8F12-6929002DCA76"
Expand Down
8 changes: 4 additions & 4 deletions rest/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ type (
RouteOption func(r *featuredRoutes)

jwtSetting struct {
enabled bool
secret string
prevSecret string
tokenKeys []string
enabled bool
secret string
prevSecret string
tokenLookups []string
}

signatureSetting struct {
Expand Down

0 comments on commit 5fad085

Please sign in to comment.