diff --git a/cmd/tinyauth/create_user.go b/cmd/tinyauth/create_user.go index ef5fe266..d7e9f97e 100644 --- a/cmd/tinyauth/create_user.go +++ b/cmd/tinyauth/create_user.go @@ -6,8 +6,8 @@ import ( "strings" "charm.land/huh/v2" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/paerser/cli" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" "golang.org/x/crypto/bcrypt" ) @@ -40,7 +40,8 @@ func createUserCmd() *cli.Command { Configuration: tCfg, Resources: loaders, Run: func(_ []string) error { - tlog.NewSimpleLogger().Init() + log := logger.NewLogger().WithSimpleConfig() + log.Init() if tCfg.Interactive { form := huh.NewForm( @@ -73,7 +74,7 @@ func createUserCmd() *cli.Command { return errors.New("username and password cannot be empty") } - tlog.App.Info().Str("username", tCfg.Username).Msg("Creating user") + log.App.Info().Str("username", tCfg.Username).Msg("Creating user") passwd, err := bcrypt.GenerateFromPassword([]byte(tCfg.Password), bcrypt.DefaultCost) if err != nil { @@ -86,7 +87,7 @@ func createUserCmd() *cli.Command { passwdStr = strings.ReplaceAll(passwdStr, "$", "$$") } - tlog.App.Info().Str("user", fmt.Sprintf("%s:%s", tCfg.Username, passwdStr)).Msg("User created") + log.App.Info().Str("user", fmt.Sprintf("%s:%s", tCfg.Username, passwdStr)).Msg("User created") return nil }, diff --git a/cmd/tinyauth/generate_totp.go b/cmd/tinyauth/generate_totp.go index 8819922e..8492f87b 100644 --- a/cmd/tinyauth/generate_totp.go +++ b/cmd/tinyauth/generate_totp.go @@ -7,7 +7,7 @@ import ( "strings" "github.com/tinyauthapp/tinyauth/internal/utils" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" "charm.land/huh/v2" "github.com/mdp/qrterminal/v3" @@ -40,7 +40,8 @@ func generateTotpCmd() *cli.Command { Configuration: tCfg, Resources: loaders, Run: func(_ []string) error { - tlog.NewSimpleLogger().Init() + log := logger.NewLogger().WithSimpleConfig() + log.Init() if tCfg.Interactive { form := huh.NewForm( @@ -88,9 +89,9 @@ func generateTotpCmd() *cli.Command { secret := key.Secret() - tlog.App.Info().Str("secret", secret).Msg("Generated TOTP secret") + log.App.Info().Str("secret", secret).Msg("Generated TOTP secret") - tlog.App.Info().Msg("Generated QR code") + log.App.Info().Msg("Generated QR code") config := qrterminal.Config{ Level: qrterminal.L, @@ -109,7 +110,7 @@ func generateTotpCmd() *cli.Command { user.Password = strings.ReplaceAll(user.Password, "$", "$$") } - tlog.App.Info().Str("user", fmt.Sprintf("%s:%s:%s", user.Username, user.Password, user.TOTPSecret)).Msg("Add the totp secret to your authenticator app then use the verify command to ensure everything is working correctly.") + log.App.Info().Str("user", fmt.Sprintf("%s:%s:%s", user.Username, user.Password, user.TOTPSecret)).Msg("Add the totp secret to your authenticator app then use the verify command to ensure everything is working correctly.") return nil }, diff --git a/cmd/tinyauth/healthcheck.go b/cmd/tinyauth/healthcheck.go index 649a68c7..921479a5 100644 --- a/cmd/tinyauth/healthcheck.go +++ b/cmd/tinyauth/healthcheck.go @@ -9,8 +9,8 @@ import ( "os" "time" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/paerser/cli" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) type healthzResponse struct { @@ -26,7 +26,8 @@ func healthcheckCmd() *cli.Command { Resources: nil, AllowArg: true, Run: func(args []string) error { - tlog.NewSimpleLogger().Init() + log := logger.NewLogger().WithSimpleConfig() + log.Init() srvAddr := os.Getenv("TINYAUTH_SERVER_ADDRESS") if srvAddr == "" { @@ -48,7 +49,7 @@ func healthcheckCmd() *cli.Command { return errors.New("Could not determine app URL") } - tlog.App.Info().Str("app_url", appUrl).Msg("Performing health check") + log.App.Info().Str("app_url", appUrl).Msg("Performing health check") client := http.Client{ Timeout: 30 * time.Second, @@ -86,7 +87,7 @@ func healthcheckCmd() *cli.Command { return fmt.Errorf("failed to decode response: %w", err) } - tlog.App.Info().Interface("response", healthResp).Msg("Tinyauth is healthy") + log.App.Info().Interface("response", healthResp).Msg("Tinyauth is healthy") return nil }, diff --git a/cmd/tinyauth/tinyauth.go b/cmd/tinyauth/tinyauth.go index f5bbb19f..b6293718 100644 --- a/cmd/tinyauth/tinyauth.go +++ b/cmd/tinyauth/tinyauth.go @@ -7,7 +7,6 @@ import ( "github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/utils/loaders" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/rs/zerolog/log" "github.com/tinyauthapp/paerser/cli" @@ -109,11 +108,6 @@ func main() { } func runCmd(cfg model.Config) error { - logger := tlog.NewLogger(cfg.Log) - logger.Init() - - tlog.App.Info().Str("version", model.Version).Msg("Starting tinyauth") - app := bootstrap.NewBootstrapApp(cfg) err := app.Setup() diff --git a/cmd/tinyauth/verify_user.go b/cmd/tinyauth/verify_user.go index 5058b606..b0347f6f 100644 --- a/cmd/tinyauth/verify_user.go +++ b/cmd/tinyauth/verify_user.go @@ -5,7 +5,7 @@ import ( "fmt" "github.com/tinyauthapp/tinyauth/internal/utils" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" "charm.land/huh/v2" "github.com/pquerna/otp/totp" @@ -44,7 +44,8 @@ func verifyUserCmd() *cli.Command { Configuration: tCfg, Resources: loaders, Run: func(_ []string) error { - tlog.NewSimpleLogger().Init() + log := logger.NewLogger().WithSimpleConfig() + log.Init() if tCfg.Interactive { form := huh.NewForm( @@ -97,9 +98,9 @@ func verifyUserCmd() *cli.Command { if user.TOTPSecret == "" { if tCfg.Totp != "" { - tlog.App.Warn().Msg("User does not have TOTP secret") + log.App.Warn().Msg("User does not have TOTP secret") } - tlog.App.Info().Msg("User verified") + log.App.Info().Msg("User verified") return nil } @@ -109,7 +110,7 @@ func verifyUserCmd() *cli.Command { return fmt.Errorf("TOTP code incorrect") } - tlog.App.Info().Msg("User verified") + log.App.Info().Msg("User verified") return nil }, diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go index 5b342c48..3f491fa1 100644 --- a/internal/bootstrap/app_bootstrap.go +++ b/internal/bootstrap/app_bootstrap.go @@ -3,39 +3,50 @@ package bootstrap import ( "bytes" "context" + "database/sql" "encoding/json" + "errors" "fmt" + "net" "net/http" "net/url" "os" + "os/signal" "sort" "strings" + "sync" + "syscall" "time" - "github.com/tinyauthapp/tinyauth/internal/controller" + "github.com/gin-gonic/gin" "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" + "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) +type Services struct { + accessControlService *service.AccessControlsService + authService *service.AuthService + dockerService *service.DockerService + kubernetesService *service.KubernetesService + ldapService *service.LdapService + oauthBrokerService *service.OAuthBrokerService + oidcService *service.OIDCService +} + type BootstrapApp struct { - config model.Config - context struct { - appUrl string - uuid string - cookieDomain string - sessionCookieName string - csrfCookieName string - redirectCookieName string - oauthSessionCookieName string - localUsers *[]model.LocalUser - oauthProviders map[string]model.OAuthServiceConfig - oauthWhitelist []string - configuredProviders []controller.Provider - oidcClients []model.OIDCClientConfig - } + config model.Config + runtime model.RuntimeConfig services Services + log *logger.Logger + ctx context.Context + cancel context.CancelFunc + queries *repository.Queries + router *gin.Engine + db *sql.DB + wg sync.WaitGroup } func NewBootstrapApp(config model.Config) *BootstrapApp { @@ -45,56 +56,69 @@ func NewBootstrapApp(config model.Config) *BootstrapApp { } func (app *BootstrapApp) Setup() error { + // create context + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + app.ctx = ctx + app.cancel = cancel + + // setup logger + log := logger.NewLogger().WithConfig(app.config.Log) + log.Init() + app.log = log + // get app url if app.config.AppURL == "" { - return fmt.Errorf("app URL cannot be empty, perhaps config loading failed") + return errors.New("app url cannot be empty, perhaps config loading failed") } appUrl, err := url.Parse(app.config.AppURL) if err != nil { - return err + return fmt.Errorf("failed to parse app url: %w", err) } - app.context.appUrl = appUrl.Scheme + "://" + appUrl.Host + app.runtime.AppURL = appUrl.Scheme + "://" + appUrl.Host // validate session config if app.config.Auth.SessionMaxLifetime != 0 && app.config.Auth.SessionMaxLifetime < app.config.Auth.SessionExpiry { - return fmt.Errorf("session max lifetime cannot be less than session expiry") + return errors.New("session max lifetime cannot be less than session expiry") } - // Parse users + // parse users users, err := utils.GetUsers(app.config.Auth.Users, app.config.Auth.UsersFile, app.config.Auth.UserAttributes) if err != nil { - return err + return fmt.Errorf("failed to load users: %w", err) } - app.context.localUsers = users + app.runtime.LocalUsers = *users + // load oauth whitelist oauthWhitelist, err := utils.GetStringList(app.config.OAuth.Whitelist, app.config.OAuth.WhitelistFile) + if err != nil { - return err + return fmt.Errorf("failed to load oauth whitelist: %w", err) } - app.context.oauthWhitelist = oauthWhitelist + app.runtime.OAuthWhitelist = oauthWhitelist - // Setup OAuth providers - app.context.oauthProviders = app.config.OAuth.Providers + // setup oauth providers + app.runtime.OAuthProviders = app.config.OAuth.Providers - for name, provider := range app.context.oauthProviders { + for id, provider := range app.runtime.OAuthProviders { secret := utils.GetSecret(provider.ClientSecret, provider.ClientSecretFile) provider.ClientSecret = secret provider.ClientSecretFile = "" if provider.RedirectURL == "" { - provider.RedirectURL = app.context.appUrl + "/api/oauth/callback/" + name + provider.RedirectURL = app.runtime.AppURL + "/api/oauth/callback/" + id } - app.context.oauthProviders[name] = provider + app.runtime.OAuthProviders[id] = provider } - for id, provider := range app.context.oauthProviders { + // set presets for built-in providers + for id, provider := range app.runtime.OAuthProviders { if provider.Name == "" { if name, ok := model.OverrideProviders[id]; ok { provider.Name = name @@ -102,71 +126,72 @@ func (app *BootstrapApp) Setup() error { provider.Name = utils.Capitalize(id) } } - app.context.oauthProviders[id] = provider + app.runtime.OAuthProviders[id] = provider } - // Setup OIDC clients + // setup oidc clients for id, client := range app.config.OIDC.Clients { client.ID = id - app.context.oidcClients = append(app.context.oidcClients, client) + app.runtime.OIDCClients = append(app.runtime.OIDCClients, client) } - // Get cookie domain + // cookie domain cookieDomainResolver := utils.GetCookieDomain + if !app.config.Auth.SubdomainsEnabled { - tlog.App.Info().Msg("Subdomains disabled, automatic authentication for proxied apps will not work") + app.log.App.Warn().Msg("Subdomains are disabled, using standalone cookie domain resolver which will not work with subdomains") cookieDomainResolver = utils.GetStandaloneCookieDomain } - cookieDomain, err := cookieDomainResolver(app.context.appUrl) + cookieDomain, err := cookieDomainResolver(app.runtime.AppURL) if err != nil { - return err + return fmt.Errorf("failed to get cookie domain: %w", err) } - app.context.cookieDomain = cookieDomain + app.runtime.CookieDomain = cookieDomain - // Cookie names - app.context.uuid = utils.GenerateUUID(appUrl.Hostname()) - cookieId := strings.Split(app.context.uuid, "-")[0] - app.context.sessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId) - app.context.csrfCookieName = fmt.Sprintf("%s-%s", model.CSRFCookieName, cookieId) - app.context.redirectCookieName = fmt.Sprintf("%s-%s", model.RedirectCookieName, cookieId) - app.context.oauthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId) + // cookie names + app.runtime.UUID = utils.GenerateUUID(appUrl.Hostname()) - // Dumps - tlog.App.Trace().Interface("config", app.config).Msg("Config dump") - tlog.App.Trace().Interface("users", app.context.localUsers).Msg("Users dump") - tlog.App.Trace().Interface("oauthProviders", app.context.oauthProviders).Msg("OAuth providers dump") - tlog.App.Trace().Str("cookieDomain", app.context.cookieDomain).Msg("Cookie domain") - tlog.App.Trace().Str("sessionCookieName", app.context.sessionCookieName).Msg("Session cookie name") - tlog.App.Trace().Str("csrfCookieName", app.context.csrfCookieName).Msg("CSRF cookie name") - tlog.App.Trace().Str("redirectCookieName", app.context.redirectCookieName).Msg("Redirect cookie name") + cookieId := strings.Split(app.runtime.UUID, "-")[0] // first 8 characters of the uuid should be good enough - // Database - db, err := app.SetupDatabase(app.config.Database.Path) + app.runtime.SessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId) + app.runtime.CSRFCookieName = fmt.Sprintf("%s-%s", model.CSRFCookieName, cookieId) + app.runtime.RedirectCookieName = fmt.Sprintf("%s-%s", model.RedirectCookieName, cookieId) + app.runtime.OAuthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId) + + // database + err = app.SetupDatabase() if err != nil { return fmt.Errorf("failed to setup database: %w", err) } - // Queries - queries := repository.New(db) + // after this point, we start initializing dependencies so it's a good time to setup a defer + // to ensure that resources are cleaned up properly in case of an error during initialization + defer func() { + app.cancel() + app.wg.Wait() + app.db.Close() + }() - // Services - services, err := app.initServices(queries) + // queries + queries := repository.New(app.db) + app.queries = queries + + // services + err = app.setupServices() if err != nil { return fmt.Errorf("failed to initialize services: %w", err) } - app.services = services + // configured providers + configuredProviders := make([]model.Provider, 0) - // Configured providers - configuredProviders := make([]controller.Provider, 0) - - for id, provider := range app.context.oauthProviders { - configuredProviders = append(configuredProviders, controller.Provider{ + for id, provider := range app.runtime.OAuthProviders { + configuredProviders = append(configuredProviders, model.Provider{ Name: provider.Name, ID: id, OAuth: true, @@ -177,70 +202,171 @@ func (app *BootstrapApp) Setup() error { return configuredProviders[i].Name < configuredProviders[j].Name }) - if services.authService.LocalAuthConfigured() { - configuredProviders = append(configuredProviders, controller.Provider{ + if app.services.authService.LocalAuthConfigured() { + configuredProviders = append(configuredProviders, model.Provider{ Name: "Local", ID: "local", OAuth: false, }) } - if services.authService.LDAPAuthConfigured() { - configuredProviders = append(configuredProviders, controller.Provider{ + if app.services.authService.LDAPAuthConfigured() { + configuredProviders = append(configuredProviders, model.Provider{ Name: "LDAP", ID: "ldap", OAuth: false, }) } - tlog.App.Debug().Interface("providers", configuredProviders).Msg("Authentication providers") - if len(configuredProviders) == 0 { - return fmt.Errorf("no authentication providers configured") + return errors.New("no authentication providers configured") } - app.context.configuredProviders = configuredProviders + for _, provider := range configuredProviders { + app.log.App.Debug().Str("provider", provider.Name).Msg("Configured authentication provider") + } - // Setup router - router, err := app.setupRouter() + app.runtime.ConfiguredProviders = configuredProviders + + // setup router + err = app.setupRouter() if err != nil { return fmt.Errorf("failed to setup routes: %w", err) } - // Start db cleanup routine - tlog.App.Debug().Msg("Starting database cleanup routine") - go app.dbCleanupRoutine(queries) + // start db cleanup routine + app.log.App.Debug().Msg("Starting database cleanup routine") + app.wg.Go(app.dbCleanupRoutine) - // If analytics are not disabled, start heartbeat + // if analytics are not disabled, start heartbeat if app.config.Analytics.Enabled { - tlog.App.Debug().Msg("Starting heartbeat routine") - go app.heartbeatRoutine() + app.log.App.Debug().Msg("Starting heartbeat routine") + app.wg.Go(app.heartbeatRoutine) } - // If we have an socket path, bind to it - if app.config.Server.SocketPath != "" { - if _, err := os.Stat(app.config.Server.SocketPath); err == nil { - tlog.App.Info().Msgf("Removing existing socket file %s", app.config.Server.SocketPath) - err := os.Remove(app.config.Server.SocketPath) + // create err channel to listen for server errors + errChanLen := 0 + + runUnix := app.config.Server.SocketPath != "" + runHTTP := app.config.Server.SocketPath == "" || app.config.Server.ConcurrentListenersEnabled + + if runUnix { + errChanLen++ + } + + if runHTTP { + errChanLen++ + } + + errChan := make(chan error, errChanLen) + + if app.config.Server.ConcurrentListenersEnabled { + app.log.App.Info().Msg("Concurrent listeners enabled, will run on all available listeners") + } + + // serve unix + if runUnix { + app.wg.Go(func() { + if err := app.serveUnix(); err != nil { + errChan <- err + } + }) + } + + // serve to http + if runHTTP { + app.wg.Go(func() { + if err := app.serveHTTP(); err != nil { + errChan <- err + } + }) + } + + // monitor cancellation and server errors + for { + select { + case <-app.ctx.Done(): + app.log.App.Info().Msg("Oh, it's time for me to go, bye!") + return nil + case err := <-errChan: if err != nil { - return fmt.Errorf("failed to remove existing socket file: %w", err) + return fmt.Errorf("server error: %w", err) } } + } +} - tlog.App.Info().Msgf("Starting server on unix socket %s", app.config.Server.SocketPath) - if err := router.RunUnix(app.config.Server.SocketPath); err != nil { - tlog.App.Fatal().Err(err).Msg("Failed to start server") - } +func (app *BootstrapApp) serveHTTP() error { + address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port) + app.log.App.Info().Msgf("Starting server on %s", address) + + server := &http.Server{ + Addr: address, + Handler: app.router.Handler(), + } + + go func() { + <-app.ctx.Done() + app.log.App.Debug().Msg("Shutting down http listener") + server.Shutdown(app.ctx) + }() + + err := server.ListenAndServe() + + if err != nil && !errors.Is(err, http.ErrServerClosed) { + return fmt.Errorf("failed to start http listener: %w", err) + } + + return nil +} + +func (app *BootstrapApp) serveUnix() error { + if app.config.Server.SocketPath == "" { return nil } - // Start server - address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port) - tlog.App.Info().Msgf("Starting server on %s", address) - if err := router.Run(address); err != nil { - tlog.App.Fatal().Err(err).Msg("Failed to start server") + _, err := os.Stat(app.config.Server.SocketPath) + + if err == nil { + app.log.App.Info().Msgf("Removing existing socket file %s", app.config.Server.SocketPath) + err := os.Remove(app.config.Server.SocketPath) + + if err != nil { + return fmt.Errorf("failed to remove existing socket file: %w", err) + } + } + + app.log.App.Info().Msgf("Starting server on unix socket %s", app.config.Server.SocketPath) + + listener, err := net.Listen("unix", app.config.Server.SocketPath) + + if err != nil { + return fmt.Errorf("failed to create unix socket listener: %w", err) + } + + server := &http.Server{ + Handler: app.router.Handler(), + } + + shutdown := func() { + server.Shutdown(app.ctx) + listener.Close() + os.Remove(app.config.Server.SocketPath) + } + + go func() { + <-app.ctx.Done() + app.log.App.Debug().Msg("Shutting down unix socket listener") + shutdown() + }() + + err = server.Serve(listener) + + if err != nil && !errors.Is(err, http.ErrServerClosed) { + shutdown() + return fmt.Errorf("failed to start unix socket listener: %w", err) } return nil @@ -250,20 +376,20 @@ func (app *BootstrapApp) heartbeatRoutine() { ticker := time.NewTicker(time.Duration(12) * time.Hour) defer ticker.Stop() - type heartbeat struct { + type Heartbeat struct { UUID string `json:"uuid"` Version string `json:"version"` } - var body heartbeat + var body Heartbeat - body.UUID = app.context.uuid + body.UUID = app.runtime.UUID body.Version = model.Version bodyJson, err := json.Marshal(body) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to marshal heartbeat body") + app.log.App.Error().Err(err).Msg("Failed to marshal heartbeat body, heartbeat routine will not start") return } @@ -273,43 +399,60 @@ func (app *BootstrapApp) heartbeatRoutine() { heartbeatURL := model.APIServer + "/v1/instances/heartbeat" - for range ticker.C { - tlog.App.Debug().Msg("Sending heartbeat") + for { + select { + case <-ticker.C: + app.log.App.Debug().Msg("Sending heartbeat") - req, err := http.NewRequest(http.MethodPost, heartbeatURL, bytes.NewReader(bodyJson)) + req, err := http.NewRequest(http.MethodPost, heartbeatURL, bytes.NewReader(bodyJson)) - if err != nil { - tlog.App.Error().Err(err).Msg("Failed to create heartbeat request") - continue - } + if err != nil { + app.log.App.Error().Err(err).Msg("Failed to create heartbeat request") + continue + } - req.Header.Add("Content-Type", "application/json") + req.Header.Add("Content-Type", "application/json") - res, err := client.Do(req) + res, err := client.Do(req) - if err != nil { - tlog.App.Error().Err(err).Msg("Failed to send heartbeat") - continue - } + if err != nil { + app.log.App.Error().Err(err).Msg("Failed to send heartbeat") + continue + } - res.Body.Close() + res.Body.Close() - if res.StatusCode != 200 && res.StatusCode != 201 { - tlog.App.Debug().Str("status", res.Status).Msg("Heartbeat returned non-200/201 status") + if res.StatusCode != 200 && res.StatusCode != 201 { + app.log.App.Debug().Str("status", res.Status).Msg("Heartbeat returned non-200/201 status") + } + case <-app.ctx.Done(): + app.log.App.Debug().Msg("Stopping heartbeat routine") + ticker.Stop() + return } } } -func (app *BootstrapApp) dbCleanupRoutine(queries *repository.Queries) { +func (app *BootstrapApp) dbCleanupRoutine() { ticker := time.NewTicker(time.Duration(30) * time.Minute) defer ticker.Stop() - ctx := context.Background() - for range ticker.C { - tlog.App.Debug().Msg("Cleaning up old database sessions") - err := queries.DeleteExpiredSessions(ctx, time.Now().Unix()) - if err != nil { - tlog.App.Error().Err(err).Msg("Failed to clean up old database sessions") + for { + select { + case <-ticker.C: + app.log.App.Debug().Msg("Running database cleanup") + + err := app.queries.DeleteExpiredSessions(app.ctx, time.Now().Unix()) + + if err != nil { + app.log.App.Error().Err(err).Msg("Failed to delete expired sessions") + } + + app.log.App.Debug().Msg("Database cleanup completed") + case <-app.ctx.Done(): + app.log.App.Debug().Msg("Stopping database cleanup routine") + ticker.Stop() + return } } } diff --git a/internal/bootstrap/db_bootstrap.go b/internal/bootstrap/db_bootstrap.go index 3f48f793..d8572c4c 100644 --- a/internal/bootstrap/db_bootstrap.go +++ b/internal/bootstrap/db_bootstrap.go @@ -14,19 +14,26 @@ import ( _ "modernc.org/sqlite" ) -func (app *BootstrapApp) SetupDatabase(databasePath string) (*sql.DB, error) { - dir := filepath.Dir(databasePath) +func (app *BootstrapApp) SetupDatabase() error { + dir := filepath.Dir(app.config.Database.Path) if err := os.MkdirAll(dir, 0750); err != nil { - return nil, fmt.Errorf("failed to create database directory %s: %w", dir, err) + return fmt.Errorf("failed to create database directory %s: %w", dir, err) } - db, err := sql.Open("sqlite", databasePath) + db, err := sql.Open("sqlite", app.config.Database.Path) if err != nil { - return nil, fmt.Errorf("failed to open database: %w", err) + return fmt.Errorf("failed to open database: %w", err) } + // Close the database if there is an error during migration + defer func() { + if err != nil { + db.Close() + } + }() + // Limit to 1 connection to sequence writes, this may need to be revisited in the future // if the sqlite connection starts being a bottleneck db.SetMaxOpenConns(1) @@ -34,24 +41,29 @@ func (app *BootstrapApp) SetupDatabase(databasePath string) (*sql.DB, error) { migrations, err := iofs.New(assets.Migrations, "migrations") if err != nil { - return nil, fmt.Errorf("failed to create migrations: %w", err) + return fmt.Errorf("failed to create migrations: %w", err) } target, err := sqlite3.WithInstance(db, &sqlite3.Config{}) if err != nil { - return nil, fmt.Errorf("failed to create sqlite3 instance: %w", err) + return fmt.Errorf("failed to create sqlite3 instance: %w", err) } migrator, err := migrate.NewWithInstance("iofs", migrations, "sqlite3", target) if err != nil { - return nil, fmt.Errorf("failed to create migrator: %w", err) + return fmt.Errorf("failed to create migrator: %w", err) } if err := migrator.Up(); err != nil && err != migrate.ErrNoChange { - return nil, fmt.Errorf("failed to migrate database: %w", err) + return fmt.Errorf("failed to migrate database: %w", err) } - return db, nil + app.db = db + return nil +} + +func (app *BootstrapApp) GetDB() *sql.DB { + return app.db } diff --git a/internal/bootstrap/router_bootstrap.go b/internal/bootstrap/router_bootstrap.go index a746be79..12a48bc0 100644 --- a/internal/bootstrap/router_bootstrap.go +++ b/internal/bootstrap/router_bootstrap.go @@ -2,21 +2,16 @@ package bootstrap import ( "fmt" - "slices" "github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/middleware" - "github.com/tinyauthapp/tinyauth/internal/model" "github.com/gin-gonic/gin" ) -var DEV_MODES = []string{"main", "test", "development"} - -func (app *BootstrapApp) setupRouter() (*gin.Engine, error) { - if !slices.Contains(DEV_MODES, model.Version) { - gin.SetMode(gin.ReleaseMode) - } +func (app *BootstrapApp) setupRouter() error { + // we don't want gin debug mode + gin.SetMode(gin.ReleaseMode) engine := gin.New() engine.Use(gin.Recovery()) @@ -25,101 +20,36 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) { err := engine.SetTrustedProxies(app.config.Auth.TrustedProxies) if err != nil { - return nil, fmt.Errorf("failed to set trusted proxies: %w", err) + return fmt.Errorf("failed to set trusted proxies: %w", err) } } - contextMiddleware := middleware.NewContextMiddleware(middleware.ContextMiddlewareConfig{ - CookieDomain: app.context.cookieDomain, - SessionCookieName: app.context.sessionCookieName, - }, app.services.authService, app.services.oauthBrokerService) - - err := contextMiddleware.Init() - - if err != nil { - return nil, fmt.Errorf("failed to initialize context middleware: %w", err) - } - + contextMiddleware := middleware.NewContextMiddleware(app.log, app.runtime, app.services.authService, app.services.oauthBrokerService) engine.Use(contextMiddleware.Middleware()) - uiMiddleware := middleware.NewUIMiddleware() - - err = uiMiddleware.Init() + uiMiddleware, err := middleware.NewUIMiddleware() if err != nil { - return nil, fmt.Errorf("failed to initialize UI middleware: %w", err) + return fmt.Errorf("failed to initialize UI middleware: %w", err) } engine.Use(uiMiddleware.Middleware()) - zerologMiddleware := middleware.NewZerologMiddleware() - - err = zerologMiddleware.Init() - - if err != nil { - return nil, fmt.Errorf("failed to initialize zerolog middleware: %w", err) - } + zerologMiddleware := middleware.NewZerologMiddleware(app.log) engine.Use(zerologMiddleware.Middleware()) apiRouter := engine.Group("/api") - contextController := controller.NewContextController(controller.ContextControllerConfig{ - Providers: app.context.configuredProviders, - Title: app.config.UI.Title, - AppURL: app.config.AppURL, - CookieDomain: app.context.cookieDomain, - ForgotPasswordMessage: app.config.UI.ForgotPasswordMessage, - BackgroundImage: app.config.UI.BackgroundImage, - OAuthAutoRedirect: app.config.OAuth.AutoRedirect, - WarningsEnabled: app.config.UI.WarningsEnabled, - }, apiRouter) + controller.NewContextController(app.log, app.config, app.runtime, apiRouter) + controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService) + controller.NewOIDCController(app.log, app.services.oidcService, app.runtime, apiRouter) + controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService) + controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService) + controller.NewResourcesController(app.config, &engine.RouterGroup) + controller.NewHealthController(apiRouter) + controller.NewWellKnownController(app.services.oidcService, &engine.RouterGroup) - contextController.SetupRoutes() - - oauthController := controller.NewOAuthController(controller.OAuthControllerConfig{ - AppURL: app.config.AppURL, - SecureCookie: app.config.Auth.SecureCookie, - CSRFCookieName: app.context.csrfCookieName, - RedirectCookieName: app.context.redirectCookieName, - CookieDomain: app.context.cookieDomain, - OAuthSessionCookieName: app.context.oauthSessionCookieName, - SubdomainsEnabled: app.config.Auth.SubdomainsEnabled, - }, apiRouter, app.services.authService) - - oauthController.SetupRoutes() - - oidcController := controller.NewOIDCController(controller.OIDCControllerConfig{}, app.services.oidcService, apiRouter) - - oidcController.SetupRoutes() - - proxyController := controller.NewProxyController(controller.ProxyControllerConfig{ - AppURL: app.config.AppURL, - }, apiRouter, app.services.accessControlService, app.services.authService) - - proxyController.SetupRoutes() - - userController := controller.NewUserController(controller.UserControllerConfig{ - CookieDomain: app.context.cookieDomain, - SessionCookieName: app.context.sessionCookieName, - }, apiRouter, app.services.authService) - - userController.SetupRoutes() - - resourcesController := controller.NewResourcesController(controller.ResourcesControllerConfig{ - Path: app.config.Resources.Path, - Enabled: app.config.Resources.Enabled, - }, &engine.RouterGroup) - - resourcesController.SetupRoutes() - - healthController := controller.NewHealthController(apiRouter) - - healthController.SetupRoutes() - - wellknownController := controller.NewWellKnownController(controller.WellKnownControllerConfig{}, app.services.oidcService, engine) - - wellknownController.SetupRoutes() - - return engine, nil + app.router = engine + return nil } diff --git a/internal/bootstrap/service_bootstrap.go b/internal/bootstrap/service_bootstrap.go index 09485bd0..ef3ee591 100644 --- a/internal/bootstrap/service_bootstrap.go +++ b/internal/bootstrap/service_bootstrap.go @@ -1,131 +1,66 @@ package bootstrap import ( + "fmt" "os" - "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" ) -type Services struct { - accessControlService *service.AccessControlsService - authService *service.AuthService - dockerService *service.DockerService - kubernetesService *service.KubernetesService - ldapService *service.LdapService - oauthBrokerService *service.OAuthBrokerService - oidcService *service.OIDCService -} - -func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, error) { - services := Services{} - - ldapService := service.NewLdapService(service.LdapServiceConfig{ - Address: app.config.LDAP.Address, - BindDN: app.config.LDAP.BindDN, - BindPassword: app.config.LDAP.BindPassword, - BaseDN: app.config.LDAP.BaseDN, - Insecure: app.config.LDAP.Insecure, - SearchFilter: app.config.LDAP.SearchFilter, - AuthCert: app.config.LDAP.AuthCert, - AuthKey: app.config.LDAP.AuthKey, - }) - - err := ldapService.Init() +func (app *BootstrapApp) setupServices() error { + ldapService, err := service.NewLdapService(app.log, app.config, app.ctx, &app.wg) if err != nil { - tlog.App.Warn().Err(err).Msg("Failed to setup LDAP service, starting without it") - ldapService.Unconfigure() + app.log.App.Warn().Err(err).Msg("Failed to initialize LDAP connection, will continue without it") } - services.ldapService = ldapService - - var labelProvider service.LabelProvider - var dockerService *service.DockerService - var kubernetesService *service.KubernetesService + app.services.ldapService = ldapService useKubernetes := app.config.LabelProvider == "kubernetes" || (app.config.LabelProvider == "auto" && os.Getenv("KUBERNETES_SERVICE_HOST") != "") + var labelProvider service.LabelProvider + if useKubernetes { - tlog.App.Debug().Msg("Using Kubernetes label provider") - kubernetesService = service.NewKubernetesService() - err = kubernetesService.Init() + app.log.App.Debug().Msg("Using Kubernetes label provider") + + kubernetesService, err := service.NewKubernetesService(app.log, app.ctx, &app.wg) + if err != nil { - return Services{}, err + return fmt.Errorf("failed to initialize kubernetes service: %w", err) } - services.kubernetesService = kubernetesService + + app.services.kubernetesService = kubernetesService labelProvider = kubernetesService } else { - tlog.App.Debug().Msg("Using Docker label provider") - dockerService = service.NewDockerService() - err = dockerService.Init() + app.log.App.Debug().Msg("Using Docker label provider") + + dockerService, err := service.NewDockerService(app.log, app.ctx, &app.wg) + if err != nil { - return Services{}, err + return fmt.Errorf("failed to initialize docker service: %w", err) } - services.dockerService = dockerService + + app.services.dockerService = dockerService labelProvider = dockerService } - accessControlsService := service.NewAccessControlsService(labelProvider, app.config.Apps) + accessControlsService := service.NewAccessControlsService(app.log, &labelProvider, app.config.Apps) + app.services.accessControlService = accessControlsService - err = accessControlsService.Init() + oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders, app.ctx) + app.services.oauthBrokerService = oauthBrokerService + + authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, &app.wg, app.services.ldapService, app.queries, app.services.oauthBrokerService) + app.services.authService = authService + + oidcService, err := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ctx, &app.wg) if err != nil { - return Services{}, err + return fmt.Errorf("failed to initialize oidc service: %w", err) } - services.accessControlService = accessControlsService + app.services.oidcService = oidcService - oauthBrokerService := service.NewOAuthBrokerService(app.context.oauthProviders) - - err = oauthBrokerService.Init() - - if err != nil { - return Services{}, err - } - - services.oauthBrokerService = oauthBrokerService - - authService := service.NewAuthService(service.AuthServiceConfig{ - LocalUsers: app.context.localUsers, - OauthWhitelist: app.context.oauthWhitelist, - SessionExpiry: app.config.Auth.SessionExpiry, - SessionMaxLifetime: app.config.Auth.SessionMaxLifetime, - SecureCookie: app.config.Auth.SecureCookie, - CookieDomain: app.context.cookieDomain, - LoginTimeout: app.config.Auth.LoginTimeout, - LoginMaxRetries: app.config.Auth.LoginMaxRetries, - SessionCookieName: app.context.sessionCookieName, - IP: app.config.Auth.IP, - LDAPGroupsCacheTTL: app.config.LDAP.GroupCacheTTL, - SubdomainsEnabled: app.config.Auth.SubdomainsEnabled, - }, services.ldapService, queries, services.oauthBrokerService) - - err = authService.Init() - - if err != nil { - return Services{}, err - } - - services.authService = authService - - oidcService := service.NewOIDCService(service.OIDCServiceConfig{ - Clients: app.config.OIDC.Clients, - PrivateKeyPath: app.config.OIDC.PrivateKeyPath, - PublicKeyPath: app.config.OIDC.PublicKeyPath, - Issuer: app.config.AppURL, - SessionExpiry: app.config.Auth.SessionExpiry, - }, queries) - - err = oidcService.Init() - - if err != nil { - return Services{}, err - } - - services.oidcService = oidcService - - return services, nil + return nil } diff --git a/internal/controller/context_controller.go b/internal/controller/context_controller.go index f939ba55..8d9f5fa2 100644 --- a/internal/controller/context_controller.go +++ b/internal/controller/context_controller.go @@ -5,7 +5,7 @@ import ( "net/url" "github.com/tinyauthapp/tinyauth/internal/model" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/gin-gonic/gin" ) @@ -24,62 +24,52 @@ type UserContextResponse struct { } type AppContextResponse struct { - Status int `json:"status"` - Message string `json:"message"` - Providers []Provider `json:"providers"` - Title string `json:"title"` - AppURL string `json:"appUrl"` - CookieDomain string `json:"cookieDomain"` - ForgotPasswordMessage string `json:"forgotPasswordMessage"` - BackgroundImage string `json:"backgroundImage"` - OAuthAutoRedirect string `json:"oauthAutoRedirect"` - WarningsEnabled bool `json:"warningsEnabled"` -} - -type Provider struct { - Name string `json:"name"` - ID string `json:"id"` - OAuth bool `json:"oauth"` -} - -type ContextControllerConfig struct { - Providers []Provider - Title string - AppURL string - CookieDomain string - ForgotPasswordMessage string - BackgroundImage string - OAuthAutoRedirect string - WarningsEnabled bool + Status int `json:"status"` + Message string `json:"message"` + Providers []model.Provider `json:"providers"` + Title string `json:"title"` + AppURL string `json:"appUrl"` + CookieDomain string `json:"cookieDomain"` + ForgotPasswordMessage string `json:"forgotPasswordMessage"` + BackgroundImage string `json:"backgroundImage"` + OAuthAutoRedirect string `json:"oauthAutoRedirect"` + WarningsEnabled bool `json:"warningsEnabled"` } type ContextController struct { - config ContextControllerConfig - router *gin.RouterGroup + log *logger.Logger + config model.Config + runtime model.RuntimeConfig } -func NewContextController(config ContextControllerConfig, router *gin.RouterGroup) *ContextController { - if !config.WarningsEnabled { - tlog.App.Warn().Msg("UI warnings are disabled. This may expose users to security risks. Proceed with caution.") +func NewContextController( + log *logger.Logger, + config model.Config, + runtimeConfig model.RuntimeConfig, + router *gin.RouterGroup, +) *ContextController { + controller := &ContextController{ + log: log, + config: config, + runtime: runtimeConfig, } - return &ContextController{ - config: config, - router: router, + if !config.UI.WarningsEnabled { + log.App.Warn().Msg("UI warnings are disabled. This may lead to security issues if you are not careful. Make sure to enable warnings in production environments.") } -} -func (controller *ContextController) SetupRoutes() { - contextGroup := controller.router.Group("/context") + contextGroup := router.Group("/context") contextGroup.GET("/user", controller.userContextHandler) contextGroup.GET("/app", controller.appContextHandler) + + return controller } func (controller *ContextController) userContextHandler(c *gin.Context) { context, err := new(model.UserContext).NewFromGin(c) if err != nil { - tlog.App.Debug().Err(err).Msg("No user context found in request") + controller.log.App.Error().Err(err).Msg("Failed to create user context from request") c.JSON(200, UserContextResponse{ Status: 401, Message: "Unauthorized", @@ -105,9 +95,10 @@ func (controller *ContextController) userContextHandler(c *gin.Context) { } func (controller *ContextController) appContextHandler(c *gin.Context) { - appUrl, err := url.Parse(controller.config.AppURL) + appUrl, err := url.Parse(controller.runtime.AppURL) + if err != nil { - tlog.App.Error().Err(err).Msg("Failed to parse app URL") + controller.log.App.Error().Err(err).Msg("Failed to parse app URL") c.JSON(500, gin.H{ "status": 500, "message": "Internal Server Error", @@ -118,13 +109,13 @@ func (controller *ContextController) appContextHandler(c *gin.Context) { c.JSON(200, AppContextResponse{ Status: 200, Message: "Success", - Providers: controller.config.Providers, - Title: controller.config.Title, + Providers: controller.runtime.ConfiguredProviders, + Title: controller.config.UI.Title, AppURL: fmt.Sprintf("%s://%s", appUrl.Scheme, appUrl.Host), - CookieDomain: controller.config.CookieDomain, - ForgotPasswordMessage: controller.config.ForgotPasswordMessage, - BackgroundImage: controller.config.BackgroundImage, - OAuthAutoRedirect: controller.config.OAuthAutoRedirect, - WarningsEnabled: controller.config.WarningsEnabled, + CookieDomain: controller.runtime.CookieDomain, + ForgotPasswordMessage: controller.config.UI.ForgotPasswordMessage, + BackgroundImage: controller.config.UI.BackgroundImage, + OAuthAutoRedirect: controller.config.OAuth.AutoRedirect, + WarningsEnabled: controller.config.UI.WarningsEnabled, }) } diff --git a/internal/controller/context_controller_test.go b/internal/controller/context_controller_test.go index 12a8e22b..177f4744 100644 --- a/internal/controller/context_controller_test.go +++ b/internal/controller/context_controller_test.go @@ -8,30 +8,19 @@ import ( "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/model" + "github.com/tinyauthapp/tinyauth/internal/test" "github.com/tinyauthapp/tinyauth/internal/utils" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) func TestContextController(t *testing.T) { - tlog.NewTestLogger().Init() - controllerConfig := controller.ContextControllerConfig{ - Providers: []controller.Provider{ - { - Name: "Local", - ID: "local", - OAuth: false, - }, - }, - Title: "Tinyauth", - AppURL: "https://tinyauth.example.com", - CookieDomain: "example.com", - ForgotPasswordMessage: "foo", - BackgroundImage: "/background.jpg", - OAuthAutoRedirect: "none", - WarningsEnabled: true, - } + log := logger.NewLogger().WithTestConfig() + log.Init() + + cfg, runtime := test.CreateTestConfigs(t) tests := []struct { description string @@ -47,17 +36,17 @@ func TestContextController(t *testing.T) { expectedAppContextResponse := controller.AppContextResponse{ Status: 200, Message: "Success", - Providers: controllerConfig.Providers, - Title: controllerConfig.Title, - AppURL: controllerConfig.AppURL, - CookieDomain: controllerConfig.CookieDomain, - ForgotPasswordMessage: controllerConfig.ForgotPasswordMessage, - BackgroundImage: controllerConfig.BackgroundImage, - OAuthAutoRedirect: controllerConfig.OAuthAutoRedirect, - WarningsEnabled: controllerConfig.WarningsEnabled, + Providers: runtime.ConfiguredProviders, + Title: cfg.UI.Title, + AppURL: runtime.AppURL, + CookieDomain: runtime.CookieDomain, + ForgotPasswordMessage: cfg.UI.ForgotPasswordMessage, + BackgroundImage: cfg.UI.BackgroundImage, + OAuthAutoRedirect: cfg.OAuth.AutoRedirect, + WarningsEnabled: cfg.UI.WarningsEnabled, } bytes, err := json.Marshal(expectedAppContextResponse) - assert.NoError(t, err) + require.NoError(t, err) return string(bytes) }(), }, @@ -71,7 +60,7 @@ func TestContextController(t *testing.T) { Message: "Unauthorized", } bytes, err := json.Marshal(expectedUserContextResponse) - assert.NoError(t, err) + require.NoError(t, err) return string(bytes) }(), }, @@ -86,7 +75,7 @@ func TestContextController(t *testing.T) { BaseContext: model.BaseContext{ Username: "johndoe", Name: "John Doe", - Email: utils.CompileUserEmail("johndoe", controllerConfig.CookieDomain), + Email: utils.CompileUserEmail("johndoe", runtime.CookieDomain), }, }, }) @@ -100,11 +89,11 @@ func TestContextController(t *testing.T) { IsLoggedIn: true, Username: "johndoe", Name: "John Doe", - Email: utils.CompileUserEmail("johndoe", controllerConfig.CookieDomain), + Email: utils.CompileUserEmail("johndoe", runtime.CookieDomain), Provider: "local", } bytes, err := json.Marshal(expectedUserContextResponse) - assert.NoError(t, err) + require.NoError(t, err) return string(bytes) }(), }, @@ -121,13 +110,12 @@ func TestContextController(t *testing.T) { group := router.Group("/api") gin.SetMode(gin.TestMode) - contextController := controller.NewContextController(controllerConfig, group) - contextController.SetupRoutes() + controller.NewContextController(log, cfg, runtime, group) recorder := httptest.NewRecorder() request, err := http.NewRequest("GET", test.path, nil) - assert.NoError(t, err) + require.NoError(t, err) router.ServeHTTP(recorder, request) diff --git a/internal/controller/health_controller.go b/internal/controller/health_controller.go index 1b9adbf9..8e84e62b 100644 --- a/internal/controller/health_controller.go +++ b/internal/controller/health_controller.go @@ -3,18 +3,15 @@ package controller import "github.com/gin-gonic/gin" type HealthController struct { - router *gin.RouterGroup } func NewHealthController(router *gin.RouterGroup) *HealthController { - return &HealthController{ - router: router, - } -} + controller := &HealthController{} -func (controller *HealthController) SetupRoutes() { - controller.router.GET("/healthz", controller.healthHandler) - controller.router.HEAD("/healthz", controller.healthHandler) + router.GET("/healthz", controller.healthHandler) + router.HEAD("/healthz", controller.healthHandler) + + return controller } func (controller *HealthController) healthHandler(c *gin.Context) { diff --git a/internal/controller/health_controller_test.go b/internal/controller/health_controller_test.go index d1bed3b6..7576d518 100644 --- a/internal/controller/health_controller_test.go +++ b/internal/controller/health_controller_test.go @@ -7,13 +7,12 @@ import ( "testing" "github.com/gin-gonic/gin" - "github.com/tinyauthapp/tinyauth/internal/controller" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tinyauthapp/tinyauth/internal/controller" ) func TestHealthController(t *testing.T) { - tlog.NewTestLogger().Init() tests := []struct { description string path string @@ -30,7 +29,7 @@ func TestHealthController(t *testing.T) { "message": "Healthy", } bytes, err := json.Marshal(expectedHealthResponse) - assert.NoError(t, err) + require.NoError(t, err) return string(bytes) }(), }, @@ -44,7 +43,7 @@ func TestHealthController(t *testing.T) { "message": "Healthy", } bytes, err := json.Marshal(expectedHealthResponse) - assert.NoError(t, err) + require.NoError(t, err) return string(bytes) }(), }, @@ -56,13 +55,12 @@ func TestHealthController(t *testing.T) { group := router.Group("/api") gin.SetMode(gin.TestMode) - healthController := controller.NewHealthController(group) - healthController.SetupRoutes() + controller.NewHealthController(group) recorder := httptest.NewRecorder() request, err := http.NewRequest(test.method, test.path, nil) - assert.NoError(t, err) + require.NoError(t, err) router.ServeHTTP(recorder, request) diff --git a/internal/controller/oauth_controller.go b/internal/controller/oauth_controller.go index 7f6d6ce0..1aec73ae 100644 --- a/internal/controller/oauth_controller.go +++ b/internal/controller/oauth_controller.go @@ -6,10 +6,11 @@ import ( "strings" "time" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/gin-gonic/gin" "github.com/google/go-querystring/query" @@ -19,34 +20,32 @@ type OAuthRequest struct { Provider string `uri:"provider" binding:"required"` } -type OAuthControllerConfig struct { - CSRFCookieName string - OAuthSessionCookieName string - RedirectCookieName string - SecureCookie bool - AppURL string - CookieDomain string - SubdomainsEnabled bool -} - type OAuthController struct { - config OAuthControllerConfig - router *gin.RouterGroup - auth *service.AuthService + log *logger.Logger + config model.Config + runtime model.RuntimeConfig + auth *service.AuthService } -func NewOAuthController(config OAuthControllerConfig, router *gin.RouterGroup, auth *service.AuthService) *OAuthController { - return &OAuthController{ - config: config, - router: router, - auth: auth, +func NewOAuthController( + log *logger.Logger, + config model.Config, + runtimeConfig model.RuntimeConfig, + router *gin.RouterGroup, + auth *service.AuthService, +) *OAuthController { + controller := &OAuthController{ + log: log, + config: config, + runtime: runtimeConfig, + auth: auth, } -} -func (controller *OAuthController) SetupRoutes() { - oauthGroup := controller.router.Group("/oauth") + oauthGroup := router.Group("/oauth") oauthGroup.GET("/url/:provider", controller.oauthURLHandler) oauthGroup.GET("/callback/:provider", controller.oauthCallbackHandler) + + return controller } func (controller *OAuthController) oauthURLHandler(c *gin.Context) { @@ -54,7 +53,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) { err := c.BindUri(&req) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to bind URI") + controller.log.App.Error().Err(err).Msg("Failed to bind URI") c.JSON(400, gin.H{ "status": 400, "message": "Bad Request", @@ -67,7 +66,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) { err = c.BindQuery(&reqParams) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to bind query parameters") + controller.log.App.Error().Err(err).Msg("Failed to bind query parameters") c.JSON(400, gin.H{ "status": 400, "message": "Bad Request", @@ -76,10 +75,10 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) { } if !controller.isOidcRequest(reqParams) { - isRedirectSafe := utils.IsRedirectSafe(reqParams.RedirectURI, controller.config.CookieDomain) + isRedirectSafe := utils.IsRedirectSafe(reqParams.RedirectURI, controller.runtime.CookieDomain) if !isRedirectSafe { - tlog.App.Warn().Str("redirect_uri", reqParams.RedirectURI).Msg("Unsafe redirect URI detected, ignoring") + controller.log.App.Warn().Str("redirectUri", reqParams.RedirectURI).Msg("Unsafe redirect URI, ignoring") reqParams.RedirectURI = "" } } @@ -87,7 +86,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) { sessionId, _, err := controller.auth.NewOAuthSession(req.Provider, reqParams) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to create OAuth session") + controller.log.App.Error().Err(err).Msg("Failed to create new OAuth session") c.JSON(500, gin.H{ "status": 500, "message": "Internal Server Error", @@ -98,7 +97,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) { authUrl, err := controller.auth.GetOAuthURL(sessionId) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to get OAuth URL") + controller.log.App.Error().Err(err).Msg("Failed to get OAuth URL for session") c.JSON(500, gin.H{ "status": 500, "message": "Internal Server Error", @@ -106,7 +105,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) { return } - c.SetCookie(controller.config.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", controller.getCookieDomain(), controller.config.SecureCookie, true) + c.SetCookie(controller.runtime.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", controller.getCookieDomain(), controller.config.Auth.SecureCookie, true) c.JSON(200, gin.H{ "status": 200, @@ -120,7 +119,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { err := c.BindUri(&req) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to bind URI") + controller.log.App.Error().Err(err).Msg("Failed to bind URI") c.JSON(400, gin.H{ "status": 400, "message": "Bad Request", @@ -128,21 +127,21 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { return } - sessionIdCookie, err := c.Cookie(controller.config.OAuthSessionCookieName) + sessionIdCookie, err := c.Cookie(controller.runtime.OAuthSessionCookieName) if err != nil { - tlog.App.Warn().Err(err).Msg("OAuth session cookie missing") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + controller.log.App.Error().Err(err).Msg("Failed to get OAuth session cookie") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) return } - c.SetCookie(controller.config.OAuthSessionCookieName, "", -1, "/", controller.getCookieDomain(), controller.config.SecureCookie, true) + c.SetCookie(controller.runtime.OAuthSessionCookieName, "", -1, "/", controller.getCookieDomain(), controller.config.Auth.SecureCookie, true) oauthPendingSession, err := controller.auth.GetOAuthPendingSession(sessionIdCookie) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to get OAuth pending session") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + controller.log.App.Error().Err(err).Msg("Failed to get pending OAuth session") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) return } @@ -150,8 +149,8 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { state := c.Query("state") if state != oauthPendingSession.State { - tlog.App.Warn().Err(err).Msg("CSRF token mismatch") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + controller.log.App.Warn().Msg("OAuth state mismatch") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) return } @@ -159,68 +158,80 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { _, err = controller.auth.GetOAuthToken(sessionIdCookie, code) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to exchange code for token") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + controller.log.App.Error().Err(err).Msg("Failed to exchange code for token") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) return } user, err := controller.auth.GetOAuthUserinfo(sessionIdCookie) + if err != nil { + controller.log.App.Error().Err(err).Msg("Failed to get user info from OAuth provider") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) + return + } + + if user == nil { + controller.log.App.Warn().Msg("OAuth provider did not return user info") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) + return + } + if user.Email == "" { - tlog.App.Error().Msg("OAuth provider did not return an email") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + controller.log.App.Warn().Msg("OAuth provider did not return an email") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) return } if !controller.auth.IsEmailWhitelisted(user.Email) { - tlog.App.Warn().Str("email", user.Email).Msg("Email not whitelisted") - tlog.AuditLoginFailure(c, user.Email, req.Provider, "email not whitelisted") + controller.log.App.Warn().Str("email", user.Email).Msg("Email not whitelisted, denying access") + controller.log.AuditLoginFailure(user.Email, req.Provider, c.ClientIP(), "email not whitelisted") queries, err := query.Values(UnauthorizedQuery{ Username: user.Email, }) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) return } - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode())) return } var name string if strings.TrimSpace(user.Name) != "" { - tlog.App.Debug().Msg("Using name from OAuth provider") + controller.log.App.Debug().Msg("Using name from OAuth provider") name = user.Name } else { - tlog.App.Debug().Msg("No name from OAuth provider, using pseudo name") + controller.log.App.Debug().Msg("No name from OAuth provider, generating from email") name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1]) } var username string if strings.TrimSpace(user.PreferredUsername) != "" { - tlog.App.Debug().Msg("Using preferred username from OAuth provider") + controller.log.App.Debug().Msg("Using preferred username from OAuth provider") username = user.PreferredUsername } else { - tlog.App.Debug().Msg("No preferred username from OAuth provider, using pseudo username") + controller.log.App.Debug().Msg("No preferred username from OAuth provider, generating from email") username = strings.Replace(user.Email, "@", "_", 1) } svc, err := controller.auth.GetOAuthService(sessionIdCookie) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to get OAuth service for session") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + controller.log.App.Error().Err(err).Msg("Failed to get OAuth service for session") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) return } if svc.ID() != req.Provider { - tlog.App.Error().Msgf("OAuth service ID mismatch: expected %s, got %s", svc.ID(), req.Provider) - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + controller.log.App.Warn().Msgf("OAuth provider mismatch: expected %s, got %s", req.Provider, svc.ID()) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) return } @@ -234,29 +245,29 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { OAuthSub: user.Sub, } - tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie") + controller.log.App.Debug().Msg("Creating session cookie for user") cookie, err := controller.auth.CreateSession(c, sessionCookie) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to create session cookie") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + controller.log.App.Error().Err(err).Msg("Failed to create session cookie") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) return } http.SetCookie(c.Writer, cookie) - tlog.AuditLoginSuccess(c, sessionCookie.Username, sessionCookie.Provider) + controller.log.AuditLoginSuccess(sessionCookie.Username, sessionCookie.Provider, c.ClientIP()) if controller.isOidcRequest(oauthPendingSession.CallbackParams) { - tlog.App.Debug().Msg("OIDC request, redirecting to authorize page") + controller.log.App.Debug().Msg("OIDC request detected, redirecting to authorization endpoint with callback params") queries, err := query.Values(oauthPendingSession.CallbackParams) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to encode OIDC callback query") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + controller.log.App.Error().Err(err).Msg("Failed to encode OIDC callback query") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) return } - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/authorize?%s", controller.config.AppURL, queries.Encode())) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/authorize?%s", controller.runtime.AppURL, queries.Encode())) return } @@ -266,16 +277,16 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { }) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to encode redirect URI query") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + controller.log.App.Error().Err(err).Msg("Failed to encode redirect query") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) return } - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.config.AppURL, queries.Encode())) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.runtime.AppURL, queries.Encode())) return } - c.Redirect(http.StatusTemporaryRedirect, controller.config.AppURL) + c.Redirect(http.StatusTemporaryRedirect, controller.runtime.AppURL) } func (controller *OAuthController) isOidcRequest(params service.OAuthURLParams) bool { @@ -286,8 +297,8 @@ func (controller *OAuthController) isOidcRequest(params service.OAuthURLParams) } func (controller *OAuthController) getCookieDomain() string { - if controller.config.SubdomainsEnabled { - return "." + controller.config.CookieDomain + if controller.config.Auth.SubdomainsEnabled { + return "." + controller.runtime.CookieDomain } - return controller.config.CookieDomain + return controller.runtime.CookieDomain } diff --git a/internal/controller/oidc_controller.go b/internal/controller/oidc_controller.go index 5e3f75f5..142f0b40 100644 --- a/internal/controller/oidc_controller.go +++ b/internal/controller/oidc_controller.go @@ -13,15 +13,13 @@ import ( "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) -type OIDCControllerConfig struct{} - type OIDCController struct { - config OIDCControllerConfig - router *gin.RouterGroup - oidc *service.OIDCService + log *logger.Logger + oidc *service.OIDCService + runtime model.RuntimeConfig } type AuthorizeCallback struct { @@ -58,29 +56,42 @@ type ClientCredentials struct { ClientSecret string } -func NewOIDCController(config OIDCControllerConfig, oidcService *service.OIDCService, router *gin.RouterGroup) *OIDCController { - return &OIDCController{ - config: config, - oidc: oidcService, - router: router, +func NewOIDCController( + log *logger.Logger, + oidcService *service.OIDCService, + runtimeConfig model.RuntimeConfig, + router *gin.RouterGroup) *OIDCController { + controller := &OIDCController{ + log: log, + oidc: oidcService, + runtime: runtimeConfig, } -} -func (controller *OIDCController) SetupRoutes() { - oidcGroup := controller.router.Group("/oidc") + oidcGroup := router.Group("/oidc") oidcGroup.GET("/clients/:id", controller.GetClientInfo) oidcGroup.POST("/authorize", controller.Authorize) oidcGroup.POST("/token", controller.Token) oidcGroup.GET("/userinfo", controller.Userinfo) oidcGroup.POST("/userinfo", controller.Userinfo) + + return controller } func (controller *OIDCController) GetClientInfo(c *gin.Context) { + if controller.oidc == nil { + controller.log.App.Warn().Msg("Received OIDC client info request but OIDC server is not configured") + c.JSON(500, gin.H{ + "status": 500, + "message": "OIDC not configured", + }) + return + } + var req ClientRequest err := c.BindUri(&req) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to bind URI") + controller.log.App.Error().Err(err).Msg("Failed to bind URI") c.JSON(400, gin.H{ "status": 400, "message": "Bad Request", @@ -91,7 +102,7 @@ func (controller *OIDCController) GetClientInfo(c *gin.Context) { client, ok := controller.oidc.GetClient(req.ClientID) if !ok { - tlog.App.Warn().Str("client_id", req.ClientID).Msg("Client not found") + controller.log.App.Warn().Str("clientId", req.ClientID).Msg("Client not found") c.JSON(404, gin.H{ "status": 404, "message": "Client not found", @@ -107,7 +118,7 @@ func (controller *OIDCController) GetClientInfo(c *gin.Context) { } func (controller *OIDCController) Authorize(c *gin.Context) { - if !controller.oidc.IsConfigured() { + if controller.oidc == nil { controller.authorizeError(c, errors.New("err_oidc_not_configured"), "OIDC not configured", "This instance is not configured for OIDC", "", "", "") return } @@ -142,7 +153,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) { err = controller.oidc.ValidateAuthorizeParams(req) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to validate authorize params") + controller.log.App.Warn().Err(err).Msg("Failed to validate authorize params") if err.Error() != "invalid_request_uri" { controller.authorizeError(c, err, "Failed validate authorize params", "Invalid request parameters", req.RedirectURI, err.Error(), req.State) return @@ -174,7 +185,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) { err = controller.oidc.StoreUserinfo(c, sub, *userContext, req) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to insert user info into database") + controller.log.App.Error().Err(err).Msg("Failed to store user info") controller.authorizeError(c, err, "Failed to store user info", "Failed to store user info", req.RedirectURI, "server_error", req.State) return } @@ -197,10 +208,10 @@ func (controller *OIDCController) Authorize(c *gin.Context) { } func (controller *OIDCController) Token(c *gin.Context) { - if !controller.oidc.IsConfigured() { - tlog.App.Warn().Msg("OIDC not configured") - c.JSON(404, gin.H{ - "error": "not_found", + if controller.oidc == nil { + controller.log.App.Warn().Msg("Received OIDC request but OIDC server is not configured") + c.JSON(500, gin.H{ + "error": "server_error", }) return } @@ -209,7 +220,7 @@ func (controller *OIDCController) Token(c *gin.Context) { err := c.Bind(&req) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to bind token request") + controller.log.App.Warn().Err(err).Msg("Failed to bind token request") c.JSON(400, gin.H{ "error": "invalid_request", }) @@ -218,7 +229,7 @@ func (controller *OIDCController) Token(c *gin.Context) { err = controller.oidc.ValidateGrantType(req.GrantType) if err != nil { - tlog.App.Warn().Str("grant_type", req.GrantType).Msg("Unsupported grant type") + controller.log.App.Warn().Err(err).Msg("Invalid grant type") c.JSON(400, gin.H{ "error": err.Error(), }) @@ -233,12 +244,12 @@ func (controller *OIDCController) Token(c *gin.Context) { // If it fails, we try basic auth if creds.ClientID == "" || creds.ClientSecret == "" { - tlog.App.Debug().Msg("Tried form values and they are empty, trying basic auth") + controller.log.App.Debug().Msg("Client credentials not found in form, trying basic auth") clientId, clientSecret, ok := c.Request.BasicAuth() if !ok { - tlog.App.Error().Msg("Missing authorization header") + controller.log.App.Warn().Msg("Client credentials not found in basic auth") c.Header("www-authenticate", `Basic realm="Tinyauth OIDC Token Endpoint"`) c.JSON(400, gin.H{ "error": "invalid_client", @@ -255,7 +266,7 @@ func (controller *OIDCController) Token(c *gin.Context) { client, ok := controller.oidc.GetClient(creds.ClientID) if !ok { - tlog.App.Warn().Str("client_id", creds.ClientID).Msg("Client not found") + controller.log.App.Warn().Str("clientId", creds.ClientID).Msg("Client not found") c.JSON(400, gin.H{ "error": "invalid_client", }) @@ -263,7 +274,7 @@ func (controller *OIDCController) Token(c *gin.Context) { } if client.ClientSecret != creds.ClientSecret { - tlog.App.Warn().Str("client_id", creds.ClientID).Msg("Invalid client secret") + controller.log.App.Warn().Str("clientId", creds.ClientID).Msg("Invalid client secret") c.JSON(400, gin.H{ "error": "invalid_client", }) @@ -277,30 +288,30 @@ func (controller *OIDCController) Token(c *gin.Context) { entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code), client.ClientID) if err != nil { if err := controller.oidc.DeleteTokenByCodeHash(c, controller.oidc.Hash(req.Code)); err != nil { - tlog.App.Error().Err(err).Msg("Failed to delete access token by code hash") + controller.log.App.Error().Err(err).Msg("Failed to delete code") } if errors.Is(err, service.ErrCodeNotFound) { - tlog.App.Warn().Msg("Code not found") + controller.log.App.Warn().Msg("Code not found") c.JSON(400, gin.H{ "error": "invalid_grant", }) return } if errors.Is(err, service.ErrCodeExpired) { - tlog.App.Warn().Msg("Code expired") + controller.log.App.Warn().Msg("Code expired") c.JSON(400, gin.H{ "error": "invalid_grant", }) return } if errors.Is(err, service.ErrInvalidClient) { - tlog.App.Warn().Msg("Invalid client ID") + controller.log.App.Warn().Msg("Code does not belong to client") c.JSON(400, gin.H{ "error": "invalid_client", }) return } - tlog.App.Warn().Err(err).Msg("Failed to get OIDC code entry") + controller.log.App.Error().Err(err).Msg("Failed to get code entry") c.JSON(400, gin.H{ "error": "server_error", }) @@ -308,7 +319,7 @@ func (controller *OIDCController) Token(c *gin.Context) { } if entry.RedirectURI != req.RedirectURI { - tlog.App.Warn().Str("redirect_uri", req.RedirectURI).Msg("Redirect URI mismatch") + controller.log.App.Warn().Msg("Redirect URI does not match") c.JSON(400, gin.H{ "error": "invalid_grant", }) @@ -318,7 +329,7 @@ func (controller *OIDCController) Token(c *gin.Context) { ok := controller.oidc.ValidatePKCE(entry.CodeChallenge, req.CodeVerifier) if !ok { - tlog.App.Warn().Msg("PKCE validation failed") + controller.log.App.Warn().Msg("PKCE validation failed") c.JSON(400, gin.H{ "error": "invalid_grant", }) @@ -328,7 +339,7 @@ func (controller *OIDCController) Token(c *gin.Context) { tokenRes, err := controller.oidc.GenerateAccessToken(c, client, entry) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to generate access token") + controller.log.App.Error().Err(err).Msg("Failed to generate access token") c.JSON(400, gin.H{ "error": "server_error", }) @@ -341,7 +352,7 @@ func (controller *OIDCController) Token(c *gin.Context) { if err != nil { if errors.Is(err, service.ErrTokenExpired) { - tlog.App.Error().Err(err).Msg("Refresh token expired") + controller.log.App.Warn().Msg("Refresh token expired") c.JSON(400, gin.H{ "error": "invalid_grant", }) @@ -349,14 +360,14 @@ func (controller *OIDCController) Token(c *gin.Context) { } if errors.Is(err, service.ErrInvalidClient) { - tlog.App.Error().Err(err).Msg("Invalid client") + controller.log.App.Warn().Msg("Refresh token does not belong to client") c.JSON(400, gin.H{ "error": "invalid_grant", }) return } - tlog.App.Error().Err(err).Msg("Failed to refresh access token") + controller.log.App.Error().Err(err).Msg("Failed to refresh access token") c.JSON(400, gin.H{ "error": "server_error", }) @@ -373,10 +384,10 @@ func (controller *OIDCController) Token(c *gin.Context) { } func (controller *OIDCController) Userinfo(c *gin.Context) { - if !controller.oidc.IsConfigured() { - tlog.App.Warn().Msg("OIDC not configured") - c.JSON(404, gin.H{ - "error": "not_found", + if controller.oidc == nil { + controller.log.App.Warn().Msg("Received OIDC userinfo request but OIDC server is not configured") + c.JSON(500, gin.H{ + "error": "server_error", }) return } @@ -387,7 +398,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) { if authorization != "" { tokenType, bearerToken, ok := strings.Cut(authorization, " ") if !ok { - tlog.App.Warn().Msg("OIDC userinfo accessed with malformed authorization header") + controller.log.App.Warn().Msg("OIDC userinfo accessed with invalid authorization header") c.JSON(401, gin.H{ "error": "invalid_request", }) @@ -395,7 +406,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) { } if strings.ToLower(tokenType) != "bearer" { - tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token type") + controller.log.App.Warn().Msg("OIDC userinfo accessed with non-bearer token") c.JSON(401, gin.H{ "error": "invalid_request", }) @@ -405,7 +416,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) { token = bearerToken } else if c.Request.Method == http.MethodPost { if c.ContentType() != "application/x-www-form-urlencoded" { - tlog.App.Warn().Msg("OIDC userinfo POST accessed with invalid content type") + controller.log.App.Warn().Msg("OIDC userinfo POST accessed with invalid content type") c.JSON(400, gin.H{ "error": "invalid_request", }) @@ -413,14 +424,14 @@ func (controller *OIDCController) Userinfo(c *gin.Context) { } token = c.PostForm("access_token") if token == "" { - tlog.App.Warn().Msg("OIDC userinfo POST accessed without access_token in body") + controller.log.App.Warn().Msg("OIDC userinfo POST accessed without access_token") c.JSON(401, gin.H{ "error": "invalid_request", }) return } } else { - tlog.App.Warn().Msg("OIDC userinfo accessed without authorization header") + controller.log.App.Warn().Msg("OIDC userinfo accessed without authorization header or POST body") c.JSON(401, gin.H{ "error": "invalid_request", }) @@ -431,14 +442,14 @@ func (controller *OIDCController) Userinfo(c *gin.Context) { if err != nil { if errors.Is(err, service.ErrTokenNotFound) { - tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token") + controller.log.App.Warn().Msg("OIDC userinfo accessed with invalid token") c.JSON(401, gin.H{ "error": "invalid_grant", }) return } - tlog.App.Err(err).Msg("Failed to get token entry") + controller.log.App.Error().Err(err).Msg("Failed to get access token") c.JSON(401, gin.H{ "error": "server_error", }) @@ -447,7 +458,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) { // If we don't have the openid scope, return an error if !slices.Contains(strings.Split(entry.Scope, ","), "openid") { - tlog.App.Warn().Msg("OIDC userinfo accessed without openid scope") + controller.log.App.Warn().Msg("OIDC userinfo accessed with token missing openid scope") c.JSON(401, gin.H{ "error": "invalid_scope", }) @@ -457,7 +468,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) { user, err := controller.oidc.GetUserinfo(c, entry.Sub) if err != nil { - tlog.App.Err(err).Msg("Failed to get user entry") + controller.log.App.Error().Err(err).Msg("Failed to get user info") c.JSON(401, gin.H{ "error": "server_error", }) @@ -468,7 +479,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) { } func (controller *OIDCController) authorizeError(c *gin.Context, err error, reason string, reasonUser string, callback string, callbackError string, state string) { - tlog.App.Error().Err(err).Msg(reason) + controller.log.App.Warn().Err(err).Str("reason", reason).Msg("Authorization error") if callback != "" { errorQueries := CallbackError{ @@ -508,8 +519,16 @@ func (controller *OIDCController) authorizeError(c *gin.Context, err error, reas return } + redirectUrl := "" + + if controller.oidc != nil { + redirectUrl = fmt.Sprintf("%s/error?%s", controller.oidc.GetIssuer(), queries.Encode()) + } else { + redirectUrl = fmt.Sprintf("%s/error?%s", controller.runtime.AppURL, queries.Encode()) + } + c.JSON(200, gin.H{ "status": 200, - "redirect_uri": fmt.Sprintf("%s/error?%s", controller.oidc.GetIssuer(), queries.Encode()), + "redirect_uri": redirectUrl, }) } diff --git a/internal/controller/oidc_controller_test.go b/internal/controller/oidc_controller_test.go index 150540fc..9ece2073 100644 --- a/internal/controller/oidc_controller_test.go +++ b/internal/controller/oidc_controller_test.go @@ -1,13 +1,14 @@ package controller_test import ( + "context" "crypto/sha256" "encoding/base64" "encoding/json" "net/http/httptest" "net/url" - "path" "strings" + "sync" "testing" "github.com/gin-gonic/gin" @@ -19,29 +20,15 @@ import ( "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/test" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) func TestOIDCController(t *testing.T) { - tlog.NewTestLogger().Init() - tempDir := t.TempDir() + log := logger.NewLogger().WithTestConfig() + log.Init() - oidcServiceCfg := service.OIDCServiceConfig{ - Clients: map[string]model.OIDCClientConfig{ - "test": { - ClientID: "some-client-id", - ClientSecret: "some-client-secret", - TrustedRedirectURIs: []string{"https://test.example.com/callback"}, - Name: "Test Client", - }, - }, - PrivateKeyPath: path.Join(tempDir, "key.pem"), - PublicKeyPath: path.Join(tempDir, "key.pub"), - Issuer: "https://tinyauth.example.com", - SessionExpiry: 500, - } - - controllerCfg := controller.OIDCControllerConfig{} + cfg, runtime := test.CreateTestConfigs(t) simpleCtx := func(c *gin.Context) { c.Set("context", &model.UserContext{ @@ -103,7 +90,7 @@ func TestOIDCController(t *testing.T) { var res map[string]any err := json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, res["redirect_uri"], "https://tinyauth.example.com/error?error=User+is+not+logged+in+or+the+session+is+invalid") }, @@ -123,7 +110,7 @@ func TestOIDCController(t *testing.T) { Nonce: "some-nonce", } reqBodyBytes, err := json.Marshal(reqBody) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) req.Header.Set("Content-Type", "application/json") @@ -131,7 +118,7 @@ func TestOIDCController(t *testing.T) { var res map[string]any err = json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, res["redirect_uri"], "https://test.example.com/callback?error=unsupported_response_type&error_description=Invalid+request+parameters&state=some-state") }, @@ -151,7 +138,7 @@ func TestOIDCController(t *testing.T) { Nonce: "some-nonce", } reqBodyBytes, err := json.Marshal(reqBody) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) req.Header.Set("Content-Type", "application/json") @@ -160,11 +147,11 @@ func TestOIDCController(t *testing.T) { var res map[string]any err = json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) redirectURI := res["redirect_uri"].(string) url, err := url.Parse(redirectURI) - assert.NoError(t, err) + require.NoError(t, err) queryParams := url.Query() assert.Equal(t, queryParams.Get("state"), "some-state") @@ -183,7 +170,7 @@ func TestOIDCController(t *testing.T) { RedirectURI: "https://test.example.com/callback", } reqBodyEncoded, err := query.Values(reqBody) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -191,7 +178,7 @@ func TestOIDCController(t *testing.T) { var res map[string]any err = json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, res["error"], "unsupported_grant_type") }, @@ -206,7 +193,7 @@ func TestOIDCController(t *testing.T) { RedirectURI: "https://test.example.com/callback", } reqBodyEncoded, err := query.Values(reqBody) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -244,7 +231,7 @@ func TestOIDCController(t *testing.T) { RedirectURI: "https://test.example.com/callback", } reqBodyEncoded, err := query.Values(reqBody) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -267,11 +254,11 @@ func TestOIDCController(t *testing.T) { var authorizeRes map[string]any err := json.Unmarshal(authorizeTestRecorder.Body.Bytes(), &authorizeRes) - assert.NoError(t, err) + require.NoError(t, err) redirectURI := authorizeRes["redirect_uri"].(string) url, err := url.Parse(redirectURI) - assert.NoError(t, err) + require.NoError(t, err) queryParams := url.Query() code := queryParams.Get("code") @@ -283,7 +270,7 @@ func TestOIDCController(t *testing.T) { RedirectURI: "https://test.example.com/callback", } reqBodyEncoded, err := query.Values(reqBody) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -306,7 +293,7 @@ func TestOIDCController(t *testing.T) { var tokenRes map[string]any err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes) - assert.NoError(t, err) + require.NoError(t, err) _, ok := tokenRes["refresh_token"] assert.True(t, ok, "Expected refresh token in response") @@ -320,7 +307,7 @@ func TestOIDCController(t *testing.T) { ClientSecret: "some-client-secret", } reqBodyEncoded, err := query.Values(reqBody) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -332,7 +319,7 @@ func TestOIDCController(t *testing.T) { assert.Equal(t, 200, recorder.Code) var refreshRes map[string]any err = json.Unmarshal(recorder.Body.Bytes(), &refreshRes) - assert.NoError(t, err) + require.NoError(t, err) _, ok = refreshRes["access_token"] assert.True(t, ok, "Expected access token in refresh response") @@ -353,11 +340,11 @@ func TestOIDCController(t *testing.T) { var authorizeRes map[string]any err := json.Unmarshal(authorizeTestRecorder.Body.Bytes(), &authorizeRes) - assert.NoError(t, err) + require.NoError(t, err) redirectURI := authorizeRes["redirect_uri"].(string) url, err := url.Parse(redirectURI) - assert.NoError(t, err) + require.NoError(t, err) queryParams := url.Query() code := queryParams.Get("code") @@ -369,7 +356,7 @@ func TestOIDCController(t *testing.T) { RedirectURI: "https://test.example.com/callback", } reqBodyEncoded, err := query.Values(reqBody) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -389,7 +376,7 @@ func TestOIDCController(t *testing.T) { var secondRes map[string]any err = json.Unmarshal(secondRecorder.Body.Bytes(), &secondRes) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "invalid_grant", secondRes["error"]) }, @@ -417,7 +404,7 @@ func TestOIDCController(t *testing.T) { var tokenRes map[string]any err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes) - assert.NoError(t, err) + require.NoError(t, err) accessToken := tokenRes["access_token"].(string) assert.NotEmpty(t, accessToken) @@ -429,7 +416,7 @@ func TestOIDCController(t *testing.T) { var userInfoRes map[string]any err = json.Unmarshal(recorder.Body.Bytes(), &userInfoRes) - assert.NoError(t, err) + require.NoError(t, err) _, ok := userInfoRes["sub"] assert.True(t, ok, "Expected sub claim in userinfo response") @@ -449,7 +436,7 @@ func TestOIDCController(t *testing.T) { var res map[string]any err := json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "invalid_request", res["error"]) }, }, @@ -464,7 +451,7 @@ func TestOIDCController(t *testing.T) { var res map[string]any err := json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "invalid_request", res["error"]) }, }, @@ -479,7 +466,7 @@ func TestOIDCController(t *testing.T) { var res map[string]any err := json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "invalid_request", res["error"]) }, }, @@ -494,7 +481,7 @@ func TestOIDCController(t *testing.T) { var res map[string]any err := json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "invalid_grant", res["error"]) }, }, @@ -509,7 +496,7 @@ func TestOIDCController(t *testing.T) { var res map[string]any err := json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "invalid_request", res["error"]) }, }, @@ -524,7 +511,7 @@ func TestOIDCController(t *testing.T) { var res map[string]any err := json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "invalid_request", res["error"]) }, }, @@ -541,7 +528,7 @@ func TestOIDCController(t *testing.T) { var tokenRes map[string]any err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes) - assert.NoError(t, err) + require.NoError(t, err) accessToken := tokenRes["access_token"].(string) assert.NotEmpty(t, accessToken) @@ -555,7 +542,7 @@ func TestOIDCController(t *testing.T) { var userInfoRes map[string]any err = json.Unmarshal(recorder.Body.Bytes(), &userInfoRes) - assert.NoError(t, err) + require.NoError(t, err) _, ok := userInfoRes["sub"] assert.True(t, ok, "Expected sub claim in userinfo response") @@ -579,7 +566,7 @@ func TestOIDCController(t *testing.T) { CodeChallengeMethod: "", } reqBodyBytes, err := json.Marshal(reqBody) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) req.Header.Set("Content-Type", "application/json") @@ -588,11 +575,11 @@ func TestOIDCController(t *testing.T) { var res map[string]any err = json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) redirectURI := res["redirect_uri"].(string) url, err := url.Parse(redirectURI) - assert.NoError(t, err) + require.NoError(t, err) queryParams := url.Query() assert.Equal(t, queryParams.Get("state"), "some-state") @@ -609,7 +596,7 @@ func TestOIDCController(t *testing.T) { CodeVerifier: "some-challenge", } reqBodyEncoded, err := query.Values(tokenReqBody) - assert.NoError(t, err) + require.NoError(t, err) req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -640,7 +627,7 @@ func TestOIDCController(t *testing.T) { CodeChallengeMethod: "S256", } reqBodyBytes, err := json.Marshal(reqBody) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) req.Header.Set("Content-Type", "application/json") @@ -649,11 +636,11 @@ func TestOIDCController(t *testing.T) { var res map[string]any err = json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) redirectURI := res["redirect_uri"].(string) url, err := url.Parse(redirectURI) - assert.NoError(t, err) + require.NoError(t, err) queryParams := url.Query() assert.Equal(t, queryParams.Get("state"), "some-state") @@ -670,7 +657,7 @@ func TestOIDCController(t *testing.T) { CodeVerifier: "some-challenge", } reqBodyEncoded, err := query.Values(tokenReqBody) - assert.NoError(t, err) + require.NoError(t, err) req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -701,7 +688,7 @@ func TestOIDCController(t *testing.T) { CodeChallengeMethod: "S256", } reqBodyBytes, err := json.Marshal(reqBody) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) req.Header.Set("Content-Type", "application/json") @@ -710,11 +697,11 @@ func TestOIDCController(t *testing.T) { var res map[string]any err = json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) redirectURI := res["redirect_uri"].(string) url, err := url.Parse(redirectURI) - assert.NoError(t, err) + require.NoError(t, err) queryParams := url.Query() assert.Equal(t, queryParams.Get("state"), "some-state") @@ -731,7 +718,7 @@ func TestOIDCController(t *testing.T) { CodeVerifier: "some-challenge-1", } reqBodyEncoded, err := query.Values(tokenReqBody) - assert.NoError(t, err) + require.NoError(t, err) req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -762,7 +749,7 @@ func TestOIDCController(t *testing.T) { CodeChallengeMethod: "foo", } reqBodyBytes, err := json.Marshal(reqBody) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) req.Header.Set("Content-Type", "application/json") @@ -771,11 +758,11 @@ func TestOIDCController(t *testing.T) { var res map[string]any err = json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) redirectURI := res["redirect_uri"].(string) url, err := url.Parse(redirectURI) - assert.NoError(t, err) + require.NoError(t, err) queryParams := url.Query() error := queryParams.Get("error") @@ -794,11 +781,11 @@ func TestOIDCController(t *testing.T) { var res map[string]any err := json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) redirectURI := res["redirect_uri"].(string) url, err := url.Parse(redirectURI) - assert.NoError(t, err) + require.NoError(t, err) queryParams := url.Query() code := queryParams.Get("code") @@ -810,7 +797,7 @@ func TestOIDCController(t *testing.T) { RedirectURI: "https://test.example.com/callback", } reqBodyEncoded, err := query.Values(reqBody) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -821,7 +808,7 @@ func TestOIDCController(t *testing.T) { assert.Equal(t, 200, recorder.Code) err = json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) accessToken := res["access_token"].(string) assert.NotEmpty(t, accessToken) @@ -846,20 +833,22 @@ func TestOIDCController(t *testing.T) { assert.Equal(t, 401, recorder.Code) err = json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "invalid_grant", res["error"]) }, }, } - app := bootstrap.NewBootstrapApp(model.Config{}) + app := bootstrap.NewBootstrapApp(cfg) - db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db")) + err := app.SetupDatabase() require.NoError(t, err) - queries := repository.New(db) - oidcService := service.NewOIDCService(oidcServiceCfg, queries) - err = oidcService.Init() + queries := repository.New(app.GetDB()) + + wg := &sync.WaitGroup{} + + oidcService, err := service.NewOIDCService(log, cfg, runtime, queries, context.TODO(), wg) require.NoError(t, err) for _, test := range tests { @@ -873,8 +862,7 @@ func TestOIDCController(t *testing.T) { group := router.Group("/api") gin.SetMode(gin.TestMode) - oidcController := controller.NewOIDCController(controllerCfg, oidcService, group) - oidcController.SetupRoutes() + controller.NewOIDCController(log, oidcService, runtime, group) recorder := httptest.NewRecorder() @@ -883,7 +871,6 @@ func TestOIDCController(t *testing.T) { } t.Cleanup(func() { - err = db.Close() - require.NoError(t, err) + app.GetDB().Close() }) } diff --git a/internal/controller/proxy_controller.go b/internal/controller/proxy_controller.go index 7cd01969..40969b83 100644 --- a/internal/controller/proxy_controller.go +++ b/internal/controller/proxy_controller.go @@ -11,7 +11,7 @@ import ( "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/gin-gonic/gin" "github.com/google/go-querystring/query" @@ -50,29 +50,31 @@ type ProxyContext struct { ProxyType ProxyType } -type ProxyControllerConfig struct { - AppURL string -} - type ProxyController struct { - config ProxyControllerConfig - router *gin.RouterGroup - acls *service.AccessControlsService - auth *service.AuthService + log *logger.Logger + runtime model.RuntimeConfig + acls *service.AccessControlsService + auth *service.AuthService } -func NewProxyController(config ProxyControllerConfig, router *gin.RouterGroup, acls *service.AccessControlsService, auth *service.AuthService) *ProxyController { - return &ProxyController{ - config: config, - router: router, - acls: acls, - auth: auth, +func NewProxyController( + log *logger.Logger, + runtime model.RuntimeConfig, + router *gin.RouterGroup, + acls *service.AccessControlsService, + auth *service.AuthService, +) *ProxyController { + controller := &ProxyController{ + log: log, + runtime: runtime, + acls: acls, + auth: auth, } -} -func (controller *ProxyController) SetupRoutes() { - proxyGroup := controller.router.Group("/auth") + proxyGroup := router.Group("/auth") proxyGroup.Any("/:proxy", controller.proxyHandler) + + return controller } func (controller *ProxyController) proxyHandler(c *gin.Context) { @@ -80,7 +82,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { proxyCtx, err := controller.getProxyContext(c) if err != nil { - tlog.App.Warn().Err(err).Msg("Failed to get proxy context") + controller.log.App.Error().Err(err).Msg("Failed to get proxy context from request") c.JSON(400, gin.H{ "status": 400, "message": "Bad request", @@ -88,19 +90,15 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { return } - tlog.App.Trace().Interface("ctx", proxyCtx).Msg("Got proxy context") - // Get acls acls, err := controller.acls.GetAccessControls(proxyCtx.Host) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to get access controls for resource") + controller.log.App.Error().Err(err).Msg("Failed to get ACLs for resource") controller.handleError(c, proxyCtx) return } - tlog.App.Trace().Interface("acls", acls).Msg("ACLs for resource") - clientIP := c.ClientIP() if controller.auth.IsBypassedIP(clientIP, acls) { @@ -115,13 +113,13 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { authEnabled, err := controller.auth.IsAuthEnabled(proxyCtx.Path, acls) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to check if auth is enabled for resource") + controller.log.App.Error().Err(err).Msg("Failed to determine if authentication is enabled for resource") controller.handleError(c, proxyCtx) return } if !authEnabled { - tlog.App.Debug().Msg("Authentication disabled for resource, allowing access") + controller.log.App.Debug().Msg("Authentication is disabled for this resource, allowing access without authentication") controller.setHeaders(c, acls) c.JSON(200, gin.H{ "status": 200, @@ -137,12 +135,12 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { }) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query") + controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query") controller.handleError(c, proxyCtx) return } - redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode()) + redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode()) if !controller.useBrowserResponse(proxyCtx) { c.Header("x-tinyauth-location", redirectURL) @@ -160,26 +158,24 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { userContext, err := new(model.UserContext).NewFromGin(c) if err != nil { - tlog.App.Debug().Err(err).Msg("No user context found in request, treating as unauthenticated") + controller.log.App.Debug().Err(err).Msg("Failed to create user context from request, treating as unauthenticated") userContext = &model.UserContext{ Authenticated: false, } } - tlog.App.Trace().Interface("context", userContext).Msg("User context from request") - if userContext.Authenticated { userAllowed := controller.auth.IsUserAllowed(c, *userContext, acls) if !userAllowed { - tlog.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User not allowed to access resource") + controller.log.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User is not allowed to access resource") queries, err := query.Values(UnauthorizedQuery{ Resource: strings.Split(proxyCtx.Host, ".")[0], }) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query") + controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query") controller.handleError(c, proxyCtx) return } @@ -190,7 +186,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { queries.Set("username", userContext.GetUsername()) } - redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode()) + redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode()) if !controller.useBrowserResponse(proxyCtx) { c.Header("x-tinyauth-location", redirectURL) @@ -215,7 +211,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { } if !groupOK { - tlog.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User groups do not match resource requirements") + controller.log.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User is not in the required group to access resource") queries, err := query.Values(UnauthorizedQuery{ Resource: strings.Split(proxyCtx.Host, ".")[0], @@ -223,7 +219,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { }) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query") + controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query") controller.handleError(c, proxyCtx) return } @@ -234,7 +230,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { queries.Set("username", userContext.GetUsername()) } - redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode()) + redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode()) if !controller.useBrowserResponse(proxyCtx) { c.Header("x-tinyauth-location", redirectURL) @@ -277,12 +273,12 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { }) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to encode redirect URI query") + controller.log.App.Error().Err(err).Msg("Failed to encode redirect query") controller.handleError(c, proxyCtx) return } - redirectURL := fmt.Sprintf("%s/login?%s", controller.config.AppURL, queries.Encode()) + redirectURL := fmt.Sprintf("%s/login?%s", controller.runtime.AppURL, queries.Encode()) if !controller.useBrowserResponse(proxyCtx) { c.Header("x-tinyauth-location", redirectURL) @@ -306,20 +302,19 @@ func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) { headers := utils.ParseHeaders(acls.Response.Headers) for key, value := range headers { - tlog.App.Debug().Str("header", key).Msg("Setting header") c.Header(key, value) } basicPassword := utils.GetSecret(acls.Response.BasicAuth.Password, acls.Response.BasicAuth.PasswordFile) if acls.Response.BasicAuth.Username != "" && basicPassword != "" { - tlog.App.Debug().Str("username", acls.Response.BasicAuth.Username).Msg("Setting basic auth header") + controller.log.App.Debug().Msg("Setting basic auth header for response") c.Header("Authorization", fmt.Sprintf("Basic %s", utils.EncodeBasicAuth(acls.Response.BasicAuth.Username, basicPassword))) } } func (controller *ProxyController) handleError(c *gin.Context, proxyCtx ProxyContext) { - redirectURL := fmt.Sprintf("%s/error", controller.config.AppURL) + redirectURL := fmt.Sprintf("%s/error", controller.runtime.AppURL) if !controller.useBrowserResponse(proxyCtx) { c.Header("x-tinyauth-location", redirectURL) @@ -520,7 +515,7 @@ func (controller *ProxyController) getProxyContext(c *gin.Context) (ProxyContext return ProxyContext{}, err } - tlog.App.Debug().Msgf("Proxy: %v", req.Proxy) + controller.log.App.Debug().Msgf("Determined proxy type: %v", proxy) authModules := controller.determineAuthModules(proxy) @@ -531,13 +526,13 @@ func (controller *ProxyController) getProxyContext(c *gin.Context) (ProxyContext var ctx ProxyContext for _, module := range authModules { - tlog.App.Debug().Msgf("Trying auth module: %v", module) + controller.log.App.Debug().Msgf("Trying to get context from auth module %v", module) ctx, err = controller.getContextFromAuthModule(c, module) if err == nil { - tlog.App.Debug().Msgf("Auth module %v succeeded", module) + controller.log.App.Debug().Msgf("Successfully got context from auth module %v", module) break } - tlog.App.Debug().Err(err).Msgf("Auth module %v failed", module) + controller.log.App.Debug().Msgf("Failed to get context from auth module %v: %v", module, err) } if err != nil { @@ -549,9 +544,9 @@ func (controller *ProxyController) getProxyContext(c *gin.Context) (ProxyContext isBrowser := BrowserUserAgentRegex.MatchString(userAgent) if isBrowser { - tlog.App.Debug().Msg("Request identified as coming from a browser") + controller.log.App.Debug().Msg("Request identified as coming from a browser client") } else { - tlog.App.Debug().Msg("Request identified as coming from a non-browser client") + controller.log.App.Debug().Msg("Request identified as coming from a non-browser client") } ctx.IsBrowser = isBrowser diff --git a/internal/controller/proxy_controller_test.go b/internal/controller/proxy_controller_test.go index 7b2e3202..12c3c9f1 100644 --- a/internal/controller/proxy_controller_test.go +++ b/internal/controller/proxy_controller_test.go @@ -1,8 +1,9 @@ package controller_test import ( + "context" "net/http/httptest" - "path" + "sync" "testing" "github.com/gin-gonic/gin" @@ -13,35 +14,15 @@ import ( "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/test" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) func TestProxyController(t *testing.T) { - tlog.NewTestLogger().Init() - tempDir := t.TempDir() + log := logger.NewLogger().WithTestConfig() + log.Init() - authServiceCfg := service.AuthServiceConfig{ - LocalUsers: &[]model.LocalUser{ - { - Username: "testuser", - Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password - }, - { - Username: "totpuser", - Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password - TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", - }, - }, - SessionExpiry: 10, // 10 seconds, useful for testing - CookieDomain: "example.com", - LoginTimeout: 10, // 10 seconds, useful for testing - LoginMaxRetries: 3, - SessionCookieName: "tinyauth-session", - } - - controllerCfg := controller.ProxyControllerConfig{ - AppURL: "https://tinyauth.example.com", - } + cfg, runtime := test.CreateTestConfigs(t) acls := map[string]model.App{ "app_path_allow": { @@ -398,32 +379,19 @@ func TestProxyController(t *testing.T) { }, } - oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig) + app := bootstrap.NewBootstrapApp(cfg) - app := bootstrap.NewBootstrapApp(model.Config{}) - - db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db")) + err := app.SetupDatabase() require.NoError(t, err) - queries := repository.New(db) + queries := repository.New(app.GetDB()) - docker := service.NewDockerService() - err = docker.Init() - require.NoError(t, err) + wg := &sync.WaitGroup{} + ctx := context.TODO() - ldap := service.NewLdapService(service.LdapServiceConfig{}) - err = ldap.Init() - require.NoError(t, err) - - broker := service.NewOAuthBrokerService(oauthBrokerCfgs) - err = broker.Init() - require.NoError(t, err) - - authService := service.NewAuthService(authServiceCfg, ldap, queries, broker) - err = authService.Init() - require.NoError(t, err) - - aclsService := service.NewAccessControlsService(docker, acls) + broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx) + authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, queries, broker) + aclsService := service.NewAccessControlsService(log, nil, acls) for _, test := range tests { t.Run(test.description, func(t *testing.T) { @@ -438,15 +406,13 @@ func TestProxyController(t *testing.T) { recorder := httptest.NewRecorder() - proxyController := controller.NewProxyController(controllerCfg, group, aclsService, authService) - proxyController.SetupRoutes() + controller.NewProxyController(log, runtime, group, aclsService, authService) test.run(t, router, recorder) }) } t.Cleanup(func() { - err = db.Close() - require.NoError(t, err) + app.GetDB().Close() }) } diff --git a/internal/controller/resources_controller.go b/internal/controller/resources_controller.go index 98d3b23c..54af733d 100644 --- a/internal/controller/resources_controller.go +++ b/internal/controller/resources_controller.go @@ -4,42 +4,39 @@ import ( "net/http" "github.com/gin-gonic/gin" + "github.com/tinyauthapp/tinyauth/internal/model" ) -type ResourcesControllerConfig struct { - Path string - Enabled bool -} - type ResourcesController struct { - config ResourcesControllerConfig - router *gin.RouterGroup + config model.Config fileServer http.Handler } -func NewResourcesController(config ResourcesControllerConfig, router *gin.RouterGroup) *ResourcesController { - fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(config.Path))) +func NewResourcesController( + config model.Config, + router *gin.RouterGroup, +) *ResourcesController { + fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(config.Resources.Path))) - return &ResourcesController{ + controller := &ResourcesController{ config: config, - router: router, fileServer: fileServer, } -} -func (controller *ResourcesController) SetupRoutes() { - controller.router.GET("/resources/*resource", controller.resourcesHandler) + router.GET("/resources/*resource", controller.resourcesHandler) + + return controller } func (controller *ResourcesController) resourcesHandler(c *gin.Context) { - if controller.config.Path == "" { + if controller.config.Resources.Path == "" { c.JSON(404, gin.H{ "status": 404, "message": "Resources not found", }) return } - if !controller.config.Enabled { + if !controller.config.Resources.Enabled { c.JSON(403, gin.H{ "status": 403, "message": "Resources are disabled", diff --git a/internal/controller/resources_controller_test.go b/internal/controller/resources_controller_test.go index a1996be3..68ce463d 100644 --- a/internal/controller/resources_controller_test.go +++ b/internal/controller/resources_controller_test.go @@ -3,26 +3,20 @@ package controller_test import ( "net/http/httptest" "os" - "path" + "path/filepath" "testing" "github.com/gin-gonic/gin" - "github.com/tinyauthapp/tinyauth/internal/controller" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/tinyauthapp/tinyauth/internal/controller" + "github.com/tinyauthapp/tinyauth/internal/test" ) func TestResourcesController(t *testing.T) { - tlog.NewTestLogger().Init() - tempDir := t.TempDir() + cfg, _ := test.CreateTestConfigs(t) - resourcesControllerCfg := controller.ResourcesControllerConfig{ - Path: path.Join(tempDir, "resources"), - Enabled: true, - } - - err := os.Mkdir(resourcesControllerCfg.Path, 0777) + err := os.MkdirAll(cfg.Resources.Path, 0777) require.NoError(t, err) type testCase struct { @@ -61,11 +55,11 @@ func TestResourcesController(t *testing.T) { }, } - testFilePath := resourcesControllerCfg.Path + "/testfile.txt" + testFilePath := cfg.Resources.Path + "/testfile.txt" err = os.WriteFile(testFilePath, []byte("This is a test file."), 0777) require.NoError(t, err) - testFilePathParent := tempDir + "/somefile.txt" + testFilePathParent := filepath.Dir(cfg.Resources.Path) + "/somefile.txt" err = os.WriteFile(testFilePathParent, []byte("This file should not be accessible."), 0777) require.NoError(t, err) @@ -75,8 +69,7 @@ func TestResourcesController(t *testing.T) { group := router.Group("/") gin.SetMode(gin.TestMode) - resourcesController := controller.NewResourcesController(resourcesControllerCfg, group) - resourcesController.SetupRoutes() + controller.NewResourcesController(cfg, group) recorder := httptest.NewRecorder() test.run(t, router, recorder) diff --git a/internal/controller/user_controller.go b/internal/controller/user_controller.go index cb6d5e6f..45a876bf 100644 --- a/internal/controller/user_controller.go +++ b/internal/controller/user_controller.go @@ -10,7 +10,7 @@ import ( "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/gin-gonic/gin" "github.com/pquerna/otp/totp" @@ -25,30 +25,30 @@ type TotpRequest struct { Code string `json:"code"` } -type UserControllerConfig struct { - CookieDomain string - SessionCookieName string -} - type UserController struct { - config UserControllerConfig - router *gin.RouterGroup - auth *service.AuthService + log *logger.Logger + runtime model.RuntimeConfig + auth *service.AuthService } -func NewUserController(config UserControllerConfig, router *gin.RouterGroup, auth *service.AuthService) *UserController { - return &UserController{ - config: config, - router: router, - auth: auth, +func NewUserController( + log *logger.Logger, + runtimeConfig model.RuntimeConfig, + router *gin.RouterGroup, + auth *service.AuthService, +) *UserController { + controller := &UserController{ + log: log, + runtime: runtimeConfig, + auth: auth, } -} -func (controller *UserController) SetupRoutes() { - userGroup := controller.router.Group("/user") + userGroup := router.Group("/user") userGroup.POST("/login", controller.loginHandler) userGroup.POST("/logout", controller.logoutHandler) userGroup.POST("/totp", controller.totpHandler) + + return controller } func (controller *UserController) loginHandler(c *gin.Context) { @@ -56,7 +56,7 @@ func (controller *UserController) loginHandler(c *gin.Context) { err := c.ShouldBindJSON(&req) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to bind JSON") + controller.log.App.Error().Err(err).Msg("Failed to bind JSON") c.JSON(400, gin.H{ "status": 400, "message": "Bad Request", @@ -64,13 +64,13 @@ func (controller *UserController) loginHandler(c *gin.Context) { return } - tlog.App.Debug().Str("username", req.Username).Msg("Login attempt") + controller.log.App.Debug().Str("username", req.Username).Msg("Login attempt") isLocked, remaining := controller.auth.IsAccountLocked(req.Username) if isLocked { - tlog.App.Warn().Str("username", req.Username).Msg("Account is locked due to too many failed login attempts") - tlog.AuditLoginFailure(c, req.Username, "username", "account locked") + controller.log.App.Warn().Str("username", req.Username).Msg("Account is locked due to too many failed login attempts") + controller.log.AuditLoginFailure(req.Username, "local", c.ClientIP(), "account locked") c.Writer.Header().Add("x-tinyauth-lock-locked", "true") c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339)) c.JSON(429, gin.H{ @@ -84,16 +84,16 @@ func (controller *UserController) loginHandler(c *gin.Context) { if err != nil { if errors.Is(err, service.ErrUserNotFound) { - tlog.App.Warn().Str("username", req.Username).Msg("User not found") + controller.log.App.Warn().Str("username", req.Username).Msg("User not found during login attempt") controller.auth.RecordLoginAttempt(req.Username, false) - tlog.AuditLoginFailure(c, req.Username, "username", "user not found") + controller.log.AuditLoginFailure(req.Username, "unknown", c.ClientIP(), "user not found") c.JSON(401, gin.H{ "status": 401, "message": "Unauthorized", }) return } - tlog.App.Error().Err(err).Str("username", req.Username).Msg("Error searching for user") + controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Error searching for user during login attempt") c.JSON(500, gin.H{ "status": 500, "message": "Internal Server Error", @@ -102,9 +102,13 @@ func (controller *UserController) loginHandler(c *gin.Context) { } if err := controller.auth.CheckUserPassword(*search, req.Password); err != nil { - tlog.App.Warn().Err(err).Str("username", req.Username).Msg("Failed to verify password") + controller.log.App.Warn().Str("username", req.Username).Msg("Invalid password during login attempt") controller.auth.RecordLoginAttempt(req.Username, false) - tlog.AuditLoginFailure(c, req.Username, "username", "invalid password") + if search.Type == model.UserLocal { + controller.log.AuditLoginFailure(req.Username, "local", c.ClientIP(), "invalid password") + } else { + controller.log.AuditLoginFailure(req.Username, "ldap", c.ClientIP(), "invalid password") + } c.JSON(401, gin.H{ "status": 401, "message": "Unauthorized", @@ -118,7 +122,7 @@ func (controller *UserController) loginHandler(c *gin.Context) { localUser = controller.auth.GetLocalUser(req.Username) if localUser == nil { - tlog.App.Warn().Str("username", req.Username).Msg("User disappeared during login") + controller.log.App.Error().Str("username", req.Username).Msg("Local user not found after successful password verification") c.JSON(401, gin.H{ "status": 401, "message": "Unauthorized", @@ -127,7 +131,7 @@ func (controller *UserController) loginHandler(c *gin.Context) { } if localUser.TOTPSecret != "" { - tlog.App.Debug().Str("username", req.Username).Msg("User has TOTP enabled, requiring TOTP verification") + controller.log.App.Debug().Str("username", req.Username).Msg("TOTP required for user, creating pending TOTP session") name := localUser.Attributes.Name if name == "" { @@ -136,7 +140,7 @@ func (controller *UserController) loginHandler(c *gin.Context) { email := localUser.Attributes.Email if email == "" { - email = utils.CompileUserEmail(localUser.Username, controller.config.CookieDomain) + email = utils.CompileUserEmail(localUser.Username, controller.runtime.CookieDomain) } cookie, err := controller.auth.CreateSession(c, repository.Session{ @@ -148,7 +152,7 @@ func (controller *UserController) loginHandler(c *gin.Context) { }) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to create session cookie") + controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create pending TOTP session") c.JSON(500, gin.H{ "status": 500, "message": "Internal Server Error", @@ -170,7 +174,7 @@ func (controller *UserController) loginHandler(c *gin.Context) { sessionCookie := repository.Session{ Username: req.Username, Name: utils.Capitalize(req.Username), - Email: utils.CompileUserEmail(req.Username, controller.config.CookieDomain), + Email: utils.CompileUserEmail(req.Username, controller.runtime.CookieDomain), Provider: "local", } @@ -187,12 +191,10 @@ func (controller *UserController) loginHandler(c *gin.Context) { sessionCookie.Provider = "ldap" } - tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie") - cookie, err := controller.auth.CreateSession(c, sessionCookie) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to create session cookie") + controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create session cookie after successful login") c.JSON(500, gin.H{ "status": 500, "message": "Internal Server Error", @@ -202,8 +204,13 @@ func (controller *UserController) loginHandler(c *gin.Context) { http.SetCookie(c.Writer, cookie) - tlog.App.Info().Str("username", req.Username).Msg("Login successful") - tlog.AuditLoginSuccess(c, req.Username, "username") + controller.log.App.Info().Str("username", req.Username).Msg("Login successful") + + if search.Type == model.UserLocal { + controller.log.AuditLoginSuccess(req.Username, "local", c.ClientIP()) + } else { + controller.log.AuditLoginSuccess(req.Username, "ldap", c.ClientIP()) + } controller.auth.RecordLoginAttempt(req.Username, true) @@ -214,20 +221,20 @@ func (controller *UserController) loginHandler(c *gin.Context) { } func (controller *UserController) logoutHandler(c *gin.Context) { - tlog.App.Debug().Msg("Logout request received") + controller.log.App.Debug().Msg("Logout attempt") - uuid, err := c.Cookie(controller.config.SessionCookieName) + uuid, err := c.Cookie(controller.runtime.SessionCookieName) if err != nil { if errors.Is(err, http.ErrNoCookie) { - tlog.App.Warn().Msg("No session cookie found on logout request") + controller.log.App.Warn().Msg("Logout attempt without session cookie, treating as successful logout") c.JSON(200, gin.H{ "status": 200, "message": "Logout successful", }) return } - tlog.App.Error().Err(err).Msg("Error retrieving session cookie on logout") + controller.log.App.Error().Err(err).Msg("Error retrieving session cookie on logout") c.JSON(500, gin.H{ "status": 500, "message": "Internal Server Error", @@ -238,7 +245,7 @@ func (controller *UserController) logoutHandler(c *gin.Context) { cookie, err := controller.auth.DeleteSession(c, uuid) if err != nil { - tlog.App.Error().Err(err).Msg("Error deleting session on logout") + controller.log.App.Error().Err(err).Msg("Error deleting session on logout") c.JSON(500, gin.H{ "status": 500, "message": "Internal Server Error", @@ -249,10 +256,10 @@ func (controller *UserController) logoutHandler(c *gin.Context) { context, err := new(model.UserContext).NewFromGin(c) if err == nil { - tlog.AuditLogout(c, context.GetUsername(), context.GetProviderID()) + controller.log.AuditLogout(context.GetUsername(), context.GetProviderID(), c.ClientIP()) } else { - tlog.App.Warn().Err(err).Msg("Failed to get user context for logout audit, proceeding without username") - tlog.AuditLogout(c, "unknown", "unknown") + controller.log.App.Warn().Err(err).Msg("Failed to get user context during logout, logging audit with unknown user") + controller.log.AuditLogout("unknown", "unknown", c.ClientIP()) } http.SetCookie(c.Writer, cookie) @@ -268,7 +275,7 @@ func (controller *UserController) totpHandler(c *gin.Context) { err := c.ShouldBindJSON(&req) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to bind JSON") + controller.log.App.Error().Err(err).Msg("Failed to bind JSON for TOTP verification") c.JSON(400, gin.H{ "status": 400, "message": "Bad Request", @@ -279,7 +286,7 @@ func (controller *UserController) totpHandler(c *gin.Context) { context, err := new(model.UserContext).NewFromGin(c) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to get user context") + controller.log.App.Error().Err(err).Msg("Failed to create user context from request for TOTP verification") c.JSON(500, gin.H{ "status": 500, "message": "Internal Server Error", @@ -288,7 +295,7 @@ func (controller *UserController) totpHandler(c *gin.Context) { } if !context.TOTPPending() { - tlog.App.Warn().Msg("TOTP attempt without a pending TOTP session") + controller.log.App.Warn().Str("username", context.GetUsername()).Msg("TOTP verification attempt without pending TOTP session") c.JSON(401, gin.H{ "status": 401, "message": "Unauthorized", @@ -296,12 +303,13 @@ func (controller *UserController) totpHandler(c *gin.Context) { return } - tlog.App.Debug().Str("username", context.GetUsername()).Msg("TOTP verification attempt") + controller.log.App.Debug().Str("username", context.GetUsername()).Msg("TOTP verification attempt") isLocked, remaining := controller.auth.IsAccountLocked(context.GetUsername()) if isLocked { - tlog.App.Warn().Str("username", context.GetUsername()).Msg("Account is locked due to too many failed TOTP attempts") + controller.log.App.Warn().Str("username", context.GetUsername()).Msg("Account is locked due to too many failed TOTP attempts") + controller.log.AuditLoginFailure(context.GetUsername(), "local", c.ClientIP(), "account locked") c.Writer.Header().Add("x-tinyauth-lock-locked", "true") c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339)) c.JSON(429, gin.H{ @@ -314,7 +322,7 @@ func (controller *UserController) totpHandler(c *gin.Context) { user := controller.auth.GetLocalUser(context.GetUsername()) if user == nil { - tlog.App.Error().Str("username", context.GetUsername()).Msg("User not found in TOTP handler") + controller.log.App.Error().Str("username", context.GetUsername()).Msg("Local user not found during TOTP verification") c.JSON(401, gin.H{ "status": 401, "message": "Unauthorized", @@ -325,9 +333,9 @@ func (controller *UserController) totpHandler(c *gin.Context) { ok := totp.Validate(req.Code, user.TOTPSecret) if !ok { - tlog.App.Warn().Str("username", context.GetUsername()).Msg("Invalid TOTP code") + controller.log.App.Warn().Str("username", context.GetUsername()).Msg("Invalid TOTP code during verification attempt") controller.auth.RecordLoginAttempt(context.GetUsername(), false) - tlog.AuditLoginFailure(c, context.GetUsername(), "totp", "invalid totp code") + controller.log.AuditLoginFailure(context.GetUsername(), "local", c.ClientIP(), "invalid TOTP code") c.JSON(401, gin.H{ "status": 401, "message": "Unauthorized", @@ -335,15 +343,15 @@ func (controller *UserController) totpHandler(c *gin.Context) { return } - uuid, err := c.Cookie(controller.config.SessionCookieName) + uuid, err := c.Cookie(controller.runtime.SessionCookieName) if err == nil { _, err = controller.auth.DeleteSession(c, uuid) if err != nil { - tlog.App.Warn().Err(err).Msg("Failed to delete pending TOTP session") + controller.log.App.Error().Err(err).Msg("Failed to delete pending TOTP session after successful verification") } } else { - tlog.App.Warn().Err(err).Msg("Failed to retrieve session cookie for pending TOTP session, proceeding without deleting it") + controller.log.App.Warn().Err(err).Msg("Failed to retrieve session cookie for pending TOTP session, cannot delete it") } controller.auth.RecordLoginAttempt(context.GetUsername(), true) @@ -351,7 +359,7 @@ func (controller *UserController) totpHandler(c *gin.Context) { sessionCookie := repository.Session{ Username: user.Username, Name: utils.Capitalize(user.Username), - Email: utils.CompileUserEmail(user.Username, controller.config.CookieDomain), + Email: utils.CompileUserEmail(user.Username, controller.runtime.CookieDomain), Provider: "local", } @@ -362,12 +370,10 @@ func (controller *UserController) totpHandler(c *gin.Context) { sessionCookie.Email = user.Attributes.Email } - tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie") - cookie, err := controller.auth.CreateSession(c, sessionCookie) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to create session cookie") + controller.log.App.Error().Err(err).Str("username", context.GetUsername()).Msg("Failed to create session cookie after successful TOTP verification") c.JSON(500, gin.H{ "status": 500, "message": "Internal Server Error", @@ -377,8 +383,8 @@ func (controller *UserController) totpHandler(c *gin.Context) { http.SetCookie(c.Writer, cookie) - tlog.App.Info().Str("username", context.GetUsername()).Msg("TOTP verification successful") - tlog.AuditLoginSuccess(c, context.GetUsername(), "totp") + controller.log.App.Info().Str("username", context.GetUsername()).Msg("TOTP verification successful, login complete") + controller.log.AuditLoginSuccess(context.GetUsername(), "local", c.ClientIP()) c.JSON(200, gin.H{ "status": 200, diff --git a/internal/controller/user_controller_test.go b/internal/controller/user_controller_test.go index 4863c16e..10858175 100644 --- a/internal/controller/user_controller_test.go +++ b/internal/controller/user_controller_test.go @@ -5,8 +5,8 @@ import ( "encoding/json" "net/http" "net/http/httptest" - "path" "strings" + "sync" "testing" "time" @@ -19,53 +19,15 @@ import ( "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/test" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) func TestUserController(t *testing.T) { - tlog.NewTestLogger().Init() - tempDir := t.TempDir() + log := logger.NewLogger().WithTestConfig() + log.Init() - authServiceCfg := service.AuthServiceConfig{ - LocalUsers: &[]model.LocalUser{ - { - Username: "testuser", - Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password - }, - { - Username: "totpuser", - Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password - TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", - }, - { - Username: "attruser", - Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password - Attributes: model.UserAttributes{ - Name: "Alice Smith", - Email: "alice@example.com", - }, - }, - { - Username: "attrtotpuser", - Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password - TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", - Attributes: model.UserAttributes{ - Name: "Bob Jones", - Email: "bob@example.com", - }, - }, - }, - SessionExpiry: 10, // 10 seconds, useful for testing - CookieDomain: "example.com", - LoginTimeout: 10, // 10 seconds, useful for testing - LoginMaxRetries: 3, - SessionCookieName: "tinyauth-session", - } - - userControllerCfg := controller.UserControllerConfig{ - CookieDomain: "example.com", - SessionCookieName: "tinyauth-session", - } + cfg, runtime := test.CreateTestConfigs(t) totpCtx := func(c *gin.Context) { c.Set("context", &model.UserContext{ @@ -111,14 +73,12 @@ func TestUserController(t *testing.T) { }) } - oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig) + app := bootstrap.NewBootstrapApp(cfg) - app := bootstrap.NewBootstrapApp(model.Config{}) - - db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db")) + err := app.SetupDatabase() require.NoError(t, err) - queries := repository.New(db) + queries := repository.New(app.GetDB()) type testCase struct { description string @@ -136,7 +96,7 @@ func TestUserController(t *testing.T) { Password: "password", } loginReqBody, err := json.Marshal(loginReq) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody))) req.Header.Set("Content-Type", "application/json") @@ -144,7 +104,7 @@ func TestUserController(t *testing.T) { router.ServeHTTP(recorder, req) assert.Equal(t, 200, recorder.Code) - assert.Len(t, recorder.Result().Cookies(), 1) + require.Len(t, recorder.Result().Cookies(), 1) cookie := recorder.Result().Cookies()[0] assert.Equal(t, "tinyauth-session", cookie.Name) @@ -164,7 +124,7 @@ func TestUserController(t *testing.T) { Password: "wrongpassword", } loginReqBody, err := json.Marshal(loginReq) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody))) req.Header.Set("Content-Type", "application/json") @@ -185,7 +145,7 @@ func TestUserController(t *testing.T) { Password: "wrongpassword", } loginReqBody, err := json.Marshal(loginReq) - assert.NoError(t, err) + require.NoError(t, err) for range 3 { recorder := httptest.NewRecorder() @@ -220,7 +180,7 @@ func TestUserController(t *testing.T) { Password: "password", } loginReqBody, err := json.Marshal(loginReq) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody))) req.Header.Set("Content-Type", "application/json") @@ -231,12 +191,12 @@ func TestUserController(t *testing.T) { decodedBody := make(map[string]any) err = json.Unmarshal(recorder.Body.Bytes(), &decodedBody) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, decodedBody["totpPending"], true) // should set the session cookie - assert.Len(t, recorder.Result().Cookies(), 1) + require.Len(t, recorder.Result().Cookies(), 1) cookie := recorder.Result().Cookies()[0] assert.Equal(t, "tinyauth-session", cookie.Name) assert.True(t, cookie.HttpOnly) @@ -257,7 +217,7 @@ func TestUserController(t *testing.T) { Password: "password", } loginReqBody, err := json.Marshal(loginReq) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody))) req.Header.Set("Content-Type", "application/json") @@ -266,7 +226,7 @@ func TestUserController(t *testing.T) { assert.Equal(t, 200, recorder.Code) cookies := recorder.Result().Cookies() - assert.Len(t, cookies, 1) + require.Len(t, cookies, 1) cookie := cookies[0] assert.Equal(t, "tinyauth-session", cookie.Name) @@ -280,7 +240,7 @@ func TestUserController(t *testing.T) { assert.Equal(t, 200, recorder.Code) cookies = recorder.Result().Cookies() - assert.Len(t, cookies, 1) + require.Len(t, cookies, 1) cookie = cookies[0] assert.Equal(t, "tinyauth-session", cookie.Name) @@ -307,14 +267,14 @@ func TestUserController(t *testing.T) { require.NoError(t, err) code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now()) - assert.NoError(t, err) + require.NoError(t, err) totpReq := controller.TotpRequest{ Code: code, } totpReqBody, err := json.Marshal(totpReq) - assert.NoError(t, err) + require.NoError(t, err) recorder = httptest.NewRecorder() req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody))) @@ -329,7 +289,7 @@ func TestUserController(t *testing.T) { router.ServeHTTP(recorder, req) assert.Equal(t, 200, recorder.Code) - assert.Len(t, recorder.Result().Cookies(), 1) + require.Len(t, recorder.Result().Cookies(), 1) // should set a new session cookie with totp pending removed totpCookie := recorder.Result().Cookies()[0] @@ -352,7 +312,7 @@ func TestUserController(t *testing.T) { } totpReqBody, err := json.Marshal(totpReq) - assert.NoError(t, err) + require.NoError(t, err) recorder = httptest.NewRecorder() req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody))) @@ -456,21 +416,11 @@ func TestUserController(t *testing.T) { }, } - docker := service.NewDockerService() - err = docker.Init() - require.NoError(t, err) + ctx := context.TODO() + wg := &sync.WaitGroup{} - ldap := service.NewLdapService(service.LdapServiceConfig{}) - err = ldap.Init() - require.NoError(t, err) - - broker := service.NewOAuthBrokerService(oauthBrokerCfgs) - err = broker.Init() - require.NoError(t, err) - - authService := service.NewAuthService(authServiceCfg, ldap, queries, broker) - err = authService.Init() - require.NoError(t, err) + broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx) + authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, queries, broker) beforeEach := func() { // Clear failed login attempts before each test @@ -489,8 +439,7 @@ func TestUserController(t *testing.T) { group := router.Group("/api") gin.SetMode(gin.TestMode) - userController := controller.NewUserController(userControllerCfg, group, authService) - userController.SetupRoutes() + controller.NewUserController(log, runtime, group, authService) recorder := httptest.NewRecorder() @@ -499,7 +448,6 @@ func TestUserController(t *testing.T) { } t.Cleanup(func() { - err = db.Close() - require.NoError(t, err) + app.GetDB().Close() }) } diff --git a/internal/controller/well_known_controller.go b/internal/controller/well_known_controller.go index f31a9ed7..8c71d890 100644 --- a/internal/controller/well_known_controller.go +++ b/internal/controller/well_known_controller.go @@ -26,28 +26,30 @@ type OpenIDConnectConfiguration struct { RequestObjectSigningAlgValuesSupported []string `json:"request_object_signing_alg_values_supported"` } -type WellKnownControllerConfig struct{} - type WellKnownController struct { - config WellKnownControllerConfig - engine *gin.Engine - oidc *service.OIDCService + oidc *service.OIDCService } -func NewWellKnownController(config WellKnownControllerConfig, oidc *service.OIDCService, engine *gin.Engine) *WellKnownController { - return &WellKnownController{ - config: config, - oidc: oidc, - engine: engine, +func NewWellKnownController(oidc *service.OIDCService, router *gin.RouterGroup) *WellKnownController { + controller := &WellKnownController{ + oidc: oidc, } -} -func (controller *WellKnownController) SetupRoutes() { - controller.engine.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration) - controller.engine.GET("/.well-known/jwks.json", controller.JWKS) + router.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration) + router.GET("/.well-known/jwks.json", controller.JWKS) + + return controller } func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context) { + if controller.oidc == nil { + c.JSON(500, gin.H{ + "status": 500, + "message": "OIDC service not configured", + }) + return + } + issuer := controller.oidc.GetIssuer() c.JSON(200, OpenIDConnectConfiguration{ Issuer: issuer, @@ -69,11 +71,19 @@ func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context } func (controller *WellKnownController) JWKS(c *gin.Context) { + if controller.oidc == nil { + c.JSON(500, gin.H{ + "status": 500, + "message": "OIDC service not configured", + }) + return + } + jwks, err := controller.oidc.GetJWK() if err != nil { c.JSON(500, gin.H{ - "status": "500", + "status": 500, "message": "failed to get JWK", }) return diff --git a/internal/controller/well_known_controller_test.go b/internal/controller/well_known_controller_test.go index 7dcf2bdc..e2323da2 100644 --- a/internal/controller/well_known_controller_test.go +++ b/internal/controller/well_known_controller_test.go @@ -1,10 +1,11 @@ package controller_test import ( + "context" "encoding/json" "fmt" "net/http/httptest" - "path" + "sync" "testing" "github.com/gin-gonic/gin" @@ -12,30 +13,17 @@ import ( "github.com/stretchr/testify/require" "github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/controller" - "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/test" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) func TestWellKnownController(t *testing.T) { - tlog.NewTestLogger().Init() - tempDir := t.TempDir() + log := logger.NewLogger().WithTestConfig() + log.Init() - oidcServiceCfg := service.OIDCServiceConfig{ - Clients: map[string]model.OIDCClientConfig{ - "test": { - ClientID: "some-client-id", - ClientSecret: "some-client-secret", - TrustedRedirectURIs: []string{"https://test.example.com/callback"}, - Name: "Test Client", - }, - }, - PrivateKeyPath: path.Join(tempDir, "key.pem"), - PublicKeyPath: path.Join(tempDir, "key.pub"), - Issuer: "https://tinyauth.example.com", - SessionExpiry: 500, - } + cfg, runtime := test.CreateTestConfigs(t) type testCase struct { description string @@ -56,11 +44,11 @@ func TestWellKnownController(t *testing.T) { assert.NoError(t, err) expected := controller.OpenIDConnectConfiguration{ - Issuer: oidcServiceCfg.Issuer, - AuthorizationEndpoint: fmt.Sprintf("%s/authorize", oidcServiceCfg.Issuer), - TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", oidcServiceCfg.Issuer), - UserinfoEndpoint: fmt.Sprintf("%s/api/oidc/userinfo", oidcServiceCfg.Issuer), - JwksUri: fmt.Sprintf("%s/.well-known/jwks.json", oidcServiceCfg.Issuer), + Issuer: runtime.AppURL, + AuthorizationEndpoint: fmt.Sprintf("%s/authorize", runtime.AppURL), + TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", runtime.AppURL), + UserinfoEndpoint: fmt.Sprintf("%s/api/oidc/userinfo", runtime.AppURL), + JwksUri: fmt.Sprintf("%s/.well-known/jwks.json", runtime.AppURL), ScopesSupported: service.SupportedScopes, ResponseTypesSupported: service.SupportedResponseTypes, GrantTypesSupported: service.SupportedGrantTypes, @@ -101,15 +89,17 @@ func TestWellKnownController(t *testing.T) { }, } - app := bootstrap.NewBootstrapApp(model.Config{}) + ctx := context.TODO() + wg := &sync.WaitGroup{} - db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db")) + app := bootstrap.NewBootstrapApp(cfg) + + err := app.SetupDatabase() require.NoError(t, err) - queries := repository.New(db) + queries := repository.New(app.GetDB()) - oidcService := service.NewOIDCService(oidcServiceCfg, queries) - err = oidcService.Init() + oidcService, err := service.NewOIDCService(log, cfg, runtime, queries, ctx, wg) require.NoError(t, err) for _, test := range tests { @@ -119,15 +109,13 @@ func TestWellKnownController(t *testing.T) { recorder := httptest.NewRecorder() - wellKnownController := controller.NewWellKnownController(controller.WellKnownControllerConfig{}, oidcService, router) - wellKnownController.SetupRoutes() + controller.NewWellKnownController(oidcService, &router.RouterGroup) test.run(t, router, recorder) }) } t.Cleanup(func() { - err = db.Close() - require.NoError(t, err) + app.GetDB().Close() }) } diff --git a/internal/middleware/context_middleware.go b/internal/middleware/context_middleware.go index 88e96462..6e6bbe56 100644 --- a/internal/middleware/context_middleware.go +++ b/internal/middleware/context_middleware.go @@ -10,7 +10,7 @@ import ( "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/gin-gonic/gin" ) @@ -35,29 +35,27 @@ var ( } ) -type ContextMiddlewareConfig struct { - CookieDomain string - SessionCookieName string -} - type ContextMiddleware struct { - config ContextMiddlewareConfig - auth *service.AuthService - broker *service.OAuthBrokerService + log *logger.Logger + runtime model.RuntimeConfig + auth *service.AuthService + broker *service.OAuthBrokerService } -func NewContextMiddleware(config ContextMiddlewareConfig, auth *service.AuthService, broker *service.OAuthBrokerService) *ContextMiddleware { +func NewContextMiddleware( + log *logger.Logger, + runtime model.RuntimeConfig, + auth *service.AuthService, + broker *service.OAuthBrokerService, +) *ContextMiddleware { return &ContextMiddleware{ - config: config, - auth: auth, - broker: broker, + log: log, + runtime: runtime, + auth: auth, + broker: broker, } } -func (m *ContextMiddleware) Init() error { - return nil -} - func (m *ContextMiddleware) Middleware() gin.HandlerFunc { return func(c *gin.Context) { if m.isIgnorePath(c.Request.Method + " " + c.Request.URL.Path) { @@ -65,7 +63,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { return } - uuid, err := c.Cookie(m.config.SessionCookieName) + uuid, err := c.Cookie(m.runtime.SessionCookieName) if err == nil { userContext, cookie, err := m.cookieAuth(c.Request.Context(), uuid) @@ -75,12 +73,12 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { http.SetCookie(c.Writer, cookie) } - tlog.App.Trace().Msgf("Authenticated user from session cookie: %s", userContext.GetUsername()) + m.log.App.Debug().Msgf("Authenticated user %s via session cookie", userContext.GetUsername()) c.Set("context", userContext) c.Next() return } else { - tlog.App.Error().Msgf("Error authenticating session cookie: %v", err) + m.log.App.Debug().Msgf("Error authenticating session cookie: %v", err) } } @@ -90,7 +88,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { userContext, headers, err := m.basicAuth(username, password) if err != nil { - tlog.App.Error().Msgf("Error authenticating basic auth: %v", err) + m.log.App.Error().Msgf("Error authenticating basic auth: %v", err) c.Next() return } @@ -141,7 +139,7 @@ func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string) (*model } if userContext.Local.Attributes.Email == "" { - userContext.Local.Attributes.Email = utils.CompileUserEmail(user.Username, m.config.CookieDomain) + userContext.Local.Attributes.Email = utils.CompileUserEmail(user.Username, m.runtime.CookieDomain) } case model.ProviderLDAP: search, err := m.auth.SearchUser(userContext.LDAP.Username) @@ -162,7 +160,7 @@ func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string) (*model userContext.LDAP.Groups = user.Groups userContext.LDAP.Name = utils.Capitalize(userContext.LDAP.Username) - userContext.LDAP.Email = utils.CompileUserEmail(userContext.LDAP.Username, m.config.CookieDomain) + userContext.LDAP.Email = utils.CompileUserEmail(userContext.LDAP.Username, m.runtime.CookieDomain) case model.ProviderOAuth: _, exists := m.broker.GetService(userContext.OAuth.ID) @@ -191,7 +189,7 @@ func (m *ContextMiddleware) basicAuth(username string, password string) (*model. locked, remaining := m.auth.IsAccountLocked(username) if locked { - tlog.App.Debug().Msgf("Account for user %s is locked for %d seconds, denying auth", username, remaining) + m.log.App.Debug().Msgf("Account for user %s is locked for %d seconds, denying auth", username, remaining) headers["x-tinyauth-lock-locked"] = "true" headers["x-tinyauth-lock-reset"] = time.Now().Add(time.Duration(remaining) * time.Second).Format(time.RFC3339) return nil, headers, nil @@ -224,7 +222,7 @@ func (m *ContextMiddleware) basicAuth(username string, password string) (*model. BaseContext: model.BaseContext{ Username: user.Username, Name: utils.Capitalize(user.Username), - Email: utils.CompileUserEmail(user.Username, m.config.CookieDomain), + Email: utils.CompileUserEmail(user.Username, m.runtime.CookieDomain), }, Attributes: user.Attributes, } @@ -240,7 +238,7 @@ func (m *ContextMiddleware) basicAuth(username string, password string) (*model. BaseContext: model.BaseContext{ Username: username, Name: utils.Capitalize(username), - Email: utils.CompileUserEmail(username, m.config.CookieDomain), + Email: utils.CompileUserEmail(username, m.runtime.CookieDomain), }, Groups: user.Groups, } diff --git a/internal/middleware/context_middleware_test.go b/internal/middleware/context_middleware_test.go index 5dfde3b4..03f9f553 100644 --- a/internal/middleware/context_middleware_test.go +++ b/internal/middleware/context_middleware_test.go @@ -5,7 +5,7 @@ import ( "encoding/base64" "net/http" "net/http/httptest" - "path" + "sync" "testing" "time" @@ -17,36 +17,15 @@ import ( "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/test" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) func TestContextMiddleware(t *testing.T) { - tlog.NewTestLogger().Init() - tempDir := t.TempDir() + log := logger.NewLogger().WithTestConfig() + log.Init() - authServiceCfg := service.AuthServiceConfig{ - LocalUsers: &[]model.LocalUser{ - { - Username: "testuser", - Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password - }, - { - Username: "totpuser", - Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password - TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", - }, - }, - SessionExpiry: 10, // 10 seconds, useful for testing - CookieDomain: "example.com", - LoginTimeout: 10, // 10 seconds, useful for testing - LoginMaxRetries: 3, - SessionCookieName: "tinyauth-session", - } - - middlewareCfg := middleware.ContextMiddlewareConfig{ - CookieDomain: "example.com", - SessionCookieName: "tinyauth-session", - } + cfg, runtime := test.CreateTestConfigs(t) basicAuthHeader := func(username, password string) string { return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password)) @@ -270,30 +249,20 @@ func TestContextMiddleware(t *testing.T) { }, } - oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig) + ctx := context.TODO() + wg := &sync.WaitGroup{} - app := bootstrap.NewBootstrapApp(model.Config{}) + app := bootstrap.NewBootstrapApp(cfg) - db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db")) + err := app.SetupDatabase() require.NoError(t, err) - queries := repository.New(db) + queries := repository.New(app.GetDB()) - ldap := service.NewLdapService(service.LdapServiceConfig{}) - err = ldap.Init() - require.NoError(t, err) + broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx) + authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, queries, broker) - broker := service.NewOAuthBrokerService(oauthBrokerCfgs) - err = broker.Init() - require.NoError(t, err) - - authService := service.NewAuthService(authServiceCfg, ldap, queries, broker) - err = authService.Init() - require.NoError(t, err) - - contextMiddleware := middleware.NewContextMiddleware(middlewareCfg, authService, broker) - err = contextMiddleware.Init() - require.NoError(t, err) + contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker) for _, test := range tests { authService.ClearRateLimitsTestingOnly() @@ -322,7 +291,6 @@ func TestContextMiddleware(t *testing.T) { } t.Cleanup(func() { - err = db.Close() - require.NoError(t, err) + app.GetDB().Close() }) } diff --git a/internal/middleware/ui_middleware.go b/internal/middleware/ui_middleware.go index 96553b07..2b8d6b8a 100644 --- a/internal/middleware/ui_middleware.go +++ b/internal/middleware/ui_middleware.go @@ -9,7 +9,6 @@ import ( "time" "github.com/tinyauthapp/tinyauth/internal/assets" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/gin-gonic/gin" ) @@ -19,29 +18,25 @@ type UIMiddleware struct { uiFileServer http.Handler } -func NewUIMiddleware() *UIMiddleware { - return &UIMiddleware{} -} +func NewUIMiddleware() (*UIMiddleware, error) { + m := &UIMiddleware{} -func (m *UIMiddleware) Init() error { ui, err := fs.Sub(assets.FrontendAssets, "dist") if err != nil { - return err + return nil, fmt.Errorf("failed to load ui assets: %w", err) } m.uiFs = ui m.uiFileServer = http.FileServerFS(ui) - return nil + return m, nil } func (m *UIMiddleware) Middleware() gin.HandlerFunc { return func(c *gin.Context) { path := strings.TrimPrefix(c.Request.URL.Path, "/") - tlog.App.Debug().Str("path", path).Msg("path") - switch strings.SplitN(path, "/", 2)[0] { case "api", "resources", ".well-known": c.Next() diff --git a/internal/middleware/zerolog_middleware.go b/internal/middleware/zerolog_middleware.go index d75e3a72..9870a70a 100644 --- a/internal/middleware/zerolog_middleware.go +++ b/internal/middleware/zerolog_middleware.go @@ -5,7 +5,7 @@ import ( "time" "github.com/gin-gonic/gin" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) // See context middleware for explanation of why we have to do this @@ -17,14 +17,14 @@ var ( } ) -type ZerologMiddleware struct{} - -func NewZerologMiddleware() *ZerologMiddleware { - return &ZerologMiddleware{} +type ZerologMiddleware struct { + log *logger.Logger } -func (m *ZerologMiddleware) Init() error { - return nil +func NewZerologMiddleware(log *logger.Logger) *ZerologMiddleware { + return &ZerologMiddleware{ + log: log, + } } func (m *ZerologMiddleware) logPath(path string) bool { @@ -50,7 +50,7 @@ func (m *ZerologMiddleware) Middleware() gin.HandlerFunc { latency := time.Since(tStart).String() - subLogger := tlog.HTTP.With().Str("method", method). + subLogger := m.log.HTTP.With().Str("method", method). Str("path", path). Str("address", address). Str("client_ip", clientIP). diff --git a/internal/model/config.go b/internal/model/config.go index 95870e3d..f5376af2 100644 --- a/internal/model/config.go +++ b/internal/model/config.go @@ -14,8 +14,9 @@ func NewDefaultConfiguration() *Config { Path: "./resources", }, Server: ServerConfig{ - Port: 3000, - Address: "0.0.0.0", + Port: 3000, + Address: "0.0.0.0", + ConcurrentListenersEnabled: false, }, Auth: AuthConfig{ SubdomainsEnabled: true, @@ -95,9 +96,10 @@ type ResourcesConfig struct { } type ServerConfig struct { - Port int `description:"The port on which the server listens." yaml:"port"` - Address string `description:"The address on which the server listens." yaml:"address"` - SocketPath string `description:"The path to the Unix socket." yaml:"socketPath"` + Port int `description:"The port on which the server listens." yaml:"port"` + Address string `description:"The address on which the server listens." yaml:"address"` + SocketPath string `description:"The path to the Unix socket." yaml:"socketPath"` + ConcurrentListenersEnabled bool `description:"Enable listening on both TCP and Unix socket at the same time." yaml:"concurrentListenersEnabled"` } type AuthConfig struct { @@ -147,10 +149,10 @@ type IPConfig struct { } type OAuthConfig struct { - Whitelist []string `description:"Comma-separated list of allowed OAuth domains." yaml:"whitelist"` - WhitelistFile string `description:"Path to the OAuth whitelist file." yaml:"whitelistFile"` - AutoRedirect string `description:"The OAuth provider to use for automatic redirection." yaml:"autoRedirect"` - Providers map[string]OAuthServiceConfig `description:"OAuth providers configuration." yaml:"providers"` + Whitelist []string `description:"Comma-separated list of allowed OAuth domains." yaml:"whitelist"` + WhitelistFile string `description:"Path to the OAuth whitelist file." yaml:"whitelistFile"` + AutoRedirect string `description:"The OAuth provider to use for automatic redirection." yaml:"autoRedirect"` + Providers map[string]OAuthServiceConfig `description:"OAuth providers configuration." yaml:"providers"` } type OIDCConfig struct { diff --git a/internal/model/context.go b/internal/model/context.go index 7384ebe8..b9e31bef 100644 --- a/internal/model/context.go +++ b/internal/model/context.go @@ -8,6 +8,10 @@ import ( "github.com/tinyauthapp/tinyauth/internal/repository" ) +var ( + ErrUserContextNotFound = errors.New("user context not found") +) + type ProviderType int const ( @@ -74,7 +78,7 @@ func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) { userContextValue, exists := ginctx.Get("context") if !exists { - return nil, errors.New("failed to get user context") + return nil, ErrUserContextNotFound } userContext, ok := userContextValue.(*UserContext) @@ -117,7 +121,7 @@ func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext, Email: session.Email, }, } - // By default we assume an unkown name which is oauth + // By default we assume an unknown name which is oauth default: c.Provider = ProviderOAuth c.OAuth = &OAuthContext{ diff --git a/internal/model/context_test.go b/internal/model/context_test.go index 733805a7..79bc97b0 100644 --- a/internal/model/context_test.go +++ b/internal/model/context_test.go @@ -238,7 +238,7 @@ func TestContext(t *testing.T) { _, err := c.NewFromGin(newGinCtx(nil, false)) return err.Error() }, - expected: "failed to get user context", + expected: model.ErrUserContextNotFound.Error(), }, { description: "NewFromGin returns error when context value has wrong type", diff --git a/internal/model/runtime.go b/internal/model/runtime.go new file mode 100644 index 00000000..9bd81770 --- /dev/null +++ b/internal/model/runtime.go @@ -0,0 +1,22 @@ +package model + +type RuntimeConfig struct { + AppURL string + UUID string + CookieDomain string + SessionCookieName string + CSRFCookieName string + RedirectCookieName string + OAuthSessionCookieName string + LocalUsers []LocalUser + OAuthProviders map[string]OAuthServiceConfig + OAuthWhitelist []string + ConfiguredProviders []Provider + OIDCClients []OIDCClientConfig +} + +type Provider struct { + Name string `json:"name"` + ID string `json:"id"` + OAuth bool `json:"oauth"` +} diff --git a/internal/service/access_controls_service.go b/internal/service/access_controls_service.go index fd57bf39..34700ea7 100644 --- a/internal/service/access_controls_service.go +++ b/internal/service/access_controls_service.go @@ -4,7 +4,7 @@ import ( "strings" "github.com/tinyauthapp/tinyauth/internal/model" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) type LabelProvider interface { @@ -12,32 +12,33 @@ type LabelProvider interface { } type AccessControlsService struct { - labelProvider LabelProvider + log *logger.Logger + labelProvider *LabelProvider static map[string]model.App } -func NewAccessControlsService(labelProvider LabelProvider, static map[string]model.App) *AccessControlsService { +func NewAccessControlsService( + log *logger.Logger, + labelProvider *LabelProvider, + static map[string]model.App) *AccessControlsService { return &AccessControlsService{ + log: log, labelProvider: labelProvider, static: static, } } -func (acls *AccessControlsService) Init() error { - return nil // No initialization needed -} - func (acls *AccessControlsService) lookupStaticACLs(domain string) *model.App { var appAcls *model.App for app, config := range acls.static { if config.Config.Domain == domain { - tlog.App.Debug().Str("name", app).Msg("Found matching container by domain") + acls.log.App.Debug().Str("name", app).Msg("Found matching container by domain") appAcls = &config break // If we find a match by domain, we can stop searching } if strings.SplitN(domain, ".", 2)[0] == app { - tlog.App.Debug().Str("name", app).Msg("Found matching container by app name") + acls.log.App.Debug().Str("name", app).Msg("Found matching container by app name") appAcls = &config break // If we find a match by app name, we can stop searching } @@ -50,11 +51,15 @@ func (acls *AccessControlsService) GetAccessControls(domain string) (*model.App, app := acls.lookupStaticACLs(domain) if app != nil { - tlog.App.Debug().Msg("Using ACls from static configuration") + acls.log.App.Debug().Msg("Using static ACLs for app") return app, nil } - // Fallback to label provider - tlog.App.Debug().Msg("Falling back to label provider for ACLs") - return acls.labelProvider.GetLabels(domain) + // If we have a label provider configured, try to get ACLs from it + if acls.labelProvider != nil { + return (*acls.labelProvider).GetLabels(domain) + } + + // no labels + return nil, nil } diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index 16c53fe0..a9139bb3 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -14,7 +14,7 @@ import ( "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/utils" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" "slices" @@ -72,39 +72,41 @@ type Lockdown struct { ActiveUntil time.Time } -type AuthServiceConfig struct { - LocalUsers *[]model.LocalUser - OauthWhitelist []string - SessionExpiry int - SessionMaxLifetime int - SecureCookie bool - CookieDomain string - LoginTimeout int - LoginMaxRetries int - SessionCookieName string - IP model.IPConfig - LDAPGroupsCacheTTL int - SubdomainsEnabled bool -} - type AuthService struct { - config AuthServiceConfig + log *logger.Logger + config model.Config + runtime model.RuntimeConfig + context context.Context + + ldap *LdapService + queries *repository.Queries + oauthBroker *OAuthBrokerService + loginAttempts map[string]*LoginAttempt ldapGroupsCache map[string]*LdapGroupsCache oauthPendingSessions map[string]*OAuthPendingSession oauthMutex sync.RWMutex loginMutex sync.RWMutex ldapGroupsMutex sync.RWMutex - ldap *LdapService - queries *repository.Queries - oauthBroker *OAuthBrokerService lockdown *Lockdown lockdownCtx context.Context lockdownCancelFunc context.CancelFunc } -func NewAuthService(config AuthServiceConfig, ldap *LdapService, queries *repository.Queries, oauthBroker *OAuthBrokerService) *AuthService { - return &AuthService{ +func NewAuthService( + log *logger.Logger, + config model.Config, + runtime model.RuntimeConfig, + ctx context.Context, + wg *sync.WaitGroup, + ldap *LdapService, + queries *repository.Queries, + oauthBroker *OAuthBrokerService, +) *AuthService { + service := &AuthService{ + log: log, + runtime: runtime, + context: ctx, config: config, loginAttempts: make(map[string]*LoginAttempt), ldapGroupsCache: make(map[string]*LdapGroupsCache), @@ -113,11 +115,10 @@ func NewAuthService(config AuthServiceConfig, ldap *LdapService, queries *reposi queries: queries, oauthBroker: oauthBroker, } -} -func (auth *AuthService) Init() error { - go auth.CleanupOAuthSessionsRoutine() - return nil + wg.Go(service.CleanupOAuthSessionsRoutine) + + return service } func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error) { @@ -128,7 +129,7 @@ func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error) }, nil } - if auth.ldap.IsConfigured() { + if auth.ldap != nil { userDN, err := auth.ldap.GetUserDN(username) if err != nil { @@ -153,7 +154,7 @@ func (auth *AuthService) CheckUserPassword(search model.UserSearch, password str } return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) case model.UserLDAP: - if auth.ldap.IsConfigured() { + if auth.ldap != nil { err := auth.ldap.Bind(search.Username, password) if err != nil { return fmt.Errorf("failed to bind to ldap user: %w", err) @@ -173,10 +174,10 @@ func (auth *AuthService) CheckUserPassword(search model.UserSearch, password str } func (auth *AuthService) GetLocalUser(username string) *model.LocalUser { - if auth.config.LocalUsers == nil { + if auth.runtime.LocalUsers == nil { return nil } - for _, user := range *auth.config.LocalUsers { + for _, user := range auth.runtime.LocalUsers { if user.Username == username { return &user } @@ -185,7 +186,7 @@ func (auth *AuthService) GetLocalUser(username string) *model.LocalUser { } func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) { - if !auth.ldap.IsConfigured() { + if auth.ldap == nil { return nil, errors.New("ldap service not configured") } @@ -209,7 +210,7 @@ func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) { auth.ldapGroupsMutex.Lock() auth.ldapGroupsCache[userDN] = &LdapGroupsCache{ Groups: groups, - Expires: time.Now().Add(time.Duration(auth.config.LDAPGroupsCacheTTL) * time.Second), + Expires: time.Now().Add(time.Duration(auth.config.LDAP.GroupCacheTTL) * time.Second), } auth.ldapGroupsMutex.Unlock() @@ -228,7 +229,7 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) { return true, remaining } - if auth.config.LoginMaxRetries <= 0 || auth.config.LoginTimeout <= 0 { + if auth.config.Auth.LoginMaxRetries <= 0 || auth.config.Auth.LoginTimeout <= 0 { return false, 0 } @@ -246,7 +247,7 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) { } func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) { - if auth.config.LoginMaxRetries <= 0 || auth.config.LoginTimeout <= 0 { + if auth.config.Auth.LoginMaxRetries <= 0 || auth.config.Auth.LoginTimeout <= 0 { return } @@ -277,14 +278,14 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) { attempt.FailedAttempts++ - if attempt.FailedAttempts >= auth.config.LoginMaxRetries { - attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.LoginTimeout) * time.Second) - tlog.App.Warn().Str("identifier", identifier).Int("timeout", auth.config.LoginTimeout).Msg("Account locked due to too many failed login attempts") + if attempt.FailedAttempts >= auth.config.Auth.LoginMaxRetries { + attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second) + auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", attempt.FailedAttempts).Msg("Account locked due to too many failed login attempts") } } func (auth *AuthService) IsEmailWhitelisted(email string) bool { - return utils.CheckFilter(strings.Join(auth.config.OauthWhitelist, ","), email) + return utils.CheckFilter(strings.Join(auth.runtime.OAuthWhitelist, ","), email) } func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) { @@ -299,7 +300,7 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess if data.TotpPending { expiry = 3600 } else { - expiry = auth.config.SessionExpiry + expiry = auth.config.Auth.SessionExpiry } expiresAt := time.Now().Add(time.Duration(expiry) * time.Second) @@ -325,13 +326,13 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess } return &http.Cookie{ - Name: auth.config.SessionCookieName, + Name: auth.runtime.SessionCookieName, Value: session.UUID, Path: "/", - Domain: fmt.Sprintf(".%s", auth.config.CookieDomain), + Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain), Expires: expiresAt, MaxAge: int(time.Until(expiresAt).Seconds()), - Secure: auth.config.SecureCookie, + Secure: auth.config.Auth.SecureCookie, HttpOnly: true, SameSite: http.SameSiteLaxMode, }, nil @@ -348,8 +349,8 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http var refreshThreshold int64 - if auth.config.SessionExpiry <= int(time.Hour.Seconds()) { - refreshThreshold = int64(auth.config.SessionExpiry / 2) + if auth.config.Auth.SessionExpiry <= int(time.Hour.Seconds()) { + refreshThreshold = int64(auth.config.Auth.SessionExpiry / 2) } else { refreshThreshold = int64(time.Hour.Seconds()) } @@ -378,13 +379,13 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http } return &http.Cookie{ - Name: auth.config.SessionCookieName, + Name: auth.runtime.SessionCookieName, Value: session.UUID, Path: "/", - Domain: fmt.Sprintf(".%s", auth.config.CookieDomain), + Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain), Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second), MaxAge: int(newExpiry - currentTime), - Secure: auth.config.SecureCookie, + Secure: auth.config.Auth.SecureCookie, HttpOnly: true, SameSite: http.SameSiteLaxMode, }, nil @@ -395,23 +396,17 @@ func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http. err := auth.queries.DeleteSession(ctx, uuid) if err != nil { - tlog.App.Warn().Err(err).Msg("Failed to delete session from database, proceeding to clear cookie anyway") - } - - err = auth.queries.DeleteSession(ctx, uuid) - - if err != nil { - return nil, err + auth.log.App.Error().Err(err).Str("uuid", uuid).Msg("Failed to delete session from database") } return &http.Cookie{ - Name: auth.config.SessionCookieName, + Name: auth.runtime.SessionCookieName, Value: "", Path: "/", - Domain: fmt.Sprintf(".%s", auth.config.CookieDomain), + Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain), Expires: time.Now(), MaxAge: -1, - Secure: auth.config.SecureCookie, + Secure: auth.config.Auth.SecureCookie, HttpOnly: true, SameSite: http.SameSiteLaxMode, }, nil @@ -429,8 +424,8 @@ func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*reposito currentTime := time.Now().Unix() - if auth.config.SessionMaxLifetime != 0 && session.CreatedAt != 0 { - if currentTime-session.CreatedAt > int64(auth.config.SessionMaxLifetime) { + if auth.config.Auth.SessionMaxLifetime != 0 && session.CreatedAt != 0 { + if currentTime-session.CreatedAt > int64(auth.config.Auth.SessionMaxLifetime) { err = auth.queries.DeleteSession(ctx, uuid) if err != nil { return nil, fmt.Errorf("failed to delete expired session: %w", err) @@ -451,11 +446,11 @@ func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*reposito } func (auth *AuthService) LocalAuthConfigured() bool { - return auth.config.LocalUsers != nil && len(*auth.config.LocalUsers) > 0 + return len(auth.runtime.LocalUsers) > 0 } func (auth *AuthService) LDAPAuthConfigured() bool { - return auth.ldap.IsConfigured() + return auth.ldap != nil } func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext, acls *model.App) bool { @@ -464,18 +459,18 @@ func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext } if context.Provider == model.ProviderOAuth { - tlog.App.Debug().Msg("Checking OAuth whitelist") + auth.log.App.Debug().Msg("User is an OAuth user, checking OAuth whitelist") return utils.CheckFilter(acls.OAuth.Whitelist, context.OAuth.Email) } if acls.Users.Block != "" { - tlog.App.Debug().Msg("Checking blocked users") + auth.log.App.Debug().Msg("Checking users block list") if utils.CheckFilter(acls.Users.Block, context.GetUsername()) { return false } } - tlog.App.Debug().Msg("Checking users") + auth.log.App.Debug().Msg("Checking users allow list") return utils.CheckFilter(acls.Users.Allow, context.GetUsername()) } @@ -485,23 +480,23 @@ func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContex } if !context.IsOAuth() { - tlog.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check") + auth.log.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check") return false } if _, ok := model.OverrideProviders[context.OAuth.ID]; ok { - tlog.App.Debug().Msg("Provider override for OAuth groups enabled, skipping group check") + auth.log.App.Debug().Str("provider", context.OAuth.ID).Msg("Provider override detected, skipping group check") return true } for _, userGroup := range context.OAuth.Groups { if utils.CheckFilter(acls.OAuth.Groups, strings.TrimSpace(userGroup)) { - tlog.App.Trace().Str("group", userGroup).Str("required", acls.OAuth.Groups).Msg("User group matched") + auth.log.App.Trace().Str("group", userGroup).Str("required", acls.OAuth.Groups).Msg("User group matched") return true } } - tlog.App.Debug().Msg("No groups matched") + auth.log.App.Debug().Msg("No groups matched") return false } @@ -511,18 +506,18 @@ func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext } if !context.IsLDAP() { - tlog.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check") + auth.log.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check") return false } for _, userGroup := range context.LDAP.Groups { if utils.CheckFilter(acls.LDAP.Groups, strings.TrimSpace(userGroup)) { - tlog.App.Trace().Str("group", userGroup).Str("required", acls.LDAP.Groups).Msg("User group matched") + auth.log.App.Trace().Str("group", userGroup).Str("required", acls.LDAP.Groups).Msg("User group matched") return true } } - tlog.App.Debug().Msg("No groups matched") + auth.log.App.Debug().Msg("No groups matched") return false } @@ -566,17 +561,17 @@ func (auth *AuthService) CheckIP(ip string, acls *model.App) bool { } // Merge the global and app IP filter - blockedIps := append(auth.config.IP.Block, acls.IP.Block...) - allowedIPs := append(auth.config.IP.Allow, acls.IP.Allow...) + blockedIps := append(auth.config.Auth.IP.Block, acls.IP.Block...) + allowedIPs := append(auth.config.Auth.IP.Allow, acls.IP.Allow...) for _, blocked := range blockedIps { res, err := utils.FilterIP(blocked, ip) if err != nil { - tlog.App.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list") + auth.log.App.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list") continue } if res { - tlog.App.Debug().Str("ip", ip).Str("item", blocked).Msg("IP is in blocked list, denying access") + auth.log.App.Debug().Str("ip", ip).Str("item", blocked).Msg("IP is in block list, denying access") return false } } @@ -584,21 +579,21 @@ func (auth *AuthService) CheckIP(ip string, acls *model.App) bool { for _, allowed := range allowedIPs { res, err := utils.FilterIP(allowed, ip) if err != nil { - tlog.App.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list") + auth.log.App.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list") continue } if res { - tlog.App.Debug().Str("ip", ip).Str("item", allowed).Msg("IP is in allowed list, allowing access") + auth.log.App.Debug().Str("ip", ip).Str("item", allowed).Msg("IP is in allow list, allowing access") return true } } if len(allowedIPs) > 0 { - tlog.App.Debug().Str("ip", ip).Msg("IP not in allow list, denying access") + auth.log.App.Debug().Str("ip", ip).Msg("IP not in allow list, denying access") return false } - tlog.App.Debug().Str("ip", ip).Msg("IP not in allow or block list, allowing by default") + auth.log.App.Debug().Str("ip", ip).Msg("IP not in any block or allow list, allowing access by default") return true } @@ -610,16 +605,16 @@ func (auth *AuthService) IsBypassedIP(ip string, acls *model.App) bool { for _, bypassed := range acls.IP.Bypass { res, err := utils.FilterIP(bypassed, ip) if err != nil { - tlog.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list") + auth.log.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list") continue } if res { - tlog.App.Debug().Str("ip", ip).Str("item", bypassed).Msg("IP is in bypass list, allowing access") + auth.log.App.Debug().Str("ip", ip).Str("item", bypassed).Msg("IP is in bypass list, skipping authentication") return true } } - tlog.App.Debug().Str("ip", ip).Msg("IP not in bypass list, continuing with authentication") + auth.log.App.Debug().Str("ip", ip).Msg("IP not in bypass list, proceeding with authentication") return false } @@ -723,21 +718,32 @@ func (auth *AuthService) EndOAuthSession(sessionId string) { } func (auth *AuthService) CleanupOAuthSessionsRoutine() { + auth.log.App.Debug().Msg("Starting OAuth session cleanup routine") + ticker := time.NewTicker(30 * time.Minute) defer ticker.Stop() - for range ticker.C { - auth.oauthMutex.Lock() + for { + select { + case <-ticker.C: + auth.log.App.Debug().Msg("Running OAuth session cleanup") - now := time.Now() + auth.oauthMutex.Lock() - for sessionId, session := range auth.oauthPendingSessions { - if now.After(session.ExpiresAt) { - delete(auth.oauthPendingSessions, sessionId) + now := time.Now() + + for sessionId, session := range auth.oauthPendingSessions { + if now.After(session.ExpiresAt) { + delete(auth.oauthPendingSessions, sessionId) + } } - } - auth.oauthMutex.Unlock() + auth.oauthMutex.Unlock() + auth.log.App.Debug().Msg("OAuth session cleanup completed") + case <-auth.context.Done(): + auth.log.App.Debug().Msg("Stopping OAuth session cleanup routine") + return + } } } @@ -806,11 +812,11 @@ func (auth *AuthService) lockdownMode() { auth.loginMutex.Lock() - tlog.App.Warn().Msg("Multiple login attempts detected, possibly DDOS attack. Activating temporary lockdown.") + auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode") auth.lockdown = &Lockdown{ Active: true, - ActiveUntil: time.Now().Add(time.Duration(auth.config.LoginTimeout) * time.Second), + ActiveUntil: time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second), } // At this point all login attemps will also expire so, @@ -827,11 +833,14 @@ func (auth *AuthService) lockdownMode() { // Timer expired, end lockdown case <-ctx.Done(): // Context cancelled, end lockdown + case <-auth.context.Done(): + // Service is shutting down, end lockdown } auth.loginMutex.Lock() - tlog.App.Info().Msg("Lockdown period ended, resuming normal operation") + auth.log.App.Info().Msg("Exiting lockdown mode") + auth.lockdown = nil auth.loginMutex.Unlock() } @@ -845,10 +854,3 @@ func (auth *AuthService) ClearRateLimitsTestingOnly() { } auth.loginMutex.Unlock() } - -func (auth *AuthService) getCookieDomain() string { - if auth.config.SubdomainsEnabled { - return "." + auth.config.CookieDomain - } - return auth.config.CookieDomain -} diff --git a/internal/service/docker_service.go b/internal/service/docker_service.go index c5f95dd4..9d077c53 100644 --- a/internal/service/docker_service.go +++ b/internal/service/docker_service.go @@ -3,51 +3,56 @@ package service import ( "context" "strings" + "sync" "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/utils/decoders" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" container "github.com/docker/docker/api/types/container" "github.com/docker/docker/client" ) type DockerService struct { - client *client.Client - context context.Context + log *logger.Logger + client *client.Client + context context.Context + isConnected bool } -func NewDockerService() *DockerService { - return &DockerService{} -} +func NewDockerService( + log *logger.Logger, + ctx context.Context, + wg *sync.WaitGroup, +) (*DockerService, error) { -func (docker *DockerService) Init() error { client, err := client.NewClientWithOpts(client.FromEnv) if err != nil { - return err + return nil, err } - ctx := context.Background() client.NegotiateAPIVersion(ctx) - docker.client = client - docker.context = ctx - - _, err = docker.client.Ping(docker.context) + _, err = client.Ping(ctx) if err != nil { - tlog.App.Debug().Err(err).Msg("Docker not connected") - docker.isConnected = false - docker.client = nil - docker.context = nil - return nil + log.App.Debug().Err(err).Msg("Docker not connected") + return nil, nil } - docker.isConnected = true - tlog.App.Debug().Msg("Docker connected") + service := &DockerService{ + log: log, + client: client, + context: ctx, + } - return nil + service.isConnected = true + service.log.App.Debug().Msg("Docker connected successfully") + + wg.Go(service.watchAndClose) + + return service, nil } func (docker *DockerService) getContainers() ([]container.Summary, error) { @@ -60,7 +65,7 @@ func (docker *DockerService) inspectContainer(containerId string) (container.Ins func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) { if !docker.isConnected { - tlog.App.Debug().Msg("Docker not connected, returning empty labels") + docker.log.App.Debug().Msg("Docker service not connected, returning empty labels") return nil, nil } @@ -82,17 +87,28 @@ func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) { for appName, appLabels := range labels.Apps { if appLabels.Config.Domain == appDomain { - tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by domain") + docker.log.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by domain") return &appLabels, nil } if strings.SplitN(appDomain, ".", 2)[0] == appName { - tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by app name") + docker.log.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by app name") return &appLabels, nil } } } - tlog.App.Debug().Msg("No matching container found, returning empty labels") + docker.log.App.Debug().Str("domain", appDomain).Msg("No matching container found for domain") return nil, nil } + +func (docker *DockerService) watchAndClose() { + <-docker.context.Done() + docker.log.App.Debug().Msg("Closing Docker client") + if docker.client != nil { + err := docker.client.Close() + if err != nil { + docker.log.App.Error().Err(err).Msg("Error closing Docker client") + } + } +} diff --git a/internal/service/kubernetes_service.go b/internal/service/kubernetes_service.go index 9c5ad427..8976cb54 100644 --- a/internal/service/kubernetes_service.go +++ b/internal/service/kubernetes_service.go @@ -9,7 +9,7 @@ import ( "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/utils/decoders" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" @@ -36,9 +36,10 @@ type ingressApp struct { } type KubernetesService struct { + log *logger.Logger + ctx context.Context + client dynamic.Interface - ctx context.Context - cancel context.CancelFunc started bool mu sync.RWMutex ingressApps map[ingressKey][]ingressApp @@ -46,12 +47,55 @@ type KubernetesService struct { appNameIndex map[string]ingressAppKey } -func NewKubernetesService() *KubernetesService { - return &KubernetesService{ +func NewKubernetesService( + log *logger.Logger, + ctx context.Context, + wg *sync.WaitGroup, +) (*KubernetesService, error) { + cfg, err := rest.InClusterConfig() + if err != nil { + return nil, fmt.Errorf("failed to get in-cluster kubernetes config: %w", err) + } + + client, err := dynamic.NewForConfig(cfg) + if err != nil { + return nil, fmt.Errorf("failed to create kubernetes client: %w", err) + } + + gvr := schema.GroupVersionResource{ + Group: "networking.k8s.io", + Version: "v1", + Resource: "ingresses", + } + + accessCtx, accessCancel := context.WithTimeout(ctx, 5*time.Second) + defer accessCancel() + + _, err = client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1}) + if err != nil { + log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled") + return nil, fmt.Errorf("failed to access ingress api: %w", err) + } + + log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher") + + service := &KubernetesService{ + log: log, + ctx: ctx, + client: client, ingressApps: make(map[ingressKey][]ingressApp), domainIndex: make(map[string]ingressAppKey), appNameIndex: make(map[string]ingressAppKey), } + + wg.Go(func() { + service.watchGVR(gvr) + }) + + service.started = true + log.App.Debug().Msg("Kubernetes label provider started successfully") + + return service, nil } func (k *KubernetesService) addIngressApps(namespace, name string, apps []ingressApp) { @@ -133,7 +177,7 @@ func (k *KubernetesService) updateFromItem(item *unstructured.Unstructured) { } labels, err := decoders.DecodeLabels[model.Apps](annotations, "apps") if err != nil { - tlog.App.Debug().Err(err).Msg("Failed to decode labels from annotations") + k.log.App.Warn().Err(err).Str("namespace", namespace).Str("name", name).Msg("Failed to decode ingress labels, skipping") k.removeIngress(namespace, name) return } @@ -161,13 +205,13 @@ func (k *KubernetesService) resyncGVR(gvr schema.GroupVersionResource) error { list, err := k.client.Resource(gvr).List(ctx, metav1.ListOptions{}) if err != nil { - tlog.App.Debug().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to list ingresses during resync") + k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to list resources for resync") return err } for i := range list.Items { k.updateFromItem(&list.Items[i]) } - tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Int("count", len(list.Items)).Msg("Resynced ingress cache") + k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Int("count", len(list.Items)).Msg("Resync complete") return nil } @@ -181,14 +225,14 @@ func (k *KubernetesService) runWatcher(gvr schema.GroupVersionResource, w watch. return false case event, ok := <-w.ResultChan(): if !ok { - tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher channel closed, restarting in 5 seconds") + k.log.App.Warn().Str("api", gvr.GroupVersion().String()).Msg("Watcher channel closed, restarting watcher") w.Stop() time.Sleep(5 * time.Second) return true } item, ok := event.Object.(*unstructured.Unstructured) if !ok { - tlog.App.Warn().Str("api", gvr.GroupVersion().String()).Msg("Failed to cast watched object") + k.log.App.Warn().Str("api", gvr.GroupVersion().String()).Msg("Received unexpected event object, skipping") continue } switch event.Type { @@ -199,7 +243,7 @@ func (k *KubernetesService) runWatcher(gvr schema.GroupVersionResource, w watch. } case <-resyncTicker.C: if err := k.resyncGVR(gvr); err != nil { - tlog.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed") + k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed during watcher run") } } } @@ -210,29 +254,29 @@ func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource) { defer resyncTicker.Stop() if err := k.resyncGVR(gvr); err != nil { - tlog.App.Error().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Initial resync failed, retrying in 30 seconds") + k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Initial resync failed, will retry") time.Sleep(30 * time.Second) } for { select { case <-k.ctx.Done(): - tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Stopping watcher") + k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Shutting down kubernetes watcher") return case <-resyncTicker.C: if err := k.resyncGVR(gvr); err != nil { - tlog.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed") + k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed, will retry") } default: ctx, cancel := context.WithCancel(k.ctx) watcher, err := k.client.Resource(gvr).Watch(ctx, metav1.ListOptions{}) if err != nil { - tlog.App.Error().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to start watcher") + k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to start watcher, will retry") cancel() time.Sleep(10 * time.Second) continue } - tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher started") + k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher started successfully") if !k.runWatcher(gvr, watcher, resyncTicker) { cancel() return @@ -242,65 +286,25 @@ func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource) { } } -func (k *KubernetesService) Init() error { - var cfg *rest.Config - var err error - - cfg, err = rest.InClusterConfig() - if err != nil { - return fmt.Errorf("failed to get in-cluster Kubernetes config: %w", err) - } - - client, err := dynamic.NewForConfig(cfg) - if err != nil { - return fmt.Errorf("failed to create Kubernetes client: %w", err) - } - - k.client = client - k.ctx, k.cancel = context.WithCancel(context.Background()) - - gvr := schema.GroupVersionResource{ - Group: "networking.k8s.io", - Version: "v1", - Resource: "ingresses", - } - - accessCtx, accessCancel := context.WithTimeout(k.ctx, 5*time.Second) - defer accessCancel() - _, err = k.client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1}) - if err != nil { - tlog.App.Warn().Err(err).Msg("Insufficient permissions for networking.k8s.io/v1 Ingress, Kubernetes label provider will not work") - k.started = false - return nil - } - - tlog.App.Debug().Msg("networking.k8s.io/v1 Ingress API accessible") - go k.watchGVR(gvr) - - k.started = true - tlog.App.Info().Msg("Kubernetes label provider initialized") - return nil -} - func (k *KubernetesService) GetLabels(appDomain string) (*model.App, error) { if !k.started { - tlog.App.Debug().Msg("Kubernetes not connected, returning empty labels") + k.log.App.Debug().Str("domain", appDomain).Msg("Kubernetes label provider not started, skipping") return nil, nil } // First check cache app := k.getByDomain(appDomain) if app != nil { - tlog.App.Debug().Str("domain", appDomain).Msg("Found labels in cache by domain") + k.log.App.Debug().Str("domain", appDomain).Msg("Found labels in cache by domain") return app, nil } appName := strings.SplitN(appDomain, ".", 2)[0] app = k.getByAppName(appName) if app != nil { - tlog.App.Debug().Str("domain", appDomain).Str("appName", appName).Msg("Found labels in cache by app name") + k.log.App.Debug().Str("domain", appDomain).Str("appName", appName).Msg("Found labels in cache by app name") return app, nil } - tlog.App.Debug().Str("domain", appDomain).Msg("Cache miss, no matching ingress found") + k.log.App.Debug().Str("domain", appDomain).Msg("No labels found for domain") return nil, nil } diff --git a/internal/service/kubernetes_service_test.go b/internal/service/kubernetes_service_test.go index c7b39ead..702fe0f8 100644 --- a/internal/service/kubernetes_service_test.go +++ b/internal/service/kubernetes_service_test.go @@ -8,9 +8,13 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/tinyauthapp/tinyauth/internal/model" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) func TestKubernetesService(t *testing.T) { + log := logger.NewLogger().WithTestConfig() + log.Init() + type testCase struct { description string run func(t *testing.T, svc *KubernetesService) @@ -179,6 +183,7 @@ func TestKubernetesService(t *testing.T) { ingressApps: make(map[ingressKey][]ingressApp), domainIndex: make(map[string]ingressAppKey), appNameIndex: make(map[string]ingressAppKey), + log: log, } test.run(t, svc) }) diff --git a/internal/service/ldap_service.go b/internal/service/ldap_service.go index 0963ebf5..9c031206 100644 --- a/internal/service/ldap_service.go +++ b/internal/service/ldap_service.go @@ -9,69 +9,47 @@ import ( "github.com/cenkalti/backoff/v5" ldapgo "github.com/go-ldap/ldap/v3" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/model" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) -type LdapServiceConfig struct { - Address string - BindDN string - BindPassword string - BaseDN string - Insecure bool - SearchFilter string - AuthCert string - AuthKey string -} - type LdapService struct { - config LdapServiceConfig - conn *ldapgo.Conn - mutex sync.RWMutex - cert *tls.Certificate - isConfigured bool + log *logger.Logger + config model.Config + context context.Context + + conn *ldapgo.Conn + mutex sync.RWMutex + cert *tls.Certificate } -func NewLdapService(config LdapServiceConfig) *LdapService { - return &LdapService{ - config: config, - } -} - -func (ldap *LdapService) IsConfigured() bool { - return ldap.isConfigured -} - -func (ldap *LdapService) Unconfigure() error { - if !ldap.isConfigured { - return nil +func NewLdapService( + log *logger.Logger, + config model.Config, + ctx context.Context, + wg *sync.WaitGroup, +) (*LdapService, error) { + if config.LDAP.Address == "" { + return nil, nil } - if ldap.conn != nil { - if err := ldap.conn.Close(); err != nil { - return fmt.Errorf("failed to close LDAP connection: %w", err) - } + ldap := &LdapService{ + log: log, + config: config, + context: ctx, } - ldap.isConfigured = false - return nil -} - -func (ldap *LdapService) Init() error { - if ldap.config.Address == "" { - ldap.isConfigured = false - return nil - } - - ldap.isConfigured = true - // Check whether authentication with client certificate is possible - if ldap.config.AuthCert != "" && ldap.config.AuthKey != "" { - cert, err := tls.LoadX509KeyPair(ldap.config.AuthCert, ldap.config.AuthKey) + if config.LDAP.AuthCert != "" && config.LDAP.AuthKey != "" { + cert, err := tls.LoadX509KeyPair(config.LDAP.AuthCert, config.LDAP.AuthKey) + if err != nil { - return fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err) + return nil, fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err) } + + log.App.Info().Msg("LDAP mTLS authentication configured successfully") + ldap.cert = &cert - tlog.App.Info().Msg("Using LDAP with mTLS authentication") // TODO: Add optional extra CA certificates, instead of `InsecureSkipVerify` /* @@ -84,26 +62,39 @@ func (ldap *LdapService) Init() error { } */ } + _, err := ldap.connect() + if err != nil { - return fmt.Errorf("failed to connect to LDAP server: %w", err) + return nil, fmt.Errorf("failed to connect to ldap server: %w", err) } - go func() { - for range time.Tick(time.Duration(5) * time.Minute) { - err := ldap.heartbeat() - if err != nil { - tlog.App.Error().Err(err).Msg("LDAP connection heartbeat failed") - if reconnectErr := ldap.reconnect(); reconnectErr != nil { - tlog.App.Error().Err(reconnectErr).Msg("Failed to reconnect to LDAP server") - continue + wg.Go(func() { + ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine") + + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + err := ldap.heartbeat() + if err != nil { + ldap.log.App.Warn().Err(err).Msg("LDAP connection heartbeat failed, attempting to reconnect") + if reconnectErr := ldap.reconnect(); reconnectErr != nil { + ldap.log.App.Error().Err(reconnectErr).Msg("Failed to reconnect to LDAP server") + continue + } + ldap.log.App.Info().Msg("Successfully reconnected to LDAP server") } - tlog.App.Info().Msg("Successfully reconnected to LDAP server") + case <-ldap.context.Done(): + ldap.log.App.Debug().Msg("LDAP service context cancelled, stopping heartbeat") + return } } - }() + }) - return nil + return ldap, nil } func (ldap *LdapService) connect() (*ldapgo.Conn, error) { @@ -120,13 +111,13 @@ func (ldap *LdapService) connect() (*ldapgo.Conn, error) { // 2. conn.StartTLS(tlsConfig) // 3. conn.externalBind() if ldap.cert != nil { - conn, err = ldapgo.DialURL(ldap.config.Address, ldapgo.DialWithTLSConfig(&tls.Config{ + conn, err = ldapgo.DialURL(ldap.config.LDAP.Address, ldapgo.DialWithTLSConfig(&tls.Config{ MinVersion: tls.VersionTLS12, Certificates: []tls.Certificate{*ldap.cert}, })) } else { - conn, err = ldapgo.DialURL(ldap.config.Address, ldapgo.DialWithTLSConfig(&tls.Config{ - InsecureSkipVerify: ldap.config.Insecure, + conn, err = ldapgo.DialURL(ldap.config.LDAP.Address, ldapgo.DialWithTLSConfig(&tls.Config{ + InsecureSkipVerify: ldap.config.LDAP.Insecure, MinVersion: tls.VersionTLS12, })) } @@ -146,10 +137,10 @@ func (ldap *LdapService) connect() (*ldapgo.Conn, error) { func (ldap *LdapService) GetUserDN(username string) (string, error) { // Escape the username to prevent LDAP injection escapedUsername := ldapgo.EscapeFilter(username) - filter := fmt.Sprintf(ldap.config.SearchFilter, escapedUsername) + filter := fmt.Sprintf(ldap.config.LDAP.SearchFilter, escapedUsername) searchRequest := ldapgo.NewSearchRequest( - ldap.config.BaseDN, + ldap.config.LDAP.BaseDN, ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false, filter, []string{"dn"}, @@ -176,7 +167,7 @@ func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) { escapedUserDN := ldapgo.EscapeFilter(userDN) searchRequest := ldapgo.NewSearchRequest( - ldap.config.BaseDN, + ldap.config.LDAP.BaseDN, ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false, fmt.Sprintf("(&(objectclass=groupOfUniqueNames)(uniquemember=%s))", escapedUserDN), []string{"dn"}, @@ -224,7 +215,7 @@ func (ldap *LdapService) BindService(rebind bool) error { if ldap.cert != nil { return ldap.conn.ExternalBind() } - return ldap.conn.Bind(ldap.config.BindDN, ldap.config.BindPassword) + return ldap.conn.Bind(ldap.config.LDAP.BindDN, ldap.config.LDAP.BindPassword) } func (ldap *LdapService) Bind(userDN string, password string) error { @@ -238,7 +229,7 @@ func (ldap *LdapService) Bind(userDN string, password string) error { } func (ldap *LdapService) heartbeat() error { - tlog.App.Debug().Msg("Performing LDAP connection heartbeat") + ldap.log.App.Debug().Msg("Performing LDAP connection heartbeat") searchRequest := ldapgo.NewSearchRequest( "", @@ -260,7 +251,7 @@ func (ldap *LdapService) heartbeat() error { } func (ldap *LdapService) reconnect() error { - tlog.App.Info().Msg("Reconnecting to LDAP server") + ldap.log.App.Info().Msg("Attempting to reconnect to LDAP server") exp := backoff.NewExponentialBackOff() exp.InitialInterval = 500 * time.Millisecond diff --git a/internal/service/oauth_broker_service.go b/internal/service/oauth_broker_service.go index 15823c47..fdb5e1e0 100644 --- a/internal/service/oauth_broker_service.go +++ b/internal/service/oauth_broker_service.go @@ -1,8 +1,10 @@ package service import ( + "context" + "github.com/tinyauthapp/tinyauth/internal/model" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" "slices" @@ -19,33 +21,39 @@ type OAuthServiceImpl interface { } type OAuthBrokerService struct { + log *logger.Logger + services map[string]OAuthServiceImpl configs map[string]model.OAuthServiceConfig } -var presets = map[string]func(config model.OAuthServiceConfig) *OAuthService{ +var presets = map[string]func(config model.OAuthServiceConfig, ctx context.Context) *OAuthService{ "github": newGitHubOAuthService, "google": newGoogleOAuthService, } -func NewOAuthBrokerService(configs map[string]model.OAuthServiceConfig) *OAuthBrokerService { - return &OAuthBrokerService{ +func NewOAuthBrokerService( + log *logger.Logger, + configs map[string]model.OAuthServiceConfig, + ctx context.Context, +) *OAuthBrokerService { + service := &OAuthBrokerService{ + log: log, services: make(map[string]OAuthServiceImpl), configs: configs, } -} -func (broker *OAuthBrokerService) Init() error { - for name, cfg := range broker.configs { + for name, cfg := range configs { if presetFunc, exists := presets[name]; exists { - broker.services[name] = presetFunc(cfg) - tlog.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset") + service.services[name] = presetFunc(cfg, ctx) + service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset") } else { - broker.services[name] = NewOAuthService(cfg, name) - tlog.App.Debug().Str("service", name).Msg("Loaded OAuth service from config") + service.services[name] = NewOAuthService(cfg, name, ctx) + service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from custom config") } } - return nil + + return service } func (broker *OAuthBrokerService) GetConfiguredServices() []string { diff --git a/internal/service/oauth_presets.go b/internal/service/oauth_presets.go index ef21fa60..d620d54d 100644 --- a/internal/service/oauth_presets.go +++ b/internal/service/oauth_presets.go @@ -1,23 +1,25 @@ package service import ( + "context" + "github.com/tinyauthapp/tinyauth/internal/model" "golang.org/x/oauth2/endpoints" ) -func newGoogleOAuthService(config model.OAuthServiceConfig) *OAuthService { +func newGoogleOAuthService(config model.OAuthServiceConfig, ctx context.Context) *OAuthService { scopes := []string{"openid", "email", "profile"} config.Scopes = scopes config.AuthURL = endpoints.Google.AuthURL config.TokenURL = endpoints.Google.TokenURL config.UserinfoURL = "https://openidconnect.googleapis.com/v1/userinfo" - return NewOAuthService(config, "google") + return NewOAuthService(config, "google", ctx) } -func newGitHubOAuthService(config model.OAuthServiceConfig) *OAuthService { +func newGitHubOAuthService(config model.OAuthServiceConfig, ctx context.Context) *OAuthService { scopes := []string{"read:user", "user:email"} config.Scopes = scopes config.AuthURL = endpoints.GitHub.AuthURL config.TokenURL = endpoints.GitHub.TokenURL - return NewOAuthService(config, "github").WithUserinfoExtractor(githubExtractor) + return NewOAuthService(config, "github", ctx).WithUserinfoExtractor(githubExtractor) } diff --git a/internal/service/oauth_service.go b/internal/service/oauth_service.go index 11b0be9c..0def3143 100644 --- a/internal/service/oauth_service.go +++ b/internal/service/oauth_service.go @@ -20,7 +20,7 @@ type OAuthService struct { id string } -func NewOAuthService(config model.OAuthServiceConfig, id string) *OAuthService { +func NewOAuthService(config model.OAuthServiceConfig, id string, ctx context.Context) *OAuthService { httpClient := &http.Client{ Timeout: 30 * time.Second, Transport: &http.Transport{ @@ -29,8 +29,7 @@ func NewOAuthService(config model.OAuthServiceConfig, id string) *OAuthService { }, }, } - ctx := context.Background() - ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) + vctx := context.WithValue(ctx, oauth2.HTTPClient, httpClient) return &OAuthService{ serviceCfg: config, @@ -44,7 +43,7 @@ func NewOAuthService(config model.OAuthServiceConfig, id string) *OAuthService { TokenURL: config.TokenURL, }, }, - ctx: ctx, + ctx: vctx, userinfoExtractor: defaultExtractor, id: id, } diff --git a/internal/service/oidc_service.go b/internal/service/oidc_service.go index 1e1c1986..92216451 100644 --- a/internal/service/oidc_service.go +++ b/internal/service/oidc_service.go @@ -16,6 +16,7 @@ import ( "net/url" "os" "strings" + "sync" "time" "slices" @@ -25,7 +26,7 @@ import ( "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/utils" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) var ( @@ -111,172 +112,173 @@ type AuthorizeRequest struct { CodeChallengeMethod string `json:"code_challenge_method"` } -type OIDCServiceConfig struct { - Clients map[string]model.OIDCClientConfig - PrivateKeyPath string - PublicKeyPath string - Issuer string - SessionExpiry int -} - type OIDCService struct { - config OIDCServiceConfig - queries *repository.Queries - clients map[string]model.OIDCClientConfig - privateKey *rsa.PrivateKey - publicKey crypto.PublicKey - issuer string - isConfigured bool + log *logger.Logger + config model.Config + runtime model.RuntimeConfig + queries *repository.Queries + context context.Context + + clients map[string]model.OIDCClientConfig + privateKey *rsa.PrivateKey + publicKey crypto.PublicKey + issuer string } -func NewOIDCService(config OIDCServiceConfig, queries *repository.Queries) *OIDCService { - return &OIDCService{ - config: config, - queries: queries, - } -} - -func (service *OIDCService) IsConfigured() bool { - return service.isConfigured -} - -func (service *OIDCService) Init() error { +func NewOIDCService( + log *logger.Logger, + config model.Config, + runtime model.RuntimeConfig, + queries *repository.Queries, + ctx context.Context, + wg *sync.WaitGroup) (*OIDCService, error) { // If not configured, skip init - if len(service.config.Clients) == 0 { - service.isConfigured = false - return nil + if len(runtime.OIDCClients) == 0 { + return nil, nil } - service.isConfigured = true - // Ensure issuer is https - uissuer, err := url.Parse(service.config.Issuer) + uissuer, err := url.Parse(runtime.AppURL) if err != nil { - return err + return nil, fmt.Errorf("failed to parse app url: %w", err) } if uissuer.Scheme != "https" { - return errors.New("issuer must be https") + return nil, errors.New("issuer must be https") } - service.issuer = fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host) + issuer := fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host) // Create/load private and public keys - if strings.TrimSpace(service.config.PrivateKeyPath) == "" || - strings.TrimSpace(service.config.PublicKeyPath) == "" { - return errors.New("private key path and public key path are required") + if strings.TrimSpace(config.OIDC.PrivateKeyPath) == "" || + strings.TrimSpace(config.OIDC.PublicKeyPath) == "" { + return nil, errors.New("private key path and public key path are required") } var privateKey *rsa.PrivateKey - fprivateKey, err := os.ReadFile(service.config.PrivateKeyPath) + fprivateKey, err := os.ReadFile(config.OIDC.PrivateKeyPath) if err != nil && !errors.Is(err, os.ErrNotExist) { - return err + return nil, err } if errors.Is(err, os.ErrNotExist) { privateKey, err = rsa.GenerateKey(rand.Reader, 2048) if err != nil { - return err + return nil, fmt.Errorf("failed to generate private key: %w", err) } der := x509.MarshalPKCS1PrivateKey(privateKey) if der == nil { - return errors.New("failed to marshal private key") + return nil, errors.New("failed to marshal private key") } encoded := pem.EncodeToMemory(&pem.Block{ Type: "RSA PRIVATE KEY", Bytes: der, }) - tlog.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key") - err = os.WriteFile(service.config.PrivateKeyPath, encoded, 0600) + log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key") + err = os.WriteFile(config.OIDC.PrivateKeyPath, encoded, 0600) if err != nil { - return err + return nil, fmt.Errorf("failed to write private key to file: %w", err) } - service.privateKey = privateKey } else { block, _ := pem.Decode(fprivateKey) if block == nil { - return errors.New("failed to decode private key") + return nil, errors.New("failed to decode private key") } - tlog.App.Trace().Str("type", block.Type).Msg("Loaded private key") + log.App.Trace().Str("type", block.Type).Msg("Loaded private key") privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes) if err != nil { - return err + return nil, fmt.Errorf("failed to parse private key: %w", err) } - service.privateKey = privateKey } - fpublicKey, err := os.ReadFile(service.config.PublicKeyPath) + var publicKey crypto.PublicKey + + fpublicKey, err := os.ReadFile(config.OIDC.PublicKeyPath) if err != nil && !errors.Is(err, os.ErrNotExist) { - return err + return nil, fmt.Errorf("failed to read public key: %w", err) } if errors.Is(err, os.ErrNotExist) { - publicKey := service.privateKey.Public() + publicKey = privateKey.Public() der := x509.MarshalPKCS1PublicKey(publicKey.(*rsa.PublicKey)) if der == nil { - return errors.New("failed to marshal public key") + return nil, errors.New("failed to marshal public key") } encoded := pem.EncodeToMemory(&pem.Block{ Type: "RSA PUBLIC KEY", Bytes: der, }) - tlog.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key") - err = os.WriteFile(service.config.PublicKeyPath, encoded, 0644) + log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key") + err = os.WriteFile(config.OIDC.PublicKeyPath, encoded, 0644) if err != nil { - return err + return nil, err } - service.publicKey = publicKey } else { block, _ := pem.Decode(fpublicKey) if block == nil { - return errors.New("failed to decode public key") + return nil, errors.New("failed to decode public key") } - tlog.App.Trace().Str("type", block.Type).Msg("Loaded public key") + log.App.Trace().Str("type", block.Type).Msg("Loaded public key") switch block.Type { case "RSA PUBLIC KEY": - publicKey, err := x509.ParsePKCS1PublicKey(block.Bytes) + publicKey, err = x509.ParsePKCS1PublicKey(block.Bytes) if err != nil { - return err + return nil, fmt.Errorf("failed to parse public key: %w", err) } - service.publicKey = publicKey case "PUBLIC KEY": - publicKey, err := x509.ParsePKIXPublicKey(block.Bytes) + publicKey, err = x509.ParsePKIXPublicKey(block.Bytes) if err != nil { - return err + return nil, fmt.Errorf("failed to parse public key: %w", err) } - service.publicKey = publicKey.(crypto.PublicKey) default: - return fmt.Errorf("unsupported public key type: %s", block.Type) + return nil, fmt.Errorf("unsupported public key type: %s", block.Type) } } // We will reorganize the client into a map with the client ID as the key - service.clients = make(map[string]model.OIDCClientConfig) + clients := make(map[string]model.OIDCClientConfig) - for id, client := range service.config.Clients { + for id, client := range config.OIDC.Clients { client.ID = id if client.Name == "" { client.Name = utils.Capitalize(client.ID) } - service.clients[client.ClientID] = client + clients[client.ClientID] = client } // Load the client secrets from files if they exist - for id, client := range service.clients { + for id, client := range clients { secret := utils.GetSecret(client.ClientSecret, client.ClientSecretFile) if secret != "" { client.ClientSecret = secret } client.ClientSecretFile = "" - service.clients[id] = client - tlog.App.Info().Str("id", client.ID).Msg("Registered OIDC client") + clients[id] = client + log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration") } - return nil + // Initialize the service + service := &OIDCService{ + log: log, + config: config, + runtime: runtime, + queries: queries, + context: ctx, + + clients: clients, + privateKey: privateKey, + publicKey: publicKey, + issuer: issuer, + } + + // Start cleanup routine + wg.Go(service.cleanupRoutine) + + return service, nil } func (service *OIDCService) GetIssuer() string { @@ -307,7 +309,7 @@ func (service *OIDCService) ValidateAuthorizeParams(req AuthorizeRequest) error return errors.New("invalid_scope") } if !slices.Contains(SupportedScopes, scope) { - tlog.App.Warn().Str("scope", scope).Msg("Unsupported OIDC scope, will be ignored") + service.log.App.Warn().Str("scope", scope).Msg("Requested unsupported scope") } } @@ -357,7 +359,7 @@ func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, r entry.CodeChallenge = req.CodeChallenge } else { entry.CodeChallenge = service.hashAndEncodePKCE(req.CodeChallenge) - tlog.App.Warn().Msg("Received plain PKCE code challenge, it's recommended to use S256 for better security") + service.log.App.Warn().Msg("Using plain PKCE code challenge method is not recommended, consider switching to S256 for better security") } } @@ -449,7 +451,7 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string, client func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user repository.OidcUserinfo, scope string, nonce string) (string, error) { createdAt := time.Now().Unix() - expiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix() + expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix() hasher := sha256.New() @@ -529,16 +531,16 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client model.OID accessToken := utils.GenerateString(32) refreshToken := utils.GenerateString(32) - tokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix() + tokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix() // Refresh token lives double the time of an access token but can't be used to access userinfo - refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry*2) * time.Second).Unix() + refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry*2) * time.Second).Unix() tokenResponse := TokenResponse{ AccessToken: accessToken, RefreshToken: refreshToken, TokenType: "Bearer", - ExpiresIn: int64(service.config.SessionExpiry), + ExpiresIn: int64(service.config.Auth.SessionExpiry), IDToken: idToken, Scope: strings.ReplaceAll(codeEntry.Scope, ",", " "), } @@ -598,14 +600,14 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri accessToken := utils.GenerateString(32) newRefreshToken := utils.GenerateString(32) - tokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix() - refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry*2) * time.Second).Unix() + tokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix() + refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry*2) * time.Second).Unix() tokenResponse := TokenResponse{ AccessToken: accessToken, RefreshToken: newRefreshToken, TokenType: "Bearer", - ExpiresIn: int64(service.config.SessionExpiry), + ExpiresIn: int64(service.config.Auth.SessionExpiry), IDToken: idToken, Scope: strings.ReplaceAll(entry.Scope, ",", " "), } @@ -748,56 +750,62 @@ func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) er } // Cleanup routine - Resource heavy due to the linked tables -func (service *OIDCService) Cleanup() { - // We need a context for the routine - ctx := context.Background() - +func (service *OIDCService) cleanupRoutine() { + service.log.App.Debug().Msg("Starting OIDC cleanup routine") ticker := time.NewTicker(time.Duration(30) * time.Minute) defer ticker.Stop() - for range ticker.C { - currentTime := time.Now().Unix() + for { + select { + case <-ticker.C: + service.log.App.Debug().Msg("Performing OIDC cleanup routine") - // For the OIDC tokens, if they are expired we delete the userinfo and codes - expiredTokens, err := service.queries.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{ - TokenExpiresAt: currentTime, - RefreshTokenExpiresAt: currentTime, - }) + currentTime := time.Now().Unix() - if err != nil { - tlog.App.Warn().Err(err).Msg("Failed to delete expired tokens") - } + // For the OIDC tokens, if they are expired we delete the userinfo and codes + expiredTokens, err := service.queries.DeleteExpiredOidcTokens(service.context, repository.DeleteExpiredOidcTokensParams{ + TokenExpiresAt: currentTime, + RefreshTokenExpiresAt: currentTime, + }) - for _, expiredToken := range expiredTokens { - err := service.DeleteOldSession(ctx, expiredToken.Sub) if err != nil { - tlog.App.Warn().Err(err).Msg("Failed to delete old session") + service.log.App.Warn().Err(err).Msg("Failed to delete expired tokens") } - } - // For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything - expiredCodes, err := service.queries.DeleteExpiredOidcCodes(ctx, currentTime) + for _, expiredToken := range expiredTokens { + err := service.DeleteOldSession(service.context, expiredToken.Sub) + if err != nil { + service.log.App.Warn().Err(err).Msg("Failed to delete session for expired token") + } + } - if err != nil { - tlog.App.Warn().Err(err).Msg("Failed to delete expired codes") - } - - for _, expiredCode := range expiredCodes { - token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub) + // For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything + expiredCodes, err := service.queries.DeleteExpiredOidcCodes(service.context, currentTime) if err != nil { - if errors.Is(err, sql.ErrNoRows) { + service.log.App.Warn().Err(err).Msg("Failed to delete expired codes") + } + + for _, expiredCode := range expiredCodes { + token, err := service.queries.GetOidcTokenBySub(service.context, expiredCode.Sub) + + if err != nil { + service.log.App.Warn().Err(err).Msg("Failed to get token by sub for expired code") continue } - tlog.App.Warn().Err(err).Msg("Failed to get OIDC token by sub") - } - if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime { - err := service.DeleteOldSession(ctx, expiredCode.Sub) - if err != nil { - tlog.App.Warn().Err(err).Msg("Failed to delete session") + if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime { + err := service.DeleteOldSession(service.context, expiredCode.Sub) + if err != nil { + service.log.App.Warn().Err(err).Msg("Failed to delete session for expired code") + } } } + + service.log.App.Debug().Msg("Finished OIDC cleanup routine") + case <-service.context.Done(): + service.log.App.Debug().Msg("Stopping OIDC cleanup routine") + return } } } diff --git a/internal/service/oidc_service_test.go b/internal/service/oidc_service_test.go index 394df4be..bc24c9be 100644 --- a/internal/service/oidc_service_test.go +++ b/internal/service/oidc_service_test.go @@ -1,7 +1,9 @@ package service_test import ( + "context" "encoding/json" + "sync" "testing" "github.com/stretchr/testify/assert" @@ -10,6 +12,7 @@ import ( "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) func newTestUser() repository.OidcUserinfo { @@ -48,13 +51,29 @@ func newTestUser() repository.OidcUserinfo { func TestCompileUserinfo(t *testing.T) { dir := t.TempDir() - svc := service.NewOIDCService(service.OIDCServiceConfig{ - PrivateKeyPath: dir + "/key.pem", - PublicKeyPath: dir + "/key.pub", - Issuer: "https://tinyauth.example.com", - SessionExpiry: 3600, - }, nil) - require.NoError(t, svc.Init()) + + cfg := model.Config{ + OIDC: model.OIDCConfig{ + PrivateKeyPath: dir + "/key.pem", + PublicKeyPath: dir + "/key.pub", + }, + Auth: model.AuthConfig{ + SessionExpiry: 3600, + }, + } + + runtime := model.RuntimeConfig{ + AppURL: "https://tinyauth.example.com", + } + + log := logger.NewLogger().WithTestConfig() + log.Init() + + ctx := context.TODO() + wg := &sync.WaitGroup{} + + svc, err := service.NewOIDCService(log, cfg, runtime, nil, ctx, wg) + require.NoError(t, err) type testCase struct { description string diff --git a/internal/test/test.go b/internal/test/test.go new file mode 100644 index 00000000..73ff5d38 --- /dev/null +++ b/internal/test/test.go @@ -0,0 +1,106 @@ +package test + +import ( + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + "github.com/tinyauthapp/tinyauth/internal/model" + "golang.org/x/crypto/bcrypt" +) + +var TestingTOTPSecret = "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK" + +func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) { + tempDir := t.TempDir() + + config := model.Config{ + UI: model.UIConfig{ + Title: "Tinyauth Test", + ForgotPasswordMessage: "foo", + BackgroundImage: "/background.jpg", + WarningsEnabled: true, + }, + OAuth: model.OAuthConfig{ + AutoRedirect: "none", + }, + OIDC: model.OIDCConfig{ + Clients: map[string]model.OIDCClientConfig{ + "test": { + ClientID: "some-client-id", + ClientSecret: "some-client-secret", + TrustedRedirectURIs: []string{"https://test.example.com/callback"}, + Name: "Test Client", + }, + }, + PrivateKeyPath: filepath.Join(tempDir, "key.pem"), + PublicKeyPath: filepath.Join(tempDir, "key.pub"), + }, + Auth: model.AuthConfig{ + SessionExpiry: 10, + LoginTimeout: 10, + LoginMaxRetries: 3, + }, + Database: model.DatabaseConfig{ + Path: filepath.Join(tempDir, "test.db"), + }, + Resources: model.ResourcesConfig{ + Enabled: true, + Path: filepath.Join(tempDir, "resources"), + }, + } + + passwd, err := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost) + require.NoError(t, err) + + runtime := model.RuntimeConfig{ + ConfiguredProviders: []model.Provider{ + { + Name: "Local", + ID: "local", + OAuth: false, + }, + }, + LocalUsers: []model.LocalUser{ + { + Username: "testuser", + Password: string(passwd), + }, + { + Username: "totpuser", + Password: string(passwd), + TOTPSecret: TestingTOTPSecret, + }, + { + Username: "attruser", + Password: string(passwd), + Attributes: model.UserAttributes{ + Name: "Alice Smith", + Email: "alice@example.com", + }, + }, + { + Username: "attrtotpuser", + Password: string(passwd), + TOTPSecret: TestingTOTPSecret, + Attributes: model.UserAttributes{ + Name: "Bob Jones", + Email: "bob@example.com", + }, + }, + }, + CookieDomain: "example.com", + AppURL: "https://tinyauth.example.com", + SessionCookieName: "tinyauth-session", + OIDCClients: func() []model.OIDCClientConfig { + var clients []model.OIDCClientConfig + for id, client := range config.OIDC.Clients { + client.ID = id + clients = append(clients, client) + } + return clients + }(), + } + + return config, runtime +} diff --git a/internal/utils/app_utils.go b/internal/utils/app_utils.go index d021c083..6413755b 100644 --- a/internal/utils/app_utils.go +++ b/internal/utils/app_utils.go @@ -7,8 +7,6 @@ import ( "net/url" "strings" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" - "github.com/weppos/publicsuffix-go/publicsuffix" ) @@ -28,7 +26,6 @@ func GetCookieDomain(u string) (string, error) { parts := strings.Split(host, ".") if len(parts) == 2 { - tlog.App.Warn().Msgf("Running on the root domain, cookies will be set for .%v", host) return host, nil } diff --git a/internal/utils/logger/logger.go b/internal/utils/logger/logger.go new file mode 100644 index 00000000..af6b55ea --- /dev/null +++ b/internal/utils/logger/logger.go @@ -0,0 +1,160 @@ +package logger + +import ( + "io" + "os" + "strings" + "time" + + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" + "github.com/tinyauthapp/tinyauth/internal/model" +) + +type Logger struct { + HTTP zerolog.Logger + App zerolog.Logger + config model.LogConfig + base zerolog.Logger + audit zerolog.Logger + writer io.Writer +} + +func NewLogger() *Logger { + return &Logger{ + writer: os.Stderr, + config: model.LogConfig{ + Level: "error", + Json: true, + Streams: model.LogStreams{ + HTTP: model.LogStreamConfig{ + Enabled: true, + }, + App: model.LogStreamConfig{ + Enabled: true, + }, + // No reason to enable audit by default since it will be suppressed by the log level + }, + }, + } +} + +func (l *Logger) WithConfig(cfg model.LogConfig) *Logger { + l.config = cfg + return l +} + +func (l *Logger) WithSimpleConfig() *Logger { + l.config = model.LogConfig{ + Level: "info", + Json: false, + Streams: model.LogStreams{ + HTTP: model.LogStreamConfig{Enabled: true}, + App: model.LogStreamConfig{Enabled: true}, + Audit: model.LogStreamConfig{Enabled: false}, + }, + } + return l +} + +func (l *Logger) WithTestConfig() *Logger { + l.config = model.LogConfig{ + Level: "trace", + Streams: model.LogStreams{ + HTTP: model.LogStreamConfig{Enabled: true}, + App: model.LogStreamConfig{Enabled: true}, + Audit: model.LogStreamConfig{Enabled: true}, + }, + } + return l +} + +func (l *Logger) WithWriter(writer io.Writer) *Logger { + l.writer = writer + return l +} + +func (l *Logger) Init() { + base := log.With(). + Timestamp(). + Logger(). + Level(l.parseLogLevel(l.config.Level)).Output(l.writer) + + if !l.config.Json { + base = base.Output(zerolog.ConsoleWriter{ + Out: l.writer, + TimeFormat: time.RFC3339, + }) + } + + if base.GetLevel() == zerolog.TraceLevel || base.GetLevel() == zerolog.DebugLevel { + base = base.With().Caller().Logger() + } + + l.base = base + l.audit = l.createLogger("audit", l.config.Streams.Audit) + l.HTTP = l.createLogger("http", l.config.Streams.HTTP) + l.App = l.createLogger("app", l.config.Streams.App) +} + +func (l *Logger) parseLogLevel(level string) zerolog.Level { + if level == "" { + return zerolog.InfoLevel + } + parsed, err := zerolog.ParseLevel(strings.ToLower(level)) + if err != nil { + log.Warn().Err(err).Str("level", level).Msg("Invalid log level, defaulting to error") + parsed = zerolog.ErrorLevel + } + return parsed +} + +func (l *Logger) createLogger(component string, cfg model.LogStreamConfig) zerolog.Logger { + if !cfg.Enabled { + return zerolog.Nop() + } + sub := l.base.With().Str("stream", component).Logger() + if cfg.Level != "" { + sub = sub.Level(l.parseLogLevel(cfg.Level)) + } + return sub +} + +func (l *Logger) AuditLoginSuccess(username, provider, ip string) { + l.audit.Info(). + CallerSkipFrame(1). + Str("event", "login"). + Str("result", "success"). + Str("username", username). + Str("provider", provider). + Str("ip", ip). + Send() +} + +func (l *Logger) AuditLoginFailure(username, provider, ip, reason string) { + l.audit.Warn(). + CallerSkipFrame(1). + Str("event", "login"). + Str("result", "failure"). + Str("username", username). + Str("provider", provider). + Str("ip", ip). + Str("reason", reason). + Send() +} + +func (l *Logger) AuditLogout(username, provider, ip string) { + l.audit.Info(). + CallerSkipFrame(1). + Str("event", "logout"). + Str("result", "success"). + Str("username", username). + Str("provider", provider). + Str("ip", ip). + Send() +} + +// Used for testing +func (l *Logger) GetConfig() model.LogConfig { + return l.config +} diff --git a/internal/utils/logger/logger_test.go b/internal/utils/logger/logger_test.go new file mode 100644 index 00000000..167e2337 --- /dev/null +++ b/internal/utils/logger/logger_test.go @@ -0,0 +1,173 @@ +package logger_test + +import ( + "bytes" + "encoding/json" + "testing" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tinyauthapp/tinyauth/internal/model" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" +) + +func TestLogger(t *testing.T) { + type testCase struct { + description string + run func(t *testing.T) + } + + tests := []testCase{ + { + description: "Should create a simple logger with the expected config", + run: func(t *testing.T) { + l := logger.NewLogger().WithSimpleConfig() + l.Init() + + cfg := l.GetConfig() + + assert.Equal(t, cfg, model.LogConfig{ + Level: "info", + Json: false, + Streams: model.LogStreams{ + HTTP: model.LogStreamConfig{Enabled: true}, + App: model.LogStreamConfig{Enabled: true}, + Audit: model.LogStreamConfig{Enabled: false}, + }, + }) + }, + }, + { + description: "Should create a test logger with the expected config", + run: func(t *testing.T) { + l := logger.NewLogger().WithTestConfig() + l.Init() + + cfg := l.GetConfig() + + assert.Equal(t, cfg, model.LogConfig{ + Level: "trace", + Json: false, + Streams: model.LogStreams{ + HTTP: model.LogStreamConfig{Enabled: true}, + App: model.LogStreamConfig{Enabled: true}, + Audit: model.LogStreamConfig{Enabled: true}, + }, + }) + }, + }, + { + description: "Should create a logger with a custom config", + run: func(t *testing.T) { + customCfg := model.LogConfig{ + Level: "debug", + Json: true, + Streams: model.LogStreams{ + HTTP: model.LogStreamConfig{Enabled: false}, + App: model.LogStreamConfig{Enabled: true}, + Audit: model.LogStreamConfig{Enabled: false}, + }, + } + + l := logger.NewLogger().WithConfig(customCfg) + l.Init() + + cfg := l.GetConfig() + + assert.Equal(t, cfg, customCfg) + }, + }, + { + description: "Default logger should use error type and log json", + run: func(t *testing.T) { + buf := bytes.Buffer{} + + l := logger.NewLogger().WithWriter(&buf) + l.Init() + + cfg := l.GetConfig() + + assert.Equal(t, cfg, model.LogConfig{ + Level: "error", + Json: true, + Streams: model.LogStreams{ + HTTP: model.LogStreamConfig{Enabled: true}, + App: model.LogStreamConfig{Enabled: true}, + Audit: model.LogStreamConfig{Enabled: false}, + }, + }) + + l.App.Error().Msg("test") + + var entry map[string]any + err := json.Unmarshal(buf.Bytes(), &entry) + require.NoError(t, err) + + assert.Equal(t, "test", entry["message"]) + assert.Equal(t, "app", entry["stream"]) + assert.Equal(t, "error", entry["level"]) + assert.NotEmpty(t, entry["time"]) + }, + }, + { + description: "Should default to error level if an invalid level is provided", + run: func(t *testing.T) { + buf := bytes.Buffer{} + + customCfg := model.LogConfig{ + Level: "invalid", + Json: false, + Streams: model.LogStreams{ + HTTP: model.LogStreamConfig{Enabled: true}, + App: model.LogStreamConfig{Enabled: true}, + Audit: model.LogStreamConfig{Enabled: false}, + }, + } + + l := logger.NewLogger().WithConfig(customCfg).WithWriter(&buf) + l.Init() + + assert.Equal(t, zerolog.ErrorLevel, l.App.GetLevel()) + assert.Equal(t, zerolog.ErrorLevel, l.HTTP.GetLevel()) + + // should not get logged + l.AuditLoginFailure("test", "test", "test", "test") + + assert.Empty(t, buf.String()) + }, + }, + { + description: "Should use nop logger for disabled streams", + run: func(t *testing.T) { + buf := bytes.Buffer{} + + customCfg := model.LogConfig{ + Level: "info", + Json: false, + Streams: model.LogStreams{ + HTTP: model.LogStreamConfig{Enabled: false}, + App: model.LogStreamConfig{Enabled: true}, + Audit: model.LogStreamConfig{Enabled: false}, + }, + } + + l := logger.NewLogger().WithConfig(customCfg).WithWriter(&buf) + l.Init() + + assert.Equal(t, zerolog.Disabled, l.HTTP.GetLevel()) + + l.App.Info().Msg("test") + + l.AuditLoginFailure("test_nop", "test_nop", "test_nop", "test_nop") + + assert.NotEmpty(t, buf.String()) + assert.NotContains(t, buf.String(), "test_nop") + }, + }, + } + + for _, test := range tests { + t.Run(test.description, test.run) + } +} diff --git a/internal/utils/tlog/log_audit.go b/internal/utils/tlog/log_audit.go deleted file mode 100644 index 115d41fe..00000000 --- a/internal/utils/tlog/log_audit.go +++ /dev/null @@ -1,39 +0,0 @@ -package tlog - -import "github.com/gin-gonic/gin" - -// functions here use CallerSkipFrame to ensure correct caller info is logged - -func AuditLoginSuccess(c *gin.Context, username, provider string) { - Audit.Info(). - CallerSkipFrame(1). - Str("event", "login"). - Str("result", "success"). - Str("username", username). - Str("provider", provider). - Str("ip", c.ClientIP()). - Send() -} - -func AuditLoginFailure(c *gin.Context, username, provider string, reason string) { - Audit.Warn(). - CallerSkipFrame(1). - Str("event", "login"). - Str("result", "failure"). - Str("username", username). - Str("provider", provider). - Str("ip", c.ClientIP()). - Str("reason", reason). - Send() -} - -func AuditLogout(c *gin.Context, username, provider string) { - Audit.Info(). - CallerSkipFrame(1). - Str("event", "logout"). - Str("result", "success"). - Str("username", username). - Str("provider", provider). - Str("ip", c.ClientIP()). - Send() -} diff --git a/internal/utils/tlog/log_wrapper.go b/internal/utils/tlog/log_wrapper.go deleted file mode 100644 index ffdfcf91..00000000 --- a/internal/utils/tlog/log_wrapper.go +++ /dev/null @@ -1,97 +0,0 @@ -package tlog - -import ( - "os" - "strings" - "time" - - "github.com/rs/zerolog" - "github.com/rs/zerolog/log" - "github.com/tinyauthapp/tinyauth/internal/model" -) - -type Logger struct { - Audit zerolog.Logger - HTTP zerolog.Logger - App zerolog.Logger -} - -var ( - Audit zerolog.Logger - HTTP zerolog.Logger - App zerolog.Logger -) - -func NewLogger(cfg model.LogConfig) *Logger { - baseLogger := log.With(). - Timestamp(). - Caller(). - Logger(). - Level(parseLogLevel(cfg.Level)) - - if !cfg.Json { - baseLogger = baseLogger.Output(zerolog.ConsoleWriter{ - Out: os.Stderr, - TimeFormat: time.RFC3339, - }) - } - - return &Logger{ - Audit: createLogger("audit", cfg.Streams.Audit, baseLogger), - HTTP: createLogger("http", cfg.Streams.HTTP, baseLogger), - App: createLogger("app", cfg.Streams.App, baseLogger), - } -} - -func NewSimpleLogger() *Logger { - return NewLogger(model.LogConfig{ - Level: "info", - Json: false, - Streams: model.LogStreams{ - HTTP: model.LogStreamConfig{Enabled: true}, - App: model.LogStreamConfig{Enabled: true}, - Audit: model.LogStreamConfig{Enabled: false}, - }, - }) -} - -func NewTestLogger() *Logger { - return NewLogger(model.LogConfig{ - Level: "trace", - Streams: model.LogStreams{ - HTTP: model.LogStreamConfig{Enabled: true}, - App: model.LogStreamConfig{Enabled: true}, - Audit: model.LogStreamConfig{Enabled: true}, - }, - }) -} - -func (l *Logger) Init() { - Audit = l.Audit - HTTP = l.HTTP - App = l.App -} - -func createLogger(component string, streamCfg model.LogStreamConfig, baseLogger zerolog.Logger) zerolog.Logger { - if !streamCfg.Enabled { - return zerolog.Nop() - } - subLogger := baseLogger.With().Str("log_stream", component).Logger() - // override level if specified, otherwise use base level - if streamCfg.Level != "" { - subLogger = subLogger.Level(parseLogLevel(streamCfg.Level)) - } - return subLogger -} - -func parseLogLevel(level string) zerolog.Level { - if level == "" { - return zerolog.InfoLevel - } - parsedLevel, err := zerolog.ParseLevel(strings.ToLower(level)) - if err != nil { - log.Warn().Err(err).Str("level", level).Msg("Invalid log level, defaulting to info") - parsedLevel = zerolog.InfoLevel - } - return parsedLevel -} diff --git a/internal/utils/tlog/log_wrapper_test.go b/internal/utils/tlog/log_wrapper_test.go deleted file mode 100644 index 41609f53..00000000 --- a/internal/utils/tlog/log_wrapper_test.go +++ /dev/null @@ -1,93 +0,0 @@ -package tlog_test - -import ( - "bytes" - "encoding/json" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/tinyauthapp/tinyauth/internal/model" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" - - "github.com/rs/zerolog" -) - -func TestNewLogger(t *testing.T) { - cfg := model.LogConfig{ - Level: "debug", - Json: true, - Streams: model.LogStreams{ - HTTP: model.LogStreamConfig{Enabled: true, Level: "info"}, - App: model.LogStreamConfig{Enabled: true, Level: ""}, - Audit: model.LogStreamConfig{Enabled: false, Level: ""}, - }, - } - - logger := tlog.NewLogger(cfg) - - assert.NotNil(t, logger) - assert.Equal(t, zerolog.InfoLevel, logger.HTTP.GetLevel()) - assert.Equal(t, zerolog.DebugLevel, logger.App.GetLevel()) - assert.Equal(t, zerolog.Disabled, logger.Audit.GetLevel()) -} - -func TestNewSimpleLogger(t *testing.T) { - logger := tlog.NewSimpleLogger() - assert.NotNil(t, logger) - assert.Equal(t, zerolog.InfoLevel, logger.HTTP.GetLevel()) - assert.Equal(t, zerolog.InfoLevel, logger.App.GetLevel()) - assert.Equal(t, zerolog.Disabled, logger.Audit.GetLevel()) -} - -func TestLoggerInit(t *testing.T) { - logger := tlog.NewSimpleLogger() - logger.Init() - - assert.NotEqual(t, zerolog.Disabled, tlog.App.GetLevel()) -} - -func TestLoggerWithDisabledStreams(t *testing.T) { - cfg := model.LogConfig{ - Level: "info", - Json: false, - Streams: model.LogStreams{ - HTTP: model.LogStreamConfig{Enabled: false}, - App: model.LogStreamConfig{Enabled: false}, - Audit: model.LogStreamConfig{Enabled: false}, - }, - } - - logger := tlog.NewLogger(cfg) - - assert.Equal(t, zerolog.Disabled, logger.HTTP.GetLevel()) - assert.Equal(t, zerolog.Disabled, logger.App.GetLevel()) - assert.Equal(t, zerolog.Disabled, logger.Audit.GetLevel()) -} - -func TestLogStreamField(t *testing.T) { - var buf bytes.Buffer - - cfg := model.LogConfig{ - Level: "info", - Json: true, - Streams: model.LogStreams{ - HTTP: model.LogStreamConfig{Enabled: true}, - App: model.LogStreamConfig{Enabled: true}, - Audit: model.LogStreamConfig{Enabled: true}, - }, - } - - logger := tlog.NewLogger(cfg) - - // Override output for HTTP logger to capture output - logger.HTTP = logger.HTTP.Output(&buf) - - logger.HTTP.Info().Msg("test message") - - var logEntry map[string]interface{} - err := json.Unmarshal(buf.Bytes(), &logEntry) - assert.NoError(t, err) - - assert.Equal(t, "http", logEntry["log_stream"]) - assert.Equal(t, "test message", logEntry["message"]) -}