refactor: use new cache store in auth service

This commit is contained in:
Stavros
2026-05-29 23:33:35 +03:00
parent faee58ca8e
commit ed94490efd
4 changed files with 291 additions and 193 deletions
+1 -1
View File
@@ -422,7 +422,7 @@ func TestUserController(t *testing.T) {
beforeEach := func() { beforeEach := func() {
// Clear failed login attempts before each test // Clear failed login attempts before each test
authService.ClearRateLimitsTestingOnly() authService.ClearLoginAttempts()
} }
for _, test := range tests { for _, test := range tests {
@@ -263,7 +263,7 @@ func TestContextMiddleware(t *testing.T) {
contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker, nil) contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker, nil)
for _, test := range tests { for _, test := range tests {
authService.ClearRateLimitsTestingOnly() authService.ClearLoginAttempts()
t.Run(test.description, func(t *testing.T) { t.Run(test.description, func(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
+97 -178
View File
@@ -15,8 +15,6 @@ import (
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"slices"
"github.com/google/uuid" "github.com/google/uuid"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"golang.org/x/oauth2" "golang.org/x/oauth2"
@@ -54,27 +52,17 @@ type OAuthPendingSession struct {
CallbackParams OAuthURLParams CallbackParams OAuthURLParams
} }
type LdapGroupsCache struct {
Groups []string
Expires time.Time
}
type LoginAttempt struct { type LoginAttempt struct {
FailedAttempts int FailedAttempts int
LastAttempt time.Time LastAttempt time.Time
LockedUntil time.Time LockedUntil time.Time
} }
type Lockdown struct {
Active bool
ActiveUntil time.Time
}
type AuthService struct { type AuthService struct {
log *logger.Logger log *logger.Logger
config model.Config config model.Config
runtime model.RuntimeConfig runtime model.RuntimeConfig
context context.Context ctx context.Context
ldap *LdapService ldap *LdapService
queries repository.Store queries repository.Store
@@ -82,15 +70,19 @@ type AuthService struct {
tailscale *TailscaleService tailscale *TailscaleService
policyEngine *PolicyEngine policyEngine *PolicyEngine
loginAttempts map[string]*LoginAttempt lockdown struct {
ldapGroupsCache map[string]*LdapGroupsCache active bool
oauthPendingSessions map[string]*OAuthPendingSession until time.Time
oauthMutex sync.RWMutex ctx context.Context
loginMutex sync.RWMutex cancelFunc context.CancelFunc
ldapGroupsMutex sync.RWMutex mu sync.RWMutex
lockdown *Lockdown }
lockdownCtx context.Context
lockdownCancelFunc context.CancelFunc caches struct {
login *CacheStore[LoginAttempt]
oauth *CacheStore[OAuthPendingSession]
ldap *CacheStore[[]string]
}
} }
func NewAuthService( func NewAuthService(
@@ -108,11 +100,8 @@ func NewAuthService(
service := &AuthService{ service := &AuthService{
log: log, log: log,
runtime: runtime, runtime: runtime,
context: ctx, ctx: ctx,
config: config, config: config,
loginAttempts: make(map[string]*LoginAttempt),
ldapGroupsCache: make(map[string]*LdapGroupsCache),
oauthPendingSessions: make(map[string]*OAuthPendingSession),
ldap: ldap, ldap: ldap,
queries: queries, queries: queries,
oauthBroker: oauthBroker, oauthBroker: oauthBroker,
@@ -120,7 +109,30 @@ func NewAuthService(
policyEngine: policy, policyEngine: policy,
} }
dg.Go(service.cleanupOAuthSessions, ding.RingMinor) // caches setup
oauthCache := NewCacheStore[OAuthPendingSession](256)
loginCache := NewCacheStore[LoginAttempt](1024)
ldapCache := NewCacheStore[[]string](1024)
service.caches.oauth = oauthCache
service.caches.login = loginCache
service.caches.ldap = ldapCache
dg.Go(func(ctx context.Context) {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
service.caches.oauth.Sweep()
service.caches.login.Sweep()
service.caches.ldap.Sweep()
case <-service.ctx.Done():
return
}
}
}, ding.RingMinor)
return service return service
} }
@@ -195,14 +207,12 @@ func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
return nil, errors.New("ldap service not configured") return nil, errors.New("ldap service not configured")
} }
auth.ldapGroupsMutex.RLock() entry, exists := auth.caches.ldap.Get(userDN)
entry, exists := auth.ldapGroupsCache[userDN]
auth.ldapGroupsMutex.RUnlock()
if exists && time.Now().Before(entry.Expires) { if exists {
return &model.LDAPUser{ return &model.LDAPUser{
DN: userDN, DN: userDN,
Groups: entry.Groups, Groups: entry,
}, nil }, nil
} }
@@ -212,12 +222,7 @@ func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
return nil, fmt.Errorf("failed to get ldap groups: %w", err) return nil, fmt.Errorf("failed to get ldap groups: %w", err)
} }
auth.ldapGroupsMutex.Lock() auth.caches.ldap.Set(userDN, groups, time.Duration(auth.config.LDAP.GroupCacheTTL)*time.Second)
auth.ldapGroupsCache[userDN] = &LdapGroupsCache{
Groups: groups,
Expires: time.Now().Add(time.Duration(auth.config.LDAP.GroupCacheTTL) * time.Second),
}
auth.ldapGroupsMutex.Unlock()
return &model.LDAPUser{ return &model.LDAPUser{
DN: userDN, DN: userDN,
@@ -226,11 +231,8 @@ func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
} }
func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) { func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
auth.loginMutex.RLock() if auth.lockdown.active {
defer auth.loginMutex.RUnlock() remaining := int(time.Until(auth.lockdown.until).Seconds())
if auth.lockdown != nil && auth.lockdown.Active {
remaining := int(time.Until(auth.lockdown.ActiveUntil).Seconds())
return true, remaining return true, remaining
} }
@@ -238,7 +240,7 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
return false, 0 return false, 0
} }
attempt, exists := auth.loginAttempts[identifier] attempt, exists := auth.caches.login.Get(identifier)
if !exists { if !exists {
return false, 0 return false, 0
} }
@@ -256,37 +258,43 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
return return
} }
auth.loginMutex.Lock() if auth.caches.login.Size() >= MaxLoginAttemptRecords {
defer auth.loginMutex.Unlock() if auth.lockdown.active {
if len(auth.loginAttempts) >= MaxLoginAttemptRecords {
if auth.lockdown != nil && auth.lockdown.Active {
return return
} }
go auth.lockdownMode() go auth.lockdownMode()
return return
} }
attempt, exists := auth.loginAttempts[identifier] ok := auth.caches.login.Mutate(identifier, func(la LoginAttempt) (LoginAttempt, bool) {
if !exists { la.LastAttempt = time.Now()
attempt = &LoginAttempt{}
auth.loginAttempts[identifier] = attempt
}
attempt.LastAttempt = time.Now()
if success { if success {
attempt.FailedAttempts = 0 la.FailedAttempts = 0
attempt.LockedUntil = time.Time{} // Reset lock time la.LockedUntil = time.Time{}
return return la, false
} }
la.FailedAttempts++
if la.FailedAttempts >= auth.config.Auth.LoginMaxRetries {
la.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", la.FailedAttempts).Msg("Account locked due to too many failed login attempts")
}
return la, true
})
attempt.FailedAttempts++ if !ok {
// No existing record, create a new one
attempt := LoginAttempt{
LastAttempt: time.Now(),
}
if !success {
attempt.FailedAttempts = 1
if attempt.FailedAttempts >= auth.config.Auth.LoginMaxRetries { if attempt.FailedAttempts >= auth.config.Auth.LoginMaxRetries {
attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second) attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", attempt.FailedAttempts).Msg("Account locked due to too many failed login attempts") auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", attempt.FailedAttempts).Msg("Account locked due to too many failed login attempts")
} }
}
auth.caches.login.Set(identifier, attempt, 0) // match current tinyauth behavior which doesn't expire rate limits
}
} }
// We could also directly access the policyEngine.effectToAccess but // We could also directly access the policyEngine.effectToAccess but
@@ -504,8 +512,6 @@ func (auth *AuthService) LDAPAuthConfigured() bool {
} }
func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLParams) (string, OAuthPendingSession, error) { func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLParams) (string, OAuthPendingSession, error) {
auth.ensureOAuthSessionLimit()
service, ok := auth.oauthBroker.GetService(serviceName) service, ok := auth.oauthBroker.GetService(serviceName)
if !ok { if !ok {
@@ -529,9 +535,7 @@ func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLPara
CallbackParams: params, CallbackParams: params,
} }
auth.oauthMutex.Lock() auth.caches.oauth.Set(sessionId.String(), session, time.Minute*10)
auth.oauthPendingSessions[sessionId.String()] = &session
auth.oauthMutex.Unlock()
return sessionId.String(), session, nil return sessionId.String(), session, nil
} }
@@ -559,9 +563,9 @@ func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.T
return nil, fmt.Errorf("failed to exchange code for token: %w", err) return nil, fmt.Errorf("failed to exchange code for token: %w", err)
} }
auth.oauthMutex.Lock()
session.Token = token session.Token = token
auth.oauthMutex.Unlock()
auth.caches.oauth.Set(sessionId, *session, time.Minute*10)
return token, nil return token, nil
} }
@@ -597,123 +601,39 @@ func (auth *AuthService) GetOAuthService(sessionId string) (OAuthServiceImpl, er
} }
func (auth *AuthService) EndOAuthSession(sessionId string) { func (auth *AuthService) EndOAuthSession(sessionId string) {
auth.oauthMutex.Lock() auth.caches.oauth.Delete(sessionId)
delete(auth.oauthPendingSessions, sessionId)
auth.oauthMutex.Unlock()
}
func (auth *AuthService) cleanupOAuthSessions(ctx context.Context) {
auth.log.App.Debug().Msg("Starting OAuth session cleanup routine")
ticker := time.NewTicker(30 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
auth.log.App.Debug().Msg("Running OAuth session cleanup")
auth.oauthMutex.Lock()
now := time.Now()
for sessionId, session := range auth.oauthPendingSessions {
if now.After(session.ExpiresAt) {
delete(auth.oauthPendingSessions, sessionId)
}
}
auth.oauthMutex.Unlock()
auth.log.App.Debug().Msg("OAuth session cleanup completed")
case <-ctx.Done():
auth.log.App.Debug().Msg("Stopping OAuth session cleanup routine")
return
}
}
} }
func (auth *AuthService) GetOAuthPendingSession(sessionId string) (*OAuthPendingSession, error) { func (auth *AuthService) GetOAuthPendingSession(sessionId string) (*OAuthPendingSession, error) {
auth.ensureOAuthSessionLimit() session, exists := auth.caches.oauth.Get(sessionId)
auth.oauthMutex.RLock()
session, exists := auth.oauthPendingSessions[sessionId]
auth.oauthMutex.RUnlock()
if !exists { if !exists {
return &OAuthPendingSession{}, fmt.Errorf("oauth session not found: %s", sessionId) return &OAuthPendingSession{}, fmt.Errorf("oauth session not found: %s", sessionId)
} }
if time.Now().After(session.ExpiresAt) { return &session, nil
auth.oauthMutex.Lock()
delete(auth.oauthPendingSessions, sessionId)
auth.oauthMutex.Unlock()
return &OAuthPendingSession{}, fmt.Errorf("oauth session expired: %s", sessionId)
}
return session, nil
}
func (auth *AuthService) ensureOAuthSessionLimit() {
auth.oauthMutex.Lock()
defer auth.oauthMutex.Unlock()
if len(auth.oauthPendingSessions) <= MaxOAuthPendingSessions {
return
}
type entry struct {
id string
expiresAt int64
}
entries := make([]entry, 0, len(auth.oauthPendingSessions))
for id, session := range auth.oauthPendingSessions {
entries = append(entries, entry{id, session.ExpiresAt.Unix()})
}
slices.SortFunc(entries, func(a, b entry) int {
if a.expiresAt < b.expiresAt {
return -1
}
if a.expiresAt > b.expiresAt {
return 1
}
return 0
})
for _, e := range entries[:OAuthCleanupCount] {
delete(auth.oauthPendingSessions, e.id)
}
} }
func (auth *AuthService) lockdownMode() { func (auth *AuthService) lockdownMode() {
ctx, cancel := context.WithCancel(context.Background()) auth.lockdown.mu.Lock()
auth.loginMutex.Lock() if auth.lockdown.active {
auth.lockdown.mu.Unlock()
if auth.lockdown != nil && auth.lockdown.Active {
auth.loginMutex.Unlock()
cancel()
return return
} }
auth.lockdownCtx = ctx ctx, cancel := context.WithCancel(context.Background())
auth.lockdownCancelFunc = cancel
auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode") auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")
auth.lockdown = &Lockdown{ auth.lockdown.active = true
Active: true, auth.lockdown.ctx = ctx
ActiveUntil: time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second), auth.lockdown.cancelFunc = cancel
} auth.lockdown.until = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
// At this point all login attemps will also expire so, timer := time.NewTimer(time.Until(auth.lockdown.until))
// we might as well clear them to free up memory
auth.loginAttempts = make(map[string]*LoginAttempt)
timer := time.NewTimer(time.Until(auth.lockdown.ActiveUntil)) auth.lockdown.mu.Unlock()
auth.loginMutex.Unlock()
defer cancel() defer cancel()
defer timer.Stop() defer timer.Stop()
@@ -723,24 +643,23 @@ func (auth *AuthService) lockdownMode() {
// Timer expired, end lockdown // Timer expired, end lockdown
case <-ctx.Done(): case <-ctx.Done():
// Context cancelled, end lockdown // Context cancelled, end lockdown
case <-auth.context.Done(): case <-auth.ctx.Done():
// Service is shutting down, end lockdown // Service is shutting down, end lockdown
} }
auth.loginMutex.Lock() auth.lockdown.mu.Lock()
auth.log.App.Info().Msg("Exiting lockdown mode") auth.log.App.Info().Msg("Exiting lockdown mode")
auth.lockdown = nil auth.lockdown.active = false
auth.loginMutex.Unlock() auth.lockdown.until = time.Time{}
auth.lockdown.ctx = nil
auth.lockdown.cancelFunc = nil
auth.lockdown.mu.Unlock()
} }
// Function only used for testing - do not use in prod! // mostly a testing function, not useful for anything else
func (auth *AuthService) ClearRateLimitsTestingOnly() { func (auth *AuthService) ClearLoginAttempts() {
auth.loginMutex.Lock() auth.caches.login.Clear()
auth.loginAttempts = make(map[string]*LoginAttempt)
if auth.lockdown != nil {
auth.lockdownCancelFunc()
}
auth.loginMutex.Unlock()
} }
+179
View File
@@ -0,0 +1,179 @@
package service
import (
"sync"
"time"
)
type cacheEntry[T any] struct {
value T
expiresAt *time.Time
}
type CacheStore[T any] struct {
cache map[string]cacheEntry[T]
mu sync.RWMutex
maxSize int
}
func NewCacheStore[T any](maxSize int) *CacheStore[T] {
return &CacheStore[T]{
cache: make(map[string]cacheEntry[T]),
maxSize: maxSize,
}
}
func (cs *CacheStore[T]) Set(key string, value T, ttl time.Duration) {
cs.mu.Lock()
defer cs.mu.Unlock()
if cs.maxSize > 0 {
if _, exists := cs.cache[key]; !exists && len(cs.cache) >= cs.maxSize {
cs.evictOne()
}
}
var expiresAt *time.Time
if ttl > 0 {
expiration := time.Now().Add(ttl)
expiresAt = &expiration
}
cs.cache[key] = cacheEntry[T]{
value: value,
expiresAt: expiresAt,
}
}
func (cs *CacheStore[T]) Get(key string) (T, bool) {
cs.mu.RLock()
defer cs.mu.RUnlock()
entry, exists := cs.cache[key]
if !exists {
var zero T
return zero, false
}
if entry.expiresAt != nil && time.Now().After(*entry.expiresAt) {
var zero T
return zero, false
}
return entry.value, true
}
func (cs *CacheStore[T]) Delete(key string) {
cs.mu.Lock()
defer cs.mu.Unlock()
delete(cs.cache, key)
}
func (cs *CacheStore[T]) Mutate(key string, mutator func(T) (T, bool)) bool {
cs.mu.Lock()
defer cs.mu.Unlock()
entry, exists := cs.cache[key]
if !exists {
return false
}
if entry.expiresAt != nil && time.Now().After(*entry.expiresAt) {
delete(cs.cache, key)
return false
}
newValue, shouldKeep := mutator(entry.value)
if !shouldKeep {
delete(cs.cache, key)
return true
}
cs.cache[key] = cacheEntry[T]{
value: newValue,
expiresAt: entry.expiresAt,
}
return true
}
func (cs *CacheStore[T]) MutateWithTTL(key string, mutator func(T) (T, time.Duration, bool)) bool {
cs.mu.Lock()
defer cs.mu.Unlock()
entry, exists := cs.cache[key]
if !exists {
return false
}
if entry.expiresAt != nil && time.Now().After(*entry.expiresAt) {
delete(cs.cache, key)
return false
}
newValue, ttl, shouldKeep := mutator(entry.value)
if !shouldKeep {
delete(cs.cache, key)
return true
}
expiresAt := time.Now().Add(ttl)
cs.cache[key] = cacheEntry[T]{
value: newValue,
expiresAt: &expiresAt,
}
return true
}
func (cs *CacheStore[T]) Sweep() {
cs.mu.Lock()
for key, entry := range cs.cache {
if entry.expiresAt != nil && time.Now().After(*entry.expiresAt) {
delete(cs.cache, key)
}
}
cs.mu.Unlock()
}
func (cs *CacheStore[T]) evictOne() bool {
now := time.Now()
var oldestKey string
var oldestExp *time.Time
for k, e := range cs.cache {
if e.expiresAt != nil && now.After(*e.expiresAt) {
delete(cs.cache, k)
return true
}
if e.expiresAt != nil && (oldestExp == nil || e.expiresAt.Before(*oldestExp)) {
oldestKey, oldestExp = k, e.expiresAt
}
}
if oldestKey != "" {
delete(cs.cache, oldestKey)
return true
}
return false
}
func (cs *CacheStore[T]) Size() int {
cs.mu.RLock()
defer cs.mu.RUnlock()
return len(cs.cache)
}
func (cs *CacheStore[T]) Clear() {
cs.mu.Lock()
defer cs.mu.Unlock()
cs.cache = make(map[string]cacheEntry[T])
}