Compare commits

...

3 Commits

Author SHA1 Message Date
dependabot[bot] 0138834ce7 chore(deps): bump the minor-patch group across 1 directory with 3 updates
Bumps the minor-patch group with 3 updates in the / directory: [golang.org/x/tools](https://github.com/golang/tools), [modernc.org/sqlite](https://gitlab.com/cznic/sqlite) and [tailscale.com](https://github.com/tailscale/tailscale).


Updates `golang.org/x/tools` from 0.44.0 to 0.45.0
- [Release notes](https://github.com/golang/tools/releases)
- [Commits](https://github.com/golang/tools/compare/v0.44.0...v0.45.0)

Updates `modernc.org/sqlite` from 1.50.1 to 1.51.0
- [Changelog](https://gitlab.com/cznic/sqlite/blob/master/CHANGELOG.md)
- [Commits](https://gitlab.com/cznic/sqlite/compare/v1.50.1...v1.51.0)

Updates `tailscale.com` from 1.98.3 to 1.98.5
- [Release notes](https://github.com/tailscale/tailscale/releases)
- [Commits](https://github.com/tailscale/tailscale/compare/v1.98.3...v1.98.5)

---
updated-dependencies:
- dependency-name: golang.org/x/tools
  dependency-version: 0.45.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: minor-patch
- dependency-name: modernc.org/sqlite
  dependency-version: 1.51.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: minor-patch
- dependency-name: tailscale.com
  dependency-version: 1.98.4
  dependency-type: direct:production
  update-type: version-update:semver-patch
  dependency-group: minor-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-06-02 01:30:51 +00:00
Stavros dac844595d refactor: use new cache store in services (#912) 2026-05-31 18:55:06 +03:00
Stavros 940ba6dff7 fix: don't allow tagged devices in tailscale integration 2026-05-31 12:42:00 +03:00
11 changed files with 743 additions and 215 deletions
+9
View File
@@ -62,6 +62,15 @@ binary-linux-arm64:
test:
go test -v ./...
# Go vet
.PHONY: vet
vet:
go vet ./...
# Go race
test-race:
go test -race ./...
# Development
dev:
docker compose -f $(DEV_COMPOSE) up --force-recreate --pull=always --remove-orphans --build
+4 -4
View File
@@ -22,11 +22,11 @@ require (
github.com/weppos/publicsuffix-go v0.50.3
golang.org/x/crypto v0.52.0
golang.org/x/oauth2 v0.36.0
golang.org/x/tools v0.44.0
golang.org/x/tools v0.45.0
k8s.io/apimachinery v0.36.1
k8s.io/client-go v0.36.1
modernc.org/sqlite v1.50.1
tailscale.com v1.98.3
modernc.org/sqlite v1.51.0
tailscale.com v1.98.5
)
require (
@@ -156,7 +156,7 @@ require (
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba // indirect
golang.org/x/arch v0.22.0 // indirect
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
golang.org/x/mod v0.35.0 // indirect
golang.org/x/mod v0.36.0 // indirect
golang.org/x/net v0.54.0 // indirect
golang.org/x/sync v0.20.0 // indirect
golang.org/x/sys v0.45.0 // indirect
+8 -8
View File
@@ -501,8 +501,8 @@ golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f h1:phY1HzDcf18Aq9
golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk=
golang.org/x/image v0.27.0 h1:C8gA4oWU/tKkdCfYT6T2u4faJu3MeNS5O8UPWlPF61w=
golang.org/x/image v0.27.0/go.mod h1:xbdrClrAUway1MUTEZDq9mz/UpRwYAkFFNUslZtcB+g=
golang.org/x/mod v0.35.0 h1:Ww1D637e6Pg+Zb2KrWfHQUnH2dQRLBQyAtpr/haaJeM=
golang.org/x/mod v0.35.0/go.mod h1:+GwiRhIInF8wPm+4AoT6L0FA1QWAad3OMdTRx4tFYlU=
golang.org/x/mod v0.36.0 h1:JJjpVx6myfUsUdAzZuOSTTmRE0PfZeNWzzvKrP7amb4=
golang.org/x/mod v0.36.0/go.mod h1:moc6ELqsWcOw5Ef3xVprK5ul/MvtVvkIXLziUOICjUQ=
golang.org/x/net v0.54.0 h1:2zJIZAxAHV/OHCDTCOHAYehQzLfSXuf/5SoL/Dv6w/w=
golang.org/x/net v0.54.0/go.mod h1:Sj4oj8jK6XmHpBZU/zWHw3BV3abl4Kvi+Ut7cQcY+cQ=
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
@@ -520,8 +520,8 @@ golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc=
golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38=
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
golang.org/x/tools v0.44.0 h1:UP4ajHPIcuMjT1GqzDWRlalUEoY+uzoZKnhOjbIPD2c=
golang.org/x/tools v0.44.0/go.mod h1:KA0AfVErSdxRZIsOVipbv3rQhVXTnlU6UhKxHd1seDI=
golang.org/x/tools v0.45.0 h1:18qN3FAooORvApf5XjCXgsuayZOEtXf6JK18I3+ONa8=
golang.org/x/tools v0.45.0/go.mod h1:LuUGqqaXcXMEFEruIVJVm5mgDD8vww/z/SR1gQ4uE/0=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
@@ -587,8 +587,8 @@ modernc.org/opt v0.2.0 h1:tGyef5ApycA7FSEOMraay9SaTk5zmbx7Tu+cJs4QKZg=
modernc.org/opt v0.2.0/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
modernc.org/sqlite v1.50.1 h1:l+cQvn0sd0zJJtfygGHuQJ5AjlrwXmWPw4KP3ZMwr9w=
modernc.org/sqlite v1.50.1/go.mod h1:tcNzv5p84E0skkmJn038y+hWJbLQXQqEnQfeh5r2JLM=
modernc.org/sqlite v1.51.0 h1:aH/MMSoayAIhozZ7uJbVTT9QO/VhzBf0J9tymmmuC/U=
modernc.org/sqlite v1.51.0/go.mod h1:tcNzv5p84E0skkmJn038y+hWJbLQXQqEnQfeh5r2JLM=
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
@@ -605,5 +605,5 @@ sigs.k8s.io/yaml v1.6.0 h1:G8fkbMSAFqgEFgh4b1wmtzDnioxFCUgTZhlbj5P9QYs=
sigs.k8s.io/yaml v1.6.0/go.mod h1:796bPqUfzR/0jLAl6XjHl3Ck7MiyVv8dbTdyT3/pMf4=
software.sslmate.com/src/go-pkcs12 v0.4.0 h1:H2g08FrTvSFKUj+D309j1DPfk5APnIdAQAB8aEykJ5k=
software.sslmate.com/src/go-pkcs12 v0.4.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI=
tailscale.com v1.98.3 h1:caAbG4UfkKfKPE6b1fj5t4ep5qrwEis5AJu91ruvePw=
tailscale.com v1.98.3/go.mod h1:U23ZwbZlKJMNU7CScy+lCVVlece/S5n09q0nyudncBI=
tailscale.com v1.98.5 h1:2XSwZqp1UnW4aavT8qw3Thh/SJu3YssRG9tNnJNIx04=
tailscale.com v1.98.5/go.mod h1:U23ZwbZlKJMNU7CScy+lCVVlece/S5n09q0nyudncBI=
+1 -1
View File
@@ -422,7 +422,7 @@ func TestUserController(t *testing.T) {
beforeEach := func() {
// Clear failed login attempts before each test
authService.ClearRateLimitsTestingOnly()
authService.ClearLoginAttempts()
}
for _, test := range tests {
@@ -326,11 +326,6 @@ func (m *ContextMiddleware) tailscaleWhois(ctx context.Context, ip string) (*mod
Name: whois.DisplayName,
},
UserID: whois.UserID,
Tags: whois.Tags,
}
if !strings.ContainsAny(uctx.Email, "@") {
uctx.Email = utils.CompileUserEmail(uctx.Email+"-tailscale", m.runtime.CookieDomain)
}
return &uctx, nil
@@ -263,7 +263,7 @@ func TestContextMiddleware(t *testing.T) {
contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker, nil)
for _, test := range tests {
authService.ClearRateLimitsTestingOnly()
authService.ClearLoginAttempts()
t.Run(test.description, func(t *testing.T) {
gin.SetMode(gin.TestMode)
-2
View File
@@ -59,8 +59,6 @@ type LDAPContext struct {
type TailscaleContext struct {
BaseContext
UserID string
// for future use
Tags []string
}
func (c *UserContext) IsAuthenticated() bool {
+130 -191
View File
@@ -15,8 +15,6 @@ import (
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"slices"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
"golang.org/x/oauth2"
@@ -54,27 +52,17 @@ type OAuthPendingSession struct {
CallbackParams OAuthURLParams
}
type LdapGroupsCache struct {
Groups []string
Expires time.Time
}
type LoginAttempt struct {
FailedAttempts int
LastAttempt time.Time
LockedUntil time.Time
}
type Lockdown struct {
Active bool
ActiveUntil time.Time
}
type AuthService struct {
log *logger.Logger
config model.Config
runtime model.RuntimeConfig
context context.Context
ctx context.Context
ldap *LdapService
queries repository.Store
@@ -82,15 +70,19 @@ type AuthService struct {
tailscale *TailscaleService
policyEngine *PolicyEngine
loginAttempts map[string]*LoginAttempt
ldapGroupsCache map[string]*LdapGroupsCache
oauthPendingSessions map[string]*OAuthPendingSession
oauthMutex sync.RWMutex
loginMutex sync.RWMutex
ldapGroupsMutex sync.RWMutex
lockdown *Lockdown
lockdownCtx context.Context
lockdownCancelFunc context.CancelFunc
lockdown struct {
active bool
until time.Time
ctx context.Context
cancelFunc context.CancelFunc
mu sync.RWMutex
}
caches struct {
login *CacheStore[LoginAttempt]
oauth *CacheStore[OAuthPendingSession]
ldap *CacheStore[[]string]
}
}
func NewAuthService(
@@ -106,21 +98,41 @@ func NewAuthService(
policy *PolicyEngine,
) *AuthService {
service := &AuthService{
log: log,
runtime: runtime,
context: ctx,
config: config,
loginAttempts: make(map[string]*LoginAttempt),
ldapGroupsCache: make(map[string]*LdapGroupsCache),
oauthPendingSessions: make(map[string]*OAuthPendingSession),
ldap: ldap,
queries: queries,
oauthBroker: oauthBroker,
tailscale: tailscale,
policyEngine: policy,
log: log,
runtime: runtime,
ctx: ctx,
config: config,
ldap: ldap,
queries: queries,
oauthBroker: oauthBroker,
tailscale: tailscale,
policyEngine: policy,
}
dg.Go(service.cleanupOAuthSessions, ding.RingMinor)
// caches setup
oauthCache := NewCacheStore[OAuthPendingSession](256)
loginCache := NewCacheStore[LoginAttempt](1024)
ldapCache := NewCacheStore[[]string](1024)
service.caches.oauth = oauthCache
service.caches.login = loginCache
service.caches.ldap = ldapCache
dg.Go(func(ctx context.Context) {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
service.caches.oauth.Sweep()
service.caches.login.Sweep()
service.caches.ldap.Sweep()
case <-ctx.Done():
return
}
}
}, ding.RingMinor)
return service
}
@@ -195,14 +207,12 @@ func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
return nil, errors.New("ldap service not configured")
}
auth.ldapGroupsMutex.RLock()
entry, exists := auth.ldapGroupsCache[userDN]
auth.ldapGroupsMutex.RUnlock()
entry, exists := auth.caches.ldap.Get(userDN)
if exists && time.Now().Before(entry.Expires) {
if exists {
return &model.LDAPUser{
DN: userDN,
Groups: entry.Groups,
Groups: entry,
}, nil
}
@@ -212,12 +222,7 @@ func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
return nil, fmt.Errorf("failed to get ldap groups: %w", err)
}
auth.ldapGroupsMutex.Lock()
auth.ldapGroupsCache[userDN] = &LdapGroupsCache{
Groups: groups,
Expires: time.Now().Add(time.Duration(auth.config.LDAP.GroupCacheTTL) * time.Second),
}
auth.ldapGroupsMutex.Unlock()
auth.caches.ldap.Set(userDN, groups, time.Duration(auth.config.LDAP.GroupCacheTTL)*time.Second)
return &model.LDAPUser{
DN: userDN,
@@ -226,11 +231,7 @@ func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
}
func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
auth.loginMutex.RLock()
defer auth.loginMutex.RUnlock()
if auth.lockdown != nil && auth.lockdown.Active {
remaining := int(time.Until(auth.lockdown.ActiveUntil).Seconds())
if locked, remaining := auth.IsInLockdown(); locked {
return true, remaining
}
@@ -238,7 +239,7 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
return false, 0
}
attempt, exists := auth.loginAttempts[identifier]
attempt, exists := auth.caches.login.Get(identifier)
if !exists {
return false, 0
}
@@ -256,37 +257,49 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
return
}
auth.loginMutex.Lock()
defer auth.loginMutex.Unlock()
if len(auth.loginAttempts) >= MaxLoginAttemptRecords {
if auth.lockdown != nil && auth.lockdown.Active {
if auth.caches.login.Size() >= MaxLoginAttemptRecords {
if locked, _ := auth.IsInLockdown(); locked {
return
}
go auth.lockdownMode()
return
}
attempt, exists := auth.loginAttempts[identifier]
if !exists {
attempt = &LoginAttempt{}
auth.loginAttempts[identifier] = attempt
}
auth.caches.login.WithLock(func(actions CacheStoreActions[LoginAttempt]) {
entry, ok := actions.Get(identifier)
attempt.LastAttempt = time.Now()
if !ok {
attempt := LoginAttempt{
LastAttempt: time.Now(),
}
if !success {
attempt.FailedAttempts = 1
if attempt.FailedAttempts >= auth.config.Auth.LoginMaxRetries {
attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", attempt.FailedAttempts).Msg("Account locked due to too many failed login attempts")
}
}
// match current tinyauth behavior which doesn't expire rate limits
actions.Set(identifier, attempt, 0)
return
}
if success {
attempt.FailedAttempts = 0
attempt.LockedUntil = time.Time{} // Reset lock time
return
}
entry.LastAttempt = time.Now()
attempt.FailedAttempts++
if success {
entry.FailedAttempts = 0
entry.LockedUntil = time.Time{}
} else {
entry.FailedAttempts++
if attempt.FailedAttempts >= auth.config.Auth.LoginMaxRetries {
attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", attempt.FailedAttempts).Msg("Account locked due to too many failed login attempts")
}
if entry.FailedAttempts >= auth.config.Auth.LoginMaxRetries {
entry.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", entry.FailedAttempts).Msg("Account locked due to too many failed login attempts")
}
}
actions.Set(identifier, entry, 0)
})
}
// We could also directly access the policyEngine.effectToAccess but
@@ -504,8 +517,6 @@ func (auth *AuthService) LDAPAuthConfigured() bool {
}
func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLParams) (string, OAuthPendingSession, error) {
auth.ensureOAuthSessionLimit()
service, ok := auth.oauthBroker.GetService(serviceName)
if !ok {
@@ -529,9 +540,7 @@ func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLPara
CallbackParams: params,
}
auth.oauthMutex.Lock()
auth.oauthPendingSessions[sessionId.String()] = &session
auth.oauthMutex.Unlock()
auth.caches.oauth.Set(sessionId.String(), session, time.Minute*10)
return sessionId.String(), session, nil
}
@@ -547,10 +556,10 @@ func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) {
}
func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.Token, error) {
session, err := auth.GetOAuthPendingSession(sessionId)
session, ok := auth.caches.oauth.Get(sessionId)
if err != nil {
return nil, err
if !ok {
return nil, fmt.Errorf("oauth session not found: %s", sessionId)
}
token, err := (*session.Service).GetToken(code, session.Verifier)
@@ -559,9 +568,14 @@ func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.T
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
}
auth.oauthMutex.Lock()
session.Token = token
auth.oauthMutex.Unlock()
// ttl 0 means keep current expiration
ok = auth.caches.oauth.Update(sessionId, session, 0)
if !ok {
return nil, fmt.Errorf("failed to update oauth session with token: %s", sessionId)
}
return token, nil
}
@@ -597,123 +611,39 @@ func (auth *AuthService) GetOAuthService(sessionId string) (OAuthServiceImpl, er
}
func (auth *AuthService) EndOAuthSession(sessionId string) {
auth.oauthMutex.Lock()
delete(auth.oauthPendingSessions, sessionId)
auth.oauthMutex.Unlock()
}
func (auth *AuthService) cleanupOAuthSessions(ctx context.Context) {
auth.log.App.Debug().Msg("Starting OAuth session cleanup routine")
ticker := time.NewTicker(30 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
auth.log.App.Debug().Msg("Running OAuth session cleanup")
auth.oauthMutex.Lock()
now := time.Now()
for sessionId, session := range auth.oauthPendingSessions {
if now.After(session.ExpiresAt) {
delete(auth.oauthPendingSessions, sessionId)
}
}
auth.oauthMutex.Unlock()
auth.log.App.Debug().Msg("OAuth session cleanup completed")
case <-ctx.Done():
auth.log.App.Debug().Msg("Stopping OAuth session cleanup routine")
return
}
}
auth.caches.oauth.Delete(sessionId)
}
func (auth *AuthService) GetOAuthPendingSession(sessionId string) (*OAuthPendingSession, error) {
auth.ensureOAuthSessionLimit()
auth.oauthMutex.RLock()
session, exists := auth.oauthPendingSessions[sessionId]
auth.oauthMutex.RUnlock()
session, exists := auth.caches.oauth.Get(sessionId)
if !exists {
return &OAuthPendingSession{}, fmt.Errorf("oauth session not found: %s", sessionId)
}
if time.Now().After(session.ExpiresAt) {
auth.oauthMutex.Lock()
delete(auth.oauthPendingSessions, sessionId)
auth.oauthMutex.Unlock()
return &OAuthPendingSession{}, fmt.Errorf("oauth session expired: %s", sessionId)
}
return session, nil
}
func (auth *AuthService) ensureOAuthSessionLimit() {
auth.oauthMutex.Lock()
defer auth.oauthMutex.Unlock()
if len(auth.oauthPendingSessions) <= MaxOAuthPendingSessions {
return
}
type entry struct {
id string
expiresAt int64
}
entries := make([]entry, 0, len(auth.oauthPendingSessions))
for id, session := range auth.oauthPendingSessions {
entries = append(entries, entry{id, session.ExpiresAt.Unix()})
}
slices.SortFunc(entries, func(a, b entry) int {
if a.expiresAt < b.expiresAt {
return -1
}
if a.expiresAt > b.expiresAt {
return 1
}
return 0
})
for _, e := range entries[:OAuthCleanupCount] {
delete(auth.oauthPendingSessions, e.id)
}
return &session, nil
}
func (auth *AuthService) lockdownMode() {
ctx, cancel := context.WithCancel(context.Background())
auth.lockdown.mu.Lock()
auth.loginMutex.Lock()
if auth.lockdown != nil && auth.lockdown.Active {
auth.loginMutex.Unlock()
cancel()
if auth.lockdown.active {
auth.lockdown.mu.Unlock()
return
}
auth.lockdownCtx = ctx
auth.lockdownCancelFunc = cancel
ctx, cancel := context.WithCancel(context.Background())
auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")
auth.lockdown = &Lockdown{
Active: true,
ActiveUntil: time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second),
}
auth.lockdown.active = true
auth.lockdown.ctx = ctx
auth.lockdown.cancelFunc = cancel
auth.lockdown.until = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
// At this point all login attemps will also expire so,
// we might as well clear them to free up memory
auth.loginAttempts = make(map[string]*LoginAttempt)
timer := time.NewTimer(time.Until(auth.lockdown.until))
timer := time.NewTimer(time.Until(auth.lockdown.ActiveUntil))
auth.loginMutex.Unlock()
auth.lockdown.mu.Unlock()
defer cancel()
defer timer.Stop()
@@ -723,24 +653,33 @@ func (auth *AuthService) lockdownMode() {
// Timer expired, end lockdown
case <-ctx.Done():
// Context cancelled, end lockdown
case <-auth.context.Done():
case <-auth.ctx.Done():
// Service is shutting down, end lockdown
}
auth.loginMutex.Lock()
auth.lockdown.mu.Lock()
auth.log.App.Info().Msg("Exiting lockdown mode")
auth.lockdown = nil
auth.loginMutex.Unlock()
auth.lockdown.active = false
auth.lockdown.until = time.Time{}
auth.lockdown.ctx = nil
auth.lockdown.cancelFunc = nil
auth.lockdown.mu.Unlock()
}
// Function only used for testing - do not use in prod!
func (auth *AuthService) ClearRateLimitsTestingOnly() {
auth.loginMutex.Lock()
auth.loginAttempts = make(map[string]*LoginAttempt)
if auth.lockdown != nil {
auth.lockdownCancelFunc()
func (auth *AuthService) IsInLockdown() (bool, int) {
auth.lockdown.mu.RLock()
defer auth.lockdown.mu.RUnlock()
if auth.lockdown.active {
remaining := int(time.Until(auth.lockdown.until).Seconds())
return true, remaining
}
auth.loginMutex.Unlock()
return false, 0
}
// mostly a testing function, not useful for anything else
func (auth *AuthService) ClearLoginAttempts() {
auth.caches.login.Clear()
}
+197
View File
@@ -0,0 +1,197 @@
package service
import (
"slices"
"sync"
"time"
)
type CacheStoreActions[T any] struct {
Set func(key string, value T, ttl time.Duration)
Get func(key string) (T, bool)
Delete func(key string)
Update func(key string, value T, ttl time.Duration) bool
}
type cacheEntry[T any] struct {
value T
expiresAt *time.Time
}
type CacheStore[T any] struct {
cache map[string]cacheEntry[T]
order []string
mu sync.RWMutex
maxSize int
}
func NewCacheStore[T any](maxSize int) *CacheStore[T] {
return &CacheStore[T]{
cache: make(map[string]cacheEntry[T]),
order: make([]string, 0),
maxSize: maxSize,
}
}
// With lock allows performing multiple operations on the cache store atomically.
// The provided mutate function receives a set of actions (Set, Get, Delete) that
// can be used to manipulate the cache store within the locked context.
func (cs *CacheStore[T]) WithLock(mutate func(actions CacheStoreActions[T])) {
cs.mu.Lock()
defer cs.mu.Unlock()
actions := CacheStoreActions[T]{
Set: cs.setCallback,
Get: cs.getCallback,
Delete: cs.deleteCallback,
Update: cs.updateCallback,
}
mutate(actions)
}
func (cs *CacheStore[T]) updateCallback(key string, value T, ttl time.Duration) bool {
if currentEntry, exists := cs.cache[key]; exists {
if currentEntry.expiresAt != nil && time.Now().After(*currentEntry.expiresAt) {
return false
}
entry := cacheEntry[T]{
value: value,
expiresAt: currentEntry.expiresAt,
}
if ttl > 0 {
expiration := time.Now().Add(ttl)
entry.expiresAt = &expiration
}
cs.cache[key] = entry
return true
}
return false
}
func (cs *CacheStore[T]) Update(key string, value T, ttl time.Duration) bool {
cs.mu.Lock()
defer cs.mu.Unlock()
return cs.updateCallback(key, value, ttl)
}
func (cs *CacheStore[T]) setCallback(key string, value T, ttl time.Duration) {
if cs.maxSize > 0 {
if _, exists := cs.cache[key]; !exists && len(cs.cache) >= cs.maxSize {
cs.evictOne()
}
}
var expiresAt *time.Time
if ttl > 0 {
expiration := time.Now().Add(ttl)
expiresAt = &expiration
}
cs.cache[key] = cacheEntry[T]{
value: value,
expiresAt: expiresAt,
}
if !slices.Contains(cs.order, key) {
cs.order = append(cs.order, key)
}
}
func (cs *CacheStore[T]) Set(key string, value T, ttl time.Duration) {
cs.mu.Lock()
defer cs.mu.Unlock()
cs.setCallback(key, value, ttl)
}
func (cs *CacheStore[T]) getCallback(key string) (T, bool) {
entry, exists := cs.cache[key]
if !exists {
var zero T
return zero, false
}
if entry.expiresAt != nil && time.Now().After(*entry.expiresAt) {
var zero T
return zero, false
}
return entry.value, true
}
func (cs *CacheStore[T]) Get(key string) (T, bool) {
cs.mu.RLock()
defer cs.mu.RUnlock()
return cs.getCallback(key)
}
func (cs *CacheStore[T]) deleteCallback(key string) {
delete(cs.cache, key)
keyIdx := slices.Index(cs.order, key)
if keyIdx != -1 {
cs.order = append(cs.order[:keyIdx], cs.order[keyIdx+1:]...)
}
}
func (cs *CacheStore[T]) Delete(key string) {
cs.mu.Lock()
defer cs.mu.Unlock()
cs.deleteCallback(key)
}
func (cs *CacheStore[T]) Sweep() {
cs.mu.Lock()
for key, entry := range cs.cache {
if entry.expiresAt != nil && time.Now().After(*entry.expiresAt) {
cs.deleteCallback(key)
}
}
cs.mu.Unlock()
}
func (cs *CacheStore[T]) evictOne() bool {
now := time.Now()
var oldestKey string
var oldestExp *time.Time
for k, e := range cs.cache {
if e.expiresAt != nil && now.After(*e.expiresAt) {
cs.deleteCallback(k)
return true
}
if e.expiresAt != nil && (oldestExp == nil || e.expiresAt.Before(*oldestExp)) {
oldestKey, oldestExp = k, e.expiresAt
}
}
// If we found an oldest key, evict it else we delete the first key in the order list
if oldestKey != "" {
cs.deleteCallback(oldestKey)
return true
} else {
if len(cs.order) > 0 {
cs.deleteCallback(cs.order[0])
return true
}
}
return false
}
func (cs *CacheStore[T]) Size() int {
cs.mu.RLock()
defer cs.mu.RUnlock()
return len(cs.cache)
}
func (cs *CacheStore[T]) Clear() {
cs.mu.Lock()
defer cs.mu.Unlock()
cs.cache = make(map[string]cacheEntry[T])
cs.order = make([]string, 0)
}
+383
View File
@@ -0,0 +1,383 @@
package service
import (
"strconv"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestCacheStoreGet(t *testing.T) {
tests := []struct {
name string
setup func(cs *CacheStore[string])
wantValue string
wantOk bool
}{
{
name: "returns a stored value",
setup: func(cs *CacheStore[string]) { cs.Set("key", "value", 0) },
wantValue: "value",
wantOk: true,
},
{
name: "reports a missing key",
setup: func(cs *CacheStore[string]) {},
wantOk: false,
},
{
name: "returns the latest value after an overwrite",
setup: func(cs *CacheStore[string]) {
cs.Set("key", "first", 0)
cs.Set("key", "second", 0)
},
wantValue: "second",
wantOk: true,
},
{
name: "returns a non-expired entry",
setup: func(cs *CacheStore[string]) { cs.Set("key", "value", time.Minute) },
wantValue: "value",
wantOk: true,
},
{
name: "treats an expired entry as missing",
setup: func(cs *CacheStore[string]) {
cs.Set("key", "value", 10*time.Millisecond)
time.Sleep(20 * time.Millisecond)
},
wantOk: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cs := NewCacheStore[string](0)
tt.setup(cs)
value, ok := cs.Get("key")
assert.Equal(t, tt.wantOk, ok)
if tt.wantOk {
assert.Equal(t, tt.wantValue, value)
}
})
}
}
func TestCacheStoreUpdate(t *testing.T) {
tests := []struct {
name string
setup func(cs *CacheStore[string])
ttl time.Duration
wantOk bool
afterWait time.Duration
wantPresent bool
wantValue string
}{
{
name: "updates an existing entry",
setup: func(cs *CacheStore[string]) { cs.Set("key", "old", 0) },
ttl: 0,
wantOk: true,
wantPresent: true,
wantValue: "new",
},
{
name: "does not create a missing entry",
setup: func(cs *CacheStore[string]) {},
ttl: 0,
wantOk: false,
wantPresent: false,
},
{
name: "preserves the existing expiry when ttl is zero",
setup: func(cs *CacheStore[string]) { cs.Set("key", "old", 30*time.Millisecond) },
ttl: 0,
wantOk: true,
afterWait: 40 * time.Millisecond,
wantPresent: false,
},
{
name: "refreshes the expiry when ttl is provided",
setup: func(cs *CacheStore[string]) { cs.Set("key", "old", 10*time.Millisecond) },
ttl: time.Minute,
wantOk: true,
afterWait: 20 * time.Millisecond,
wantPresent: true,
wantValue: "new",
},
{
name: "does not update an expired entry",
setup: func(cs *CacheStore[string]) {
cs.Set("key", "old", 10*time.Millisecond)
time.Sleep(20 * time.Millisecond)
},
ttl: time.Minute,
wantOk: false,
wantPresent: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cs := NewCacheStore[string](0)
tt.setup(cs)
ok := cs.Update("key", "new", tt.ttl)
assert.Equal(t, tt.wantOk, ok)
time.Sleep(tt.afterWait)
value, present := cs.Get("key")
assert.Equal(t, tt.wantPresent, present)
if tt.wantPresent {
assert.Equal(t, tt.wantValue, value)
}
})
}
}
func TestCacheStoreDelete(t *testing.T) {
tests := []struct {
name string
setup func(cs *CacheStore[string])
key string
wantSize int
}{
{
name: "removes an existing key",
setup: func(cs *CacheStore[string]) {
cs.Set("a", "1", 0)
cs.Set("b", "2", 0)
},
key: "a",
wantSize: 1,
},
{
name: "is a no-op for a missing key",
setup: func(cs *CacheStore[string]) { cs.Set("a", "1", 0) },
key: "missing",
wantSize: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cs := NewCacheStore[string](0)
tt.setup(cs)
cs.Delete(tt.key)
_, ok := cs.Get(tt.key)
assert.False(t, ok)
assert.Equal(t, tt.wantSize, cs.Size())
})
}
}
func TestCacheStoreSweep(t *testing.T) {
tests := []struct {
name string
setup func(cs *CacheStore[string])
present []string
absent []string
wantSize int
}{
{
name: "removes expired entries and keeps the rest",
setup: func(cs *CacheStore[string]) {
cs.Set("permanent", "value", 0)
cs.Set("expired", "value", 10*time.Millisecond)
time.Sleep(20 * time.Millisecond)
},
present: []string{"permanent"},
absent: []string{"expired"},
wantSize: 1,
},
{
name: "keeps all live entries",
setup: func(cs *CacheStore[string]) {
cs.Set("a", "value", 0)
cs.Set("b", "value", time.Minute)
},
present: []string{"a", "b"},
wantSize: 2,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cs := NewCacheStore[string](0)
tt.setup(cs)
cs.Sweep()
for _, key := range tt.present {
_, ok := cs.Get(key)
assert.True(t, ok)
}
for _, key := range tt.absent {
_, ok := cs.Get(key)
assert.False(t, ok)
}
assert.Equal(t, tt.wantSize, cs.Size())
})
}
}
func TestCacheStoreEviction(t *testing.T) {
// Every case uses a cache with maxSize 2; the final Set in setup is the
// insertion that overflows the cache and triggers an eviction.
tests := []struct {
name string
setup func(cs *CacheStore[string])
present []string
absent []string
wantSize int
}{
{
name: "evicts an already expired entry first",
setup: func(cs *CacheStore[string]) {
cs.Set("expired", "value", 10*time.Millisecond)
cs.Set("fresh", "value", time.Minute)
time.Sleep(20 * time.Millisecond)
cs.Set("new", "value", time.Minute)
},
present: []string{"fresh", "new"},
absent: []string{"expired"},
wantSize: 2,
},
{
name: "evicts the entry expiring soonest",
setup: func(cs *CacheStore[string]) {
cs.Set("soon", "value", 50*time.Millisecond)
cs.Set("later", "value", time.Hour)
cs.Set("new", "value", time.Hour)
},
present: []string{"later", "new"},
absent: []string{"soon"},
wantSize: 2,
},
{
name: "evicts the oldest inserted entry when none have a ttl",
setup: func(cs *CacheStore[string]) {
cs.Set("first", "value", 0)
cs.Set("second", "value", 0)
cs.Set("third", "value", 0)
},
present: []string{"second", "third"},
absent: []string{"first"},
wantSize: 2,
},
{
name: "overwriting an existing key does not trigger eviction",
setup: func(cs *CacheStore[string]) {
cs.Set("a", "1", 0)
cs.Set("b", "2", 0)
cs.Set("a", "updated", 0)
},
present: []string{"a", "b"},
wantSize: 2,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cs := NewCacheStore[string](2)
tt.setup(cs)
for _, key := range tt.present {
_, ok := cs.Get(key)
assert.True(t, ok)
}
for _, key := range tt.absent {
_, ok := cs.Get(key)
assert.False(t, ok)
}
assert.Equal(t, tt.wantSize, cs.Size())
})
}
}
func TestCacheStoreSizeAndClear(t *testing.T) {
cs := NewCacheStore[string](0)
assert.Equal(t, 0, cs.Size())
cs.Set("a", "1", 0)
cs.Set("b", "2", 0)
assert.Equal(t, 2, cs.Size())
cs.Clear()
assert.Equal(t, 0, cs.Size())
_, ok := cs.Get("a")
assert.False(t, ok)
}
func TestCacheStoreWithLock(t *testing.T) {
cs := NewCacheStore[int](0)
cs.Set("counter", 1, 0)
// All four actions run atomically under a single lock.
cs.WithLock(func(actions CacheStoreActions[int]) {
current, ok := actions.Get("counter")
assert.True(t, ok)
actions.Set("counter", current+1, 0)
actions.Set("other", 100, 0)
actions.Delete("counter")
updated := actions.Update("other", 200, 0)
assert.True(t, updated)
})
_, ok := cs.Get("counter")
assert.False(t, ok)
value, ok := cs.Get("other")
assert.True(t, ok)
assert.Equal(t, 200, value)
}
// TestCacheStoreConcurrency exercises every locking path concurrently so the
// race detector (go test -race) can flag unsynchronised access.
func TestCacheStoreConcurrency(t *testing.T) {
cs := NewCacheStore[int](64)
const goroutines = 16
const iterations = 200
var wg sync.WaitGroup
wg.Add(goroutines)
for g := range goroutines {
go func(g int) {
defer wg.Done()
for i := range iterations {
key := strconv.Itoa((g*iterations + i) % 32)
switch i % 6 {
case 0:
cs.Set(key, i, time.Minute)
case 1:
cs.Get(key)
case 2:
cs.Update(key, i, time.Minute)
case 3:
cs.Delete(key)
case 4:
cs.Size()
case 5:
cs.WithLock(func(actions CacheStoreActions[int]) {
if v, ok := actions.Get(key); ok {
actions.Set(key, v+1, time.Minute)
}
})
}
}
}(g)
}
wg.Wait()
}
+10 -3
View File
@@ -21,7 +21,6 @@ type TailscaleWhoisResponse struct {
LoginName string
DisplayName string
NodeName string
Tags []string
}
type TailscaleService struct {
@@ -115,14 +114,22 @@ func (ts *TailscaleService) Whois(ctx context.Context, addr string) (*TailscaleW
return nil, fmt.Errorf("failed to get client whois: %w", err)
}
if who.Node.IsTagged() {
ts.log.App.Debug().Msgf("Skipping whois for tagged node %s", who.Node.Name)
return nil, nil
}
uid := strings.TrimPrefix(who.UserProfile.ID.String(), "userid:")
res := TailscaleWhoisResponse{
UserID: who.UserProfile.ID.String(),
UserID: uid,
LoginName: who.UserProfile.LoginName,
DisplayName: who.UserProfile.DisplayName,
NodeName: strings.TrimSuffix(who.Node.Name, "."),
Tags: who.Node.Tags,
}
ts.log.App.Debug().Interface("res", res).Msg("tailscale")
return &res, nil
}