mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-05-09 05:48:11 +00:00
refactor: rework logging and cancellation in services
This commit is contained in:
@@ -14,7 +14,7 @@ import (
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
|
||||
"slices"
|
||||
|
||||
@@ -72,39 +72,40 @@ type Lockdown struct {
|
||||
ActiveUntil time.Time
|
||||
}
|
||||
|
||||
type AuthServiceConfig struct {
|
||||
LocalUsers *[]model.LocalUser
|
||||
OauthWhitelist []string
|
||||
SessionExpiry int
|
||||
SessionMaxLifetime int
|
||||
SecureCookie bool
|
||||
CookieDomain string
|
||||
LoginTimeout int
|
||||
LoginMaxRetries int
|
||||
SessionCookieName string
|
||||
IP model.IPConfig
|
||||
LDAPGroupsCacheTTL int
|
||||
SubdomainsEnabled bool
|
||||
}
|
||||
|
||||
type AuthService struct {
|
||||
config AuthServiceConfig
|
||||
log *logger.Logger
|
||||
config model.Config
|
||||
runtime model.RuntimeConfig
|
||||
context context.Context
|
||||
|
||||
ldap *LdapService
|
||||
queries *repository.Queries
|
||||
oauthBroker *OAuthBrokerService
|
||||
|
||||
loginAttempts map[string]*LoginAttempt
|
||||
ldapGroupsCache map[string]*LdapGroupsCache
|
||||
oauthPendingSessions map[string]*OAuthPendingSession
|
||||
oauthMutex sync.RWMutex
|
||||
loginMutex sync.RWMutex
|
||||
ldapGroupsMutex sync.RWMutex
|
||||
ldap *LdapService
|
||||
queries *repository.Queries
|
||||
oauthBroker *OAuthBrokerService
|
||||
lockdown *Lockdown
|
||||
lockdownCtx context.Context
|
||||
lockdownCancelFunc context.CancelFunc
|
||||
}
|
||||
|
||||
func NewAuthService(config AuthServiceConfig, ldap *LdapService, queries *repository.Queries, oauthBroker *OAuthBrokerService) *AuthService {
|
||||
func NewAuthService(
|
||||
log *logger.Logger,
|
||||
config model.Config,
|
||||
runtime model.RuntimeConfig,
|
||||
context context.Context,
|
||||
ldap *LdapService,
|
||||
queries *repository.Queries,
|
||||
oauthBroker *OAuthBrokerService,
|
||||
) *AuthService {
|
||||
return &AuthService{
|
||||
log: log,
|
||||
runtime: runtime,
|
||||
context: context,
|
||||
config: config,
|
||||
loginAttempts: make(map[string]*LoginAttempt),
|
||||
ldapGroupsCache: make(map[string]*LdapGroupsCache),
|
||||
@@ -173,10 +174,10 @@ func (auth *AuthService) CheckUserPassword(search model.UserSearch, password str
|
||||
}
|
||||
|
||||
func (auth *AuthService) GetLocalUser(username string) *model.LocalUser {
|
||||
if auth.config.LocalUsers == nil {
|
||||
if auth.runtime.LocalUsers == nil {
|
||||
return nil
|
||||
}
|
||||
for _, user := range *auth.config.LocalUsers {
|
||||
for _, user := range auth.runtime.LocalUsers {
|
||||
if user.Username == username {
|
||||
return &user
|
||||
}
|
||||
@@ -209,7 +210,7 @@ func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
|
||||
auth.ldapGroupsMutex.Lock()
|
||||
auth.ldapGroupsCache[userDN] = &LdapGroupsCache{
|
||||
Groups: groups,
|
||||
Expires: time.Now().Add(time.Duration(auth.config.LDAPGroupsCacheTTL) * time.Second),
|
||||
Expires: time.Now().Add(time.Duration(auth.config.LDAP.GroupCacheTTL) * time.Second),
|
||||
}
|
||||
auth.ldapGroupsMutex.Unlock()
|
||||
|
||||
@@ -228,7 +229,7 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
|
||||
return true, remaining
|
||||
}
|
||||
|
||||
if auth.config.LoginMaxRetries <= 0 || auth.config.LoginTimeout <= 0 {
|
||||
if auth.config.Auth.LoginMaxRetries <= 0 || auth.config.Auth.LoginTimeout <= 0 {
|
||||
return false, 0
|
||||
}
|
||||
|
||||
@@ -246,7 +247,7 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
|
||||
}
|
||||
|
||||
func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
|
||||
if auth.config.LoginMaxRetries <= 0 || auth.config.LoginTimeout <= 0 {
|
||||
if auth.config.Auth.LoginMaxRetries <= 0 || auth.config.Auth.LoginTimeout <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -277,14 +278,14 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
|
||||
|
||||
attempt.FailedAttempts++
|
||||
|
||||
if attempt.FailedAttempts >= auth.config.LoginMaxRetries {
|
||||
attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.LoginTimeout) * time.Second)
|
||||
tlog.App.Warn().Str("identifier", identifier).Int("timeout", auth.config.LoginTimeout).Msg("Account locked due to too many failed login attempts")
|
||||
if attempt.FailedAttempts >= auth.config.Auth.LoginMaxRetries {
|
||||
attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
|
||||
auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", attempt.FailedAttempts).Msg("Account locked due to too many failed login attempts")
|
||||
}
|
||||
}
|
||||
|
||||
func (auth *AuthService) IsEmailWhitelisted(email string) bool {
|
||||
return utils.CheckFilter(strings.Join(auth.config.OauthWhitelist, ","), email)
|
||||
return utils.CheckFilter(strings.Join(auth.runtime.OAuthWhitelist, ","), email)
|
||||
}
|
||||
|
||||
func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) {
|
||||
@@ -299,7 +300,7 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
|
||||
if data.TotpPending {
|
||||
expiry = 3600
|
||||
} else {
|
||||
expiry = auth.config.SessionExpiry
|
||||
expiry = auth.config.Auth.SessionExpiry
|
||||
}
|
||||
|
||||
expiresAt := time.Now().Add(time.Duration(expiry) * time.Second)
|
||||
@@ -325,13 +326,13 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
|
||||
}
|
||||
|
||||
return &http.Cookie{
|
||||
Name: auth.config.SessionCookieName,
|
||||
Name: auth.runtime.SessionCookieName,
|
||||
Value: session.UUID,
|
||||
Path: "/",
|
||||
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
|
||||
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
|
||||
Expires: expiresAt,
|
||||
MaxAge: int(time.Until(expiresAt).Seconds()),
|
||||
Secure: auth.config.SecureCookie,
|
||||
Secure: auth.config.Auth.SecureCookie,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
}, nil
|
||||
@@ -348,8 +349,8 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http
|
||||
|
||||
var refreshThreshold int64
|
||||
|
||||
if auth.config.SessionExpiry <= int(time.Hour.Seconds()) {
|
||||
refreshThreshold = int64(auth.config.SessionExpiry / 2)
|
||||
if auth.config.Auth.SessionExpiry <= int(time.Hour.Seconds()) {
|
||||
refreshThreshold = int64(auth.config.Auth.SessionExpiry / 2)
|
||||
} else {
|
||||
refreshThreshold = int64(time.Hour.Seconds())
|
||||
}
|
||||
@@ -378,13 +379,13 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http
|
||||
}
|
||||
|
||||
return &http.Cookie{
|
||||
Name: auth.config.SessionCookieName,
|
||||
Name: auth.runtime.SessionCookieName,
|
||||
Value: session.UUID,
|
||||
Path: "/",
|
||||
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
|
||||
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
|
||||
Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second),
|
||||
MaxAge: int(newExpiry - currentTime),
|
||||
Secure: auth.config.SecureCookie,
|
||||
Secure: auth.config.Auth.SecureCookie,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
}, nil
|
||||
@@ -395,7 +396,7 @@ func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.
|
||||
err := auth.queries.DeleteSession(ctx, uuid)
|
||||
|
||||
if err != nil {
|
||||
tlog.App.Warn().Err(err).Msg("Failed to delete session from database, proceeding to clear cookie anyway")
|
||||
auth.log.App.Error().Err(err).Str("uuid", uuid).Msg("Failed to delete session from database")
|
||||
}
|
||||
|
||||
err = auth.queries.DeleteSession(ctx, uuid)
|
||||
@@ -405,13 +406,13 @@ func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.
|
||||
}
|
||||
|
||||
return &http.Cookie{
|
||||
Name: auth.config.SessionCookieName,
|
||||
Name: auth.runtime.SessionCookieName,
|
||||
Value: "",
|
||||
Path: "/",
|
||||
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
|
||||
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
|
||||
Expires: time.Now(),
|
||||
MaxAge: -1,
|
||||
Secure: auth.config.SecureCookie,
|
||||
Secure: auth.config.Auth.SecureCookie,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
}, nil
|
||||
@@ -429,8 +430,8 @@ func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*reposito
|
||||
|
||||
currentTime := time.Now().Unix()
|
||||
|
||||
if auth.config.SessionMaxLifetime != 0 && session.CreatedAt != 0 {
|
||||
if currentTime-session.CreatedAt > int64(auth.config.SessionMaxLifetime) {
|
||||
if auth.config.Auth.SessionMaxLifetime != 0 && session.CreatedAt != 0 {
|
||||
if currentTime-session.CreatedAt > int64(auth.config.Auth.SessionMaxLifetime) {
|
||||
err = auth.queries.DeleteSession(ctx, uuid)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to delete expired session: %w", err)
|
||||
@@ -451,7 +452,7 @@ func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*reposito
|
||||
}
|
||||
|
||||
func (auth *AuthService) LocalAuthConfigured() bool {
|
||||
return auth.config.LocalUsers != nil && len(*auth.config.LocalUsers) > 0
|
||||
return len(auth.runtime.LocalUsers) > 0
|
||||
}
|
||||
|
||||
func (auth *AuthService) LDAPAuthConfigured() bool {
|
||||
@@ -464,18 +465,18 @@ func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext
|
||||
}
|
||||
|
||||
if context.Provider == model.ProviderOAuth {
|
||||
tlog.App.Debug().Msg("Checking OAuth whitelist")
|
||||
auth.log.App.Debug().Msg("User is an OAuth user, checking OAuth whitelist")
|
||||
return utils.CheckFilter(acls.OAuth.Whitelist, context.OAuth.Email)
|
||||
}
|
||||
|
||||
if acls.Users.Block != "" {
|
||||
tlog.App.Debug().Msg("Checking blocked users")
|
||||
auth.log.App.Debug().Msg("Checking users block list")
|
||||
if utils.CheckFilter(acls.Users.Block, context.GetUsername()) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
tlog.App.Debug().Msg("Checking users")
|
||||
auth.log.App.Debug().Msg("Checking users allow list")
|
||||
return utils.CheckFilter(acls.Users.Allow, context.GetUsername())
|
||||
}
|
||||
|
||||
@@ -485,23 +486,23 @@ func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContex
|
||||
}
|
||||
|
||||
if !context.IsOAuth() {
|
||||
tlog.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check")
|
||||
auth.log.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check")
|
||||
return false
|
||||
}
|
||||
|
||||
if _, ok := model.OverrideProviders[context.OAuth.ID]; ok {
|
||||
tlog.App.Debug().Msg("Provider override for OAuth groups enabled, skipping group check")
|
||||
auth.log.App.Debug().Str("provider", context.OAuth.ID).Msg("Provider override detected, skipping group check")
|
||||
return true
|
||||
}
|
||||
|
||||
for _, userGroup := range context.OAuth.Groups {
|
||||
if utils.CheckFilter(acls.OAuth.Groups, strings.TrimSpace(userGroup)) {
|
||||
tlog.App.Trace().Str("group", userGroup).Str("required", acls.OAuth.Groups).Msg("User group matched")
|
||||
auth.log.App.Trace().Str("group", userGroup).Str("required", acls.OAuth.Groups).Msg("User group matched")
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
tlog.App.Debug().Msg("No groups matched")
|
||||
auth.log.App.Debug().Msg("No groups matched")
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -511,18 +512,18 @@ func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext
|
||||
}
|
||||
|
||||
if !context.IsLDAP() {
|
||||
tlog.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check")
|
||||
auth.log.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check")
|
||||
return false
|
||||
}
|
||||
|
||||
for _, userGroup := range context.LDAP.Groups {
|
||||
if utils.CheckFilter(acls.LDAP.Groups, strings.TrimSpace(userGroup)) {
|
||||
tlog.App.Trace().Str("group", userGroup).Str("required", acls.LDAP.Groups).Msg("User group matched")
|
||||
auth.log.App.Trace().Str("group", userGroup).Str("required", acls.LDAP.Groups).Msg("User group matched")
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
tlog.App.Debug().Msg("No groups matched")
|
||||
auth.log.App.Debug().Msg("No groups matched")
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -566,17 +567,17 @@ func (auth *AuthService) CheckIP(ip string, acls *model.App) bool {
|
||||
}
|
||||
|
||||
// Merge the global and app IP filter
|
||||
blockedIps := append(auth.config.IP.Block, acls.IP.Block...)
|
||||
allowedIPs := append(auth.config.IP.Allow, acls.IP.Allow...)
|
||||
blockedIps := append(auth.config.Auth.IP.Block, acls.IP.Block...)
|
||||
allowedIPs := append(auth.config.Auth.IP.Allow, acls.IP.Allow...)
|
||||
|
||||
for _, blocked := range blockedIps {
|
||||
res, err := utils.FilterIP(blocked, ip)
|
||||
if err != nil {
|
||||
tlog.App.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list")
|
||||
auth.log.App.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list")
|
||||
continue
|
||||
}
|
||||
if res {
|
||||
tlog.App.Debug().Str("ip", ip).Str("item", blocked).Msg("IP is in blocked list, denying access")
|
||||
auth.log.App.Debug().Str("ip", ip).Str("item", blocked).Msg("IP is in block list, denying access")
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -584,21 +585,21 @@ func (auth *AuthService) CheckIP(ip string, acls *model.App) bool {
|
||||
for _, allowed := range allowedIPs {
|
||||
res, err := utils.FilterIP(allowed, ip)
|
||||
if err != nil {
|
||||
tlog.App.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list")
|
||||
auth.log.App.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list")
|
||||
continue
|
||||
}
|
||||
if res {
|
||||
tlog.App.Debug().Str("ip", ip).Str("item", allowed).Msg("IP is in allowed list, allowing access")
|
||||
auth.log.App.Debug().Str("ip", ip).Str("item", allowed).Msg("IP is in allow list, allowing access")
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
if len(allowedIPs) > 0 {
|
||||
tlog.App.Debug().Str("ip", ip).Msg("IP not in allow list, denying access")
|
||||
auth.log.App.Debug().Str("ip", ip).Msg("IP not in allow list, denying access")
|
||||
return false
|
||||
}
|
||||
|
||||
tlog.App.Debug().Str("ip", ip).Msg("IP not in allow or block list, allowing by default")
|
||||
auth.log.App.Debug().Str("ip", ip).Msg("IP not in any block or allow list, allowing access by default")
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -610,16 +611,16 @@ func (auth *AuthService) IsBypassedIP(ip string, acls *model.App) bool {
|
||||
for _, bypassed := range acls.IP.Bypass {
|
||||
res, err := utils.FilterIP(bypassed, ip)
|
||||
if err != nil {
|
||||
tlog.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list")
|
||||
auth.log.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list")
|
||||
continue
|
||||
}
|
||||
if res {
|
||||
tlog.App.Debug().Str("ip", ip).Str("item", bypassed).Msg("IP is in bypass list, allowing access")
|
||||
auth.log.App.Debug().Str("ip", ip).Str("item", bypassed).Msg("IP is in bypass list, skipping authentication")
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
tlog.App.Debug().Str("ip", ip).Msg("IP not in bypass list, continuing with authentication")
|
||||
auth.log.App.Debug().Str("ip", ip).Msg("IP not in bypass list, proceeding with authentication")
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -726,18 +727,23 @@ func (auth *AuthService) CleanupOAuthSessionsRoutine() {
|
||||
ticker := time.NewTicker(30 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
auth.oauthMutex.Lock()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
auth.oauthMutex.Lock()
|
||||
|
||||
now := time.Now()
|
||||
now := time.Now()
|
||||
|
||||
for sessionId, session := range auth.oauthPendingSessions {
|
||||
if now.After(session.ExpiresAt) {
|
||||
delete(auth.oauthPendingSessions, sessionId)
|
||||
for sessionId, session := range auth.oauthPendingSessions {
|
||||
if now.After(session.ExpiresAt) {
|
||||
delete(auth.oauthPendingSessions, sessionId)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auth.oauthMutex.Unlock()
|
||||
auth.oauthMutex.Unlock()
|
||||
case <-auth.context.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -806,11 +812,11 @@ func (auth *AuthService) lockdownMode() {
|
||||
|
||||
auth.loginMutex.Lock()
|
||||
|
||||
tlog.App.Warn().Msg("Multiple login attempts detected, possibly DDOS attack. Activating temporary lockdown.")
|
||||
auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")
|
||||
|
||||
auth.lockdown = &Lockdown{
|
||||
Active: true,
|
||||
ActiveUntil: time.Now().Add(time.Duration(auth.config.LoginTimeout) * time.Second),
|
||||
ActiveUntil: time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second),
|
||||
}
|
||||
|
||||
// At this point all login attemps will also expire so,
|
||||
@@ -827,11 +833,14 @@ func (auth *AuthService) lockdownMode() {
|
||||
// Timer expired, end lockdown
|
||||
case <-ctx.Done():
|
||||
// Context cancelled, end lockdown
|
||||
case <-auth.context.Done():
|
||||
// Service is shutting down, end lockdown
|
||||
}
|
||||
|
||||
auth.loginMutex.Lock()
|
||||
|
||||
tlog.App.Info().Msg("Lockdown period ended, resuming normal operation")
|
||||
auth.log.App.Info().Msg("Exiting lockdown mode")
|
||||
|
||||
auth.lockdown = nil
|
||||
auth.loginMutex.Unlock()
|
||||
}
|
||||
@@ -845,10 +854,3 @@ func (auth *AuthService) ClearRateLimitsTestingOnly() {
|
||||
}
|
||||
auth.loginMutex.Unlock()
|
||||
}
|
||||
|
||||
func (auth *AuthService) getCookieDomain() string {
|
||||
if auth.config.SubdomainsEnabled {
|
||||
return "." + auth.config.CookieDomain
|
||||
}
|
||||
return auth.config.CookieDomain
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user