mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-06-03 18:10:16 +00:00
fix: fix race conditions
This commit is contained in:
@@ -128,7 +128,7 @@ func NewAuthService(
|
|||||||
service.caches.oauth.Sweep()
|
service.caches.oauth.Sweep()
|
||||||
service.caches.login.Sweep()
|
service.caches.login.Sweep()
|
||||||
service.caches.ldap.Sweep()
|
service.caches.ldap.Sweep()
|
||||||
case <-service.ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -231,8 +231,7 @@ func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
|
func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
|
||||||
if auth.lockdown.active {
|
if locked, remaining := auth.IsInLockdown(); locked {
|
||||||
remaining := int(time.Until(auth.lockdown.until).Seconds())
|
|
||||||
return true, remaining
|
return true, remaining
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -259,30 +258,17 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if auth.caches.login.Size() >= MaxLoginAttemptRecords {
|
if auth.caches.login.Size() >= MaxLoginAttemptRecords {
|
||||||
if auth.lockdown.active {
|
if locked, _ := auth.IsInLockdown(); locked {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
go auth.lockdownMode()
|
go auth.lockdownMode()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ok := auth.caches.login.Mutate(identifier, func(la LoginAttempt) (LoginAttempt, bool) {
|
auth.caches.login.WithLock(func(actions CacheStoreActions[LoginAttempt]) {
|
||||||
la.LastAttempt = time.Now()
|
entry, ok := actions.Get(identifier)
|
||||||
if success {
|
|
||||||
la.FailedAttempts = 0
|
|
||||||
la.LockedUntil = time.Time{}
|
|
||||||
return la, false
|
|
||||||
}
|
|
||||||
la.FailedAttempts++
|
|
||||||
if la.FailedAttempts >= auth.config.Auth.LoginMaxRetries {
|
|
||||||
la.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
|
|
||||||
auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", la.FailedAttempts).Msg("Account locked due to too many failed login attempts")
|
|
||||||
}
|
|
||||||
return la, true
|
|
||||||
})
|
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
// No existing record, create a new one
|
|
||||||
attempt := LoginAttempt{
|
attempt := LoginAttempt{
|
||||||
LastAttempt: time.Now(),
|
LastAttempt: time.Now(),
|
||||||
}
|
}
|
||||||
@@ -293,8 +279,27 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
|
|||||||
auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", attempt.FailedAttempts).Msg("Account locked due to too many failed login attempts")
|
auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", attempt.FailedAttempts).Msg("Account locked due to too many failed login attempts")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auth.caches.login.Set(identifier, attempt, 0) // match current tinyauth behavior which doesn't expire rate limits
|
// match current tinyauth behavior which doesn't expire rate limits
|
||||||
|
actions.Set(identifier, attempt, 0)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
entry.LastAttempt = time.Now()
|
||||||
|
|
||||||
|
if success {
|
||||||
|
entry.FailedAttempts = 0
|
||||||
|
entry.LockedUntil = time.Time{}
|
||||||
|
} else {
|
||||||
|
entry.FailedAttempts++
|
||||||
|
|
||||||
|
if entry.FailedAttempts >= auth.config.Auth.LoginMaxRetries {
|
||||||
|
entry.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
|
||||||
|
auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", entry.FailedAttempts).Msg("Account locked due to too many failed login attempts")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
actions.Set(identifier, entry, 0)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// We could also directly access the policyEngine.effectToAccess but
|
// We could also directly access the policyEngine.effectToAccess but
|
||||||
@@ -551,10 +556,10 @@ func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.Token, error) {
|
func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.Token, error) {
|
||||||
session, err := auth.GetOAuthPendingSession(sessionId)
|
session, ok := auth.caches.oauth.Get(sessionId)
|
||||||
|
|
||||||
if err != nil {
|
if !ok {
|
||||||
return nil, err
|
return nil, fmt.Errorf("oauth session not found: %s", sessionId)
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err := (*session.Service).GetToken(code, session.Verifier)
|
token, err := (*session.Service).GetToken(code, session.Verifier)
|
||||||
@@ -565,7 +570,12 @@ func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.T
|
|||||||
|
|
||||||
session.Token = token
|
session.Token = token
|
||||||
|
|
||||||
auth.caches.oauth.Set(sessionId, *session, time.Minute*10)
|
// ttl 0 means keep current expiration
|
||||||
|
ok = auth.caches.oauth.Update(sessionId, session, 0)
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("failed to update oauth session with token: %s", sessionId)
|
||||||
|
}
|
||||||
|
|
||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
@@ -659,6 +669,16 @@ func (auth *AuthService) lockdownMode() {
|
|||||||
auth.lockdown.mu.Unlock()
|
auth.lockdown.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (auth *AuthService) IsInLockdown() (bool, int) {
|
||||||
|
auth.lockdown.mu.RLock()
|
||||||
|
defer auth.lockdown.mu.RUnlock()
|
||||||
|
if auth.lockdown.active {
|
||||||
|
remaining := int(time.Until(auth.lockdown.until).Seconds())
|
||||||
|
return true, remaining
|
||||||
|
}
|
||||||
|
return false, 0
|
||||||
|
}
|
||||||
|
|
||||||
// mostly a testing function, not useful for anything else
|
// mostly a testing function, not useful for anything else
|
||||||
func (auth *AuthService) ClearLoginAttempts() {
|
func (auth *AuthService) ClearLoginAttempts() {
|
||||||
auth.caches.login.Clear()
|
auth.caches.login.Clear()
|
||||||
|
|||||||
@@ -1,10 +1,18 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type CacheStoreActions[T any] struct {
|
||||||
|
Set func(key string, value T, ttl time.Duration)
|
||||||
|
Get func(key string) (T, bool)
|
||||||
|
Delete func(key string)
|
||||||
|
Update func(key string, value T, ttl time.Duration) bool
|
||||||
|
}
|
||||||
|
|
||||||
type cacheEntry[T any] struct {
|
type cacheEntry[T any] struct {
|
||||||
value T
|
value T
|
||||||
expiresAt *time.Time
|
expiresAt *time.Time
|
||||||
@@ -12,6 +20,7 @@ type cacheEntry[T any] struct {
|
|||||||
|
|
||||||
type CacheStore[T any] struct {
|
type CacheStore[T any] struct {
|
||||||
cache map[string]cacheEntry[T]
|
cache map[string]cacheEntry[T]
|
||||||
|
order []string
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
maxSize int
|
maxSize int
|
||||||
}
|
}
|
||||||
@@ -19,14 +28,57 @@ type CacheStore[T any] struct {
|
|||||||
func NewCacheStore[T any](maxSize int) *CacheStore[T] {
|
func NewCacheStore[T any](maxSize int) *CacheStore[T] {
|
||||||
return &CacheStore[T]{
|
return &CacheStore[T]{
|
||||||
cache: make(map[string]cacheEntry[T]),
|
cache: make(map[string]cacheEntry[T]),
|
||||||
|
order: make([]string, 0),
|
||||||
maxSize: maxSize,
|
maxSize: maxSize,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cs *CacheStore[T]) Set(key string, value T, ttl time.Duration) {
|
// With lock allows performing multiple operations on the cache store atomically.
|
||||||
|
// The provided mutate function receives a set of actions (Set, Get, Delete) that
|
||||||
|
// can be used to manipulate the cache store within the locked context.
|
||||||
|
func (cs *CacheStore[T]) WithLock(mutate func(actions CacheStoreActions[T])) {
|
||||||
cs.mu.Lock()
|
cs.mu.Lock()
|
||||||
defer cs.mu.Unlock()
|
defer cs.mu.Unlock()
|
||||||
|
actions := CacheStoreActions[T]{
|
||||||
|
Set: cs.setCallback,
|
||||||
|
Get: cs.getCallback,
|
||||||
|
Delete: cs.deleteCallback,
|
||||||
|
Update: cs.updateCallback,
|
||||||
|
}
|
||||||
|
mutate(actions)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cs *CacheStore[T]) updateCallback(key string, value T, ttl time.Duration) bool {
|
||||||
|
if currentEntry, exists := cs.cache[key]; exists {
|
||||||
|
if currentEntry.expiresAt != nil && time.Now().After(*currentEntry.expiresAt) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
entry := cacheEntry[T]{
|
||||||
|
value: value,
|
||||||
|
expiresAt: currentEntry.expiresAt,
|
||||||
|
}
|
||||||
|
|
||||||
|
if ttl > 0 {
|
||||||
|
expiration := time.Now().Add(ttl)
|
||||||
|
entry.expiresAt = &expiration
|
||||||
|
}
|
||||||
|
|
||||||
|
cs.cache[key] = entry
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cs *CacheStore[T]) Update(key string, value T, ttl time.Duration) bool {
|
||||||
|
cs.mu.Lock()
|
||||||
|
defer cs.mu.Unlock()
|
||||||
|
return cs.updateCallback(key, value, ttl)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cs *CacheStore[T]) setCallback(key string, value T, ttl time.Duration) {
|
||||||
if cs.maxSize > 0 {
|
if cs.maxSize > 0 {
|
||||||
if _, exists := cs.cache[key]; !exists && len(cs.cache) >= cs.maxSize {
|
if _, exists := cs.cache[key]; !exists && len(cs.cache) >= cs.maxSize {
|
||||||
cs.evictOne()
|
cs.evictOne()
|
||||||
@@ -44,12 +96,17 @@ func (cs *CacheStore[T]) Set(key string, value T, ttl time.Duration) {
|
|||||||
value: value,
|
value: value,
|
||||||
expiresAt: expiresAt,
|
expiresAt: expiresAt,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cs.order = append(cs.order, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cs *CacheStore[T]) Get(key string) (T, bool) {
|
func (cs *CacheStore[T]) Set(key string, value T, ttl time.Duration) {
|
||||||
cs.mu.RLock()
|
cs.mu.Lock()
|
||||||
defer cs.mu.RUnlock()
|
defer cs.mu.Unlock()
|
||||||
|
cs.setCallback(key, value, ttl)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cs *CacheStore[T]) getCallback(key string) (T, bool) {
|
||||||
entry, exists := cs.cache[key]
|
entry, exists := cs.cache[key]
|
||||||
|
|
||||||
if !exists {
|
if !exists {
|
||||||
@@ -65,79 +122,31 @@ func (cs *CacheStore[T]) Get(key string) (T, bool) {
|
|||||||
return entry.value, true
|
return entry.value, true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (cs *CacheStore[T]) Get(key string) (T, bool) {
|
||||||
|
cs.mu.RLock()
|
||||||
|
defer cs.mu.RUnlock()
|
||||||
|
return cs.getCallback(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cs *CacheStore[T]) deleteCallback(key string) {
|
||||||
|
delete(cs.cache, key)
|
||||||
|
keyIdx := slices.Index(cs.order, key)
|
||||||
|
if keyIdx != -1 {
|
||||||
|
cs.order = append(cs.order[:keyIdx], cs.order[keyIdx+1:]...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (cs *CacheStore[T]) Delete(key string) {
|
func (cs *CacheStore[T]) Delete(key string) {
|
||||||
cs.mu.Lock()
|
cs.mu.Lock()
|
||||||
defer cs.mu.Unlock()
|
defer cs.mu.Unlock()
|
||||||
delete(cs.cache, key)
|
cs.deleteCallback(key)
|
||||||
}
|
|
||||||
|
|
||||||
func (cs *CacheStore[T]) Mutate(key string, mutator func(T) (T, bool)) bool {
|
|
||||||
cs.mu.Lock()
|
|
||||||
defer cs.mu.Unlock()
|
|
||||||
|
|
||||||
entry, exists := cs.cache[key]
|
|
||||||
|
|
||||||
if !exists {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if entry.expiresAt != nil && time.Now().After(*entry.expiresAt) {
|
|
||||||
delete(cs.cache, key)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
newValue, shouldKeep := mutator(entry.value)
|
|
||||||
|
|
||||||
if !shouldKeep {
|
|
||||||
delete(cs.cache, key)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
cs.cache[key] = cacheEntry[T]{
|
|
||||||
value: newValue,
|
|
||||||
expiresAt: entry.expiresAt,
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cs *CacheStore[T]) MutateWithTTL(key string, mutator func(T) (T, time.Duration, bool)) bool {
|
|
||||||
cs.mu.Lock()
|
|
||||||
defer cs.mu.Unlock()
|
|
||||||
|
|
||||||
entry, exists := cs.cache[key]
|
|
||||||
|
|
||||||
if !exists {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if entry.expiresAt != nil && time.Now().After(*entry.expiresAt) {
|
|
||||||
delete(cs.cache, key)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
newValue, ttl, shouldKeep := mutator(entry.value)
|
|
||||||
|
|
||||||
if !shouldKeep {
|
|
||||||
delete(cs.cache, key)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
expiresAt := time.Now().Add(ttl)
|
|
||||||
|
|
||||||
cs.cache[key] = cacheEntry[T]{
|
|
||||||
value: newValue,
|
|
||||||
expiresAt: &expiresAt,
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cs *CacheStore[T]) Sweep() {
|
func (cs *CacheStore[T]) Sweep() {
|
||||||
cs.mu.Lock()
|
cs.mu.Lock()
|
||||||
for key, entry := range cs.cache {
|
for key, entry := range cs.cache {
|
||||||
if entry.expiresAt != nil && time.Now().After(*entry.expiresAt) {
|
if entry.expiresAt != nil && time.Now().After(*entry.expiresAt) {
|
||||||
delete(cs.cache, key)
|
cs.deleteCallback(key)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cs.mu.Unlock()
|
cs.mu.Unlock()
|
||||||
@@ -158,9 +167,17 @@ func (cs *CacheStore[T]) evictOne() bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If we found an oldest key, evict it else we delete the first key in the order list
|
||||||
if oldestKey != "" {
|
if oldestKey != "" {
|
||||||
delete(cs.cache, oldestKey)
|
delete(cs.cache, oldestKey)
|
||||||
return true
|
return true
|
||||||
|
} else {
|
||||||
|
if len(cs.order) > 0 {
|
||||||
|
firstKey := cs.order[0]
|
||||||
|
cs.order = cs.order[1:]
|
||||||
|
delete(cs.cache, firstKey)
|
||||||
|
return true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
@@ -176,4 +193,5 @@ func (cs *CacheStore[T]) Clear() {
|
|||||||
cs.mu.Lock()
|
cs.mu.Lock()
|
||||||
defer cs.mu.Unlock()
|
defer cs.mu.Unlock()
|
||||||
cs.cache = make(map[string]cacheEntry[T])
|
cs.cache = make(map[string]cacheEntry[T])
|
||||||
|
cs.order = make([]string, 0)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user