diff --git a/internal/controller/oauth_controller.go b/internal/controller/oauth_controller.go index 9414bc0..23d00de 100644 --- a/internal/controller/oauth_controller.go +++ b/internal/controller/oauth_controller.go @@ -169,24 +169,30 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1]) } - var usename string + var username string if user.PreferredUsername != "" { log.Debug().Msg("Using preferred username from OAuth provider") - usename = user.PreferredUsername + username = user.PreferredUsername } else { log.Debug().Msg("No preferred username from OAuth provider, using pseudo username") - usename = strings.Replace(user.Email, "@", "_", -1) + username = strings.Replace(user.Email, "@", "_", -1) } - controller.auth.CreateSessionCookie(c, &config.SessionCookie{ - Username: usename, + err = controller.auth.CreateSessionCookie(c, &config.SessionCookie{ + Username: username, Name: name, Email: user.Email, Provider: req.Provider, OAuthGroups: utils.CoalesceToString(user.Groups), }) + if err != nil { + log.Error().Err(err).Msg("Failed to create session cookie") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + return + } + redirectURI, err := c.Cookie(controller.config.RedirectCookieName) if err != nil || !utils.IsRedirectSafe(redirectURI, controller.config.RootDomain) { diff --git a/internal/controller/proxy_controller.go b/internal/controller/proxy_controller.go index 5cf182f..7be9743 100644 --- a/internal/controller/proxy_controller.go +++ b/internal/controller/proxy_controller.go @@ -67,23 +67,11 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { proto := c.Request.Header.Get("X-Forwarded-Proto") host := c.Request.Header.Get("X-Forwarded-Host") - hostWithoutPort := strings.Split(host, ":")[0] - id := strings.Split(hostWithoutPort, ".")[0] - - labels, err := controller.docker.GetLabels(id, hostWithoutPort) + labels, err := controller.docker.GetLabels(host) if err != nil { log.Error().Err(err).Msg("Failed to get labels from Docker") - - if req.Proxy == "nginx" || !isBrowser { - c.JSON(500, gin.H{ - "status": 500, - "message": "Internal Server Error", - }) - return - } - - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + controller.handleError(c, req, isBrowser) return } @@ -91,20 +79,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { if controller.auth.IsBypassedIP(labels.IP, clientIP) { c.Header("Authorization", c.Request.Header.Get("Authorization")) - - headers := utils.ParseHeaders(labels.Response.Headers) - - for key, value := range headers { - log.Debug().Str("header", key).Msg("Setting header") - c.Header(key, value) - } - - basicPassword := utils.GetSecret(labels.Response.BasicAuth.Password, labels.Response.BasicAuth.PasswordFile) - if labels.Response.BasicAuth.Username != "" && basicPassword != "" { - log.Debug().Str("username", labels.Response.BasicAuth.Username).Msg("Setting basic auth header") - c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Response.BasicAuth.Username, basicPassword))) - } - + controller.setHeaders(c, labels) c.JSON(200, gin.H{ "status": 200, "message": "Authenticated", @@ -116,37 +91,13 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { if err != nil { log.Error().Err(err).Msg("Failed to check if auth is enabled for resource") - - if req.Proxy == "nginx" || !isBrowser { - c.JSON(500, gin.H{ - "status": 500, - "message": "Internal Server Error", - }) - return - } - - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + controller.handleError(c, req, isBrowser) return } if !authEnabled { log.Debug().Msg("Authentication disabled for resource, allowing access") - - c.Header("Authorization", c.Request.Header.Get("Authorization")) - - headers := utils.ParseHeaders(labels.Response.Headers) - - for key, value := range headers { - log.Debug().Str("header", key).Msg("Setting header") - c.Header(key, value) - } - - basicPassword := utils.GetSecret(labels.Response.BasicAuth.Password, labels.Response.BasicAuth.PasswordFile) - if labels.Response.BasicAuth.Username != "" && basicPassword != "" { - log.Debug().Str("username", labels.Response.BasicAuth.Username).Msg("Setting basic auth header") - c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Response.BasicAuth.Username, basicPassword))) - } - + controller.setHeaders(c, labels) c.JSON(200, gin.H{ "status": 200, "message": "Authenticated", @@ -272,18 +223,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { c.Header("Remote-Email", utils.SanitizeHeader(userContext.Email)) c.Header("Remote-Groups", utils.SanitizeHeader(userContext.OAuthGroups)) - headers := utils.ParseHeaders(labels.Response.Headers) - - for key, value := range headers { - log.Debug().Str("header", key).Msg("Setting header") - c.Header(key, value) - } - - basicPassword := utils.GetSecret(labels.Response.BasicAuth.Password, labels.Response.BasicAuth.PasswordFile) - if labels.Response.BasicAuth.Username != "" && basicPassword != "" { - log.Debug().Str("username", labels.Response.BasicAuth.Username).Msg("Setting basic auth header") - c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Response.BasicAuth.Username, basicPassword))) - } + controller.setHeaders(c, labels) c.JSON(200, gin.H{ "status": 200, @@ -312,3 +252,33 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/login?%s", controller.config.AppURL, queries.Encode())) } + +func (controller *ProxyController) setHeaders(c *gin.Context, labels config.AppLabels) { + c.Header("Authorization", c.Request.Header.Get("Authorization")) + + headers := utils.ParseHeaders(labels.Response.Headers) + + for key, value := range headers { + log.Debug().Str("header", key).Msg("Setting header") + c.Header(key, value) + } + + basicPassword := utils.GetSecret(labels.Response.BasicAuth.Password, labels.Response.BasicAuth.PasswordFile) + + if labels.Response.BasicAuth.Username != "" && basicPassword != "" { + log.Debug().Str("username", labels.Response.BasicAuth.Username).Msg("Setting basic auth header") + c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Response.BasicAuth.Username, basicPassword))) + } +} + +func (controller *ProxyController) handleError(c *gin.Context, req Proxy, isBrowser bool) { + if req.Proxy == "nginx" || !isBrowser { + c.JSON(500, gin.H{ + "status": 500, + "message": "Internal Server Error", + }) + return + } + + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) +} diff --git a/internal/controller/user_controller.go b/internal/controller/user_controller.go index 617f6a7..7b48652 100644 --- a/internal/controller/user_controller.go +++ b/internal/controller/user_controller.go @@ -82,7 +82,7 @@ func (controller *UserController) loginHandler(c *gin.Context) { userSearch := controller.auth.SearchUser(req.Username) - if userSearch.Type == "" { + if userSearch.Type == "unknown" { log.Warn().Str("username", req.Username).Str("ip", clientIP).Msg("User not found") controller.auth.RecordLoginAttempt(rateIdentifier, false) c.JSON(401, gin.H{ @@ -220,7 +220,7 @@ func (controller *UserController) totpHandler(c *gin.Context) { log.Warn().Str("username", context.Username).Str("ip", clientIP).Msg("Account is locked due to too many failed TOTP attempts") c.JSON(429, gin.H{ "status": 429, - "message": fmt.Sprintf("Too many failed login attempts. Try again in %d seconds", remainingTime), + "message": fmt.Sprintf("Too many failed TOTP attempts. Try again in %d seconds", remainingTime), }) return } diff --git a/internal/middleware/context_middleware.go b/internal/middleware/context_middleware.go index f5b1bbd..cbf9412 100644 --- a/internal/middleware/context_middleware.go +++ b/internal/middleware/context_middleware.go @@ -59,7 +59,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { case "username": userSearch := m.auth.SearchUser(cookie.Username) - if userSearch.Type == "unknown" { + if userSearch.Type == "unknown" || userSearch.Type == "error" { log.Debug().Msg("User from session cookie not found") m.auth.DeleteSessionCookie(c) goto basic @@ -113,7 +113,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { userSearch := m.auth.SearchUser(basic.Username) - if userSearch.Type == "unknown" { + if userSearch.Type == "unknown" || userSearch.Type == "error" { log.Debug().Msg("User from basic auth not found") c.Next() return diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index 713fc63..cb14a7e 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -71,7 +71,7 @@ func (auth *AuthService) SearchUser(username string) config.UserSearch { if err != nil { log.Warn().Err(err).Str("username", username).Msg("Failed to search for user in LDAP") return config.UserSearch{ - Type: "unknown", + Type: "error", } } diff --git a/internal/service/docker_service.go b/internal/service/docker_service.go index 762070e..f4ce236 100644 --- a/internal/service/docker_service.go +++ b/internal/service/docker_service.go @@ -55,7 +55,7 @@ func (docker *DockerService) DockerConnected() bool { return err == nil } -func (docker *DockerService) GetLabels(app string, domain string) (config.AppLabels, error) { +func (docker *DockerService) GetLabels(appDomain string) (config.AppLabels, error) { isConnected := docker.DockerConnected() if !isConnected { @@ -68,21 +68,21 @@ func (docker *DockerService) GetLabels(app string, domain string) (config.AppLab return config.AppLabels{}, err } - for _, container := range containers { - inspect, err := docker.InspectContainer(container.ID) + for _, ctr := range containers { + inspect, err := docker.InspectContainer(ctr.ID) if err != nil { - log.Warn().Str("id", container.ID).Err(err).Msg("Error inspecting container, skipping") + log.Warn().Str("id", ctr.ID).Err(err).Msg("Error inspecting container, skipping") continue } labels, err := utils.GetLabels(inspect.Config.Labels) if err != nil { - log.Warn().Str("id", container.ID).Err(err).Msg("Error getting container labels, skipping") + log.Warn().Str("id", ctr.ID).Err(err).Msg("Error getting container labels, skipping") continue } for appName, appLabels := range labels.Apps { - if appLabels.Config.Domain == domain { + if appLabels.Config.Domain == appDomain { log.Debug().Str("id", inspect.ID).Msg("Found matching container by domain") return appLabels, nil } diff --git a/internal/service/ldap_service.go b/internal/service/ldap_service.go index 3f3781d..5734c63 100644 --- a/internal/service/ldap_service.go +++ b/internal/service/ldap_service.go @@ -92,11 +92,12 @@ func (ldap *LdapService) Search(username string) (string, error) { ) ldap.mutex.Lock() + defer ldap.mutex.Unlock() + searchResult, err := ldap.conn.Search(searchRequest) if err != nil { return "", err } - ldap.mutex.Unlock() if len(searchResult.Entries) != 1 { return "", fmt.Errorf("multiple or no entries found for user %s", username) @@ -128,11 +129,11 @@ func (ldap *LdapService) heartbeat() error { ) ldap.mutex.Lock() + defer ldap.mutex.Unlock() _, err := ldap.conn.Search(searchRequest) if err != nil { return err } - ldap.mutex.Unlock() // No error means the connection is alive return nil diff --git a/internal/service/oauth_broker_service.go b/internal/service/oauth_broker_service.go index 5bc0b63..301dd4e 100644 --- a/internal/service/oauth_broker_service.go +++ b/internal/service/oauth_broker_service.go @@ -5,6 +5,7 @@ import ( "tinyauth/internal/config" "github.com/rs/zerolog/log" + "golang.org/x/exp/slices" ) type OAuthService interface { @@ -59,6 +60,7 @@ func (broker *OAuthBrokerService) GetConfiguredServices() []string { for name := range broker.services { services = append(services, name) } + slices.Sort(services) return services }