fix: prevent ddos attacks in oauth rate limit

This commit is contained in:
Stavros
2026-03-22 20:43:33 +02:00
parent db73c56dfe
commit 7ae16d6bdc
2 changed files with 45 additions and 3 deletions

View File

@@ -129,6 +129,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
} }
c.SetCookie(controller.config.OAuthSessionCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true) c.SetCookie(controller.config.OAuthSessionCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
defer controller.auth.EndOAuthSession(sessionIdCookie)
state := c.Query("state") state := c.Query("state")
csrfCookie, err := c.Cookie(controller.config.CSRFCookieName) csrfCookie, err := c.Cookie(controller.config.CSRFCookieName)
@@ -227,9 +228,6 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
tlog.AuditLoginSuccess(c, sessionCookie.Username, sessionCookie.Provider) tlog.AuditLoginSuccess(c, sessionCookie.Username, sessionCookie.Provider)
// Clear OAuth session
controller.auth.EndOAuthSession(sessionIdCookie)
redirectURI, err := c.Cookie(controller.config.RedirectCookieName) redirectURI, err := c.Cookie(controller.config.RedirectCookieName)
if err != nil || !utils.IsRedirectSafe(redirectURI, controller.config.CookieDomain) { if err != nil || !utils.IsRedirectSafe(redirectURI, controller.config.CookieDomain) {

View File

@@ -17,9 +17,13 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/uuid" "github.com/google/uuid"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"golang.org/x/exp/slices"
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
const MaxOAuthPendingSessions = 256
const OAuthCleanupCount = 16
type OAuthPendingSession struct { type OAuthPendingSession struct {
State string State string
Verifier string Verifier string
@@ -570,6 +574,8 @@ func (auth *AuthService) IsBypassedIP(acls config.AppIP, ip string) bool {
} }
func (auth *AuthService) NewOAuthSession(serviceName string) (string, OAuthPendingSession, error) { func (auth *AuthService) NewOAuthSession(serviceName string) (string, OAuthPendingSession, error) {
auth.ensureOAuthSessionLimit()
service, ok := auth.oauthBroker.GetService(serviceName) service, ok := auth.oauthBroker.GetService(serviceName)
if !ok { if !ok {
@@ -685,6 +691,8 @@ func (auth *AuthService) CleanupOAuthSessionsRoutine() {
} }
func (auth *AuthService) getOAuthPendingSession(sessionId string) (*OAuthPendingSession, error) { func (auth *AuthService) getOAuthPendingSession(sessionId string) (*OAuthPendingSession, error) {
auth.ensureOAuthSessionLimit()
auth.oauthMutex.RLock() auth.oauthMutex.RLock()
session, exists := auth.oauthPendingSessions[sessionId] session, exists := auth.oauthPendingSessions[sessionId]
auth.oauthMutex.RUnlock() auth.oauthMutex.RUnlock()
@@ -702,3 +710,39 @@ func (auth *AuthService) getOAuthPendingSession(sessionId string) (*OAuthPending
return session, nil return session, nil
} }
func (auth *AuthService) ensureOAuthSessionLimit() {
auth.oauthMutex.Lock()
defer auth.oauthMutex.Unlock()
if len(auth.oauthPendingSessions) >= MaxOAuthPendingSessions {
cleanupIds := make([]string, 0, OAuthCleanupCount)
for range OAuthCleanupCount {
oldestId := ""
oldestTime := int64(0)
for id, session := range auth.oauthPendingSessions {
if oldestTime == 0 {
oldestId = id
oldestTime = session.ExpiresAt.Unix()
continue
}
if slices.Contains(cleanupIds, id) {
continue
}
if session.ExpiresAt.Unix() < oldestTime {
oldestId = id
oldestTime = session.ExpiresAt.Unix()
}
}
cleanupIds = append(cleanupIds, oldestId)
}
for _, id := range cleanupIds {
delete(auth.oauthPendingSessions, id)
}
}
}