diff --git a/internal/bootstrap/router_bootstrap.go b/internal/bootstrap/router_bootstrap.go index 2250fb19..47d3461e 100644 --- a/internal/bootstrap/router_bootstrap.go +++ b/internal/bootstrap/router_bootstrap.go @@ -29,10 +29,7 @@ func (app *BootstrapApp) setupRouter() error { } } - contextMiddleware := middleware.NewContextMiddleware(middleware.ContextMiddlewareConfig{ - CookieDomain: app.runtime.CookieDomain, - SessionCookieName: app.runtime.SessionCookieName, - }, app.services.authService, app.services.oauthBrokerService) + contextMiddleware := middleware.NewContextMiddleware(app.log, app.runtime, app.services.authService, app.services.oauthBrokerService) err := contextMiddleware.Init() @@ -52,7 +49,7 @@ func (app *BootstrapApp) setupRouter() error { engine.Use(uiMiddleware.Middleware()) - zerologMiddleware := middleware.NewZerologMiddleware() + zerologMiddleware := middleware.NewZerologMiddleware(app.log) err = zerologMiddleware.Init() diff --git a/internal/bootstrap/service_bootstrap.go b/internal/bootstrap/service_bootstrap.go index 9f44540d..6d79d801 100644 --- a/internal/bootstrap/service_bootstrap.go +++ b/internal/bootstrap/service_bootstrap.go @@ -4,21 +4,11 @@ import ( "fmt" "os" - "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/service" ) func (app *BootstrapApp) setupServices() error { - ldapService := service.NewLdapService(service.LdapServiceConfig{ - Address: app.config.LDAP.Address, - BindDN: app.config.LDAP.BindDN, - BindPassword: app.config.LDAP.BindPassword, - BaseDN: app.config.LDAP.BaseDN, - Insecure: app.config.LDAP.Insecure, - SearchFilter: app.config.LDAP.SearchFilter, - AuthCert: app.config.LDAP.AuthCert, - AuthKey: app.config.LDAP.AuthKey, - }) + ldapService := service.NewLdapService(app.log, app.config, app.ctx) err := ldapService.Init() @@ -32,10 +22,12 @@ func (app *BootstrapApp) setupServices() error { useKubernetes := app.config.LabelProvider == "kubernetes" || (app.config.LabelProvider == "auto" && os.Getenv("KUBERNETES_SERVICE_HOST") != "") + var labelProvider service.LabelProviderImpl + if useKubernetes { app.log.App.Debug().Msg("Using Kubernetes label provider") - kubernetesService := service.NewKubernetesService() + kubernetesService := service.NewKubernetesService(app.log, app.ctx) err = kubernetesService.Init() @@ -44,11 +36,11 @@ func (app *BootstrapApp) setupServices() error { } app.services.kubernetesService = kubernetesService - app.runtime.LabelProvider = model.LabelProviderKubernetes + labelProvider = kubernetesService } else { app.log.App.Debug().Msg("Using Docker label provider") - dockerService := service.NewDockerService() + dockerService := service.NewDockerService(app.log, app.ctx) err = dockerService.Init() @@ -57,10 +49,10 @@ func (app *BootstrapApp) setupServices() error { } app.services.dockerService = dockerService - app.runtime.LabelProvider = model.LabelProviderDocker + labelProvider = dockerService } - accessControlsService := service.NewAccessControlsService(app.runtime.LabelProvider, app.config.Apps) + accessControlsService := service.NewAccessControlsService(app.log, labelProvider, app.config.Apps) err = accessControlsService.Init() @@ -70,7 +62,7 @@ func (app *BootstrapApp) setupServices() error { app.services.accessControlService = accessControlsService - oauthBrokerService := service.NewOAuthBrokerService(app.runtime.OAuthProviders) + oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders) err = oauthBrokerService.Init() @@ -80,20 +72,7 @@ func (app *BootstrapApp) setupServices() error { app.services.oauthBrokerService = oauthBrokerService - authService := service.NewAuthService(service.AuthServiceConfig{ - LocalUsers: &app.runtime.LocalUsers, - OauthWhitelist: app.runtime.OAuthWhitelist, - SessionExpiry: app.config.Auth.SessionExpiry, - SessionMaxLifetime: app.config.Auth.SessionMaxLifetime, - SecureCookie: app.config.Auth.SecureCookie, - CookieDomain: app.runtime.CookieDomain, - LoginTimeout: app.config.Auth.LoginTimeout, - LoginMaxRetries: app.config.Auth.LoginMaxRetries, - SessionCookieName: app.runtime.SessionCookieName, - IP: app.config.Auth.IP, - LDAPGroupsCacheTTL: app.config.LDAP.GroupCacheTTL, - SubdomainsEnabled: app.config.Auth.SubdomainsEnabled, - }, app.services.ldapService, app.queries, app.services.oauthBrokerService) + authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, app.services.ldapService, app.queries, app.services.oauthBrokerService) err = authService.Init() @@ -103,13 +82,7 @@ func (app *BootstrapApp) setupServices() error { app.services.authService = authService - oidcService := service.NewOIDCService(service.OIDCServiceConfig{ - Clients: app.config.OIDC.Clients, - PrivateKeyPath: app.config.OIDC.PrivateKeyPath, - PublicKeyPath: app.config.OIDC.PublicKeyPath, - Issuer: app.config.AppURL, - SessionExpiry: app.config.Auth.SessionExpiry, - }, app.queries) + oidcService := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ctx) err = oidcService.Init() diff --git a/internal/controller/user_controller.go b/internal/controller/user_controller.go index a7a1f948..b405bb03 100644 --- a/internal/controller/user_controller.go +++ b/internal/controller/user_controller.go @@ -375,7 +375,7 @@ func (controller *UserController) totpHandler(c *gin.Context) { cookie, err := controller.auth.CreateSession(c, sessionCookie) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to create session cookie") + controller.log.App.Error().Err(err).Str("username", context.GetUsername()).Msg("Failed to create session cookie after successful TOTP verification") c.JSON(500, gin.H{ "status": 500, "message": "Internal Server Error", diff --git a/internal/model/runtime.go b/internal/model/runtime.go index 72eab370..9bd81770 100644 --- a/internal/model/runtime.go +++ b/internal/model/runtime.go @@ -13,7 +13,6 @@ type RuntimeConfig struct { OAuthWhitelist []string ConfiguredProviders []Provider OIDCClients []OIDCClientConfig - LabelProvider LabelProvider } type Provider struct { @@ -21,10 +20,3 @@ type Provider struct { ID string `json:"id"` OAuth bool `json:"oauth"` } - -type LabelProvider int - -const ( - LabelProviderDocker LabelProvider = iota - LabelProviderKubernetes -) diff --git a/internal/service/access_controls_service.go b/internal/service/access_controls_service.go index d31ae6b7..9bfe834d 100644 --- a/internal/service/access_controls_service.go +++ b/internal/service/access_controls_service.go @@ -4,7 +4,7 @@ import ( "strings" "github.com/tinyauthapp/tinyauth/internal/model" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) type LabelProviderImpl interface { @@ -12,12 +12,17 @@ type LabelProviderImpl interface { } type AccessControlsService struct { - labelProvider LabelProvider + log *logger.Logger + labelProvider LabelProviderImpl static map[string]model.App } -func NewAccessControlsService(labelProvider LabelProvider, static map[string]model.App) *AccessControlsService { +func NewAccessControlsService( + log *logger.Logger, + labelProvider LabelProviderImpl, + static map[string]model.App) *AccessControlsService { return &AccessControlsService{ + log: log, labelProvider: labelProvider, static: static, } @@ -31,13 +36,13 @@ func (acls *AccessControlsService) lookupStaticACLs(domain string) *model.App { var appAcls *model.App for app, config := range acls.static { if config.Config.Domain == domain { - tlog.App.Debug().Str("name", app).Msg("Found matching container by domain") + acls.log.App.Debug().Str("name", app).Msg("Found matching container by domain") appAcls = &config break // If we find a match by domain, we can stop searching } if strings.SplitN(domain, ".", 2)[0] == app { - tlog.App.Debug().Str("name", app).Msg("Found matching container by app name") + acls.log.App.Debug().Str("name", app).Msg("Found matching container by app name") appAcls = &config break // If we find a match by app name, we can stop searching } @@ -50,11 +55,11 @@ func (acls *AccessControlsService) GetAccessControls(domain string) (*model.App, app := acls.lookupStaticACLs(domain) if app != nil { - tlog.App.Debug().Msg("Using ACls from static configuration") + acls.log.App.Debug().Msg("Using static ACLs for app") return app, nil } // Fallback to label provider - tlog.App.Debug().Msg("Falling back to label provider for ACLs") + acls.log.App.Debug().Msg("Using label provider for app") return acls.labelProvider.GetLabels(domain) } diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index 16c53fe0..8b891c34 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -14,7 +14,7 @@ import ( "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/utils" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" "slices" @@ -72,39 +72,40 @@ type Lockdown struct { ActiveUntil time.Time } -type AuthServiceConfig struct { - LocalUsers *[]model.LocalUser - OauthWhitelist []string - SessionExpiry int - SessionMaxLifetime int - SecureCookie bool - CookieDomain string - LoginTimeout int - LoginMaxRetries int - SessionCookieName string - IP model.IPConfig - LDAPGroupsCacheTTL int - SubdomainsEnabled bool -} - type AuthService struct { - config AuthServiceConfig + log *logger.Logger + config model.Config + runtime model.RuntimeConfig + context context.Context + + ldap *LdapService + queries *repository.Queries + oauthBroker *OAuthBrokerService + loginAttempts map[string]*LoginAttempt ldapGroupsCache map[string]*LdapGroupsCache oauthPendingSessions map[string]*OAuthPendingSession oauthMutex sync.RWMutex loginMutex sync.RWMutex ldapGroupsMutex sync.RWMutex - ldap *LdapService - queries *repository.Queries - oauthBroker *OAuthBrokerService lockdown *Lockdown lockdownCtx context.Context lockdownCancelFunc context.CancelFunc } -func NewAuthService(config AuthServiceConfig, ldap *LdapService, queries *repository.Queries, oauthBroker *OAuthBrokerService) *AuthService { +func NewAuthService( + log *logger.Logger, + config model.Config, + runtime model.RuntimeConfig, + context context.Context, + ldap *LdapService, + queries *repository.Queries, + oauthBroker *OAuthBrokerService, +) *AuthService { return &AuthService{ + log: log, + runtime: runtime, + context: context, config: config, loginAttempts: make(map[string]*LoginAttempt), ldapGroupsCache: make(map[string]*LdapGroupsCache), @@ -173,10 +174,10 @@ func (auth *AuthService) CheckUserPassword(search model.UserSearch, password str } func (auth *AuthService) GetLocalUser(username string) *model.LocalUser { - if auth.config.LocalUsers == nil { + if auth.runtime.LocalUsers == nil { return nil } - for _, user := range *auth.config.LocalUsers { + for _, user := range auth.runtime.LocalUsers { if user.Username == username { return &user } @@ -209,7 +210,7 @@ func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) { auth.ldapGroupsMutex.Lock() auth.ldapGroupsCache[userDN] = &LdapGroupsCache{ Groups: groups, - Expires: time.Now().Add(time.Duration(auth.config.LDAPGroupsCacheTTL) * time.Second), + Expires: time.Now().Add(time.Duration(auth.config.LDAP.GroupCacheTTL) * time.Second), } auth.ldapGroupsMutex.Unlock() @@ -228,7 +229,7 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) { return true, remaining } - if auth.config.LoginMaxRetries <= 0 || auth.config.LoginTimeout <= 0 { + if auth.config.Auth.LoginMaxRetries <= 0 || auth.config.Auth.LoginTimeout <= 0 { return false, 0 } @@ -246,7 +247,7 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) { } func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) { - if auth.config.LoginMaxRetries <= 0 || auth.config.LoginTimeout <= 0 { + if auth.config.Auth.LoginMaxRetries <= 0 || auth.config.Auth.LoginTimeout <= 0 { return } @@ -277,14 +278,14 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) { attempt.FailedAttempts++ - if attempt.FailedAttempts >= auth.config.LoginMaxRetries { - attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.LoginTimeout) * time.Second) - tlog.App.Warn().Str("identifier", identifier).Int("timeout", auth.config.LoginTimeout).Msg("Account locked due to too many failed login attempts") + if attempt.FailedAttempts >= auth.config.Auth.LoginMaxRetries { + attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second) + auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", attempt.FailedAttempts).Msg("Account locked due to too many failed login attempts") } } func (auth *AuthService) IsEmailWhitelisted(email string) bool { - return utils.CheckFilter(strings.Join(auth.config.OauthWhitelist, ","), email) + return utils.CheckFilter(strings.Join(auth.runtime.OAuthWhitelist, ","), email) } func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) { @@ -299,7 +300,7 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess if data.TotpPending { expiry = 3600 } else { - expiry = auth.config.SessionExpiry + expiry = auth.config.Auth.SessionExpiry } expiresAt := time.Now().Add(time.Duration(expiry) * time.Second) @@ -325,13 +326,13 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess } return &http.Cookie{ - Name: auth.config.SessionCookieName, + Name: auth.runtime.SessionCookieName, Value: session.UUID, Path: "/", - Domain: fmt.Sprintf(".%s", auth.config.CookieDomain), + Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain), Expires: expiresAt, MaxAge: int(time.Until(expiresAt).Seconds()), - Secure: auth.config.SecureCookie, + Secure: auth.config.Auth.SecureCookie, HttpOnly: true, SameSite: http.SameSiteLaxMode, }, nil @@ -348,8 +349,8 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http var refreshThreshold int64 - if auth.config.SessionExpiry <= int(time.Hour.Seconds()) { - refreshThreshold = int64(auth.config.SessionExpiry / 2) + if auth.config.Auth.SessionExpiry <= int(time.Hour.Seconds()) { + refreshThreshold = int64(auth.config.Auth.SessionExpiry / 2) } else { refreshThreshold = int64(time.Hour.Seconds()) } @@ -378,13 +379,13 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http } return &http.Cookie{ - Name: auth.config.SessionCookieName, + Name: auth.runtime.SessionCookieName, Value: session.UUID, Path: "/", - Domain: fmt.Sprintf(".%s", auth.config.CookieDomain), + Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain), Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second), MaxAge: int(newExpiry - currentTime), - Secure: auth.config.SecureCookie, + Secure: auth.config.Auth.SecureCookie, HttpOnly: true, SameSite: http.SameSiteLaxMode, }, nil @@ -395,7 +396,7 @@ func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http. err := auth.queries.DeleteSession(ctx, uuid) if err != nil { - tlog.App.Warn().Err(err).Msg("Failed to delete session from database, proceeding to clear cookie anyway") + auth.log.App.Error().Err(err).Str("uuid", uuid).Msg("Failed to delete session from database") } err = auth.queries.DeleteSession(ctx, uuid) @@ -405,13 +406,13 @@ func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http. } return &http.Cookie{ - Name: auth.config.SessionCookieName, + Name: auth.runtime.SessionCookieName, Value: "", Path: "/", - Domain: fmt.Sprintf(".%s", auth.config.CookieDomain), + Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain), Expires: time.Now(), MaxAge: -1, - Secure: auth.config.SecureCookie, + Secure: auth.config.Auth.SecureCookie, HttpOnly: true, SameSite: http.SameSiteLaxMode, }, nil @@ -429,8 +430,8 @@ func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*reposito currentTime := time.Now().Unix() - if auth.config.SessionMaxLifetime != 0 && session.CreatedAt != 0 { - if currentTime-session.CreatedAt > int64(auth.config.SessionMaxLifetime) { + if auth.config.Auth.SessionMaxLifetime != 0 && session.CreatedAt != 0 { + if currentTime-session.CreatedAt > int64(auth.config.Auth.SessionMaxLifetime) { err = auth.queries.DeleteSession(ctx, uuid) if err != nil { return nil, fmt.Errorf("failed to delete expired session: %w", err) @@ -451,7 +452,7 @@ func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*reposito } func (auth *AuthService) LocalAuthConfigured() bool { - return auth.config.LocalUsers != nil && len(*auth.config.LocalUsers) > 0 + return len(auth.runtime.LocalUsers) > 0 } func (auth *AuthService) LDAPAuthConfigured() bool { @@ -464,18 +465,18 @@ func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext } if context.Provider == model.ProviderOAuth { - tlog.App.Debug().Msg("Checking OAuth whitelist") + auth.log.App.Debug().Msg("User is an OAuth user, checking OAuth whitelist") return utils.CheckFilter(acls.OAuth.Whitelist, context.OAuth.Email) } if acls.Users.Block != "" { - tlog.App.Debug().Msg("Checking blocked users") + auth.log.App.Debug().Msg("Checking users block list") if utils.CheckFilter(acls.Users.Block, context.GetUsername()) { return false } } - tlog.App.Debug().Msg("Checking users") + auth.log.App.Debug().Msg("Checking users allow list") return utils.CheckFilter(acls.Users.Allow, context.GetUsername()) } @@ -485,23 +486,23 @@ func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContex } if !context.IsOAuth() { - tlog.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check") + auth.log.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check") return false } if _, ok := model.OverrideProviders[context.OAuth.ID]; ok { - tlog.App.Debug().Msg("Provider override for OAuth groups enabled, skipping group check") + auth.log.App.Debug().Str("provider", context.OAuth.ID).Msg("Provider override detected, skipping group check") return true } for _, userGroup := range context.OAuth.Groups { if utils.CheckFilter(acls.OAuth.Groups, strings.TrimSpace(userGroup)) { - tlog.App.Trace().Str("group", userGroup).Str("required", acls.OAuth.Groups).Msg("User group matched") + auth.log.App.Trace().Str("group", userGroup).Str("required", acls.OAuth.Groups).Msg("User group matched") return true } } - tlog.App.Debug().Msg("No groups matched") + auth.log.App.Debug().Msg("No groups matched") return false } @@ -511,18 +512,18 @@ func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext } if !context.IsLDAP() { - tlog.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check") + auth.log.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check") return false } for _, userGroup := range context.LDAP.Groups { if utils.CheckFilter(acls.LDAP.Groups, strings.TrimSpace(userGroup)) { - tlog.App.Trace().Str("group", userGroup).Str("required", acls.LDAP.Groups).Msg("User group matched") + auth.log.App.Trace().Str("group", userGroup).Str("required", acls.LDAP.Groups).Msg("User group matched") return true } } - tlog.App.Debug().Msg("No groups matched") + auth.log.App.Debug().Msg("No groups matched") return false } @@ -566,17 +567,17 @@ func (auth *AuthService) CheckIP(ip string, acls *model.App) bool { } // Merge the global and app IP filter - blockedIps := append(auth.config.IP.Block, acls.IP.Block...) - allowedIPs := append(auth.config.IP.Allow, acls.IP.Allow...) + blockedIps := append(auth.config.Auth.IP.Block, acls.IP.Block...) + allowedIPs := append(auth.config.Auth.IP.Allow, acls.IP.Allow...) for _, blocked := range blockedIps { res, err := utils.FilterIP(blocked, ip) if err != nil { - tlog.App.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list") + auth.log.App.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list") continue } if res { - tlog.App.Debug().Str("ip", ip).Str("item", blocked).Msg("IP is in blocked list, denying access") + auth.log.App.Debug().Str("ip", ip).Str("item", blocked).Msg("IP is in block list, denying access") return false } } @@ -584,21 +585,21 @@ func (auth *AuthService) CheckIP(ip string, acls *model.App) bool { for _, allowed := range allowedIPs { res, err := utils.FilterIP(allowed, ip) if err != nil { - tlog.App.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list") + auth.log.App.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list") continue } if res { - tlog.App.Debug().Str("ip", ip).Str("item", allowed).Msg("IP is in allowed list, allowing access") + auth.log.App.Debug().Str("ip", ip).Str("item", allowed).Msg("IP is in allow list, allowing access") return true } } if len(allowedIPs) > 0 { - tlog.App.Debug().Str("ip", ip).Msg("IP not in allow list, denying access") + auth.log.App.Debug().Str("ip", ip).Msg("IP not in allow list, denying access") return false } - tlog.App.Debug().Str("ip", ip).Msg("IP not in allow or block list, allowing by default") + auth.log.App.Debug().Str("ip", ip).Msg("IP not in any block or allow list, allowing access by default") return true } @@ -610,16 +611,16 @@ func (auth *AuthService) IsBypassedIP(ip string, acls *model.App) bool { for _, bypassed := range acls.IP.Bypass { res, err := utils.FilterIP(bypassed, ip) if err != nil { - tlog.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list") + auth.log.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list") continue } if res { - tlog.App.Debug().Str("ip", ip).Str("item", bypassed).Msg("IP is in bypass list, allowing access") + auth.log.App.Debug().Str("ip", ip).Str("item", bypassed).Msg("IP is in bypass list, skipping authentication") return true } } - tlog.App.Debug().Str("ip", ip).Msg("IP not in bypass list, continuing with authentication") + auth.log.App.Debug().Str("ip", ip).Msg("IP not in bypass list, proceeding with authentication") return false } @@ -726,18 +727,23 @@ func (auth *AuthService) CleanupOAuthSessionsRoutine() { ticker := time.NewTicker(30 * time.Minute) defer ticker.Stop() - for range ticker.C { - auth.oauthMutex.Lock() + for { + select { + case <-ticker.C: + auth.oauthMutex.Lock() - now := time.Now() + now := time.Now() - for sessionId, session := range auth.oauthPendingSessions { - if now.After(session.ExpiresAt) { - delete(auth.oauthPendingSessions, sessionId) + for sessionId, session := range auth.oauthPendingSessions { + if now.After(session.ExpiresAt) { + delete(auth.oauthPendingSessions, sessionId) + } } - } - auth.oauthMutex.Unlock() + auth.oauthMutex.Unlock() + case <-auth.context.Done(): + return + } } } @@ -806,11 +812,11 @@ func (auth *AuthService) lockdownMode() { auth.loginMutex.Lock() - tlog.App.Warn().Msg("Multiple login attempts detected, possibly DDOS attack. Activating temporary lockdown.") + auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode") auth.lockdown = &Lockdown{ Active: true, - ActiveUntil: time.Now().Add(time.Duration(auth.config.LoginTimeout) * time.Second), + ActiveUntil: time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second), } // At this point all login attemps will also expire so, @@ -827,11 +833,14 @@ func (auth *AuthService) lockdownMode() { // Timer expired, end lockdown case <-ctx.Done(): // Context cancelled, end lockdown + case <-auth.context.Done(): + // Service is shutting down, end lockdown } auth.loginMutex.Lock() - tlog.App.Info().Msg("Lockdown period ended, resuming normal operation") + auth.log.App.Info().Msg("Exiting lockdown mode") + auth.lockdown = nil auth.loginMutex.Unlock() } @@ -845,10 +854,3 @@ func (auth *AuthService) ClearRateLimitsTestingOnly() { } auth.loginMutex.Unlock() } - -func (auth *AuthService) getCookieDomain() string { - if auth.config.SubdomainsEnabled { - return "." + auth.config.CookieDomain - } - return auth.config.CookieDomain -} diff --git a/internal/service/docker_service.go b/internal/service/docker_service.go index c5f95dd4..763e26fb 100644 --- a/internal/service/docker_service.go +++ b/internal/service/docker_service.go @@ -6,20 +6,28 @@ import ( "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/utils/decoders" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" container "github.com/docker/docker/api/types/container" "github.com/docker/docker/client" ) type DockerService struct { - client *client.Client - context context.Context + log *logger.Logger + client *client.Client + context context.Context + isConnected bool } -func NewDockerService() *DockerService { - return &DockerService{} +func NewDockerService( + log *logger.Logger, + context context.Context, +) *DockerService { + return &DockerService{ + log: log, + context: context, + } } func (docker *DockerService) Init() error { @@ -28,16 +36,14 @@ func (docker *DockerService) Init() error { return err } - ctx := context.Background() - client.NegotiateAPIVersion(ctx) + client.NegotiateAPIVersion(docker.context) docker.client = client - docker.context = ctx _, err = docker.client.Ping(docker.context) if err != nil { - tlog.App.Debug().Err(err).Msg("Docker not connected") + docker.log.App.Debug().Err(err).Msg("Docker not connected") docker.isConnected = false docker.client = nil docker.context = nil @@ -45,7 +51,9 @@ func (docker *DockerService) Init() error { } docker.isConnected = true - tlog.App.Debug().Msg("Docker connected") + docker.log.App.Debug().Msg("Docker connected successfully") + + go docker.watchAndClose() return nil } @@ -60,7 +68,7 @@ func (docker *DockerService) inspectContainer(containerId string) (container.Ins func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) { if !docker.isConnected { - tlog.App.Debug().Msg("Docker not connected, returning empty labels") + docker.log.App.Debug().Msg("Docker service not connected, returning empty labels") return nil, nil } @@ -82,17 +90,28 @@ func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) { for appName, appLabels := range labels.Apps { if appLabels.Config.Domain == appDomain { - tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by domain") + docker.log.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by domain") return &appLabels, nil } if strings.SplitN(appDomain, ".", 2)[0] == appName { - tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by app name") + docker.log.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by app name") return &appLabels, nil } } } - tlog.App.Debug().Msg("No matching container found, returning empty labels") + docker.log.App.Debug().Str("domain", appDomain).Msg("No matching container found for domain") return nil, nil } + +func (docker *DockerService) watchAndClose() { + <-docker.context.Done() + docker.log.App.Debug().Msg("Closing Docker client") + if docker.client != nil { + err := docker.client.Close() + if err != nil { + docker.log.App.Error().Err(err).Msg("Error closing Docker client") + } + } +} diff --git a/internal/service/kubernetes_service.go b/internal/service/kubernetes_service.go index 9c5ad427..acba24e4 100644 --- a/internal/service/kubernetes_service.go +++ b/internal/service/kubernetes_service.go @@ -9,7 +9,7 @@ import ( "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/utils/decoders" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" @@ -36,8 +36,10 @@ type ingressApp struct { } type KubernetesService struct { + log *logger.Logger + ctx context.Context + client dynamic.Interface - ctx context.Context cancel context.CancelFunc started bool mu sync.RWMutex @@ -46,8 +48,13 @@ type KubernetesService struct { appNameIndex map[string]ingressAppKey } -func NewKubernetesService() *KubernetesService { +func NewKubernetesService( + log *logger.Logger, + context context.Context, +) *KubernetesService { return &KubernetesService{ + log: log, + ctx: context, ingressApps: make(map[ingressKey][]ingressApp), domainIndex: make(map[string]ingressAppKey), appNameIndex: make(map[string]ingressAppKey), @@ -133,7 +140,7 @@ func (k *KubernetesService) updateFromItem(item *unstructured.Unstructured) { } labels, err := decoders.DecodeLabels[model.Apps](annotations, "apps") if err != nil { - tlog.App.Debug().Err(err).Msg("Failed to decode labels from annotations") + k.log.App.Warn().Err(err).Str("namespace", namespace).Str("name", name).Msg("Failed to decode ingress labels, skipping") k.removeIngress(namespace, name) return } @@ -161,13 +168,13 @@ func (k *KubernetesService) resyncGVR(gvr schema.GroupVersionResource) error { list, err := k.client.Resource(gvr).List(ctx, metav1.ListOptions{}) if err != nil { - tlog.App.Debug().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to list ingresses during resync") + k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to list resources for resync") return err } for i := range list.Items { k.updateFromItem(&list.Items[i]) } - tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Int("count", len(list.Items)).Msg("Resynced ingress cache") + k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Int("count", len(list.Items)).Msg("Resync complete") return nil } @@ -181,14 +188,14 @@ func (k *KubernetesService) runWatcher(gvr schema.GroupVersionResource, w watch. return false case event, ok := <-w.ResultChan(): if !ok { - tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher channel closed, restarting in 5 seconds") + k.log.App.Warn().Str("api", gvr.GroupVersion().String()).Msg("Watcher channel closed, restarting watcher") w.Stop() time.Sleep(5 * time.Second) return true } item, ok := event.Object.(*unstructured.Unstructured) if !ok { - tlog.App.Warn().Str("api", gvr.GroupVersion().String()).Msg("Failed to cast watched object") + k.log.App.Warn().Str("api", gvr.GroupVersion().String()).Msg("Received unexpected event object, skipping") continue } switch event.Type { @@ -199,7 +206,7 @@ func (k *KubernetesService) runWatcher(gvr schema.GroupVersionResource, w watch. } case <-resyncTicker.C: if err := k.resyncGVR(gvr); err != nil { - tlog.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed") + k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed during watcher run") } } } @@ -210,29 +217,29 @@ func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource) { defer resyncTicker.Stop() if err := k.resyncGVR(gvr); err != nil { - tlog.App.Error().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Initial resync failed, retrying in 30 seconds") + k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Initial resync failed, will retry") time.Sleep(30 * time.Second) } for { select { case <-k.ctx.Done(): - tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Stopping watcher") + k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Context cancelled, stopping watcher") return case <-resyncTicker.C: if err := k.resyncGVR(gvr); err != nil { - tlog.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed") + k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed, will retry") } default: ctx, cancel := context.WithCancel(k.ctx) watcher, err := k.client.Resource(gvr).Watch(ctx, metav1.ListOptions{}) if err != nil { - tlog.App.Error().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to start watcher") + k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to start watcher, will retry") cancel() time.Sleep(10 * time.Second) continue } - tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher started") + k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher started successfully") if !k.runWatcher(gvr, watcher, resyncTicker) { cancel() return @@ -257,7 +264,7 @@ func (k *KubernetesService) Init() error { } k.client = client - k.ctx, k.cancel = context.WithCancel(context.Background()) + k.ctx, k.cancel = context.WithCancel(k.ctx) gvr := schema.GroupVersionResource{ Group: "networking.k8s.io", @@ -269,38 +276,38 @@ func (k *KubernetesService) Init() error { defer accessCancel() _, err = k.client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1}) if err != nil { - tlog.App.Warn().Err(err).Msg("Insufficient permissions for networking.k8s.io/v1 Ingress, Kubernetes label provider will not work") + k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled") k.started = false return nil } - tlog.App.Debug().Msg("networking.k8s.io/v1 Ingress API accessible") + k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher") go k.watchGVR(gvr) k.started = true - tlog.App.Info().Msg("Kubernetes label provider initialized") + k.log.App.Debug().Msg("Kubernetes label provider started successfully") return nil } func (k *KubernetesService) GetLabels(appDomain string) (*model.App, error) { if !k.started { - tlog.App.Debug().Msg("Kubernetes not connected, returning empty labels") + k.log.App.Debug().Str("domain", appDomain).Msg("Kubernetes label provider not started, skipping") return nil, nil } // First check cache app := k.getByDomain(appDomain) if app != nil { - tlog.App.Debug().Str("domain", appDomain).Msg("Found labels in cache by domain") + k.log.App.Debug().Str("domain", appDomain).Msg("Found labels in cache by domain") return app, nil } appName := strings.SplitN(appDomain, ".", 2)[0] app = k.getByAppName(appName) if app != nil { - tlog.App.Debug().Str("domain", appDomain).Str("appName", appName).Msg("Found labels in cache by app name") + k.log.App.Debug().Str("domain", appDomain).Str("appName", appName).Msg("Found labels in cache by app name") return app, nil } - tlog.App.Debug().Str("domain", appDomain).Msg("Cache miss, no matching ingress found") + k.log.App.Debug().Str("domain", appDomain).Msg("No labels found for domain") return nil, nil } diff --git a/internal/service/ldap_service.go b/internal/service/ldap_service.go index 0963ebf5..d356cc75 100644 --- a/internal/service/ldap_service.go +++ b/internal/service/ldap_service.go @@ -9,31 +9,30 @@ import ( "github.com/cenkalti/backoff/v5" ldapgo "github.com/go-ldap/ldap/v3" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/model" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) -type LdapServiceConfig struct { - Address string - BindDN string - BindPassword string - BaseDN string - Insecure bool - SearchFilter string - AuthCert string - AuthKey string -} - type LdapService struct { - config LdapServiceConfig + log *logger.Logger + config model.Config + context context.Context + conn *ldapgo.Conn mutex sync.RWMutex cert *tls.Certificate isConfigured bool } -func NewLdapService(config LdapServiceConfig) *LdapService { +func NewLdapService( + log *logger.Logger, + config model.Config, + context context.Context, +) *LdapService { return &LdapService{ - config: config, + log: log, + config: config, + context: context, } } @@ -57,7 +56,7 @@ func (ldap *LdapService) Unconfigure() error { } func (ldap *LdapService) Init() error { - if ldap.config.Address == "" { + if ldap.config.LDAP.Address == "" { ldap.isConfigured = false return nil } @@ -65,13 +64,13 @@ func (ldap *LdapService) Init() error { ldap.isConfigured = true // Check whether authentication with client certificate is possible - if ldap.config.AuthCert != "" && ldap.config.AuthKey != "" { - cert, err := tls.LoadX509KeyPair(ldap.config.AuthCert, ldap.config.AuthKey) + if ldap.config.LDAP.AuthCert != "" && ldap.config.LDAP.AuthKey != "" { + cert, err := tls.LoadX509KeyPair(ldap.config.LDAP.AuthCert, ldap.config.LDAP.AuthKey) if err != nil { return fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err) } ldap.cert = &cert - tlog.App.Info().Msg("Using LDAP with mTLS authentication") + ldap.log.App.Info().Msg("LDAP mTLS authentication configured successfully") // TODO: Add optional extra CA certificates, instead of `InsecureSkipVerify` /* @@ -90,15 +89,24 @@ func (ldap *LdapService) Init() error { } go func() { - for range time.Tick(time.Duration(5) * time.Minute) { - err := ldap.heartbeat() - if err != nil { - tlog.App.Error().Err(err).Msg("LDAP connection heartbeat failed") - if reconnectErr := ldap.reconnect(); reconnectErr != nil { - tlog.App.Error().Err(reconnectErr).Msg("Failed to reconnect to LDAP server") - continue + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + err := ldap.heartbeat() + if err != nil { + ldap.log.App.Warn().Err(err).Msg("LDAP connection heartbeat failed, attempting to reconnect") + if reconnectErr := ldap.reconnect(); reconnectErr != nil { + ldap.log.App.Error().Err(reconnectErr).Msg("Failed to reconnect to LDAP server") + continue + } + ldap.log.App.Info().Msg("Successfully reconnected to LDAP server") } - tlog.App.Info().Msg("Successfully reconnected to LDAP server") + case <-ldap.context.Done(): + ldap.log.App.Debug().Msg("LDAP service context cancelled, stopping heartbeat") + return } } }() @@ -120,13 +128,13 @@ func (ldap *LdapService) connect() (*ldapgo.Conn, error) { // 2. conn.StartTLS(tlsConfig) // 3. conn.externalBind() if ldap.cert != nil { - conn, err = ldapgo.DialURL(ldap.config.Address, ldapgo.DialWithTLSConfig(&tls.Config{ + conn, err = ldapgo.DialURL(ldap.config.LDAP.Address, ldapgo.DialWithTLSConfig(&tls.Config{ MinVersion: tls.VersionTLS12, Certificates: []tls.Certificate{*ldap.cert}, })) } else { - conn, err = ldapgo.DialURL(ldap.config.Address, ldapgo.DialWithTLSConfig(&tls.Config{ - InsecureSkipVerify: ldap.config.Insecure, + conn, err = ldapgo.DialURL(ldap.config.LDAP.Address, ldapgo.DialWithTLSConfig(&tls.Config{ + InsecureSkipVerify: ldap.config.LDAP.Insecure, MinVersion: tls.VersionTLS12, })) } @@ -146,10 +154,10 @@ func (ldap *LdapService) connect() (*ldapgo.Conn, error) { func (ldap *LdapService) GetUserDN(username string) (string, error) { // Escape the username to prevent LDAP injection escapedUsername := ldapgo.EscapeFilter(username) - filter := fmt.Sprintf(ldap.config.SearchFilter, escapedUsername) + filter := fmt.Sprintf(ldap.config.LDAP.SearchFilter, escapedUsername) searchRequest := ldapgo.NewSearchRequest( - ldap.config.BaseDN, + ldap.config.LDAP.BaseDN, ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false, filter, []string{"dn"}, @@ -176,7 +184,7 @@ func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) { escapedUserDN := ldapgo.EscapeFilter(userDN) searchRequest := ldapgo.NewSearchRequest( - ldap.config.BaseDN, + ldap.config.LDAP.BaseDN, ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false, fmt.Sprintf("(&(objectclass=groupOfUniqueNames)(uniquemember=%s))", escapedUserDN), []string{"dn"}, @@ -224,7 +232,7 @@ func (ldap *LdapService) BindService(rebind bool) error { if ldap.cert != nil { return ldap.conn.ExternalBind() } - return ldap.conn.Bind(ldap.config.BindDN, ldap.config.BindPassword) + return ldap.conn.Bind(ldap.config.LDAP.BindDN, ldap.config.LDAP.BindPassword) } func (ldap *LdapService) Bind(userDN string, password string) error { @@ -238,7 +246,7 @@ func (ldap *LdapService) Bind(userDN string, password string) error { } func (ldap *LdapService) heartbeat() error { - tlog.App.Debug().Msg("Performing LDAP connection heartbeat") + ldap.log.App.Debug().Msg("Performing LDAP connection heartbeat") searchRequest := ldapgo.NewSearchRequest( "", @@ -260,7 +268,7 @@ func (ldap *LdapService) heartbeat() error { } func (ldap *LdapService) reconnect() error { - tlog.App.Info().Msg("Reconnecting to LDAP server") + ldap.log.App.Info().Msg("Attempting to reconnect to LDAP server") exp := backoff.NewExponentialBackOff() exp.InitialInterval = 500 * time.Millisecond diff --git a/internal/service/oauth_broker_service.go b/internal/service/oauth_broker_service.go index 15823c47..c3bfec9c 100644 --- a/internal/service/oauth_broker_service.go +++ b/internal/service/oauth_broker_service.go @@ -2,7 +2,7 @@ package service import ( "github.com/tinyauthapp/tinyauth/internal/model" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" "slices" @@ -19,6 +19,8 @@ type OAuthServiceImpl interface { } type OAuthBrokerService struct { + log *logger.Logger + services map[string]OAuthServiceImpl configs map[string]model.OAuthServiceConfig } @@ -28,7 +30,10 @@ var presets = map[string]func(config model.OAuthServiceConfig) *OAuthService{ "google": newGoogleOAuthService, } -func NewOAuthBrokerService(configs map[string]model.OAuthServiceConfig) *OAuthBrokerService { +func NewOAuthBrokerService( + log *logger.Logger, + configs map[string]model.OAuthServiceConfig, +) *OAuthBrokerService { return &OAuthBrokerService{ services: make(map[string]OAuthServiceImpl), configs: configs, @@ -39,10 +44,10 @@ func (broker *OAuthBrokerService) Init() error { for name, cfg := range broker.configs { if presetFunc, exists := presets[name]; exists { broker.services[name] = presetFunc(cfg) - tlog.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset") + broker.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset") } else { broker.services[name] = NewOAuthService(cfg, name) - tlog.App.Debug().Str("service", name).Msg("Loaded OAuth service from config") + broker.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from custom config") } } return nil diff --git a/internal/service/oidc_service.go b/internal/service/oidc_service.go index 1e1c1986..da69eb96 100644 --- a/internal/service/oidc_service.go +++ b/internal/service/oidc_service.go @@ -25,7 +25,7 @@ import ( "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/utils" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) var ( @@ -111,17 +111,13 @@ type AuthorizeRequest struct { CodeChallengeMethod string `json:"code_challenge_method"` } -type OIDCServiceConfig struct { - Clients map[string]model.OIDCClientConfig - PrivateKeyPath string - PublicKeyPath string - Issuer string - SessionExpiry int -} - type OIDCService struct { - config OIDCServiceConfig - queries *repository.Queries + log *logger.Logger + config model.Config + runtime model.RuntimeConfig + queries *repository.Queries + context context.Context + clients map[string]model.OIDCClientConfig privateKey *rsa.PrivateKey publicKey crypto.PublicKey @@ -129,10 +125,18 @@ type OIDCService struct { isConfigured bool } -func NewOIDCService(config OIDCServiceConfig, queries *repository.Queries) *OIDCService { +func NewOIDCService( + log *logger.Logger, + config model.Config, + runtime model.RuntimeConfig, + queries *repository.Queries, + context context.Context) *OIDCService { return &OIDCService{ + log: log, config: config, + runtime: runtime, queries: queries, + context: context, } } @@ -142,7 +146,7 @@ func (service *OIDCService) IsConfigured() bool { func (service *OIDCService) Init() error { // If not configured, skip init - if len(service.config.Clients) == 0 { + if len(service.runtime.OIDCClients) == 0 { service.isConfigured = false return nil } @@ -150,7 +154,7 @@ func (service *OIDCService) Init() error { service.isConfigured = true // Ensure issuer is https - uissuer, err := url.Parse(service.config.Issuer) + uissuer, err := url.Parse(service.runtime.AppURL) if err != nil { return err @@ -163,14 +167,14 @@ func (service *OIDCService) Init() error { service.issuer = fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host) // Create/load private and public keys - if strings.TrimSpace(service.config.PrivateKeyPath) == "" || - strings.TrimSpace(service.config.PublicKeyPath) == "" { + if strings.TrimSpace(service.config.OIDC.PrivateKeyPath) == "" || + strings.TrimSpace(service.config.OIDC.PublicKeyPath) == "" { return errors.New("private key path and public key path are required") } var privateKey *rsa.PrivateKey - fprivateKey, err := os.ReadFile(service.config.PrivateKeyPath) + fprivateKey, err := os.ReadFile(service.config.OIDC.PrivateKeyPath) if err != nil && !errors.Is(err, os.ErrNotExist) { return err @@ -189,8 +193,8 @@ func (service *OIDCService) Init() error { Type: "RSA PRIVATE KEY", Bytes: der, }) - tlog.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key") - err = os.WriteFile(service.config.PrivateKeyPath, encoded, 0600) + service.log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key") + err = os.WriteFile(service.config.OIDC.PrivateKeyPath, encoded, 0600) if err != nil { return err } @@ -200,7 +204,7 @@ func (service *OIDCService) Init() error { if block == nil { return errors.New("failed to decode private key") } - tlog.App.Trace().Str("type", block.Type).Msg("Loaded private key") + service.log.App.Trace().Str("type", block.Type).Msg("Loaded private key") privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes) if err != nil { return err @@ -208,7 +212,7 @@ func (service *OIDCService) Init() error { service.privateKey = privateKey } - fpublicKey, err := os.ReadFile(service.config.PublicKeyPath) + fpublicKey, err := os.ReadFile(service.config.OIDC.PublicKeyPath) if err != nil && !errors.Is(err, os.ErrNotExist) { return err @@ -224,8 +228,8 @@ func (service *OIDCService) Init() error { Type: "RSA PUBLIC KEY", Bytes: der, }) - tlog.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key") - err = os.WriteFile(service.config.PublicKeyPath, encoded, 0644) + service.log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key") + err = os.WriteFile(service.config.OIDC.PublicKeyPath, encoded, 0644) if err != nil { return err } @@ -235,7 +239,7 @@ func (service *OIDCService) Init() error { if block == nil { return errors.New("failed to decode public key") } - tlog.App.Trace().Str("type", block.Type).Msg("Loaded public key") + service.log.App.Trace().Str("type", block.Type).Msg("Loaded public key") switch block.Type { case "RSA PUBLIC KEY": publicKey, err := x509.ParsePKCS1PublicKey(block.Bytes) @@ -257,7 +261,7 @@ func (service *OIDCService) Init() error { // We will reorganize the client into a map with the client ID as the key service.clients = make(map[string]model.OIDCClientConfig) - for id, client := range service.config.Clients { + for id, client := range service.config.OIDC.Clients { client.ID = id if client.Name == "" { client.Name = utils.Capitalize(client.ID) @@ -273,9 +277,12 @@ func (service *OIDCService) Init() error { } client.ClientSecretFile = "" service.clients[id] = client - tlog.App.Info().Str("id", client.ID).Msg("Registered OIDC client") + service.log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration") } + // Start cleanup routine + go service.cleanupRoutine() + return nil } @@ -307,7 +314,7 @@ func (service *OIDCService) ValidateAuthorizeParams(req AuthorizeRequest) error return errors.New("invalid_scope") } if !slices.Contains(SupportedScopes, scope) { - tlog.App.Warn().Str("scope", scope).Msg("Unsupported OIDC scope, will be ignored") + service.log.App.Warn().Str("scope", scope).Msg("Requested unsupported scope") } } @@ -357,7 +364,7 @@ func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, r entry.CodeChallenge = req.CodeChallenge } else { entry.CodeChallenge = service.hashAndEncodePKCE(req.CodeChallenge) - tlog.App.Warn().Msg("Received plain PKCE code challenge, it's recommended to use S256 for better security") + service.log.App.Warn().Msg("Using plain PKCE code challenge method is not recommended, consider switching to S256 for better security") } } @@ -449,7 +456,7 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string, client func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user repository.OidcUserinfo, scope string, nonce string) (string, error) { createdAt := time.Now().Unix() - expiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix() + expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix() hasher := sha256.New() @@ -529,16 +536,16 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client model.OID accessToken := utils.GenerateString(32) refreshToken := utils.GenerateString(32) - tokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix() + tokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix() // Refresh token lives double the time of an access token but can't be used to access userinfo - refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry*2) * time.Second).Unix() + refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry*2) * time.Second).Unix() tokenResponse := TokenResponse{ AccessToken: accessToken, RefreshToken: refreshToken, TokenType: "Bearer", - ExpiresIn: int64(service.config.SessionExpiry), + ExpiresIn: int64(service.config.Auth.SessionExpiry), IDToken: idToken, Scope: strings.ReplaceAll(codeEntry.Scope, ",", " "), } @@ -598,14 +605,14 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri accessToken := utils.GenerateString(32) newRefreshToken := utils.GenerateString(32) - tokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix() - refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry*2) * time.Second).Unix() + tokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix() + refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry*2) * time.Second).Unix() tokenResponse := TokenResponse{ AccessToken: accessToken, RefreshToken: newRefreshToken, TokenType: "Bearer", - ExpiresIn: int64(service.config.SessionExpiry), + ExpiresIn: int64(service.config.Auth.SessionExpiry), IDToken: idToken, Scope: strings.ReplaceAll(entry.Scope, ",", " "), } @@ -748,56 +755,64 @@ func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) er } // Cleanup routine - Resource heavy due to the linked tables -func (service *OIDCService) Cleanup() { - // We need a context for the routine - ctx := context.Background() +func (service *OIDCService) cleanupRoutine() { ticker := time.NewTicker(time.Duration(30) * time.Minute) defer ticker.Stop() - for range ticker.C { - currentTime := time.Now().Unix() + for { + select { + case <-ticker.C: + service.log.App.Debug().Msg("Starting OIDC cleanup routine") - // For the OIDC tokens, if they are expired we delete the userinfo and codes - expiredTokens, err := service.queries.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{ - TokenExpiresAt: currentTime, - RefreshTokenExpiresAt: currentTime, - }) + currentTime := time.Now().Unix() - if err != nil { - tlog.App.Warn().Err(err).Msg("Failed to delete expired tokens") - } - - for _, expiredToken := range expiredTokens { - err := service.DeleteOldSession(ctx, expiredToken.Sub) - if err != nil { - tlog.App.Warn().Err(err).Msg("Failed to delete old session") - } - } - - // For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything - expiredCodes, err := service.queries.DeleteExpiredOidcCodes(ctx, currentTime) - - if err != nil { - tlog.App.Warn().Err(err).Msg("Failed to delete expired codes") - } - - for _, expiredCode := range expiredCodes { - token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub) + // For the OIDC tokens, if they are expired we delete the userinfo and codes + expiredTokens, err := service.queries.DeleteExpiredOidcTokens(service.context, repository.DeleteExpiredOidcTokensParams{ + TokenExpiresAt: currentTime, + RefreshTokenExpiresAt: currentTime, + }) if err != nil { - if errors.Is(err, sql.ErrNoRows) { - continue - } - tlog.App.Warn().Err(err).Msg("Failed to get OIDC token by sub") + service.log.App.Warn().Err(err).Msg("Failed to delete expired tokens") } - if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime { - err := service.DeleteOldSession(ctx, expiredCode.Sub) + for _, expiredToken := range expiredTokens { + err := service.DeleteOldSession(service.context, expiredToken.Sub) if err != nil { - tlog.App.Warn().Err(err).Msg("Failed to delete session") + service.log.App.Warn().Err(err).Msg("Failed to delete session for expired token") } } + + // For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything + expiredCodes, err := service.queries.DeleteExpiredOidcCodes(service.context, currentTime) + + if err != nil { + service.log.App.Warn().Err(err).Msg("Failed to delete expired codes") + } + + for _, expiredCode := range expiredCodes { + token, err := service.queries.GetOidcTokenBySub(service.context, expiredCode.Sub) + + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + continue + } + service.log.App.Warn().Err(err).Msg("Failed to get token by sub for expired code") + } + + if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime { + err := service.DeleteOldSession(service.context, expiredCode.Sub) + if err != nil { + service.log.App.Warn().Err(err).Msg("Failed to delete session for expired code") + } + } + } + + service.log.App.Debug().Msg("Finished OIDC cleanup routine") + case <-service.context.Done(): + service.log.App.Debug().Msg("OIDC cleanup routine context cancelled, stopping") + return } } } diff --git a/internal/utils/app_utils.go b/internal/utils/app_utils.go index d021c083..6413755b 100644 --- a/internal/utils/app_utils.go +++ b/internal/utils/app_utils.go @@ -7,8 +7,6 @@ import ( "net/url" "strings" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" - "github.com/weppos/publicsuffix-go/publicsuffix" ) @@ -28,7 +26,6 @@ func GetCookieDomain(u string) (string, error) { parts := strings.Split(host, ".") if len(parts) == 2 { - tlog.App.Warn().Msgf("Running on the root domain, cookies will be set for .%v", host) return host, nil }