mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-03-27 00:47:55 +00:00
fix: prevent ddos attacks in oauth rate limit
This commit is contained in:
@@ -17,9 +17,13 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"golang.org/x/exp/slices"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
const MaxOAuthPendingSessions = 256
|
||||
const OAuthCleanupCount = 16
|
||||
|
||||
type OAuthPendingSession struct {
|
||||
State string
|
||||
Verifier string
|
||||
@@ -570,6 +574,8 @@ func (auth *AuthService) IsBypassedIP(acls config.AppIP, ip string) bool {
|
||||
}
|
||||
|
||||
func (auth *AuthService) NewOAuthSession(serviceName string) (string, OAuthPendingSession, error) {
|
||||
auth.ensureOAuthSessionLimit()
|
||||
|
||||
service, ok := auth.oauthBroker.GetService(serviceName)
|
||||
|
||||
if !ok {
|
||||
@@ -685,6 +691,8 @@ func (auth *AuthService) CleanupOAuthSessionsRoutine() {
|
||||
}
|
||||
|
||||
func (auth *AuthService) getOAuthPendingSession(sessionId string) (*OAuthPendingSession, error) {
|
||||
auth.ensureOAuthSessionLimit()
|
||||
|
||||
auth.oauthMutex.RLock()
|
||||
session, exists := auth.oauthPendingSessions[sessionId]
|
||||
auth.oauthMutex.RUnlock()
|
||||
@@ -702,3 +710,39 @@ func (auth *AuthService) getOAuthPendingSession(sessionId string) (*OAuthPending
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (auth *AuthService) ensureOAuthSessionLimit() {
|
||||
auth.oauthMutex.Lock()
|
||||
defer auth.oauthMutex.Unlock()
|
||||
|
||||
if len(auth.oauthPendingSessions) >= MaxOAuthPendingSessions {
|
||||
|
||||
cleanupIds := make([]string, 0, OAuthCleanupCount)
|
||||
|
||||
for range OAuthCleanupCount {
|
||||
oldestId := ""
|
||||
oldestTime := int64(0)
|
||||
|
||||
for id, session := range auth.oauthPendingSessions {
|
||||
if oldestTime == 0 {
|
||||
oldestId = id
|
||||
oldestTime = session.ExpiresAt.Unix()
|
||||
continue
|
||||
}
|
||||
if slices.Contains(cleanupIds, id) {
|
||||
continue
|
||||
}
|
||||
if session.ExpiresAt.Unix() < oldestTime {
|
||||
oldestId = id
|
||||
oldestTime = session.ExpiresAt.Unix()
|
||||
}
|
||||
}
|
||||
|
||||
cleanupIds = append(cleanupIds, oldestId)
|
||||
}
|
||||
|
||||
for _, id := range cleanupIds {
|
||||
delete(auth.oauthPendingSessions, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user