From 659d3561e0031c462ea40a67c9cc5b6e8b7c0b40 Mon Sep 17 00:00:00 2001 From: Stavros Date: Mon, 25 Aug 2025 22:03:06 +0300 Subject: [PATCH] refactor: use a boostrap service to bootstrap the app --- .env.example | 4 +- cmd/root.go | 346 ++++++----------------- cmd/version.go | 8 +- internal/bootstrap/bootstrap_app.go | 246 ++++++++++++++++ internal/config/config.go | 7 +- internal/service/auth_service.go | 7 +- internal/service/github_oauth_service.go | 2 + internal/service/google_oauth_service.go | 2 + internal/utils/utils.go | 20 ++ 9 files changed, 369 insertions(+), 273 deletions(-) create mode 100644 internal/bootstrap/bootstrap_app.go diff --git a/.env.example b/.env.example index 8edde7b..0f43bf0 100644 --- a/.env.example +++ b/.env.example @@ -5,7 +5,7 @@ SECRET_FILE=app_secret_file APP_URL=http://localhost:3000 USERS=your_user_password_hash USERS_FILE=users_file -COOKIE_SECURE=false +SECURE_COOKIE=false GITHUB_CLIENT_ID=github_client_id GITHUB_CLIENT_SECRET=github_client_secret GITHUB_CLIENT_SECRET_FILE=github_client_secret_file @@ -25,7 +25,7 @@ GENERIC_NAME=My OAuth SESSION_EXPIRY=7200 LOGIN_TIMEOUT=300 LOGIN_MAX_RETRIES=5 -LOG_LEVEL=0 +LOG_LEVEL=debug APP_TITLE=Tinyauth SSO FORGOT_PASSWORD_MESSAGE=Some message about resetting the password OAUTH_AUTO_REDIRECT=none diff --git a/cmd/root.go b/cmd/root.go index 8dadd5d..898c27f 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -1,22 +1,13 @@ package cmd import ( - "errors" - "fmt" "strings" totpCmd "tinyauth/cmd/totp" userCmd "tinyauth/cmd/user" - "tinyauth/internal/auth" - "tinyauth/internal/constants" - "tinyauth/internal/controller" - "tinyauth/internal/docker" - "tinyauth/internal/ldap" - "tinyauth/internal/middleware" - "tinyauth/internal/providers" - "tinyauth/internal/types" + "tinyauth/internal/bootstrap" + "tinyauth/internal/config" "tinyauth/internal/utils" - "github.com/gin-gonic/gin" "github.com/go-playground/validator" "github.com/rs/zerolog" "github.com/rs/zerolog/log" @@ -24,197 +15,51 @@ import ( "github.com/spf13/viper" ) -type Middleware interface { - Middleware() gin.HandlerFunc - Init() error - Name() string -} - var rootCmd = &cobra.Command{ Use: "tinyauth", Short: "The simplest way to protect your apps with a login screen.", Long: `Tinyauth is a simple authentication middleware that adds simple username/password login or OAuth with Google, Github and any generic OAuth provider to all of your docker apps.`, Run: func(cmd *cobra.Command, args []string) { - var config types.Config - err := viper.Unmarshal(&config) - HandleError(err, "Failed to parse config") + var conf config.Config + + err := viper.Unmarshal(&conf) + if err != nil { + log.Fatal().Err(err).Msg("Failed to parse config") + } // Check if secrets have a file associated with them - config.Secret = utils.GetSecret(config.Secret, config.SecretFile) - config.GithubClientSecret = utils.GetSecret(config.GithubClientSecret, config.GithubClientSecretFile) - config.GoogleClientSecret = utils.GetSecret(config.GoogleClientSecret, config.GoogleClientSecretFile) - config.GenericClientSecret = utils.GetSecret(config.GenericClientSecret, config.GenericClientSecretFile) + conf.Secret = utils.GetSecret(conf.Secret, conf.SecretFile) + conf.GithubClientSecret = utils.GetSecret(conf.GithubClientSecret, conf.GithubClientSecretFile) + conf.GoogleClientSecret = utils.GetSecret(conf.GoogleClientSecret, conf.GoogleClientSecretFile) + conf.GenericClientSecret = utils.GetSecret(conf.GenericClientSecret, conf.GenericClientSecretFile) validator := validator.New() - err = validator.Struct(config) - HandleError(err, "Failed to validate config") - log.Logger = log.Level(zerolog.Level(config.LogLevel)) - log.Info().Str("version", strings.TrimSpace(constants.Version)).Msg("Starting tinyauth") - - log.Info().Msg("Parsing users") - users, err := utils.GetUsers(config.Users, config.UsersFile) - HandleError(err, "Failed to parse users") - - log.Debug().Msg("Getting domain") - domain, err := utils.GetUpperDomain(config.AppURL) - HandleError(err, "Failed to get upper domain") - log.Info().Str("domain", domain).Msg("Using domain for cookie store") - - cookieId := utils.GenerateIdentifier(strings.Split(domain, ".")[0]) - sessionCookieName := fmt.Sprintf("%s-%s", constants.SessionCookieName, cookieId) - csrfCookieName := fmt.Sprintf("%s-%s", constants.CsrfCookieName, cookieId) - redirectCookieName := fmt.Sprintf("%s-%s", constants.RedirectCookieName, cookieId) - - log.Debug().Msg("Deriving HMAC and encryption secrets") - - hmacSecret, err := utils.DeriveKey(config.Secret, "hmac") - HandleError(err, "Failed to derive HMAC secret") - - encryptionSecret, err := utils.DeriveKey(config.Secret, "encryption") - HandleError(err, "Failed to derive encryption secret") - - // Split the config into service-specific sub-configs - oauthConfig := types.OAuthConfig{ - GithubClientId: config.GithubClientId, - GithubClientSecret: config.GithubClientSecret, - GoogleClientId: config.GoogleClientId, - GoogleClientSecret: config.GoogleClientSecret, - GenericClientId: config.GenericClientId, - GenericClientSecret: config.GenericClientSecret, - GenericScopes: strings.Split(config.GenericScopes, ","), - GenericAuthURL: config.GenericAuthURL, - GenericTokenURL: config.GenericTokenURL, - GenericUserURL: config.GenericUserURL, - GenericSkipSSL: config.GenericSkipSSL, - AppURL: config.AppURL, + err = validator.Struct(conf) + if err != nil { + log.Fatal().Err(err).Msg("Invalid config") } - authConfig := types.AuthConfig{ - Users: users, - OauthWhitelist: config.OAuthWhitelist, - CookieSecure: config.CookieSecure, - SessionExpiry: config.SessionExpiry, - Domain: domain, - LoginTimeout: config.LoginTimeout, - LoginMaxRetries: config.LoginMaxRetries, - SessionCookieName: sessionCookieName, - HMACSecret: hmacSecret, - EncryptionSecret: encryptionSecret, + log.Logger = log.Level(zerolog.Level(utils.GetLogLevel(conf.LogLevel))) + log.Info().Str("version", strings.TrimSpace(config.Version)).Msg("Starting tinyauth") + + // Create bootstrap app + app := bootstrap.NewBootstrapApp(conf) + + // Run + err = app.Setup() + + if err != nil { + log.Fatal().Err(err).Msg("Failed to setup app") } - var ldapService *ldap.LDAP - - if config.LdapAddress != "" { - log.Info().Msg("Using LDAP for authentication") - ldapConfig := types.LdapConfig{ - Address: config.LdapAddress, - BindDN: config.LdapBindDN, - BindPassword: config.LdapBindPassword, - BaseDN: config.LdapBaseDN, - Insecure: config.LdapInsecure, - SearchFilter: config.LdapSearchFilter, - } - ldapService, err = ldap.NewLDAP(ldapConfig) - if err != nil { - log.Error().Err(err).Msg("Failed to initialize LDAP service, disabling LDAP authentication") - ldapService = nil - } - } else { - log.Info().Msg("LDAP not configured, using local users or OAuth") - } - - // Check if we have a source of users - if len(users) == 0 && !utils.OAuthConfigured(config) && ldapService == nil { - HandleError(errors.New("err no users"), "Unable to find a source of users") - } - - // Setup the services - docker, err := docker.NewDocker() - HandleError(err, "Failed to initialize docker") - auth := auth.NewAuth(authConfig, docker, ldapService) - providers := providers.NewProviders(oauthConfig) - - // Create the engine - engine := gin.New() - - // Create the group - router := engine.Group("/api") - - // Setup the middlewares - var middlewares []Middleware - - contextMiddleware := middleware.NewContextMiddleware(middleware.ContextMiddlewareConfig{ - Domain: domain, - }, auth, providers) - uiMiddleware := middleware.NewUIMiddleware() - zerologMiddleware := middleware.NewZerologMiddleware() - - middlewares = append(middlewares, contextMiddleware, uiMiddleware, zerologMiddleware) - - for _, middleware := range middlewares { - log.Debug().Str("middleware", middleware.Name()).Msg("Initializing middleware") - err := middleware.Init() - HandleError(err, fmt.Sprintf("Failed to initialize middleware %s", middleware.Name())) - router.Use(middleware.Middleware()) - } - - // Create configured providers - var configuredProviders []string - - configuredProviders = append(configuredProviders, providers.GetConfiguredProviders()...) - - if auth.UserAuthConfigured() { - configuredProviders = append(configuredProviders, "username") - } - - // Create controllers - contextController := controller.NewContextController(controller.ContextControllerConfig{ - ConfiguredProviders: configuredProviders, - DisableContinue: config.DisableContinue, - Title: config.Title, - GenericName: config.GenericName, - Domain: domain, - ForgotPasswordMessage: config.FogotPasswordMessage, - BackgroundImage: config.BackgroundImage, - OAuthAutoRedirect: config.OAuthAutoRedirect, - }, router) - contextController.SetupRoutes() - - healthController := controller.NewHealthController(router) - healthController.SetupRoutes() - - oauthController := controller.NewOAuthController(controller.OAuthControllerConfig{ - AppURL: config.AppURL, - SecureCookie: config.CookieSecure, - CSRFCookieName: csrfCookieName, - RedirectCookieName: redirectCookieName, - }, router, auth, providers) - oauthController.SetupRoutes() - - proxyController := controller.NewProxyController(controller.ProxyControllerConfig{ - AppURL: config.AppURL, - }, router, docker, auth) - proxyController.SetupRoutes() - - userController := controller.NewUserController(controller.UserControllerConfig{ - Domain: domain, - }, router, auth) - userController.SetupRoutes() - - // Run server - engine.Run(fmt.Sprintf("%s:%d", config.Address, config.Port)) }, } func Execute() { err := rootCmd.Execute() - HandleError(err, "Failed to execute root command") -} - -func HandleError(err error, msg string) { if err != nil { - log.Fatal().Err(err).Msg(msg) + log.Fatal().Err(err).Msg("Failed to execute command") } } @@ -224,85 +69,66 @@ func init() { viper.AutomaticEnv() - rootCmd.Flags().Int("port", 3000, "Port to run the server on.") - rootCmd.Flags().String("address", "0.0.0.0", "Address to bind the server to.") - rootCmd.Flags().String("secret", "", "Secret to use for the cookie.") - rootCmd.Flags().String("secret-file", "", "Path to a file containing the secret.") - rootCmd.Flags().String("app-url", "", "The tinyauth URL.") - rootCmd.Flags().String("users", "", "Comma separated list of users in the format username:hash.") - rootCmd.Flags().String("users-file", "", "Path to a file containing users in the format username:hash.") - rootCmd.Flags().Bool("cookie-secure", false, "Send cookie over secure connection only.") - rootCmd.Flags().String("github-client-id", "", "Github OAuth client ID.") - rootCmd.Flags().String("github-client-secret", "", "Github OAuth client secret.") - rootCmd.Flags().String("github-client-secret-file", "", "Github OAuth client secret file.") - rootCmd.Flags().String("google-client-id", "", "Google OAuth client ID.") - rootCmd.Flags().String("google-client-secret", "", "Google OAuth client secret.") - rootCmd.Flags().String("google-client-secret-file", "", "Google OAuth client secret file.") - rootCmd.Flags().String("generic-client-id", "", "Generic OAuth client ID.") - rootCmd.Flags().String("generic-client-secret", "", "Generic OAuth client secret.") - rootCmd.Flags().String("generic-client-secret-file", "", "Generic OAuth client secret file.") - rootCmd.Flags().String("generic-scopes", "", "Generic OAuth scopes.") - rootCmd.Flags().String("generic-auth-url", "", "Generic OAuth auth URL.") - rootCmd.Flags().String("generic-token-url", "", "Generic OAuth token URL.") - rootCmd.Flags().String("generic-user-url", "", "Generic OAuth user info URL.") - rootCmd.Flags().String("generic-name", "Generic", "Generic OAuth provider name.") - rootCmd.Flags().Bool("generic-skip-ssl", false, "Skip SSL verification for the generic OAuth provider.") - rootCmd.Flags().Bool("disable-continue", false, "Disable continue screen and redirect to app directly.") - rootCmd.Flags().String("oauth-whitelist", "", "Comma separated list of email addresses to whitelist when using OAuth.") - rootCmd.Flags().String("oauth-auto-redirect", "none", "Auto redirect to the specified OAuth provider if configured. (available providers: github, google, generic)") - rootCmd.Flags().Int("session-expiry", 86400, "Session (cookie) expiration time in seconds.") - rootCmd.Flags().Int("login-timeout", 300, "Login timeout in seconds after max retries reached (0 to disable).") - rootCmd.Flags().Int("login-max-retries", 5, "Maximum login attempts before timeout (0 to disable).") - rootCmd.Flags().Int("log-level", 1, "Log level.") - rootCmd.Flags().String("app-title", "Tinyauth", "Title of the app.") - rootCmd.Flags().String("forgot-password-message", "", "Message to show on the forgot password page.") - rootCmd.Flags().String("background-image", "/background.jpg", "Background image URL for the login page.") - rootCmd.Flags().String("ldap-address", "", "LDAP server address (e.g. ldap://localhost:389).") - rootCmd.Flags().String("ldap-bind-dn", "", "LDAP bind DN (e.g. uid=user,dc=example,dc=com).") - rootCmd.Flags().String("ldap-bind-password", "", "LDAP bind password.") - rootCmd.Flags().String("ldap-base-dn", "", "LDAP base DN (e.g. dc=example,dc=com).") - rootCmd.Flags().Bool("ldap-insecure", false, "Skip certificate verification for the LDAP server.") - rootCmd.Flags().String("ldap-search-filter", "(uid=%s)", "LDAP search filter for user lookup.") + configOptions := []struct { + name string + defaultVal any + description string + }{ + {"port", 3000, "Port to run the server on."}, + {"address", "0.0.0.0", "Address to bind the server to."}, + {"secret", "", "Secret to use for the cookie."}, + {"secret-file", "", "Path to a file containing the secret."}, + {"app-url", "", "The Tinyauth URL."}, + {"users", "", "Comma separated list of users in the format username:hash."}, + {"users-file", "", "Path to a file containing users in the format username:hash."}, + {"cookie-secure", false, "Send cookie over secure connection only."}, + {"github-client-id", "", "Github OAuth client ID."}, + {"github-client-secret", "", "Github OAuth client secret."}, + {"github-client-secret-file", "", "Github OAuth client secret file."}, + {"google-client-id", "", "Google OAuth client ID."}, + {"google-client-secret", "", "Google OAuth client secret."}, + {"google-client-secret-file", "", "Google OAuth client secret file."}, + {"generic-client-id", "", "Generic OAuth client ID."}, + {"generic-client-secret", "", "Generic OAuth client secret."}, + {"generic-client-secret-file", "", "Generic OAuth client secret file."}, + {"generic-scopes", "", "Generic OAuth scopes."}, + {"generic-auth-url", "", "Generic OAuth auth URL."}, + {"generic-token-url", "", "Generic OAuth token URL."}, + {"generic-user-url", "", "Generic OAuth user info URL."}, + {"generic-name", "Generic", "Generic OAuth provider name."}, + {"generic-skip-ssl", false, "Skip SSL verification for the generic OAuth provider."}, + {"disable-continue", false, "Disable continue screen and redirect to app directly."}, + {"oauth-whitelist", "", "Comma separated list of email addresses to whitelist when using OAuth."}, + {"oauth-auto-redirect", "none", "Auto redirect to the specified OAuth provider if configured. (available providers: github, google, generic)"}, + {"session-expiry", 86400, "Session (cookie) expiration time in seconds."}, + {"login-timeout", 300, "Login timeout in seconds after max retries reached (0 to disable)."}, + {"login-max-retries", 5, "Maximum login attempts before timeout (0 to disable)."}, + {"log-level", "info", "Log level."}, + {"app-title", "Tinyauth", "Title of the app."}, + {"forgot-password-message", "", "Message to show on the forgot password page."}, + {"background-image", "/background.jpg", "Background image URL for the login page."}, + {"ldap-address", "", "LDAP server address (e.g. ldap://localhost:389)."}, + {"ldap-bind-dn", "", "LDAP bind DN (e.g. uid=user,dc=example,dc=com)."}, + {"ldap-bind-password", "", "LDAP bind password."}, + {"ldap-base-dn", "", "LDAP base DN (e.g. dc=example,dc=com)."}, + {"ldap-insecure", false, "Skip certificate verification for the LDAP server."}, + {"ldap-search-filter", "(uid=%s)", "LDAP search filter for user lookup."}, + } - viper.BindEnv("port", "PORT") - viper.BindEnv("address", "ADDRESS") - viper.BindEnv("secret", "SECRET") - viper.BindEnv("secret-file", "SECRET_FILE") - viper.BindEnv("app-url", "APP_URL") - viper.BindEnv("users", "USERS") - viper.BindEnv("users-file", "USERS_FILE") - viper.BindEnv("cookie-secure", "COOKIE_SECURE") - viper.BindEnv("github-client-id", "GITHUB_CLIENT_ID") - viper.BindEnv("github-client-secret", "GITHUB_CLIENT_SECRET") - viper.BindEnv("github-client-secret-file", "GITHUB_CLIENT_SECRET_FILE") - viper.BindEnv("google-client-id", "GOOGLE_CLIENT_ID") - viper.BindEnv("google-client-secret", "GOOGLE_CLIENT_SECRET") - viper.BindEnv("google-client-secret-file", "GOOGLE_CLIENT_SECRET_FILE") - viper.BindEnv("generic-client-id", "GENERIC_CLIENT_ID") - viper.BindEnv("generic-client-secret", "GENERIC_CLIENT_SECRET") - viper.BindEnv("generic-client-secret-file", "GENERIC_CLIENT_SECRET_FILE") - viper.BindEnv("generic-scopes", "GENERIC_SCOPES") - viper.BindEnv("generic-auth-url", "GENERIC_AUTH_URL") - viper.BindEnv("generic-token-url", "GENERIC_TOKEN_URL") - viper.BindEnv("generic-user-url", "GENERIC_USER_URL") - viper.BindEnv("generic-name", "GENERIC_NAME") - viper.BindEnv("generic-skip-ssl", "GENERIC_SKIP_SSL") - viper.BindEnv("disable-continue", "DISABLE_CONTINUE") - viper.BindEnv("oauth-whitelist", "OAUTH_WHITELIST") - viper.BindEnv("oauth-auto-redirect", "OAUTH_AUTO_REDIRECT") - viper.BindEnv("session-expiry", "SESSION_EXPIRY") - viper.BindEnv("log-level", "LOG_LEVEL") - viper.BindEnv("app-title", "APP_TITLE") - viper.BindEnv("login-timeout", "LOGIN_TIMEOUT") - viper.BindEnv("login-max-retries", "LOGIN_MAX_RETRIES") - viper.BindEnv("forgot-password-message", "FORGOT_PASSWORD_MESSAGE") - viper.BindEnv("background-image", "BACKGROUND_IMAGE") - viper.BindEnv("ldap-address", "LDAP_ADDRESS") - viper.BindEnv("ldap-bind-dn", "LDAP_BIND_DN") - viper.BindEnv("ldap-bind-password", "LDAP_BIND_PASSWORD") - viper.BindEnv("ldap-base-dn", "LDAP_BASE_DN") - viper.BindEnv("ldap-insecure", "LDAP_INSECURE") - viper.BindEnv("ldap-search-filter", "LDAP_SEARCH_FILTER") + for _, opt := range configOptions { + switch v := opt.defaultVal.(type) { + case bool: + rootCmd.Flags().Bool(opt.name, v, opt.description) + case int: + rootCmd.Flags().Int(opt.name, v, opt.description) + case string: + rootCmd.Flags().String(opt.name, v, opt.description) + } + + // Create uppercase env var name + envVar := strings.ReplaceAll(strings.ToUpper(opt.name), "-", "_") + viper.BindEnv(opt.name, envVar) + } viper.BindPFlags(rootCmd.Flags()) } diff --git a/cmd/version.go b/cmd/version.go index ffbd6fc..2a1827b 100644 --- a/cmd/version.go +++ b/cmd/version.go @@ -2,7 +2,7 @@ package cmd import ( "fmt" - "tinyauth/internal/constants" + "tinyauth/internal/config" "github.com/spf13/cobra" ) @@ -12,9 +12,9 @@ var versionCmd = &cobra.Command{ Short: "Print the version number of Tinyauth", Long: `All software has versions. This is Tinyauth's`, Run: func(cmd *cobra.Command, args []string) { - fmt.Printf("Version: %s\n", constants.Version) - fmt.Printf("Commit Hash: %s\n", constants.CommitHash) - fmt.Printf("Build Timestamp: %s\n", constants.BuildTimestamp) + fmt.Printf("Version: %s\n", config.Version) + fmt.Printf("Commit Hash: %s\n", config.CommitHash) + fmt.Printf("Build Timestamp: %s\n", config.BuildTimestamp) }, } diff --git a/internal/bootstrap/bootstrap_app.go b/internal/bootstrap/bootstrap_app.go new file mode 100644 index 0000000..2dfc61d --- /dev/null +++ b/internal/bootstrap/bootstrap_app.go @@ -0,0 +1,246 @@ +package bootstrap + +import ( + "fmt" + "strings" + "tinyauth/internal/config" + "tinyauth/internal/controller" + "tinyauth/internal/middleware" + "tinyauth/internal/service" + "tinyauth/internal/utils" + + "github.com/gin-gonic/gin" + "github.com/rs/zerolog/log" +) + +type Controller interface { + SetupRoutes() +} + +type Middleware interface { + Middleware() gin.HandlerFunc + Init() error + Name() string +} + +type Service interface { + Init() error +} + +type BootstrapApp struct { + Config config.Config +} + +func NewBootstrapApp(config config.Config) *BootstrapApp { + return &BootstrapApp{ + Config: config, + } +} + +func (app *BootstrapApp) Setup() error { + // Parse users + users, err := utils.GetUsers(app.Config.Users, app.Config.UsersFile) + + if err != nil { + return err + } + + // Get domain + domain, err := utils.GetUpperDomain(app.Config.AppURL) + + if err != nil { + return err + } + + // Cookie names + cookieId := utils.GenerateIdentifier(strings.Split(domain, ".")[0]) + sessionCookieName := fmt.Sprintf("%s-%s", config.SessionCookieName, cookieId) + csrfCookieName := fmt.Sprintf("%s-%s", config.CSRFCookieName, cookieId) + redirectCookieName := fmt.Sprintf("%s-%s", config.RedirectCookieName, cookieId) + + // Secrets + encryptionSecret, err := utils.DeriveKey(app.Config.Secret, "encryption") + + if err != nil { + return err + } + + hmacSecret, err := utils.DeriveKey(app.Config.Secret, "hmac") + + if err != nil { + return err + } + + // Create configs + authConfig := service.AuthServiceConfig{ + Users: users, + OauthWhitelist: app.Config.OAuthWhitelist, + SessionExpiry: app.Config.SessionExpiry, + SecureCookie: app.Config.SecureCookie, + Domain: domain, + LoginTimeout: app.Config.LoginTimeout, + LoginMaxRetries: app.Config.LoginMaxRetries, + SessionCookieName: sessionCookieName, + HMACSecret: hmacSecret, + EncryptionSecret: encryptionSecret, + } + + // Setup services + var ldapService *service.LdapService + + if app.Config.LdapAddress != "" { + ldapConfig := service.LdapServiceConfig{ + Address: app.Config.LdapAddress, + BindDN: app.Config.LdapBindDN, + BindPassword: app.Config.LdapBindPassword, + BaseDN: app.Config.LdapBaseDN, + Insecure: app.Config.LdapInsecure, + SearchFilter: app.Config.LdapSearchFilter, + } + + ldapService = service.NewLdapService(ldapConfig) + + err := ldapService.Init() + + if err != nil { + ldapService = nil + } + } + + dockerService := service.NewDockerService() + authService := service.NewAuthService(authConfig, dockerService, ldapService) + oauthBrokerService := service.NewOAuthBrokerService(app.getOAuthBrokerConfig()) + + // Initialize services + services := []Service{ + dockerService, + authService, + oauthBrokerService, + } + + for _, svc := range services { + if svc != nil { + err := svc.Init() + if err != nil { + return err + } + } + } + + // Configured providers + var configuredProviders []string + + if authService.UserAuthConfigured() || ldapService != nil { + configuredProviders = append(configuredProviders, "username") + } + + configuredProviders = append(configuredProviders, oauthBrokerService.GetConfiguredServices()...) + + if len(configuredProviders) == 0 { + return fmt.Errorf("no authentication providers configured") + } + + // Create engine + engine := gin.New() + router := engine.Group("/api") + + // Create middlewares + var middlewares []Middleware + + contextMiddleware := middleware.NewContextMiddleware(middleware.ContextMiddlewareConfig{ + Domain: domain, + }, authService, oauthBrokerService) + + uiMiddleware := middleware.NewUIMiddleware() + zerologMiddleware := middleware.NewZerologMiddleware() + + middlewares = append(middlewares, contextMiddleware, uiMiddleware, zerologMiddleware) + + for _, middleware := range middlewares { + log.Debug().Str("middleware", middleware.Name()).Msg("Initializing middleware") + err := middleware.Init() + if err != nil { + return fmt.Errorf("failed to initialize %s middleware: %w", middleware.Name(), err) + } + router.Use(middleware.Middleware()) + } + + // Create controllers + contextController := controller.NewContextController(controller.ContextControllerConfig{ + ConfiguredProviders: configuredProviders, + DisableContinue: app.Config.DisableContinue, + Title: app.Config.Title, + GenericName: app.Config.GenericName, + Domain: domain, + ForgotPasswordMessage: app.Config.FogotPasswordMessage, + BackgroundImage: app.Config.BackgroundImage, + OAuthAutoRedirect: app.Config.OAuthAutoRedirect, + }, router) + + oauthController := controller.NewOAuthController(controller.OAuthControllerConfig{ + AppURL: app.Config.AppURL, + SecureCookie: app.Config.SecureCookie, + CSRFCookieName: csrfCookieName, + RedirectCookieName: redirectCookieName, + }, router, authService, oauthBrokerService) + + proxyController := controller.NewProxyController(controller.ProxyControllerConfig{ + AppURL: app.Config.AppURL, + }, router, dockerService, authService) + + userController := controller.NewUserController(controller.UserControllerConfig{ + Domain: domain, + }, router, authService) + + healthController := controller.NewHealthController(router) + + // Setup routes + controller := []Controller{ + contextController, + oauthController, + proxyController, + userController, + healthController, + } + + for _, ctrl := range controller { + log.Debug().Msgf("Setting up %T routes", ctrl) + ctrl.SetupRoutes() + } + + // Start server + address := fmt.Sprintf("%s:%d", app.Config.Address, app.Config.Port) + log.Info().Msgf("Starting server on %s", address) + if err := engine.Run(address); err != nil { + log.Fatal().Err(err).Msg("Failed to start server") + } + + return nil +} + +// Temporary +func (app *BootstrapApp) getOAuthBrokerConfig() map[string]config.OAuthServiceConfig { + return map[string]config.OAuthServiceConfig{ + "google": { + ClientID: app.Config.GoogleClientId, + ClientSecret: app.Config.GoogleClientSecret, + RedirectURL: fmt.Sprintf("%s/api/oauth/callback/google", app.Config.AppURL), + }, + "github": { + ClientID: app.Config.GithubClientId, + ClientSecret: app.Config.GithubClientSecret, + RedirectURL: fmt.Sprintf("%s/api/oauth/callback/github", app.Config.AppURL), + }, + "generic": { + ClientID: app.Config.GenericClientId, + ClientSecret: app.Config.GenericClientSecret, + RedirectURL: fmt.Sprintf("%s/api/oauth/callback/generic", app.Config.AppURL), + Scopes: strings.Split(app.Config.GenericScopes, ","), + AuthURL: app.Config.GenericAuthURL, + TokenURL: app.Config.GenericTokenURL, + UserinfoURL: app.Config.GenericUserURL, + InsecureSkipVerify: app.Config.GenericSkipSSL, + }, + } + +} diff --git a/internal/config/config.go b/internal/config/config.go index 5584d0e..655b61a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -12,7 +12,7 @@ var CommitHash = "n/a" var BuildTimestamp = "n/a" var SessionCookieName = "tinyauth-session" -var CsrfCookieName = "tinyauth-csrf" +var CSRFCookieName = "tinyauth-csrf" var RedirectCookieName = "tinyauth-redirect" type Config struct { @@ -23,7 +23,7 @@ type Config struct { AppURL string `validate:"required,url" mapstructure:"app-url"` Users string `mapstructure:"users"` UsersFile string `mapstructure:"users-file"` - CookieSecure bool `mapstructure:"cookie-secure"` + SecureCookie bool `mapstructure:"secure-cookie"` GithubClientId string `mapstructure:"github-client-id"` GithubClientSecret string `mapstructure:"github-client-secret"` GithubClientSecretFile string `mapstructure:"github-client-secret-file"` @@ -43,9 +43,8 @@ type Config struct { OAuthWhitelist string `mapstructure:"oauth-whitelist"` OAuthAutoRedirect string `mapstructure:"oauth-auto-redirect" validate:"oneof=none github google generic"` SessionExpiry int `mapstructure:"session-expiry"` - LogLevel int8 `mapstructure:"log-level" validate:"min=-1,max=5"` + LogLevel string `mapstructure:"log-level" validate:"oneof=trace debug info warn error fatal panic"` Title string `mapstructure:"app-title"` - EnvFile string `mapstructure:"env-file"` LoginTimeout int `mapstructure:"login-timeout"` LoginMaxRetries int `mapstructure:"login-max-retries"` FogotPasswordMessage string `mapstructure:"forgot-password-message"` diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index 46bad06..e91f98a 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -25,7 +25,7 @@ type AuthServiceConfig struct { Users []config.User OauthWhitelist string SessionExpiry int - CookieSecure bool + SecureCookie bool Domain string LoginTimeout int LoginMaxRetries int @@ -57,10 +57,11 @@ func (auth *AuthService) Init() error { store.Options = &sessions.Options{ Path: "/", MaxAge: auth.Config.SessionExpiry, - Secure: auth.Config.CookieSecure, + Secure: auth.Config.SecureCookie, HttpOnly: true, Domain: fmt.Sprintf(".%s", auth.Config.Domain), } + auth.Store = store return nil } @@ -70,7 +71,7 @@ func (auth *AuthService) GetSession(c *gin.Context) (*sessions.Session, error) { // If there was an error getting the session, it might be invalid so let's clear it and retry if err != nil { log.Error().Err(err).Msg("Invalid session, clearing cookie and retrying") - c.SetCookie(auth.Config.SessionCookieName, "", -1, "/", fmt.Sprintf(".%s", auth.Config.Domain), auth.Config.CookieSecure, true) + c.SetCookie(auth.Config.SessionCookieName, "", -1, "/", fmt.Sprintf(".%s", auth.Config.Domain), auth.Config.SecureCookie, true) session, err = auth.Store.Get(c.Request, auth.Config.SessionCookieName) if err != nil { log.Error().Err(err).Msg("Failed to get session") diff --git a/internal/service/github_oauth_service.go b/internal/service/github_oauth_service.go index a8c1334..2f9e27f 100644 --- a/internal/service/github_oauth_service.go +++ b/internal/service/github_oauth_service.go @@ -11,6 +11,7 @@ import ( "tinyauth/internal/config" "golang.org/x/oauth2" + "golang.org/x/oauth2/endpoints" ) var GithubOAuthScopes = []string{"user:email", "read:user"} @@ -39,6 +40,7 @@ func NewGithubOAuthService(config config.OAuthServiceConfig) *GithubOAuthService ClientSecret: config.ClientSecret, RedirectURL: config.RedirectURL, Scopes: GithubOAuthScopes, + Endpoint: endpoints.GitHub, }, } } diff --git a/internal/service/google_oauth_service.go b/internal/service/google_oauth_service.go index 6d9eaed..776aeca 100644 --- a/internal/service/google_oauth_service.go +++ b/internal/service/google_oauth_service.go @@ -11,6 +11,7 @@ import ( "tinyauth/internal/config" "golang.org/x/oauth2" + "golang.org/x/oauth2/endpoints" ) var GoogleOAuthScopes = []string{"https://www.googleapis.com/auth/userinfo.email", "https://www.googleapis.com/auth/userinfo.profile"} @@ -34,6 +35,7 @@ func NewGoogleOAuthService(config config.OAuthServiceConfig) *GoogleOAuthService ClientSecret: config.ClientSecret, RedirectURL: config.RedirectURL, Scopes: GoogleOAuthScopes, + Endpoint: endpoints.Google, }, } } diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 67b904f..7181a26 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -18,6 +18,7 @@ import ( "golang.org/x/crypto/hkdf" "github.com/google/uuid" + "github.com/rs/zerolog" "github.com/rs/zerolog/log" ) @@ -360,3 +361,22 @@ func GetContext(c *gin.Context) (config.UserContext, error) { return *userContext, nil } + +func GetLogLevel(level string) zerolog.Level { + switch strings.ToLower(level) { + case "debug": + return zerolog.DebugLevel + case "info": + return zerolog.InfoLevel + case "warn": + return zerolog.WarnLevel + case "error": + return zerolog.ErrorLevel + case "fatal": + return zerolog.FatalLevel + case "panic": + return zerolog.PanicLevel + default: + return zerolog.InfoLevel + } +}