From 983000d94e94cc51da1213fa0da82f9380294cc3 Mon Sep 17 00:00:00 2001 From: Lennart Fleischmann <67686424+lfleischmann@users.noreply.github.com> Date: Thu, 13 Feb 2025 12:44:36 +0100 Subject: [PATCH] feat(ee): saml idp initiated sso --- backend/ee/saml/handler.go | 235 +++++++++++++++--- backend/ee/saml/service.go | 11 + backend/ee/saml/state.go | 6 +- backend/ee/saml/utils/response.go | 76 ++++++ backend/flow_api/flow/flows.go | 19 ++ .../flow/shared/action_exchange_token.go | 26 +- backend/flow_api/handler.go | 5 + backend/go.mod | 4 +- backend/handler/public_router.go | 4 + backend/persistence/identity_persister.go | 2 +- ...eate_saml_idp_initiated_requests.down.fizz | 1 + ...create_saml_idp_initiated_requests.up.fizz | 9 + .../models/saml_idp_initiated_request.go | 43 ++++ backend/persistence/persister.go | 10 + .../saml_idp_inititated_request_persister.go | 48 ++++ .../elements/src/contexts/AppProvider.tsx | 13 +- .../src/lib/flow-api/types/state-handling.ts | 6 +- 17 files changed, 469 insertions(+), 49 deletions(-) create mode 100644 backend/ee/saml/utils/response.go create mode 100644 backend/persistence/migrations/20250210095906_create_saml_idp_initiated_requests.down.fizz create mode 100644 backend/persistence/migrations/20250210095906_create_saml_idp_initiated_requests.up.fizz create mode 100644 backend/persistence/models/saml_idp_initiated_request.go create mode 100644 backend/persistence/saml_idp_inititated_request_persister.go diff --git a/backend/ee/saml/handler.go b/backend/ee/saml/handler.go index 285b90a97..0795af401 100644 --- a/backend/ee/saml/handler.go +++ b/backend/ee/saml/handler.go @@ -9,6 +9,7 @@ import ( auditlog "github.com/teamhanko/hanko/backend/audit_log" "github.com/teamhanko/hanko/backend/ee/saml/dto" "github.com/teamhanko/hanko/backend/ee/saml/provider" + samlUtils "github.com/teamhanko/hanko/backend/ee/saml/utils" "github.com/teamhanko/hanko/backend/persistence/models" "github.com/teamhanko/hanko/backend/session" "github.com/teamhanko/hanko/backend/thirdparty" @@ -16,6 +17,7 @@ import ( "net/http" "net/url" "strings" + "time" ) type Handler struct { @@ -97,48 +99,132 @@ func (handler *Handler) Auth(c echo.Context) error { return c.Redirect(http.StatusTemporaryRedirect, redirectUrl) } -func (handler *Handler) CallbackPost(c echo.Context) error { - state, samlError := VerifyState(handler.samlService.Config(), handler.samlService.Persister().GetSamlStatePersister(), c.FormValue("RelayState")) - if samlError != nil { +func (handler *Handler) callbackPostIdPInitiated(c echo.Context, samlResponse string) error { + // ignore URL parse error because config validation already ensures it is a parseable URL + redirectTo, _ := url.Parse(handler.samlService.Config().Saml.DefaultRedirectUrl) + + // We need to already parse the response to be able to extract information (a response's ID, Issuer, InResponseTo + // nodes/values) to ensure protection against replaying IDP initiated responses as well as using service provider + // issued responses as IDP initiated responses, even though we later also use the gosaml2 library to parse (and then + // also validate) the response _again_. The reason is that the gosaml2 library does not make this information + // easily/publicly accessible through its API. + parsedSamlResponseDocument, _, err := samlUtils.ParseSamlResponse(samlResponse) + if err != nil { return handler.redirectError( c, - thirdparty.ErrorInvalidRequest(samlError.Error()).WithCause(samlError), - handler.samlService.Config().Saml.DefaultRedirectUrl, + thirdparty.ErrorInvalidRequest("could not parse saml response").WithCause(err), + redirectTo.String(), ) } - if strings.TrimSpace(state.RedirectTo) == "" { - state.RedirectTo = handler.samlService.Config().Saml.DefaultRedirectUrl + responseElement := parsedSamlResponseDocument.FindElement("/Response") + if responseElement == nil { + return handler.redirectError( + c, + thirdparty.ErrorInvalidRequest("invalid saml response: no response node present"), + redirectTo.String(), + ) } - redirectTo, samlError := url.Parse(state.RedirectTo) - if samlError != nil { + issuerElement := parsedSamlResponseDocument.FindElement("/Response/Issuer") + if issuerElement == nil || issuerElement.Text() == "" { + return handler.redirectError( + c, + thirdparty.ErrorInvalidRequest("invalid saml response: no issuer node present"), + redirectTo.String(), + ) + } + + issuer := issuerElement.Text() + + serviceProvider, err := handler.samlService.GetProviderByIssuer(issuer) + if err != nil { return handler.redirectError( c, - thirdparty.ErrorServer("unable to parse redirect url").WithCause(samlError), - handler.samlService.Config().Saml.DefaultRedirectUrl, + thirdparty.ErrorInvalidRequest( + fmt.Sprintf("could not get provider for issuer %s", issuer)). + WithCause(err), + redirectTo.String(), ) } - foundProvider, samlError := handler.samlService.GetProviderByDomain(state.Provider) - if samlError != nil { + // We need to check whether this is an unsolicited request, otherwise SP initiated responses could + // be used as IDP initiated responses. + if responseElement.SelectAttr("InResponseTo") != nil { return handler.redirectError( c, - thirdparty.ErrorServer("unable to find provider by domain").WithCause(samlError), + thirdparty.ErrorInvalidRequest("saml request is not unsolicited"), redirectTo.String(), ) } - assertionInfo, samlError := handler.parseSamlResponse(foundProvider, c.FormValue("SAMLResponse")) - if samlError != nil { + assertionInfo, err := handler.getAssertionInfo(serviceProvider, samlResponse) + if err != nil { + return handler.redirectError( + c, + thirdparty.ErrorInvalidRequest("could not get assertion info").WithCause(err), + redirectTo.String(), + ) + } + + samlResponseIDAttr := responseElement.SelectAttr("ID") + if samlResponseIDAttr == nil { + return handler.redirectError( + c, + thirdparty.ErrorInvalidRequest("invalid saml response: no ID for response present"), + redirectTo.String(), + ) + } + + samlResponseID := samlResponseIDAttr.Value + + samlIDPInitiatedRequestPersister := handler.samlService.Persister().GetSamlIDPInitiatedRequestPersister() + + // We use the SAML response's ID to prevent replay attacks by persisting every IDP initiated request and + // checking whether an IDP initiated request already exists for this request. + existingSamlIDPInitiatedRequest, err := samlIDPInitiatedRequestPersister.GetByResponseIDAndIssuer(samlResponseID, issuer) + if existingSamlIDPInitiatedRequest != nil { + return handler.redirectError( + c, + thirdparty.ErrorInvalidRequest("attempting to replay unsolicited saml request"), + redirectTo.String(), + ) + } + + // We assume only one assertion, and we assume it is present because we already validated it using the gosaml2 + // library (which also consumes only one/the first assertion). We also assume assertion conditions are present + // because validation assures it is not nil (or else it returns an error). + expiresAtString := assertionInfo.Assertions[0].Conditions.NotOnOrAfter + + expiresAt, err := time.Parse(time.RFC3339, expiresAtString) + if err != nil { return handler.redirectError( c, - thirdparty.ErrorServer("unable to parse saml response").WithCause(samlError), + thirdparty.ErrorServer("could not parse saml assertion conditions' NotOnOrAfter value").WithCause(err), redirectTo.String(), ) } - redirectUrl, samlError := handler.linkAccount(c, redirectTo, state, foundProvider, assertionInfo) + // If no request exists we create a new IDP initiated request model and persist it. + samlIDPInitiatedRequest, err := models.NewSamlIDPInitiatedRequest(samlResponseID, issuer, expiresAt) + if err != nil { + return handler.redirectError( + c, + thirdparty.ErrorServer("could not instantiate saml idp initiated request model").WithCause(err), + redirectTo.String(), + ) + } + + err = samlIDPInitiatedRequestPersister.Create(*samlIDPInitiatedRequest) + if err != nil { + return handler.redirectError( + c, + thirdparty.ErrorServer("could not persist saml idp initiated request"), + redirectTo.String(), + ) + } + + redirectUrl, samlError := handler.linkAccount(c, redirectTo, true, serviceProvider, assertionInfo) if samlError != nil { return handler.redirectError( c, @@ -147,19 +233,94 @@ func (handler *Handler) CallbackPost(c echo.Context) error { ) } + // Add hint to the redirect URL that this is an IDP initiated request so that a token exchange can + // eventually be performed through the dedicated flow API handler. + values := redirectUrl.Query() + values.Add("saml_hint", "idp_initiated") + redirectUrl.RawQuery = values.Encode() + return c.Redirect(http.StatusFound, redirectUrl.String()) } -func (handler *Handler) linkAccount(c echo.Context, redirectTo *url.URL, state *State, provider provider.ServiceProvider, assertionInfo *saml2.AssertionInfo) (*url.URL, error) { +func (handler *Handler) CallbackPost(c echo.Context) error { + relayState := c.FormValue("RelayState") + samlResponse := c.FormValue("SAMLResponse") + + if handler.isIDPInitiated(relayState) { + return handler.callbackPostIdPInitiated(c, samlResponse) + } else { + state, err := VerifyState( + handler.samlService.Config(), + handler.samlService.Persister().GetSamlStatePersister(), + strings.TrimPrefix(relayState, statePrefixServiceProviderInitiated), + ) + + if err != nil { + return handler.redirectError( + c, + thirdparty.ErrorInvalidRequest(err.Error()).WithCause(err), + handler.samlService.Config().Saml.DefaultRedirectUrl, + ) + } + + if strings.TrimSpace(state.RedirectTo) == "" { + state.RedirectTo = handler.samlService.Config().Saml.DefaultRedirectUrl + } + + redirectTo, err := url.Parse(state.RedirectTo) + if err != nil { + return handler.redirectError( + c, + thirdparty.ErrorServer("unable to parse redirect url").WithCause(err), + handler.samlService.Config().Saml.DefaultRedirectUrl, + ) + } + + foundProvider, err := handler.samlService.GetProviderByDomain(state.Provider) + if err != nil { + return handler.redirectError( + c, + thirdparty.ErrorServer("unable to find provider by domain").WithCause(err), + redirectTo.String(), + ) + } + + assertionInfo, err := handler.getAssertionInfo(foundProvider, samlResponse) + if err != nil { + return handler.redirectError( + c, + thirdparty.ErrorServer("unable to parse saml response").WithCause(err), + redirectTo.String(), + ) + } + + redirectUrl, err := handler.linkAccount(c, redirectTo, state.IsFlow, foundProvider, assertionInfo) + if err != nil { + return handler.redirectError( + c, + err, + redirectTo.String(), + ) + } + + return c.Redirect(http.StatusFound, redirectUrl.String()) + } +} + +func (handler *Handler) isIDPInitiated(relayState string) bool { + return !strings.HasPrefix(relayState, statePrefixServiceProviderInitiated) +} + +func (handler *Handler) linkAccount(c echo.Context, redirectTo *url.URL, isFlow bool, provider provider.ServiceProvider, assertionInfo *saml2.AssertionInfo) (*url.URL, error) { var accountLinkingResult *thirdparty.AccountLinkingResult - var samlError error - samlError = handler.samlService.Persister().Transaction(func(tx *pop.Connection) error { + var err error + err = handler.samlService.Persister().Transaction(func(tx *pop.Connection) error { userdata := provider.GetUserData(assertionInfo) identityProviderIssuer := assertionInfo.Assertions[0].Issuer samlDomain := provider.GetDomain() - linkResult, samlErrorTx := thirdparty.LinkAccount(tx, handler.samlService.Config(), handler.samlService.Persister(), userdata, identityProviderIssuer.Value, true, &samlDomain, state.IsFlow) - if samlErrorTx != nil { - return samlErrorTx + linkResult, errTx := thirdparty.LinkAccount(tx, handler.samlService.Config(), handler.samlService.Persister(), userdata, identityProviderIssuer.Value, true, &samlDomain, isFlow) + if errTx != nil { + return errTx } accountLinkingResult = linkResult @@ -167,18 +328,18 @@ func (handler *Handler) linkAccount(c echo.Context, redirectTo *url.URL, state * emailModel := linkResult.User.Emails.GetEmailByAddress(userdata.Metadata.Email) identityModel := emailModel.Identities.GetIdentity(identityProviderIssuer.Value, userdata.Metadata.Subject) - token, tokenError := models.NewToken( + token, errTx := models.NewToken( linkResult.User.ID, models.TokenWithIdentityID(identityModel.ID), - models.TokenForFlowAPI(state.IsFlow), + models.TokenForFlowAPI(isFlow), models.TokenUserCreated(linkResult.UserCreated)) - if tokenError != nil { - return thirdparty.ErrorServer("could not create token").WithCause(tokenError) + if errTx != nil { + return thirdparty.ErrorServer("could not create token").WithCause(errTx) } - tokenError = handler.samlService.Persister().GetTokenPersisterWithConnection(tx).Create(*token) - if tokenError != nil { - return thirdparty.ErrorServer("could not save token to db").WithCause(tokenError) + errTx = handler.samlService.Persister().GetTokenPersisterWithConnection(tx).Create(*token) + if errTx != nil { + return thirdparty.ErrorServer("could not save token to db").WithCause(errTx) } query := redirectTo.Query() @@ -188,20 +349,20 @@ func (handler *Handler) linkAccount(c echo.Context, redirectTo *url.URL, state * return nil }) - if samlError != nil { - return nil, samlError + if err != nil { + return nil, err } - samlError = handler.auditLogger.Create(c, accountLinkingResult.Type, accountLinkingResult.User, nil) + err = handler.auditLogger.Create(c, accountLinkingResult.Type, accountLinkingResult.User, nil) - if samlError != nil { - return nil, samlError + if err != nil { + return nil, err } return redirectTo, nil } -func (handler *Handler) parseSamlResponse(provider provider.ServiceProvider, samlResponse string) (*saml2.AssertionInfo, error) { +func (handler *Handler) getAssertionInfo(provider provider.ServiceProvider, samlResponse string) (*saml2.AssertionInfo, error) { assertionInfo, err := provider.GetService().RetrieveAssertionInfo(samlResponse) if err != nil { return nil, thirdparty.ErrorServer("unable to parse SAML response").WithCause(err) diff --git a/backend/ee/saml/service.go b/backend/ee/saml/service.go index fe1b70784..45c693916 100644 --- a/backend/ee/saml/service.go +++ b/backend/ee/saml/service.go @@ -15,6 +15,7 @@ type Service interface { Persister() persistence.Persister Providers() []provider.ServiceProvider GetProviderByDomain(domain string) (provider.ServiceProvider, error) + GetProviderByIssuer(issuer string) (provider.ServiceProvider, error) GetAuthUrl(provider provider.ServiceProvider, redirectTo string, isFlow bool) (string, error) } @@ -83,6 +84,16 @@ func (s *defaultService) GetProviderByDomain(domain string) (provider.ServicePro return nil, fmt.Errorf("unknown provider for domain %s", domain) } +func (s *defaultService) GetProviderByIssuer(issuer string) (provider.ServiceProvider, error) { + for _, availableProvider := range s.providers { + if availableProvider.GetService().IdentityProviderIssuer == issuer { + return availableProvider, nil + } + } + + return nil, fmt.Errorf("unknown provider for issuer %s", issuer) +} + func (s *defaultService) GetAuthUrl(provider provider.ServiceProvider, redirectTo string, isFlow bool) (string, error) { if ok := samlUtils.IsAllowedRedirect(s.config.Saml, redirectTo); !ok { return "", thirdparty.ErrorInvalidRequest(fmt.Sprintf("redirect to '%s' not allowed", redirectTo)) diff --git a/backend/ee/saml/state.go b/backend/ee/saml/state.go index 9999553f4..56add9712 100644 --- a/backend/ee/saml/state.go +++ b/backend/ee/saml/state.go @@ -22,6 +22,8 @@ type State struct { IsFlow bool `json:"is_flow"` } +const statePrefixServiceProviderInitiated = "hanko_spi_" + func GenerateStateForFlowAPI(isFlow bool) func(*State) { return func(state *State) { state.IsFlow = isFlow @@ -77,7 +79,9 @@ func GenerateState(config *config.Config, persister persistence.SamlStatePersist return nil, fmt.Errorf("could not save state to db: %w", err) } - return []byte(encryptedState), nil + // Add prefix to distinguish between SP initiated and IDP initiated requests in callback handler. + result := fmt.Sprintf("%s%s", statePrefixServiceProviderInitiated, encryptedState) + return []byte(result), nil } func VerifyState(config *config.Config, persister persistence.SamlStatePersister, state string) (*State, error) { diff --git a/backend/ee/saml/utils/response.go b/backend/ee/saml/utils/response.go new file mode 100644 index 000000000..4c01eefea --- /dev/null +++ b/backend/ee/saml/utils/response.go @@ -0,0 +1,76 @@ +package utils + +import ( + "bytes" + "compress/flate" + "encoding/base64" + "fmt" + "github.com/beevik/etree" + rtvalidator "github.com/mattermost/xml-roundtrip-validator" + "io" +) + +const ( + defaultMaxDecompressedResponseSize = 5 * 1024 * 1024 +) + +func maybeDeflate(data []byte, maxSize int64, decoder func([]byte) error) error { + err := decoder(data) + if err == nil { + return nil + } + + // Default to 5MB max size + if maxSize == 0 { + maxSize = defaultMaxDecompressedResponseSize + } + + lr := io.LimitReader(flate.NewReader(bytes.NewReader(data)), maxSize+1) + + deflated, err := io.ReadAll(lr) + if err != nil { + return err + } + + if int64(len(deflated)) > maxSize { + return fmt.Errorf("deflated response exceeds maximum size of %d bytes", maxSize) + } + + return decoder(deflated) +} + +func ParseSamlResponse(samlResponse string) (*etree.Document, *etree.Element, error) { + raw, err := base64.StdEncoding.DecodeString(samlResponse) + if err != nil { + return nil, nil, fmt.Errorf("could not decode saml response: %w", err) + } + + return parseResponseXml(raw) +} + +func parseResponseXml(xml []byte) (*etree.Document, *etree.Element, error) { + var doc *etree.Document + var rawXML []byte + + err := maybeDeflate(xml, defaultMaxDecompressedResponseSize, func(xml []byte) error { + doc = etree.NewDocument() + rawXML = xml + return doc.ReadFromBytes(xml) + }) + if err != nil { + return nil, nil, err + } + + el := doc.Root() + if el == nil { + return nil, nil, fmt.Errorf("unable to parse response") + } + + // Examine the response for attempts to exploit weaknesses in Go's encoding/xml + err = rtvalidator.Validate(bytes.NewReader(rawXML)) + if err != nil { + return nil, nil, err + } + + return doc, el, nil +} diff --git a/backend/flow_api/flow/flows.go b/backend/flow_api/flow/flows.go index c4d1c3201..2d09f767b 100644 --- a/backend/flow_api/flow/flows.go +++ b/backend/flow_api/flow/flows.go @@ -215,3 +215,22 @@ func NewProfileFlow(debug bool) flowpilot.Flow { Debug(debug). MustBuild() } + +func NewTokenExchangeFlow(debug bool) flowpilot.Flow { + return flowpilot.NewFlow("token_exchange"). + State(shared.StateThirdParty, + shared.ExchangeToken{}). + State(shared.StateSuccess). + BeforeState(shared.StateSuccess, + shared.IssueSession{}, + shared.GetUserData{}). + SubFlows( + CredentialUsageSubFlow, + UserDetailsSubFlow). + AfterState(shared.StatePasscodeConfirmation, + shared.EmailPersistVerifiedStatus{}). + InitialState(shared.StateThirdParty). + ErrorState(shared.StateError). + Debug(debug). + MustBuild() +} diff --git a/backend/flow_api/flow/shared/action_exchange_token.go b/backend/flow_api/flow/shared/action_exchange_token.go index 1031d6556..8fd99b618 100644 --- a/backend/flow_api/flow/shared/action_exchange_token.go +++ b/backend/flow_api/flow/shared/action_exchange_token.go @@ -81,16 +81,30 @@ func (a ExchangeToken) Execute(c flowpilot.ExecutionContext) error { return fmt.Errorf("failed to delete token from db: %w", err) } - onboardingStates, err := a.determineOnboardingStates(c, identity, tokenModel.UserCreated) + isSaml := identity.SamlIdentity != nil + + var onboardingStates []flowpilot.StateName + if isSaml { + samlProvider, err := deps.SamlService.GetProviderByIssuer(identity.ProviderID) + if err != nil { + return fmt.Errorf("could not fetch saml provider for identity: %w", err) + } + mustDoEmailVerification := !samlProvider.GetConfig().SkipEmailVerification && identity.Email != nil && !identity.Email.Verified + onboardingStates, err = a.determineOnboardingStates(c, identity, tokenModel.UserCreated, mustDoEmailVerification) + } else { + mustDoEmailVerification := deps.Cfg.Email.RequireVerification && identity.Email != nil && !identity.Email.Verified + onboardingStates, err = a.determineOnboardingStates(c, identity, tokenModel.UserCreated, mustDoEmailVerification) + } + if err != nil { - return fmt.Errorf("failed to determine onboarding stattes: %w", err) + return fmt.Errorf("failed to determine onboarding states: %w", err) } - if err := c.Stash().Set(StashPathLoginMethod, "third_party"); err != nil { + if err = c.Stash().Set(StashPathLoginMethod, "third_party"); err != nil { return fmt.Errorf("failed to set login_method to the stash: %w", err) } - if err := c.Stash().Set(StashPathThirdPartyProvider, identity.ProviderID); err != nil { + if err = c.Stash().Set(StashPathThirdPartyProvider, identity.ProviderID); err != nil { return fmt.Errorf("failed to set third_party_provider to the stash: %w", err) } @@ -99,11 +113,11 @@ func (a ExchangeToken) Execute(c flowpilot.ExecutionContext) error { return c.Continue(onboardingStates...) } -func (a ExchangeToken) determineOnboardingStates(c flowpilot.ExecutionContext, identity *models.Identity, userCreated bool) ([]flowpilot.StateName, error) { +func (a ExchangeToken) determineOnboardingStates(c flowpilot.ExecutionContext, identity *models.Identity, userCreated bool, mustDoEmailVerification bool) ([]flowpilot.StateName, error) { deps := a.GetDeps(c) result := make([]flowpilot.StateName, 0) - if deps.Cfg.Email.RequireVerification && identity.Email != nil && !identity.Email.Verified { + if mustDoEmailVerification { if err := c.Stash().Set(StashPathEmail, identity.Email.Address); err != nil { return nil, fmt.Errorf("failed to stash email: %w", err) } diff --git a/backend/flow_api/handler.go b/backend/flow_api/handler.go index 2c8df28d4..dccfef29d 100644 --- a/backend/flow_api/handler.go +++ b/backend/flow_api/handler.go @@ -62,6 +62,11 @@ func (h *FlowPilotHandler) ProfileFlowHandler(c echo.Context) error { return h.executeFlow(c, profileFlow) } +func (h *FlowPilotHandler) TokenExchangeFlowHandler(c echo.Context) error { + samlIdPInitiatedLoginFlow := flow.NewTokenExchangeFlow(h.Cfg.Debug) + return h.executeFlow(c, samlIdPInitiatedLoginFlow) +} + func (h *FlowPilotHandler) validateSession(c echo.Context) error { lookup := fmt.Sprintf("header:Authorization:Bearer,cookie:%s", h.Cfg.Session.Cookie.GetName()) extractors, err := echojwt.CreateExtractors(lookup) diff --git a/backend/go.mod b/backend/go.mod index 76de34d4a..e955b749e 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -3,6 +3,7 @@ module github.com/teamhanko/hanko/backend go 1.20 require ( + github.com/beevik/etree v1.1.0 github.com/brianvoe/gofakeit/v6 v6.28.0 github.com/coreos/go-oidc/v3 v3.9.0 github.com/fatih/structs v1.1.0 @@ -28,6 +29,7 @@ require ( github.com/labstack/gommon v0.4.2 github.com/lestrrat-go/jwx/v2 v2.1.0 github.com/lib/pq v1.10.9 + github.com/mattermost/xml-roundtrip-validator v0.1.0 github.com/mileusna/useragent v1.3.5 github.com/mitchellh/mapstructure v1.5.0 github.com/nicksnyder/go-i18n/v2 v2.4.0 @@ -63,7 +65,6 @@ require ( github.com/andybalholm/brotli v1.0.5 // indirect github.com/aymerick/douceur v0.2.0 // indirect github.com/bahlo/generic-list-go v0.2.0 // indirect - github.com/beevik/etree v1.1.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect github.com/buger/jsonparser v1.1.1 // indirect @@ -123,7 +124,6 @@ require ( github.com/lestrrat-go/option v1.0.1 // indirect github.com/luna-duclos/instrumentedsql v1.1.3 // indirect github.com/mailru/easyjson v0.7.7 // indirect - github.com/mattermost/xml-roundtrip-validator v0.1.0 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect diff --git a/backend/handler/public_router.go b/backend/handler/public_router.go index cb2120ba5..fb4e9310b 100644 --- a/backend/handler/public_router.go +++ b/backend/handler/public_router.go @@ -86,6 +86,10 @@ func NewPublicRouter(cfg *config.Config, persister persistence.Persister, promet e.POST("/login", flowAPIHandler.LoginFlowHandler, webhookMiddleware) e.POST("/profile", flowAPIHandler.ProfileFlowHandler, webhookMiddleware) + if cfg.Saml.Enabled { + e.POST("/token_exchange", flowAPIHandler.TokenExchangeFlowHandler, webhookMiddleware) + } + e.HideBanner = true g := e.Group("") diff --git a/backend/persistence/identity_persister.go b/backend/persistence/identity_persister.go index 1a7e65c4b..0dba5dbf7 100644 --- a/backend/persistence/identity_persister.go +++ b/backend/persistence/identity_persister.go @@ -23,7 +23,7 @@ type identityPersister struct { func (p identityPersister) GetByID(identityID uuid.UUID) (*models.Identity, error) { identity := &models.Identity{} - if err := p.db.EagerPreload("Email", "Email.User", "Email.User.Username").Find(identity, identityID); err != nil { + if err := p.db.EagerPreload("Email", "Email.User", "Email.User.Username", "SamlIdentity").Find(identity, identityID); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, nil } diff --git a/backend/persistence/migrations/20250210095906_create_saml_idp_initiated_requests.down.fizz b/backend/persistence/migrations/20250210095906_create_saml_idp_initiated_requests.down.fizz new file mode 100644 index 000000000..a3b211d62 --- /dev/null +++ b/backend/persistence/migrations/20250210095906_create_saml_idp_initiated_requests.down.fizz @@ -0,0 +1 @@ +drop_table("saml_idp_initiated_requests") diff --git a/backend/persistence/migrations/20250210095906_create_saml_idp_initiated_requests.up.fizz b/backend/persistence/migrations/20250210095906_create_saml_idp_initiated_requests.up.fizz new file mode 100644 index 000000000..4a9176727 --- /dev/null +++ b/backend/persistence/migrations/20250210095906_create_saml_idp_initiated_requests.up.fizz @@ -0,0 +1,9 @@ +create_table("saml_idp_initiated_requests") { + t.Column("id", "uuid", {primary: true}) + t.Column("response_id", "string", { "null": false }) + t.Column("issuer", "string", { "null": false }) + t.Column("expires_at", "timestamp", { "null": false }) + t.Column("created_at", "timestamp", { "null": false }) + t.DisableTimestamps() + t.Index(["response_id", "issuer"], {"unique": true}) +} diff --git a/backend/persistence/models/saml_idp_initiated_request.go b/backend/persistence/models/saml_idp_initiated_request.go new file mode 100644 index 000000000..98840550c --- /dev/null +++ b/backend/persistence/models/saml_idp_initiated_request.go @@ -0,0 +1,43 @@ +package models + +import ( + "github.com/gobuffalo/pop/v6" + "github.com/gobuffalo/validate/v3" + "github.com/gobuffalo/validate/v3/validators" + "github.com/gofrs/uuid" + "time" +) + +type SamlIDPInitiatedRequest struct { + ID uuid.UUID `db:"id"` + ResponseID string `db:"response_id"` + Issuer string `db:"issuer"` + ExpiresAt time.Time `db:"expires_at"` + CreatedAt time.Time `db:"created_at"` +} + +func NewSamlIDPInitiatedRequest(responseID, issuer string, expiresAt time.Time) (*SamlIDPInitiatedRequest, error) { + id, _ := uuid.NewV4() + + return &SamlIDPInitiatedRequest{ + ID: id, + ResponseID: responseID, + Issuer: issuer, + ExpiresAt: expiresAt, + CreatedAt: time.Now().UTC(), + }, nil +} + +func (samlIDPInitiatedRequest SamlIDPInitiatedRequest) TableName() string { + return "saml_idp_initiated_requests" +} + +func (r *SamlIDPInitiatedRequest) Validate(tx *pop.Connection) (*validate.Errors, error) { + return validate.Validate( + &validators.UUIDIsPresent{Name: "ID", Field: r.ID}, + &validators.StringIsPresent{Name: "ResponseID", Field: r.ResponseID}, + &validators.StringIsPresent{Name: "Issuer", Field: r.Issuer}, + &validators.TimeIsPresent{Name: "ExpiresAt", Field: r.ExpiresAt}, + &validators.TimeIsPresent{Name: "CreatedAt", Field: r.CreatedAt}, + ), nil +} diff --git a/backend/persistence/persister.go b/backend/persistence/persister.go index 7159a42c1..03b87b27e 100644 --- a/backend/persistence/persister.go +++ b/backend/persistence/persister.go @@ -39,6 +39,8 @@ type Persister interface { GetSamlStatePersisterWithConnection(tx *pop.Connection) SamlStatePersister GetSamlIdentityPersister() SamlIdentityPersister GetSamlIdentityPersisterWithConnection(tx *pop.Connection) SamlIdentityPersister + GetSamlIDPInitiatedRequestPersister() SamlIDPInitiatedRequestPersister + GetSamlIDPInitiatedRequestPersisterWithConnection(tx *pop.Connection) SamlIDPInitiatedRequestPersister GetTokenPersister() TokenPersister GetTokenPersisterWithConnection(tx *pop.Connection) TokenPersister GetUserPersister() UserPersister @@ -288,6 +290,14 @@ func (p *persister) GetSamlIdentityPersisterWithConnection(tx *pop.Connection) S return NewSamlIdentityPersister(tx) } +func (p *persister) GetSamlIDPInitiatedRequestPersister() SamlIDPInitiatedRequestPersister { + return NewSamlIDPInitiatedRequestPersister(p.DB) +} + +func (p *persister) GetSamlIDPInitiatedRequestPersisterWithConnection(tx *pop.Connection) SamlIDPInitiatedRequestPersister { + return NewSamlIDPInitiatedRequestPersister(tx) +} + func (p *persister) GetWebhookPersister(tx *pop.Connection) WebhookPersister { if tx != nil { return NewWebhookPersister(tx) diff --git a/backend/persistence/saml_idp_inititated_request_persister.go b/backend/persistence/saml_idp_inititated_request_persister.go new file mode 100644 index 000000000..07aaa2d61 --- /dev/null +++ b/backend/persistence/saml_idp_inititated_request_persister.go @@ -0,0 +1,48 @@ +package persistence + +import ( + "database/sql" + "errors" + "fmt" + "github.com/gobuffalo/pop/v6" + "github.com/teamhanko/hanko/backend/persistence/models" +) + +type SamlIDPInitiatedRequestPersister interface { + Create(samlIDPInitiatedRequest models.SamlIDPInitiatedRequest) error + GetByResponseIDAndIssuer(responseID, entityID string) (*models.SamlIDPInitiatedRequest, error) +} + +type samlIDPInitiatedRequestPersister struct { + db *pop.Connection +} + +func (p samlIDPInitiatedRequestPersister) GetByResponseIDAndIssuer(responseID, entityID string) (*models.SamlIDPInitiatedRequest, error) { + samlIDPInitiatedRequest := models.SamlIDPInitiatedRequest{} + query := p.db.Where("response_id = ? AND idp_entity_id = ?", responseID, entityID) + err := query.First(&samlIDPInitiatedRequest) + if err != nil && errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("failed to get credential: %w", err) + } + return &samlIDPInitiatedRequest, nil +} + +func NewSamlIDPInitiatedRequestPersister(db *pop.Connection) SamlIDPInitiatedRequestPersister { + return &samlIDPInitiatedRequestPersister{db: db} +} + +func (p samlIDPInitiatedRequestPersister) Create(samlIDPInitiatedRequest models.SamlIDPInitiatedRequest) error { + vErr, err := p.db.ValidateAndCreate(&samlIDPInitiatedRequest) + if err != nil { + return fmt.Errorf("failed to store saml idp initiated request: %w", err) + } + + if vErr != nil && vErr.HasAny() { + return fmt.Errorf("saml idp initated request object validation failed: %w", vErr) + } + + return nil +} diff --git a/frontend/elements/src/contexts/AppProvider.tsx b/frontend/elements/src/contexts/AppProvider.tsx index f05d6240a..7897b39f3 100644 --- a/frontend/elements/src/contexts/AppProvider.tsx +++ b/frontend/elements/src/contexts/AppProvider.tsx @@ -464,6 +464,7 @@ const AppProvider = ({ .run(); searchParams.delete("hanko_token"); + searchParams.delete("saml_hint"); history.replaceState( null, @@ -524,7 +525,17 @@ const AppProvider = ({ "hanko_token", ); const cachedState = localStorage.getItem(localStorageCacheStateKey); - if (cachedState && cachedState.length > 0 && token && token.length > 0) { + const samlHint = new URLSearchParams(window.location.search).get( + "saml_hint", + ); + if (samlHint === "idp_initiated") { + await hanko.flow.init("/token_exchange", { ...stateHandler }); + } else if ( + cachedState && + cachedState.length > 0 && + token && + token.length > 0 + ) { await hanko.flow.fromString( localStorage.getItem(localStorageCacheStateKey), { ...stateHandler }, diff --git a/frontend/frontend-sdk/src/lib/flow-api/types/state-handling.ts b/frontend/frontend-sdk/src/lib/flow-api/types/state-handling.ts index c7ade7ae0..ef4e6d399 100644 --- a/frontend/frontend-sdk/src/lib/flow-api/types/state-handling.ts +++ b/frontend/frontend-sdk/src/lib/flow-api/types/state-handling.ts @@ -121,7 +121,11 @@ export interface Payloads { readonly webauthn_credential_verification: OnboardingVerifyPasskeyAttestationPayload; } -export type FlowPath = "/login" | "/registration" | "/profile"; +export type FlowPath = + | "/login" + | "/registration" + | "/profile" + | "/token_exchange"; export type FetchNextState = ( // eslint-disable-next-line no-unused-vars