mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-06-02 17:40:14 +00:00
refactor: use new cache store in services (#912)
This commit is contained in:
@@ -62,6 +62,15 @@ binary-linux-arm64:
|
|||||||
test:
|
test:
|
||||||
go test -v ./...
|
go test -v ./...
|
||||||
|
|
||||||
|
# Go vet
|
||||||
|
.PHONY: vet
|
||||||
|
vet:
|
||||||
|
go vet ./...
|
||||||
|
|
||||||
|
# Go race
|
||||||
|
test-race:
|
||||||
|
go test -race ./...
|
||||||
|
|
||||||
# Development
|
# Development
|
||||||
dev:
|
dev:
|
||||||
docker compose -f $(DEV_COMPOSE) up --force-recreate --pull=always --remove-orphans --build
|
docker compose -f $(DEV_COMPOSE) up --force-recreate --pull=always --remove-orphans --build
|
||||||
|
|||||||
@@ -422,7 +422,7 @@ func TestUserController(t *testing.T) {
|
|||||||
|
|
||||||
beforeEach := func() {
|
beforeEach := func() {
|
||||||
// Clear failed login attempts before each test
|
// Clear failed login attempts before each test
|
||||||
authService.ClearRateLimitsTestingOnly()
|
authService.ClearLoginAttempts()
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
|
|||||||
@@ -263,7 +263,7 @@ func TestContextMiddleware(t *testing.T) {
|
|||||||
contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker, nil)
|
contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker, nil)
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
authService.ClearRateLimitsTestingOnly()
|
authService.ClearLoginAttempts()
|
||||||
t.Run(test.description, func(t *testing.T) {
|
t.Run(test.description, func(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
|||||||
+130
-191
@@ -15,8 +15,6 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
|
|
||||||
"slices"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
@@ -54,27 +52,17 @@ type OAuthPendingSession struct {
|
|||||||
CallbackParams OAuthURLParams
|
CallbackParams OAuthURLParams
|
||||||
}
|
}
|
||||||
|
|
||||||
type LdapGroupsCache struct {
|
|
||||||
Groups []string
|
|
||||||
Expires time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
type LoginAttempt struct {
|
type LoginAttempt struct {
|
||||||
FailedAttempts int
|
FailedAttempts int
|
||||||
LastAttempt time.Time
|
LastAttempt time.Time
|
||||||
LockedUntil time.Time
|
LockedUntil time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
type Lockdown struct {
|
|
||||||
Active bool
|
|
||||||
ActiveUntil time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
type AuthService struct {
|
type AuthService struct {
|
||||||
log *logger.Logger
|
log *logger.Logger
|
||||||
config model.Config
|
config model.Config
|
||||||
runtime model.RuntimeConfig
|
runtime model.RuntimeConfig
|
||||||
context context.Context
|
ctx context.Context
|
||||||
|
|
||||||
ldap *LdapService
|
ldap *LdapService
|
||||||
queries repository.Store
|
queries repository.Store
|
||||||
@@ -82,15 +70,19 @@ type AuthService struct {
|
|||||||
tailscale *TailscaleService
|
tailscale *TailscaleService
|
||||||
policyEngine *PolicyEngine
|
policyEngine *PolicyEngine
|
||||||
|
|
||||||
loginAttempts map[string]*LoginAttempt
|
lockdown struct {
|
||||||
ldapGroupsCache map[string]*LdapGroupsCache
|
active bool
|
||||||
oauthPendingSessions map[string]*OAuthPendingSession
|
until time.Time
|
||||||
oauthMutex sync.RWMutex
|
ctx context.Context
|
||||||
loginMutex sync.RWMutex
|
cancelFunc context.CancelFunc
|
||||||
ldapGroupsMutex sync.RWMutex
|
mu sync.RWMutex
|
||||||
lockdown *Lockdown
|
}
|
||||||
lockdownCtx context.Context
|
|
||||||
lockdownCancelFunc context.CancelFunc
|
caches struct {
|
||||||
|
login *CacheStore[LoginAttempt]
|
||||||
|
oauth *CacheStore[OAuthPendingSession]
|
||||||
|
ldap *CacheStore[[]string]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAuthService(
|
func NewAuthService(
|
||||||
@@ -106,21 +98,41 @@ func NewAuthService(
|
|||||||
policy *PolicyEngine,
|
policy *PolicyEngine,
|
||||||
) *AuthService {
|
) *AuthService {
|
||||||
service := &AuthService{
|
service := &AuthService{
|
||||||
log: log,
|
log: log,
|
||||||
runtime: runtime,
|
runtime: runtime,
|
||||||
context: ctx,
|
ctx: ctx,
|
||||||
config: config,
|
config: config,
|
||||||
loginAttempts: make(map[string]*LoginAttempt),
|
ldap: ldap,
|
||||||
ldapGroupsCache: make(map[string]*LdapGroupsCache),
|
queries: queries,
|
||||||
oauthPendingSessions: make(map[string]*OAuthPendingSession),
|
oauthBroker: oauthBroker,
|
||||||
ldap: ldap,
|
tailscale: tailscale,
|
||||||
queries: queries,
|
policyEngine: policy,
|
||||||
oauthBroker: oauthBroker,
|
|
||||||
tailscale: tailscale,
|
|
||||||
policyEngine: policy,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
dg.Go(service.cleanupOAuthSessions, ding.RingMinor)
|
// caches setup
|
||||||
|
oauthCache := NewCacheStore[OAuthPendingSession](256)
|
||||||
|
loginCache := NewCacheStore[LoginAttempt](1024)
|
||||||
|
ldapCache := NewCacheStore[[]string](1024)
|
||||||
|
|
||||||
|
service.caches.oauth = oauthCache
|
||||||
|
service.caches.login = loginCache
|
||||||
|
service.caches.ldap = ldapCache
|
||||||
|
|
||||||
|
dg.Go(func(ctx context.Context) {
|
||||||
|
ticker := time.NewTicker(1 * time.Minute)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
service.caches.oauth.Sweep()
|
||||||
|
service.caches.login.Sweep()
|
||||||
|
service.caches.ldap.Sweep()
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, ding.RingMinor)
|
||||||
|
|
||||||
return service
|
return service
|
||||||
}
|
}
|
||||||
@@ -195,14 +207,12 @@ func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
|
|||||||
return nil, errors.New("ldap service not configured")
|
return nil, errors.New("ldap service not configured")
|
||||||
}
|
}
|
||||||
|
|
||||||
auth.ldapGroupsMutex.RLock()
|
entry, exists := auth.caches.ldap.Get(userDN)
|
||||||
entry, exists := auth.ldapGroupsCache[userDN]
|
|
||||||
auth.ldapGroupsMutex.RUnlock()
|
|
||||||
|
|
||||||
if exists && time.Now().Before(entry.Expires) {
|
if exists {
|
||||||
return &model.LDAPUser{
|
return &model.LDAPUser{
|
||||||
DN: userDN,
|
DN: userDN,
|
||||||
Groups: entry.Groups,
|
Groups: entry,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -212,12 +222,7 @@ func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
|
|||||||
return nil, fmt.Errorf("failed to get ldap groups: %w", err)
|
return nil, fmt.Errorf("failed to get ldap groups: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
auth.ldapGroupsMutex.Lock()
|
auth.caches.ldap.Set(userDN, groups, time.Duration(auth.config.LDAP.GroupCacheTTL)*time.Second)
|
||||||
auth.ldapGroupsCache[userDN] = &LdapGroupsCache{
|
|
||||||
Groups: groups,
|
|
||||||
Expires: time.Now().Add(time.Duration(auth.config.LDAP.GroupCacheTTL) * time.Second),
|
|
||||||
}
|
|
||||||
auth.ldapGroupsMutex.Unlock()
|
|
||||||
|
|
||||||
return &model.LDAPUser{
|
return &model.LDAPUser{
|
||||||
DN: userDN,
|
DN: userDN,
|
||||||
@@ -226,11 +231,7 @@ func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
|
func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
|
||||||
auth.loginMutex.RLock()
|
if locked, remaining := auth.IsInLockdown(); locked {
|
||||||
defer auth.loginMutex.RUnlock()
|
|
||||||
|
|
||||||
if auth.lockdown != nil && auth.lockdown.Active {
|
|
||||||
remaining := int(time.Until(auth.lockdown.ActiveUntil).Seconds())
|
|
||||||
return true, remaining
|
return true, remaining
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -238,7 +239,7 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
|
|||||||
return false, 0
|
return false, 0
|
||||||
}
|
}
|
||||||
|
|
||||||
attempt, exists := auth.loginAttempts[identifier]
|
attempt, exists := auth.caches.login.Get(identifier)
|
||||||
if !exists {
|
if !exists {
|
||||||
return false, 0
|
return false, 0
|
||||||
}
|
}
|
||||||
@@ -256,37 +257,49 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
auth.loginMutex.Lock()
|
if auth.caches.login.Size() >= MaxLoginAttemptRecords {
|
||||||
defer auth.loginMutex.Unlock()
|
if locked, _ := auth.IsInLockdown(); locked {
|
||||||
|
|
||||||
if len(auth.loginAttempts) >= MaxLoginAttemptRecords {
|
|
||||||
if auth.lockdown != nil && auth.lockdown.Active {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
go auth.lockdownMode()
|
go auth.lockdownMode()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
attempt, exists := auth.loginAttempts[identifier]
|
auth.caches.login.WithLock(func(actions CacheStoreActions[LoginAttempt]) {
|
||||||
if !exists {
|
entry, ok := actions.Get(identifier)
|
||||||
attempt = &LoginAttempt{}
|
|
||||||
auth.loginAttempts[identifier] = attempt
|
|
||||||
}
|
|
||||||
|
|
||||||
attempt.LastAttempt = time.Now()
|
if !ok {
|
||||||
|
attempt := LoginAttempt{
|
||||||
|
LastAttempt: time.Now(),
|
||||||
|
}
|
||||||
|
if !success {
|
||||||
|
attempt.FailedAttempts = 1
|
||||||
|
if attempt.FailedAttempts >= auth.config.Auth.LoginMaxRetries {
|
||||||
|
attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
|
||||||
|
auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", attempt.FailedAttempts).Msg("Account locked due to too many failed login attempts")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// match current tinyauth behavior which doesn't expire rate limits
|
||||||
|
actions.Set(identifier, attempt, 0)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if success {
|
entry.LastAttempt = time.Now()
|
||||||
attempt.FailedAttempts = 0
|
|
||||||
attempt.LockedUntil = time.Time{} // Reset lock time
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
attempt.FailedAttempts++
|
if success {
|
||||||
|
entry.FailedAttempts = 0
|
||||||
|
entry.LockedUntil = time.Time{}
|
||||||
|
} else {
|
||||||
|
entry.FailedAttempts++
|
||||||
|
|
||||||
if attempt.FailedAttempts >= auth.config.Auth.LoginMaxRetries {
|
if entry.FailedAttempts >= auth.config.Auth.LoginMaxRetries {
|
||||||
attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
|
entry.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
|
||||||
auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", attempt.FailedAttempts).Msg("Account locked due to too many failed login attempts")
|
auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", entry.FailedAttempts).Msg("Account locked due to too many failed login attempts")
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
actions.Set(identifier, entry, 0)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// We could also directly access the policyEngine.effectToAccess but
|
// We could also directly access the policyEngine.effectToAccess but
|
||||||
@@ -504,8 +517,6 @@ func (auth *AuthService) LDAPAuthConfigured() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLParams) (string, OAuthPendingSession, error) {
|
func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLParams) (string, OAuthPendingSession, error) {
|
||||||
auth.ensureOAuthSessionLimit()
|
|
||||||
|
|
||||||
service, ok := auth.oauthBroker.GetService(serviceName)
|
service, ok := auth.oauthBroker.GetService(serviceName)
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -529,9 +540,7 @@ func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLPara
|
|||||||
CallbackParams: params,
|
CallbackParams: params,
|
||||||
}
|
}
|
||||||
|
|
||||||
auth.oauthMutex.Lock()
|
auth.caches.oauth.Set(sessionId.String(), session, time.Minute*10)
|
||||||
auth.oauthPendingSessions[sessionId.String()] = &session
|
|
||||||
auth.oauthMutex.Unlock()
|
|
||||||
|
|
||||||
return sessionId.String(), session, nil
|
return sessionId.String(), session, nil
|
||||||
}
|
}
|
||||||
@@ -547,10 +556,10 @@ func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.Token, error) {
|
func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.Token, error) {
|
||||||
session, err := auth.GetOAuthPendingSession(sessionId)
|
session, ok := auth.caches.oauth.Get(sessionId)
|
||||||
|
|
||||||
if err != nil {
|
if !ok {
|
||||||
return nil, err
|
return nil, fmt.Errorf("oauth session not found: %s", sessionId)
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err := (*session.Service).GetToken(code, session.Verifier)
|
token, err := (*session.Service).GetToken(code, session.Verifier)
|
||||||
@@ -559,9 +568,14 @@ func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.T
|
|||||||
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
|
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
auth.oauthMutex.Lock()
|
|
||||||
session.Token = token
|
session.Token = token
|
||||||
auth.oauthMutex.Unlock()
|
|
||||||
|
// ttl 0 means keep current expiration
|
||||||
|
ok = auth.caches.oauth.Update(sessionId, session, 0)
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("failed to update oauth session with token: %s", sessionId)
|
||||||
|
}
|
||||||
|
|
||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
@@ -597,123 +611,39 @@ func (auth *AuthService) GetOAuthService(sessionId string) (OAuthServiceImpl, er
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) EndOAuthSession(sessionId string) {
|
func (auth *AuthService) EndOAuthSession(sessionId string) {
|
||||||
auth.oauthMutex.Lock()
|
auth.caches.oauth.Delete(sessionId)
|
||||||
delete(auth.oauthPendingSessions, sessionId)
|
|
||||||
auth.oauthMutex.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (auth *AuthService) cleanupOAuthSessions(ctx context.Context) {
|
|
||||||
auth.log.App.Debug().Msg("Starting OAuth session cleanup routine")
|
|
||||||
|
|
||||||
ticker := time.NewTicker(30 * time.Minute)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ticker.C:
|
|
||||||
auth.log.App.Debug().Msg("Running OAuth session cleanup")
|
|
||||||
|
|
||||||
auth.oauthMutex.Lock()
|
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
|
|
||||||
for sessionId, session := range auth.oauthPendingSessions {
|
|
||||||
if now.After(session.ExpiresAt) {
|
|
||||||
delete(auth.oauthPendingSessions, sessionId)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
auth.oauthMutex.Unlock()
|
|
||||||
auth.log.App.Debug().Msg("OAuth session cleanup completed")
|
|
||||||
case <-ctx.Done():
|
|
||||||
auth.log.App.Debug().Msg("Stopping OAuth session cleanup routine")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) GetOAuthPendingSession(sessionId string) (*OAuthPendingSession, error) {
|
func (auth *AuthService) GetOAuthPendingSession(sessionId string) (*OAuthPendingSession, error) {
|
||||||
auth.ensureOAuthSessionLimit()
|
session, exists := auth.caches.oauth.Get(sessionId)
|
||||||
|
|
||||||
auth.oauthMutex.RLock()
|
|
||||||
session, exists := auth.oauthPendingSessions[sessionId]
|
|
||||||
auth.oauthMutex.RUnlock()
|
|
||||||
|
|
||||||
if !exists {
|
if !exists {
|
||||||
return &OAuthPendingSession{}, fmt.Errorf("oauth session not found: %s", sessionId)
|
return &OAuthPendingSession{}, fmt.Errorf("oauth session not found: %s", sessionId)
|
||||||
}
|
}
|
||||||
|
|
||||||
if time.Now().After(session.ExpiresAt) {
|
return &session, nil
|
||||||
auth.oauthMutex.Lock()
|
|
||||||
delete(auth.oauthPendingSessions, sessionId)
|
|
||||||
auth.oauthMutex.Unlock()
|
|
||||||
return &OAuthPendingSession{}, fmt.Errorf("oauth session expired: %s", sessionId)
|
|
||||||
}
|
|
||||||
|
|
||||||
return session, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (auth *AuthService) ensureOAuthSessionLimit() {
|
|
||||||
auth.oauthMutex.Lock()
|
|
||||||
defer auth.oauthMutex.Unlock()
|
|
||||||
|
|
||||||
if len(auth.oauthPendingSessions) <= MaxOAuthPendingSessions {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
type entry struct {
|
|
||||||
id string
|
|
||||||
expiresAt int64
|
|
||||||
}
|
|
||||||
|
|
||||||
entries := make([]entry, 0, len(auth.oauthPendingSessions))
|
|
||||||
for id, session := range auth.oauthPendingSessions {
|
|
||||||
entries = append(entries, entry{id, session.ExpiresAt.Unix()})
|
|
||||||
}
|
|
||||||
|
|
||||||
slices.SortFunc(entries, func(a, b entry) int {
|
|
||||||
if a.expiresAt < b.expiresAt {
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
if a.expiresAt > b.expiresAt {
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
})
|
|
||||||
|
|
||||||
for _, e := range entries[:OAuthCleanupCount] {
|
|
||||||
delete(auth.oauthPendingSessions, e.id)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) lockdownMode() {
|
func (auth *AuthService) lockdownMode() {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
auth.lockdown.mu.Lock()
|
||||||
|
|
||||||
auth.loginMutex.Lock()
|
if auth.lockdown.active {
|
||||||
|
auth.lockdown.mu.Unlock()
|
||||||
if auth.lockdown != nil && auth.lockdown.Active {
|
|
||||||
auth.loginMutex.Unlock()
|
|
||||||
cancel()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
auth.lockdownCtx = ctx
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
auth.lockdownCancelFunc = cancel
|
|
||||||
|
|
||||||
auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")
|
auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")
|
||||||
|
|
||||||
auth.lockdown = &Lockdown{
|
auth.lockdown.active = true
|
||||||
Active: true,
|
auth.lockdown.ctx = ctx
|
||||||
ActiveUntil: time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second),
|
auth.lockdown.cancelFunc = cancel
|
||||||
}
|
auth.lockdown.until = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
|
||||||
|
|
||||||
// At this point all login attemps will also expire so,
|
timer := time.NewTimer(time.Until(auth.lockdown.until))
|
||||||
// we might as well clear them to free up memory
|
|
||||||
auth.loginAttempts = make(map[string]*LoginAttempt)
|
|
||||||
|
|
||||||
timer := time.NewTimer(time.Until(auth.lockdown.ActiveUntil))
|
auth.lockdown.mu.Unlock()
|
||||||
|
|
||||||
auth.loginMutex.Unlock()
|
|
||||||
|
|
||||||
defer cancel()
|
defer cancel()
|
||||||
defer timer.Stop()
|
defer timer.Stop()
|
||||||
@@ -723,24 +653,33 @@ func (auth *AuthService) lockdownMode() {
|
|||||||
// Timer expired, end lockdown
|
// Timer expired, end lockdown
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
// Context cancelled, end lockdown
|
// Context cancelled, end lockdown
|
||||||
case <-auth.context.Done():
|
case <-auth.ctx.Done():
|
||||||
// Service is shutting down, end lockdown
|
// Service is shutting down, end lockdown
|
||||||
}
|
}
|
||||||
|
|
||||||
auth.loginMutex.Lock()
|
auth.lockdown.mu.Lock()
|
||||||
|
|
||||||
auth.log.App.Info().Msg("Exiting lockdown mode")
|
auth.log.App.Info().Msg("Exiting lockdown mode")
|
||||||
|
|
||||||
auth.lockdown = nil
|
auth.lockdown.active = false
|
||||||
auth.loginMutex.Unlock()
|
auth.lockdown.until = time.Time{}
|
||||||
|
auth.lockdown.ctx = nil
|
||||||
|
auth.lockdown.cancelFunc = nil
|
||||||
|
|
||||||
|
auth.lockdown.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Function only used for testing - do not use in prod!
|
func (auth *AuthService) IsInLockdown() (bool, int) {
|
||||||
func (auth *AuthService) ClearRateLimitsTestingOnly() {
|
auth.lockdown.mu.RLock()
|
||||||
auth.loginMutex.Lock()
|
defer auth.lockdown.mu.RUnlock()
|
||||||
auth.loginAttempts = make(map[string]*LoginAttempt)
|
if auth.lockdown.active {
|
||||||
if auth.lockdown != nil {
|
remaining := int(time.Until(auth.lockdown.until).Seconds())
|
||||||
auth.lockdownCancelFunc()
|
return true, remaining
|
||||||
}
|
}
|
||||||
auth.loginMutex.Unlock()
|
return false, 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// mostly a testing function, not useful for anything else
|
||||||
|
func (auth *AuthService) ClearLoginAttempts() {
|
||||||
|
auth.caches.login.Clear()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,197 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"slices"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CacheStoreActions[T any] struct {
|
||||||
|
Set func(key string, value T, ttl time.Duration)
|
||||||
|
Get func(key string) (T, bool)
|
||||||
|
Delete func(key string)
|
||||||
|
Update func(key string, value T, ttl time.Duration) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type cacheEntry[T any] struct {
|
||||||
|
value T
|
||||||
|
expiresAt *time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type CacheStore[T any] struct {
|
||||||
|
cache map[string]cacheEntry[T]
|
||||||
|
order []string
|
||||||
|
mu sync.RWMutex
|
||||||
|
maxSize int
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCacheStore[T any](maxSize int) *CacheStore[T] {
|
||||||
|
return &CacheStore[T]{
|
||||||
|
cache: make(map[string]cacheEntry[T]),
|
||||||
|
order: make([]string, 0),
|
||||||
|
maxSize: maxSize,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// With lock allows performing multiple operations on the cache store atomically.
|
||||||
|
// The provided mutate function receives a set of actions (Set, Get, Delete) that
|
||||||
|
// can be used to manipulate the cache store within the locked context.
|
||||||
|
func (cs *CacheStore[T]) WithLock(mutate func(actions CacheStoreActions[T])) {
|
||||||
|
cs.mu.Lock()
|
||||||
|
defer cs.mu.Unlock()
|
||||||
|
actions := CacheStoreActions[T]{
|
||||||
|
Set: cs.setCallback,
|
||||||
|
Get: cs.getCallback,
|
||||||
|
Delete: cs.deleteCallback,
|
||||||
|
Update: cs.updateCallback,
|
||||||
|
}
|
||||||
|
mutate(actions)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cs *CacheStore[T]) updateCallback(key string, value T, ttl time.Duration) bool {
|
||||||
|
if currentEntry, exists := cs.cache[key]; exists {
|
||||||
|
if currentEntry.expiresAt != nil && time.Now().After(*currentEntry.expiresAt) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
entry := cacheEntry[T]{
|
||||||
|
value: value,
|
||||||
|
expiresAt: currentEntry.expiresAt,
|
||||||
|
}
|
||||||
|
|
||||||
|
if ttl > 0 {
|
||||||
|
expiration := time.Now().Add(ttl)
|
||||||
|
entry.expiresAt = &expiration
|
||||||
|
}
|
||||||
|
|
||||||
|
cs.cache[key] = entry
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cs *CacheStore[T]) Update(key string, value T, ttl time.Duration) bool {
|
||||||
|
cs.mu.Lock()
|
||||||
|
defer cs.mu.Unlock()
|
||||||
|
return cs.updateCallback(key, value, ttl)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cs *CacheStore[T]) setCallback(key string, value T, ttl time.Duration) {
|
||||||
|
if cs.maxSize > 0 {
|
||||||
|
if _, exists := cs.cache[key]; !exists && len(cs.cache) >= cs.maxSize {
|
||||||
|
cs.evictOne()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var expiresAt *time.Time
|
||||||
|
|
||||||
|
if ttl > 0 {
|
||||||
|
expiration := time.Now().Add(ttl)
|
||||||
|
expiresAt = &expiration
|
||||||
|
}
|
||||||
|
|
||||||
|
cs.cache[key] = cacheEntry[T]{
|
||||||
|
value: value,
|
||||||
|
expiresAt: expiresAt,
|
||||||
|
}
|
||||||
|
|
||||||
|
if !slices.Contains(cs.order, key) {
|
||||||
|
cs.order = append(cs.order, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cs *CacheStore[T]) Set(key string, value T, ttl time.Duration) {
|
||||||
|
cs.mu.Lock()
|
||||||
|
defer cs.mu.Unlock()
|
||||||
|
cs.setCallback(key, value, ttl)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cs *CacheStore[T]) getCallback(key string) (T, bool) {
|
||||||
|
entry, exists := cs.cache[key]
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
var zero T
|
||||||
|
return zero, false
|
||||||
|
}
|
||||||
|
|
||||||
|
if entry.expiresAt != nil && time.Now().After(*entry.expiresAt) {
|
||||||
|
var zero T
|
||||||
|
return zero, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return entry.value, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cs *CacheStore[T]) Get(key string) (T, bool) {
|
||||||
|
cs.mu.RLock()
|
||||||
|
defer cs.mu.RUnlock()
|
||||||
|
return cs.getCallback(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cs *CacheStore[T]) deleteCallback(key string) {
|
||||||
|
delete(cs.cache, key)
|
||||||
|
keyIdx := slices.Index(cs.order, key)
|
||||||
|
if keyIdx != -1 {
|
||||||
|
cs.order = append(cs.order[:keyIdx], cs.order[keyIdx+1:]...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cs *CacheStore[T]) Delete(key string) {
|
||||||
|
cs.mu.Lock()
|
||||||
|
defer cs.mu.Unlock()
|
||||||
|
cs.deleteCallback(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cs *CacheStore[T]) Sweep() {
|
||||||
|
cs.mu.Lock()
|
||||||
|
for key, entry := range cs.cache {
|
||||||
|
if entry.expiresAt != nil && time.Now().After(*entry.expiresAt) {
|
||||||
|
cs.deleteCallback(key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cs.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cs *CacheStore[T]) evictOne() bool {
|
||||||
|
now := time.Now()
|
||||||
|
var oldestKey string
|
||||||
|
var oldestExp *time.Time
|
||||||
|
|
||||||
|
for k, e := range cs.cache {
|
||||||
|
if e.expiresAt != nil && now.After(*e.expiresAt) {
|
||||||
|
cs.deleteCallback(k)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if e.expiresAt != nil && (oldestExp == nil || e.expiresAt.Before(*oldestExp)) {
|
||||||
|
oldestKey, oldestExp = k, e.expiresAt
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we found an oldest key, evict it else we delete the first key in the order list
|
||||||
|
if oldestKey != "" {
|
||||||
|
cs.deleteCallback(oldestKey)
|
||||||
|
return true
|
||||||
|
} else {
|
||||||
|
if len(cs.order) > 0 {
|
||||||
|
cs.deleteCallback(cs.order[0])
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cs *CacheStore[T]) Size() int {
|
||||||
|
cs.mu.RLock()
|
||||||
|
defer cs.mu.RUnlock()
|
||||||
|
return len(cs.cache)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cs *CacheStore[T]) Clear() {
|
||||||
|
cs.mu.Lock()
|
||||||
|
defer cs.mu.Unlock()
|
||||||
|
cs.cache = make(map[string]cacheEntry[T])
|
||||||
|
cs.order = make([]string, 0)
|
||||||
|
}
|
||||||
@@ -0,0 +1,383 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCacheStoreGet(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setup func(cs *CacheStore[string])
|
||||||
|
wantValue string
|
||||||
|
wantOk bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "returns a stored value",
|
||||||
|
setup: func(cs *CacheStore[string]) { cs.Set("key", "value", 0) },
|
||||||
|
wantValue: "value",
|
||||||
|
wantOk: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "reports a missing key",
|
||||||
|
setup: func(cs *CacheStore[string]) {},
|
||||||
|
wantOk: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "returns the latest value after an overwrite",
|
||||||
|
setup: func(cs *CacheStore[string]) {
|
||||||
|
cs.Set("key", "first", 0)
|
||||||
|
cs.Set("key", "second", 0)
|
||||||
|
},
|
||||||
|
wantValue: "second",
|
||||||
|
wantOk: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "returns a non-expired entry",
|
||||||
|
setup: func(cs *CacheStore[string]) { cs.Set("key", "value", time.Minute) },
|
||||||
|
wantValue: "value",
|
||||||
|
wantOk: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "treats an expired entry as missing",
|
||||||
|
setup: func(cs *CacheStore[string]) {
|
||||||
|
cs.Set("key", "value", 10*time.Millisecond)
|
||||||
|
time.Sleep(20 * time.Millisecond)
|
||||||
|
},
|
||||||
|
wantOk: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
cs := NewCacheStore[string](0)
|
||||||
|
tt.setup(cs)
|
||||||
|
|
||||||
|
value, ok := cs.Get("key")
|
||||||
|
assert.Equal(t, tt.wantOk, ok)
|
||||||
|
if tt.wantOk {
|
||||||
|
assert.Equal(t, tt.wantValue, value)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCacheStoreUpdate(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setup func(cs *CacheStore[string])
|
||||||
|
ttl time.Duration
|
||||||
|
wantOk bool
|
||||||
|
afterWait time.Duration
|
||||||
|
wantPresent bool
|
||||||
|
wantValue string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "updates an existing entry",
|
||||||
|
setup: func(cs *CacheStore[string]) { cs.Set("key", "old", 0) },
|
||||||
|
ttl: 0,
|
||||||
|
wantOk: true,
|
||||||
|
wantPresent: true,
|
||||||
|
wantValue: "new",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "does not create a missing entry",
|
||||||
|
setup: func(cs *CacheStore[string]) {},
|
||||||
|
ttl: 0,
|
||||||
|
wantOk: false,
|
||||||
|
wantPresent: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "preserves the existing expiry when ttl is zero",
|
||||||
|
setup: func(cs *CacheStore[string]) { cs.Set("key", "old", 30*time.Millisecond) },
|
||||||
|
ttl: 0,
|
||||||
|
wantOk: true,
|
||||||
|
afterWait: 40 * time.Millisecond,
|
||||||
|
wantPresent: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "refreshes the expiry when ttl is provided",
|
||||||
|
setup: func(cs *CacheStore[string]) { cs.Set("key", "old", 10*time.Millisecond) },
|
||||||
|
ttl: time.Minute,
|
||||||
|
wantOk: true,
|
||||||
|
afterWait: 20 * time.Millisecond,
|
||||||
|
wantPresent: true,
|
||||||
|
wantValue: "new",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "does not update an expired entry",
|
||||||
|
setup: func(cs *CacheStore[string]) {
|
||||||
|
cs.Set("key", "old", 10*time.Millisecond)
|
||||||
|
time.Sleep(20 * time.Millisecond)
|
||||||
|
},
|
||||||
|
ttl: time.Minute,
|
||||||
|
wantOk: false,
|
||||||
|
wantPresent: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
cs := NewCacheStore[string](0)
|
||||||
|
tt.setup(cs)
|
||||||
|
|
||||||
|
ok := cs.Update("key", "new", tt.ttl)
|
||||||
|
assert.Equal(t, tt.wantOk, ok)
|
||||||
|
|
||||||
|
time.Sleep(tt.afterWait)
|
||||||
|
|
||||||
|
value, present := cs.Get("key")
|
||||||
|
assert.Equal(t, tt.wantPresent, present)
|
||||||
|
if tt.wantPresent {
|
||||||
|
assert.Equal(t, tt.wantValue, value)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCacheStoreDelete(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setup func(cs *CacheStore[string])
|
||||||
|
key string
|
||||||
|
wantSize int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "removes an existing key",
|
||||||
|
setup: func(cs *CacheStore[string]) {
|
||||||
|
cs.Set("a", "1", 0)
|
||||||
|
cs.Set("b", "2", 0)
|
||||||
|
},
|
||||||
|
key: "a",
|
||||||
|
wantSize: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "is a no-op for a missing key",
|
||||||
|
setup: func(cs *CacheStore[string]) { cs.Set("a", "1", 0) },
|
||||||
|
key: "missing",
|
||||||
|
wantSize: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
cs := NewCacheStore[string](0)
|
||||||
|
tt.setup(cs)
|
||||||
|
|
||||||
|
cs.Delete(tt.key)
|
||||||
|
|
||||||
|
_, ok := cs.Get(tt.key)
|
||||||
|
assert.False(t, ok)
|
||||||
|
assert.Equal(t, tt.wantSize, cs.Size())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCacheStoreSweep(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setup func(cs *CacheStore[string])
|
||||||
|
present []string
|
||||||
|
absent []string
|
||||||
|
wantSize int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "removes expired entries and keeps the rest",
|
||||||
|
setup: func(cs *CacheStore[string]) {
|
||||||
|
cs.Set("permanent", "value", 0)
|
||||||
|
cs.Set("expired", "value", 10*time.Millisecond)
|
||||||
|
time.Sleep(20 * time.Millisecond)
|
||||||
|
},
|
||||||
|
present: []string{"permanent"},
|
||||||
|
absent: []string{"expired"},
|
||||||
|
wantSize: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "keeps all live entries",
|
||||||
|
setup: func(cs *CacheStore[string]) {
|
||||||
|
cs.Set("a", "value", 0)
|
||||||
|
cs.Set("b", "value", time.Minute)
|
||||||
|
},
|
||||||
|
present: []string{"a", "b"},
|
||||||
|
wantSize: 2,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
cs := NewCacheStore[string](0)
|
||||||
|
tt.setup(cs)
|
||||||
|
|
||||||
|
cs.Sweep()
|
||||||
|
|
||||||
|
for _, key := range tt.present {
|
||||||
|
_, ok := cs.Get(key)
|
||||||
|
assert.True(t, ok)
|
||||||
|
}
|
||||||
|
for _, key := range tt.absent {
|
||||||
|
_, ok := cs.Get(key)
|
||||||
|
assert.False(t, ok)
|
||||||
|
}
|
||||||
|
assert.Equal(t, tt.wantSize, cs.Size())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCacheStoreEviction(t *testing.T) {
|
||||||
|
// Every case uses a cache with maxSize 2; the final Set in setup is the
|
||||||
|
// insertion that overflows the cache and triggers an eviction.
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setup func(cs *CacheStore[string])
|
||||||
|
present []string
|
||||||
|
absent []string
|
||||||
|
wantSize int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "evicts an already expired entry first",
|
||||||
|
setup: func(cs *CacheStore[string]) {
|
||||||
|
cs.Set("expired", "value", 10*time.Millisecond)
|
||||||
|
cs.Set("fresh", "value", time.Minute)
|
||||||
|
time.Sleep(20 * time.Millisecond)
|
||||||
|
cs.Set("new", "value", time.Minute)
|
||||||
|
},
|
||||||
|
present: []string{"fresh", "new"},
|
||||||
|
absent: []string{"expired"},
|
||||||
|
wantSize: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "evicts the entry expiring soonest",
|
||||||
|
setup: func(cs *CacheStore[string]) {
|
||||||
|
cs.Set("soon", "value", 50*time.Millisecond)
|
||||||
|
cs.Set("later", "value", time.Hour)
|
||||||
|
cs.Set("new", "value", time.Hour)
|
||||||
|
},
|
||||||
|
present: []string{"later", "new"},
|
||||||
|
absent: []string{"soon"},
|
||||||
|
wantSize: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "evicts the oldest inserted entry when none have a ttl",
|
||||||
|
setup: func(cs *CacheStore[string]) {
|
||||||
|
cs.Set("first", "value", 0)
|
||||||
|
cs.Set("second", "value", 0)
|
||||||
|
cs.Set("third", "value", 0)
|
||||||
|
},
|
||||||
|
present: []string{"second", "third"},
|
||||||
|
absent: []string{"first"},
|
||||||
|
wantSize: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "overwriting an existing key does not trigger eviction",
|
||||||
|
setup: func(cs *CacheStore[string]) {
|
||||||
|
cs.Set("a", "1", 0)
|
||||||
|
cs.Set("b", "2", 0)
|
||||||
|
cs.Set("a", "updated", 0)
|
||||||
|
},
|
||||||
|
present: []string{"a", "b"},
|
||||||
|
wantSize: 2,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
cs := NewCacheStore[string](2)
|
||||||
|
tt.setup(cs)
|
||||||
|
|
||||||
|
for _, key := range tt.present {
|
||||||
|
_, ok := cs.Get(key)
|
||||||
|
assert.True(t, ok)
|
||||||
|
}
|
||||||
|
for _, key := range tt.absent {
|
||||||
|
_, ok := cs.Get(key)
|
||||||
|
assert.False(t, ok)
|
||||||
|
}
|
||||||
|
assert.Equal(t, tt.wantSize, cs.Size())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCacheStoreSizeAndClear(t *testing.T) {
|
||||||
|
cs := NewCacheStore[string](0)
|
||||||
|
assert.Equal(t, 0, cs.Size())
|
||||||
|
|
||||||
|
cs.Set("a", "1", 0)
|
||||||
|
cs.Set("b", "2", 0)
|
||||||
|
assert.Equal(t, 2, cs.Size())
|
||||||
|
|
||||||
|
cs.Clear()
|
||||||
|
assert.Equal(t, 0, cs.Size())
|
||||||
|
|
||||||
|
_, ok := cs.Get("a")
|
||||||
|
assert.False(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCacheStoreWithLock(t *testing.T) {
|
||||||
|
cs := NewCacheStore[int](0)
|
||||||
|
cs.Set("counter", 1, 0)
|
||||||
|
|
||||||
|
// All four actions run atomically under a single lock.
|
||||||
|
cs.WithLock(func(actions CacheStoreActions[int]) {
|
||||||
|
current, ok := actions.Get("counter")
|
||||||
|
assert.True(t, ok)
|
||||||
|
|
||||||
|
actions.Set("counter", current+1, 0)
|
||||||
|
actions.Set("other", 100, 0)
|
||||||
|
actions.Delete("counter")
|
||||||
|
|
||||||
|
updated := actions.Update("other", 200, 0)
|
||||||
|
assert.True(t, updated)
|
||||||
|
})
|
||||||
|
|
||||||
|
_, ok := cs.Get("counter")
|
||||||
|
assert.False(t, ok)
|
||||||
|
|
||||||
|
value, ok := cs.Get("other")
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, 200, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCacheStoreConcurrency exercises every locking path concurrently so the
|
||||||
|
// race detector (go test -race) can flag unsynchronised access.
|
||||||
|
func TestCacheStoreConcurrency(t *testing.T) {
|
||||||
|
cs := NewCacheStore[int](64)
|
||||||
|
|
||||||
|
const goroutines = 16
|
||||||
|
const iterations = 200
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(goroutines)
|
||||||
|
|
||||||
|
for g := range goroutines {
|
||||||
|
go func(g int) {
|
||||||
|
defer wg.Done()
|
||||||
|
for i := range iterations {
|
||||||
|
key := strconv.Itoa((g*iterations + i) % 32)
|
||||||
|
switch i % 6 {
|
||||||
|
case 0:
|
||||||
|
cs.Set(key, i, time.Minute)
|
||||||
|
case 1:
|
||||||
|
cs.Get(key)
|
||||||
|
case 2:
|
||||||
|
cs.Update(key, i, time.Minute)
|
||||||
|
case 3:
|
||||||
|
cs.Delete(key)
|
||||||
|
case 4:
|
||||||
|
cs.Size()
|
||||||
|
case 5:
|
||||||
|
cs.WithLock(func(actions CacheStoreActions[int]) {
|
||||||
|
if v, ok := actions.Get(key); ok {
|
||||||
|
actions.Set(key, v+1, time.Minute)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}(g)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user