refactor: rework logging and cancellation in services

This commit is contained in:
Stavros
2026-05-08 17:18:39 +03:00
parent 55b53c77bf
commit e214d6d8d4
12 changed files with 315 additions and 295 deletions
+87 -85
View File
@@ -14,7 +14,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
"slices"
@@ -72,39 +72,40 @@ type Lockdown struct {
ActiveUntil time.Time
}
type AuthServiceConfig struct {
LocalUsers *[]model.LocalUser
OauthWhitelist []string
SessionExpiry int
SessionMaxLifetime int
SecureCookie bool
CookieDomain string
LoginTimeout int
LoginMaxRetries int
SessionCookieName string
IP model.IPConfig
LDAPGroupsCacheTTL int
SubdomainsEnabled bool
}
type AuthService struct {
config AuthServiceConfig
log *logger.Logger
config model.Config
runtime model.RuntimeConfig
context context.Context
ldap *LdapService
queries *repository.Queries
oauthBroker *OAuthBrokerService
loginAttempts map[string]*LoginAttempt
ldapGroupsCache map[string]*LdapGroupsCache
oauthPendingSessions map[string]*OAuthPendingSession
oauthMutex sync.RWMutex
loginMutex sync.RWMutex
ldapGroupsMutex sync.RWMutex
ldap *LdapService
queries *repository.Queries
oauthBroker *OAuthBrokerService
lockdown *Lockdown
lockdownCtx context.Context
lockdownCancelFunc context.CancelFunc
}
func NewAuthService(config AuthServiceConfig, ldap *LdapService, queries *repository.Queries, oauthBroker *OAuthBrokerService) *AuthService {
func NewAuthService(
log *logger.Logger,
config model.Config,
runtime model.RuntimeConfig,
context context.Context,
ldap *LdapService,
queries *repository.Queries,
oauthBroker *OAuthBrokerService,
) *AuthService {
return &AuthService{
log: log,
runtime: runtime,
context: context,
config: config,
loginAttempts: make(map[string]*LoginAttempt),
ldapGroupsCache: make(map[string]*LdapGroupsCache),
@@ -173,10 +174,10 @@ func (auth *AuthService) CheckUserPassword(search model.UserSearch, password str
}
func (auth *AuthService) GetLocalUser(username string) *model.LocalUser {
if auth.config.LocalUsers == nil {
if auth.runtime.LocalUsers == nil {
return nil
}
for _, user := range *auth.config.LocalUsers {
for _, user := range auth.runtime.LocalUsers {
if user.Username == username {
return &user
}
@@ -209,7 +210,7 @@ func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
auth.ldapGroupsMutex.Lock()
auth.ldapGroupsCache[userDN] = &LdapGroupsCache{
Groups: groups,
Expires: time.Now().Add(time.Duration(auth.config.LDAPGroupsCacheTTL) * time.Second),
Expires: time.Now().Add(time.Duration(auth.config.LDAP.GroupCacheTTL) * time.Second),
}
auth.ldapGroupsMutex.Unlock()
@@ -228,7 +229,7 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
return true, remaining
}
if auth.config.LoginMaxRetries <= 0 || auth.config.LoginTimeout <= 0 {
if auth.config.Auth.LoginMaxRetries <= 0 || auth.config.Auth.LoginTimeout <= 0 {
return false, 0
}
@@ -246,7 +247,7 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
}
func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
if auth.config.LoginMaxRetries <= 0 || auth.config.LoginTimeout <= 0 {
if auth.config.Auth.LoginMaxRetries <= 0 || auth.config.Auth.LoginTimeout <= 0 {
return
}
@@ -277,14 +278,14 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
attempt.FailedAttempts++
if attempt.FailedAttempts >= auth.config.LoginMaxRetries {
attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.LoginTimeout) * time.Second)
tlog.App.Warn().Str("identifier", identifier).Int("timeout", auth.config.LoginTimeout).Msg("Account locked due to too many failed login attempts")
if attempt.FailedAttempts >= auth.config.Auth.LoginMaxRetries {
attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", attempt.FailedAttempts).Msg("Account locked due to too many failed login attempts")
}
}
func (auth *AuthService) IsEmailWhitelisted(email string) bool {
return utils.CheckFilter(strings.Join(auth.config.OauthWhitelist, ","), email)
return utils.CheckFilter(strings.Join(auth.runtime.OAuthWhitelist, ","), email)
}
func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) {
@@ -299,7 +300,7 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
if data.TotpPending {
expiry = 3600
} else {
expiry = auth.config.SessionExpiry
expiry = auth.config.Auth.SessionExpiry
}
expiresAt := time.Now().Add(time.Duration(expiry) * time.Second)
@@ -325,13 +326,13 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
}
return &http.Cookie{
Name: auth.config.SessionCookieName,
Name: auth.runtime.SessionCookieName,
Value: session.UUID,
Path: "/",
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
Expires: expiresAt,
MaxAge: int(time.Until(expiresAt).Seconds()),
Secure: auth.config.SecureCookie,
Secure: auth.config.Auth.SecureCookie,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}, nil
@@ -348,8 +349,8 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http
var refreshThreshold int64
if auth.config.SessionExpiry <= int(time.Hour.Seconds()) {
refreshThreshold = int64(auth.config.SessionExpiry / 2)
if auth.config.Auth.SessionExpiry <= int(time.Hour.Seconds()) {
refreshThreshold = int64(auth.config.Auth.SessionExpiry / 2)
} else {
refreshThreshold = int64(time.Hour.Seconds())
}
@@ -378,13 +379,13 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http
}
return &http.Cookie{
Name: auth.config.SessionCookieName,
Name: auth.runtime.SessionCookieName,
Value: session.UUID,
Path: "/",
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second),
MaxAge: int(newExpiry - currentTime),
Secure: auth.config.SecureCookie,
Secure: auth.config.Auth.SecureCookie,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}, nil
@@ -395,7 +396,7 @@ func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.
err := auth.queries.DeleteSession(ctx, uuid)
if err != nil {
tlog.App.Warn().Err(err).Msg("Failed to delete session from database, proceeding to clear cookie anyway")
auth.log.App.Error().Err(err).Str("uuid", uuid).Msg("Failed to delete session from database")
}
err = auth.queries.DeleteSession(ctx, uuid)
@@ -405,13 +406,13 @@ func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.
}
return &http.Cookie{
Name: auth.config.SessionCookieName,
Name: auth.runtime.SessionCookieName,
Value: "",
Path: "/",
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
Expires: time.Now(),
MaxAge: -1,
Secure: auth.config.SecureCookie,
Secure: auth.config.Auth.SecureCookie,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}, nil
@@ -429,8 +430,8 @@ func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*reposito
currentTime := time.Now().Unix()
if auth.config.SessionMaxLifetime != 0 && session.CreatedAt != 0 {
if currentTime-session.CreatedAt > int64(auth.config.SessionMaxLifetime) {
if auth.config.Auth.SessionMaxLifetime != 0 && session.CreatedAt != 0 {
if currentTime-session.CreatedAt > int64(auth.config.Auth.SessionMaxLifetime) {
err = auth.queries.DeleteSession(ctx, uuid)
if err != nil {
return nil, fmt.Errorf("failed to delete expired session: %w", err)
@@ -451,7 +452,7 @@ func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*reposito
}
func (auth *AuthService) LocalAuthConfigured() bool {
return auth.config.LocalUsers != nil && len(*auth.config.LocalUsers) > 0
return len(auth.runtime.LocalUsers) > 0
}
func (auth *AuthService) LDAPAuthConfigured() bool {
@@ -464,18 +465,18 @@ func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext
}
if context.Provider == model.ProviderOAuth {
tlog.App.Debug().Msg("Checking OAuth whitelist")
auth.log.App.Debug().Msg("User is an OAuth user, checking OAuth whitelist")
return utils.CheckFilter(acls.OAuth.Whitelist, context.OAuth.Email)
}
if acls.Users.Block != "" {
tlog.App.Debug().Msg("Checking blocked users")
auth.log.App.Debug().Msg("Checking users block list")
if utils.CheckFilter(acls.Users.Block, context.GetUsername()) {
return false
}
}
tlog.App.Debug().Msg("Checking users")
auth.log.App.Debug().Msg("Checking users allow list")
return utils.CheckFilter(acls.Users.Allow, context.GetUsername())
}
@@ -485,23 +486,23 @@ func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContex
}
if !context.IsOAuth() {
tlog.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check")
auth.log.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check")
return false
}
if _, ok := model.OverrideProviders[context.OAuth.ID]; ok {
tlog.App.Debug().Msg("Provider override for OAuth groups enabled, skipping group check")
auth.log.App.Debug().Str("provider", context.OAuth.ID).Msg("Provider override detected, skipping group check")
return true
}
for _, userGroup := range context.OAuth.Groups {
if utils.CheckFilter(acls.OAuth.Groups, strings.TrimSpace(userGroup)) {
tlog.App.Trace().Str("group", userGroup).Str("required", acls.OAuth.Groups).Msg("User group matched")
auth.log.App.Trace().Str("group", userGroup).Str("required", acls.OAuth.Groups).Msg("User group matched")
return true
}
}
tlog.App.Debug().Msg("No groups matched")
auth.log.App.Debug().Msg("No groups matched")
return false
}
@@ -511,18 +512,18 @@ func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext
}
if !context.IsLDAP() {
tlog.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check")
auth.log.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check")
return false
}
for _, userGroup := range context.LDAP.Groups {
if utils.CheckFilter(acls.LDAP.Groups, strings.TrimSpace(userGroup)) {
tlog.App.Trace().Str("group", userGroup).Str("required", acls.LDAP.Groups).Msg("User group matched")
auth.log.App.Trace().Str("group", userGroup).Str("required", acls.LDAP.Groups).Msg("User group matched")
return true
}
}
tlog.App.Debug().Msg("No groups matched")
auth.log.App.Debug().Msg("No groups matched")
return false
}
@@ -566,17 +567,17 @@ func (auth *AuthService) CheckIP(ip string, acls *model.App) bool {
}
// Merge the global and app IP filter
blockedIps := append(auth.config.IP.Block, acls.IP.Block...)
allowedIPs := append(auth.config.IP.Allow, acls.IP.Allow...)
blockedIps := append(auth.config.Auth.IP.Block, acls.IP.Block...)
allowedIPs := append(auth.config.Auth.IP.Allow, acls.IP.Allow...)
for _, blocked := range blockedIps {
res, err := utils.FilterIP(blocked, ip)
if err != nil {
tlog.App.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list")
auth.log.App.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list")
continue
}
if res {
tlog.App.Debug().Str("ip", ip).Str("item", blocked).Msg("IP is in blocked list, denying access")
auth.log.App.Debug().Str("ip", ip).Str("item", blocked).Msg("IP is in block list, denying access")
return false
}
}
@@ -584,21 +585,21 @@ func (auth *AuthService) CheckIP(ip string, acls *model.App) bool {
for _, allowed := range allowedIPs {
res, err := utils.FilterIP(allowed, ip)
if err != nil {
tlog.App.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list")
auth.log.App.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list")
continue
}
if res {
tlog.App.Debug().Str("ip", ip).Str("item", allowed).Msg("IP is in allowed list, allowing access")
auth.log.App.Debug().Str("ip", ip).Str("item", allowed).Msg("IP is in allow list, allowing access")
return true
}
}
if len(allowedIPs) > 0 {
tlog.App.Debug().Str("ip", ip).Msg("IP not in allow list, denying access")
auth.log.App.Debug().Str("ip", ip).Msg("IP not in allow list, denying access")
return false
}
tlog.App.Debug().Str("ip", ip).Msg("IP not in allow or block list, allowing by default")
auth.log.App.Debug().Str("ip", ip).Msg("IP not in any block or allow list, allowing access by default")
return true
}
@@ -610,16 +611,16 @@ func (auth *AuthService) IsBypassedIP(ip string, acls *model.App) bool {
for _, bypassed := range acls.IP.Bypass {
res, err := utils.FilterIP(bypassed, ip)
if err != nil {
tlog.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list")
auth.log.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list")
continue
}
if res {
tlog.App.Debug().Str("ip", ip).Str("item", bypassed).Msg("IP is in bypass list, allowing access")
auth.log.App.Debug().Str("ip", ip).Str("item", bypassed).Msg("IP is in bypass list, skipping authentication")
return true
}
}
tlog.App.Debug().Str("ip", ip).Msg("IP not in bypass list, continuing with authentication")
auth.log.App.Debug().Str("ip", ip).Msg("IP not in bypass list, proceeding with authentication")
return false
}
@@ -726,18 +727,23 @@ func (auth *AuthService) CleanupOAuthSessionsRoutine() {
ticker := time.NewTicker(30 * time.Minute)
defer ticker.Stop()
for range ticker.C {
auth.oauthMutex.Lock()
for {
select {
case <-ticker.C:
auth.oauthMutex.Lock()
now := time.Now()
now := time.Now()
for sessionId, session := range auth.oauthPendingSessions {
if now.After(session.ExpiresAt) {
delete(auth.oauthPendingSessions, sessionId)
for sessionId, session := range auth.oauthPendingSessions {
if now.After(session.ExpiresAt) {
delete(auth.oauthPendingSessions, sessionId)
}
}
}
auth.oauthMutex.Unlock()
auth.oauthMutex.Unlock()
case <-auth.context.Done():
return
}
}
}
@@ -806,11 +812,11 @@ func (auth *AuthService) lockdownMode() {
auth.loginMutex.Lock()
tlog.App.Warn().Msg("Multiple login attempts detected, possibly DDOS attack. Activating temporary lockdown.")
auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")
auth.lockdown = &Lockdown{
Active: true,
ActiveUntil: time.Now().Add(time.Duration(auth.config.LoginTimeout) * time.Second),
ActiveUntil: time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second),
}
// At this point all login attemps will also expire so,
@@ -827,11 +833,14 @@ func (auth *AuthService) lockdownMode() {
// Timer expired, end lockdown
case <-ctx.Done():
// Context cancelled, end lockdown
case <-auth.context.Done():
// Service is shutting down, end lockdown
}
auth.loginMutex.Lock()
tlog.App.Info().Msg("Lockdown period ended, resuming normal operation")
auth.log.App.Info().Msg("Exiting lockdown mode")
auth.lockdown = nil
auth.loginMutex.Unlock()
}
@@ -845,10 +854,3 @@ func (auth *AuthService) ClearRateLimitsTestingOnly() {
}
auth.loginMutex.Unlock()
}
func (auth *AuthService) getCookieDomain() string {
if auth.config.SubdomainsEnabled {
return "." + auth.config.CookieDomain
}
return auth.config.CookieDomain
}