From e4db2d7f6cc0bd3be0bc5543a2b68b4d770d9bea Mon Sep 17 00:00:00 2001 From: Michal Krzyz Date: Tue, 5 Nov 2024 10:53:01 +0100 Subject: [PATCH 1/2] feat(authN): Redesign JWT token auth #372 Redesign JWT token authentication middleware to support additional/alternative authentication method --- README.md | 2 + docker-compose.yaml | 2 - internal/api/graphql/access/auth.go | 57 +++++++-- internal/api/graphql/access/no_auth.go | 21 ---- internal/api/graphql/access/test/util.go | 20 +-- internal/api/graphql/access/token_auth.go | 116 ------------------ .../api/graphql/access/token_auth_method.go | 109 ++++++++++++++++ ...auth_test.go => token_auth_method_test.go} | 22 ++-- internal/api/graphql/server.go | 2 +- .../database/mariadb/test/database_manager.go | 1 - internal/e2e/token_auth_test.go | 29 +++-- internal/util/config.go | 6 +- tools/token_generator/main.go | 2 +- 13 files changed, 195 insertions(+), 194 deletions(-) delete mode 100644 internal/api/graphql/access/no_auth.go delete mode 100644 internal/api/graphql/access/token_auth.go create mode 100644 internal/api/graphql/access/token_auth_method.go rename internal/api/graphql/access/{token_auth_test.go => token_auth_method_test.go} (76%) diff --git a/README.md b/README.md index dd538cb1..06df3475 100644 --- a/README.md +++ b/README.md @@ -61,6 +61,8 @@ LOCAL_TEST_DB=true SEED_MODE=false ``` +To enable JWT token authentication, define `AUTH_TOKEN_SECRET` environment variable. Those variable is read by application on startup to start token validation middleware. + ### Docker The `docker-compose.yml` file defines two profiles: `db` for the `heureka-db` service and `heureka` for the `heureka-app` service. diff --git a/docker-compose.yaml b/docker-compose.yaml index b6e388c2..637c2c67 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -45,8 +45,6 @@ services: DB_NAME: ${DB_NAME} DB_SCHEMA: /app_sqlschema/schema.sql SEED_MODE: ${SEED_MODE} - AUTH_TYPE: token - AUTH_TOKEN_SECRET: xxx volumes: - ./internal/database/mariadb/init/schema.sql:/app_sqlschema/schema.sql depends_on: diff --git a/internal/api/graphql/access/auth.go b/internal/api/graphql/access/auth.go index 4df956d1..faaf6a6a 100644 --- a/internal/api/graphql/access/auth.go +++ b/internal/api/graphql/access/auth.go @@ -4,7 +4,9 @@ package access import ( - "strings" + "fmt" + "net/http" + "reflect" "github.com/gin-gonic/gin" "github.com/sirupsen/logrus" @@ -17,23 +19,52 @@ type Logger interface { Warn(...interface{}) } -type Auth interface { - GetMiddleware() gin.HandlerFunc +func NewAuth(cfg *util.Config) *Auth { + l := newLogger() + auth := Auth{logger: l} + auth.AppendInstance(NewTokenAuthMethod(l, cfg)) + //TODO: auth.AppendInstance(NewOidcAuthMethod(l, cfg)) + return &auth } -func NewAuth(cfg *util.Config) Auth { - l := newLogger() +type Auth struct { + chain []AuthMethod + logger Logger +} - authType := strings.ToLower(cfg.AuthType) - if authType == "token" { - return NewTokenAuth(l, cfg) - } else if authType == "none" { - return NewNoAuth() - } +type AuthMethod interface { + Verify(*gin.Context) error +} - l.Warn("AUTH_TYPE is not set, assuming 'none' authorization method") +func (a *Auth) GetMiddleware() gin.HandlerFunc { + return func(authCtx *gin.Context) { + if len(a.chain) > 0 { + var retMsg string + for _, auth := range a.chain { + if err := auth.Verify(authCtx); err == nil { + authCtx.Next() + return + } else { + if retMsg != "" { + retMsg = fmt.Sprintf("%s, ", retMsg) + } + retMsg = fmt.Sprintf("%s%s", retMsg, err) + } + } + a.logger.Error("Unauthorized access: %s", retMsg) + authCtx.JSON(http.StatusUnauthorized, gin.H{"error": retMsg}) + authCtx.Abort() + return + } + authCtx.Next() + return + } +} - return NewNoAuth() +func (a *Auth) AppendInstance(am AuthMethod) { + if !reflect.ValueOf(am).IsNil() { + a.chain = append(a.chain, am) + } } func newLogger() Logger { diff --git a/internal/api/graphql/access/no_auth.go b/internal/api/graphql/access/no_auth.go deleted file mode 100644 index e0cc7bea..00000000 --- a/internal/api/graphql/access/no_auth.go +++ /dev/null @@ -1,21 +0,0 @@ -// SPDX-FileCopyrightText: 2024 SAP SE or an SAP affiliate company and Greenhouse contributors -// SPDX-License-Identifier: Apache-2.0 - -package access - -import ( - "github.com/gin-gonic/gin" -) - -type NoAuth struct { -} - -func NewNoAuth() *NoAuth { - return &NoAuth{} -} - -func (no *NoAuth) GetMiddleware() gin.HandlerFunc { - return func(c *gin.Context) { - c.Next() - } -} diff --git a/internal/api/graphql/access/test/util.go b/internal/api/graphql/access/test/util.go index 51ad46ae..334e8dc9 100644 --- a/internal/api/graphql/access/test/util.go +++ b/internal/api/graphql/access/test/util.go @@ -19,7 +19,7 @@ import ( ) const ( - testUsername = "testUser" + testClientName = "testClientName" ) func SendGetRequest(url string, headers map[string]string) *http.Response { @@ -53,7 +53,7 @@ type Jwt struct { signingMethod jwt.SigningMethod signKey interface{} expiresAt *jwt.NumericDate - username string + name string } func NewJwt(secret string) *Jwt { @@ -64,8 +64,8 @@ func NewRsaJwt(privKey *rsa.PrivateKey) *Jwt { return &Jwt{signKey: privKey, signingMethod: jwt.SigningMethodRS256} } -func (j *Jwt) WithUsername(username string) *Jwt { - j.username = username +func (j *Jwt) WithName(name string) *Jwt { + j.name = name return j } @@ -81,7 +81,7 @@ func (j *Jwt) String() string { ExpiresAt: j.expiresAt, IssuedAt: jwt.NewNumericDate(time.Now()), Issuer: "heureka", - Subject: j.username, + Subject: j.name, }, } token := jwt.NewWithClaims(j.signingMethod, claims) @@ -92,15 +92,15 @@ func (j *Jwt) String() string { } func GenerateJwt(jwtSecret string, expiresIn time.Duration) string { - return NewJwt(jwtSecret).WithExpiresAt(time.Now().Add(expiresIn)).WithUsername(testUsername).String() + return NewJwt(jwtSecret).WithExpiresAt(time.Now().Add(expiresIn)).WithName(testClientName).String() } -func GenerateJwtWithUsername(jwtSecret string, expiresIn time.Duration, username string) string { - return NewJwt(jwtSecret).WithExpiresAt(time.Now().Add(expiresIn)).WithUsername(username).String() +func GenerateJwtWithName(jwtSecret string, expiresIn time.Duration, name string) string { + return NewJwt(jwtSecret).WithExpiresAt(time.Now().Add(expiresIn)).WithName(name).String() } func GenerateInvalidJwt(jwtSecret string) string { - return NewJwt(jwtSecret).WithUsername(testUsername).String() + return NewJwt(jwtSecret).WithName(testClientName).String() } func GenerateRsaPrivateKey() *rsa.PrivateKey { @@ -110,5 +110,5 @@ func GenerateRsaPrivateKey() *rsa.PrivateKey { } func GenerateJwtWithInvalidSigningMethod(jwtSecret string, expiresIn time.Duration) string { - return NewRsaJwt(GenerateRsaPrivateKey()).WithExpiresAt(time.Now().Add(expiresIn)).WithUsername(testUsername).String() + return NewRsaJwt(GenerateRsaPrivateKey()).WithExpiresAt(time.Now().Add(expiresIn)).WithName(testClientName).String() } diff --git a/internal/api/graphql/access/token_auth.go b/internal/api/graphql/access/token_auth.go deleted file mode 100644 index d661c598..00000000 --- a/internal/api/graphql/access/token_auth.go +++ /dev/null @@ -1,116 +0,0 @@ -// SPDX-FileCopyrightText: 2024 SAP SE or an SAP affiliate company and Greenhouse contributors -// SPDX-License-Identifier: Apache-2.0 - -package access - -import ( - "context" - "fmt" - "net/http" - "time" - - "github.com/gin-gonic/gin" - "github.com/golang-jwt/jwt/v5" - - "github.com/cloudoperators/heureka/internal/util" -) - -const ( - ginContextKey ginContextKeyType = "GinContextKey" - usernameKey string = "username" -) - -type ginContextKeyType string - -type TokenAuth struct { - logger Logger - secret []byte -} - -func NewTokenAuth(l Logger, cfg *util.Config) *TokenAuth { - return &TokenAuth{logger: l, secret: []byte(cfg.AuthTokenSecret)} -} - -type TokenClaims struct { - Version string `json:"version"` - jwt.RegisteredClaims -} - -func (ta *TokenAuth) GetMiddleware() gin.HandlerFunc { - return func(c *gin.Context) { - tokenString := c.GetHeader("Authorization") - - if tokenString == "" { - ta.logger.Error("Trying to use API without authorization header") - c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization header is required"}) - c.Abort() - return - } - - token, claims, err := ta.parseFromString(tokenString) - if err != nil { - ta.logger.Error("JWT parsing error: ", err.Error()) - c.JSON(http.StatusUnauthorized, gin.H{"error": "Token parsing error"}) - c.Abort() - return - } else if !token.Valid || claims.ExpiresAt == nil { - ta.logger.Error("Invalid token") - c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"}) - c.Abort() - return - } else if claims.ExpiresAt.Before(time.Now()) { - ta.logger.Warn("Expired token") - c.JSON(http.StatusUnauthorized, gin.H{"error": "Token expired"}) - c.Abort() - return - } - - c.Set(usernameKey, claims.RegisteredClaims.Subject) - ctx := context.WithValue(c.Request.Context(), ginContextKey, c) - c.Request = c.Request.WithContext(ctx) - c.Next() - } -} - -func (ta *TokenAuth) parseFromString(tokenString string) (*jwt.Token, *TokenClaims, error) { - claims := &TokenClaims{} - token, err := jwt.ParseWithClaims(tokenString, claims, ta.parse) - return token, claims, err -} - -func (ta *TokenAuth) parse(token *jwt.Token) (interface{}, error) { - if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, fmt.Errorf("Invalid JWT parse method") - } - return ta.secret, nil -} - -func UsernameFromContext(ctx context.Context) (string, error) { - gc, err := ginContextFromContext(ctx) - if err != nil { - return "", err - } - - u, ok := gc.Get(usernameKey) - if !ok { - return "", fmt.Errorf("could not find username in gin.Context") - } - us, ok := u.(string) - if !ok { - return "", fmt.Errorf("invalid username type") - } - return us, nil -} - -func ginContextFromContext(ctx context.Context) (*gin.Context, error) { - ginContext := ctx.Value(ginContextKey) - if ginContext == nil { - return nil, fmt.Errorf("could not retrieve gin.Context") - } - - gc, ok := ginContext.(*gin.Context) - if !ok { - return nil, fmt.Errorf("gin.Context has wrong type") - } - return gc, nil -} diff --git a/internal/api/graphql/access/token_auth_method.go b/internal/api/graphql/access/token_auth_method.go new file mode 100644 index 00000000..b87943f6 --- /dev/null +++ b/internal/api/graphql/access/token_auth_method.go @@ -0,0 +1,109 @@ +// SPDX-FileCopyrightText: 2024 SAP SE or an SAP affiliate company and Greenhouse contributors +// SPDX-License-Identifier: Apache-2.0 + +package access + +import ( + "context" + "fmt" + "time" + + "github.com/gin-gonic/gin" + "github.com/golang-jwt/jwt/v5" + + "github.com/cloudoperators/heureka/internal/util" +) + +type ginContextKeyType string + +const ( + ginContextKey ginContextKeyType = "GinContextKey" + scannerNameKey string = "scannername" + tokenAuthHeader string = "X-Service-Authorization" +) + +func NewTokenAuthMethod(l Logger, cfg *util.Config) *TokenAuthMethod { + if cfg.AuthTokenSecret != "" { + return &TokenAuthMethod{logger: l, secret: []byte(cfg.AuthTokenSecret)} + } + return nil +} + +type TokenClaims struct { + Version string `json:"version"` + jwt.RegisteredClaims +} + +type TokenAuthMethod struct { + logger Logger + secret []byte +} + +func (tam TokenAuthMethod) Verify(c *gin.Context) error { + verifyError := func(s string) error { + return fmt.Errorf("TokenAuthMethod(%s)", s) + } + + tokenString := c.GetHeader(tokenAuthHeader) + if tokenString == "" { + return verifyError("No authorization header") + } + token, claims, err := tam.parseFromString(tokenString) + if err != nil { + tam.logger.Error("JWT parsing error: ", err) + return verifyError("Token parsing error") + } else if !token.Valid || claims.ExpiresAt == nil { + tam.logger.Error("Invalid token") + return verifyError("Invalid token") + } else if claims.ExpiresAt.Before(time.Now()) { + tam.logger.Warn("Expired token") + return verifyError("Token expired") + } + c.Set(scannerNameKey, claims.RegisteredClaims.Subject) + ctx := context.WithValue(c.Request.Context(), ginContextKey, c) + c.Request = c.Request.WithContext(ctx) + return nil +} + +func (tam TokenAuthMethod) parseFromString(tokenString string) (*jwt.Token, *TokenClaims, error) { + claims := &TokenClaims{} + token, err := jwt.ParseWithClaims(tokenString, claims, tam.parse) + return token, claims, err +} + +func (tam *TokenAuthMethod) parse(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("Invalid JWT parse method") + } + return tam.secret, nil +} + +func ScannerNameFromContext(ctx context.Context) (string, error) { + gc, err := ginContextFromContext(ctx) + if err != nil { + return "", err + } + + s, ok := gc.Get(scannerNameKey) + if !ok { + return "", fmt.Errorf("could not find scanner name in gin.Context") + } + ss, ok := s.(string) + if !ok { + return "", fmt.Errorf("invalid scanner name type") + } + return ss, nil +} + +func ginContextFromContext(ctx context.Context) (*gin.Context, error) { + ginContext := ctx.Value(ginContextKey) + if ginContext == nil { + return nil, fmt.Errorf("could not retrieve gin.Context") + } + + gc, ok := ginContext.(*gin.Context) + if !ok { + return nil, fmt.Errorf("gin.Context has wrong type") + } + return gc, nil +} diff --git a/internal/api/graphql/access/token_auth_test.go b/internal/api/graphql/access/token_auth_method_test.go similarity index 76% rename from internal/api/graphql/access/token_auth_test.go rename to internal/api/graphql/access/token_auth_method_test.go index a81deed2..fe7eb73b 100644 --- a/internal/api/graphql/access/token_auth_test.go +++ b/internal/api/graphql/access/token_auth_method_test.go @@ -24,7 +24,7 @@ import ( const ( testEndpoint = "/testendpoint" - testUsername = "testAccessUser" + testScannerName = "testAccessScanner" authTokenSecret = "xxx" ) @@ -46,7 +46,7 @@ type server struct { func (s *server) startInBackground(port string) { s.lastRequestCtx = context.TODO() - auth := access.NewTokenAuth(&noLogLogger{}, &util.Config{AuthTokenSecret: authTokenSecret}) + auth := access.NewAuth(&util.Config{AuthTokenSecret: authTokenSecret}) r := gin.Default() r.Use(auth.GetMiddleware()) r.GET(testEndpoint, func(c *gin.Context) { @@ -92,27 +92,27 @@ var _ = Describe("Pass token data via context when using token auth middleware", testServer.stop() }) - When("User access api through token auth middleware with valid token", func() { + When("Scanner access api through token auth middleware with valid token", func() { BeforeEach(func() { - token := GenerateJwtWithUsername(authTokenSecret, 1*time.Hour, testUsername) - resp := SendGetRequest(url, map[string]string{"Authorization": token}) + token := GenerateJwtWithName(authTokenSecret, 1*time.Hour, testScannerName) + resp := SendGetRequest(url, map[string]string{"X-Service-Authorization": token}) Expect(resp.StatusCode).To(Equal(200)) }) - It("Should be able to access user name from request context", func() { - username, err := access.UsernameFromContext(testServer.context()) + It("Should be able to access scanner name from request context", func() { + name, err := access.ScannerNameFromContext(testServer.context()) Expect(err).To(BeNil()) - Expect(username).To(BeEquivalentTo(testUsername)) + Expect(name).To(BeEquivalentTo(testScannerName)) }) }) - When("User access api through token auth middleware with invalid token", func() { + When("Scanner access api through token auth middleware with invalid token", func() { BeforeEach(func() { token := GenerateInvalidJwt(authTokenSecret) - resp := SendGetRequest(url, map[string]string{"Authorization": token}) + resp := SendGetRequest(url, map[string]string{"X-Service-Authorization": token}) Expect(resp.StatusCode).To(Equal(401)) }) It("Should not store gin context in request context", func() { - _, err := access.UsernameFromContext(testServer.context()) + _, err := access.ScannerNameFromContext(testServer.context()) Expect(err).ShouldNot(BeNil()) }) }) diff --git a/internal/api/graphql/server.go b/internal/api/graphql/server.go index e180544b..b8b82b6a 100644 --- a/internal/api/graphql/server.go +++ b/internal/api/graphql/server.go @@ -18,7 +18,7 @@ type GraphQLAPI struct { Server *handler.Server App app.Heureka - auth access.Auth + auth *access.Auth } func NewGraphQLAPI(a app.Heureka, cfg util.Config) *GraphQLAPI { diff --git a/internal/database/mariadb/test/database_manager.go b/internal/database/mariadb/test/database_manager.go index 8006d2e9..d1e30765 100644 --- a/internal/database/mariadb/test/database_manager.go +++ b/internal/database/mariadb/test/database_manager.go @@ -149,7 +149,6 @@ func (dbm *LocalTestDataBaseManager) NewTestSchema() *mariadb.SqlDatabase { dbm.Schemas = append(dbm.Schemas, schemaName) dbm.CurrentSchema = schemaName dbm.Config.DBName = schemaName - dbm.Config.AuthType = "none" err := dbm.dbClient.SetupSchema(dbm.Config.Config) if err != nil { diff --git a/internal/e2e/token_auth_test.go b/internal/e2e/token_auth_test.go index ab2a8ce6..33e55311 100644 --- a/internal/e2e/token_auth_test.go +++ b/internal/e2e/token_auth_test.go @@ -29,7 +29,6 @@ var _ = Describe("Getting access via API", Label("e2e", "TokenAuthorization"), f cfg = dbm.DbConfig() cfg.Port = util2.GetRandomFreePort() - cfg.AuthType = "token" cfg.AuthTokenSecret = "xxx" s = server.NewServer(cfg) @@ -45,54 +44,54 @@ var _ = Describe("Getting access via API", Label("e2e", "TokenAuthorization"), f When("trying to access query resource with valid token", func() { It("respond with 200", func() { token := GenerateJwt(cfg.AuthTokenSecret, 1*time.Hour) - resp := SendGetRequest(queryUrl, map[string]string{"Authorization": token}) + resp := SendGetRequest(queryUrl, map[string]string{"X-Service-Authorization": token}) Expect(resp.StatusCode).To(Equal(200)) }) }) - When("trying to access query resource without 'Authorization' header", func() { + When("trying to access query resource without 'X-Service-Authorization' header", func() { It("respond with 401", func() { resp := SendGetRequest(queryUrl, nil) Expect(resp.StatusCode).To(Equal(401)) - ExpectErrorMessage(resp, "Authorization header is required") + ExpectErrorMessage(resp, "TokenAuthMethod(No authorization header)") }) }) - When("trying to access query resource with invalid 'Authorization' header", func() { + When("trying to access query resource with invalid 'X-Service-Authorization' header", func() { It("respond with 401", func() { - resp := SendGetRequest(queryUrl, map[string]string{"Authorization": "invalidHeader"}) + resp := SendGetRequest(queryUrl, map[string]string{"X-Service-Authorization": "invalidHeader"}) Expect(resp.StatusCode).To(Equal(401)) - ExpectErrorMessage(resp, "Token parsing error") + ExpectErrorMessage(resp, "TokenAuthMethod(Token parsing error)") }) }) When("trying to access query resource with expired token", func() { It("respond with 401", func() { token := GenerateJwt(cfg.AuthTokenSecret, -1*time.Hour) - resp := SendGetRequest(queryUrl, map[string]string{"Authorization": token}) + resp := SendGetRequest(queryUrl, map[string]string{"X-Service-Authorization": token}) Expect(resp.StatusCode).To(Equal(401)) - ExpectErrorMessage(resp, "Token parsing error") + ExpectErrorMessage(resp, "TokenAuthMethod(Token parsing error)") }) }) When("trying to access query resource with token created using invalid secret", func() { It("respond with 401", func() { token := GenerateJwt("invalidSecret", 1*time.Hour) - resp := SendGetRequest(queryUrl, map[string]string{"Authorization": token}) + resp := SendGetRequest(queryUrl, map[string]string{"X-Service-Authorization": token}) Expect(resp.StatusCode).To(Equal(401)) - ExpectErrorMessage(resp, "Token parsing error") + ExpectErrorMessage(resp, "TokenAuthMethod(Token parsing error)") }) }) When("trying to access query resource with token created using invalid signing method", func() { It("respond with 401", func() { token := GenerateJwtWithInvalidSigningMethod(cfg.AuthTokenSecret, 1*time.Hour) - resp := SendGetRequest(queryUrl, map[string]string{"Authorization": token}) + resp := SendGetRequest(queryUrl, map[string]string{"X-Service-Authorization": token}) Expect(resp.StatusCode).To(Equal(401)) - ExpectErrorMessage(resp, "Token parsing error") + ExpectErrorMessage(resp, "TokenAuthMethod(Token parsing error)") }) }) When("trying to access query resource with invalid token", func() { It("respond with 401", func() { token := GenerateInvalidJwt(cfg.AuthTokenSecret) - resp := SendGetRequest(queryUrl, map[string]string{"Authorization": token}) + resp := SendGetRequest(queryUrl, map[string]string{"X-Service-Authorization": token}) Expect(resp.StatusCode).To(Equal(401)) - ExpectErrorMessage(resp, "Invalid token") + ExpectErrorMessage(resp, "TokenAuthMethod(Invalid token)") }) }) }) diff --git a/internal/util/config.go b/internal/util/config.go index 1f2e8a42..84ed4ce3 100644 --- a/internal/util/config.go +++ b/internal/util/config.go @@ -32,9 +32,9 @@ type Config struct { //Environment string `envconfig:"ENVIRONMENT" required:"true" json:"environment"` //// https://pkg.go.dev/github.com/robfig/cron#hdr-Predefined_schedules //DiscoverySchedule string `envconfig:"DISOVERY_SCHEDULE" default:"0 0 0 * * *" json:"discoverySchedule"` - SeedMode bool `envconfig:"SEED_MODE" required:"false" default:"false" json:"seedMode"` - AuthType string `envconfig:"AUTH_TYPE" required:"false" json:"-" default:"none"` - AuthTokenSecret string `envconfig:"AUTH_TOKEN_SECRET" required:"false" json:"-"` + SeedMode bool `envconfig:"SEED_MODE" required:"false" default:"false" json:"seedMode"` + AuthTokenSecret string `envconfig:"AUTH_TOKEN_SECRET" required:"false" json:"-"` + //TODO: add: AuthOidcUrl string `envconfig:"AUTH_OIDC_URL" required:"false" json:"-"` DefaultIssuePriority int64 `envconfig:"DEFAULT_ISSUE_PRIORITY" default:"100" json:"defaultIssuePriority"` DefaultRepositoryName string `envconfig:"DEFAULT_REPOSITORY_NAME" default:"nvd" json:"defaultRepositoryName"` } diff --git a/tools/token_generator/main.go b/tools/token_generator/main.go index d7cc30d2..eeee024e 100644 --- a/tools/token_generator/main.go +++ b/tools/token_generator/main.go @@ -26,7 +26,7 @@ func GenerateJWT(jwtSecret []byte, expireIn time.Duration) (string, error) { ExpiresAt: jwt.NewNumericDate(time.Now().Add(expireIn)), IssuedAt: jwt.NewNumericDate(time.Now()), Issuer: "heureka", - Subject: "testUser", + Subject: "testclient", }, } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) From f36bc07e2f0dc002872fc71c13b782a2f8ef2fb5 Mon Sep 17 00:00:00 2001 From: Michal Krzyz Date: Fri, 22 Nov 2024 14:47:25 +0100 Subject: [PATCH 2/2] claimsfeat(authN): Redesign JWT token auth #372 Split verify function for JWT token auth to make it more readable --- .../api/graphql/access/token_auth_method.go | 68 ++++++++++++++----- internal/e2e/token_auth_test.go | 2 +- 2 files changed, 52 insertions(+), 18 deletions(-) diff --git a/internal/api/graphql/access/token_auth_method.go b/internal/api/graphql/access/token_auth_method.go index b87943f6..87685dcd 100644 --- a/internal/api/graphql/access/token_auth_method.go +++ b/internal/api/graphql/access/token_auth_method.go @@ -40,35 +40,69 @@ type TokenAuthMethod struct { } func (tam TokenAuthMethod) Verify(c *gin.Context) error { - verifyError := func(s string) error { - return fmt.Errorf("TokenAuthMethod(%s)", s) + + tokenString, err := getTokenFromHeader(c) + if err != nil { + return err + } + + claims, err := tam.verifyTokenAndGetClaimsFromTokenString(tokenString) + if err != nil { + return err + } + + err = tam.verifyTokenExpiration(claims) + if err != nil { + return err } + scannerNameToContext(c, claims) + + return nil +} + +func verifyError(s string) error { + return fmt.Errorf("TokenAuthMethod(%s)", s) +} + +func getTokenFromHeader(c *gin.Context) (string, error) { + var err error tokenString := c.GetHeader(tokenAuthHeader) if tokenString == "" { - return verifyError("No authorization header") + err = verifyError("No authorization header") } - token, claims, err := tam.parseFromString(tokenString) + return tokenString, err +} + +func (tam TokenAuthMethod) verifyTokenAndGetClaimsFromTokenString(tokenString string) (*TokenClaims, error) { + claims := &TokenClaims{} + token, err := jwt.ParseWithClaims(tokenString, claims, tam.parse) if err != nil { tam.logger.Error("JWT parsing error: ", err) - return verifyError("Token parsing error") - } else if !token.Valid || claims.ExpiresAt == nil { + err = verifyError("Token parsing error") + } else if !token.Valid { tam.logger.Error("Invalid token") - return verifyError("Invalid token") - } else if claims.ExpiresAt.Before(time.Now()) { + err = verifyError("Invalid token") + } + return claims, err +} + +func (tam TokenAuthMethod) verifyTokenExpiration(tc *TokenClaims) error { + var err error + if tc.ExpiresAt == nil { + tam.logger.Error("Missing ExpiresAt in token claims") + err = verifyError("Missing ExpiresAt in token claims") + } else if tc.ExpiresAt.Before(time.Now()) { tam.logger.Warn("Expired token") - return verifyError("Token expired") + err = verifyError("Expired token") } - c.Set(scannerNameKey, claims.RegisteredClaims.Subject) - ctx := context.WithValue(c.Request.Context(), ginContextKey, c) - c.Request = c.Request.WithContext(ctx) - return nil + return err } -func (tam TokenAuthMethod) parseFromString(tokenString string) (*jwt.Token, *TokenClaims, error) { - claims := &TokenClaims{} - token, err := jwt.ParseWithClaims(tokenString, claims, tam.parse) - return token, claims, err +func scannerNameToContext(c *gin.Context, tc *TokenClaims) { + c.Set(scannerNameKey, tc.RegisteredClaims.Subject) + ctx := context.WithValue(c.Request.Context(), ginContextKey, c) + c.Request = c.Request.WithContext(ctx) } func (tam *TokenAuthMethod) parse(token *jwt.Token) (interface{}, error) { diff --git a/internal/e2e/token_auth_test.go b/internal/e2e/token_auth_test.go index 33e55311..344bc140 100644 --- a/internal/e2e/token_auth_test.go +++ b/internal/e2e/token_auth_test.go @@ -91,7 +91,7 @@ var _ = Describe("Getting access via API", Label("e2e", "TokenAuthorization"), f token := GenerateInvalidJwt(cfg.AuthTokenSecret) resp := SendGetRequest(queryUrl, map[string]string{"X-Service-Authorization": token}) Expect(resp.StatusCode).To(Equal(401)) - ExpectErrorMessage(resp, "TokenAuthMethod(Invalid token)") + ExpectErrorMessage(resp, "TokenAuthMethod(Missing ExpiresAt in token claims)") }) }) })