mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-03-22 22:47:52 +00:00
fix: prevent ddos attacks in oauth rate limit
This commit is contained in:
@@ -129,6 +129,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
c.SetCookie(controller.config.OAuthSessionCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
|
||||
defer controller.auth.EndOAuthSession(sessionIdCookie)
|
||||
|
||||
state := c.Query("state")
|
||||
csrfCookie, err := c.Cookie(controller.config.CSRFCookieName)
|
||||
@@ -227,9 +228,6 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
||||
|
||||
tlog.AuditLoginSuccess(c, sessionCookie.Username, sessionCookie.Provider)
|
||||
|
||||
// Clear OAuth session
|
||||
controller.auth.EndOAuthSession(sessionIdCookie)
|
||||
|
||||
redirectURI, err := c.Cookie(controller.config.RedirectCookieName)
|
||||
|
||||
if err != nil || !utils.IsRedirectSafe(redirectURI, controller.config.CookieDomain) {
|
||||
|
||||
@@ -17,9 +17,13 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"golang.org/x/exp/slices"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
const MaxOAuthPendingSessions = 256
|
||||
const OAuthCleanupCount = 16
|
||||
|
||||
type OAuthPendingSession struct {
|
||||
State string
|
||||
Verifier string
|
||||
@@ -570,6 +574,8 @@ func (auth *AuthService) IsBypassedIP(acls config.AppIP, ip string) bool {
|
||||
}
|
||||
|
||||
func (auth *AuthService) NewOAuthSession(serviceName string) (string, OAuthPendingSession, error) {
|
||||
auth.ensureOAuthSessionLimit()
|
||||
|
||||
service, ok := auth.oauthBroker.GetService(serviceName)
|
||||
|
||||
if !ok {
|
||||
@@ -685,6 +691,8 @@ func (auth *AuthService) CleanupOAuthSessionsRoutine() {
|
||||
}
|
||||
|
||||
func (auth *AuthService) getOAuthPendingSession(sessionId string) (*OAuthPendingSession, error) {
|
||||
auth.ensureOAuthSessionLimit()
|
||||
|
||||
auth.oauthMutex.RLock()
|
||||
session, exists := auth.oauthPendingSessions[sessionId]
|
||||
auth.oauthMutex.RUnlock()
|
||||
@@ -702,3 +710,39 @@ func (auth *AuthService) getOAuthPendingSession(sessionId string) (*OAuthPending
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (auth *AuthService) ensureOAuthSessionLimit() {
|
||||
auth.oauthMutex.Lock()
|
||||
defer auth.oauthMutex.Unlock()
|
||||
|
||||
if len(auth.oauthPendingSessions) >= MaxOAuthPendingSessions {
|
||||
|
||||
cleanupIds := make([]string, 0, OAuthCleanupCount)
|
||||
|
||||
for range OAuthCleanupCount {
|
||||
oldestId := ""
|
||||
oldestTime := int64(0)
|
||||
|
||||
for id, session := range auth.oauthPendingSessions {
|
||||
if oldestTime == 0 {
|
||||
oldestId = id
|
||||
oldestTime = session.ExpiresAt.Unix()
|
||||
continue
|
||||
}
|
||||
if slices.Contains(cleanupIds, id) {
|
||||
continue
|
||||
}
|
||||
if session.ExpiresAt.Unix() < oldestTime {
|
||||
oldestId = id
|
||||
oldestTime = session.ExpiresAt.Unix()
|
||||
}
|
||||
}
|
||||
|
||||
cleanupIds = append(cleanupIds, oldestId)
|
||||
}
|
||||
|
||||
for _, id := range cleanupIds {
|
||||
delete(auth.oauthPendingSessions, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user