diff --git a/v2/CHANGELOG.md b/v2/CHANGELOG.md index 27cdb674..af55f645 100644 --- a/v2/CHANGELOG.md +++ b/v2/CHANGELOG.md @@ -4,6 +4,7 @@ - Add tasks endpoints to v2 - Add missing endpoints from collections to v2 - Add missing endpoints from query to v2 +- Add SSO auth token implementation ## [2.1.3](https://github.com/arangodb/go-driver/tree/v2.1.3) (2025-02-21) - Switch to Go 1.22.11 diff --git a/v2/connection/auth_jwt_impl.go b/v2/connection/auth_jwt_impl.go index ec72c4de..3ce4c173 100644 --- a/v2/connection/auth_jwt_impl.go +++ b/v2/connection/auth_jwt_impl.go @@ -24,11 +24,20 @@ package connection import ( "context" + "encoding/base64" + "encoding/json" + "fmt" + "log" "net/http" + "strings" + "time" ) func NewJWTAuthWrapper(username, password string) Wrapper { - return WrapAuthentication(func(ctx context.Context, conn Connection) (authentication Authentication, err error) { + var token string + var expiry time.Time + + refresh := func(ctx context.Context, conn Connection) error { url := NewUrl("_open", "auth") var data jwtOpenResponse @@ -40,15 +49,31 @@ func NewJWTAuthWrapper(username, password string) Wrapper { resp, err := CallPost(ctx, conn, url, &data, j) if err != nil { - return nil, err + return err + } + if resp.Code() != http.StatusOK { + return NewError(resp.Code(), "unexpected code") + } + + token = data.Token + expiry, err = parseJWTExpiry(token) + if err != nil { + // Log for visibility but don't break functionality + log.Printf("failed to parse JWT expiry: %v", err) + expiry = time.Now().Add(1 * time.Minute) // fallback, so it will refresh immediately next time } + return nil + } - switch resp.Code() { - case http.StatusOK: - return NewHeaderAuth("Authorization", "bearer %s", data.Token), nil - default: - return nil, NewError(resp.Code(), "unexpected code") + return WrapAuthentication(func(ctx context.Context, conn Connection) (Authentication, error) { + // First time fetch + if token == "" || time.Now().After(expiry) { + if err := refresh(ctx, conn); err != nil { + return nil, err + } } + + return NewHeaderAuth("Authorization", "bearer %s", token), nil }) } @@ -59,5 +84,77 @@ type jwtOpenRequest struct { type jwtOpenResponse struct { Token string `json:"jwt"` + ExpiresIn int `json:"expires_in,omitempty"` MustChangePassword bool `json:"must_change_password,omitempty"` } + +func parseJWTExpiry(token string) (time.Time, error) { + parts := strings.Split(token, ".") + if len(parts) < 2 { + return time.Time{}, fmt.Errorf("invalid JWT format") + } + + payload, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return time.Time{}, err + } + + var claims struct { + Exp int64 `json:"exp"` + } + if err := json.Unmarshal(payload, &claims); err != nil { + return time.Time{}, err + } + + return time.Unix(claims.Exp, 0), nil +} + +func NewSSOAuthWrapper(initialToken string) Wrapper { + var token = initialToken + var expiry time.Time + // setToken updates the current JWT and its expiry time. + // If expiry parsing fails, we log the error and fall back to a short 1-minute lifetime. + // This ensures the token will be refreshed soon without breaking functionality. + setToken := func(newToken string) { + token = newToken + expiryTime, err := parseJWTExpiry(newToken) + if err != nil { + // Log for visibility but don't break functionality + log.Printf("failed to parse JWT expiry: %v", err) + expiry = time.Now().Add(1 * time.Minute) // fallback, so it will refresh immediately next time + } else { + expiry = expiryTime + } + } + + // If we already have a token (from an SSO login), parse expiry now + if token != "" { + setToken(token) + } + + return WrapAuthentication(func(ctx context.Context, conn Connection) (Authentication, error) { + // No token yet or expired — let caller know they must login via SSO + if token == "" || time.Now().After(expiry) { + // Try a call to _open/auth just to see if server sends 307 + url := NewUrl("_open", "auth") + var data jwtOpenResponse + // Intentionally passing nil: in SSO mode, /_open/auth expects no body + resp, err := CallPost(ctx, conn, url, &data, nil) + if err != nil { + return nil, err + } + + switch resp.Code() { + case http.StatusOK: + setToken(data.Token) + case http.StatusTemporaryRedirect: + loc := resp.Header("Location") + return nil, fmt.Errorf("SSO redirect: please authenticate via browser at %s", loc) + default: + return nil, NewError(resp.Code(), "unexpected code") + } + } + + return NewHeaderAuth("Authorization", "bearer %s", token), nil + }) +}