feat: coderabbit suggestions

This commit is contained in:
Stavros
2025-09-02 01:11:14 +03:00
parent 00ed365f66
commit 3feb5d3930
8 changed files with 63 additions and 84 deletions

View File

@@ -169,24 +169,30 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1])
}
var usename string
var username string
if user.PreferredUsername != "" {
log.Debug().Msg("Using preferred username from OAuth provider")
usename = user.PreferredUsername
username = user.PreferredUsername
} else {
log.Debug().Msg("No preferred username from OAuth provider, using pseudo username")
usename = strings.Replace(user.Email, "@", "_", -1)
username = strings.Replace(user.Email, "@", "_", -1)
}
controller.auth.CreateSessionCookie(c, &config.SessionCookie{
Username: usename,
err = controller.auth.CreateSessionCookie(c, &config.SessionCookie{
Username: username,
Name: name,
Email: user.Email,
Provider: req.Provider,
OAuthGroups: utils.CoalesceToString(user.Groups),
})
if err != nil {
log.Error().Err(err).Msg("Failed to create session cookie")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}
redirectURI, err := c.Cookie(controller.config.RedirectCookieName)
if err != nil || !utils.IsRedirectSafe(redirectURI, controller.config.RootDomain) {

View File

@@ -67,23 +67,11 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
proto := c.Request.Header.Get("X-Forwarded-Proto")
host := c.Request.Header.Get("X-Forwarded-Host")
hostWithoutPort := strings.Split(host, ":")[0]
id := strings.Split(hostWithoutPort, ".")[0]
labels, err := controller.docker.GetLabels(id, hostWithoutPort)
labels, err := controller.docker.GetLabels(host)
if err != nil {
log.Error().Err(err).Msg("Failed to get labels from Docker")
if req.Proxy == "nginx" || !isBrowser {
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
})
return
}
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
controller.handleError(c, req, isBrowser)
return
}
@@ -91,20 +79,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
if controller.auth.IsBypassedIP(labels.IP, clientIP) {
c.Header("Authorization", c.Request.Header.Get("Authorization"))
headers := utils.ParseHeaders(labels.Response.Headers)
for key, value := range headers {
log.Debug().Str("header", key).Msg("Setting header")
c.Header(key, value)
}
basicPassword := utils.GetSecret(labels.Response.BasicAuth.Password, labels.Response.BasicAuth.PasswordFile)
if labels.Response.BasicAuth.Username != "" && basicPassword != "" {
log.Debug().Str("username", labels.Response.BasicAuth.Username).Msg("Setting basic auth header")
c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Response.BasicAuth.Username, basicPassword)))
}
controller.setHeaders(c, labels)
c.JSON(200, gin.H{
"status": 200,
"message": "Authenticated",
@@ -116,37 +91,13 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
if err != nil {
log.Error().Err(err).Msg("Failed to check if auth is enabled for resource")
if req.Proxy == "nginx" || !isBrowser {
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
})
return
}
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
controller.handleError(c, req, isBrowser)
return
}
if !authEnabled {
log.Debug().Msg("Authentication disabled for resource, allowing access")
c.Header("Authorization", c.Request.Header.Get("Authorization"))
headers := utils.ParseHeaders(labels.Response.Headers)
for key, value := range headers {
log.Debug().Str("header", key).Msg("Setting header")
c.Header(key, value)
}
basicPassword := utils.GetSecret(labels.Response.BasicAuth.Password, labels.Response.BasicAuth.PasswordFile)
if labels.Response.BasicAuth.Username != "" && basicPassword != "" {
log.Debug().Str("username", labels.Response.BasicAuth.Username).Msg("Setting basic auth header")
c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Response.BasicAuth.Username, basicPassword)))
}
controller.setHeaders(c, labels)
c.JSON(200, gin.H{
"status": 200,
"message": "Authenticated",
@@ -272,18 +223,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
c.Header("Remote-Email", utils.SanitizeHeader(userContext.Email))
c.Header("Remote-Groups", utils.SanitizeHeader(userContext.OAuthGroups))
headers := utils.ParseHeaders(labels.Response.Headers)
for key, value := range headers {
log.Debug().Str("header", key).Msg("Setting header")
c.Header(key, value)
}
basicPassword := utils.GetSecret(labels.Response.BasicAuth.Password, labels.Response.BasicAuth.PasswordFile)
if labels.Response.BasicAuth.Username != "" && basicPassword != "" {
log.Debug().Str("username", labels.Response.BasicAuth.Username).Msg("Setting basic auth header")
c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Response.BasicAuth.Username, basicPassword)))
}
controller.setHeaders(c, labels)
c.JSON(200, gin.H{
"status": 200,
@@ -312,3 +252,33 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/login?%s", controller.config.AppURL, queries.Encode()))
}
func (controller *ProxyController) setHeaders(c *gin.Context, labels config.AppLabels) {
c.Header("Authorization", c.Request.Header.Get("Authorization"))
headers := utils.ParseHeaders(labels.Response.Headers)
for key, value := range headers {
log.Debug().Str("header", key).Msg("Setting header")
c.Header(key, value)
}
basicPassword := utils.GetSecret(labels.Response.BasicAuth.Password, labels.Response.BasicAuth.PasswordFile)
if labels.Response.BasicAuth.Username != "" && basicPassword != "" {
log.Debug().Str("username", labels.Response.BasicAuth.Username).Msg("Setting basic auth header")
c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Response.BasicAuth.Username, basicPassword)))
}
}
func (controller *ProxyController) handleError(c *gin.Context, req Proxy, isBrowser bool) {
if req.Proxy == "nginx" || !isBrowser {
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
})
return
}
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
}

View File

@@ -82,7 +82,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
userSearch := controller.auth.SearchUser(req.Username)
if userSearch.Type == "" {
if userSearch.Type == "unknown" {
log.Warn().Str("username", req.Username).Str("ip", clientIP).Msg("User not found")
controller.auth.RecordLoginAttempt(rateIdentifier, false)
c.JSON(401, gin.H{
@@ -220,7 +220,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
log.Warn().Str("username", context.Username).Str("ip", clientIP).Msg("Account is locked due to too many failed TOTP attempts")
c.JSON(429, gin.H{
"status": 429,
"message": fmt.Sprintf("Too many failed login attempts. Try again in %d seconds", remainingTime),
"message": fmt.Sprintf("Too many failed TOTP attempts. Try again in %d seconds", remainingTime),
})
return
}

View File

@@ -59,7 +59,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
case "username":
userSearch := m.auth.SearchUser(cookie.Username)
if userSearch.Type == "unknown" {
if userSearch.Type == "unknown" || userSearch.Type == "error" {
log.Debug().Msg("User from session cookie not found")
m.auth.DeleteSessionCookie(c)
goto basic
@@ -113,7 +113,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
userSearch := m.auth.SearchUser(basic.Username)
if userSearch.Type == "unknown" {
if userSearch.Type == "unknown" || userSearch.Type == "error" {
log.Debug().Msg("User from basic auth not found")
c.Next()
return

View File

@@ -71,7 +71,7 @@ func (auth *AuthService) SearchUser(username string) config.UserSearch {
if err != nil {
log.Warn().Err(err).Str("username", username).Msg("Failed to search for user in LDAP")
return config.UserSearch{
Type: "unknown",
Type: "error",
}
}

View File

@@ -55,7 +55,7 @@ func (docker *DockerService) DockerConnected() bool {
return err == nil
}
func (docker *DockerService) GetLabels(app string, domain string) (config.AppLabels, error) {
func (docker *DockerService) GetLabels(appDomain string) (config.AppLabels, error) {
isConnected := docker.DockerConnected()
if !isConnected {
@@ -68,21 +68,21 @@ func (docker *DockerService) GetLabels(app string, domain string) (config.AppLab
return config.AppLabels{}, err
}
for _, container := range containers {
inspect, err := docker.InspectContainer(container.ID)
for _, ctr := range containers {
inspect, err := docker.InspectContainer(ctr.ID)
if err != nil {
log.Warn().Str("id", container.ID).Err(err).Msg("Error inspecting container, skipping")
log.Warn().Str("id", ctr.ID).Err(err).Msg("Error inspecting container, skipping")
continue
}
labels, err := utils.GetLabels(inspect.Config.Labels)
if err != nil {
log.Warn().Str("id", container.ID).Err(err).Msg("Error getting container labels, skipping")
log.Warn().Str("id", ctr.ID).Err(err).Msg("Error getting container labels, skipping")
continue
}
for appName, appLabels := range labels.Apps {
if appLabels.Config.Domain == domain {
if appLabels.Config.Domain == appDomain {
log.Debug().Str("id", inspect.ID).Msg("Found matching container by domain")
return appLabels, nil
}

View File

@@ -92,11 +92,12 @@ func (ldap *LdapService) Search(username string) (string, error) {
)
ldap.mutex.Lock()
defer ldap.mutex.Unlock()
searchResult, err := ldap.conn.Search(searchRequest)
if err != nil {
return "", err
}
ldap.mutex.Unlock()
if len(searchResult.Entries) != 1 {
return "", fmt.Errorf("multiple or no entries found for user %s", username)
@@ -128,11 +129,11 @@ func (ldap *LdapService) heartbeat() error {
)
ldap.mutex.Lock()
defer ldap.mutex.Unlock()
_, err := ldap.conn.Search(searchRequest)
if err != nil {
return err
}
ldap.mutex.Unlock()
// No error means the connection is alive
return nil

View File

@@ -5,6 +5,7 @@ import (
"tinyauth/internal/config"
"github.com/rs/zerolog/log"
"golang.org/x/exp/slices"
)
type OAuthService interface {
@@ -59,6 +60,7 @@ func (broker *OAuthBrokerService) GetConfiguredServices() []string {
for name := range broker.services {
services = append(services, name)
}
slices.Sort(services)
return services
}