diff --git a/internal/controller/oauth_controller.go b/internal/controller/oauth_controller.go index be116c5..d5dfc39 100644 --- a/internal/controller/oauth_controller.go +++ b/internal/controller/oauth_controller.go @@ -129,6 +129,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { } c.SetCookie(controller.config.OAuthSessionCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true) + defer controller.auth.EndOAuthSession(sessionIdCookie) state := c.Query("state") csrfCookie, err := c.Cookie(controller.config.CSRFCookieName) @@ -227,9 +228,6 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { tlog.AuditLoginSuccess(c, sessionCookie.Username, sessionCookie.Provider) - // Clear OAuth session - controller.auth.EndOAuthSession(sessionIdCookie) - redirectURI, err := c.Cookie(controller.config.RedirectCookieName) if err != nil || !utils.IsRedirectSafe(redirectURI, controller.config.CookieDomain) { diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index 05a5548..53c879d 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -17,9 +17,13 @@ import ( "github.com/gin-gonic/gin" "github.com/google/uuid" "golang.org/x/crypto/bcrypt" + "golang.org/x/exp/slices" "golang.org/x/oauth2" ) +const MaxOAuthPendingSessions = 256 +const OAuthCleanupCount = 16 + type OAuthPendingSession struct { State string Verifier string @@ -570,6 +574,8 @@ func (auth *AuthService) IsBypassedIP(acls config.AppIP, ip string) bool { } func (auth *AuthService) NewOAuthSession(serviceName string) (string, OAuthPendingSession, error) { + auth.ensureOAuthSessionLimit() + service, ok := auth.oauthBroker.GetService(serviceName) if !ok { @@ -685,6 +691,8 @@ func (auth *AuthService) CleanupOAuthSessionsRoutine() { } func (auth *AuthService) getOAuthPendingSession(sessionId string) (*OAuthPendingSession, error) { + auth.ensureOAuthSessionLimit() + auth.oauthMutex.RLock() session, exists := auth.oauthPendingSessions[sessionId] auth.oauthMutex.RUnlock() @@ -702,3 +710,39 @@ func (auth *AuthService) getOAuthPendingSession(sessionId string) (*OAuthPending return session, nil } + +func (auth *AuthService) ensureOAuthSessionLimit() { + auth.oauthMutex.Lock() + defer auth.oauthMutex.Unlock() + + if len(auth.oauthPendingSessions) >= MaxOAuthPendingSessions { + + cleanupIds := make([]string, 0, OAuthCleanupCount) + + for range OAuthCleanupCount { + oldestId := "" + oldestTime := int64(0) + + for id, session := range auth.oauthPendingSessions { + if oldestTime == 0 { + oldestId = id + oldestTime = session.ExpiresAt.Unix() + continue + } + if slices.Contains(cleanupIds, id) { + continue + } + if session.ExpiresAt.Unix() < oldestTime { + oldestId = id + oldestTime = session.ExpiresAt.Unix() + } + } + + cleanupIds = append(cleanupIds, oldestId) + } + + for _, id := range cleanupIds { + delete(auth.oauthPendingSessions, id) + } + } +}