From 3e5757cfc90f70c0e00f422ed254e919fde946e8 Mon Sep 17 00:00:00 2001 From: Stavros Date: Sat, 30 May 2026 15:04:53 +0300 Subject: [PATCH] fix: fix race conditions --- internal/service/auth_service.go | 86 ++++++++++------- internal/service/cache_store.go | 154 +++++++++++++++++-------------- 2 files changed, 139 insertions(+), 101 deletions(-) diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index fefeb00c..1034ed1e 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -128,7 +128,7 @@ func NewAuthService( service.caches.oauth.Sweep() service.caches.login.Sweep() service.caches.ldap.Sweep() - case <-service.ctx.Done(): + case <-ctx.Done(): return } } @@ -231,8 +231,7 @@ func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) { } func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) { - if auth.lockdown.active { - remaining := int(time.Until(auth.lockdown.until).Seconds()) + if locked, remaining := auth.IsInLockdown(); locked { return true, remaining } @@ -259,42 +258,48 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) { } if auth.caches.login.Size() >= MaxLoginAttemptRecords { - if auth.lockdown.active { + if locked, _ := auth.IsInLockdown(); locked { return } go auth.lockdownMode() return } - ok := auth.caches.login.Mutate(identifier, func(la LoginAttempt) (LoginAttempt, bool) { - la.LastAttempt = time.Now() - if success { - la.FailedAttempts = 0 - la.LockedUntil = time.Time{} - return la, false - } - la.FailedAttempts++ - if la.FailedAttempts >= auth.config.Auth.LoginMaxRetries { - la.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second) - auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", la.FailedAttempts).Msg("Account locked due to too many failed login attempts") - } - return la, true - }) + auth.caches.login.WithLock(func(actions CacheStoreActions[LoginAttempt]) { + entry, ok := actions.Get(identifier) - if !ok { - // No existing record, create a new one - attempt := LoginAttempt{ - LastAttempt: time.Now(), + if !ok { + attempt := LoginAttempt{ + LastAttempt: time.Now(), + } + if !success { + attempt.FailedAttempts = 1 + if attempt.FailedAttempts >= auth.config.Auth.LoginMaxRetries { + attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second) + auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", attempt.FailedAttempts).Msg("Account locked due to too many failed login attempts") + } + } + // match current tinyauth behavior which doesn't expire rate limits + actions.Set(identifier, attempt, 0) + return } - if !success { - attempt.FailedAttempts = 1 - if attempt.FailedAttempts >= auth.config.Auth.LoginMaxRetries { - attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second) - auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", attempt.FailedAttempts).Msg("Account locked due to too many failed login attempts") + + entry.LastAttempt = time.Now() + + if success { + entry.FailedAttempts = 0 + entry.LockedUntil = time.Time{} + } else { + entry.FailedAttempts++ + + if entry.FailedAttempts >= auth.config.Auth.LoginMaxRetries { + entry.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second) + auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", entry.FailedAttempts).Msg("Account locked due to too many failed login attempts") } } - auth.caches.login.Set(identifier, attempt, 0) // match current tinyauth behavior which doesn't expire rate limits - } + + actions.Set(identifier, entry, 0) + }) } // We could also directly access the policyEngine.effectToAccess but @@ -551,10 +556,10 @@ func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) { } func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.Token, error) { - session, err := auth.GetOAuthPendingSession(sessionId) + session, ok := auth.caches.oauth.Get(sessionId) - if err != nil { - return nil, err + if !ok { + return nil, fmt.Errorf("oauth session not found: %s", sessionId) } token, err := (*session.Service).GetToken(code, session.Verifier) @@ -565,7 +570,12 @@ func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.T session.Token = token - auth.caches.oauth.Set(sessionId, *session, time.Minute*10) + // ttl 0 means keep current expiration + ok = auth.caches.oauth.Update(sessionId, session, 0) + + if !ok { + return nil, fmt.Errorf("failed to update oauth session with token: %s", sessionId) + } return token, nil } @@ -659,6 +669,16 @@ func (auth *AuthService) lockdownMode() { auth.lockdown.mu.Unlock() } +func (auth *AuthService) IsInLockdown() (bool, int) { + auth.lockdown.mu.RLock() + defer auth.lockdown.mu.RUnlock() + if auth.lockdown.active { + remaining := int(time.Until(auth.lockdown.until).Seconds()) + return true, remaining + } + return false, 0 +} + // mostly a testing function, not useful for anything else func (auth *AuthService) ClearLoginAttempts() { auth.caches.login.Clear() diff --git a/internal/service/cache_store.go b/internal/service/cache_store.go index 939ebcef..77ac1630 100644 --- a/internal/service/cache_store.go +++ b/internal/service/cache_store.go @@ -1,10 +1,18 @@ package service import ( + "slices" "sync" "time" ) +type CacheStoreActions[T any] struct { + Set func(key string, value T, ttl time.Duration) + Get func(key string) (T, bool) + Delete func(key string) + Update func(key string, value T, ttl time.Duration) bool +} + type cacheEntry[T any] struct { value T expiresAt *time.Time @@ -12,6 +20,7 @@ type cacheEntry[T any] struct { type CacheStore[T any] struct { cache map[string]cacheEntry[T] + order []string mu sync.RWMutex maxSize int } @@ -19,14 +28,57 @@ type CacheStore[T any] struct { func NewCacheStore[T any](maxSize int) *CacheStore[T] { return &CacheStore[T]{ cache: make(map[string]cacheEntry[T]), + order: make([]string, 0), maxSize: maxSize, } } -func (cs *CacheStore[T]) Set(key string, value T, ttl time.Duration) { +// With lock allows performing multiple operations on the cache store atomically. +// The provided mutate function receives a set of actions (Set, Get, Delete) that +// can be used to manipulate the cache store within the locked context. +func (cs *CacheStore[T]) WithLock(mutate func(actions CacheStoreActions[T])) { cs.mu.Lock() defer cs.mu.Unlock() + actions := CacheStoreActions[T]{ + Set: cs.setCallback, + Get: cs.getCallback, + Delete: cs.deleteCallback, + Update: cs.updateCallback, + } + mutate(actions) +} +func (cs *CacheStore[T]) updateCallback(key string, value T, ttl time.Duration) bool { + if currentEntry, exists := cs.cache[key]; exists { + if currentEntry.expiresAt != nil && time.Now().After(*currentEntry.expiresAt) { + return false + } + + entry := cacheEntry[T]{ + value: value, + expiresAt: currentEntry.expiresAt, + } + + if ttl > 0 { + expiration := time.Now().Add(ttl) + entry.expiresAt = &expiration + } + + cs.cache[key] = entry + + return true + } + + return false +} + +func (cs *CacheStore[T]) Update(key string, value T, ttl time.Duration) bool { + cs.mu.Lock() + defer cs.mu.Unlock() + return cs.updateCallback(key, value, ttl) +} + +func (cs *CacheStore[T]) setCallback(key string, value T, ttl time.Duration) { if cs.maxSize > 0 { if _, exists := cs.cache[key]; !exists && len(cs.cache) >= cs.maxSize { cs.evictOne() @@ -44,12 +96,17 @@ func (cs *CacheStore[T]) Set(key string, value T, ttl time.Duration) { value: value, expiresAt: expiresAt, } + + cs.order = append(cs.order, key) } -func (cs *CacheStore[T]) Get(key string) (T, bool) { - cs.mu.RLock() - defer cs.mu.RUnlock() +func (cs *CacheStore[T]) Set(key string, value T, ttl time.Duration) { + cs.mu.Lock() + defer cs.mu.Unlock() + cs.setCallback(key, value, ttl) +} +func (cs *CacheStore[T]) getCallback(key string) (T, bool) { entry, exists := cs.cache[key] if !exists { @@ -65,79 +122,31 @@ func (cs *CacheStore[T]) Get(key string) (T, bool) { return entry.value, true } +func (cs *CacheStore[T]) Get(key string) (T, bool) { + cs.mu.RLock() + defer cs.mu.RUnlock() + return cs.getCallback(key) +} + +func (cs *CacheStore[T]) deleteCallback(key string) { + delete(cs.cache, key) + keyIdx := slices.Index(cs.order, key) + if keyIdx != -1 { + cs.order = append(cs.order[:keyIdx], cs.order[keyIdx+1:]...) + } +} + func (cs *CacheStore[T]) Delete(key string) { cs.mu.Lock() defer cs.mu.Unlock() - delete(cs.cache, key) -} - -func (cs *CacheStore[T]) Mutate(key string, mutator func(T) (T, bool)) bool { - cs.mu.Lock() - defer cs.mu.Unlock() - - entry, exists := cs.cache[key] - - if !exists { - return false - } - - if entry.expiresAt != nil && time.Now().After(*entry.expiresAt) { - delete(cs.cache, key) - return false - } - - newValue, shouldKeep := mutator(entry.value) - - if !shouldKeep { - delete(cs.cache, key) - return true - } - - cs.cache[key] = cacheEntry[T]{ - value: newValue, - expiresAt: entry.expiresAt, - } - - return true -} - -func (cs *CacheStore[T]) MutateWithTTL(key string, mutator func(T) (T, time.Duration, bool)) bool { - cs.mu.Lock() - defer cs.mu.Unlock() - - entry, exists := cs.cache[key] - - if !exists { - return false - } - - if entry.expiresAt != nil && time.Now().After(*entry.expiresAt) { - delete(cs.cache, key) - return false - } - - newValue, ttl, shouldKeep := mutator(entry.value) - - if !shouldKeep { - delete(cs.cache, key) - return true - } - - expiresAt := time.Now().Add(ttl) - - cs.cache[key] = cacheEntry[T]{ - value: newValue, - expiresAt: &expiresAt, - } - - return true + cs.deleteCallback(key) } func (cs *CacheStore[T]) Sweep() { cs.mu.Lock() for key, entry := range cs.cache { if entry.expiresAt != nil && time.Now().After(*entry.expiresAt) { - delete(cs.cache, key) + cs.deleteCallback(key) } } cs.mu.Unlock() @@ -158,9 +167,17 @@ func (cs *CacheStore[T]) evictOne() bool { } } + // If we found an oldest key, evict it else we delete the first key in the order list if oldestKey != "" { delete(cs.cache, oldestKey) return true + } else { + if len(cs.order) > 0 { + firstKey := cs.order[0] + cs.order = cs.order[1:] + delete(cs.cache, firstKey) + return true + } } return false @@ -176,4 +193,5 @@ func (cs *CacheStore[T]) Clear() { cs.mu.Lock() defer cs.mu.Unlock() cs.cache = make(map[string]cacheEntry[T]) + cs.order = make([]string, 0) }