From 598abc5fe1cf08ece529a203ceb41d164ac9a028 Mon Sep 17 00:00:00 2001 From: Stavros Date: Fri, 29 Aug 2025 13:52:47 +0300 Subject: [PATCH] refactor: unify labels --- internal/config/config.go | 100 ++++++++++++++++-------- internal/controller/proxy_controller.go | 81 ++++++++++--------- internal/service/auth_service.go | 32 ++++---- internal/service/docker_service.go | 27 +++---- internal/utils/label_utils.go | 2 +- 5 files changed, 138 insertions(+), 104 deletions(-) diff --git a/internal/config/config.go b/internal/config/config.go index e053f65..7f7129d 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,20 +1,19 @@ package config -type Claims struct { - Name string `json:"name"` - Email string `json:"email"` - PreferredUsername string `json:"preferred_username"` - Groups any `json:"groups"` -} +// Version information, set at build time var Version = "development" var CommitHash = "n/a" var BuildTimestamp = "n/a" +// Cookie name templates + var SessionCookieName = "tinyauth-session" var CSRFCookieName = "tinyauth-csrf" var RedirectCookieName = "tinyauth-redirect" +// Main app config + type Config struct { Port int `mapstructure:"port" validate:"required"` Address string `validate:"required,ip4_addr" mapstructure:"address"` @@ -57,35 +56,13 @@ type Config struct { DatabasePath string `mapstructure:"database-path" validate:"required"` } -type OAuthLabels struct { - Whitelist string - Groups string -} +// OAuth/OIDC config -type BasicLabels struct { - Username string - Password PasswordLabels -} - -type PasswordLabels struct { - Plain string - File string -} - -type IPLabels struct { - Allow []string - Block []string - Bypass []string -} - -type Labels struct { - Users string - Allowed string - Headers []string - Domain []string - Basic BasicLabels - OAuth OAuthLabels - IP IPLabels +type Claims struct { + Name string `json:"name"` + Email string `json:"email"` + PreferredUsername string `json:"preferred_username"` + Groups any `json:"groups"` } type OAuthServiceConfig struct { @@ -99,6 +76,8 @@ type OAuthServiceConfig struct { InsecureSkipVerify bool } +// User/session related stuff + type User struct { Username string Password string @@ -132,6 +111,8 @@ type UserContext struct { TotpEnabled bool } +// API responses and queries + type UnauthorizedQuery struct { Username string `url:"username"` Resource string `url:"resource"` @@ -142,3 +123,54 @@ type UnauthorizedQuery struct { type RedirectQuery struct { RedirectURI string `url:"redirect_uri"` } + +// Labels + +type Labels struct { + Apps map[string]AppLabels +} + +type AppLabels struct { + Config ConfigLabels + Users UsersLabels + OAuth OAuthLabels + IP IPLabels + Response ResponseLabels + Path PathLabels +} + +type ConfigLabels struct { + Domain string +} + +type UsersLabels struct { + Allow string + Block string +} + +type OAuthLabels struct { + Whitelist string + Groups string +} + +type IPLabels struct { + Allow []string + Block []string + Bypass []string +} + +type ResponseLabels struct { + Headers []string + BasicAuth BasicAuthLabels +} + +type BasicAuthLabels struct { + Username string + Password string + PasswordFile string +} + +type PathLabels struct { + Allow string + Block string +} diff --git a/internal/controller/proxy_controller.go b/internal/controller/proxy_controller.go index 6e207e8..8fec341 100644 --- a/internal/controller/proxy_controller.go +++ b/internal/controller/proxy_controller.go @@ -89,19 +89,20 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { clientIP := c.ClientIP() - if controller.Auth.IsBypassedIP(labels, clientIP) { + if controller.Auth.IsBypassedIP(labels.IP, clientIP) { c.Header("Authorization", c.Request.Header.Get("Authorization")) - headers := utils.ParseHeaders(labels.Headers) + headers := utils.ParseHeaders(labels.Response.Headers) for key, value := range headers { log.Debug().Str("header", key).Msg("Setting header") c.Header(key, value) } - if labels.Basic.Username != "" && utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File) != "" { - log.Debug().Str("username", labels.Basic.Username).Msg("Setting basic auth header") - c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Basic.Username, utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File)))) + basicPassword := utils.GetSecret(labels.Response.BasicAuth.Password, labels.Response.BasicAuth.PasswordFile) + if labels.Response.BasicAuth.Username != "" && basicPassword != "" { + log.Debug().Str("username", labels.Response.BasicAuth.Username).Msg("Setting basic auth header") + c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Response.BasicAuth.Username, basicPassword))) } c.JSON(200, gin.H{ @@ -111,31 +112,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { return } - if !controller.Auth.CheckIP(labels, clientIP) { - if req.Proxy == "nginx" || !isBrowser { - c.JSON(401, gin.H{ - "status": 401, - "message": "Unauthorized", - }) - return - } - - queries, err := query.Values(config.UnauthorizedQuery{ - Resource: strings.Split(host, ".")[0], - IP: clientIP, - }) - - if err != nil { - log.Error().Err(err).Msg("Failed to encode unauthorized query") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) - return - } - - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.Config.AppURL, queries.Encode())) - return - } - - authEnabled, err := controller.Auth.IsAuthEnabled(uri, labels) + authEnabled, err := controller.Auth.IsAuthEnabled(uri, labels.Path.Allow) if err != nil { log.Error().Err(err).Msg("Failed to check if auth is enabled for resource") @@ -157,16 +134,17 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { c.Header("Authorization", c.Request.Header.Get("Authorization")) - headers := utils.ParseHeaders(labels.Headers) + headers := utils.ParseHeaders(labels.Response.Headers) for key, value := range headers { log.Debug().Str("header", key).Msg("Setting header") c.Header(key, value) } - if labels.Basic.Username != "" && utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File) != "" { - log.Debug().Str("username", labels.Basic.Username).Msg("Setting basic auth header") - c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Basic.Username, utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File)))) + basicPassword := utils.GetSecret(labels.Response.BasicAuth.Password, labels.Response.BasicAuth.PasswordFile) + if labels.Response.BasicAuth.Username != "" && basicPassword != "" { + log.Debug().Str("username", labels.Response.BasicAuth.Username).Msg("Setting basic auth header") + c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Response.BasicAuth.Username, basicPassword))) } c.JSON(200, gin.H{ @@ -176,6 +154,30 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { return } + if !controller.Auth.CheckIP(labels.IP, clientIP) { + if req.Proxy == "nginx" || !isBrowser { + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } + + queries, err := query.Values(config.UnauthorizedQuery{ + Resource: strings.Split(host, ".")[0], + IP: clientIP, + }) + + if err != nil { + log.Error().Err(err).Msg("Failed to encode unauthorized query") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + return + } + + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.Config.AppURL, queries.Encode())) + return + } + var userContext config.UserContext context, err := utils.GetContext(c) @@ -229,7 +231,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { } if userContext.OAuth { - groupOK := controller.Auth.IsInOAuthGroup(c, userContext, labels) + groupOK := controller.Auth.IsInOAuthGroup(c, userContext, labels.OAuth.Groups) if !groupOK { log.Warn().Str("user", userContext.Username).Str("resource", strings.Split(host, ".")[0]).Msg("User OAuth groups do not match resource requirements") @@ -270,16 +272,17 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { c.Header("Remote-Email", utils.SanitizeHeader(userContext.Email)) c.Header("Remote-Groups", utils.SanitizeHeader(userContext.OAuthGroups)) - headers := utils.ParseHeaders(labels.Headers) + headers := utils.ParseHeaders(labels.Response.Headers) for key, value := range headers { log.Debug().Str("header", key).Msg("Setting header") c.Header(key, value) } - if labels.Basic.Username != "" && utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File) != "" { - log.Debug().Str("username", labels.Basic.Username).Msg("Setting basic auth header") - c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Basic.Username, utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File)))) + basicPassword := utils.GetSecret(labels.Response.BasicAuth.Password, labels.Response.BasicAuth.PasswordFile) + if labels.Response.BasicAuth.Username != "" && basicPassword != "" { + log.Debug().Str("username", labels.Response.BasicAuth.Username).Msg("Setting basic auth header") + c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Response.BasicAuth.Username, basicPassword))) } c.JSON(200, gin.H{ diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index f55961c..7024214 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -283,18 +283,18 @@ func (auth *AuthService) UserAuthConfigured() bool { return len(auth.Config.Users) > 0 || auth.LDAP != nil } -func (auth *AuthService) IsResourceAllowed(c *gin.Context, context config.UserContext, labels config.Labels) bool { +func (auth *AuthService) IsResourceAllowed(c *gin.Context, context config.UserContext, labels config.AppLabels) bool { if context.OAuth { log.Debug().Msg("Checking OAuth whitelist") return utils.CheckFilter(labels.OAuth.Whitelist, context.Email) } log.Debug().Msg("Checking users") - return utils.CheckFilter(labels.Users, context.Username) + return utils.CheckFilter(labels.Users.Allow, context.Username) } -func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context config.UserContext, labels config.Labels) bool { - if labels.OAuth.Groups == "" { +func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context config.UserContext, groups string) bool { + if groups == "" { return true } @@ -304,10 +304,10 @@ func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context config.UserConte } // No need to parse since they are from the API response - oauthGroups := strings.Split(context.OAuthGroups, ",") + groupsSplit := strings.Split(groups, ",") - for _, group := range oauthGroups { - if utils.CheckFilter(labels.OAuth.Groups, group) { + for _, group := range groupsSplit { + if utils.CheckFilter(groups, group) { return true } } @@ -316,12 +316,12 @@ func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context config.UserConte return false } -func (auth *AuthService) IsAuthEnabled(uri string, labels config.Labels) (bool, error) { - if labels.Allowed == "" { +func (auth *AuthService) IsAuthEnabled(uri string, pathAllow string) (bool, error) { + if pathAllow == "" { return true, nil } - regex, err := regexp.Compile(labels.Allowed) + regex, err := regexp.Compile(pathAllow) if err != nil { return true, err @@ -346,8 +346,8 @@ func (auth *AuthService) GetBasicAuth(c *gin.Context) *config.User { } } -func (auth *AuthService) CheckIP(labels config.Labels, ip string) bool { - for _, blocked := range labels.IP.Block { +func (auth *AuthService) CheckIP(labels config.IPLabels, ip string) bool { + for _, blocked := range labels.Block { res, err := utils.FilterIP(blocked, ip) if err != nil { log.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list") @@ -359,7 +359,7 @@ func (auth *AuthService) CheckIP(labels config.Labels, ip string) bool { } } - for _, allowed := range labels.IP.Allow { + for _, allowed := range labels.Allow { res, err := utils.FilterIP(allowed, ip) if err != nil { log.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list") @@ -371,7 +371,7 @@ func (auth *AuthService) CheckIP(labels config.Labels, ip string) bool { } } - if len(labels.IP.Allow) > 0 { + if len(labels.Allow) > 0 { log.Debug().Str("ip", ip).Msg("IP not in allow list, denying access") return false } @@ -380,8 +380,8 @@ func (auth *AuthService) CheckIP(labels config.Labels, ip string) bool { return true } -func (auth *AuthService) IsBypassedIP(labels config.Labels, ip string) bool { - for _, bypassed := range labels.IP.Bypass { +func (auth *AuthService) IsBypassedIP(labels config.IPLabels, ip string) bool { + for _, bypassed := range labels.Bypass { res, err := utils.FilterIP(bypassed, ip) if err != nil { log.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list") diff --git a/internal/service/docker_service.go b/internal/service/docker_service.go index 41eb07c..e078a7e 100644 --- a/internal/service/docker_service.go +++ b/internal/service/docker_service.go @@ -6,8 +6,6 @@ import ( "tinyauth/internal/config" "tinyauth/internal/utils" - "slices" - container "github.com/docker/docker/api/types/container" "github.com/docker/docker/client" "github.com/rs/zerolog/log" @@ -57,17 +55,17 @@ func (docker *DockerService) DockerConnected() bool { return err == nil } -func (docker *DockerService) GetLabels(app string, domain string) (config.Labels, error) { +func (docker *DockerService) GetLabels(app string, domain string) (config.AppLabels, error) { isConnected := docker.DockerConnected() if !isConnected { log.Debug().Msg("Docker not connected, returning empty labels") - return config.Labels{}, nil + return config.AppLabels{}, nil } containers, err := docker.GetContainers() if err != nil { - return config.Labels{}, err + return config.AppLabels{}, err } for _, container := range containers { @@ -83,18 +81,19 @@ func (docker *DockerService) GetLabels(app string, domain string) (config.Labels continue } - // Check if the container matches the ID or domain - if slices.Contains(labels.Domain, domain) { - log.Debug().Str("id", inspect.ID).Msg("Found matching container by domain") - return labels, nil - } + for appName, appLabels := range labels.Apps { + if appLabels.Config.Domain == domain { + log.Debug().Str("id", inspect.ID).Msg("Found matching container by domain") + return appLabels, nil + } - if strings.TrimPrefix(inspect.Name, "/") == app { - log.Debug().Str("id", inspect.ID).Msg("Found matching container by name") - return labels, nil + if strings.TrimPrefix(inspect.Name, "/") == appName { + log.Debug().Str("id", inspect.ID).Msg("Found matching container by app name") + return appLabels, nil + } } } log.Debug().Msg("No matching container found, returning empty labels") - return config.Labels{}, nil + return config.AppLabels{}, nil } diff --git a/internal/utils/label_utils.go b/internal/utils/label_utils.go index f10092d..5e423f7 100644 --- a/internal/utils/label_utils.go +++ b/internal/utils/label_utils.go @@ -11,7 +11,7 @@ import ( func GetLabels(labels map[string]string) (config.Labels, error) { var labelsParsed config.Labels - err := parser.Decode(labels, &labelsParsed, "tinyauth", "tinyauth.users", "tinyauth.allowed", "tinyauth.headers", "tinyauth.domain", "tinyauth.basic", "tinyauth.oauth", "tinyauth.ip") + err := parser.Decode(labels, &labelsParsed, "tinyauth", "tinyauth.apps") if err != nil { return config.Labels{}, err }