From eb0f0b3f4c7254713a155ac699fb7989a7a13a32 Mon Sep 17 00:00:00 2001 From: Michal Krzyz Date: Fri, 22 Nov 2024 14:47:25 +0100 Subject: [PATCH] feat(authN): Redesign JWT token auth #372 Split verify function for JWT token auth to make it more readable --- .../api/graphql/access/token_auth_method.go | 68 ++++++++++++++----- internal/e2e/token_auth_test.go | 2 +- 2 files changed, 52 insertions(+), 18 deletions(-) diff --git a/internal/api/graphql/access/token_auth_method.go b/internal/api/graphql/access/token_auth_method.go index b87943f6..b3adaa94 100644 --- a/internal/api/graphql/access/token_auth_method.go +++ b/internal/api/graphql/access/token_auth_method.go @@ -40,35 +40,69 @@ type TokenAuthMethod struct { } func (tam TokenAuthMethod) Verify(c *gin.Context) error { - verifyError := func(s string) error { - return fmt.Errorf("TokenAuthMethod(%s)", s) + + tokenString, err := getTokenFromHeader(c) + if err != nil { + return err + } + + claims, err := tam.verifyTokenAndGetClaimsFromTokenString(tokenString) + if err != nil { + return err + } + + err = tam.verifyTokenExpiration(claims) + if err != nil { + return err } + scannerNameToContext(c, claims) + + return nil +} + +func verifyError(s string) error { + return fmt.Errorf("TokenAuthMethod(%s)", s) +} + +func getTokenFromHeader(c *gin.Context) (string, error) { + var err error tokenString := c.GetHeader(tokenAuthHeader) if tokenString == "" { - return verifyError("No authorization header") + err = verifyError("No authorization header") } - token, claims, err := tam.parseFromString(tokenString) + return tokenString, err +} + +func (tam TokenAuthMethod) verifyTokenAndGetClaimsFromTokenString(tokenString string) (*TokenClaims, error) { + claims := &TokenClaims{} + token, err := jwt.ParseWithClaims(tokenString, claims, tam.parse) if err != nil { tam.logger.Error("JWT parsing error: ", err) - return verifyError("Token parsing error") - } else if !token.Valid || claims.ExpiresAt == nil { + err = verifyError("Token parsing error") + } else if !token.Valid { tam.logger.Error("Invalid token") - return verifyError("Invalid token") - } else if claims.ExpiresAt.Before(time.Now()) { + err = verifyError("Invalid token") + } else if claims.ExpiresAt == nil { + tam.logger.Error("Missing ExpiresAt in token claims") + err = verifyError("Missing ExpiresAt in token claims") + } + return claims, err +} + +func (tam TokenAuthMethod) verifyTokenExpiration(tc *TokenClaims) error { + var err error + if tc.ExpiresAt.Before(time.Now()) { tam.logger.Warn("Expired token") - return verifyError("Token expired") + err = verifyError("Expired token") } - c.Set(scannerNameKey, claims.RegisteredClaims.Subject) - ctx := context.WithValue(c.Request.Context(), ginContextKey, c) - c.Request = c.Request.WithContext(ctx) - return nil + return err } -func (tam TokenAuthMethod) parseFromString(tokenString string) (*jwt.Token, *TokenClaims, error) { - claims := &TokenClaims{} - token, err := jwt.ParseWithClaims(tokenString, claims, tam.parse) - return token, claims, err +func scannerNameToContext(c *gin.Context, tc *TokenClaims) { + c.Set(scannerNameKey, tc.RegisteredClaims.Subject) + ctx := context.WithValue(c.Request.Context(), ginContextKey, c) + c.Request = c.Request.WithContext(ctx) } func (tam *TokenAuthMethod) parse(token *jwt.Token) (interface{}, error) { diff --git a/internal/e2e/token_auth_test.go b/internal/e2e/token_auth_test.go index 33e55311..344bc140 100644 --- a/internal/e2e/token_auth_test.go +++ b/internal/e2e/token_auth_test.go @@ -91,7 +91,7 @@ var _ = Describe("Getting access via API", Label("e2e", "TokenAuthorization"), f token := GenerateInvalidJwt(cfg.AuthTokenSecret) resp := SendGetRequest(queryUrl, map[string]string{"X-Service-Authorization": token}) Expect(resp.StatusCode).To(Equal(401)) - ExpectErrorMessage(resp, "TokenAuthMethod(Invalid token)") + ExpectErrorMessage(resp, "TokenAuthMethod(Missing ExpiresAt in token claims)") }) }) })