diff --git a/frontend/src/pages/login-page.tsx b/frontend/src/pages/login-page.tsx index fe033b5..facb9da 100644 --- a/frontend/src/pages/login-page.tsx +++ b/frontend/src/pages/login-page.tsx @@ -76,10 +76,14 @@ export const LoginPage = () => { isPending: oauthIsPending, variables: oauthVariables, } = useMutation({ - mutationFn: (provider: string) => - axios.get( - `/api/oauth/url/${provider}${props.redirect_uri ? `?redirect_uri=${encodeURIComponent(props.redirect_uri)}` : ""}`, - ), + mutationFn: (provider: string) => { + const params = isOidc + ? `?${compiledOIDCParams}` + : props.redirect_uri + ? `?redirect_uri=${encodeURIComponent(props.redirect_uri)}` + : ""; + return axios.get(`/api/oauth/url/${provider}${params}`); + }, mutationKey: ["oauth"], onSuccess: (data) => { toast.info(t("loginOauthSuccessTitle"), { diff --git a/internal/controller/oauth_controller.go b/internal/controller/oauth_controller.go index 6be0bc5..aa11613 100644 --- a/internal/controller/oauth_controller.go +++ b/internal/controller/oauth_controller.go @@ -62,7 +62,29 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) { return } - sessionId, session, err := controller.auth.NewOAuthSession(req.Provider) + var reqParams service.OAuthURLParams + + err = c.BindQuery(&reqParams) + + if err != nil { + tlog.App.Error().Err(err).Msg("Failed to bind query parameters") + c.JSON(400, gin.H{ + "status": 400, + "message": "Bad Request", + }) + return + } + + if !controller.isOidcRequest(reqParams) { + isRedirectSafe := utils.IsRedirectSafe(reqParams.RedirectURI, controller.config.CookieDomain) + + if !isRedirectSafe { + tlog.App.Warn().Str("redirect_uri", reqParams.RedirectURI).Msg("Unsafe redirect URI detected, ignoring") + reqParams.RedirectURI = "" + } + } + + sessionId, _, err := controller.auth.NewOAuthSession(req.Provider, reqParams) if err != nil { tlog.App.Error().Err(err).Msg("Failed to create OAuth session") @@ -85,20 +107,6 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) { } c.SetCookie(controller.config.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true) - c.SetCookie(controller.config.CSRFCookieName, session.State, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true) - - redirectURI := c.Query("redirect_uri") - isRedirectSafe := utils.IsRedirectSafe(redirectURI, controller.config.CookieDomain) - - if !isRedirectSafe { - tlog.App.Warn().Str("redirect_uri", redirectURI).Msg("Unsafe redirect URI detected, ignoring") - redirectURI = "" - } - - if redirectURI != "" && isRedirectSafe { - tlog.App.Debug().Msg("Setting redirect URI cookie") - c.SetCookie(controller.config.RedirectCookieName, redirectURI, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true) - } c.JSON(200, gin.H{ "status": 200, @@ -129,19 +137,23 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { } c.SetCookie(controller.config.OAuthSessionCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true) - defer controller.auth.EndOAuthSession(sessionIdCookie) - state := c.Query("state") - csrfCookie, err := c.Cookie(controller.config.CSRFCookieName) + oauthPendingSession, err := controller.auth.GetOAuthPendingSession(sessionIdCookie) - if err != nil || state != csrfCookie { - tlog.App.Warn().Err(err).Msg("CSRF token mismatch or cookie missing") - c.SetCookie(controller.config.CSRFCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true) + if err != nil { + tlog.App.Error().Err(err).Msg("Failed to get OAuth pending session") c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) return } - c.SetCookie(controller.config.CSRFCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true) + defer controller.auth.EndOAuthSession(sessionIdCookie) + + state := c.Query("state") + if state != oauthPendingSession.State { + tlog.App.Warn().Err(err).Msg("CSRF token mismatch") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + return + } code := c.Query("code") _, err = controller.auth.GetOAuthToken(sessionIdCookie, code) @@ -198,7 +210,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { username = strings.Replace(user.Email, "@", "_", 1) } - service, err := controller.auth.GetOAuthService(sessionIdCookie) + svc, err := controller.auth.GetOAuthService(sessionIdCookie) if err != nil { tlog.App.Error().Err(err).Msg("Failed to get OAuth service for session") @@ -206,8 +218,8 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { return } - if service.ID() != req.Provider { - tlog.App.Error().Msgf("OAuth service ID mismatch: expected %s, got %s", service.ID(), req.Provider) + if svc.ID() != req.Provider { + tlog.App.Error().Msgf("OAuth service ID mismatch: expected %s, got %s", svc.ID(), req.Provider) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) return } @@ -216,9 +228,9 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { Username: username, Name: name, Email: user.Email, - Provider: service.ID(), + Provider: svc.ID(), OAuthGroups: utils.CoalesceToString(user.Groups), - OAuthName: service.Name(), + OAuthName: svc.Name(), OAuthSub: user.Sub, } @@ -234,24 +246,39 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { tlog.AuditLoginSuccess(c, sessionCookie.Username, sessionCookie.Provider) - redirectURI, err := c.Cookie(controller.config.RedirectCookieName) - - if err != nil || !utils.IsRedirectSafe(redirectURI, controller.config.CookieDomain) { - tlog.App.Debug().Msg("No redirect URI cookie found, redirecting to app root") - c.Redirect(http.StatusTemporaryRedirect, controller.config.AppURL) + if controller.isOidcRequest(oauthPendingSession.CallbackParams) { + tlog.App.Debug().Msg("OIDC request, redirecting to authorize page") + queries, err := query.Values(oauthPendingSession.CallbackParams) + if err != nil { + tlog.App.Error().Err(err).Msg("Failed to encode OIDC callback query") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + return + } + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/authorize?%s", controller.config.AppURL, queries.Encode())) return } - queries, err := query.Values(config.RedirectQuery{ - RedirectURI: redirectURI, - }) + if oauthPendingSession.CallbackParams.RedirectURI != "" { + queries, err := query.Values(config.RedirectQuery{ + RedirectURI: oauthPendingSession.CallbackParams.RedirectURI, + }) - if err != nil { - tlog.App.Error().Err(err).Msg("Failed to encode redirect URI query") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + if err != nil { + tlog.App.Error().Err(err).Msg("Failed to encode redirect URI query") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + return + } + + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.config.AppURL, queries.Encode())) return } - c.SetCookie(controller.config.RedirectCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true) - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.config.AppURL, queries.Encode())) + c.Redirect(http.StatusTemporaryRedirect, controller.config.AppURL) +} + +func (controller *OAuthController) isOidcRequest(params service.OAuthURLParams) bool { + return params.Scope != "" && + params.ResponseType != "" && + params.ClientID != "" && + params.RedirectURI != "" } diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index 6540fe8..807d39c 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -28,12 +28,26 @@ const MaxOAuthPendingSessions = 256 const OAuthCleanupCount = 16 const MaxLoginAttemptRecords = 256 +// slightly modified version of the AuthorizeRequest from the OIDC service to basically accept all +// parameters and pass them to the authorize page if needed +type OAuthURLParams struct { + Scope string `form:"scope" url:"scope"` + ResponseType string `form:"response_type" url:"response_type"` + ClientID string `form:"client_id" url:"client_id"` + RedirectURI string `form:"redirect_uri" url:"redirect_uri"` + State string `form:"state" url:"state"` + Nonce string `form:"nonce" url:"nonce"` + CodeChallenge string `form:"code_challenge" url:"code_challenge"` + CodeChallengeMethod string `form:"code_challenge_method" url:"code_challenge_method"` +} + type OAuthPendingSession struct { - State string - Verifier string - Token *oauth2.Token - Service *OAuthServiceImpl - ExpiresAt time.Time + State string + Verifier string + Token *oauth2.Token + Service *OAuthServiceImpl + ExpiresAt time.Time + CallbackParams OAuthURLParams } type LdapGroupsCache struct { @@ -598,7 +612,7 @@ func (auth *AuthService) IsBypassedIP(acls config.AppIP, ip string) bool { return false } -func (auth *AuthService) NewOAuthSession(serviceName string) (string, OAuthPendingSession, error) { +func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLParams) (string, OAuthPendingSession, error) { auth.ensureOAuthSessionLimit() service, ok := auth.oauthBroker.GetService(serviceName) @@ -617,10 +631,11 @@ func (auth *AuthService) NewOAuthSession(serviceName string) (string, OAuthPendi verifier := service.NewRandom() session := OAuthPendingSession{ - State: state, - Verifier: verifier, - Service: &service, - ExpiresAt: time.Now().Add(1 * time.Hour), + State: state, + Verifier: verifier, + Service: &service, + ExpiresAt: time.Now().Add(1 * time.Hour), + CallbackParams: params, } auth.oauthMutex.Lock() @@ -631,7 +646,7 @@ func (auth *AuthService) NewOAuthSession(serviceName string) (string, OAuthPendi } func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) { - session, err := auth.getOAuthPendingSession(sessionId) + session, err := auth.GetOAuthPendingSession(sessionId) if err != nil { return "", err @@ -641,7 +656,7 @@ func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) { } func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.Token, error) { - session, err := auth.getOAuthPendingSession(sessionId) + session, err := auth.GetOAuthPendingSession(sessionId) if err != nil { return nil, err @@ -661,7 +676,7 @@ func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.T } func (auth *AuthService) GetOAuthUserinfo(sessionId string) (config.Claims, error) { - session, err := auth.getOAuthPendingSession(sessionId) + session, err := auth.GetOAuthPendingSession(sessionId) if err != nil { return config.Claims{}, err @@ -681,7 +696,7 @@ func (auth *AuthService) GetOAuthUserinfo(sessionId string) (config.Claims, erro } func (auth *AuthService) GetOAuthService(sessionId string) (OAuthServiceImpl, error) { - session, err := auth.getOAuthPendingSession(sessionId) + session, err := auth.GetOAuthPendingSession(sessionId) if err != nil { return nil, err @@ -715,7 +730,7 @@ func (auth *AuthService) CleanupOAuthSessionsRoutine() { } } -func (auth *AuthService) getOAuthPendingSession(sessionId string) (*OAuthPendingSession, error) { +func (auth *AuthService) GetOAuthPendingSession(sessionId string) (*OAuthPendingSession, error) { auth.ensureOAuthSessionLimit() auth.oauthMutex.RLock()