refactor: rework logging and cancellation in services

This commit is contained in:
Stavros
2026-05-08 17:18:39 +03:00
parent 55b53c77bf
commit e214d6d8d4
12 changed files with 315 additions and 295 deletions
+2 -5
View File
@@ -29,10 +29,7 @@ func (app *BootstrapApp) setupRouter() error {
} }
} }
contextMiddleware := middleware.NewContextMiddleware(middleware.ContextMiddlewareConfig{ contextMiddleware := middleware.NewContextMiddleware(app.log, app.runtime, app.services.authService, app.services.oauthBrokerService)
CookieDomain: app.runtime.CookieDomain,
SessionCookieName: app.runtime.SessionCookieName,
}, app.services.authService, app.services.oauthBrokerService)
err := contextMiddleware.Init() err := contextMiddleware.Init()
@@ -52,7 +49,7 @@ func (app *BootstrapApp) setupRouter() error {
engine.Use(uiMiddleware.Middleware()) engine.Use(uiMiddleware.Middleware())
zerologMiddleware := middleware.NewZerologMiddleware() zerologMiddleware := middleware.NewZerologMiddleware(app.log)
err = zerologMiddleware.Init() err = zerologMiddleware.Init()
+11 -38
View File
@@ -4,21 +4,11 @@ import (
"fmt" "fmt"
"os" "os"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
) )
func (app *BootstrapApp) setupServices() error { func (app *BootstrapApp) setupServices() error {
ldapService := service.NewLdapService(service.LdapServiceConfig{ ldapService := service.NewLdapService(app.log, app.config, app.ctx)
Address: app.config.LDAP.Address,
BindDN: app.config.LDAP.BindDN,
BindPassword: app.config.LDAP.BindPassword,
BaseDN: app.config.LDAP.BaseDN,
Insecure: app.config.LDAP.Insecure,
SearchFilter: app.config.LDAP.SearchFilter,
AuthCert: app.config.LDAP.AuthCert,
AuthKey: app.config.LDAP.AuthKey,
})
err := ldapService.Init() err := ldapService.Init()
@@ -32,10 +22,12 @@ func (app *BootstrapApp) setupServices() error {
useKubernetes := app.config.LabelProvider == "kubernetes" || useKubernetes := app.config.LabelProvider == "kubernetes" ||
(app.config.LabelProvider == "auto" && os.Getenv("KUBERNETES_SERVICE_HOST") != "") (app.config.LabelProvider == "auto" && os.Getenv("KUBERNETES_SERVICE_HOST") != "")
var labelProvider service.LabelProviderImpl
if useKubernetes { if useKubernetes {
app.log.App.Debug().Msg("Using Kubernetes label provider") app.log.App.Debug().Msg("Using Kubernetes label provider")
kubernetesService := service.NewKubernetesService() kubernetesService := service.NewKubernetesService(app.log, app.ctx)
err = kubernetesService.Init() err = kubernetesService.Init()
@@ -44,11 +36,11 @@ func (app *BootstrapApp) setupServices() error {
} }
app.services.kubernetesService = kubernetesService app.services.kubernetesService = kubernetesService
app.runtime.LabelProvider = model.LabelProviderKubernetes labelProvider = kubernetesService
} else { } else {
app.log.App.Debug().Msg("Using Docker label provider") app.log.App.Debug().Msg("Using Docker label provider")
dockerService := service.NewDockerService() dockerService := service.NewDockerService(app.log, app.ctx)
err = dockerService.Init() err = dockerService.Init()
@@ -57,10 +49,10 @@ func (app *BootstrapApp) setupServices() error {
} }
app.services.dockerService = dockerService app.services.dockerService = dockerService
app.runtime.LabelProvider = model.LabelProviderDocker labelProvider = dockerService
} }
accessControlsService := service.NewAccessControlsService(app.runtime.LabelProvider, app.config.Apps) accessControlsService := service.NewAccessControlsService(app.log, labelProvider, app.config.Apps)
err = accessControlsService.Init() err = accessControlsService.Init()
@@ -70,7 +62,7 @@ func (app *BootstrapApp) setupServices() error {
app.services.accessControlService = accessControlsService app.services.accessControlService = accessControlsService
oauthBrokerService := service.NewOAuthBrokerService(app.runtime.OAuthProviders) oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders)
err = oauthBrokerService.Init() err = oauthBrokerService.Init()
@@ -80,20 +72,7 @@ func (app *BootstrapApp) setupServices() error {
app.services.oauthBrokerService = oauthBrokerService app.services.oauthBrokerService = oauthBrokerService
authService := service.NewAuthService(service.AuthServiceConfig{ authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, app.services.ldapService, app.queries, app.services.oauthBrokerService)
LocalUsers: &app.runtime.LocalUsers,
OauthWhitelist: app.runtime.OAuthWhitelist,
SessionExpiry: app.config.Auth.SessionExpiry,
SessionMaxLifetime: app.config.Auth.SessionMaxLifetime,
SecureCookie: app.config.Auth.SecureCookie,
CookieDomain: app.runtime.CookieDomain,
LoginTimeout: app.config.Auth.LoginTimeout,
LoginMaxRetries: app.config.Auth.LoginMaxRetries,
SessionCookieName: app.runtime.SessionCookieName,
IP: app.config.Auth.IP,
LDAPGroupsCacheTTL: app.config.LDAP.GroupCacheTTL,
SubdomainsEnabled: app.config.Auth.SubdomainsEnabled,
}, app.services.ldapService, app.queries, app.services.oauthBrokerService)
err = authService.Init() err = authService.Init()
@@ -103,13 +82,7 @@ func (app *BootstrapApp) setupServices() error {
app.services.authService = authService app.services.authService = authService
oidcService := service.NewOIDCService(service.OIDCServiceConfig{ oidcService := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ctx)
Clients: app.config.OIDC.Clients,
PrivateKeyPath: app.config.OIDC.PrivateKeyPath,
PublicKeyPath: app.config.OIDC.PublicKeyPath,
Issuer: app.config.AppURL,
SessionExpiry: app.config.Auth.SessionExpiry,
}, app.queries)
err = oidcService.Init() err = oidcService.Init()
+1 -1
View File
@@ -375,7 +375,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
cookie, err := controller.auth.CreateSession(c, sessionCookie) cookie, err := controller.auth.CreateSession(c, sessionCookie)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to create session cookie") controller.log.App.Error().Err(err).Str("username", context.GetUsername()).Msg("Failed to create session cookie after successful TOTP verification")
c.JSON(500, gin.H{ c.JSON(500, gin.H{
"status": 500, "status": 500,
"message": "Internal Server Error", "message": "Internal Server Error",
-8
View File
@@ -13,7 +13,6 @@ type RuntimeConfig struct {
OAuthWhitelist []string OAuthWhitelist []string
ConfiguredProviders []Provider ConfiguredProviders []Provider
OIDCClients []OIDCClientConfig OIDCClients []OIDCClientConfig
LabelProvider LabelProvider
} }
type Provider struct { type Provider struct {
@@ -21,10 +20,3 @@ type Provider struct {
ID string `json:"id"` ID string `json:"id"`
OAuth bool `json:"oauth"` OAuth bool `json:"oauth"`
} }
type LabelProvider int
const (
LabelProviderDocker LabelProvider = iota
LabelProviderKubernetes
)
+12 -7
View File
@@ -4,7 +4,7 @@ import (
"strings" "strings"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
type LabelProviderImpl interface { type LabelProviderImpl interface {
@@ -12,12 +12,17 @@ type LabelProviderImpl interface {
} }
type AccessControlsService struct { type AccessControlsService struct {
labelProvider LabelProvider log *logger.Logger
labelProvider LabelProviderImpl
static map[string]model.App static map[string]model.App
} }
func NewAccessControlsService(labelProvider LabelProvider, static map[string]model.App) *AccessControlsService { func NewAccessControlsService(
log *logger.Logger,
labelProvider LabelProviderImpl,
static map[string]model.App) *AccessControlsService {
return &AccessControlsService{ return &AccessControlsService{
log: log,
labelProvider: labelProvider, labelProvider: labelProvider,
static: static, static: static,
} }
@@ -31,13 +36,13 @@ func (acls *AccessControlsService) lookupStaticACLs(domain string) *model.App {
var appAcls *model.App var appAcls *model.App
for app, config := range acls.static { for app, config := range acls.static {
if config.Config.Domain == domain { if config.Config.Domain == domain {
tlog.App.Debug().Str("name", app).Msg("Found matching container by domain") acls.log.App.Debug().Str("name", app).Msg("Found matching container by domain")
appAcls = &config appAcls = &config
break // If we find a match by domain, we can stop searching break // If we find a match by domain, we can stop searching
} }
if strings.SplitN(domain, ".", 2)[0] == app { if strings.SplitN(domain, ".", 2)[0] == app {
tlog.App.Debug().Str("name", app).Msg("Found matching container by app name") acls.log.App.Debug().Str("name", app).Msg("Found matching container by app name")
appAcls = &config appAcls = &config
break // If we find a match by app name, we can stop searching break // If we find a match by app name, we can stop searching
} }
@@ -50,11 +55,11 @@ func (acls *AccessControlsService) GetAccessControls(domain string) (*model.App,
app := acls.lookupStaticACLs(domain) app := acls.lookupStaticACLs(domain)
if app != nil { if app != nil {
tlog.App.Debug().Msg("Using ACls from static configuration") acls.log.App.Debug().Msg("Using static ACLs for app")
return app, nil return app, nil
} }
// Fallback to label provider // Fallback to label provider
tlog.App.Debug().Msg("Falling back to label provider for ACLs") acls.log.App.Debug().Msg("Using label provider for app")
return acls.labelProvider.GetLabels(domain) return acls.labelProvider.GetLabels(domain)
} }
+87 -85
View File
@@ -14,7 +14,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"slices" "slices"
@@ -72,39 +72,40 @@ type Lockdown struct {
ActiveUntil time.Time ActiveUntil time.Time
} }
type AuthServiceConfig struct {
LocalUsers *[]model.LocalUser
OauthWhitelist []string
SessionExpiry int
SessionMaxLifetime int
SecureCookie bool
CookieDomain string
LoginTimeout int
LoginMaxRetries int
SessionCookieName string
IP model.IPConfig
LDAPGroupsCacheTTL int
SubdomainsEnabled bool
}
type AuthService struct { type AuthService struct {
config AuthServiceConfig log *logger.Logger
config model.Config
runtime model.RuntimeConfig
context context.Context
ldap *LdapService
queries *repository.Queries
oauthBroker *OAuthBrokerService
loginAttempts map[string]*LoginAttempt loginAttempts map[string]*LoginAttempt
ldapGroupsCache map[string]*LdapGroupsCache ldapGroupsCache map[string]*LdapGroupsCache
oauthPendingSessions map[string]*OAuthPendingSession oauthPendingSessions map[string]*OAuthPendingSession
oauthMutex sync.RWMutex oauthMutex sync.RWMutex
loginMutex sync.RWMutex loginMutex sync.RWMutex
ldapGroupsMutex sync.RWMutex ldapGroupsMutex sync.RWMutex
ldap *LdapService
queries *repository.Queries
oauthBroker *OAuthBrokerService
lockdown *Lockdown lockdown *Lockdown
lockdownCtx context.Context lockdownCtx context.Context
lockdownCancelFunc context.CancelFunc lockdownCancelFunc context.CancelFunc
} }
func NewAuthService(config AuthServiceConfig, ldap *LdapService, queries *repository.Queries, oauthBroker *OAuthBrokerService) *AuthService { func NewAuthService(
log *logger.Logger,
config model.Config,
runtime model.RuntimeConfig,
context context.Context,
ldap *LdapService,
queries *repository.Queries,
oauthBroker *OAuthBrokerService,
) *AuthService {
return &AuthService{ return &AuthService{
log: log,
runtime: runtime,
context: context,
config: config, config: config,
loginAttempts: make(map[string]*LoginAttempt), loginAttempts: make(map[string]*LoginAttempt),
ldapGroupsCache: make(map[string]*LdapGroupsCache), ldapGroupsCache: make(map[string]*LdapGroupsCache),
@@ -173,10 +174,10 @@ func (auth *AuthService) CheckUserPassword(search model.UserSearch, password str
} }
func (auth *AuthService) GetLocalUser(username string) *model.LocalUser { func (auth *AuthService) GetLocalUser(username string) *model.LocalUser {
if auth.config.LocalUsers == nil { if auth.runtime.LocalUsers == nil {
return nil return nil
} }
for _, user := range *auth.config.LocalUsers { for _, user := range auth.runtime.LocalUsers {
if user.Username == username { if user.Username == username {
return &user return &user
} }
@@ -209,7 +210,7 @@ func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
auth.ldapGroupsMutex.Lock() auth.ldapGroupsMutex.Lock()
auth.ldapGroupsCache[userDN] = &LdapGroupsCache{ auth.ldapGroupsCache[userDN] = &LdapGroupsCache{
Groups: groups, Groups: groups,
Expires: time.Now().Add(time.Duration(auth.config.LDAPGroupsCacheTTL) * time.Second), Expires: time.Now().Add(time.Duration(auth.config.LDAP.GroupCacheTTL) * time.Second),
} }
auth.ldapGroupsMutex.Unlock() auth.ldapGroupsMutex.Unlock()
@@ -228,7 +229,7 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
return true, remaining return true, remaining
} }
if auth.config.LoginMaxRetries <= 0 || auth.config.LoginTimeout <= 0 { if auth.config.Auth.LoginMaxRetries <= 0 || auth.config.Auth.LoginTimeout <= 0 {
return false, 0 return false, 0
} }
@@ -246,7 +247,7 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
} }
func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) { func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
if auth.config.LoginMaxRetries <= 0 || auth.config.LoginTimeout <= 0 { if auth.config.Auth.LoginMaxRetries <= 0 || auth.config.Auth.LoginTimeout <= 0 {
return return
} }
@@ -277,14 +278,14 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
attempt.FailedAttempts++ attempt.FailedAttempts++
if attempt.FailedAttempts >= auth.config.LoginMaxRetries { if attempt.FailedAttempts >= auth.config.Auth.LoginMaxRetries {
attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.LoginTimeout) * time.Second) attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
tlog.App.Warn().Str("identifier", identifier).Int("timeout", auth.config.LoginTimeout).Msg("Account locked due to too many failed login attempts") auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", attempt.FailedAttempts).Msg("Account locked due to too many failed login attempts")
} }
} }
func (auth *AuthService) IsEmailWhitelisted(email string) bool { func (auth *AuthService) IsEmailWhitelisted(email string) bool {
return utils.CheckFilter(strings.Join(auth.config.OauthWhitelist, ","), email) return utils.CheckFilter(strings.Join(auth.runtime.OAuthWhitelist, ","), email)
} }
func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) { func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) {
@@ -299,7 +300,7 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
if data.TotpPending { if data.TotpPending {
expiry = 3600 expiry = 3600
} else { } else {
expiry = auth.config.SessionExpiry expiry = auth.config.Auth.SessionExpiry
} }
expiresAt := time.Now().Add(time.Duration(expiry) * time.Second) expiresAt := time.Now().Add(time.Duration(expiry) * time.Second)
@@ -325,13 +326,13 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
} }
return &http.Cookie{ return &http.Cookie{
Name: auth.config.SessionCookieName, Name: auth.runtime.SessionCookieName,
Value: session.UUID, Value: session.UUID,
Path: "/", Path: "/",
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain), Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
Expires: expiresAt, Expires: expiresAt,
MaxAge: int(time.Until(expiresAt).Seconds()), MaxAge: int(time.Until(expiresAt).Seconds()),
Secure: auth.config.SecureCookie, Secure: auth.config.Auth.SecureCookie,
HttpOnly: true, HttpOnly: true,
SameSite: http.SameSiteLaxMode, SameSite: http.SameSiteLaxMode,
}, nil }, nil
@@ -348,8 +349,8 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http
var refreshThreshold int64 var refreshThreshold int64
if auth.config.SessionExpiry <= int(time.Hour.Seconds()) { if auth.config.Auth.SessionExpiry <= int(time.Hour.Seconds()) {
refreshThreshold = int64(auth.config.SessionExpiry / 2) refreshThreshold = int64(auth.config.Auth.SessionExpiry / 2)
} else { } else {
refreshThreshold = int64(time.Hour.Seconds()) refreshThreshold = int64(time.Hour.Seconds())
} }
@@ -378,13 +379,13 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http
} }
return &http.Cookie{ return &http.Cookie{
Name: auth.config.SessionCookieName, Name: auth.runtime.SessionCookieName,
Value: session.UUID, Value: session.UUID,
Path: "/", Path: "/",
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain), Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second), Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second),
MaxAge: int(newExpiry - currentTime), MaxAge: int(newExpiry - currentTime),
Secure: auth.config.SecureCookie, Secure: auth.config.Auth.SecureCookie,
HttpOnly: true, HttpOnly: true,
SameSite: http.SameSiteLaxMode, SameSite: http.SameSiteLaxMode,
}, nil }, nil
@@ -395,7 +396,7 @@ func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.
err := auth.queries.DeleteSession(ctx, uuid) err := auth.queries.DeleteSession(ctx, uuid)
if err != nil { if err != nil {
tlog.App.Warn().Err(err).Msg("Failed to delete session from database, proceeding to clear cookie anyway") auth.log.App.Error().Err(err).Str("uuid", uuid).Msg("Failed to delete session from database")
} }
err = auth.queries.DeleteSession(ctx, uuid) err = auth.queries.DeleteSession(ctx, uuid)
@@ -405,13 +406,13 @@ func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.
} }
return &http.Cookie{ return &http.Cookie{
Name: auth.config.SessionCookieName, Name: auth.runtime.SessionCookieName,
Value: "", Value: "",
Path: "/", Path: "/",
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain), Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
Expires: time.Now(), Expires: time.Now(),
MaxAge: -1, MaxAge: -1,
Secure: auth.config.SecureCookie, Secure: auth.config.Auth.SecureCookie,
HttpOnly: true, HttpOnly: true,
SameSite: http.SameSiteLaxMode, SameSite: http.SameSiteLaxMode,
}, nil }, nil
@@ -429,8 +430,8 @@ func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*reposito
currentTime := time.Now().Unix() currentTime := time.Now().Unix()
if auth.config.SessionMaxLifetime != 0 && session.CreatedAt != 0 { if auth.config.Auth.SessionMaxLifetime != 0 && session.CreatedAt != 0 {
if currentTime-session.CreatedAt > int64(auth.config.SessionMaxLifetime) { if currentTime-session.CreatedAt > int64(auth.config.Auth.SessionMaxLifetime) {
err = auth.queries.DeleteSession(ctx, uuid) err = auth.queries.DeleteSession(ctx, uuid)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to delete expired session: %w", err) return nil, fmt.Errorf("failed to delete expired session: %w", err)
@@ -451,7 +452,7 @@ func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*reposito
} }
func (auth *AuthService) LocalAuthConfigured() bool { func (auth *AuthService) LocalAuthConfigured() bool {
return auth.config.LocalUsers != nil && len(*auth.config.LocalUsers) > 0 return len(auth.runtime.LocalUsers) > 0
} }
func (auth *AuthService) LDAPAuthConfigured() bool { func (auth *AuthService) LDAPAuthConfigured() bool {
@@ -464,18 +465,18 @@ func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext
} }
if context.Provider == model.ProviderOAuth { if context.Provider == model.ProviderOAuth {
tlog.App.Debug().Msg("Checking OAuth whitelist") auth.log.App.Debug().Msg("User is an OAuth user, checking OAuth whitelist")
return utils.CheckFilter(acls.OAuth.Whitelist, context.OAuth.Email) return utils.CheckFilter(acls.OAuth.Whitelist, context.OAuth.Email)
} }
if acls.Users.Block != "" { if acls.Users.Block != "" {
tlog.App.Debug().Msg("Checking blocked users") auth.log.App.Debug().Msg("Checking users block list")
if utils.CheckFilter(acls.Users.Block, context.GetUsername()) { if utils.CheckFilter(acls.Users.Block, context.GetUsername()) {
return false return false
} }
} }
tlog.App.Debug().Msg("Checking users") auth.log.App.Debug().Msg("Checking users allow list")
return utils.CheckFilter(acls.Users.Allow, context.GetUsername()) return utils.CheckFilter(acls.Users.Allow, context.GetUsername())
} }
@@ -485,23 +486,23 @@ func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContex
} }
if !context.IsOAuth() { if !context.IsOAuth() {
tlog.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check") auth.log.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check")
return false return false
} }
if _, ok := model.OverrideProviders[context.OAuth.ID]; ok { if _, ok := model.OverrideProviders[context.OAuth.ID]; ok {
tlog.App.Debug().Msg("Provider override for OAuth groups enabled, skipping group check") auth.log.App.Debug().Str("provider", context.OAuth.ID).Msg("Provider override detected, skipping group check")
return true return true
} }
for _, userGroup := range context.OAuth.Groups { for _, userGroup := range context.OAuth.Groups {
if utils.CheckFilter(acls.OAuth.Groups, strings.TrimSpace(userGroup)) { if utils.CheckFilter(acls.OAuth.Groups, strings.TrimSpace(userGroup)) {
tlog.App.Trace().Str("group", userGroup).Str("required", acls.OAuth.Groups).Msg("User group matched") auth.log.App.Trace().Str("group", userGroup).Str("required", acls.OAuth.Groups).Msg("User group matched")
return true return true
} }
} }
tlog.App.Debug().Msg("No groups matched") auth.log.App.Debug().Msg("No groups matched")
return false return false
} }
@@ -511,18 +512,18 @@ func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext
} }
if !context.IsLDAP() { if !context.IsLDAP() {
tlog.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check") auth.log.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check")
return false return false
} }
for _, userGroup := range context.LDAP.Groups { for _, userGroup := range context.LDAP.Groups {
if utils.CheckFilter(acls.LDAP.Groups, strings.TrimSpace(userGroup)) { if utils.CheckFilter(acls.LDAP.Groups, strings.TrimSpace(userGroup)) {
tlog.App.Trace().Str("group", userGroup).Str("required", acls.LDAP.Groups).Msg("User group matched") auth.log.App.Trace().Str("group", userGroup).Str("required", acls.LDAP.Groups).Msg("User group matched")
return true return true
} }
} }
tlog.App.Debug().Msg("No groups matched") auth.log.App.Debug().Msg("No groups matched")
return false return false
} }
@@ -566,17 +567,17 @@ func (auth *AuthService) CheckIP(ip string, acls *model.App) bool {
} }
// Merge the global and app IP filter // Merge the global and app IP filter
blockedIps := append(auth.config.IP.Block, acls.IP.Block...) blockedIps := append(auth.config.Auth.IP.Block, acls.IP.Block...)
allowedIPs := append(auth.config.IP.Allow, acls.IP.Allow...) allowedIPs := append(auth.config.Auth.IP.Allow, acls.IP.Allow...)
for _, blocked := range blockedIps { for _, blocked := range blockedIps {
res, err := utils.FilterIP(blocked, ip) res, err := utils.FilterIP(blocked, ip)
if err != nil { if err != nil {
tlog.App.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list") auth.log.App.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list")
continue continue
} }
if res { if res {
tlog.App.Debug().Str("ip", ip).Str("item", blocked).Msg("IP is in blocked list, denying access") auth.log.App.Debug().Str("ip", ip).Str("item", blocked).Msg("IP is in block list, denying access")
return false return false
} }
} }
@@ -584,21 +585,21 @@ func (auth *AuthService) CheckIP(ip string, acls *model.App) bool {
for _, allowed := range allowedIPs { for _, allowed := range allowedIPs {
res, err := utils.FilterIP(allowed, ip) res, err := utils.FilterIP(allowed, ip)
if err != nil { if err != nil {
tlog.App.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list") auth.log.App.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list")
continue continue
} }
if res { if res {
tlog.App.Debug().Str("ip", ip).Str("item", allowed).Msg("IP is in allowed list, allowing access") auth.log.App.Debug().Str("ip", ip).Str("item", allowed).Msg("IP is in allow list, allowing access")
return true return true
} }
} }
if len(allowedIPs) > 0 { if len(allowedIPs) > 0 {
tlog.App.Debug().Str("ip", ip).Msg("IP not in allow list, denying access") auth.log.App.Debug().Str("ip", ip).Msg("IP not in allow list, denying access")
return false return false
} }
tlog.App.Debug().Str("ip", ip).Msg("IP not in allow or block list, allowing by default") auth.log.App.Debug().Str("ip", ip).Msg("IP not in any block or allow list, allowing access by default")
return true return true
} }
@@ -610,16 +611,16 @@ func (auth *AuthService) IsBypassedIP(ip string, acls *model.App) bool {
for _, bypassed := range acls.IP.Bypass { for _, bypassed := range acls.IP.Bypass {
res, err := utils.FilterIP(bypassed, ip) res, err := utils.FilterIP(bypassed, ip)
if err != nil { if err != nil {
tlog.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list") auth.log.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list")
continue continue
} }
if res { if res {
tlog.App.Debug().Str("ip", ip).Str("item", bypassed).Msg("IP is in bypass list, allowing access") auth.log.App.Debug().Str("ip", ip).Str("item", bypassed).Msg("IP is in bypass list, skipping authentication")
return true return true
} }
} }
tlog.App.Debug().Str("ip", ip).Msg("IP not in bypass list, continuing with authentication") auth.log.App.Debug().Str("ip", ip).Msg("IP not in bypass list, proceeding with authentication")
return false return false
} }
@@ -726,18 +727,23 @@ func (auth *AuthService) CleanupOAuthSessionsRoutine() {
ticker := time.NewTicker(30 * time.Minute) ticker := time.NewTicker(30 * time.Minute)
defer ticker.Stop() defer ticker.Stop()
for range ticker.C { for {
auth.oauthMutex.Lock() select {
case <-ticker.C:
auth.oauthMutex.Lock()
now := time.Now() now := time.Now()
for sessionId, session := range auth.oauthPendingSessions { for sessionId, session := range auth.oauthPendingSessions {
if now.After(session.ExpiresAt) { if now.After(session.ExpiresAt) {
delete(auth.oauthPendingSessions, sessionId) delete(auth.oauthPendingSessions, sessionId)
}
} }
}
auth.oauthMutex.Unlock() auth.oauthMutex.Unlock()
case <-auth.context.Done():
return
}
} }
} }
@@ -806,11 +812,11 @@ func (auth *AuthService) lockdownMode() {
auth.loginMutex.Lock() auth.loginMutex.Lock()
tlog.App.Warn().Msg("Multiple login attempts detected, possibly DDOS attack. Activating temporary lockdown.") auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")
auth.lockdown = &Lockdown{ auth.lockdown = &Lockdown{
Active: true, Active: true,
ActiveUntil: time.Now().Add(time.Duration(auth.config.LoginTimeout) * time.Second), ActiveUntil: time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second),
} }
// At this point all login attemps will also expire so, // At this point all login attemps will also expire so,
@@ -827,11 +833,14 @@ func (auth *AuthService) lockdownMode() {
// Timer expired, end lockdown // Timer expired, end lockdown
case <-ctx.Done(): case <-ctx.Done():
// Context cancelled, end lockdown // Context cancelled, end lockdown
case <-auth.context.Done():
// Service is shutting down, end lockdown
} }
auth.loginMutex.Lock() auth.loginMutex.Lock()
tlog.App.Info().Msg("Lockdown period ended, resuming normal operation") auth.log.App.Info().Msg("Exiting lockdown mode")
auth.lockdown = nil auth.lockdown = nil
auth.loginMutex.Unlock() auth.loginMutex.Unlock()
} }
@@ -845,10 +854,3 @@ func (auth *AuthService) ClearRateLimitsTestingOnly() {
} }
auth.loginMutex.Unlock() auth.loginMutex.Unlock()
} }
func (auth *AuthService) getCookieDomain() string {
if auth.config.SubdomainsEnabled {
return "." + auth.config.CookieDomain
}
return auth.config.CookieDomain
}
+33 -14
View File
@@ -6,20 +6,28 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/decoders" "github.com/tinyauthapp/tinyauth/internal/utils/decoders"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
container "github.com/docker/docker/api/types/container" container "github.com/docker/docker/api/types/container"
"github.com/docker/docker/client" "github.com/docker/docker/client"
) )
type DockerService struct { type DockerService struct {
client *client.Client log *logger.Logger
context context.Context client *client.Client
context context.Context
isConnected bool isConnected bool
} }
func NewDockerService() *DockerService { func NewDockerService(
return &DockerService{} log *logger.Logger,
context context.Context,
) *DockerService {
return &DockerService{
log: log,
context: context,
}
} }
func (docker *DockerService) Init() error { func (docker *DockerService) Init() error {
@@ -28,16 +36,14 @@ func (docker *DockerService) Init() error {
return err return err
} }
ctx := context.Background() client.NegotiateAPIVersion(docker.context)
client.NegotiateAPIVersion(ctx)
docker.client = client docker.client = client
docker.context = ctx
_, err = docker.client.Ping(docker.context) _, err = docker.client.Ping(docker.context)
if err != nil { if err != nil {
tlog.App.Debug().Err(err).Msg("Docker not connected") docker.log.App.Debug().Err(err).Msg("Docker not connected")
docker.isConnected = false docker.isConnected = false
docker.client = nil docker.client = nil
docker.context = nil docker.context = nil
@@ -45,7 +51,9 @@ func (docker *DockerService) Init() error {
} }
docker.isConnected = true docker.isConnected = true
tlog.App.Debug().Msg("Docker connected") docker.log.App.Debug().Msg("Docker connected successfully")
go docker.watchAndClose()
return nil return nil
} }
@@ -60,7 +68,7 @@ func (docker *DockerService) inspectContainer(containerId string) (container.Ins
func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) { func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) {
if !docker.isConnected { if !docker.isConnected {
tlog.App.Debug().Msg("Docker not connected, returning empty labels") docker.log.App.Debug().Msg("Docker service not connected, returning empty labels")
return nil, nil return nil, nil
} }
@@ -82,17 +90,28 @@ func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) {
for appName, appLabels := range labels.Apps { for appName, appLabels := range labels.Apps {
if appLabels.Config.Domain == appDomain { if appLabels.Config.Domain == appDomain {
tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by domain") docker.log.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by domain")
return &appLabels, nil return &appLabels, nil
} }
if strings.SplitN(appDomain, ".", 2)[0] == appName { if strings.SplitN(appDomain, ".", 2)[0] == appName {
tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by app name") docker.log.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by app name")
return &appLabels, nil return &appLabels, nil
} }
} }
} }
tlog.App.Debug().Msg("No matching container found, returning empty labels") docker.log.App.Debug().Str("domain", appDomain).Msg("No matching container found for domain")
return nil, nil return nil, nil
} }
func (docker *DockerService) watchAndClose() {
<-docker.context.Done()
docker.log.App.Debug().Msg("Closing Docker client")
if docker.client != nil {
err := docker.client.Close()
if err != nil {
docker.log.App.Error().Err(err).Msg("Error closing Docker client")
}
}
}
+29 -22
View File
@@ -9,7 +9,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/decoders" "github.com/tinyauthapp/tinyauth/internal/utils/decoders"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
@@ -36,8 +36,10 @@ type ingressApp struct {
} }
type KubernetesService struct { type KubernetesService struct {
log *logger.Logger
ctx context.Context
client dynamic.Interface client dynamic.Interface
ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
started bool started bool
mu sync.RWMutex mu sync.RWMutex
@@ -46,8 +48,13 @@ type KubernetesService struct {
appNameIndex map[string]ingressAppKey appNameIndex map[string]ingressAppKey
} }
func NewKubernetesService() *KubernetesService { func NewKubernetesService(
log *logger.Logger,
context context.Context,
) *KubernetesService {
return &KubernetesService{ return &KubernetesService{
log: log,
ctx: context,
ingressApps: make(map[ingressKey][]ingressApp), ingressApps: make(map[ingressKey][]ingressApp),
domainIndex: make(map[string]ingressAppKey), domainIndex: make(map[string]ingressAppKey),
appNameIndex: make(map[string]ingressAppKey), appNameIndex: make(map[string]ingressAppKey),
@@ -133,7 +140,7 @@ func (k *KubernetesService) updateFromItem(item *unstructured.Unstructured) {
} }
labels, err := decoders.DecodeLabels[model.Apps](annotations, "apps") labels, err := decoders.DecodeLabels[model.Apps](annotations, "apps")
if err != nil { if err != nil {
tlog.App.Debug().Err(err).Msg("Failed to decode labels from annotations") k.log.App.Warn().Err(err).Str("namespace", namespace).Str("name", name).Msg("Failed to decode ingress labels, skipping")
k.removeIngress(namespace, name) k.removeIngress(namespace, name)
return return
} }
@@ -161,13 +168,13 @@ func (k *KubernetesService) resyncGVR(gvr schema.GroupVersionResource) error {
list, err := k.client.Resource(gvr).List(ctx, metav1.ListOptions{}) list, err := k.client.Resource(gvr).List(ctx, metav1.ListOptions{})
if err != nil { if err != nil {
tlog.App.Debug().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to list ingresses during resync") k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to list resources for resync")
return err return err
} }
for i := range list.Items { for i := range list.Items {
k.updateFromItem(&list.Items[i]) k.updateFromItem(&list.Items[i])
} }
tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Int("count", len(list.Items)).Msg("Resynced ingress cache") k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Int("count", len(list.Items)).Msg("Resync complete")
return nil return nil
} }
@@ -181,14 +188,14 @@ func (k *KubernetesService) runWatcher(gvr schema.GroupVersionResource, w watch.
return false return false
case event, ok := <-w.ResultChan(): case event, ok := <-w.ResultChan():
if !ok { if !ok {
tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher channel closed, restarting in 5 seconds") k.log.App.Warn().Str("api", gvr.GroupVersion().String()).Msg("Watcher channel closed, restarting watcher")
w.Stop() w.Stop()
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
return true return true
} }
item, ok := event.Object.(*unstructured.Unstructured) item, ok := event.Object.(*unstructured.Unstructured)
if !ok { if !ok {
tlog.App.Warn().Str("api", gvr.GroupVersion().String()).Msg("Failed to cast watched object") k.log.App.Warn().Str("api", gvr.GroupVersion().String()).Msg("Received unexpected event object, skipping")
continue continue
} }
switch event.Type { switch event.Type {
@@ -199,7 +206,7 @@ func (k *KubernetesService) runWatcher(gvr schema.GroupVersionResource, w watch.
} }
case <-resyncTicker.C: case <-resyncTicker.C:
if err := k.resyncGVR(gvr); err != nil { if err := k.resyncGVR(gvr); err != nil {
tlog.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed") k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed during watcher run")
} }
} }
} }
@@ -210,29 +217,29 @@ func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource) {
defer resyncTicker.Stop() defer resyncTicker.Stop()
if err := k.resyncGVR(gvr); err != nil { if err := k.resyncGVR(gvr); err != nil {
tlog.App.Error().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Initial resync failed, retrying in 30 seconds") k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Initial resync failed, will retry")
time.Sleep(30 * time.Second) time.Sleep(30 * time.Second)
} }
for { for {
select { select {
case <-k.ctx.Done(): case <-k.ctx.Done():
tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Stopping watcher") k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Context cancelled, stopping watcher")
return return
case <-resyncTicker.C: case <-resyncTicker.C:
if err := k.resyncGVR(gvr); err != nil { if err := k.resyncGVR(gvr); err != nil {
tlog.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed") k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed, will retry")
} }
default: default:
ctx, cancel := context.WithCancel(k.ctx) ctx, cancel := context.WithCancel(k.ctx)
watcher, err := k.client.Resource(gvr).Watch(ctx, metav1.ListOptions{}) watcher, err := k.client.Resource(gvr).Watch(ctx, metav1.ListOptions{})
if err != nil { if err != nil {
tlog.App.Error().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to start watcher") k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to start watcher, will retry")
cancel() cancel()
time.Sleep(10 * time.Second) time.Sleep(10 * time.Second)
continue continue
} }
tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher started") k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher started successfully")
if !k.runWatcher(gvr, watcher, resyncTicker) { if !k.runWatcher(gvr, watcher, resyncTicker) {
cancel() cancel()
return return
@@ -257,7 +264,7 @@ func (k *KubernetesService) Init() error {
} }
k.client = client k.client = client
k.ctx, k.cancel = context.WithCancel(context.Background()) k.ctx, k.cancel = context.WithCancel(k.ctx)
gvr := schema.GroupVersionResource{ gvr := schema.GroupVersionResource{
Group: "networking.k8s.io", Group: "networking.k8s.io",
@@ -269,38 +276,38 @@ func (k *KubernetesService) Init() error {
defer accessCancel() defer accessCancel()
_, err = k.client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1}) _, err = k.client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1})
if err != nil { if err != nil {
tlog.App.Warn().Err(err).Msg("Insufficient permissions for networking.k8s.io/v1 Ingress, Kubernetes label provider will not work") k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled")
k.started = false k.started = false
return nil return nil
} }
tlog.App.Debug().Msg("networking.k8s.io/v1 Ingress API accessible") k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher")
go k.watchGVR(gvr) go k.watchGVR(gvr)
k.started = true k.started = true
tlog.App.Info().Msg("Kubernetes label provider initialized") k.log.App.Debug().Msg("Kubernetes label provider started successfully")
return nil return nil
} }
func (k *KubernetesService) GetLabels(appDomain string) (*model.App, error) { func (k *KubernetesService) GetLabels(appDomain string) (*model.App, error) {
if !k.started { if !k.started {
tlog.App.Debug().Msg("Kubernetes not connected, returning empty labels") k.log.App.Debug().Str("domain", appDomain).Msg("Kubernetes label provider not started, skipping")
return nil, nil return nil, nil
} }
// First check cache // First check cache
app := k.getByDomain(appDomain) app := k.getByDomain(appDomain)
if app != nil { if app != nil {
tlog.App.Debug().Str("domain", appDomain).Msg("Found labels in cache by domain") k.log.App.Debug().Str("domain", appDomain).Msg("Found labels in cache by domain")
return app, nil return app, nil
} }
appName := strings.SplitN(appDomain, ".", 2)[0] appName := strings.SplitN(appDomain, ".", 2)[0]
app = k.getByAppName(appName) app = k.getByAppName(appName)
if app != nil { if app != nil {
tlog.App.Debug().Str("domain", appDomain).Str("appName", appName).Msg("Found labels in cache by app name") k.log.App.Debug().Str("domain", appDomain).Str("appName", appName).Msg("Found labels in cache by app name")
return app, nil return app, nil
} }
tlog.App.Debug().Str("domain", appDomain).Msg("Cache miss, no matching ingress found") k.log.App.Debug().Str("domain", appDomain).Msg("No labels found for domain")
return nil, nil return nil, nil
} }
+44 -36
View File
@@ -9,31 +9,30 @@ import (
"github.com/cenkalti/backoff/v5" "github.com/cenkalti/backoff/v5"
ldapgo "github.com/go-ldap/ldap/v3" ldapgo "github.com/go-ldap/ldap/v3"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
type LdapServiceConfig struct {
Address string
BindDN string
BindPassword string
BaseDN string
Insecure bool
SearchFilter string
AuthCert string
AuthKey string
}
type LdapService struct { type LdapService struct {
config LdapServiceConfig log *logger.Logger
config model.Config
context context.Context
conn *ldapgo.Conn conn *ldapgo.Conn
mutex sync.RWMutex mutex sync.RWMutex
cert *tls.Certificate cert *tls.Certificate
isConfigured bool isConfigured bool
} }
func NewLdapService(config LdapServiceConfig) *LdapService { func NewLdapService(
log *logger.Logger,
config model.Config,
context context.Context,
) *LdapService {
return &LdapService{ return &LdapService{
config: config, log: log,
config: config,
context: context,
} }
} }
@@ -57,7 +56,7 @@ func (ldap *LdapService) Unconfigure() error {
} }
func (ldap *LdapService) Init() error { func (ldap *LdapService) Init() error {
if ldap.config.Address == "" { if ldap.config.LDAP.Address == "" {
ldap.isConfigured = false ldap.isConfigured = false
return nil return nil
} }
@@ -65,13 +64,13 @@ func (ldap *LdapService) Init() error {
ldap.isConfigured = true ldap.isConfigured = true
// Check whether authentication with client certificate is possible // Check whether authentication with client certificate is possible
if ldap.config.AuthCert != "" && ldap.config.AuthKey != "" { if ldap.config.LDAP.AuthCert != "" && ldap.config.LDAP.AuthKey != "" {
cert, err := tls.LoadX509KeyPair(ldap.config.AuthCert, ldap.config.AuthKey) cert, err := tls.LoadX509KeyPair(ldap.config.LDAP.AuthCert, ldap.config.LDAP.AuthKey)
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err) return fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err)
} }
ldap.cert = &cert ldap.cert = &cert
tlog.App.Info().Msg("Using LDAP with mTLS authentication") ldap.log.App.Info().Msg("LDAP mTLS authentication configured successfully")
// TODO: Add optional extra CA certificates, instead of `InsecureSkipVerify` // TODO: Add optional extra CA certificates, instead of `InsecureSkipVerify`
/* /*
@@ -90,15 +89,24 @@ func (ldap *LdapService) Init() error {
} }
go func() { go func() {
for range time.Tick(time.Duration(5) * time.Minute) { ticker := time.NewTicker(5 * time.Minute)
err := ldap.heartbeat() defer ticker.Stop()
if err != nil {
tlog.App.Error().Err(err).Msg("LDAP connection heartbeat failed") for {
if reconnectErr := ldap.reconnect(); reconnectErr != nil { select {
tlog.App.Error().Err(reconnectErr).Msg("Failed to reconnect to LDAP server") case <-ticker.C:
continue err := ldap.heartbeat()
if err != nil {
ldap.log.App.Warn().Err(err).Msg("LDAP connection heartbeat failed, attempting to reconnect")
if reconnectErr := ldap.reconnect(); reconnectErr != nil {
ldap.log.App.Error().Err(reconnectErr).Msg("Failed to reconnect to LDAP server")
continue
}
ldap.log.App.Info().Msg("Successfully reconnected to LDAP server")
} }
tlog.App.Info().Msg("Successfully reconnected to LDAP server") case <-ldap.context.Done():
ldap.log.App.Debug().Msg("LDAP service context cancelled, stopping heartbeat")
return
} }
} }
}() }()
@@ -120,13 +128,13 @@ func (ldap *LdapService) connect() (*ldapgo.Conn, error) {
// 2. conn.StartTLS(tlsConfig) // 2. conn.StartTLS(tlsConfig)
// 3. conn.externalBind() // 3. conn.externalBind()
if ldap.cert != nil { if ldap.cert != nil {
conn, err = ldapgo.DialURL(ldap.config.Address, ldapgo.DialWithTLSConfig(&tls.Config{ conn, err = ldapgo.DialURL(ldap.config.LDAP.Address, ldapgo.DialWithTLSConfig(&tls.Config{
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
Certificates: []tls.Certificate{*ldap.cert}, Certificates: []tls.Certificate{*ldap.cert},
})) }))
} else { } else {
conn, err = ldapgo.DialURL(ldap.config.Address, ldapgo.DialWithTLSConfig(&tls.Config{ conn, err = ldapgo.DialURL(ldap.config.LDAP.Address, ldapgo.DialWithTLSConfig(&tls.Config{
InsecureSkipVerify: ldap.config.Insecure, InsecureSkipVerify: ldap.config.LDAP.Insecure,
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
})) }))
} }
@@ -146,10 +154,10 @@ func (ldap *LdapService) connect() (*ldapgo.Conn, error) {
func (ldap *LdapService) GetUserDN(username string) (string, error) { func (ldap *LdapService) GetUserDN(username string) (string, error) {
// Escape the username to prevent LDAP injection // Escape the username to prevent LDAP injection
escapedUsername := ldapgo.EscapeFilter(username) escapedUsername := ldapgo.EscapeFilter(username)
filter := fmt.Sprintf(ldap.config.SearchFilter, escapedUsername) filter := fmt.Sprintf(ldap.config.LDAP.SearchFilter, escapedUsername)
searchRequest := ldapgo.NewSearchRequest( searchRequest := ldapgo.NewSearchRequest(
ldap.config.BaseDN, ldap.config.LDAP.BaseDN,
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false, ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
filter, filter,
[]string{"dn"}, []string{"dn"},
@@ -176,7 +184,7 @@ func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) {
escapedUserDN := ldapgo.EscapeFilter(userDN) escapedUserDN := ldapgo.EscapeFilter(userDN)
searchRequest := ldapgo.NewSearchRequest( searchRequest := ldapgo.NewSearchRequest(
ldap.config.BaseDN, ldap.config.LDAP.BaseDN,
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false, ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
fmt.Sprintf("(&(objectclass=groupOfUniqueNames)(uniquemember=%s))", escapedUserDN), fmt.Sprintf("(&(objectclass=groupOfUniqueNames)(uniquemember=%s))", escapedUserDN),
[]string{"dn"}, []string{"dn"},
@@ -224,7 +232,7 @@ func (ldap *LdapService) BindService(rebind bool) error {
if ldap.cert != nil { if ldap.cert != nil {
return ldap.conn.ExternalBind() return ldap.conn.ExternalBind()
} }
return ldap.conn.Bind(ldap.config.BindDN, ldap.config.BindPassword) return ldap.conn.Bind(ldap.config.LDAP.BindDN, ldap.config.LDAP.BindPassword)
} }
func (ldap *LdapService) Bind(userDN string, password string) error { func (ldap *LdapService) Bind(userDN string, password string) error {
@@ -238,7 +246,7 @@ func (ldap *LdapService) Bind(userDN string, password string) error {
} }
func (ldap *LdapService) heartbeat() error { func (ldap *LdapService) heartbeat() error {
tlog.App.Debug().Msg("Performing LDAP connection heartbeat") ldap.log.App.Debug().Msg("Performing LDAP connection heartbeat")
searchRequest := ldapgo.NewSearchRequest( searchRequest := ldapgo.NewSearchRequest(
"", "",
@@ -260,7 +268,7 @@ func (ldap *LdapService) heartbeat() error {
} }
func (ldap *LdapService) reconnect() error { func (ldap *LdapService) reconnect() error {
tlog.App.Info().Msg("Reconnecting to LDAP server") ldap.log.App.Info().Msg("Attempting to reconnect to LDAP server")
exp := backoff.NewExponentialBackOff() exp := backoff.NewExponentialBackOff()
exp.InitialInterval = 500 * time.Millisecond exp.InitialInterval = 500 * time.Millisecond
+9 -4
View File
@@ -2,7 +2,7 @@ package service
import ( import (
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
"slices" "slices"
@@ -19,6 +19,8 @@ type OAuthServiceImpl interface {
} }
type OAuthBrokerService struct { type OAuthBrokerService struct {
log *logger.Logger
services map[string]OAuthServiceImpl services map[string]OAuthServiceImpl
configs map[string]model.OAuthServiceConfig configs map[string]model.OAuthServiceConfig
} }
@@ -28,7 +30,10 @@ var presets = map[string]func(config model.OAuthServiceConfig) *OAuthService{
"google": newGoogleOAuthService, "google": newGoogleOAuthService,
} }
func NewOAuthBrokerService(configs map[string]model.OAuthServiceConfig) *OAuthBrokerService { func NewOAuthBrokerService(
log *logger.Logger,
configs map[string]model.OAuthServiceConfig,
) *OAuthBrokerService {
return &OAuthBrokerService{ return &OAuthBrokerService{
services: make(map[string]OAuthServiceImpl), services: make(map[string]OAuthServiceImpl),
configs: configs, configs: configs,
@@ -39,10 +44,10 @@ func (broker *OAuthBrokerService) Init() error {
for name, cfg := range broker.configs { for name, cfg := range broker.configs {
if presetFunc, exists := presets[name]; exists { if presetFunc, exists := presets[name]; exists {
broker.services[name] = presetFunc(cfg) broker.services[name] = presetFunc(cfg)
tlog.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset") broker.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset")
} else { } else {
broker.services[name] = NewOAuthService(cfg, name) broker.services[name] = NewOAuthService(cfg, name)
tlog.App.Debug().Str("service", name).Msg("Loaded OAuth service from config") broker.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from custom config")
} }
} }
return nil return nil
+87 -72
View File
@@ -25,7 +25,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
var ( var (
@@ -111,17 +111,13 @@ type AuthorizeRequest struct {
CodeChallengeMethod string `json:"code_challenge_method"` CodeChallengeMethod string `json:"code_challenge_method"`
} }
type OIDCServiceConfig struct {
Clients map[string]model.OIDCClientConfig
PrivateKeyPath string
PublicKeyPath string
Issuer string
SessionExpiry int
}
type OIDCService struct { type OIDCService struct {
config OIDCServiceConfig log *logger.Logger
queries *repository.Queries config model.Config
runtime model.RuntimeConfig
queries *repository.Queries
context context.Context
clients map[string]model.OIDCClientConfig clients map[string]model.OIDCClientConfig
privateKey *rsa.PrivateKey privateKey *rsa.PrivateKey
publicKey crypto.PublicKey publicKey crypto.PublicKey
@@ -129,10 +125,18 @@ type OIDCService struct {
isConfigured bool isConfigured bool
} }
func NewOIDCService(config OIDCServiceConfig, queries *repository.Queries) *OIDCService { func NewOIDCService(
log *logger.Logger,
config model.Config,
runtime model.RuntimeConfig,
queries *repository.Queries,
context context.Context) *OIDCService {
return &OIDCService{ return &OIDCService{
log: log,
config: config, config: config,
runtime: runtime,
queries: queries, queries: queries,
context: context,
} }
} }
@@ -142,7 +146,7 @@ func (service *OIDCService) IsConfigured() bool {
func (service *OIDCService) Init() error { func (service *OIDCService) Init() error {
// If not configured, skip init // If not configured, skip init
if len(service.config.Clients) == 0 { if len(service.runtime.OIDCClients) == 0 {
service.isConfigured = false service.isConfigured = false
return nil return nil
} }
@@ -150,7 +154,7 @@ func (service *OIDCService) Init() error {
service.isConfigured = true service.isConfigured = true
// Ensure issuer is https // Ensure issuer is https
uissuer, err := url.Parse(service.config.Issuer) uissuer, err := url.Parse(service.runtime.AppURL)
if err != nil { if err != nil {
return err return err
@@ -163,14 +167,14 @@ func (service *OIDCService) Init() error {
service.issuer = fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host) service.issuer = fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host)
// Create/load private and public keys // Create/load private and public keys
if strings.TrimSpace(service.config.PrivateKeyPath) == "" || if strings.TrimSpace(service.config.OIDC.PrivateKeyPath) == "" ||
strings.TrimSpace(service.config.PublicKeyPath) == "" { strings.TrimSpace(service.config.OIDC.PublicKeyPath) == "" {
return errors.New("private key path and public key path are required") return errors.New("private key path and public key path are required")
} }
var privateKey *rsa.PrivateKey var privateKey *rsa.PrivateKey
fprivateKey, err := os.ReadFile(service.config.PrivateKeyPath) fprivateKey, err := os.ReadFile(service.config.OIDC.PrivateKeyPath)
if err != nil && !errors.Is(err, os.ErrNotExist) { if err != nil && !errors.Is(err, os.ErrNotExist) {
return err return err
@@ -189,8 +193,8 @@ func (service *OIDCService) Init() error {
Type: "RSA PRIVATE KEY", Type: "RSA PRIVATE KEY",
Bytes: der, Bytes: der,
}) })
tlog.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key") service.log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key")
err = os.WriteFile(service.config.PrivateKeyPath, encoded, 0600) err = os.WriteFile(service.config.OIDC.PrivateKeyPath, encoded, 0600)
if err != nil { if err != nil {
return err return err
} }
@@ -200,7 +204,7 @@ func (service *OIDCService) Init() error {
if block == nil { if block == nil {
return errors.New("failed to decode private key") return errors.New("failed to decode private key")
} }
tlog.App.Trace().Str("type", block.Type).Msg("Loaded private key") service.log.App.Trace().Str("type", block.Type).Msg("Loaded private key")
privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes) privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil { if err != nil {
return err return err
@@ -208,7 +212,7 @@ func (service *OIDCService) Init() error {
service.privateKey = privateKey service.privateKey = privateKey
} }
fpublicKey, err := os.ReadFile(service.config.PublicKeyPath) fpublicKey, err := os.ReadFile(service.config.OIDC.PublicKeyPath)
if err != nil && !errors.Is(err, os.ErrNotExist) { if err != nil && !errors.Is(err, os.ErrNotExist) {
return err return err
@@ -224,8 +228,8 @@ func (service *OIDCService) Init() error {
Type: "RSA PUBLIC KEY", Type: "RSA PUBLIC KEY",
Bytes: der, Bytes: der,
}) })
tlog.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key") service.log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
err = os.WriteFile(service.config.PublicKeyPath, encoded, 0644) err = os.WriteFile(service.config.OIDC.PublicKeyPath, encoded, 0644)
if err != nil { if err != nil {
return err return err
} }
@@ -235,7 +239,7 @@ func (service *OIDCService) Init() error {
if block == nil { if block == nil {
return errors.New("failed to decode public key") return errors.New("failed to decode public key")
} }
tlog.App.Trace().Str("type", block.Type).Msg("Loaded public key") service.log.App.Trace().Str("type", block.Type).Msg("Loaded public key")
switch block.Type { switch block.Type {
case "RSA PUBLIC KEY": case "RSA PUBLIC KEY":
publicKey, err := x509.ParsePKCS1PublicKey(block.Bytes) publicKey, err := x509.ParsePKCS1PublicKey(block.Bytes)
@@ -257,7 +261,7 @@ func (service *OIDCService) Init() error {
// We will reorganize the client into a map with the client ID as the key // We will reorganize the client into a map with the client ID as the key
service.clients = make(map[string]model.OIDCClientConfig) service.clients = make(map[string]model.OIDCClientConfig)
for id, client := range service.config.Clients { for id, client := range service.config.OIDC.Clients {
client.ID = id client.ID = id
if client.Name == "" { if client.Name == "" {
client.Name = utils.Capitalize(client.ID) client.Name = utils.Capitalize(client.ID)
@@ -273,9 +277,12 @@ func (service *OIDCService) Init() error {
} }
client.ClientSecretFile = "" client.ClientSecretFile = ""
service.clients[id] = client service.clients[id] = client
tlog.App.Info().Str("id", client.ID).Msg("Registered OIDC client") service.log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration")
} }
// Start cleanup routine
go service.cleanupRoutine()
return nil return nil
} }
@@ -307,7 +314,7 @@ func (service *OIDCService) ValidateAuthorizeParams(req AuthorizeRequest) error
return errors.New("invalid_scope") return errors.New("invalid_scope")
} }
if !slices.Contains(SupportedScopes, scope) { if !slices.Contains(SupportedScopes, scope) {
tlog.App.Warn().Str("scope", scope).Msg("Unsupported OIDC scope, will be ignored") service.log.App.Warn().Str("scope", scope).Msg("Requested unsupported scope")
} }
} }
@@ -357,7 +364,7 @@ func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, r
entry.CodeChallenge = req.CodeChallenge entry.CodeChallenge = req.CodeChallenge
} else { } else {
entry.CodeChallenge = service.hashAndEncodePKCE(req.CodeChallenge) entry.CodeChallenge = service.hashAndEncodePKCE(req.CodeChallenge)
tlog.App.Warn().Msg("Received plain PKCE code challenge, it's recommended to use S256 for better security") service.log.App.Warn().Msg("Using plain PKCE code challenge method is not recommended, consider switching to S256 for better security")
} }
} }
@@ -449,7 +456,7 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string, client
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user repository.OidcUserinfo, scope string, nonce string) (string, error) { func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user repository.OidcUserinfo, scope string, nonce string) (string, error) {
createdAt := time.Now().Unix() createdAt := time.Now().Unix()
expiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix() expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
hasher := sha256.New() hasher := sha256.New()
@@ -529,16 +536,16 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client model.OID
accessToken := utils.GenerateString(32) accessToken := utils.GenerateString(32)
refreshToken := utils.GenerateString(32) refreshToken := utils.GenerateString(32)
tokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix() tokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
// Refresh token lives double the time of an access token but can't be used to access userinfo // Refresh token lives double the time of an access token but can't be used to access userinfo
refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry*2) * time.Second).Unix() refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry*2) * time.Second).Unix()
tokenResponse := TokenResponse{ tokenResponse := TokenResponse{
AccessToken: accessToken, AccessToken: accessToken,
RefreshToken: refreshToken, RefreshToken: refreshToken,
TokenType: "Bearer", TokenType: "Bearer",
ExpiresIn: int64(service.config.SessionExpiry), ExpiresIn: int64(service.config.Auth.SessionExpiry),
IDToken: idToken, IDToken: idToken,
Scope: strings.ReplaceAll(codeEntry.Scope, ",", " "), Scope: strings.ReplaceAll(codeEntry.Scope, ",", " "),
} }
@@ -598,14 +605,14 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri
accessToken := utils.GenerateString(32) accessToken := utils.GenerateString(32)
newRefreshToken := utils.GenerateString(32) newRefreshToken := utils.GenerateString(32)
tokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix() tokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry*2) * time.Second).Unix() refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry*2) * time.Second).Unix()
tokenResponse := TokenResponse{ tokenResponse := TokenResponse{
AccessToken: accessToken, AccessToken: accessToken,
RefreshToken: newRefreshToken, RefreshToken: newRefreshToken,
TokenType: "Bearer", TokenType: "Bearer",
ExpiresIn: int64(service.config.SessionExpiry), ExpiresIn: int64(service.config.Auth.SessionExpiry),
IDToken: idToken, IDToken: idToken,
Scope: strings.ReplaceAll(entry.Scope, ",", " "), Scope: strings.ReplaceAll(entry.Scope, ",", " "),
} }
@@ -748,56 +755,64 @@ func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) er
} }
// Cleanup routine - Resource heavy due to the linked tables // Cleanup routine - Resource heavy due to the linked tables
func (service *OIDCService) Cleanup() { func (service *OIDCService) cleanupRoutine() {
// We need a context for the routine
ctx := context.Background()
ticker := time.NewTicker(time.Duration(30) * time.Minute) ticker := time.NewTicker(time.Duration(30) * time.Minute)
defer ticker.Stop() defer ticker.Stop()
for range ticker.C { for {
currentTime := time.Now().Unix() select {
case <-ticker.C:
service.log.App.Debug().Msg("Starting OIDC cleanup routine")
// For the OIDC tokens, if they are expired we delete the userinfo and codes currentTime := time.Now().Unix()
expiredTokens, err := service.queries.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{
TokenExpiresAt: currentTime,
RefreshTokenExpiresAt: currentTime,
})
if err != nil { // For the OIDC tokens, if they are expired we delete the userinfo and codes
tlog.App.Warn().Err(err).Msg("Failed to delete expired tokens") expiredTokens, err := service.queries.DeleteExpiredOidcTokens(service.context, repository.DeleteExpiredOidcTokensParams{
} TokenExpiresAt: currentTime,
RefreshTokenExpiresAt: currentTime,
for _, expiredToken := range expiredTokens { })
err := service.DeleteOldSession(ctx, expiredToken.Sub)
if err != nil {
tlog.App.Warn().Err(err).Msg("Failed to delete old session")
}
}
// For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything
expiredCodes, err := service.queries.DeleteExpiredOidcCodes(ctx, currentTime)
if err != nil {
tlog.App.Warn().Err(err).Msg("Failed to delete expired codes")
}
for _, expiredCode := range expiredCodes {
token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { service.log.App.Warn().Err(err).Msg("Failed to delete expired tokens")
continue
}
tlog.App.Warn().Err(err).Msg("Failed to get OIDC token by sub")
} }
if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime { for _, expiredToken := range expiredTokens {
err := service.DeleteOldSession(ctx, expiredCode.Sub) err := service.DeleteOldSession(service.context, expiredToken.Sub)
if err != nil { if err != nil {
tlog.App.Warn().Err(err).Msg("Failed to delete session") service.log.App.Warn().Err(err).Msg("Failed to delete session for expired token")
} }
} }
// For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything
expiredCodes, err := service.queries.DeleteExpiredOidcCodes(service.context, currentTime)
if err != nil {
service.log.App.Warn().Err(err).Msg("Failed to delete expired codes")
}
for _, expiredCode := range expiredCodes {
token, err := service.queries.GetOidcTokenBySub(service.context, expiredCode.Sub)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
continue
}
service.log.App.Warn().Err(err).Msg("Failed to get token by sub for expired code")
}
if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime {
err := service.DeleteOldSession(service.context, expiredCode.Sub)
if err != nil {
service.log.App.Warn().Err(err).Msg("Failed to delete session for expired code")
}
}
}
service.log.App.Debug().Msg("Finished OIDC cleanup routine")
case <-service.context.Done():
service.log.App.Debug().Msg("OIDC cleanup routine context cancelled, stopping")
return
} }
} }
} }
-3
View File
@@ -7,8 +7,6 @@ import (
"net/url" "net/url"
"strings" "strings"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/weppos/publicsuffix-go/publicsuffix" "github.com/weppos/publicsuffix-go/publicsuffix"
) )
@@ -28,7 +26,6 @@ func GetCookieDomain(u string) (string, error) {
parts := strings.Split(host, ".") parts := strings.Split(host, ".")
if len(parts) == 2 { if len(parts) == 2 {
tlog.App.Warn().Msgf("Running on the root domain, cookies will be set for .%v", host)
return host, nil return host, nil
} }