diff --git a/internal/controller/context_controller.go b/internal/controller/context_controller.go index 7cea62f..d285da3 100644 --- a/internal/controller/context_controller.go +++ b/internal/controller/context_controller.go @@ -46,19 +46,19 @@ type ContextControllerConfig struct { } type ContextController struct { - Config ContextControllerConfig - Router *gin.RouterGroup + config ContextControllerConfig + router *gin.RouterGroup } func NewContextController(config ContextControllerConfig, router *gin.RouterGroup) *ContextController { return &ContextController{ - Config: config, - Router: router, + config: config, + router: router, } } func (controller *ContextController) SetupRoutes() { - contextGroup := controller.Router.Group("/context") + contextGroup := controller.router.Group("/context") contextGroup.GET("/user", controller.userContextHandler) contextGroup.GET("/app", controller.appContextHandler) } @@ -91,18 +91,18 @@ func (controller *ContextController) userContextHandler(c *gin.Context) { } func (controller *ContextController) appContextHandler(c *gin.Context) { - appUrl, _ := url.Parse(controller.Config.AppURL) // no need to check error, validated on startup + appUrl, _ := url.Parse(controller.config.AppURL) // no need to check error, validated on startup c.JSON(200, AppContextResponse{ Status: 200, Message: "Success", - ConfiguredProviders: controller.Config.ConfiguredProviders, - Title: controller.Config.Title, - GenericName: controller.Config.GenericName, + ConfiguredProviders: controller.config.ConfiguredProviders, + Title: controller.config.Title, + GenericName: controller.config.GenericName, AppURL: fmt.Sprintf("%s://%s", appUrl.Scheme, appUrl.Host), - RootDomain: controller.Config.RootDomain, - ForgotPasswordMessage: controller.Config.ForgotPasswordMessage, - BackgroundImage: controller.Config.BackgroundImage, - OAuthAutoRedirect: controller.Config.OAuthAutoRedirect, + RootDomain: controller.config.RootDomain, + ForgotPasswordMessage: controller.config.ForgotPasswordMessage, + BackgroundImage: controller.config.BackgroundImage, + OAuthAutoRedirect: controller.config.OAuthAutoRedirect, }) } diff --git a/internal/controller/health_controller.go b/internal/controller/health_controller.go index 842b3d3..8f0aa42 100644 --- a/internal/controller/health_controller.go +++ b/internal/controller/health_controller.go @@ -3,18 +3,18 @@ package controller import "github.com/gin-gonic/gin" type HealthController struct { - Router *gin.RouterGroup + router *gin.RouterGroup } func NewHealthController(router *gin.RouterGroup) *HealthController { return &HealthController{ - Router: router, + router: router, } } func (controller *HealthController) SetupRoutes() { - controller.Router.GET("/health", controller.healthHandler) - controller.Router.HEAD("/health", controller.healthHandler) + controller.router.GET("/health", controller.healthHandler) + controller.router.HEAD("/health", controller.healthHandler) } func (controller *HealthController) healthHandler(c *gin.Context) { diff --git a/internal/controller/oauth_controller.go b/internal/controller/oauth_controller.go index cfac656..23d00de 100644 --- a/internal/controller/oauth_controller.go +++ b/internal/controller/oauth_controller.go @@ -27,23 +27,23 @@ type OAuthControllerConfig struct { } type OAuthController struct { - Config OAuthControllerConfig - Router *gin.RouterGroup - Auth *service.AuthService - Broker *service.OAuthBrokerService + config OAuthControllerConfig + router *gin.RouterGroup + auth *service.AuthService + broker *service.OAuthBrokerService } func NewOAuthController(config OAuthControllerConfig, router *gin.RouterGroup, auth *service.AuthService, broker *service.OAuthBrokerService) *OAuthController { return &OAuthController{ - Config: config, - Router: router, - Auth: auth, - Broker: broker, + config: config, + router: router, + auth: auth, + broker: broker, } } func (controller *OAuthController) SetupRoutes() { - oauthGroup := controller.Router.Group("/oauth") + oauthGroup := controller.router.Group("/oauth") oauthGroup.GET("/url/:provider", controller.oauthURLHandler) oauthGroup.GET("/callback/:provider", controller.oauthCallbackHandler) } @@ -61,7 +61,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) { return } - service, exists := controller.Broker.GetService(req.Provider) + service, exists := controller.broker.GetService(req.Provider) if !exists { log.Warn().Msgf("OAuth provider not found: %s", req.Provider) @@ -74,13 +74,13 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) { state := service.GenerateState() authURL := service.GetAuthURL(state) - c.SetCookie(controller.Config.CSRFCookieName, state, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.Config.RootDomain), controller.Config.SecureCookie, true) + c.SetCookie(controller.config.CSRFCookieName, state, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.config.RootDomain), controller.config.SecureCookie, true) redirectURI := c.Query("redirect_uri") - if redirectURI != "" && utils.IsRedirectSafe(redirectURI, controller.Config.RootDomain) { + if redirectURI != "" && utils.IsRedirectSafe(redirectURI, controller.config.RootDomain) { log.Debug().Msg("Setting redirect URI cookie") - c.SetCookie(controller.Config.RedirectCookieName, redirectURI, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.Config.RootDomain), controller.Config.SecureCookie, true) + c.SetCookie(controller.config.RedirectCookieName, redirectURI, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.config.RootDomain), controller.config.SecureCookie, true) } c.JSON(200, gin.H{ @@ -104,58 +104,58 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { } state := c.Query("state") - csrfCookie, err := c.Cookie(controller.Config.CSRFCookieName) + csrfCookie, err := c.Cookie(controller.config.CSRFCookieName) if err != nil || state != csrfCookie { log.Warn().Err(err).Msg("CSRF token mismatch or cookie missing") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) return } - c.SetCookie(controller.Config.CSRFCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.Config.RootDomain), controller.Config.SecureCookie, true) + c.SetCookie(controller.config.CSRFCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.RootDomain), controller.config.SecureCookie, true) code := c.Query("code") - service, exists := controller.Broker.GetService(req.Provider) + service, exists := controller.broker.GetService(req.Provider) if !exists { log.Warn().Msgf("OAuth provider not found: %s", req.Provider) - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) return } err = service.VerifyCode(code) if err != nil { log.Error().Err(err).Msg("Failed to verify OAuth code") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) return } - user, err := controller.Broker.GetUser(req.Provider) + user, err := controller.broker.GetUser(req.Provider) if err != nil { log.Error().Err(err).Msg("Failed to get user from OAuth provider") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) return } if user.Email == "" { log.Error().Msg("OAuth provider did not return an email") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) return } - if !controller.Auth.IsEmailWhitelisted(user.Email) { + if !controller.auth.IsEmailWhitelisted(user.Email) { queries, err := query.Values(config.UnauthorizedQuery{ Username: user.Email, }) if err != nil { log.Error().Err(err).Msg("Failed to encode unauthorized query") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) return } - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.Config.AppURL, queries.Encode())) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())) return } @@ -169,29 +169,35 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1]) } - var usename string + var username string if user.PreferredUsername != "" { log.Debug().Msg("Using preferred username from OAuth provider") - usename = user.PreferredUsername + username = user.PreferredUsername } else { log.Debug().Msg("No preferred username from OAuth provider, using pseudo username") - usename = strings.Replace(user.Email, "@", "_", -1) + username = strings.Replace(user.Email, "@", "_", -1) } - controller.Auth.CreateSessionCookie(c, &config.SessionCookie{ - Username: usename, + err = controller.auth.CreateSessionCookie(c, &config.SessionCookie{ + Username: username, Name: name, Email: user.Email, Provider: req.Provider, OAuthGroups: utils.CoalesceToString(user.Groups), }) - redirectURI, err := c.Cookie(controller.Config.RedirectCookieName) + if err != nil { + log.Error().Err(err).Msg("Failed to create session cookie") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + return + } - if err != nil || !utils.IsRedirectSafe(redirectURI, controller.Config.RootDomain) { + redirectURI, err := c.Cookie(controller.config.RedirectCookieName) + + if err != nil || !utils.IsRedirectSafe(redirectURI, controller.config.RootDomain) { log.Debug().Msg("No redirect URI cookie found, redirecting to app root") - c.Redirect(http.StatusTemporaryRedirect, controller.Config.AppURL) + c.Redirect(http.StatusTemporaryRedirect, controller.config.AppURL) return } @@ -201,10 +207,10 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { if err != nil { log.Error().Err(err).Msg("Failed to encode redirect URI query") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) return } - c.SetCookie(controller.Config.RedirectCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.Config.RootDomain), controller.Config.SecureCookie, true) - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.Config.AppURL, queries.Encode())) + c.SetCookie(controller.config.RedirectCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.RootDomain), controller.config.SecureCookie, true) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.config.AppURL, queries.Encode())) } diff --git a/internal/controller/proxy_controller.go b/internal/controller/proxy_controller.go index 6ad10ec..fd25076 100644 --- a/internal/controller/proxy_controller.go +++ b/internal/controller/proxy_controller.go @@ -22,23 +22,23 @@ type ProxyControllerConfig struct { } type ProxyController struct { - Config ProxyControllerConfig - Router *gin.RouterGroup - Docker *service.DockerService - Auth *service.AuthService + config ProxyControllerConfig + router *gin.RouterGroup + docker *service.DockerService + auth *service.AuthService } func NewProxyController(config ProxyControllerConfig, router *gin.RouterGroup, docker *service.DockerService, auth *service.AuthService) *ProxyController { return &ProxyController{ - Config: config, - Router: router, - Docker: docker, - Auth: auth, + config: config, + router: router, + docker: docker, + auth: auth, } } func (controller *ProxyController) SetupRoutes() { - proxyGroup := controller.Router.Group("/auth") + proxyGroup := controller.router.Group("/auth") proxyGroup.GET("/:proxy", controller.proxyHandler) } @@ -67,44 +67,18 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { proto := c.Request.Header.Get("X-Forwarded-Proto") host := c.Request.Header.Get("X-Forwarded-Host") - hostWithoutPort := strings.Split(host, ":")[0] - id := strings.Split(hostWithoutPort, ".")[0] - - labels, err := controller.Docker.GetLabels(id, hostWithoutPort) + labels, err := controller.docker.GetLabels(host) if err != nil { log.Error().Err(err).Msg("Failed to get labels from Docker") - - if req.Proxy == "nginx" || !isBrowser { - c.JSON(500, gin.H{ - "status": 500, - "message": "Internal Server Error", - }) - return - } - - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + controller.handleError(c, req, isBrowser) return } clientIP := c.ClientIP() - if controller.Auth.IsBypassedIP(labels.IP, clientIP) { - c.Header("Authorization", c.Request.Header.Get("Authorization")) - - headers := utils.ParseHeaders(labels.Response.Headers) - - for key, value := range headers { - log.Debug().Str("header", key).Msg("Setting header") - c.Header(key, value) - } - - basicPassword := utils.GetSecret(labels.Response.BasicAuth.Password, labels.Response.BasicAuth.PasswordFile) - if labels.Response.BasicAuth.Username != "" && basicPassword != "" { - log.Debug().Str("username", labels.Response.BasicAuth.Username).Msg("Setting basic auth header") - c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Response.BasicAuth.Username, basicPassword))) - } - + if controller.auth.IsBypassedIP(labels.IP, clientIP) { + controller.setHeaders(c, labels) c.JSON(200, gin.H{ "status": 200, "message": "Authenticated", @@ -112,41 +86,17 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { return } - authEnabled, err := controller.Auth.IsAuthEnabled(uri, labels.Path) + authEnabled, err := controller.auth.IsAuthEnabled(uri, labels.Path) if err != nil { log.Error().Err(err).Msg("Failed to check if auth is enabled for resource") - - if req.Proxy == "nginx" || !isBrowser { - c.JSON(500, gin.H{ - "status": 500, - "message": "Internal Server Error", - }) - return - } - - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + controller.handleError(c, req, isBrowser) return } if !authEnabled { log.Debug().Msg("Authentication disabled for resource, allowing access") - - c.Header("Authorization", c.Request.Header.Get("Authorization")) - - headers := utils.ParseHeaders(labels.Response.Headers) - - for key, value := range headers { - log.Debug().Str("header", key).Msg("Setting header") - c.Header(key, value) - } - - basicPassword := utils.GetSecret(labels.Response.BasicAuth.Password, labels.Response.BasicAuth.PasswordFile) - if labels.Response.BasicAuth.Username != "" && basicPassword != "" { - log.Debug().Str("username", labels.Response.BasicAuth.Username).Msg("Setting basic auth header") - c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Response.BasicAuth.Username, basicPassword))) - } - + controller.setHeaders(c, labels) c.JSON(200, gin.H{ "status": 200, "message": "Authenticated", @@ -154,7 +104,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { return } - if !controller.Auth.CheckIP(labels.IP, clientIP) { + if !controller.auth.CheckIP(labels.IP, clientIP) { if req.Proxy == "nginx" || !isBrowser { c.JSON(401, gin.H{ "status": 401, @@ -170,11 +120,11 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { if err != nil { log.Error().Err(err).Msg("Failed to encode unauthorized query") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) return } - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.Config.AppURL, queries.Encode())) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())) return } @@ -197,7 +147,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { } if userContext.IsLoggedIn { - appAllowed := controller.Auth.IsResourceAllowed(c, userContext, labels) + appAllowed := controller.auth.IsResourceAllowed(c, userContext, labels) if !appAllowed { log.Warn().Str("user", userContext.Username).Str("resource", strings.Split(host, ".")[0]).Msg("User not allowed to access resource") @@ -214,24 +164,24 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { Resource: strings.Split(host, ".")[0], }) + if err != nil { + log.Error().Err(err).Msg("Failed to encode unauthorized query") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + return + } + if userContext.OAuth { queries.Set("username", userContext.Email) } else { queries.Set("username", userContext.Username) } - if err != nil { - log.Error().Err(err).Msg("Failed to encode unauthorized query") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) - return - } - - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.Config.AppURL, queries.Encode())) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())) return } if userContext.OAuth { - groupOK := controller.Auth.IsInOAuthGroup(c, userContext, labels.OAuth.Groups) + groupOK := controller.auth.IsInOAuthGroup(c, userContext, labels.OAuth.Groups) if !groupOK { log.Warn().Str("user", userContext.Username).Str("resource", strings.Split(host, ".")[0]).Msg("User OAuth groups do not match resource requirements") @@ -249,41 +199,29 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { GroupErr: true, }) + if err != nil { + log.Error().Err(err).Msg("Failed to encode unauthorized query") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + return + } + if userContext.OAuth { queries.Set("username", userContext.Email) } else { queries.Set("username", userContext.Username) } - if err != nil { - log.Error().Err(err).Msg("Failed to encode unauthorized query") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) - return - } - - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.Config.AppURL, queries.Encode())) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())) return } } - c.Header("Authorization", c.Request.Header.Get("Authorization")) c.Header("Remote-User", utils.SanitizeHeader(userContext.Username)) c.Header("Remote-Name", utils.SanitizeHeader(userContext.Name)) c.Header("Remote-Email", utils.SanitizeHeader(userContext.Email)) c.Header("Remote-Groups", utils.SanitizeHeader(userContext.OAuthGroups)) - headers := utils.ParseHeaders(labels.Response.Headers) - - for key, value := range headers { - log.Debug().Str("header", key).Msg("Setting header") - c.Header(key, value) - } - - basicPassword := utils.GetSecret(labels.Response.BasicAuth.Password, labels.Response.BasicAuth.PasswordFile) - if labels.Response.BasicAuth.Username != "" && basicPassword != "" { - log.Debug().Str("username", labels.Response.BasicAuth.Username).Msg("Setting basic auth header") - c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Response.BasicAuth.Username, basicPassword))) - } + controller.setHeaders(c, labels) c.JSON(200, gin.H{ "status": 200, @@ -306,9 +244,39 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { if err != nil { log.Error().Err(err).Msg("Failed to encode redirect URI query") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) return } - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/login?%s", controller.Config.AppURL, queries.Encode())) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/login?%s", controller.config.AppURL, queries.Encode())) +} + +func (controller *ProxyController) setHeaders(c *gin.Context, labels config.AppLabels) { + c.Header("Authorization", c.Request.Header.Get("Authorization")) + + headers := utils.ParseHeaders(labels.Response.Headers) + + for key, value := range headers { + log.Debug().Str("header", key).Msg("Setting header") + c.Header(key, value) + } + + basicPassword := utils.GetSecret(labels.Response.BasicAuth.Password, labels.Response.BasicAuth.PasswordFile) + + if labels.Response.BasicAuth.Username != "" && basicPassword != "" { + log.Debug().Str("username", labels.Response.BasicAuth.Username).Msg("Setting basic auth header") + c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Response.BasicAuth.Username, basicPassword))) + } +} + +func (controller *ProxyController) handleError(c *gin.Context, req Proxy, isBrowser bool) { + if req.Proxy == "nginx" || !isBrowser { + c.JSON(500, gin.H{ + "status": 500, + "message": "Internal Server Error", + }) + return + } + + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) } diff --git a/internal/controller/resources_controller.go b/internal/controller/resources_controller.go index 56bae87..92384e7 100644 --- a/internal/controller/resources_controller.go +++ b/internal/controller/resources_controller.go @@ -11,32 +11,32 @@ type ResourcesControllerConfig struct { } type ResourcesController struct { - Config ResourcesControllerConfig - Router *gin.RouterGroup - FileServer http.Handler + config ResourcesControllerConfig + router *gin.RouterGroup + fileServer http.Handler } func NewResourcesController(config ResourcesControllerConfig, router *gin.RouterGroup) *ResourcesController { fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(config.ResourcesDir))) return &ResourcesController{ - Config: config, - Router: router, - FileServer: fileServer, + config: config, + router: router, + fileServer: fileServer, } } func (controller *ResourcesController) SetupRoutes() { - controller.Router.GET("/resources/*resource", controller.resourcesHandler) + controller.router.GET("/resources/*resource", controller.resourcesHandler) } func (controller *ResourcesController) resourcesHandler(c *gin.Context) { - if controller.Config.ResourcesDir == "" { + if controller.config.ResourcesDir == "" { c.JSON(404, gin.H{ "status": 404, "message": "Resources not found", }) return } - controller.FileServer.ServeHTTP(c.Writer, c.Request) + controller.fileServer.ServeHTTP(c.Writer, c.Request) } diff --git a/internal/controller/user_controller.go b/internal/controller/user_controller.go index f3b7b51..7b48652 100644 --- a/internal/controller/user_controller.go +++ b/internal/controller/user_controller.go @@ -26,21 +26,21 @@ type UserControllerConfig struct { } type UserController struct { - Config UserControllerConfig - Router *gin.RouterGroup - Auth *service.AuthService + config UserControllerConfig + router *gin.RouterGroup + auth *service.AuthService } func NewUserController(config UserControllerConfig, router *gin.RouterGroup, auth *service.AuthService) *UserController { return &UserController{ - Config: config, - Router: router, - Auth: auth, + config: config, + router: router, + auth: auth, } } func (controller *UserController) SetupRoutes() { - userGroup := controller.Router.Group("/user") + userGroup := controller.router.Group("/user") userGroup.POST("/login", controller.loginHandler) userGroup.POST("/logout", controller.logoutHandler) userGroup.POST("/totp", controller.totpHandler) @@ -69,7 +69,7 @@ func (controller *UserController) loginHandler(c *gin.Context) { log.Debug().Str("username", req.Username).Str("ip", clientIP).Msg("Login attempt") - isLocked, remainingTime := controller.Auth.IsAccountLocked(rateIdentifier) + isLocked, remainingTime := controller.auth.IsAccountLocked(rateIdentifier) if isLocked { log.Warn().Str("username", req.Username).Str("ip", clientIP).Msg("Account is locked due to too many failed login attempts") @@ -80,11 +80,11 @@ func (controller *UserController) loginHandler(c *gin.Context) { return } - userSearch := controller.Auth.SearchUser(req.Username) + userSearch := controller.auth.SearchUser(req.Username) - if userSearch.Type == "" { + if userSearch.Type == "unknown" { log.Warn().Str("username", req.Username).Str("ip", clientIP).Msg("User not found") - controller.Auth.RecordLoginAttempt(rateIdentifier, false) + controller.auth.RecordLoginAttempt(rateIdentifier, false) c.JSON(401, gin.H{ "status": 401, "message": "Unauthorized", @@ -92,9 +92,9 @@ func (controller *UserController) loginHandler(c *gin.Context) { return } - if !controller.Auth.VerifyUser(userSearch, req.Password) { + if !controller.auth.VerifyUser(userSearch, req.Password) { log.Warn().Str("username", req.Username).Str("ip", clientIP).Msg("Invalid password") - controller.Auth.RecordLoginAttempt(rateIdentifier, false) + controller.auth.RecordLoginAttempt(rateIdentifier, false) c.JSON(401, gin.H{ "status": 401, "message": "Unauthorized", @@ -104,18 +104,18 @@ func (controller *UserController) loginHandler(c *gin.Context) { log.Info().Str("username", req.Username).Str("ip", clientIP).Msg("Login successful") - controller.Auth.RecordLoginAttempt(rateIdentifier, true) + controller.auth.RecordLoginAttempt(rateIdentifier, true) if userSearch.Type == "local" { - user := controller.Auth.GetLocalUser(userSearch.Username) + user := controller.auth.GetLocalUser(userSearch.Username) if user.TotpSecret != "" { log.Debug().Str("username", req.Username).Msg("User has TOTP enabled, requiring TOTP verification") - err := controller.Auth.CreateSessionCookie(c, &config.SessionCookie{ + err := controller.auth.CreateSessionCookie(c, &config.SessionCookie{ Username: user.Username, Name: utils.Capitalize(req.Username), - Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.Config.RootDomain), + Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.config.RootDomain), Provider: "username", TotpPending: true, }) @@ -138,10 +138,10 @@ func (controller *UserController) loginHandler(c *gin.Context) { } } - err = controller.Auth.CreateSessionCookie(c, &config.SessionCookie{ + err = controller.auth.CreateSessionCookie(c, &config.SessionCookie{ Username: req.Username, Name: utils.Capitalize(req.Username), - Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.Config.RootDomain), + Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.config.RootDomain), Provider: "username", }) @@ -163,7 +163,7 @@ func (controller *UserController) loginHandler(c *gin.Context) { func (controller *UserController) logoutHandler(c *gin.Context) { log.Debug().Msg("Logout request received") - controller.Auth.DeleteSessionCookie(c) + controller.auth.DeleteSessionCookie(c) c.JSON(200, gin.H{ "status": 200, @@ -214,24 +214,24 @@ func (controller *UserController) totpHandler(c *gin.Context) { log.Debug().Str("username", context.Username).Str("ip", clientIP).Msg("TOTP verification attempt") - isLocked, remainingTime := controller.Auth.IsAccountLocked(rateIdentifier) + isLocked, remainingTime := controller.auth.IsAccountLocked(rateIdentifier) if isLocked { log.Warn().Str("username", context.Username).Str("ip", clientIP).Msg("Account is locked due to too many failed TOTP attempts") c.JSON(429, gin.H{ "status": 429, - "message": fmt.Sprintf("Too many failed login attempts. Try again in %d seconds", remainingTime), + "message": fmt.Sprintf("Too many failed TOTP attempts. Try again in %d seconds", remainingTime), }) return } - user := controller.Auth.GetLocalUser(context.Username) + user := controller.auth.GetLocalUser(context.Username) ok := totp.Validate(req.Code, user.TotpSecret) if !ok { log.Warn().Str("username", context.Username).Str("ip", clientIP).Msg("Invalid TOTP code") - controller.Auth.RecordLoginAttempt(rateIdentifier, false) + controller.auth.RecordLoginAttempt(rateIdentifier, false) c.JSON(401, gin.H{ "status": 401, "message": "Unauthorized", @@ -241,12 +241,12 @@ func (controller *UserController) totpHandler(c *gin.Context) { log.Info().Str("username", context.Username).Str("ip", clientIP).Msg("TOTP verification successful") - controller.Auth.RecordLoginAttempt(rateIdentifier, true) + controller.auth.RecordLoginAttempt(rateIdentifier, true) - err = controller.Auth.CreateSessionCookie(c, &config.SessionCookie{ + err = controller.auth.CreateSessionCookie(c, &config.SessionCookie{ Username: user.Username, Name: utils.Capitalize(user.Username), - Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), controller.Config.RootDomain), + Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), controller.config.RootDomain), Provider: "username", }) diff --git a/internal/middleware/context_middleware.go b/internal/middleware/context_middleware.go index bca0400..cbf9412 100644 --- a/internal/middleware/context_middleware.go +++ b/internal/middleware/context_middleware.go @@ -16,16 +16,16 @@ type ContextMiddlewareConfig struct { } type ContextMiddleware struct { - Config ContextMiddlewareConfig - Auth *service.AuthService - Broker *service.OAuthBrokerService + config ContextMiddlewareConfig + auth *service.AuthService + broker *service.OAuthBrokerService } func NewContextMiddleware(config ContextMiddlewareConfig, auth *service.AuthService, broker *service.OAuthBrokerService) *ContextMiddleware { return &ContextMiddleware{ - Config: config, - Auth: auth, - Broker: broker, + config: config, + auth: auth, + broker: broker, } } @@ -35,7 +35,7 @@ func (m *ContextMiddleware) Init() error { func (m *ContextMiddleware) Middleware() gin.HandlerFunc { return func(c *gin.Context) { - cookie, err := m.Auth.GetSessionCookie(c) + cookie, err := m.auth.GetSessionCookie(c) if err != nil { log.Debug().Err(err).Msg("No valid session cookie found") @@ -57,11 +57,11 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { switch cookie.Provider { case "username": - userSearch := m.Auth.SearchUser(cookie.Username) + userSearch := m.auth.SearchUser(cookie.Username) - if userSearch.Type == "unknown" { + if userSearch.Type == "unknown" || userSearch.Type == "error" { log.Debug().Msg("User from session cookie not found") - m.Auth.DeleteSessionCookie(c) + m.auth.DeleteSessionCookie(c) goto basic } @@ -75,17 +75,17 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { c.Next() return default: - _, exists := m.Broker.GetService(cookie.Provider) + _, exists := m.broker.GetService(cookie.Provider) if !exists { log.Debug().Msg("OAuth provider from session cookie not found") - m.Auth.DeleteSessionCookie(c) + m.auth.DeleteSessionCookie(c) goto basic } - if !m.Auth.IsEmailWhitelisted(cookie.Email) { + if !m.auth.IsEmailWhitelisted(cookie.Email) { log.Debug().Msg("Email from session cookie not whitelisted") - m.Auth.DeleteSessionCookie(c) + m.auth.DeleteSessionCookie(c) goto basic } @@ -103,7 +103,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { } basic: - basic := m.Auth.GetBasicAuth(c) + basic := m.auth.GetBasicAuth(c) if basic == nil { log.Debug().Msg("No basic auth provided") @@ -111,15 +111,15 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { return } - userSearch := m.Auth.SearchUser(basic.Username) + userSearch := m.auth.SearchUser(basic.Username) - if userSearch.Type == "unknown" { + if userSearch.Type == "unknown" || userSearch.Type == "error" { log.Debug().Msg("User from basic auth not found") c.Next() return } - if !m.Auth.VerifyUser(userSearch, basic.Password) { + if !m.auth.VerifyUser(userSearch, basic.Password) { log.Debug().Msg("Invalid password for basic auth user") c.Next() return @@ -129,12 +129,12 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { case "local": log.Debug().Msg("Basic auth user is local") - user := m.Auth.GetLocalUser(basic.Username) + user := m.auth.GetLocalUser(basic.Username) c.Set("context", &config.UserContext{ Username: user.Username, Name: utils.Capitalize(user.Username), - Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), m.Config.RootDomain), + Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), m.config.RootDomain), Provider: "basic", IsLoggedIn: true, TotpEnabled: user.TotpSecret != "", @@ -146,7 +146,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { c.Set("context", &config.UserContext{ Username: basic.Username, Name: utils.Capitalize(basic.Username), - Email: fmt.Sprintf("%s@%s", strings.ToLower(basic.Username), m.Config.RootDomain), + Email: fmt.Sprintf("%s@%s", strings.ToLower(basic.Username), m.config.RootDomain), Provider: "basic", IsLoggedIn: true, }) diff --git a/internal/middleware/ui_middleware.go b/internal/middleware/ui_middleware.go index dcfaa35..ff028a1 100644 --- a/internal/middleware/ui_middleware.go +++ b/internal/middleware/ui_middleware.go @@ -11,8 +11,8 @@ import ( ) type UIMiddleware struct { - UIFS fs.FS - UIFileServer http.Handler + uiFs fs.FS + uiFileServer http.Handler } func NewUIMiddleware() *UIMiddleware { @@ -26,8 +26,8 @@ func (m *UIMiddleware) Init() error { return err } - m.UIFS = ui - m.UIFileServer = http.FileServer(http.FS(ui)) + m.uiFs = ui + m.uiFileServer = http.FileServer(http.FS(ui)) return nil } @@ -42,13 +42,13 @@ func (m *UIMiddleware) Middleware() gin.HandlerFunc { c.Next() return default: - _, err := fs.Stat(m.UIFS, strings.TrimPrefix(c.Request.URL.Path, "/")) + _, err := fs.Stat(m.uiFs, strings.TrimPrefix(c.Request.URL.Path, "/")) if os.IsNotExist(err) { c.Request.URL.Path = "/" } - m.UIFileServer.ServeHTTP(c.Writer, c.Request) + m.uiFileServer.ServeHTTP(c.Writer, c.Request) c.Abort() return } diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index f028149..cb14a7e 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -35,21 +35,21 @@ type AuthServiceConfig struct { } type AuthService struct { - Config AuthServiceConfig - Docker *DockerService - LoginAttempts map[string]*LoginAttempt - LoginMutex sync.RWMutex - LDAP *LdapService - Database *gorm.DB + config AuthServiceConfig + docker *DockerService + loginAttempts map[string]*LoginAttempt + loginMutex sync.RWMutex + ldap *LdapService + database *gorm.DB } func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapService, database *gorm.DB) *AuthService { return &AuthService{ - Config: config, - Docker: docker, - LoginAttempts: make(map[string]*LoginAttempt), - LDAP: ldap, - Database: database, + config: config, + docker: docker, + loginAttempts: make(map[string]*LoginAttempt), + ldap: ldap, + database: database, } } @@ -65,12 +65,14 @@ func (auth *AuthService) SearchUser(username string) config.UserSearch { } } - if auth.LDAP != nil { - userDN, err := auth.LDAP.Search(username) + if auth.ldap != nil { + userDN, err := auth.ldap.Search(username) if err != nil { log.Warn().Err(err).Str("username", username).Msg("Failed to search for user in LDAP") - return config.UserSearch{} + return config.UserSearch{ + Type: "error", + } } return config.UserSearch{ @@ -90,14 +92,14 @@ func (auth *AuthService) VerifyUser(search config.UserSearch, password string) b user := auth.GetLocalUser(search.Username) return auth.CheckPassword(user, password) case "ldap": - if auth.LDAP != nil { - err := auth.LDAP.Bind(search.Username, password) + if auth.ldap != nil { + err := auth.ldap.Bind(search.Username, password) if err != nil { log.Warn().Err(err).Str("username", search.Username).Msg("Failed to bind to LDAP") return false } - err = auth.LDAP.Bind(auth.LDAP.Config.BindDN, auth.LDAP.Config.BindPassword) + err = auth.ldap.Bind(auth.ldap.Config.BindDN, auth.ldap.Config.BindPassword) if err != nil { log.Error().Err(err).Msg("Failed to rebind with service account after user authentication") return false @@ -115,7 +117,7 @@ func (auth *AuthService) VerifyUser(search config.UserSearch, password string) b } func (auth *AuthService) GetLocalUser(username string) config.User { - for _, user := range auth.Config.Users { + for _, user := range auth.config.Users { if user.Username == username { return user } @@ -130,14 +132,14 @@ func (auth *AuthService) CheckPassword(user config.User, password string) bool { } func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) { - auth.LoginMutex.RLock() - defer auth.LoginMutex.RUnlock() + auth.loginMutex.RLock() + defer auth.loginMutex.RUnlock() - if auth.Config.LoginMaxRetries <= 0 || auth.Config.LoginTimeout <= 0 { + if auth.config.LoginMaxRetries <= 0 || auth.config.LoginTimeout <= 0 { return false, 0 } - attempt, exists := auth.LoginAttempts[identifier] + attempt, exists := auth.loginAttempts[identifier] if !exists { return false, 0 } @@ -151,17 +153,17 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) { } func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) { - if auth.Config.LoginMaxRetries <= 0 || auth.Config.LoginTimeout <= 0 { + if auth.config.LoginMaxRetries <= 0 || auth.config.LoginTimeout <= 0 { return } - auth.LoginMutex.Lock() - defer auth.LoginMutex.Unlock() + auth.loginMutex.Lock() + defer auth.loginMutex.Unlock() - attempt, exists := auth.LoginAttempts[identifier] + attempt, exists := auth.loginAttempts[identifier] if !exists { attempt = &LoginAttempt{} - auth.LoginAttempts[identifier] = attempt + auth.loginAttempts[identifier] = attempt } attempt.LastAttempt = time.Now() @@ -174,14 +176,14 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) { attempt.FailedAttempts++ - if attempt.FailedAttempts >= auth.Config.LoginMaxRetries { - attempt.LockedUntil = time.Now().Add(time.Duration(auth.Config.LoginTimeout) * time.Second) - log.Warn().Str("identifier", identifier).Int("timeout", auth.Config.LoginTimeout).Msg("Account locked due to too many failed login attempts") + if attempt.FailedAttempts >= auth.config.LoginMaxRetries { + attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.LoginTimeout) * time.Second) + log.Warn().Str("identifier", identifier).Int("timeout", auth.config.LoginTimeout).Msg("Account locked due to too many failed login attempts") } } func (auth *AuthService) IsEmailWhitelisted(email string) bool { - return utils.CheckFilter(auth.Config.OauthWhitelist, email) + return utils.CheckFilter(auth.config.OauthWhitelist, email) } func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *config.SessionCookie) error { @@ -196,7 +198,7 @@ func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *config.Sessio if data.TotpPending { expiry = 3600 } else { - expiry = auth.Config.SessionExpiry + expiry = auth.config.SessionExpiry } session := model.Session{ @@ -210,37 +212,37 @@ func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *config.Sessio Expiry: time.Now().Add(time.Duration(expiry) * time.Second).Unix(), } - err = auth.Database.Create(&session).Error + err = auth.database.Create(&session).Error if err != nil { return err } - c.SetCookie(auth.Config.SessionCookieName, session.UUID, expiry, "/", fmt.Sprintf(".%s", auth.Config.RootDomain), auth.Config.SecureCookie, true) + c.SetCookie(auth.config.SessionCookieName, session.UUID, expiry, "/", fmt.Sprintf(".%s", auth.config.RootDomain), auth.config.SecureCookie, true) return nil } func (auth *AuthService) DeleteSessionCookie(c *gin.Context) error { - cookie, err := c.Cookie(auth.Config.SessionCookieName) + cookie, err := c.Cookie(auth.config.SessionCookieName) if err != nil { return err } - res := auth.Database.Unscoped().Where("uuid = ?", cookie).Delete(&model.Session{}) + res := auth.database.Unscoped().Where("uuid = ?", cookie).Delete(&model.Session{}) if res.Error != nil { return res.Error } - c.SetCookie(auth.Config.SessionCookieName, "", -1, "/", fmt.Sprintf(".%s", auth.Config.RootDomain), auth.Config.SecureCookie, true) + c.SetCookie(auth.config.SessionCookieName, "", -1, "/", fmt.Sprintf(".%s", auth.config.RootDomain), auth.config.SecureCookie, true) return nil } func (auth *AuthService) GetSessionCookie(c *gin.Context) (config.SessionCookie, error) { - cookie, err := c.Cookie(auth.Config.SessionCookieName) + cookie, err := c.Cookie(auth.config.SessionCookieName) if err != nil { return config.SessionCookie{}, err @@ -248,7 +250,7 @@ func (auth *AuthService) GetSessionCookie(c *gin.Context) (config.SessionCookie, var session model.Session - res := auth.Database.Unscoped().Where("uuid = ?", cookie).First(&session) + res := auth.database.Unscoped().Where("uuid = ?", cookie).First(&session) if res.Error != nil { return config.SessionCookie{}, res.Error @@ -261,7 +263,7 @@ func (auth *AuthService) GetSessionCookie(c *gin.Context) (config.SessionCookie, currentTime := time.Now().Unix() if currentTime > session.Expiry { - res := auth.Database.Unscoped().Where("uuid = ?", session.UUID).Delete(&model.Session{}) + res := auth.database.Unscoped().Where("uuid = ?", session.UUID).Delete(&model.Session{}) if res.Error != nil { log.Error().Err(res.Error).Msg("Failed to delete expired session") } @@ -280,7 +282,7 @@ func (auth *AuthService) GetSessionCookie(c *gin.Context) (config.SessionCookie, } func (auth *AuthService) UserAuthConfigured() bool { - return len(auth.Config.Users) > 0 || auth.LDAP != nil + return len(auth.config.Users) > 0 || auth.ldap != nil } func (auth *AuthService) IsResourceAllowed(c *gin.Context, context config.UserContext, labels config.AppLabels) bool { diff --git a/internal/service/database_service.go b/internal/service/database_service.go index 858ba4c..eb75b9f 100644 --- a/internal/service/database_service.go +++ b/internal/service/database_service.go @@ -16,18 +16,18 @@ type DatabaseServiceConfig struct { } type DatabaseService struct { - Config DatabaseServiceConfig - Database *gorm.DB + config DatabaseServiceConfig + database *gorm.DB } func NewDatabaseService(config DatabaseServiceConfig) *DatabaseService { return &DatabaseService{ - Config: config, + config: config, } } func (ds *DatabaseService) Init() error { - gormDB, err := gorm.Open(sqlite.Open(ds.Config.DatabasePath), &gorm.Config{}) + gormDB, err := gorm.Open(sqlite.Open(ds.config.DatabasePath), &gorm.Config{}) if err != nil { return err @@ -47,7 +47,7 @@ func (ds *DatabaseService) Init() error { return err } - ds.Database = gormDB + ds.database = gormDB return nil } @@ -74,5 +74,5 @@ func (ds *DatabaseService) migrateDatabase(sqlDB *sql.DB) error { } func (ds *DatabaseService) GetDatabase() *gorm.DB { - return ds.Database + return ds.database } diff --git a/internal/service/docker_service.go b/internal/service/docker_service.go index e078a7e..f4ce236 100644 --- a/internal/service/docker_service.go +++ b/internal/service/docker_service.go @@ -12,8 +12,8 @@ import ( ) type DockerService struct { - Client *client.Client - Context context.Context + client *client.Client + context context.Context } func NewDockerService() *DockerService { @@ -29,13 +29,13 @@ func (docker *DockerService) Init() error { ctx := context.Background() client.NegotiateAPIVersion(ctx) - docker.Client = client - docker.Context = ctx + docker.client = client + docker.context = ctx return nil } func (docker *DockerService) GetContainers() ([]container.Summary, error) { - containers, err := docker.Client.ContainerList(docker.Context, container.ListOptions{}) + containers, err := docker.client.ContainerList(docker.context, container.ListOptions{}) if err != nil { return nil, err } @@ -43,7 +43,7 @@ func (docker *DockerService) GetContainers() ([]container.Summary, error) { } func (docker *DockerService) InspectContainer(containerId string) (container.InspectResponse, error) { - inspect, err := docker.Client.ContainerInspect(docker.Context, containerId) + inspect, err := docker.client.ContainerInspect(docker.context, containerId) if err != nil { return container.InspectResponse{}, err } @@ -51,11 +51,11 @@ func (docker *DockerService) InspectContainer(containerId string) (container.Ins } func (docker *DockerService) DockerConnected() bool { - _, err := docker.Client.Ping(docker.Context) + _, err := docker.client.Ping(docker.context) return err == nil } -func (docker *DockerService) GetLabels(app string, domain string) (config.AppLabels, error) { +func (docker *DockerService) GetLabels(appDomain string) (config.AppLabels, error) { isConnected := docker.DockerConnected() if !isConnected { @@ -68,21 +68,21 @@ func (docker *DockerService) GetLabels(app string, domain string) (config.AppLab return config.AppLabels{}, err } - for _, container := range containers { - inspect, err := docker.InspectContainer(container.ID) + for _, ctr := range containers { + inspect, err := docker.InspectContainer(ctr.ID) if err != nil { - log.Warn().Str("id", container.ID).Err(err).Msg("Error inspecting container, skipping") + log.Warn().Str("id", ctr.ID).Err(err).Msg("Error inspecting container, skipping") continue } labels, err := utils.GetLabels(inspect.Config.Labels) if err != nil { - log.Warn().Str("id", container.ID).Err(err).Msg("Error getting container labels, skipping") + log.Warn().Str("id", ctr.ID).Err(err).Msg("Error getting container labels, skipping") continue } for appName, appLabels := range labels.Apps { - if appLabels.Config.Domain == domain { + if appLabels.Config.Domain == appDomain { log.Debug().Str("id", inspect.ID).Msg("Found matching container by domain") return appLabels, nil } diff --git a/internal/service/generic_oauth_service.go b/internal/service/generic_oauth_service.go index c16384d..72c2357 100644 --- a/internal/service/generic_oauth_service.go +++ b/internal/service/generic_oauth_service.go @@ -16,17 +16,17 @@ import ( ) type GenericOAuthService struct { - Config oauth2.Config - Context context.Context - Token *oauth2.Token - Verifier string - InsecureSkipVerify bool - UserinfoURL string + config oauth2.Config + context context.Context + token *oauth2.Token + verifier string + insecureSkipVerify bool + userinfoUrl string } func NewGenericOAuthService(config config.OAuthServiceConfig) *GenericOAuthService { return &GenericOAuthService{ - Config: oauth2.Config{ + config: oauth2.Config{ ClientID: config.ClientID, ClientSecret: config.ClientSecret, RedirectURL: config.RedirectURL, @@ -36,15 +36,15 @@ func NewGenericOAuthService(config config.OAuthServiceConfig) *GenericOAuthServi TokenURL: config.TokenURL, }, }, - InsecureSkipVerify: config.InsecureSkipVerify, - UserinfoURL: config.UserinfoURL, + insecureSkipVerify: config.InsecureSkipVerify, + userinfoUrl: config.UserinfoURL, } } func (generic *GenericOAuthService) Init() error { transport := &http.Transport{ TLSClientConfig: &tls.Config{ - InsecureSkipVerify: generic.InsecureSkipVerify, + InsecureSkipVerify: generic.insecureSkipVerify, MinVersion: tls.VersionTLS12, }, } @@ -58,8 +58,8 @@ func (generic *GenericOAuthService) Init() error { ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) verifier := oauth2.GenerateVerifier() - generic.Context = ctx - generic.Verifier = verifier + generic.context = ctx + generic.verifier = verifier return nil } @@ -74,26 +74,26 @@ func (generic *GenericOAuthService) GenerateState() string { } func (generic *GenericOAuthService) GetAuthURL(state string) string { - return generic.Config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(generic.Verifier)) + return generic.config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(generic.verifier)) } func (generic *GenericOAuthService) VerifyCode(code string) error { - token, err := generic.Config.Exchange(generic.Context, code, oauth2.VerifierOption(generic.Verifier)) + token, err := generic.config.Exchange(generic.context, code, oauth2.VerifierOption(generic.verifier)) if err != nil { return err } - generic.Token = token + generic.token = token return nil } func (generic *GenericOAuthService) Userinfo() (config.Claims, error) { var user config.Claims - client := generic.Config.Client(generic.Context, generic.Token) + client := generic.config.Client(generic.context, generic.token) - res, err := client.Get(generic.UserinfoURL) + res, err := client.Get(generic.userinfoUrl) if err != nil { return user, err } diff --git a/internal/service/github_oauth_service.go b/internal/service/github_oauth_service.go index 7f8466b..26d73b1 100644 --- a/internal/service/github_oauth_service.go +++ b/internal/service/github_oauth_service.go @@ -29,15 +29,15 @@ type GithubUserInfoResponse struct { } type GithubOAuthService struct { - Config oauth2.Config - Context context.Context - Token *oauth2.Token - Verifier string + config oauth2.Config + context context.Context + token *oauth2.Token + verifier string } func NewGithubOAuthService(config config.OAuthServiceConfig) *GithubOAuthService { return &GithubOAuthService{ - Config: oauth2.Config{ + config: oauth2.Config{ ClientID: config.ClientID, ClientSecret: config.ClientSecret, RedirectURL: config.RedirectURL, @@ -53,8 +53,8 @@ func (github *GithubOAuthService) Init() error { ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) verifier := oauth2.GenerateVerifier() - github.Context = ctx - github.Verifier = verifier + github.context = ctx + github.verifier = verifier return nil } @@ -69,24 +69,24 @@ func (github *GithubOAuthService) GenerateState() string { } func (github *GithubOAuthService) GetAuthURL(state string) string { - return github.Config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(github.Verifier)) + return github.config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(github.verifier)) } func (github *GithubOAuthService) VerifyCode(code string) error { - token, err := github.Config.Exchange(github.Context, code, oauth2.VerifierOption(github.Verifier)) + token, err := github.config.Exchange(github.context, code, oauth2.VerifierOption(github.verifier)) if err != nil { return err } - github.Token = token + github.token = token return nil } func (github *GithubOAuthService) Userinfo() (config.Claims, error) { var user config.Claims - client := github.Config.Client(github.Context, github.Token) + client := github.config.Client(github.context, github.token) req, err := http.NewRequest("GET", "https://api.github.com/user", nil) if err != nil { diff --git a/internal/service/google_oauth_service.go b/internal/service/google_oauth_service.go index 1605a85..0f8c7eb 100644 --- a/internal/service/google_oauth_service.go +++ b/internal/service/google_oauth_service.go @@ -24,15 +24,15 @@ type GoogleUserInfoResponse struct { } type GoogleOAuthService struct { - Config oauth2.Config - Context context.Context - Token *oauth2.Token - Verifier string + config oauth2.Config + context context.Context + token *oauth2.Token + verifier string } func NewGoogleOAuthService(config config.OAuthServiceConfig) *GoogleOAuthService { return &GoogleOAuthService{ - Config: oauth2.Config{ + config: oauth2.Config{ ClientID: config.ClientID, ClientSecret: config.ClientSecret, RedirectURL: config.RedirectURL, @@ -48,8 +48,8 @@ func (google *GoogleOAuthService) Init() error { ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) verifier := oauth2.GenerateVerifier() - google.Context = ctx - google.Verifier = verifier + google.context = ctx + google.verifier = verifier return nil } @@ -64,24 +64,24 @@ func (oauth *GoogleOAuthService) GenerateState() string { } func (google *GoogleOAuthService) GetAuthURL(state string) string { - return google.Config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(google.Verifier)) + return google.config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(google.verifier)) } func (google *GoogleOAuthService) VerifyCode(code string) error { - token, err := google.Config.Exchange(google.Context, code, oauth2.VerifierOption(google.Verifier)) + token, err := google.config.Exchange(google.context, code, oauth2.VerifierOption(google.verifier)) if err != nil { return err } - google.Token = token + google.token = token return nil } func (google *GoogleOAuthService) Userinfo() (config.Claims, error) { var user config.Claims - client := google.Config.Client(google.Context, google.Token) + client := google.config.Client(google.context, google.token) res, err := client.Get("https://www.googleapis.com/userinfo/v2/me") if err != nil { diff --git a/internal/service/ldap_service.go b/internal/service/ldap_service.go index b3a1d86..5734c63 100644 --- a/internal/service/ldap_service.go +++ b/internal/service/ldap_service.go @@ -22,9 +22,9 @@ type LdapServiceConfig struct { } type LdapService struct { - Config LdapServiceConfig - Conn *ldapgo.Conn - Mutex sync.RWMutex + Config LdapServiceConfig // exported so as the auth service can use it + conn *ldapgo.Conn + mutex sync.RWMutex } func NewLdapService(config LdapServiceConfig) *LdapService { @@ -57,7 +57,8 @@ func (ldap *LdapService) Init() error { } func (ldap *LdapService) connect() (*ldapgo.Conn, error) { - ldap.Mutex.Lock() + ldap.mutex.Lock() + defer ldap.mutex.Unlock() conn, err := ldapgo.DialURL(ldap.Config.Address, ldapgo.DialWithTLSConfig(&tls.Config{ InsecureSkipVerify: ldap.Config.Insecure, @@ -72,10 +73,8 @@ func (ldap *LdapService) connect() (*ldapgo.Conn, error) { return nil, err } - ldap.Mutex.Unlock() - // Set and return the connection - ldap.Conn = conn + ldap.conn = conn return conn, nil } @@ -92,12 +91,13 @@ func (ldap *LdapService) Search(username string) (string, error) { nil, ) - ldap.Mutex.Lock() - searchResult, err := ldap.Conn.Search(searchRequest) + ldap.mutex.Lock() + defer ldap.mutex.Unlock() + + searchResult, err := ldap.conn.Search(searchRequest) if err != nil { return "", err } - ldap.Mutex.Unlock() if len(searchResult.Entries) != 1 { return "", fmt.Errorf("multiple or no entries found for user %s", username) @@ -108,12 +108,12 @@ func (ldap *LdapService) Search(username string) (string, error) { } func (ldap *LdapService) Bind(userDN string, password string) error { - ldap.Mutex.Lock() - err := ldap.Conn.Bind(userDN, password) + ldap.mutex.Lock() + defer ldap.mutex.Unlock() + err := ldap.conn.Bind(userDN, password) if err != nil { return err } - ldap.Mutex.Unlock() return nil } @@ -128,12 +128,12 @@ func (ldap *LdapService) heartbeat() error { nil, ) - ldap.Mutex.Lock() - _, err := ldap.Conn.Search(searchRequest) + ldap.mutex.Lock() + defer ldap.mutex.Unlock() + _, err := ldap.conn.Search(searchRequest) if err != nil { return err } - ldap.Mutex.Unlock() // No error means the connection is alive return nil @@ -149,7 +149,7 @@ func (ldap *LdapService) reconnect() error { exp.Reset() operation := func() (*ldapgo.Conn, error) { - ldap.Conn.Close() + ldap.conn.Close() conn, err := ldap.connect() if err != nil { return nil, err diff --git a/internal/service/oauth_broker_service.go b/internal/service/oauth_broker_service.go index 6b5b1e6..f9df4f8 100644 --- a/internal/service/oauth_broker_service.go +++ b/internal/service/oauth_broker_service.go @@ -5,6 +5,7 @@ import ( "tinyauth/internal/config" "github.com/rs/zerolog/log" + "golang.org/x/exp/slices" ) type OAuthService interface { @@ -16,59 +17,60 @@ type OAuthService interface { } type OAuthBrokerService struct { - Services map[string]OAuthService - Configs map[string]config.OAuthServiceConfig + services map[string]OAuthService + configs map[string]config.OAuthServiceConfig } func NewOAuthBrokerService(configs map[string]config.OAuthServiceConfig) *OAuthBrokerService { return &OAuthBrokerService{ - Services: make(map[string]OAuthService), - Configs: configs, + services: make(map[string]OAuthService), + configs: configs, } } func (broker *OAuthBrokerService) Init() error { - for name, cfg := range broker.Configs { + for name, cfg := range broker.configs { switch name { case "github": service := NewGithubOAuthService(cfg) - broker.Services[name] = service + broker.services[name] = service case "google": service := NewGoogleOAuthService(cfg) - broker.Services[name] = service + broker.services[name] = service default: service := NewGenericOAuthService(cfg) - broker.Services[name] = service + broker.services[name] = service } } - for name, service := range broker.Services { + for name, service := range broker.services { err := service.Init() if err != nil { - log.Error().Err(err).Msgf("Failed to initialize OAuth service: %s", name) + log.Error().Err(err).Msgf("Failed to initialize OAuth service: %T", name) return err } - log.Info().Msgf("Initialized OAuth service: %s", name) + log.Info().Msgf("Initialized OAuth service: %T", name) } return nil } func (broker *OAuthBrokerService) GetConfiguredServices() []string { - services := make([]string, 0, len(broker.Services)) - for name := range broker.Services { + services := make([]string, 0, len(broker.services)) + for name := range broker.services { services = append(services, name) } + slices.Sort(services) return services } func (broker *OAuthBrokerService) GetService(name string) (OAuthService, bool) { - service, exists := broker.Services[name] + service, exists := broker.services[name] return service, exists } func (broker *OAuthBrokerService) GetUser(service string) (config.Claims, error) { - oauthService, exists := broker.Services[service] + oauthService, exists := broker.services[service] if !exists { return config.Claims{}, errors.New("oauth service not found") }