refactor: unify labels (#329)

* refactor: unify labels

* feat: implement path block and user block

Fixes #313

* fix: fix oauth group check logic

* chore: fix typo
This commit is contained in:
Stavros
2025-08-29 17:04:34 +03:00
committed by GitHub
parent 03d06cb0a7
commit c7c3de4f78
6 changed files with 164 additions and 114 deletions

View File

@@ -181,7 +181,7 @@ func (app *BootstrapApp) Setup() error {
Title: app.Config.Title, Title: app.Config.Title,
GenericName: app.Config.GenericName, GenericName: app.Config.GenericName,
Domain: domain, Domain: domain,
ForgotPasswordMessage: app.Config.FogotPasswordMessage, ForgotPasswordMessage: app.Config.ForgotPasswordMessage,
BackgroundImage: app.Config.BackgroundImage, BackgroundImage: app.Config.BackgroundImage,
OAuthAutoRedirect: app.Config.OAuthAutoRedirect, OAuthAutoRedirect: app.Config.OAuthAutoRedirect,
}, apiRouter) }, apiRouter)

View File

@@ -1,20 +1,19 @@
package config package config
type Claims struct { // Version information, set at build time
Name string `json:"name"`
Email string `json:"email"`
PreferredUsername string `json:"preferred_username"`
Groups any `json:"groups"`
}
var Version = "development" var Version = "development"
var CommitHash = "n/a" var CommitHash = "n/a"
var BuildTimestamp = "n/a" var BuildTimestamp = "n/a"
// Cookie name templates
var SessionCookieName = "tinyauth-session" var SessionCookieName = "tinyauth-session"
var CSRFCookieName = "tinyauth-csrf" var CSRFCookieName = "tinyauth-csrf"
var RedirectCookieName = "tinyauth-redirect" var RedirectCookieName = "tinyauth-redirect"
// Main app config
type Config struct { type Config struct {
Port int `mapstructure:"port" validate:"required"` Port int `mapstructure:"port" validate:"required"`
Address string `validate:"required,ip4_addr" mapstructure:"address"` Address string `validate:"required,ip4_addr" mapstructure:"address"`
@@ -45,7 +44,7 @@ type Config struct {
Title string `mapstructure:"app-title"` Title string `mapstructure:"app-title"`
LoginTimeout int `mapstructure:"login-timeout"` LoginTimeout int `mapstructure:"login-timeout"`
LoginMaxRetries int `mapstructure:"login-max-retries"` LoginMaxRetries int `mapstructure:"login-max-retries"`
FogotPasswordMessage string `mapstructure:"forgot-password-message"` ForgotPasswordMessage string `mapstructure:"forgot-password-message"`
BackgroundImage string `mapstructure:"background-image" validate:"required"` BackgroundImage string `mapstructure:"background-image" validate:"required"`
LdapAddress string `mapstructure:"ldap-address"` LdapAddress string `mapstructure:"ldap-address"`
LdapBindDN string `mapstructure:"ldap-bind-dn"` LdapBindDN string `mapstructure:"ldap-bind-dn"`
@@ -57,35 +56,13 @@ type Config struct {
DatabasePath string `mapstructure:"database-path" validate:"required"` DatabasePath string `mapstructure:"database-path" validate:"required"`
} }
type OAuthLabels struct { // OAuth/OIDC config
Whitelist string
Groups string
}
type BasicLabels struct { type Claims struct {
Username string Name string `json:"name"`
Password PasswordLabels Email string `json:"email"`
} PreferredUsername string `json:"preferred_username"`
Groups any `json:"groups"`
type PasswordLabels struct {
Plain string
File string
}
type IPLabels struct {
Allow []string
Block []string
Bypass []string
}
type Labels struct {
Users string
Allowed string
Headers []string
Domain []string
Basic BasicLabels
OAuth OAuthLabels
IP IPLabels
} }
type OAuthServiceConfig struct { type OAuthServiceConfig struct {
@@ -99,6 +76,8 @@ type OAuthServiceConfig struct {
InsecureSkipVerify bool InsecureSkipVerify bool
} }
// User/session related stuff
type User struct { type User struct {
Username string Username string
Password string Password string
@@ -132,6 +111,8 @@ type UserContext struct {
TotpEnabled bool TotpEnabled bool
} }
// API responses and queries
type UnauthorizedQuery struct { type UnauthorizedQuery struct {
Username string `url:"username"` Username string `url:"username"`
Resource string `url:"resource"` Resource string `url:"resource"`
@@ -142,3 +123,54 @@ type UnauthorizedQuery struct {
type RedirectQuery struct { type RedirectQuery struct {
RedirectURI string `url:"redirect_uri"` RedirectURI string `url:"redirect_uri"`
} }
// Labels
type Labels struct {
Apps map[string]AppLabels
}
type AppLabels struct {
Config ConfigLabels
Users UsersLabels
OAuth OAuthLabels
IP IPLabels
Response ResponseLabels
Path PathLabels
}
type ConfigLabels struct {
Domain string
}
type UsersLabels struct {
Allow string
Block string
}
type OAuthLabels struct {
Whitelist string
Groups string
}
type IPLabels struct {
Allow []string
Block []string
Bypass []string
}
type ResponseLabels struct {
Headers []string
BasicAuth BasicAuthLabels
}
type BasicAuthLabels struct {
Username string
Password string
PasswordFile string
}
type PathLabels struct {
Allow string
Block string
}

View File

@@ -89,19 +89,20 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
clientIP := c.ClientIP() clientIP := c.ClientIP()
if controller.Auth.IsBypassedIP(labels, clientIP) { if controller.Auth.IsBypassedIP(labels.IP, clientIP) {
c.Header("Authorization", c.Request.Header.Get("Authorization")) c.Header("Authorization", c.Request.Header.Get("Authorization"))
headers := utils.ParseHeaders(labels.Headers) headers := utils.ParseHeaders(labels.Response.Headers)
for key, value := range headers { for key, value := range headers {
log.Debug().Str("header", key).Msg("Setting header") log.Debug().Str("header", key).Msg("Setting header")
c.Header(key, value) c.Header(key, value)
} }
if labels.Basic.Username != "" && utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File) != "" { basicPassword := utils.GetSecret(labels.Response.BasicAuth.Password, labels.Response.BasicAuth.PasswordFile)
log.Debug().Str("username", labels.Basic.Username).Msg("Setting basic auth header") if labels.Response.BasicAuth.Username != "" && basicPassword != "" {
c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Basic.Username, utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File)))) log.Debug().Str("username", labels.Response.BasicAuth.Username).Msg("Setting basic auth header")
c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Response.BasicAuth.Username, basicPassword)))
} }
c.JSON(200, gin.H{ c.JSON(200, gin.H{
@@ -111,31 +112,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return return
} }
if !controller.Auth.CheckIP(labels, clientIP) { authEnabled, err := controller.Auth.IsAuthEnabled(uri, labels.Path)
if req.Proxy == "nginx" || !isBrowser {
c.JSON(401, gin.H{
"status": 401,
"message": "Unauthorized",
})
return
}
queries, err := query.Values(config.UnauthorizedQuery{
Resource: strings.Split(host, ".")[0],
IP: clientIP,
})
if err != nil {
log.Error().Err(err).Msg("Failed to encode unauthorized query")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL))
return
}
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.Config.AppURL, queries.Encode()))
return
}
authEnabled, err := controller.Auth.IsAuthEnabled(uri, labels)
if err != nil { if err != nil {
log.Error().Err(err).Msg("Failed to check if auth is enabled for resource") log.Error().Err(err).Msg("Failed to check if auth is enabled for resource")
@@ -157,16 +134,17 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
c.Header("Authorization", c.Request.Header.Get("Authorization")) c.Header("Authorization", c.Request.Header.Get("Authorization"))
headers := utils.ParseHeaders(labels.Headers) headers := utils.ParseHeaders(labels.Response.Headers)
for key, value := range headers { for key, value := range headers {
log.Debug().Str("header", key).Msg("Setting header") log.Debug().Str("header", key).Msg("Setting header")
c.Header(key, value) c.Header(key, value)
} }
if labels.Basic.Username != "" && utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File) != "" { basicPassword := utils.GetSecret(labels.Response.BasicAuth.Password, labels.Response.BasicAuth.PasswordFile)
log.Debug().Str("username", labels.Basic.Username).Msg("Setting basic auth header") if labels.Response.BasicAuth.Username != "" && basicPassword != "" {
c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Basic.Username, utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File)))) log.Debug().Str("username", labels.Response.BasicAuth.Username).Msg("Setting basic auth header")
c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Response.BasicAuth.Username, basicPassword)))
} }
c.JSON(200, gin.H{ c.JSON(200, gin.H{
@@ -176,6 +154,30 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return return
} }
if !controller.Auth.CheckIP(labels.IP, clientIP) {
if req.Proxy == "nginx" || !isBrowser {
c.JSON(401, gin.H{
"status": 401,
"message": "Unauthorized",
})
return
}
queries, err := query.Values(config.UnauthorizedQuery{
Resource: strings.Split(host, ".")[0],
IP: clientIP,
})
if err != nil {
log.Error().Err(err).Msg("Failed to encode unauthorized query")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL))
return
}
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.Config.AppURL, queries.Encode()))
return
}
var userContext config.UserContext var userContext config.UserContext
context, err := utils.GetContext(c) context, err := utils.GetContext(c)
@@ -229,7 +231,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
} }
if userContext.OAuth { if userContext.OAuth {
groupOK := controller.Auth.IsInOAuthGroup(c, userContext, labels) groupOK := controller.Auth.IsInOAuthGroup(c, userContext, labels.OAuth.Groups)
if !groupOK { if !groupOK {
log.Warn().Str("user", userContext.Username).Str("resource", strings.Split(host, ".")[0]).Msg("User OAuth groups do not match resource requirements") log.Warn().Str("user", userContext.Username).Str("resource", strings.Split(host, ".")[0]).Msg("User OAuth groups do not match resource requirements")
@@ -270,16 +272,17 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
c.Header("Remote-Email", utils.SanitizeHeader(userContext.Email)) c.Header("Remote-Email", utils.SanitizeHeader(userContext.Email))
c.Header("Remote-Groups", utils.SanitizeHeader(userContext.OAuthGroups)) c.Header("Remote-Groups", utils.SanitizeHeader(userContext.OAuthGroups))
headers := utils.ParseHeaders(labels.Headers) headers := utils.ParseHeaders(labels.Response.Headers)
for key, value := range headers { for key, value := range headers {
log.Debug().Str("header", key).Msg("Setting header") log.Debug().Str("header", key).Msg("Setting header")
c.Header(key, value) c.Header(key, value)
} }
if labels.Basic.Username != "" && utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File) != "" { basicPassword := utils.GetSecret(labels.Response.BasicAuth.Password, labels.Response.BasicAuth.PasswordFile)
log.Debug().Str("username", labels.Basic.Username).Msg("Setting basic auth header") if labels.Response.BasicAuth.Username != "" && basicPassword != "" {
c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Basic.Username, utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File)))) log.Debug().Str("username", labels.Response.BasicAuth.Username).Msg("Setting basic auth header")
c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Response.BasicAuth.Username, basicPassword)))
} }
c.JSON(200, gin.H{ c.JSON(200, gin.H{

View File

@@ -283,18 +283,25 @@ func (auth *AuthService) UserAuthConfigured() bool {
return len(auth.Config.Users) > 0 || auth.LDAP != nil return len(auth.Config.Users) > 0 || auth.LDAP != nil
} }
func (auth *AuthService) IsResourceAllowed(c *gin.Context, context config.UserContext, labels config.Labels) bool { func (auth *AuthService) IsResourceAllowed(c *gin.Context, context config.UserContext, labels config.AppLabels) bool {
if context.OAuth { if context.OAuth {
log.Debug().Msg("Checking OAuth whitelist") log.Debug().Msg("Checking OAuth whitelist")
return utils.CheckFilter(labels.OAuth.Whitelist, context.Email) return utils.CheckFilter(labels.OAuth.Whitelist, context.Email)
} }
if labels.Users.Block != "" {
log.Debug().Msg("Checking blocked users")
if utils.CheckFilter(labels.Users.Block, context.Username) {
return false
}
}
log.Debug().Msg("Checking users") log.Debug().Msg("Checking users")
return utils.CheckFilter(labels.Users, context.Username) return utils.CheckFilter(labels.Users.Allow, context.Username)
} }
func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context config.UserContext, labels config.Labels) bool { func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context config.UserContext, requiredGroups string) bool {
if labels.OAuth.Groups == "" { if requiredGroups == "" {
return true return true
} }
@@ -303,11 +310,8 @@ func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context config.UserConte
return true return true
} }
// No need to parse since they are from the API response for _, userGroup := range strings.Split(context.OAuthGroups, ",") {
oauthGroups := strings.Split(context.OAuthGroups, ",") if utils.CheckFilter(requiredGroups, strings.TrimSpace(userGroup)) {
for _, group := range oauthGroups {
if utils.CheckFilter(labels.OAuth.Groups, group) {
return true return true
} }
} }
@@ -316,19 +320,31 @@ func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context config.UserConte
return false return false
} }
func (auth *AuthService) IsAuthEnabled(uri string, labels config.Labels) (bool, error) { func (auth *AuthService) IsAuthEnabled(uri string, path config.PathLabels) (bool, error) {
if labels.Allowed == "" { // Check for block list
return true, nil if path.Block != "" {
regex, err := regexp.Compile(path.Block)
if err != nil {
return true, err
}
if !regex.MatchString(uri) {
return false, nil
}
} }
regex, err := regexp.Compile(labels.Allowed) // Check for allow list
if path.Allow != "" {
regex, err := regexp.Compile(path.Allow)
if err != nil { if err != nil {
return true, err return true, err
} }
if regex.MatchString(uri) { if regex.MatchString(uri) {
return false, nil return false, nil
}
} }
return true, nil return true, nil
@@ -346,8 +362,8 @@ func (auth *AuthService) GetBasicAuth(c *gin.Context) *config.User {
} }
} }
func (auth *AuthService) CheckIP(labels config.Labels, ip string) bool { func (auth *AuthService) CheckIP(labels config.IPLabels, ip string) bool {
for _, blocked := range labels.IP.Block { for _, blocked := range labels.Block {
res, err := utils.FilterIP(blocked, ip) res, err := utils.FilterIP(blocked, ip)
if err != nil { if err != nil {
log.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list") log.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list")
@@ -359,7 +375,7 @@ func (auth *AuthService) CheckIP(labels config.Labels, ip string) bool {
} }
} }
for _, allowed := range labels.IP.Allow { for _, allowed := range labels.Allow {
res, err := utils.FilterIP(allowed, ip) res, err := utils.FilterIP(allowed, ip)
if err != nil { if err != nil {
log.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list") log.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list")
@@ -371,7 +387,7 @@ func (auth *AuthService) CheckIP(labels config.Labels, ip string) bool {
} }
} }
if len(labels.IP.Allow) > 0 { if len(labels.Allow) > 0 {
log.Debug().Str("ip", ip).Msg("IP not in allow list, denying access") log.Debug().Str("ip", ip).Msg("IP not in allow list, denying access")
return false return false
} }
@@ -380,8 +396,8 @@ func (auth *AuthService) CheckIP(labels config.Labels, ip string) bool {
return true return true
} }
func (auth *AuthService) IsBypassedIP(labels config.Labels, ip string) bool { func (auth *AuthService) IsBypassedIP(labels config.IPLabels, ip string) bool {
for _, bypassed := range labels.IP.Bypass { for _, bypassed := range labels.Bypass {
res, err := utils.FilterIP(bypassed, ip) res, err := utils.FilterIP(bypassed, ip)
if err != nil { if err != nil {
log.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list") log.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list")

View File

@@ -6,8 +6,6 @@ import (
"tinyauth/internal/config" "tinyauth/internal/config"
"tinyauth/internal/utils" "tinyauth/internal/utils"
"slices"
container "github.com/docker/docker/api/types/container" container "github.com/docker/docker/api/types/container"
"github.com/docker/docker/client" "github.com/docker/docker/client"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@@ -57,17 +55,17 @@ func (docker *DockerService) DockerConnected() bool {
return err == nil return err == nil
} }
func (docker *DockerService) GetLabels(app string, domain string) (config.Labels, error) { func (docker *DockerService) GetLabels(app string, domain string) (config.AppLabels, error) {
isConnected := docker.DockerConnected() isConnected := docker.DockerConnected()
if !isConnected { if !isConnected {
log.Debug().Msg("Docker not connected, returning empty labels") log.Debug().Msg("Docker not connected, returning empty labels")
return config.Labels{}, nil return config.AppLabels{}, nil
} }
containers, err := docker.GetContainers() containers, err := docker.GetContainers()
if err != nil { if err != nil {
return config.Labels{}, err return config.AppLabels{}, err
} }
for _, container := range containers { for _, container := range containers {
@@ -83,18 +81,19 @@ func (docker *DockerService) GetLabels(app string, domain string) (config.Labels
continue continue
} }
// Check if the container matches the ID or domain for appName, appLabels := range labels.Apps {
if slices.Contains(labels.Domain, domain) { if appLabels.Config.Domain == domain {
log.Debug().Str("id", inspect.ID).Msg("Found matching container by domain") log.Debug().Str("id", inspect.ID).Msg("Found matching container by domain")
return labels, nil return appLabels, nil
} }
if strings.TrimPrefix(inspect.Name, "/") == app { if strings.TrimPrefix(inspect.Name, "/") == appName {
log.Debug().Str("id", inspect.ID).Msg("Found matching container by name") log.Debug().Str("id", inspect.ID).Msg("Found matching container by app name")
return labels, nil return appLabels, nil
}
} }
} }
log.Debug().Msg("No matching container found, returning empty labels") log.Debug().Msg("No matching container found, returning empty labels")
return config.Labels{}, nil return config.AppLabels{}, nil
} }

View File

@@ -11,7 +11,7 @@ import (
func GetLabels(labels map[string]string) (config.Labels, error) { func GetLabels(labels map[string]string) (config.Labels, error) {
var labelsParsed config.Labels var labelsParsed config.Labels
err := parser.Decode(labels, &labelsParsed, "tinyauth", "tinyauth.users", "tinyauth.allowed", "tinyauth.headers", "tinyauth.domain", "tinyauth.basic", "tinyauth.oauth", "tinyauth.ip") err := parser.Decode(labels, &labelsParsed, "tinyauth", "tinyauth.apps")
if err != nil { if err != nil {
return config.Labels{}, err return config.Labels{}, err
} }