From 92a1613e592cab6c6a32d77e4b6c7594cb4e054f Mon Sep 17 00:00:00 2001 From: Jakub Bednar Date: Fri, 19 Jan 2024 09:18:35 +0100 Subject: [PATCH] feat(oauth): add possibility to specify OAuthLogoutEndpoint for logout from OAuth Identity provider --- oauth2/code_exchange_test.go | 8 ++++---- oauth2/mux.go | 16 ++++++++++------ oauth2/mux_test.go | 2 +- server/server.go | 11 ++++++----- 4 files changed, 21 insertions(+), 16 deletions(-) diff --git a/oauth2/code_exchange_test.go b/oauth2/code_exchange_test.go index 230786bda1..2c5b06141f 100644 --- a/oauth2/code_exchange_test.go +++ b/oauth2/code_exchange_test.go @@ -32,7 +32,7 @@ func Test_CodeExchangeCSRF_AuthCodeURL(t *testing.T) { ProviderURL: "http://localhost:1234", Orgs: "", } - authMux := NewAuthMux(mp, auth, mt, "", clog.New(clog.ParseLevel("debug")), useidtoken, "hello", nil, nil) + authMux := NewAuthMux(mp, auth, mt, "", clog.New(clog.ParseLevel("debug")), useidtoken, "hello", nil, nil, "") // create AuthCodeURL with code exchange without PKCE codeExchange := NewCodeExchange(false, "") @@ -95,7 +95,7 @@ func Test_CodeExchangeCSRF_ExchangeCodeForToken(t *testing.T) { ProviderURL: authServer.URL, Orgs: "", } - authMux := NewAuthMux(mp, auth, auth.Tokens, "", clog.New(clog.ParseLevel("debug")), useidtoken, "hi", nil, nil) + authMux := NewAuthMux(mp, auth, auth.Tokens, "", clog.New(clog.ParseLevel("debug")), useidtoken, "hi", nil, nil, "") // create AuthCodeURL using CodeExchange with PKCE codeExchange := simpleTokenExchange @@ -136,7 +136,7 @@ func Test_CodeExchangePKCE_AuthCodeURL(t *testing.T) { ProviderURL: "http://localhost:1234", Orgs: "", } - authMux := NewAuthMux(mp, auth, mt, "", clog.New(clog.ParseLevel("debug")), useidtoken, "hi", nil, nil) + authMux := NewAuthMux(mp, auth, mt, "", clog.New(clog.ParseLevel("debug")), useidtoken, "hi", nil, nil, "") // create AuthCodeURL using CodeExchange with PKCE codeExchange := NewCodeExchange(true, "secret") @@ -213,7 +213,7 @@ func Test_CodeExchangePKCE_ExchangeCodeForToken(t *testing.T) { ProviderURL: authServer.URL, Orgs: "", } - authMux := NewAuthMux(mp, auth, jwt, "", clog.New(clog.ParseLevel("debug")), useidtoken, "hi", nil, nil) + authMux := NewAuthMux(mp, auth, jwt, "", clog.New(clog.ParseLevel("debug")), useidtoken, "hi", nil, nil, "") // create AuthCodeURL using CodeExchange with PKCE codeExchange := CodeExchangePKCE{Secret: secret} diff --git a/oauth2/mux.go b/oauth2/mux.go index 11c873e3db..5837e3d906 100644 --- a/oauth2/mux.go +++ b/oauth2/mux.go @@ -15,20 +15,24 @@ var _ Mux = &AuthMux{} const TenMinutes = 10 * time.Minute // NewAuthMux constructs a Mux handler that checks a cookie against the authenticator -func NewAuthMux(p Provider, a Authenticator, t Tokenizer, - basepath string, l chronograf.Logger, - UseIDToken bool, LoginHint string, - client *http.Client, codeExchange CodeExchange, -) *AuthMux { +func NewAuthMux(p Provider, a Authenticator, t Tokenizer, basepath string, l chronograf.Logger, UseIDToken bool, LoginHint string, client *http.Client, codeExchange CodeExchange, logoutCallback string) *AuthMux { if codeExchange == nil { codeExchange = simpleTokenExchange } + + var afterLogoutURL string + if logoutCallback != "" { + afterLogoutURL = logoutCallback + } else { + afterLogoutURL = path.Join(basepath, "/") + } + mux := &AuthMux{ Provider: p, Auth: a, Tokens: t, SuccessURL: path.Join(basepath, "/landing"), - AfterLogoutURL: path.Join(basepath, "/"), + AfterLogoutURL: afterLogoutURL, FailureURL: path.Join(basepath, "/login"), Now: DefaultNowTime, Logger: l, diff --git a/oauth2/mux_test.go b/oauth2/mux_test.go index 2d25a4248f..51d439053f 100644 --- a/oauth2/mux_test.go +++ b/oauth2/mux_test.go @@ -53,7 +53,7 @@ func setupMuxTest(response interface{}, selector func(*AuthMux) http.Handler) (* useidtoken := false - jm := NewAuthMux(mp, auth, mt, "", clog.New(clog.ParseLevel("debug")), useidtoken, "", nil, nil) + jm := NewAuthMux(mp, auth, mt, "", clog.New(clog.ParseLevel("debug")), useidtoken, "", nil, nil, "") ts := httptest.NewServer(selector(jm)) jar, _ := cookiejar.New(nil) hc := http.Client{ diff --git a/server/server.go b/server/server.go index 7c404c979c..153898f502 100644 --- a/server/server.go +++ b/server/server.go @@ -115,6 +115,7 @@ type Server struct { GenericInsecure bool `long:"generic-insecure" description:"Whether or not to verify auth-url's tls certificates." env:"GENERIC_INSECURE"` GenericRootCA flags.Filename `long:"generic-root-ca" description:"File location of root ca cert for generic oauth tls verification." env:"GENERIC_ROOT_CA"` OAuthNoPKCE bool `long:"oauth-no-pkce" description:"Disables OAuth PKCE." env:"OAUTH_NO_PKCE"` + OAuthLogoutEndpoint string `long:"oauth-logout-endpoint" description:"OAuth endpoint to call for logout from OAuth Identity provider." env:"OAUTH_LOGOUT_ENDPOINT"` Auth0Domain string `long:"auth0-domain" description:"Subdomain of auth0.com used for Auth0 OAuth2 authentication" env:"AUTH0_DOMAIN"` Auth0ClientID string `long:"auth0-client-id" description:"Auth0 Client ID for OAuth2 support" env:"AUTH0_CLIENT_ID"` @@ -343,7 +344,7 @@ func (s *Server) githubOAuth(logger chronograf.Logger, auth oauth2.Authenticator Logger: logger, } jwt := oauth2.NewJWT(s.TokenSecret, s.JwksURL) - ghMux := oauth2.NewAuthMux(&gh, auth, jwt, s.Basepath, logger, s.UseIDToken, s.LoginHint, &s.oauthClient, s.createCodeExchange()) + ghMux := oauth2.NewAuthMux(&gh, auth, jwt, s.Basepath, logger, s.UseIDToken, s.LoginHint, &s.oauthClient, s.createCodeExchange(), s.OAuthLogoutEndpoint) return &gh, ghMux, s.UseGithub } @@ -357,7 +358,7 @@ func (s *Server) googleOAuth(logger chronograf.Logger, auth oauth2.Authenticator Logger: logger, } jwt := oauth2.NewJWT(s.TokenSecret, s.JwksURL) - goMux := oauth2.NewAuthMux(&google, auth, jwt, s.Basepath, logger, s.UseIDToken, s.LoginHint, &s.oauthClient, s.createCodeExchange()) + goMux := oauth2.NewAuthMux(&google, auth, jwt, s.Basepath, logger, s.UseIDToken, s.LoginHint, &s.oauthClient, s.createCodeExchange(), s.OAuthLogoutEndpoint) return &google, goMux, s.UseGoogle } @@ -369,7 +370,7 @@ func (s *Server) herokuOAuth(logger chronograf.Logger, auth oauth2.Authenticator Logger: logger, } jwt := oauth2.NewJWT(s.TokenSecret, s.JwksURL) - hMux := oauth2.NewAuthMux(&heroku, auth, jwt, s.Basepath, logger, s.UseIDToken, s.LoginHint, &s.oauthClient, s.createCodeExchange()) + hMux := oauth2.NewAuthMux(&heroku, auth, jwt, s.Basepath, logger, s.UseIDToken, s.LoginHint, &s.oauthClient, s.createCodeExchange(), s.OAuthLogoutEndpoint) return &heroku, hMux, s.UseHeroku } @@ -388,7 +389,7 @@ func (s *Server) genericOAuth(logger chronograf.Logger, auth oauth2.Authenticato Logger: logger, } jwt := oauth2.NewJWT(s.TokenSecret, s.JwksURL) - genMux := oauth2.NewAuthMux(&gen, auth, jwt, s.Basepath, logger, s.UseIDToken, s.LoginHint, &s.oauthClient, s.createCodeExchange()) + genMux := oauth2.NewAuthMux(&gen, auth, jwt, s.Basepath, logger, s.UseIDToken, s.LoginHint, &s.oauthClient, s.createCodeExchange(), s.OAuthLogoutEndpoint) return &gen, genMux, s.UseGenericOAuth2 } @@ -404,7 +405,7 @@ func (s *Server) auth0OAuth(logger chronograf.Logger, auth oauth2.Authenticator) auth0, err := oauth2.NewAuth0(s.Auth0Domain, s.Auth0ClientID, s.Auth0ClientSecret, redirectURL.String(), s.Auth0Organizations, logger) jwt := oauth2.NewJWT(s.TokenSecret, s.JwksURL) - genMux := oauth2.NewAuthMux(&auth0, auth, jwt, s.Basepath, logger, s.UseIDToken, s.LoginHint, &s.oauthClient, s.createCodeExchange()) + genMux := oauth2.NewAuthMux(&auth0, auth, jwt, s.Basepath, logger, s.UseIDToken, s.LoginHint, &s.oauthClient, s.createCodeExchange(), s.OAuthLogoutEndpoint) if err != nil { logger.Error("Error parsing Auth0 domain: err:", err)