From 592c221b2dcc8cde5410a8e822df49aa0ad32ac1 Mon Sep 17 00:00:00 2001 From: Stavros Date: Thu, 7 May 2026 22:31:51 +0300 Subject: [PATCH] refactor: use one struct for context handling and cancellation --- cmd/tinyauth/tinyauth.go | 6 - internal/bootstrap/app_bootstrap.go | 380 +++++++++++++------- internal/bootstrap/db_bootstrap.go | 21 +- internal/bootstrap/router_bootstrap.go | 33 +- internal/bootstrap/service_bootstrap.go | 84 ++--- internal/service/access_controls_service.go | 9 +- internal/utils/logger/logger.go | 4 +- 7 files changed, 327 insertions(+), 210 deletions(-) diff --git a/cmd/tinyauth/tinyauth.go b/cmd/tinyauth/tinyauth.go index f5bbb19f..b6293718 100644 --- a/cmd/tinyauth/tinyauth.go +++ b/cmd/tinyauth/tinyauth.go @@ -7,7 +7,6 @@ import ( "github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/utils/loaders" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/rs/zerolog/log" "github.com/tinyauthapp/paerser/cli" @@ -109,11 +108,6 @@ func main() { } func runCmd(cfg model.Config) error { - logger := tlog.NewLogger(cfg.Log) - logger.Init() - - tlog.App.Info().Str("version", model.Version).Msg("Starting tinyauth") - app := bootstrap.NewBootstrapApp(cfg) err := app.Setup() diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go index 5b342c48..5b10b192 100644 --- a/internal/bootstrap/app_bootstrap.go +++ b/internal/bootstrap/app_bootstrap.go @@ -3,98 +3,137 @@ package bootstrap import ( "bytes" "context" + "database/sql" "encoding/json" + "errors" "fmt" + "net" "net/http" "net/url" "os" + "os/signal" "sort" "strings" + "syscall" "time" + "github.com/gin-gonic/gin" "github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" + "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) -type BootstrapApp struct { - config model.Config - context struct { - appUrl string - uuid string - cookieDomain string - sessionCookieName string - csrfCookieName string - redirectCookieName string - oauthSessionCookieName string - localUsers *[]model.LocalUser - oauthProviders map[string]model.OAuthServiceConfig - oauthWhitelist []string - configuredProviders []controller.Provider - oidcClients []model.OIDCClientConfig - } - services Services +type Services struct { + accessControlService *service.AccessControlsService + authService *service.AuthService + dockerService *service.DockerService + kubernetesService *service.KubernetesService + ldapService *service.LdapService + oauthBrokerService *service.OAuthBrokerService + oidcService *service.OIDCService } -func NewBootstrapApp(config model.Config) *BootstrapApp { - return &BootstrapApp{ +type RuntimeConfig struct { + appUrl string + uuid string + cookieDomain string + sessionCookieName string + csrfCookieName string + redirectCookieName string + oauthSessionCookieName string + localUsers []model.LocalUser + oauthProviders map[string]model.OAuthServiceConfig + oauthWhitelist []string + configuredProviders []controller.Provider + oidcClients []model.OIDCClientConfig + labelProvider service.LabelProvider +} + +type App struct { + config model.Config + runtime RuntimeConfig + services Services + log *logger.Logger + ctx context.Context + cancel context.CancelFunc + queries *repository.Queries + router *gin.Engine + db *sql.DB +} + +func NewBootstrapApp(config model.Config) *App { + return &App{ config: config, } } -func (app *BootstrapApp) Setup() error { +func (app *App) Setup() error { + // create context + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + app.ctx = ctx + app.cancel = cancel + + // setup logger + log := logger.NewLogger().WithConfig(app.config.Log) + log.Init() + app.log = log + // get app url if app.config.AppURL == "" { - return fmt.Errorf("app URL cannot be empty, perhaps config loading failed") + return errors.New("app url cannot be empty, perhaps config loading failed") } appUrl, err := url.Parse(app.config.AppURL) if err != nil { - return err + return fmt.Errorf("failed to parse app url: %w", err) } - app.context.appUrl = appUrl.Scheme + "://" + appUrl.Host + app.runtime.appUrl = appUrl.Scheme + "://" + appUrl.Host // validate session config if app.config.Auth.SessionMaxLifetime != 0 && app.config.Auth.SessionMaxLifetime < app.config.Auth.SessionExpiry { - return fmt.Errorf("session max lifetime cannot be less than session expiry") + return errors.New("session max lifetime cannot be less than session expiry") } - // Parse users + // parse users users, err := utils.GetUsers(app.config.Auth.Users, app.config.Auth.UsersFile, app.config.Auth.UserAttributes) if err != nil { - return err + return fmt.Errorf("failed to load users: %w", err) } - app.context.localUsers = users + app.runtime.localUsers = *users + // load oauth whitelist oauthWhitelist, err := utils.GetStringList(app.config.OAuth.Whitelist, app.config.OAuth.WhitelistFile) + if err != nil { - return err + return fmt.Errorf("failed to load oauth whitelist: %w", err) } - app.context.oauthWhitelist = oauthWhitelist + app.runtime.oauthWhitelist = oauthWhitelist - // Setup OAuth providers - app.context.oauthProviders = app.config.OAuth.Providers + // Setup oauth providers + app.runtime.oauthProviders = app.config.OAuth.Providers - for name, provider := range app.context.oauthProviders { + for id, provider := range app.runtime.oauthProviders { secret := utils.GetSecret(provider.ClientSecret, provider.ClientSecretFile) provider.ClientSecret = secret provider.ClientSecretFile = "" if provider.RedirectURL == "" { - provider.RedirectURL = app.context.appUrl + "/api/oauth/callback/" + name + provider.RedirectURL = app.runtime.appUrl + "/api/oauth/callback/" + id } - app.context.oauthProviders[name] = provider + app.runtime.oauthProviders[id] = provider } - for id, provider := range app.context.oauthProviders { + // set presets for built-in providers + for id, provider := range app.runtime.oauthProviders { if provider.Name == "" { if name, ok := model.OverrideProviders[id]; ok { provider.Name = name @@ -102,70 +141,63 @@ func (app *BootstrapApp) Setup() error { provider.Name = utils.Capitalize(id) } } - app.context.oauthProviders[id] = provider + app.runtime.oauthProviders[id] = provider } - // Setup OIDC clients + // setup oidc clients for id, client := range app.config.OIDC.Clients { client.ID = id - app.context.oidcClients = append(app.context.oidcClients, client) + app.runtime.oidcClients = append(app.runtime.oidcClients, client) } - // Get cookie domain + // cookie domain cookieDomainResolver := utils.GetCookieDomain + if !app.config.Auth.SubdomainsEnabled { - tlog.App.Info().Msg("Subdomains disabled, automatic authentication for proxied apps will not work") + app.log.App.Warn().Msg("Subdomains are disabled, using standalone cookie domain resolver which will not work with subdomains") cookieDomainResolver = utils.GetStandaloneCookieDomain } - cookieDomain, err := cookieDomainResolver(app.context.appUrl) + cookieDomain, err := cookieDomainResolver(app.runtime.appUrl) if err != nil { - return err + return fmt.Errorf("failed to get cookie domain: %w", err) } - app.context.cookieDomain = cookieDomain + app.runtime.cookieDomain = cookieDomain - // Cookie names - app.context.uuid = utils.GenerateUUID(appUrl.Hostname()) - cookieId := strings.Split(app.context.uuid, "-")[0] - app.context.sessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId) - app.context.csrfCookieName = fmt.Sprintf("%s-%s", model.CSRFCookieName, cookieId) - app.context.redirectCookieName = fmt.Sprintf("%s-%s", model.RedirectCookieName, cookieId) - app.context.oauthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId) + // cookie names + app.runtime.uuid = utils.GenerateUUID(appUrl.Hostname()) - // Dumps - tlog.App.Trace().Interface("config", app.config).Msg("Config dump") - tlog.App.Trace().Interface("users", app.context.localUsers).Msg("Users dump") - tlog.App.Trace().Interface("oauthProviders", app.context.oauthProviders).Msg("OAuth providers dump") - tlog.App.Trace().Str("cookieDomain", app.context.cookieDomain).Msg("Cookie domain") - tlog.App.Trace().Str("sessionCookieName", app.context.sessionCookieName).Msg("Session cookie name") - tlog.App.Trace().Str("csrfCookieName", app.context.csrfCookieName).Msg("CSRF cookie name") - tlog.App.Trace().Str("redirectCookieName", app.context.redirectCookieName).Msg("Redirect cookie name") + cookieId := strings.Split(app.runtime.uuid, "-")[0] // first 8 characters of the uuid should be good enough - // Database - db, err := app.SetupDatabase(app.config.Database.Path) + app.runtime.sessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId) + app.runtime.csrfCookieName = fmt.Sprintf("%s-%s", model.CSRFCookieName, cookieId) + app.runtime.redirectCookieName = fmt.Sprintf("%s-%s", model.RedirectCookieName, cookieId) + app.runtime.oauthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId) + + // database + err = app.SetupDatabase() if err != nil { return fmt.Errorf("failed to setup database: %w", err) } - // Queries - queries := repository.New(db) + // queries + queries := repository.New(app.db) + app.queries = queries - // Services - services, err := app.initServices(queries) + // services + err = app.setupServices() if err != nil { return fmt.Errorf("failed to initialize services: %w", err) } - app.services = services - - // Configured providers + // configured providers configuredProviders := make([]controller.Provider, 0) - for id, provider := range app.context.oauthProviders { + for id, provider := range app.runtime.oauthProviders { configuredProviders = append(configuredProviders, controller.Provider{ Name: provider.Name, ID: id, @@ -177,7 +209,7 @@ func (app *BootstrapApp) Setup() error { return configuredProviders[i].Name < configuredProviders[j].Name }) - if services.authService.LocalAuthConfigured() { + if app.services.authService.LocalAuthConfigured() { configuredProviders = append(configuredProviders, controller.Provider{ Name: "Local", ID: "local", @@ -185,7 +217,7 @@ func (app *BootstrapApp) Setup() error { }) } - if services.authService.LDAPAuthConfigured() { + if app.services.authService.LDAPAuthConfigured() { configuredProviders = append(configuredProviders, controller.Provider{ Name: "LDAP", ID: "ldap", @@ -193,77 +225,150 @@ func (app *BootstrapApp) Setup() error { }) } - tlog.App.Debug().Interface("providers", configuredProviders).Msg("Authentication providers") - if len(configuredProviders) == 0 { - return fmt.Errorf("no authentication providers configured") + return errors.New("no authentication providers configured") } - app.context.configuredProviders = configuredProviders + for _, provider := range app.runtime.configuredProviders { + app.log.App.Debug().Str("provider", provider.Name).Msg("Configured authentication provider") + } - // Setup router - router, err := app.setupRouter() + app.runtime.configuredProviders = configuredProviders + + // setup router + err = app.setupRouter() if err != nil { return fmt.Errorf("failed to setup routes: %w", err) } - // Start db cleanup routine - tlog.App.Debug().Msg("Starting database cleanup routine") - go app.dbCleanupRoutine(queries) + // start db cleanup routine + app.log.App.Debug().Msg("Starting database cleanup routine") + go app.dbCleanupRoutine() - // If analytics are not disabled, start heartbeat + // if analytics are not disabled, start heartbeat if app.config.Analytics.Enabled { - tlog.App.Debug().Msg("Starting heartbeat routine") + app.log.App.Debug().Msg("Starting heartbeat routine") go app.heartbeatRoutine() } - // If we have an socket path, bind to it - if app.config.Server.SocketPath != "" { - if _, err := os.Stat(app.config.Server.SocketPath); err == nil { - tlog.App.Info().Msgf("Removing existing socket file %s", app.config.Server.SocketPath) - err := os.Remove(app.config.Server.SocketPath) - if err != nil { - return fmt.Errorf("failed to remove existing socket file: %w", err) - } - } + // create err channel to listen for server errors + errChan := make(chan error, 1) - tlog.App.Info().Msgf("Starting server on unix socket %s", app.config.Server.SocketPath) - if err := router.RunUnix(app.config.Server.SocketPath); err != nil { - tlog.App.Fatal().Err(err).Msg("Failed to start server") - } + // serve unix + go func() { + errChan <- app.serveUnix() + }() + // serve to http + go func() { + errChan <- app.serveHTTP() + }() + + // monitor cancellation and server errors + select { + case <-app.ctx.Done(): + app.log.App.Debug().Msg("Shutting down application") return nil - } - - // Start server - address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port) - tlog.App.Info().Msgf("Starting server on %s", address) - if err := router.Run(address); err != nil { - tlog.App.Fatal().Err(err).Msg("Failed to start server") + case err := <-errChan: + if err != nil { + return fmt.Errorf("server error: %w", err) + } } return nil } -func (app *BootstrapApp) heartbeatRoutine() { +func (app *App) serveHTTP() error { + address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port) + + app.log.App.Info().Msgf("Starting server on %s", address) + + server := &http.Server{ + Addr: address, + Handler: app.router.Handler(), + } + + go func() { + <-app.ctx.Done() + app.log.App.Debug().Msg("Shutting down server") + server.Close() + }() + + err := server.ListenAndServe() + + if err != nil && !errors.Is(err, http.ErrServerClosed) { + return fmt.Errorf("failed to start server: %w", err) + } + + return nil +} + +func (app *App) serveUnix() error { + if app.config.Server.SocketPath == "" { + return nil + } + + _, err := os.Stat(app.config.Server.SocketPath) + + if err == nil { + app.log.App.Info().Msgf("Removing existing socket file %s", app.config.Server.SocketPath) + err := os.Remove(app.config.Server.SocketPath) + + if err != nil { + return fmt.Errorf("failed to remove existing socket file: %w", err) + } + } + + app.log.App.Info().Msgf("Starting server on unix socket %s", app.config.Server.SocketPath) + + listener, err := net.Listen("unix", app.config.Server.SocketPath) + + if err != nil { + return fmt.Errorf("failed to create unix socket listner: %w", err) + } + + defer listener.Close() + defer os.Remove(app.config.Server.SocketPath) + + go func() { + <-app.ctx.Done() + app.log.App.Debug().Msg("Shutting down server") + listener.Close() + os.Remove(app.config.Server.SocketPath) + }() + + server := &http.Server{ + Handler: app.router.Handler(), + } + + err = server.Serve(listener) + + if err != nil && !errors.Is(err, net.ErrClosed) { + return fmt.Errorf("failed to start server: %w", err) + } + + return nil +} + +func (app *App) heartbeatRoutine() { ticker := time.NewTicker(time.Duration(12) * time.Hour) defer ticker.Stop() - type heartbeat struct { + type Heartbeat struct { UUID string `json:"uuid"` Version string `json:"version"` } - var body heartbeat + var body Heartbeat - body.UUID = app.context.uuid + body.UUID = app.runtime.uuid body.Version = model.Version bodyJson, err := json.Marshal(body) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to marshal heartbeat body") + app.log.App.Error().Err(err).Msg("Failed to marshal heartbeat body, heartbeat routine will not start") return } @@ -273,43 +378,58 @@ func (app *BootstrapApp) heartbeatRoutine() { heartbeatURL := model.APIServer + "/v1/instances/heartbeat" - for range ticker.C { - tlog.App.Debug().Msg("Sending heartbeat") + for { + select { + case <-ticker.C: + app.log.App.Debug().Msg("Sending heartbeat") - req, err := http.NewRequest(http.MethodPost, heartbeatURL, bytes.NewReader(bodyJson)) + req, err := http.NewRequest(http.MethodPost, heartbeatURL, bytes.NewReader(bodyJson)) - if err != nil { - tlog.App.Error().Err(err).Msg("Failed to create heartbeat request") - continue - } + if err != nil { + app.log.App.Error().Err(err).Msg("Failed to create heartbeat request") + continue + } - req.Header.Add("Content-Type", "application/json") + req.Header.Add("Content-Type", "application/json") - res, err := client.Do(req) + res, err := client.Do(req) - if err != nil { - tlog.App.Error().Err(err).Msg("Failed to send heartbeat") - continue - } + if err != nil { + app.log.App.Error().Err(err).Msg("Failed to send heartbeat") + continue + } - res.Body.Close() + res.Body.Close() - if res.StatusCode != 200 && res.StatusCode != 201 { - tlog.App.Debug().Str("status", res.Status).Msg("Heartbeat returned non-200/201 status") + if res.StatusCode != 200 && res.StatusCode != 201 { + app.log.App.Debug().Str("status", res.Status).Msg("Heartbeat returned non-200/201 status") + } + case <-app.ctx.Done(): + app.log.App.Debug().Msg("Stopping heartbeat routine") + ticker.Stop() + return } } } -func (app *BootstrapApp) dbCleanupRoutine(queries *repository.Queries) { +func (app *App) dbCleanupRoutine() { ticker := time.NewTicker(time.Duration(30) * time.Minute) defer ticker.Stop() - ctx := context.Background() - for range ticker.C { - tlog.App.Debug().Msg("Cleaning up old database sessions") - err := queries.DeleteExpiredSessions(ctx, time.Now().Unix()) - if err != nil { - tlog.App.Error().Err(err).Msg("Failed to clean up old database sessions") + for { + select { + case <-ticker.C: + app.log.App.Debug().Msg("Running database cleanup") + + err := app.queries.DeleteExpiredSessions(app.ctx, time.Now().Unix()) + + if err != nil { + app.log.App.Error().Err(err).Msg("Failed to delete expired sessions") + } + case <-app.ctx.Done(): + app.log.App.Debug().Msg("Stopping database cleanup routine") + ticker.Stop() + return } } } diff --git a/internal/bootstrap/db_bootstrap.go b/internal/bootstrap/db_bootstrap.go index 3f48f793..5ef5c9dc 100644 --- a/internal/bootstrap/db_bootstrap.go +++ b/internal/bootstrap/db_bootstrap.go @@ -14,17 +14,17 @@ import ( _ "modernc.org/sqlite" ) -func (app *BootstrapApp) SetupDatabase(databasePath string) (*sql.DB, error) { - dir := filepath.Dir(databasePath) +func (app *App) SetupDatabase() error { + dir := filepath.Dir(app.config.Database.Path) if err := os.MkdirAll(dir, 0750); err != nil { - return nil, fmt.Errorf("failed to create database directory %s: %w", dir, err) + return fmt.Errorf("failed to create database directory %s: %w", dir, err) } - db, err := sql.Open("sqlite", databasePath) + db, err := sql.Open("sqlite", app.config.Database.Path) if err != nil { - return nil, fmt.Errorf("failed to open database: %w", err) + return fmt.Errorf("failed to open database: %w", err) } // Limit to 1 connection to sequence writes, this may need to be revisited in the future @@ -34,24 +34,25 @@ func (app *BootstrapApp) SetupDatabase(databasePath string) (*sql.DB, error) { migrations, err := iofs.New(assets.Migrations, "migrations") if err != nil { - return nil, fmt.Errorf("failed to create migrations: %w", err) + return fmt.Errorf("failed to create migrations: %w", err) } target, err := sqlite3.WithInstance(db, &sqlite3.Config{}) if err != nil { - return nil, fmt.Errorf("failed to create sqlite3 instance: %w", err) + return fmt.Errorf("failed to create sqlite3 instance: %w", err) } migrator, err := migrate.NewWithInstance("iofs", migrations, "sqlite3", target) if err != nil { - return nil, fmt.Errorf("failed to create migrator: %w", err) + return fmt.Errorf("failed to create migrator: %w", err) } if err := migrator.Up(); err != nil && err != migrate.ErrNoChange { - return nil, fmt.Errorf("failed to migrate database: %w", err) + return fmt.Errorf("failed to migrate database: %w", err) } - return db, nil + app.db = db + return nil } diff --git a/internal/bootstrap/router_bootstrap.go b/internal/bootstrap/router_bootstrap.go index a746be79..7310fa43 100644 --- a/internal/bootstrap/router_bootstrap.go +++ b/internal/bootstrap/router_bootstrap.go @@ -13,7 +13,7 @@ import ( var DEV_MODES = []string{"main", "test", "development"} -func (app *BootstrapApp) setupRouter() (*gin.Engine, error) { +func (app *App) setupRouter() error { if !slices.Contains(DEV_MODES, model.Version) { gin.SetMode(gin.ReleaseMode) } @@ -25,19 +25,19 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) { err := engine.SetTrustedProxies(app.config.Auth.TrustedProxies) if err != nil { - return nil, fmt.Errorf("failed to set trusted proxies: %w", err) + return fmt.Errorf("failed to set trusted proxies: %w", err) } } contextMiddleware := middleware.NewContextMiddleware(middleware.ContextMiddlewareConfig{ - CookieDomain: app.context.cookieDomain, - SessionCookieName: app.context.sessionCookieName, + CookieDomain: app.runtime.cookieDomain, + SessionCookieName: app.runtime.sessionCookieName, }, app.services.authService, app.services.oauthBrokerService) err := contextMiddleware.Init() if err != nil { - return nil, fmt.Errorf("failed to initialize context middleware: %w", err) + return fmt.Errorf("failed to initialize context middleware: %w", err) } engine.Use(contextMiddleware.Middleware()) @@ -47,7 +47,7 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) { err = uiMiddleware.Init() if err != nil { - return nil, fmt.Errorf("failed to initialize UI middleware: %w", err) + return fmt.Errorf("failed to initialize UI middleware: %w", err) } engine.Use(uiMiddleware.Middleware()) @@ -57,7 +57,7 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) { err = zerologMiddleware.Init() if err != nil { - return nil, fmt.Errorf("failed to initialize zerolog middleware: %w", err) + return fmt.Errorf("failed to initialize zerolog middleware: %w", err) } engine.Use(zerologMiddleware.Middleware()) @@ -65,10 +65,10 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) { apiRouter := engine.Group("/api") contextController := controller.NewContextController(controller.ContextControllerConfig{ - Providers: app.context.configuredProviders, + Providers: app.runtime.configuredProviders, Title: app.config.UI.Title, AppURL: app.config.AppURL, - CookieDomain: app.context.cookieDomain, + CookieDomain: app.runtime.cookieDomain, ForgotPasswordMessage: app.config.UI.ForgotPasswordMessage, BackgroundImage: app.config.UI.BackgroundImage, OAuthAutoRedirect: app.config.OAuth.AutoRedirect, @@ -80,10 +80,10 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) { oauthController := controller.NewOAuthController(controller.OAuthControllerConfig{ AppURL: app.config.AppURL, SecureCookie: app.config.Auth.SecureCookie, - CSRFCookieName: app.context.csrfCookieName, - RedirectCookieName: app.context.redirectCookieName, - CookieDomain: app.context.cookieDomain, - OAuthSessionCookieName: app.context.oauthSessionCookieName, + CSRFCookieName: app.runtime.csrfCookieName, + RedirectCookieName: app.runtime.redirectCookieName, + CookieDomain: app.runtime.cookieDomain, + OAuthSessionCookieName: app.runtime.oauthSessionCookieName, SubdomainsEnabled: app.config.Auth.SubdomainsEnabled, }, apiRouter, app.services.authService) @@ -100,8 +100,8 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) { proxyController.SetupRoutes() userController := controller.NewUserController(controller.UserControllerConfig{ - CookieDomain: app.context.cookieDomain, - SessionCookieName: app.context.sessionCookieName, + CookieDomain: app.runtime.cookieDomain, + SessionCookieName: app.runtime.sessionCookieName, }, apiRouter, app.services.authService) userController.SetupRoutes() @@ -121,5 +121,6 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) { wellknownController.SetupRoutes() - return engine, nil + app.router = engine + return nil } diff --git a/internal/bootstrap/service_bootstrap.go b/internal/bootstrap/service_bootstrap.go index 09485bd0..b3261180 100644 --- a/internal/bootstrap/service_bootstrap.go +++ b/internal/bootstrap/service_bootstrap.go @@ -1,26 +1,14 @@ package bootstrap import ( + "fmt" "os" - "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" ) -type Services struct { - accessControlService *service.AccessControlsService - authService *service.AuthService - dockerService *service.DockerService - kubernetesService *service.KubernetesService - ldapService *service.LdapService - oauthBrokerService *service.OAuthBrokerService - oidcService *service.OIDCService -} - -func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, error) { - services := Services{} - +func (app *App) setupServices() error { ldapService := service.NewLdapService(service.LdapServiceConfig{ Address: app.config.LDAP.Address, BindDN: app.config.LDAP.BindDN, @@ -35,81 +23,85 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er err := ldapService.Init() if err != nil { - tlog.App.Warn().Err(err).Msg("Failed to setup LDAP service, starting without it") + app.log.App.Warn().Err(err).Msg("Failed to initialize LDAP connection, will continue without it") ldapService.Unconfigure() } - services.ldapService = ldapService - - var labelProvider service.LabelProvider - var dockerService *service.DockerService - var kubernetesService *service.KubernetesService + app.services.ldapService = ldapService useKubernetes := app.config.LabelProvider == "kubernetes" || (app.config.LabelProvider == "auto" && os.Getenv("KUBERNETES_SERVICE_HOST") != "") if useKubernetes { - tlog.App.Debug().Msg("Using Kubernetes label provider") - kubernetesService = service.NewKubernetesService() + app.log.App.Debug().Msg("Using Kubernetes label provider") + + kubernetesService := service.NewKubernetesService() + err = kubernetesService.Init() + if err != nil { - return Services{}, err + return fmt.Errorf("failed to initialize kubernetes service: %w", err) } - services.kubernetesService = kubernetesService - labelProvider = kubernetesService + + app.services.kubernetesService = kubernetesService + app.runtime.labelProvider = service.LabelProviderKubernetes } else { tlog.App.Debug().Msg("Using Docker label provider") - dockerService = service.NewDockerService() + + dockerService := service.NewDockerService() + err = dockerService.Init() + if err != nil { - return Services{}, err + return fmt.Errorf("failed to initialize docker service: %w", err) } - services.dockerService = dockerService - labelProvider = dockerService + + app.services.dockerService = dockerService + app.runtime.labelProvider = service.LabelProviderDocker } - accessControlsService := service.NewAccessControlsService(labelProvider, app.config.Apps) + accessControlsService := service.NewAccessControlsService(app.runtime.labelProvider, app.config.Apps) err = accessControlsService.Init() if err != nil { - return Services{}, err + return fmt.Errorf("failed to initialize access controls service: %w", err) } - services.accessControlService = accessControlsService + app.services.accessControlService = accessControlsService - oauthBrokerService := service.NewOAuthBrokerService(app.context.oauthProviders) + oauthBrokerService := service.NewOAuthBrokerService(app.runtime.oauthProviders) err = oauthBrokerService.Init() if err != nil { - return Services{}, err + return fmt.Errorf("failed to initialize oauth broker service: %w", err) } - services.oauthBrokerService = oauthBrokerService + app.services.oauthBrokerService = oauthBrokerService authService := service.NewAuthService(service.AuthServiceConfig{ - LocalUsers: app.context.localUsers, - OauthWhitelist: app.context.oauthWhitelist, + LocalUsers: &app.runtime.localUsers, + OauthWhitelist: app.runtime.oauthWhitelist, SessionExpiry: app.config.Auth.SessionExpiry, SessionMaxLifetime: app.config.Auth.SessionMaxLifetime, SecureCookie: app.config.Auth.SecureCookie, - CookieDomain: app.context.cookieDomain, + CookieDomain: app.runtime.cookieDomain, LoginTimeout: app.config.Auth.LoginTimeout, LoginMaxRetries: app.config.Auth.LoginMaxRetries, - SessionCookieName: app.context.sessionCookieName, + SessionCookieName: app.runtime.sessionCookieName, IP: app.config.Auth.IP, LDAPGroupsCacheTTL: app.config.LDAP.GroupCacheTTL, SubdomainsEnabled: app.config.Auth.SubdomainsEnabled, - }, services.ldapService, queries, services.oauthBrokerService) + }, app.services.ldapService, app.queries, app.services.oauthBrokerService) err = authService.Init() if err != nil { - return Services{}, err + return fmt.Errorf("failed to initialize auth service: %w", err) } - services.authService = authService + app.services.authService = authService oidcService := service.NewOIDCService(service.OIDCServiceConfig{ Clients: app.config.OIDC.Clients, @@ -117,15 +109,15 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er PublicKeyPath: app.config.OIDC.PublicKeyPath, Issuer: app.config.AppURL, SessionExpiry: app.config.Auth.SessionExpiry, - }, queries) + }, app.queries) err = oidcService.Init() if err != nil { - return Services{}, err + return fmt.Errorf("failed to initialize oidc service: %w", err) } - services.oidcService = oidcService + app.services.oidcService = oidcService - return services, nil + return nil } diff --git a/internal/service/access_controls_service.go b/internal/service/access_controls_service.go index fd57bf39..c16c5a25 100644 --- a/internal/service/access_controls_service.go +++ b/internal/service/access_controls_service.go @@ -7,7 +7,14 @@ import ( "github.com/tinyauthapp/tinyauth/internal/utils/tlog" ) -type LabelProvider interface { +type LabelProvider int + +const ( + LabelProviderDocker LabelProvider = iota + LabelProviderKubernetes +) + +type LabelProviderImpl interface { GetLabels(appDomain string) (*model.App, error) } diff --git a/internal/utils/logger/logger.go b/internal/utils/logger/logger.go index 18d319fb..d85af79e 100644 --- a/internal/utils/logger/logger.go +++ b/internal/utils/logger/logger.go @@ -77,7 +77,6 @@ func (l *Logger) WithWriter(writer io.Writer) *Logger { func (l *Logger) Init() { base := log.With(). Timestamp(). - Caller(). Logger(). Level(l.parseLogLevel(l.config.Level)).Output(l.writer) @@ -114,6 +113,9 @@ func (l *Logger) createLogger(component string, cfg model.LogStreamConfig) zerol if cfg.Level != "" { sub = sub.Level(l.parseLogLevel(cfg.Level)) } + if sub.GetLevel() == zerolog.DebugLevel { + sub = sub.With().Caller().Logger() + } return sub }