From dac844595d463ca12a65ca8419feadf10edbe0e3 Mon Sep 17 00:00:00 2001 From: Stavros Date: Sun, 31 May 2026 18:55:06 +0300 Subject: [PATCH] refactor: use new cache store in services (#912) --- Makefile | 9 + internal/controller/user_controller_test.go | 2 +- .../middleware/context_middleware_test.go | 2 +- internal/service/auth_service.go | 321 ++++++--------- internal/service/cache_store.go | 197 +++++++++ internal/service/cache_store_test.go | 383 ++++++++++++++++++ 6 files changed, 721 insertions(+), 193 deletions(-) create mode 100644 internal/service/cache_store.go create mode 100644 internal/service/cache_store_test.go diff --git a/Makefile b/Makefile index 375b2def..b50b7bec 100644 --- a/Makefile +++ b/Makefile @@ -62,6 +62,15 @@ binary-linux-arm64: test: go test -v ./... +# Go vet +.PHONY: vet +vet: + go vet ./... + +# Go race +test-race: + go test -race ./... + # Development dev: docker compose -f $(DEV_COMPOSE) up --force-recreate --pull=always --remove-orphans --build diff --git a/internal/controller/user_controller_test.go b/internal/controller/user_controller_test.go index a23cd403..f3c0bed2 100644 --- a/internal/controller/user_controller_test.go +++ b/internal/controller/user_controller_test.go @@ -422,7 +422,7 @@ func TestUserController(t *testing.T) { beforeEach := func() { // Clear failed login attempts before each test - authService.ClearRateLimitsTestingOnly() + authService.ClearLoginAttempts() } for _, test := range tests { diff --git a/internal/middleware/context_middleware_test.go b/internal/middleware/context_middleware_test.go index 203e6858..50ededdb 100644 --- a/internal/middleware/context_middleware_test.go +++ b/internal/middleware/context_middleware_test.go @@ -263,7 +263,7 @@ func TestContextMiddleware(t *testing.T) { contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker, nil) for _, test := range tests { - authService.ClearRateLimitsTestingOnly() + authService.ClearLoginAttempts() t.Run(test.description, func(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index 62a3d4d8..1034ed1e 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -15,8 +15,6 @@ import ( "github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils/logger" - "slices" - "github.com/google/uuid" "golang.org/x/crypto/bcrypt" "golang.org/x/oauth2" @@ -54,27 +52,17 @@ type OAuthPendingSession struct { CallbackParams OAuthURLParams } -type LdapGroupsCache struct { - Groups []string - Expires time.Time -} - type LoginAttempt struct { FailedAttempts int LastAttempt time.Time LockedUntil time.Time } -type Lockdown struct { - Active bool - ActiveUntil time.Time -} - type AuthService struct { log *logger.Logger config model.Config runtime model.RuntimeConfig - context context.Context + ctx context.Context ldap *LdapService queries repository.Store @@ -82,15 +70,19 @@ type AuthService struct { tailscale *TailscaleService policyEngine *PolicyEngine - loginAttempts map[string]*LoginAttempt - ldapGroupsCache map[string]*LdapGroupsCache - oauthPendingSessions map[string]*OAuthPendingSession - oauthMutex sync.RWMutex - loginMutex sync.RWMutex - ldapGroupsMutex sync.RWMutex - lockdown *Lockdown - lockdownCtx context.Context - lockdownCancelFunc context.CancelFunc + lockdown struct { + active bool + until time.Time + ctx context.Context + cancelFunc context.CancelFunc + mu sync.RWMutex + } + + caches struct { + login *CacheStore[LoginAttempt] + oauth *CacheStore[OAuthPendingSession] + ldap *CacheStore[[]string] + } } func NewAuthService( @@ -106,21 +98,41 @@ func NewAuthService( policy *PolicyEngine, ) *AuthService { service := &AuthService{ - log: log, - runtime: runtime, - context: ctx, - config: config, - loginAttempts: make(map[string]*LoginAttempt), - ldapGroupsCache: make(map[string]*LdapGroupsCache), - oauthPendingSessions: make(map[string]*OAuthPendingSession), - ldap: ldap, - queries: queries, - oauthBroker: oauthBroker, - tailscale: tailscale, - policyEngine: policy, + log: log, + runtime: runtime, + ctx: ctx, + config: config, + ldap: ldap, + queries: queries, + oauthBroker: oauthBroker, + tailscale: tailscale, + policyEngine: policy, } - dg.Go(service.cleanupOAuthSessions, ding.RingMinor) + // caches setup + oauthCache := NewCacheStore[OAuthPendingSession](256) + loginCache := NewCacheStore[LoginAttempt](1024) + ldapCache := NewCacheStore[[]string](1024) + + service.caches.oauth = oauthCache + service.caches.login = loginCache + service.caches.ldap = ldapCache + + dg.Go(func(ctx context.Context) { + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + service.caches.oauth.Sweep() + service.caches.login.Sweep() + service.caches.ldap.Sweep() + case <-ctx.Done(): + return + } + } + }, ding.RingMinor) return service } @@ -195,14 +207,12 @@ func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) { return nil, errors.New("ldap service not configured") } - auth.ldapGroupsMutex.RLock() - entry, exists := auth.ldapGroupsCache[userDN] - auth.ldapGroupsMutex.RUnlock() + entry, exists := auth.caches.ldap.Get(userDN) - if exists && time.Now().Before(entry.Expires) { + if exists { return &model.LDAPUser{ DN: userDN, - Groups: entry.Groups, + Groups: entry, }, nil } @@ -212,12 +222,7 @@ func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) { return nil, fmt.Errorf("failed to get ldap groups: %w", err) } - auth.ldapGroupsMutex.Lock() - auth.ldapGroupsCache[userDN] = &LdapGroupsCache{ - Groups: groups, - Expires: time.Now().Add(time.Duration(auth.config.LDAP.GroupCacheTTL) * time.Second), - } - auth.ldapGroupsMutex.Unlock() + auth.caches.ldap.Set(userDN, groups, time.Duration(auth.config.LDAP.GroupCacheTTL)*time.Second) return &model.LDAPUser{ DN: userDN, @@ -226,11 +231,7 @@ func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) { } func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) { - auth.loginMutex.RLock() - defer auth.loginMutex.RUnlock() - - if auth.lockdown != nil && auth.lockdown.Active { - remaining := int(time.Until(auth.lockdown.ActiveUntil).Seconds()) + if locked, remaining := auth.IsInLockdown(); locked { return true, remaining } @@ -238,7 +239,7 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) { return false, 0 } - attempt, exists := auth.loginAttempts[identifier] + attempt, exists := auth.caches.login.Get(identifier) if !exists { return false, 0 } @@ -256,37 +257,49 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) { return } - auth.loginMutex.Lock() - defer auth.loginMutex.Unlock() - - if len(auth.loginAttempts) >= MaxLoginAttemptRecords { - if auth.lockdown != nil && auth.lockdown.Active { + if auth.caches.login.Size() >= MaxLoginAttemptRecords { + if locked, _ := auth.IsInLockdown(); locked { return } go auth.lockdownMode() return } - attempt, exists := auth.loginAttempts[identifier] - if !exists { - attempt = &LoginAttempt{} - auth.loginAttempts[identifier] = attempt - } + auth.caches.login.WithLock(func(actions CacheStoreActions[LoginAttempt]) { + entry, ok := actions.Get(identifier) - attempt.LastAttempt = time.Now() + if !ok { + attempt := LoginAttempt{ + LastAttempt: time.Now(), + } + if !success { + attempt.FailedAttempts = 1 + if attempt.FailedAttempts >= auth.config.Auth.LoginMaxRetries { + attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second) + auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", attempt.FailedAttempts).Msg("Account locked due to too many failed login attempts") + } + } + // match current tinyauth behavior which doesn't expire rate limits + actions.Set(identifier, attempt, 0) + return + } - if success { - attempt.FailedAttempts = 0 - attempt.LockedUntil = time.Time{} // Reset lock time - return - } + entry.LastAttempt = time.Now() - attempt.FailedAttempts++ + if success { + entry.FailedAttempts = 0 + entry.LockedUntil = time.Time{} + } else { + entry.FailedAttempts++ - if attempt.FailedAttempts >= auth.config.Auth.LoginMaxRetries { - attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second) - auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", attempt.FailedAttempts).Msg("Account locked due to too many failed login attempts") - } + if entry.FailedAttempts >= auth.config.Auth.LoginMaxRetries { + entry.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second) + auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", entry.FailedAttempts).Msg("Account locked due to too many failed login attempts") + } + } + + actions.Set(identifier, entry, 0) + }) } // We could also directly access the policyEngine.effectToAccess but @@ -504,8 +517,6 @@ func (auth *AuthService) LDAPAuthConfigured() bool { } func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLParams) (string, OAuthPendingSession, error) { - auth.ensureOAuthSessionLimit() - service, ok := auth.oauthBroker.GetService(serviceName) if !ok { @@ -529,9 +540,7 @@ func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLPara CallbackParams: params, } - auth.oauthMutex.Lock() - auth.oauthPendingSessions[sessionId.String()] = &session - auth.oauthMutex.Unlock() + auth.caches.oauth.Set(sessionId.String(), session, time.Minute*10) return sessionId.String(), session, nil } @@ -547,10 +556,10 @@ func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) { } func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.Token, error) { - session, err := auth.GetOAuthPendingSession(sessionId) + session, ok := auth.caches.oauth.Get(sessionId) - if err != nil { - return nil, err + if !ok { + return nil, fmt.Errorf("oauth session not found: %s", sessionId) } token, err := (*session.Service).GetToken(code, session.Verifier) @@ -559,9 +568,14 @@ func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.T return nil, fmt.Errorf("failed to exchange code for token: %w", err) } - auth.oauthMutex.Lock() session.Token = token - auth.oauthMutex.Unlock() + + // ttl 0 means keep current expiration + ok = auth.caches.oauth.Update(sessionId, session, 0) + + if !ok { + return nil, fmt.Errorf("failed to update oauth session with token: %s", sessionId) + } return token, nil } @@ -597,123 +611,39 @@ func (auth *AuthService) GetOAuthService(sessionId string) (OAuthServiceImpl, er } func (auth *AuthService) EndOAuthSession(sessionId string) { - auth.oauthMutex.Lock() - delete(auth.oauthPendingSessions, sessionId) - auth.oauthMutex.Unlock() -} - -func (auth *AuthService) cleanupOAuthSessions(ctx context.Context) { - auth.log.App.Debug().Msg("Starting OAuth session cleanup routine") - - ticker := time.NewTicker(30 * time.Minute) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - auth.log.App.Debug().Msg("Running OAuth session cleanup") - - auth.oauthMutex.Lock() - - now := time.Now() - - for sessionId, session := range auth.oauthPendingSessions { - if now.After(session.ExpiresAt) { - delete(auth.oauthPendingSessions, sessionId) - } - } - - auth.oauthMutex.Unlock() - auth.log.App.Debug().Msg("OAuth session cleanup completed") - case <-ctx.Done(): - auth.log.App.Debug().Msg("Stopping OAuth session cleanup routine") - return - } - } + auth.caches.oauth.Delete(sessionId) } func (auth *AuthService) GetOAuthPendingSession(sessionId string) (*OAuthPendingSession, error) { - auth.ensureOAuthSessionLimit() - - auth.oauthMutex.RLock() - session, exists := auth.oauthPendingSessions[sessionId] - auth.oauthMutex.RUnlock() + session, exists := auth.caches.oauth.Get(sessionId) if !exists { return &OAuthPendingSession{}, fmt.Errorf("oauth session not found: %s", sessionId) } - if time.Now().After(session.ExpiresAt) { - auth.oauthMutex.Lock() - delete(auth.oauthPendingSessions, sessionId) - auth.oauthMutex.Unlock() - return &OAuthPendingSession{}, fmt.Errorf("oauth session expired: %s", sessionId) - } - - return session, nil -} - -func (auth *AuthService) ensureOAuthSessionLimit() { - auth.oauthMutex.Lock() - defer auth.oauthMutex.Unlock() - - if len(auth.oauthPendingSessions) <= MaxOAuthPendingSessions { - return - } - - type entry struct { - id string - expiresAt int64 - } - - entries := make([]entry, 0, len(auth.oauthPendingSessions)) - for id, session := range auth.oauthPendingSessions { - entries = append(entries, entry{id, session.ExpiresAt.Unix()}) - } - - slices.SortFunc(entries, func(a, b entry) int { - if a.expiresAt < b.expiresAt { - return -1 - } - if a.expiresAt > b.expiresAt { - return 1 - } - return 0 - }) - - for _, e := range entries[:OAuthCleanupCount] { - delete(auth.oauthPendingSessions, e.id) - } + return &session, nil } func (auth *AuthService) lockdownMode() { - ctx, cancel := context.WithCancel(context.Background()) + auth.lockdown.mu.Lock() - auth.loginMutex.Lock() - - if auth.lockdown != nil && auth.lockdown.Active { - auth.loginMutex.Unlock() - cancel() + if auth.lockdown.active { + auth.lockdown.mu.Unlock() return } - auth.lockdownCtx = ctx - auth.lockdownCancelFunc = cancel + ctx, cancel := context.WithCancel(context.Background()) auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode") - auth.lockdown = &Lockdown{ - Active: true, - ActiveUntil: time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second), - } + auth.lockdown.active = true + auth.lockdown.ctx = ctx + auth.lockdown.cancelFunc = cancel + auth.lockdown.until = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second) - // At this point all login attemps will also expire so, - // we might as well clear them to free up memory - auth.loginAttempts = make(map[string]*LoginAttempt) + timer := time.NewTimer(time.Until(auth.lockdown.until)) - timer := time.NewTimer(time.Until(auth.lockdown.ActiveUntil)) - - auth.loginMutex.Unlock() + auth.lockdown.mu.Unlock() defer cancel() defer timer.Stop() @@ -723,24 +653,33 @@ func (auth *AuthService) lockdownMode() { // Timer expired, end lockdown case <-ctx.Done(): // Context cancelled, end lockdown - case <-auth.context.Done(): + case <-auth.ctx.Done(): // Service is shutting down, end lockdown } - auth.loginMutex.Lock() + auth.lockdown.mu.Lock() auth.log.App.Info().Msg("Exiting lockdown mode") - auth.lockdown = nil - auth.loginMutex.Unlock() + auth.lockdown.active = false + auth.lockdown.until = time.Time{} + auth.lockdown.ctx = nil + auth.lockdown.cancelFunc = nil + + auth.lockdown.mu.Unlock() } -// Function only used for testing - do not use in prod! -func (auth *AuthService) ClearRateLimitsTestingOnly() { - auth.loginMutex.Lock() - auth.loginAttempts = make(map[string]*LoginAttempt) - if auth.lockdown != nil { - auth.lockdownCancelFunc() +func (auth *AuthService) IsInLockdown() (bool, int) { + auth.lockdown.mu.RLock() + defer auth.lockdown.mu.RUnlock() + if auth.lockdown.active { + remaining := int(time.Until(auth.lockdown.until).Seconds()) + return true, remaining } - auth.loginMutex.Unlock() + return false, 0 +} + +// mostly a testing function, not useful for anything else +func (auth *AuthService) ClearLoginAttempts() { + auth.caches.login.Clear() } diff --git a/internal/service/cache_store.go b/internal/service/cache_store.go new file mode 100644 index 00000000..9dbe057d --- /dev/null +++ b/internal/service/cache_store.go @@ -0,0 +1,197 @@ +package service + +import ( + "slices" + "sync" + "time" +) + +type CacheStoreActions[T any] struct { + Set func(key string, value T, ttl time.Duration) + Get func(key string) (T, bool) + Delete func(key string) + Update func(key string, value T, ttl time.Duration) bool +} + +type cacheEntry[T any] struct { + value T + expiresAt *time.Time +} + +type CacheStore[T any] struct { + cache map[string]cacheEntry[T] + order []string + mu sync.RWMutex + maxSize int +} + +func NewCacheStore[T any](maxSize int) *CacheStore[T] { + return &CacheStore[T]{ + cache: make(map[string]cacheEntry[T]), + order: make([]string, 0), + maxSize: maxSize, + } +} + +// With lock allows performing multiple operations on the cache store atomically. +// The provided mutate function receives a set of actions (Set, Get, Delete) that +// can be used to manipulate the cache store within the locked context. +func (cs *CacheStore[T]) WithLock(mutate func(actions CacheStoreActions[T])) { + cs.mu.Lock() + defer cs.mu.Unlock() + actions := CacheStoreActions[T]{ + Set: cs.setCallback, + Get: cs.getCallback, + Delete: cs.deleteCallback, + Update: cs.updateCallback, + } + mutate(actions) +} + +func (cs *CacheStore[T]) updateCallback(key string, value T, ttl time.Duration) bool { + if currentEntry, exists := cs.cache[key]; exists { + if currentEntry.expiresAt != nil && time.Now().After(*currentEntry.expiresAt) { + return false + } + + entry := cacheEntry[T]{ + value: value, + expiresAt: currentEntry.expiresAt, + } + + if ttl > 0 { + expiration := time.Now().Add(ttl) + entry.expiresAt = &expiration + } + + cs.cache[key] = entry + + return true + } + + return false +} + +func (cs *CacheStore[T]) Update(key string, value T, ttl time.Duration) bool { + cs.mu.Lock() + defer cs.mu.Unlock() + return cs.updateCallback(key, value, ttl) +} + +func (cs *CacheStore[T]) setCallback(key string, value T, ttl time.Duration) { + if cs.maxSize > 0 { + if _, exists := cs.cache[key]; !exists && len(cs.cache) >= cs.maxSize { + cs.evictOne() + } + } + + var expiresAt *time.Time + + if ttl > 0 { + expiration := time.Now().Add(ttl) + expiresAt = &expiration + } + + cs.cache[key] = cacheEntry[T]{ + value: value, + expiresAt: expiresAt, + } + + if !slices.Contains(cs.order, key) { + cs.order = append(cs.order, key) + } +} + +func (cs *CacheStore[T]) Set(key string, value T, ttl time.Duration) { + cs.mu.Lock() + defer cs.mu.Unlock() + cs.setCallback(key, value, ttl) +} + +func (cs *CacheStore[T]) getCallback(key string) (T, bool) { + entry, exists := cs.cache[key] + + if !exists { + var zero T + return zero, false + } + + if entry.expiresAt != nil && time.Now().After(*entry.expiresAt) { + var zero T + return zero, false + } + + return entry.value, true +} + +func (cs *CacheStore[T]) Get(key string) (T, bool) { + cs.mu.RLock() + defer cs.mu.RUnlock() + return cs.getCallback(key) +} + +func (cs *CacheStore[T]) deleteCallback(key string) { + delete(cs.cache, key) + keyIdx := slices.Index(cs.order, key) + if keyIdx != -1 { + cs.order = append(cs.order[:keyIdx], cs.order[keyIdx+1:]...) + } +} + +func (cs *CacheStore[T]) Delete(key string) { + cs.mu.Lock() + defer cs.mu.Unlock() + cs.deleteCallback(key) +} + +func (cs *CacheStore[T]) Sweep() { + cs.mu.Lock() + for key, entry := range cs.cache { + if entry.expiresAt != nil && time.Now().After(*entry.expiresAt) { + cs.deleteCallback(key) + } + } + cs.mu.Unlock() +} + +func (cs *CacheStore[T]) evictOne() bool { + now := time.Now() + var oldestKey string + var oldestExp *time.Time + + for k, e := range cs.cache { + if e.expiresAt != nil && now.After(*e.expiresAt) { + cs.deleteCallback(k) + return true + } + if e.expiresAt != nil && (oldestExp == nil || e.expiresAt.Before(*oldestExp)) { + oldestKey, oldestExp = k, e.expiresAt + } + } + + // If we found an oldest key, evict it else we delete the first key in the order list + if oldestKey != "" { + cs.deleteCallback(oldestKey) + return true + } else { + if len(cs.order) > 0 { + cs.deleteCallback(cs.order[0]) + return true + } + } + + return false +} + +func (cs *CacheStore[T]) Size() int { + cs.mu.RLock() + defer cs.mu.RUnlock() + return len(cs.cache) +} + +func (cs *CacheStore[T]) Clear() { + cs.mu.Lock() + defer cs.mu.Unlock() + cs.cache = make(map[string]cacheEntry[T]) + cs.order = make([]string, 0) +} diff --git a/internal/service/cache_store_test.go b/internal/service/cache_store_test.go new file mode 100644 index 00000000..8908017f --- /dev/null +++ b/internal/service/cache_store_test.go @@ -0,0 +1,383 @@ +package service + +import ( + "strconv" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestCacheStoreGet(t *testing.T) { + tests := []struct { + name string + setup func(cs *CacheStore[string]) + wantValue string + wantOk bool + }{ + { + name: "returns a stored value", + setup: func(cs *CacheStore[string]) { cs.Set("key", "value", 0) }, + wantValue: "value", + wantOk: true, + }, + { + name: "reports a missing key", + setup: func(cs *CacheStore[string]) {}, + wantOk: false, + }, + { + name: "returns the latest value after an overwrite", + setup: func(cs *CacheStore[string]) { + cs.Set("key", "first", 0) + cs.Set("key", "second", 0) + }, + wantValue: "second", + wantOk: true, + }, + { + name: "returns a non-expired entry", + setup: func(cs *CacheStore[string]) { cs.Set("key", "value", time.Minute) }, + wantValue: "value", + wantOk: true, + }, + { + name: "treats an expired entry as missing", + setup: func(cs *CacheStore[string]) { + cs.Set("key", "value", 10*time.Millisecond) + time.Sleep(20 * time.Millisecond) + }, + wantOk: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cs := NewCacheStore[string](0) + tt.setup(cs) + + value, ok := cs.Get("key") + assert.Equal(t, tt.wantOk, ok) + if tt.wantOk { + assert.Equal(t, tt.wantValue, value) + } + }) + } +} + +func TestCacheStoreUpdate(t *testing.T) { + tests := []struct { + name string + setup func(cs *CacheStore[string]) + ttl time.Duration + wantOk bool + afterWait time.Duration + wantPresent bool + wantValue string + }{ + { + name: "updates an existing entry", + setup: func(cs *CacheStore[string]) { cs.Set("key", "old", 0) }, + ttl: 0, + wantOk: true, + wantPresent: true, + wantValue: "new", + }, + { + name: "does not create a missing entry", + setup: func(cs *CacheStore[string]) {}, + ttl: 0, + wantOk: false, + wantPresent: false, + }, + { + name: "preserves the existing expiry when ttl is zero", + setup: func(cs *CacheStore[string]) { cs.Set("key", "old", 30*time.Millisecond) }, + ttl: 0, + wantOk: true, + afterWait: 40 * time.Millisecond, + wantPresent: false, + }, + { + name: "refreshes the expiry when ttl is provided", + setup: func(cs *CacheStore[string]) { cs.Set("key", "old", 10*time.Millisecond) }, + ttl: time.Minute, + wantOk: true, + afterWait: 20 * time.Millisecond, + wantPresent: true, + wantValue: "new", + }, + { + name: "does not update an expired entry", + setup: func(cs *CacheStore[string]) { + cs.Set("key", "old", 10*time.Millisecond) + time.Sleep(20 * time.Millisecond) + }, + ttl: time.Minute, + wantOk: false, + wantPresent: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cs := NewCacheStore[string](0) + tt.setup(cs) + + ok := cs.Update("key", "new", tt.ttl) + assert.Equal(t, tt.wantOk, ok) + + time.Sleep(tt.afterWait) + + value, present := cs.Get("key") + assert.Equal(t, tt.wantPresent, present) + if tt.wantPresent { + assert.Equal(t, tt.wantValue, value) + } + }) + } +} + +func TestCacheStoreDelete(t *testing.T) { + tests := []struct { + name string + setup func(cs *CacheStore[string]) + key string + wantSize int + }{ + { + name: "removes an existing key", + setup: func(cs *CacheStore[string]) { + cs.Set("a", "1", 0) + cs.Set("b", "2", 0) + }, + key: "a", + wantSize: 1, + }, + { + name: "is a no-op for a missing key", + setup: func(cs *CacheStore[string]) { cs.Set("a", "1", 0) }, + key: "missing", + wantSize: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cs := NewCacheStore[string](0) + tt.setup(cs) + + cs.Delete(tt.key) + + _, ok := cs.Get(tt.key) + assert.False(t, ok) + assert.Equal(t, tt.wantSize, cs.Size()) + }) + } +} + +func TestCacheStoreSweep(t *testing.T) { + tests := []struct { + name string + setup func(cs *CacheStore[string]) + present []string + absent []string + wantSize int + }{ + { + name: "removes expired entries and keeps the rest", + setup: func(cs *CacheStore[string]) { + cs.Set("permanent", "value", 0) + cs.Set("expired", "value", 10*time.Millisecond) + time.Sleep(20 * time.Millisecond) + }, + present: []string{"permanent"}, + absent: []string{"expired"}, + wantSize: 1, + }, + { + name: "keeps all live entries", + setup: func(cs *CacheStore[string]) { + cs.Set("a", "value", 0) + cs.Set("b", "value", time.Minute) + }, + present: []string{"a", "b"}, + wantSize: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cs := NewCacheStore[string](0) + tt.setup(cs) + + cs.Sweep() + + for _, key := range tt.present { + _, ok := cs.Get(key) + assert.True(t, ok) + } + for _, key := range tt.absent { + _, ok := cs.Get(key) + assert.False(t, ok) + } + assert.Equal(t, tt.wantSize, cs.Size()) + }) + } +} + +func TestCacheStoreEviction(t *testing.T) { + // Every case uses a cache with maxSize 2; the final Set in setup is the + // insertion that overflows the cache and triggers an eviction. + tests := []struct { + name string + setup func(cs *CacheStore[string]) + present []string + absent []string + wantSize int + }{ + { + name: "evicts an already expired entry first", + setup: func(cs *CacheStore[string]) { + cs.Set("expired", "value", 10*time.Millisecond) + cs.Set("fresh", "value", time.Minute) + time.Sleep(20 * time.Millisecond) + cs.Set("new", "value", time.Minute) + }, + present: []string{"fresh", "new"}, + absent: []string{"expired"}, + wantSize: 2, + }, + { + name: "evicts the entry expiring soonest", + setup: func(cs *CacheStore[string]) { + cs.Set("soon", "value", 50*time.Millisecond) + cs.Set("later", "value", time.Hour) + cs.Set("new", "value", time.Hour) + }, + present: []string{"later", "new"}, + absent: []string{"soon"}, + wantSize: 2, + }, + { + name: "evicts the oldest inserted entry when none have a ttl", + setup: func(cs *CacheStore[string]) { + cs.Set("first", "value", 0) + cs.Set("second", "value", 0) + cs.Set("third", "value", 0) + }, + present: []string{"second", "third"}, + absent: []string{"first"}, + wantSize: 2, + }, + { + name: "overwriting an existing key does not trigger eviction", + setup: func(cs *CacheStore[string]) { + cs.Set("a", "1", 0) + cs.Set("b", "2", 0) + cs.Set("a", "updated", 0) + }, + present: []string{"a", "b"}, + wantSize: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cs := NewCacheStore[string](2) + tt.setup(cs) + + for _, key := range tt.present { + _, ok := cs.Get(key) + assert.True(t, ok) + } + for _, key := range tt.absent { + _, ok := cs.Get(key) + assert.False(t, ok) + } + assert.Equal(t, tt.wantSize, cs.Size()) + }) + } +} + +func TestCacheStoreSizeAndClear(t *testing.T) { + cs := NewCacheStore[string](0) + assert.Equal(t, 0, cs.Size()) + + cs.Set("a", "1", 0) + cs.Set("b", "2", 0) + assert.Equal(t, 2, cs.Size()) + + cs.Clear() + assert.Equal(t, 0, cs.Size()) + + _, ok := cs.Get("a") + assert.False(t, ok) +} + +func TestCacheStoreWithLock(t *testing.T) { + cs := NewCacheStore[int](0) + cs.Set("counter", 1, 0) + + // All four actions run atomically under a single lock. + cs.WithLock(func(actions CacheStoreActions[int]) { + current, ok := actions.Get("counter") + assert.True(t, ok) + + actions.Set("counter", current+1, 0) + actions.Set("other", 100, 0) + actions.Delete("counter") + + updated := actions.Update("other", 200, 0) + assert.True(t, updated) + }) + + _, ok := cs.Get("counter") + assert.False(t, ok) + + value, ok := cs.Get("other") + assert.True(t, ok) + assert.Equal(t, 200, value) +} + +// TestCacheStoreConcurrency exercises every locking path concurrently so the +// race detector (go test -race) can flag unsynchronised access. +func TestCacheStoreConcurrency(t *testing.T) { + cs := NewCacheStore[int](64) + + const goroutines = 16 + const iterations = 200 + + var wg sync.WaitGroup + wg.Add(goroutines) + + for g := range goroutines { + go func(g int) { + defer wg.Done() + for i := range iterations { + key := strconv.Itoa((g*iterations + i) % 32) + switch i % 6 { + case 0: + cs.Set(key, i, time.Minute) + case 1: + cs.Get(key) + case 2: + cs.Update(key, i, time.Minute) + case 3: + cs.Delete(key) + case 4: + cs.Size() + case 5: + cs.WithLock(func(actions CacheStoreActions[int]) { + if v, ok := actions.Get(key); ok { + actions.Set(key, v+1, time.Minute) + } + }) + } + } + }(g) + } + + wg.Wait() +}