diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go index 4851253b..d744fbe6 100644 --- a/internal/bootstrap/app_bootstrap.go +++ b/internal/bootstrap/app_bootstrap.go @@ -35,6 +35,7 @@ type Services struct { ldapService *service.LdapService oauthBrokerService *service.OAuthBrokerService oidcService *service.OIDCService + policyEngine *service.PolicyEngine } type BootstrapApp struct { diff --git a/internal/bootstrap/router_bootstrap.go b/internal/bootstrap/router_bootstrap.go index 12a48bc0..3506a569 100644 --- a/internal/bootstrap/router_bootstrap.go +++ b/internal/bootstrap/router_bootstrap.go @@ -44,7 +44,7 @@ func (app *BootstrapApp) setupRouter() error { controller.NewContextController(app.log, app.config, app.runtime, apiRouter) controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService) controller.NewOIDCController(app.log, app.services.oidcService, app.runtime, apiRouter) - controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService) + controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService, app.services.policyEngine) controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService) controller.NewResourcesController(app.config, &engine.RouterGroup) controller.NewHealthController(apiRouter) diff --git a/internal/bootstrap/service_bootstrap.go b/internal/bootstrap/service_bootstrap.go index ef3ee591..acd5af01 100644 --- a/internal/bootstrap/service_bootstrap.go +++ b/internal/bootstrap/service_bootstrap.go @@ -16,38 +16,21 @@ func (app *BootstrapApp) setupServices() error { app.services.ldapService = ldapService - useKubernetes := app.config.LabelProvider == "kubernetes" || - (app.config.LabelProvider == "auto" && os.Getenv("KUBERNETES_SERVICE_HOST") != "") + labelProvider, err := app.getLabelProvider() - var labelProvider service.LabelProvider - - if useKubernetes { - app.log.App.Debug().Msg("Using Kubernetes label provider") - - kubernetesService, err := service.NewKubernetesService(app.log, app.ctx, &app.wg) - - if err != nil { - return fmt.Errorf("failed to initialize kubernetes service: %w", err) - } - - app.services.kubernetesService = kubernetesService - labelProvider = kubernetesService - } else { - app.log.App.Debug().Msg("Using Docker label provider") - - dockerService, err := service.NewDockerService(app.log, app.ctx, &app.wg) - - if err != nil { - return fmt.Errorf("failed to initialize docker service: %w", err) - } - - app.services.dockerService = dockerService - labelProvider = dockerService + if err != nil { + return fmt.Errorf("failed to initialize label provider: %w", err) } - accessControlsService := service.NewAccessControlsService(app.log, &labelProvider, app.config.Apps) + accessControlsService := service.NewAccessControlsService(app.log, app.config, &labelProvider) app.services.accessControlService = accessControlsService + err = app.setupPolicyEngine() + + if err != nil { + return fmt.Errorf("failed to initialize policy engine: %w", err) + } + oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders, app.ctx) app.services.oauthBrokerService = oauthBrokerService @@ -64,3 +47,79 @@ func (app *BootstrapApp) setupServices() error { return nil } + +func (app *BootstrapApp) getLabelProvider() (service.LabelProvider, error) { + switch app.config.LabelProvider { + case "none", "docker", "kubernetes", "auto": + if app.config.LabelProvider == "none" { + return nil, nil + } + + useKubernetes := app.config.LabelProvider == "kubernetes" || + (app.config.LabelProvider == "auto" && os.Getenv("KUBERNETES_SERVICE_HOST") != "") + + if useKubernetes { + app.log.App.Debug().Msg("Using Kubernetes label provider") + + kubernetesService, err := service.NewKubernetesService(app.log, app.ctx, &app.wg) + + if err != nil { + return nil, fmt.Errorf("failed to initialize kubernetes service: %w", err) + } + + app.services.kubernetesService = kubernetesService + return kubernetesService, nil + } + + app.log.App.Debug().Msg("Using Docker label provider") + + dockerService, err := service.NewDockerService(app.log, app.ctx, &app.wg) + + if err != nil { + return nil, fmt.Errorf("failed to initialize docker service: %w", err) + } + + if dockerService == nil { + if app.config.LabelProvider == "docker" { + app.log.App.Warn().Msg("Docker label provider selected but Docker is not available, will continue without it") + } + return nil, nil + } + + app.services.dockerService = dockerService + return dockerService, nil + default: + return nil, fmt.Errorf("invalid label provider: %s", app.config.LabelProvider) + } +} + +func (app *BootstrapApp) setupPolicyEngine() error { + policyEngine, err := service.NewPolicyEngine(app.config, app.log) + + if err != nil { + return fmt.Errorf("failed to initialize policy engine: %w", err) + } + + policyEngine.RegisterRule(service.RuleUserAllowed, &service.UserAllowedRule{ + Log: app.log, + }) + policyEngine.RegisterRule(service.RuleOAuthGroup, &service.OAuthGroupRule{ + Log: app.log, + }) + policyEngine.RegisterRule(service.RuleLDAPGroup, &service.LDAPGroupRule{ + Log: app.log, + }) + policyEngine.RegisterRule(service.RuleAuthEnabled, &service.AuthEnabledRule{ + Log: app.log, + }) + policyEngine.RegisterRule(service.RuleIPAllowed, &service.IPAllowedRule{ + Log: app.log, + Config: app.config, + }) + policyEngine.RegisterRule(service.RuleIPBypassed, &service.IPBypassedRule{ + Log: app.log, + }) + + app.services.policyEngine = policyEngine + return nil +} diff --git a/internal/controller/proxy_controller.go b/internal/controller/proxy_controller.go index 95afe113..79d79525 100644 --- a/internal/controller/proxy_controller.go +++ b/internal/controller/proxy_controller.go @@ -3,6 +3,7 @@ package controller import ( "errors" "fmt" + "net" "net/http" "net/url" "regexp" @@ -51,10 +52,11 @@ type ProxyContext struct { } type ProxyController struct { - log *logger.Logger - runtime model.RuntimeConfig - acls *service.AccessControlsService - auth *service.AuthService + log *logger.Logger + runtime model.RuntimeConfig + acls *service.AccessControlsService + auth *service.AuthService + policyEngine *service.PolicyEngine } func NewProxyController( @@ -63,12 +65,14 @@ func NewProxyController( router *gin.RouterGroup, acls *service.AccessControlsService, auth *service.AuthService, + policyEngine *service.PolicyEngine, ) *ProxyController { controller := &ProxyController{ - log: log, - runtime: runtime, - acls: acls, - auth: auth, + log: log, + runtime: runtime, + acls: acls, + auth: auth, + policyEngine: policyEngine, } proxyGroup := router.Group("/auth") @@ -101,7 +105,13 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { clientIP := c.ClientIP() - if controller.auth.IsBypassedIP(clientIP, acls) { + aclsCtx := &service.ACLContext{ + ACLs: acls, + IP: net.ParseIP(clientIP), + Path: proxyCtx.Path, + } + + if controller.policyEngine.Evaluate(service.RuleIPBypassed, aclsCtx) { controller.setHeaders(c, acls) c.JSON(200, gin.H{ "status": 200, @@ -110,15 +120,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { return } - authEnabled, err := controller.auth.IsAuthEnabled(proxyCtx.Path, acls) - - if err != nil { - controller.log.App.Error().Err(err).Msg("Failed to determine if authentication is enabled for resource") - controller.handleError(c, proxyCtx) - return - } - - if !authEnabled { + if controller.policyEngine.Evaluate(service.RuleAuthEnabled, aclsCtx) { controller.log.App.Debug().Msg("Authentication is disabled for this resource, allowing access without authentication") controller.setHeaders(c, acls) c.JSON(200, gin.H{ @@ -128,7 +130,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { return } - if !controller.auth.CheckIP(clientIP, acls) { + if !controller.policyEngine.Evaluate(service.RuleIPAllowed, aclsCtx) { queries, err := query.Values(UnauthorizedQuery{ Resource: strings.Split(proxyCtx.Host, ".")[0], IP: clientIP, @@ -164,10 +166,10 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { } } - if userContext.Authenticated { - userAllowed := controller.auth.IsUserAllowed(c, *userContext, acls) + aclsCtx.UserContext = userContext - if !userAllowed { + if userContext.Authenticated { + if !controller.policyEngine.Evaluate(service.RuleUserAllowed, aclsCtx) { controller.log.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User is not allowed to access resource") queries, err := query.Values(UnauthorizedQuery{ @@ -205,9 +207,9 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { var groupOK bool if userContext.IsOAuth() { - groupOK = controller.auth.IsInOAuthGroup(c, *userContext, acls) + groupOK = controller.policyEngine.Evaluate(service.RuleOAuthGroup, aclsCtx) } else { - groupOK = controller.auth.IsInLDAPGroup(c, *userContext, acls) + groupOK = controller.policyEngine.Evaluate(service.RuleLDAPGroup, aclsCtx) } if !groupOK { diff --git a/internal/controller/proxy_controller_test.go b/internal/controller/proxy_controller_test.go index 4a467997..fa6e73cd 100644 --- a/internal/controller/proxy_controller_test.go +++ b/internal/controller/proxy_controller_test.go @@ -8,6 +8,7 @@ import ( "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository/memory" @@ -22,33 +23,6 @@ func TestProxyController(t *testing.T) { cfg, runtime := test.CreateTestConfigs(t) - acls := map[string]model.App{ - "app_path_allow": { - Config: model.AppConfig{ - Domain: "path-allow.example.com", - }, - Path: model.AppPath{ - Allow: "/allowed", - }, - }, - "app_user_allow": { - Config: model.AppConfig{ - Domain: "user-allow.example.com", - }, - Users: model.AppUsers{ - Allow: "testuser", - }, - }, - "ip_bypass": { - Config: model.AppConfig{ - Domain: "ip-bypass.example.com", - }, - IP: model.AppIP{ - Bypass: []string{"10.10.10.10"}, - }, - }, - } - const browserUserAgent = ` Mozilla/5.0 (Linux; Android 8.0.0; SM-G955U Build/R16NW) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/146.0.0.0 Mobile Safari/537.36` @@ -384,7 +358,30 @@ func TestProxyController(t *testing.T) { broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx) authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker) - aclsService := service.NewAccessControlsService(log, nil, acls) + aclsService := service.NewAccessControlsService(log, cfg, nil) + + policyEngine, err := service.NewPolicyEngine(cfg, log) + require.NoError(t, err) + + policyEngine.RegisterRule(service.RuleUserAllowed, &service.UserAllowedRule{ + Log: log, + }) + policyEngine.RegisterRule(service.RuleOAuthGroup, &service.OAuthGroupRule{ + Log: log, + }) + policyEngine.RegisterRule(service.RuleLDAPGroup, &service.LDAPGroupRule{ + Log: log, + }) + policyEngine.RegisterRule(service.RuleAuthEnabled, &service.AuthEnabledRule{ + Log: log, + }) + policyEngine.RegisterRule(service.RuleIPAllowed, &service.IPAllowedRule{ + Log: log, + Config: cfg, + }) + policyEngine.RegisterRule(service.RuleIPBypassed, &service.IPBypassedRule{ + Log: log, + }) for _, test := range tests { t.Run(test.description, func(t *testing.T) { @@ -399,7 +396,7 @@ func TestProxyController(t *testing.T) { recorder := httptest.NewRecorder() - controller.NewProxyController(log, runtime, group, aclsService, authService) + controller.NewProxyController(log, runtime, group, aclsService, authService, policyEngine) test.run(t, router, recorder) }) diff --git a/internal/model/config.go b/internal/model/config.go index 0dfd2724..68904f48 100644 --- a/internal/model/config.go +++ b/internal/model/config.go @@ -25,6 +25,9 @@ func NewDefaultConfiguration() *Config { SessionMaxLifetime: 0, // disabled LoginTimeout: 300, // 5 minutes LoginMaxRetries: 3, + ACLs: ACLsConfig{ + Policy: "allow", + }, }, UI: UIConfig{ Title: "Tinyauth", @@ -79,7 +82,7 @@ type Config struct { UI UIConfig `description:"UI customization." yaml:"ui"` LDAP LDAPConfig `description:"LDAP configuration." yaml:"ldap"` Experimental ExperimentalConfig `description:"Experimental features, use with caution." yaml:"experimental"` - LabelProvider string `description:"Label provider to use for ACLs (auto, docker, or kubernetes). auto detects the environment." yaml:"labelProvider"` + LabelProvider string `description:"Label provider to use for ACLs (auto, docker, kubernetes or none to disable). auto detects the environment." yaml:"labelProvider"` Log LogConfig `description:"Logging configuration." yaml:"log"` } @@ -116,6 +119,7 @@ type AuthConfig struct { LoginTimeout int `description:"Login timeout in seconds." yaml:"loginTimeout"` LoginMaxRetries int `description:"Maximum login retries." yaml:"loginMaxRetries"` TrustedProxies []string `description:"Comma-separated list of trusted proxy addresses." yaml:"trustedProxies"` + ACLs ACLsConfig `description:"ACLs configuration." yaml:"acls"` } type UserAttributes struct { @@ -225,6 +229,10 @@ type OIDCClientConfig struct { Name string `description:"Client name in UI." yaml:"name"` } +type ACLsConfig struct { + Policy string `description:"ACL policy for allow-by-default or deny-by-default, available options are allow and deny, default is allow." yaml:"policy"` +} + // ACLs type Apps struct { diff --git a/internal/service/access_controls_rules.go b/internal/service/access_controls_rules.go new file mode 100644 index 00000000..93245c15 --- /dev/null +++ b/internal/service/access_controls_rules.go @@ -0,0 +1,249 @@ +package service + +import ( + "regexp" + "strings" + + "github.com/tinyauthapp/tinyauth/internal/model" + "github.com/tinyauthapp/tinyauth/internal/utils" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" +) + +type RuleName string + +const ( + RuleUserAllowed RuleName = "rule-user-allowed" + RuleOAuthGroup RuleName = "rule-oauth-group" + RuleLDAPGroup RuleName = "rule-ldap-group" + RuleAuthEnabled RuleName = "rule-auth-enabled" + RuleIPAllowed RuleName = "rule-ip-allowed" + RuleIPBypassed RuleName = "rule-ip-bypassed" +) + +type UserAllowedRule struct { + Log *logger.Logger +} + +func (rule *UserAllowedRule) Evaluate(ctx *ACLContext) Effect { + if ctx.ACLs == nil || ctx.UserContext == nil { + return EffectAbstain + } + + if ctx.UserContext.Provider == model.ProviderOAuth { + rule.Log.App.Debug().Msg("User is an OAuth user, checking OAuth whitelist") + match, err := utils.CheckFilter(ctx.ACLs.OAuth.Whitelist, ctx.UserContext.OAuth.Email) + if err != nil { + rule.Log.App.Warn().Err(err).Str("item", ctx.UserContext.OAuth.Email).Msg("Invalid entry in OAuth whitelist") + return EffectAbstain + } + if match { + rule.Log.App.Debug().Str("email", ctx.UserContext.OAuth.Email).Msg("User is in OAuth whitelist, allowing access") + return EffectAllow + } + return EffectDeny + } + + if ctx.ACLs.Users.Block != "" { + rule.Log.App.Debug().Msg("Checking users block list") + match, err := utils.CheckFilter(ctx.ACLs.Users.Block, ctx.UserContext.GetUsername()) + if err != nil { + rule.Log.App.Warn().Err(err).Str("item", ctx.UserContext.GetUsername()).Msg("Invalid entry in users block list") + return EffectAbstain + } + if match { + rule.Log.App.Debug().Str("username", ctx.UserContext.GetUsername()).Msg("User is in users block list, denying access") + return EffectDeny + } + return EffectAllow + } + + rule.Log.App.Debug().Msg("Checking users allow list") + + match, err := utils.CheckFilter(ctx.ACLs.Users.Allow, ctx.UserContext.GetUsername()) + + if err != nil { + rule.Log.App.Warn().Err(err).Str("item", ctx.UserContext.GetUsername()).Msg("Invalid entry in users allow list") + return EffectAbstain + } + + if match { + rule.Log.App.Debug().Str("username", ctx.UserContext.GetUsername()).Msg("User is in users allow list, allowing access") + return EffectAllow + } + + rule.Log.App.Debug().Str("username", ctx.UserContext.GetUsername()).Msg("User is not in users allow list, denying access") + return EffectDeny +} + +type OAuthGroupRule struct { + Log *logger.Logger +} + +func (rule *OAuthGroupRule) Evaluate(ctx *ACLContext) Effect { + if ctx.ACLs == nil || ctx.UserContext == nil { + return EffectAbstain + } + + if !ctx.UserContext.IsOAuth() { + rule.Log.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check") + return EffectAbstain + } + + if _, ok := model.OverrideProviders[ctx.UserContext.OAuth.ID]; ok { + rule.Log.App.Debug().Str("provider", ctx.UserContext.OAuth.ID).Msg("Provider override detected, skipping group check") + return EffectAllow + } + + for _, group := range ctx.UserContext.OAuth.Groups { + match, err := utils.CheckFilter(ctx.ACLs.OAuth.Groups, strings.TrimSpace(group)) + if err != nil { + return EffectAbstain + } + if match { + rule.Log.App.Trace().Str("group", group).Str("required", ctx.ACLs.OAuth.Groups).Msg("User group matched, allowing access") + return EffectAllow + } + } + + rule.Log.App.Debug().Msg("No groups matched") + return EffectDeny +} + +type LDAPGroupRule struct { + Log *logger.Logger +} + +func (rule *LDAPGroupRule) Evaluate(ctx *ACLContext) Effect { + if ctx == nil || ctx.UserContext == nil { + return EffectAbstain + } + + if !ctx.UserContext.IsLDAP() { + rule.Log.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check") + return EffectAbstain + } + + for _, group := range ctx.UserContext.LDAP.Groups { + match, err := utils.CheckFilter(ctx.ACLs.LDAP.Groups, strings.TrimSpace(group)) + if err != nil { + return EffectAbstain + } + if match { + rule.Log.App.Trace().Str("group", group).Str("required", ctx.ACLs.LDAP.Groups).Msg("User group matched, allowing access") + return EffectAllow + } + } + + rule.Log.App.Debug().Msg("No groups matched") + return EffectDeny +} + +type AuthEnabledRule struct { + Log *logger.Logger +} + +func (rule *AuthEnabledRule) Evaluate(ctx *ACLContext) Effect { + if ctx.ACLs == nil { + return EffectDeny + } + + if ctx.ACLs.Path.Block != "" { + regex, err := regexp.Compile(ctx.ACLs.Path.Block) + + if err != nil { + rule.Log.App.Error().Err(err).Msg("Failed to compile block regex") + return EffectDeny + } + + if !regex.MatchString(ctx.Path) { + return EffectAllow + } + } + + if ctx.ACLs.Path.Allow != "" { + regex, err := regexp.Compile(ctx.ACLs.Path.Allow) + + if err != nil { + rule.Log.App.Error().Err(err).Msg("Failed to compile allow regex") + return EffectDeny + } + + if regex.MatchString(ctx.Path) { + return EffectAllow + } + } + + return EffectDeny +} + +type IPAllowedRule struct { + Log *logger.Logger + Config model.Config +} + +func (rule *IPAllowedRule) Evaluate(ctx *ACLContext) Effect { + if ctx.ACLs == nil { + return EffectAbstain + } + + // Merge the global and app IP filter + blockedIps := append(ctx.ACLs.IP.Block, rule.Config.Auth.IP.Block...) + allowedIPs := append(ctx.ACLs.IP.Allow, rule.Config.Auth.IP.Allow...) + + for _, blocked := range blockedIps { + match, err := utils.CheckIPFilter(blocked, ctx.IP.String()) + if err != nil { + rule.Log.App.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list") + continue + } + if match { + rule.Log.App.Debug().Str("ip", ctx.IP.String()).Str("item", blocked).Msg("IP is in block list, denying access") + return EffectDeny + } + } + + for _, allowed := range allowedIPs { + match, err := utils.CheckIPFilter(allowed, ctx.IP.String()) + if err != nil { + rule.Log.App.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list") + continue + } + if match { + rule.Log.App.Debug().Str("ip", ctx.IP.String()).Str("item", allowed).Msg("IP is in allow list, allowing access") + return EffectAllow + } + } + + if len(allowedIPs) > 0 { + rule.Log.App.Debug().Str("ip", ctx.IP.String()).Msg("IP not in allow list, denying access") + return EffectDeny + } + + rule.Log.App.Debug().Str("ip", ctx.IP.String()).Msg("IP not in block or allow list, allowing access") + return EffectAllow +} + +type IPBypassedRule struct { + Log *logger.Logger +} + +func (rule *IPBypassedRule) Evaluate(ctx *ACLContext) Effect { + if ctx.ACLs == nil { + return EffectDeny + } + + for _, bypassed := range ctx.ACLs.IP.Bypass { + match, err := utils.CheckIPFilter(bypassed, ctx.IP.String()) + if err != nil { + rule.Log.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list") + continue + } + if match { + rule.Log.App.Debug().Str("ip", ctx.IP.String()).Str("item", bypassed).Msg("IP is in bypass list, skipping authentication") + return EffectAllow + } + } + + rule.Log.App.Debug().Str("ip", ctx.IP.String()).Msg("IP not in bypass list, proceeding with authentication") + return EffectDeny +} diff --git a/internal/service/access_controls_rules_test.go b/internal/service/access_controls_rules_test.go new file mode 100644 index 00000000..16dde083 --- /dev/null +++ b/internal/service/access_controls_rules_test.go @@ -0,0 +1,732 @@ +package service + +import ( + "net" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/tinyauthapp/tinyauth/internal/model" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" +) + +func TestUserAllowedRule(t *testing.T) { + log := logger.NewLogger().WithTestConfig() + log.Init() + + rule := &UserAllowedRule{Log: log} + + tests := []struct { + name string + ctx *ACLContext + expected Effect + }{ + { + name: "abstains when ACLs are nil", + ctx: &ACLContext{ + ACLs: nil, + UserContext: &model.UserContext{ + Provider: model.ProviderLocal, + Local: &model.LocalContext{ + BaseContext: model.BaseContext{Username: "alice"}, + }, + }, + }, + expected: EffectAbstain, + }, + { + name: "abstains when user context is nil", + ctx: &ACLContext{ + ACLs: &model.App{ + OAuth: model.AppOAuth{Whitelist: "alice"}, + }, + UserContext: nil, + }, + expected: EffectAbstain, + }, + { + name: "allows OAuth user when email matches whitelist", + ctx: &ACLContext{ + ACLs: &model.App{ + OAuth: model.AppOAuth{Whitelist: "allowed@example.com"}, + }, + UserContext: &model.UserContext{ + Provider: model.ProviderOAuth, + OAuth: &model.OAuthContext{ + BaseContext: model.BaseContext{ + Username: "different-username", + Email: "allowed@example.com", + }, + }, + }, + }, + expected: EffectAllow, + }, + { + name: "denies OAuth user when email does not match whitelist", + ctx: &ACLContext{ + ACLs: &model.App{ + OAuth: model.AppOAuth{Whitelist: "allowed@example.com"}, + }, + UserContext: &model.UserContext{ + Provider: model.ProviderOAuth, + OAuth: &model.OAuthContext{ + BaseContext: model.BaseContext{Email: "denied@example.com"}, + }, + }, + }, + expected: EffectDeny, + }, + { + name: "abstains for OAuth user when whitelist filter is invalid", + ctx: &ACLContext{ + ACLs: &model.App{ + OAuth: model.AppOAuth{Whitelist: "/[/"}, + }, + UserContext: &model.UserContext{ + Provider: model.ProviderOAuth, + OAuth: &model.OAuthContext{ + BaseContext: model.BaseContext{Email: "allowed@example.com"}, + }, + }, + }, + expected: EffectAbstain, + }, + { + name: "denies local user when username matches block list", + ctx: &ACLContext{ + ACLs: &model.App{ + Users: model.AppUsers{Block: "alice,bob"}, + }, + UserContext: &model.UserContext{ + Provider: model.ProviderLocal, + Local: &model.LocalContext{ + BaseContext: model.BaseContext{Username: "alice"}, + }, + }, + }, + expected: EffectDeny, + }, + { + name: "allows local user when username does not match block list", + ctx: &ACLContext{ + ACLs: &model.App{ + Users: model.AppUsers{Block: "alice,bob"}, + }, + UserContext: &model.UserContext{ + Provider: model.ProviderLocal, + Local: &model.LocalContext{ + BaseContext: model.BaseContext{Username: "charlie"}, + }, + }, + }, + expected: EffectAllow, + }, + { + name: "abstains when block list filter is invalid", + ctx: &ACLContext{ + ACLs: &model.App{ + Users: model.AppUsers{Block: "/[/"}, + }, + UserContext: &model.UserContext{ + Provider: model.ProviderLocal, + Local: &model.LocalContext{ + BaseContext: model.BaseContext{Username: "alice"}, + }, + }, + }, + expected: EffectAbstain, + }, + { + name: "allows local user when username matches allow list", + ctx: &ACLContext{ + ACLs: &model.App{ + Users: model.AppUsers{Allow: "alice,bob"}, + }, + UserContext: &model.UserContext{ + Provider: model.ProviderLocal, + Local: &model.LocalContext{ + BaseContext: model.BaseContext{Username: "alice"}, + }, + }, + }, + expected: EffectAllow, + }, + { + name: "denies local user when username does not match allow list", + ctx: &ACLContext{ + ACLs: &model.App{ + Users: model.AppUsers{Allow: "alice,bob"}, + }, + UserContext: &model.UserContext{ + Provider: model.ProviderLocal, + Local: &model.LocalContext{ + BaseContext: model.BaseContext{Username: "charlie"}, + }, + }, + }, + expected: EffectDeny, + }, + { + name: "abstains when allow list filter is invalid", + ctx: &ACLContext{ + ACLs: &model.App{ + Users: model.AppUsers{Allow: "/[/"}, + }, + UserContext: &model.UserContext{ + Provider: model.ProviderLocal, + Local: &model.LocalContext{ + BaseContext: model.BaseContext{Username: "alice"}, + }, + }, + }, + expected: EffectAbstain, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, rule.Evaluate(tt.ctx)) + }) + } +} + +func TestOAuthGroupRule(t *testing.T) { + log := logger.NewLogger().WithTestConfig() + log.Init() + + rule := &OAuthGroupRule{Log: log} + + tests := []struct { + name string + ctx *ACLContext + expected Effect + }{ + { + name: "abstains when ACLs are nil", + ctx: &ACLContext{ + ACLs: nil, + UserContext: &model.UserContext{ + Provider: model.ProviderOAuth, + OAuth: &model.OAuthContext{ + Groups: []string{"admins"}, + }, + }, + }, + expected: EffectAbstain, + }, + { + name: "abstains when user context is nil", + ctx: &ACLContext{ + ACLs: &model.App{ + OAuth: model.AppOAuth{Whitelist: "alice"}, + }, + UserContext: nil, + }, + expected: EffectAbstain, + }, + { + name: "abstains when user is not OAuth", + ctx: &ACLContext{ + ACLs: &model.App{ + OAuth: model.AppOAuth{Groups: "admins"}, + }, + UserContext: &model.UserContext{ + Provider: model.ProviderLocal, + Local: &model.LocalContext{ + BaseContext: model.BaseContext{Username: "alice"}, + }, + }, + }, + expected: EffectAbstain, + }, + { + name: "allows when provider is an override provider regardless of groups", + ctx: &ACLContext{ + ACLs: &model.App{ + OAuth: model.AppOAuth{Groups: "admins"}, + }, + UserContext: &model.UserContext{ + Provider: model.ProviderOAuth, + OAuth: &model.OAuthContext{ + ID: "google", + Groups: []string{"unrelated"}, + }, + }, + }, + expected: EffectAllow, + }, + { + name: "allows OAuth user when a group matches", + ctx: &ACLContext{ + ACLs: &model.App{ + OAuth: model.AppOAuth{Groups: "admins,users"}, + }, + UserContext: &model.UserContext{ + Provider: model.ProviderOAuth, + OAuth: &model.OAuthContext{ + ID: "custom", + Groups: []string{"users"}, + }, + }, + }, + expected: EffectAllow, + }, + { + name: "denies OAuth user when no group matches", + ctx: &ACLContext{ + ACLs: &model.App{ + OAuth: model.AppOAuth{Groups: "admins"}, + }, + UserContext: &model.UserContext{ + Provider: model.ProviderOAuth, + OAuth: &model.OAuthContext{ + ID: "custom", + Groups: []string{"users", "guests"}, + }, + }, + }, + expected: EffectDeny, + }, + { + name: "denies OAuth user when user has no groups", + ctx: &ACLContext{ + ACLs: &model.App{ + OAuth: model.AppOAuth{Groups: "admins"}, + }, + UserContext: &model.UserContext{ + Provider: model.ProviderOAuth, + OAuth: &model.OAuthContext{ + ID: "custom", + Groups: nil, + }, + }, + }, + expected: EffectDeny, + }, + { + name: "abstains when groups filter is invalid", + ctx: &ACLContext{ + ACLs: &model.App{ + OAuth: model.AppOAuth{Groups: "/[/"}, + }, + UserContext: &model.UserContext{ + Provider: model.ProviderOAuth, + OAuth: &model.OAuthContext{ + ID: "custom", + Groups: []string{"admins"}, + }, + }, + }, + expected: EffectAbstain, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, rule.Evaluate(tt.ctx)) + }) + } +} + +func TestLDAPGroupRule(t *testing.T) { + log := logger.NewLogger().WithTestConfig() + log.Init() + + rule := &LDAPGroupRule{Log: log} + + tests := []struct { + name string + ctx *ACLContext + expected Effect + }{ + { + name: "abstains when context is nil", + ctx: nil, + expected: EffectAbstain, + }, + { + name: "abstains when user context is nil", + ctx: &ACLContext{ + ACLs: &model.App{ + OAuth: model.AppOAuth{Whitelist: "alice"}, + }, + UserContext: nil, + }, + expected: EffectAbstain, + }, + { + name: "abstains when user is not LDAP", + ctx: &ACLContext{ + ACLs: &model.App{ + LDAP: model.AppLDAP{Groups: "admins"}, + }, + UserContext: &model.UserContext{ + Provider: model.ProviderLocal, + Local: &model.LocalContext{ + BaseContext: model.BaseContext{Username: "alice"}, + }, + }, + }, + expected: EffectAbstain, + }, + { + name: "allows LDAP user when a group matches", + ctx: &ACLContext{ + ACLs: &model.App{ + LDAP: model.AppLDAP{Groups: "admins,users"}, + }, + UserContext: &model.UserContext{ + Provider: model.ProviderLDAP, + LDAP: &model.LDAPContext{ + Groups: []string{"users"}, + }, + }, + }, + expected: EffectAllow, + }, + { + name: "denies LDAP user when no group matches", + ctx: &ACLContext{ + ACLs: &model.App{ + LDAP: model.AppLDAP{Groups: "admins"}, + }, + UserContext: &model.UserContext{ + Provider: model.ProviderLDAP, + LDAP: &model.LDAPContext{ + Groups: []string{"users", "guests"}, + }, + }, + }, + expected: EffectDeny, + }, + { + name: "denies LDAP user when user has no groups", + ctx: &ACLContext{ + ACLs: &model.App{ + LDAP: model.AppLDAP{Groups: "admins"}, + }, + UserContext: &model.UserContext{ + Provider: model.ProviderLDAP, + LDAP: &model.LDAPContext{ + Groups: nil, + }, + }, + }, + expected: EffectDeny, + }, + { + name: "abstains when groups filter is invalid", + ctx: &ACLContext{ + ACLs: &model.App{ + LDAP: model.AppLDAP{Groups: "/[/"}, + }, + UserContext: &model.UserContext{ + Provider: model.ProviderLDAP, + LDAP: &model.LDAPContext{ + Groups: []string{"admins"}, + }, + }, + }, + expected: EffectAbstain, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, rule.Evaluate(tt.ctx)) + }) + } +} + +func TestAuthEnabledRule(t *testing.T) { + log := logger.NewLogger().WithTestConfig() + log.Init() + + rule := &AuthEnabledRule{Log: log} + + tests := []struct { + name string + ctx *ACLContext + expected Effect + }{ + { + name: "deny when ACLs are nil", + ctx: &ACLContext{ + ACLs: nil, + Path: "/anything", + }, + expected: EffectDeny, + }, + { + name: "allows when path does not match block regex", + ctx: &ACLContext{ + ACLs: &model.App{ + Path: model.AppPath{Block: "^/admin"}, + }, + Path: "/public", + }, + expected: EffectAllow, + }, + { + name: "denies when path matches block regex and no allow regex", + ctx: &ACLContext{ + ACLs: &model.App{ + Path: model.AppPath{Block: "^/admin"}, + }, + Path: "/admin/users", + }, + expected: EffectDeny, + }, + { + name: "allows when path matches allow regex", + ctx: &ACLContext{ + ACLs: &model.App{ + Path: model.AppPath{Allow: "^/public"}, + }, + Path: "/public/index", + }, + expected: EffectAllow, + }, + { + name: "denies when path does not match allow regex", + ctx: &ACLContext{ + ACLs: &model.App{ + Path: model.AppPath{Allow: "^/public"}, + }, + Path: "/private", + }, + expected: EffectDeny, + }, + { + name: "allows when blocked path is also explicitly allowed", + ctx: &ACLContext{ + ACLs: &model.App{ + Path: model.AppPath{ + Block: "^/admin", + Allow: "^/admin/public", + }, + }, + Path: "/admin/public/page", + }, + expected: EffectAllow, + }, + { + name: "denies when block regex fails to compile", + ctx: &ACLContext{ + ACLs: &model.App{ + Path: model.AppPath{Block: "[invalid"}, + }, + Path: "/anything", + }, + expected: EffectDeny, + }, + { + name: "denies when allow regex fails to compile", + ctx: &ACLContext{ + ACLs: &model.App{ + Path: model.AppPath{Allow: "[invalid"}, + }, + Path: "/anything", + }, + expected: EffectDeny, + }, + { + name: "denies when no path rules are configured", + ctx: &ACLContext{ + ACLs: &model.App{}, + Path: "/anything", + }, + expected: EffectDeny, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, rule.Evaluate(tt.ctx)) + }) + } +} + +func TestIPAllowedRule(t *testing.T) { + log := logger.NewLogger().WithTestConfig() + log.Init() + + tests := []struct { + name string + config model.Config + ctx *ACLContext + expected Effect + }{ + { + name: "abstains when ACLs are nil", + ctx: &ACLContext{ + ACLs: nil, + IP: net.ParseIP("10.0.0.1"), + }, + expected: EffectAbstain, + }, + { + name: "denies when IP matches app block list", + ctx: &ACLContext{ + ACLs: &model.App{ + IP: model.AppIP{Block: []string{"10.0.0.1"}}, + }, + IP: net.ParseIP("10.0.0.1"), + }, + expected: EffectDeny, + }, + { + name: "denies when IP matches global block list", + config: model.Config{ + Auth: model.AuthConfig{ + IP: model.IPConfig{Block: []string{"10.0.0.0/24"}}, + }, + }, + ctx: &ACLContext{ + ACLs: &model.App{}, + IP: net.ParseIP("10.0.0.5"), + }, + expected: EffectDeny, + }, + { + name: "allows when IP matches app allow list", + ctx: &ACLContext{ + ACLs: &model.App{ + IP: model.AppIP{Allow: []string{"192.168.1.0/24"}}, + }, + IP: net.ParseIP("192.168.1.10"), + }, + expected: EffectAllow, + }, + { + name: "allows when IP matches global allow list", + config: model.Config{ + Auth: model.AuthConfig{ + IP: model.IPConfig{Allow: []string{"192.168.1.10"}}, + }, + }, + ctx: &ACLContext{ + ACLs: &model.App{}, + IP: net.ParseIP("192.168.1.10"), + }, + expected: EffectAllow, + }, + { + name: "denies when allow list is set and IP does not match", + ctx: &ACLContext{ + ACLs: &model.App{ + IP: model.AppIP{Allow: []string{"192.168.1.0/24"}}, + }, + IP: net.ParseIP("10.0.0.1"), + }, + expected: EffectDeny, + }, + { + name: "allows when no block or allow lists are configured", + ctx: &ACLContext{ + ACLs: &model.App{}, + IP: net.ParseIP("10.0.0.1"), + }, + expected: EffectAllow, + }, + { + name: "block list takes precedence over allow list", + ctx: &ACLContext{ + ACLs: &model.App{ + IP: model.AppIP{ + Block: []string{"10.0.0.1"}, + Allow: []string{"10.0.0.1"}, + }, + }, + IP: net.ParseIP("10.0.0.1"), + }, + expected: EffectDeny, + }, + { + name: "skips invalid block entries and continues evaluation", + ctx: &ACLContext{ + ACLs: &model.App{ + IP: model.AppIP{ + Block: []string{"not-an-ip"}, + Allow: []string{"10.0.0.1"}, + }, + }, + IP: net.ParseIP("10.0.0.1"), + }, + expected: EffectAllow, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rule := &IPAllowedRule{Log: log, Config: tt.config} + assert.Equal(t, tt.expected, rule.Evaluate(tt.ctx)) + }) + } +} + +func TestIPBypassedRule(t *testing.T) { + log := logger.NewLogger().WithTestConfig() + log.Init() + + rule := &IPBypassedRule{Log: log} + + tests := []struct { + name string + ctx *ACLContext + expected Effect + }{ + { + name: "deny when ACLs are nil", + ctx: &ACLContext{ + ACLs: nil, + IP: net.ParseIP("10.0.0.1"), + }, + expected: EffectDeny, + }, + { + name: "allows when IP matches bypass list", + ctx: &ACLContext{ + ACLs: &model.App{ + IP: model.AppIP{Bypass: []string{"10.0.0.0/24"}}, + }, + IP: net.ParseIP("10.0.0.5"), + }, + expected: EffectAllow, + }, + { + name: "denies when IP does not match bypass list", + ctx: &ACLContext{ + ACLs: &model.App{ + IP: model.AppIP{Bypass: []string{"10.0.0.0/24"}}, + }, + IP: net.ParseIP("192.168.1.1"), + }, + expected: EffectDeny, + }, + { + name: "denies when bypass list is empty", + ctx: &ACLContext{ + ACLs: &model.App{}, + IP: net.ParseIP("10.0.0.1"), + }, + expected: EffectDeny, + }, + { + name: "skips invalid bypass entries and allows on later match", + ctx: &ACLContext{ + ACLs: &model.App{ + IP: model.AppIP{Bypass: []string{"not-an-ip", "10.0.0.1"}}, + }, + IP: net.ParseIP("10.0.0.1"), + }, + expected: EffectAllow, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, rule.Evaluate(tt.ctx)) + }) + } +} diff --git a/internal/service/access_controls_service.go b/internal/service/access_controls_service.go index 34700ea7..64c4d6fc 100644 --- a/internal/service/access_controls_service.go +++ b/internal/service/access_controls_service.go @@ -13,51 +13,52 @@ type LabelProvider interface { type AccessControlsService struct { log *logger.Logger + config model.Config labelProvider *LabelProvider - static map[string]model.App } func NewAccessControlsService( log *logger.Logger, - labelProvider *LabelProvider, - static map[string]model.App) *AccessControlsService { + config model.Config, + labelProvider *LabelProvider) *AccessControlsService { + return &AccessControlsService{ log: log, + config: config, labelProvider: labelProvider, - static: static, } } -func (acls *AccessControlsService) lookupStaticACLs(domain string) *model.App { - var appAcls *model.App - for app, config := range acls.static { +func (service *AccessControlsService) lookupStaticACLs(domain string) *model.App { + var nameMatch *model.App + + // First try to find a matching app by domain, then fallback to matching by app name (subdomain) + for app, config := range service.config.Apps { if config.Config.Domain == domain { - acls.log.App.Debug().Str("name", app).Msg("Found matching container by domain") - appAcls = &config - break // If we find a match by domain, we can stop searching + service.log.App.Debug().Str("name", app).Msg("Found matching container by domain") + return &config } - if strings.SplitN(domain, ".", 2)[0] == app { - acls.log.App.Debug().Str("name", app).Msg("Found matching container by app name") - appAcls = &config - break // If we find a match by app name, we can stop searching + service.log.App.Debug().Str("name", app).Msg("Found matching container by app name") + nameMatch = &config } } - return appAcls + + return nameMatch } -func (acls *AccessControlsService) GetAccessControls(domain string) (*model.App, error) { +func (service *AccessControlsService) GetAccessControls(domain string) (*model.App, error) { // First check in the static config - app := acls.lookupStaticACLs(domain) + app := service.lookupStaticACLs(domain) if app != nil { - acls.log.App.Debug().Msg("Using static ACLs for app") + service.log.App.Debug().Msg("Using static ACLs for app") return app, nil } // If we have a label provider configured, try to get ACLs from it - if acls.labelProvider != nil { - return (*acls.labelProvider).GetLabels(domain) + if service.labelProvider != nil && *service.labelProvider != nil { + return (*service.labelProvider).GetLabels(domain) } // no labels diff --git a/internal/service/access_controls_service_test.go b/internal/service/access_controls_service_test.go new file mode 100644 index 00000000..e3d32eb6 --- /dev/null +++ b/internal/service/access_controls_service_test.go @@ -0,0 +1,199 @@ +package service + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tinyauthapp/tinyauth/internal/model" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" +) + +type mockLabelProvider struct { + getLabelsFn func(appDomain string) (*model.App, error) + calledWith string + callCount int +} + +func (m *mockLabelProvider) GetLabels(appDomain string) (*model.App, error) { + m.calledWith = appDomain + m.callCount++ + if m.getLabelsFn != nil { + return m.getLabelsFn(appDomain) + } + return nil, nil +} + +func TestLookupStaticACLs(t *testing.T) { + log := logger.NewLogger().WithTestConfig() + log.Init() + + tests := []struct { + name string + apps map[string]model.App + domain string + expectNil bool + expectedDomain string + }{ + { + name: "returns nil when no apps are configured", + apps: nil, + domain: "foo.example.com", + expectNil: true, + }, + { + name: "returns nil when no app matches", + apps: map[string]model.App{ + "foo": {Config: model.AppConfig{Domain: "foo.example.com"}}, + }, + domain: "bar.example.com", + expectNil: true, + }, + { + name: "matches by exact domain", + apps: map[string]model.App{ + "foo": {Config: model.AppConfig{Domain: "foo.example.com"}}, + }, + domain: "foo.example.com", + expectedDomain: "foo.example.com", + }, + { + name: "matches by app name when domain does not match any app", + apps: map[string]model.App{ + "foo": {Config: model.AppConfig{Domain: "configured.example.com"}}, + }, + domain: "foo.example.com", + expectedDomain: "configured.example.com", + }, + { + name: "matches by app name for nested subdomains", + apps: map[string]model.App{ + "foo": {Config: model.AppConfig{Domain: "configured.example.com"}}, + }, + domain: "foo.sub.example.com", + expectedDomain: "configured.example.com", + }, + { + name: "selects the app matching by domain among multiple apps", + apps: map[string]model.App{ + "unrelated": {Config: model.AppConfig{Domain: "other.example.com"}}, + "target": {Config: model.AppConfig{Domain: "foo.example.com"}}, + }, + domain: "foo.example.com", + expectedDomain: "foo.example.com", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + svc := NewAccessControlsService(log, model.Config{Apps: tt.apps}, nil) + got := svc.lookupStaticACLs(tt.domain) + if tt.expectNil { + assert.Nil(t, got) + return + } + require.NotNil(t, got) + assert.Equal(t, tt.expectedDomain, got.Config.Domain) + }) + } +} + +func TestGetAccessControls(t *testing.T) { + log := logger.NewLogger().WithTestConfig() + log.Init() + + t.Run("returns static ACLs when domain matches", func(t *testing.T) { + config := model.Config{ + Apps: map[string]model.App{ + "foo": { + Config: model.AppConfig{Domain: "foo.example.com"}, + Users: model.AppUsers{Allow: "alice"}, + }, + }, + } + svc := NewAccessControlsService(log, config, nil) + + got, err := svc.GetAccessControls("foo.example.com") + + require.NoError(t, err) + require.NotNil(t, got) + assert.Equal(t, "foo.example.com", got.Config.Domain) + assert.Equal(t, "alice", got.Users.Allow) + }) + + t.Run("returns nil when no static match and no label provider", func(t *testing.T) { + svc := NewAccessControlsService(log, model.Config{}, nil) + + got, err := svc.GetAccessControls("unknown.example.com") + + require.NoError(t, err) + assert.Nil(t, got) + }) + + t.Run("returns nil when label provider pointer wraps a nil interface", func(t *testing.T) { + var provider LabelProvider + svc := NewAccessControlsService(log, model.Config{}, &provider) + + got, err := svc.GetAccessControls("unknown.example.com") + + require.NoError(t, err) + assert.Nil(t, got) + }) + + t.Run("falls back to label provider when no static match", func(t *testing.T) { + expected := &model.App{ + Config: model.AppConfig{Domain: "dynamic.example.com"}, + Users: model.AppUsers{Allow: "bob"}, + } + mock := &mockLabelProvider{ + getLabelsFn: func(appDomain string) (*model.App, error) { + return expected, nil + }, + } + var provider LabelProvider = mock + svc := NewAccessControlsService(log, model.Config{}, &provider) + + got, err := svc.GetAccessControls("dynamic.example.com") + + require.NoError(t, err) + assert.Same(t, expected, got) + assert.Equal(t, "dynamic.example.com", mock.calledWith) + assert.Equal(t, 1, mock.callCount) + }) + + t.Run("does not call label provider when static match found", func(t *testing.T) { + mock := &mockLabelProvider{} + var provider LabelProvider = mock + config := model.Config{ + Apps: map[string]model.App{ + "foo": {Config: model.AppConfig{Domain: "foo.example.com"}}, + }, + } + svc := NewAccessControlsService(log, config, &provider) + + got, err := svc.GetAccessControls("foo.example.com") + + require.NoError(t, err) + require.NotNil(t, got) + assert.Equal(t, "foo.example.com", got.Config.Domain) + assert.Equal(t, 0, mock.callCount) + }) + + t.Run("propagates label provider errors", func(t *testing.T) { + providerErr := errors.New("provider boom") + mock := &mockLabelProvider{ + getLabelsFn: func(appDomain string) (*model.App, error) { + return nil, providerErr + }, + } + var provider LabelProvider = mock + svc := NewAccessControlsService(log, model.Config{}, &provider) + + got, err := svc.GetAccessControls("dynamic.example.com") + + assert.Nil(t, got) + assert.ErrorIs(t, err, providerErr) + assert.Equal(t, 1, mock.callCount) + }) +} diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index 4cb3af81..ba748fe1 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "net/http" - "regexp" "strings" "sync" "time" @@ -17,7 +16,6 @@ import ( "slices" - "github.com/gin-gonic/gin" "github.com/google/uuid" "golang.org/x/crypto/bcrypt" "golang.org/x/oauth2" @@ -285,7 +283,12 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) { } func (auth *AuthService) IsEmailWhitelisted(email string) bool { - return utils.CheckFilter(strings.Join(auth.runtime.OAuthWhitelist, ","), email) + match, err := utils.CheckFilter(strings.Join(auth.runtime.OAuthWhitelist, ","), email) + if err != nil { + auth.log.App.Warn().Err(err).Str("email", email).Msg("Invalid email filter pattern") + return false + } + return match } func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) { @@ -453,171 +456,6 @@ func (auth *AuthService) LDAPAuthConfigured() bool { return auth.ldap != nil } -func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext, acls *model.App) bool { - if acls == nil { - return true - } - - if context.Provider == model.ProviderOAuth { - auth.log.App.Debug().Msg("User is an OAuth user, checking OAuth whitelist") - return utils.CheckFilter(acls.OAuth.Whitelist, context.OAuth.Email) - } - - if acls.Users.Block != "" { - auth.log.App.Debug().Msg("Checking users block list") - if utils.CheckFilter(acls.Users.Block, context.GetUsername()) { - return false - } - } - - auth.log.App.Debug().Msg("Checking users allow list") - return utils.CheckFilter(acls.Users.Allow, context.GetUsername()) -} - -func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContext, acls *model.App) bool { - if acls == nil { - return true - } - - if !context.IsOAuth() { - auth.log.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check") - return false - } - - if _, ok := model.OverrideProviders[context.OAuth.ID]; ok { - auth.log.App.Debug().Str("provider", context.OAuth.ID).Msg("Provider override detected, skipping group check") - return true - } - - for _, userGroup := range context.OAuth.Groups { - if utils.CheckFilter(acls.OAuth.Groups, strings.TrimSpace(userGroup)) { - auth.log.App.Trace().Str("group", userGroup).Str("required", acls.OAuth.Groups).Msg("User group matched") - return true - } - } - - auth.log.App.Debug().Msg("No groups matched") - return false -} - -func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext, acls *model.App) bool { - if acls == nil { - return true - } - - if !context.IsLDAP() { - auth.log.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check") - return false - } - - for _, userGroup := range context.LDAP.Groups { - if utils.CheckFilter(acls.LDAP.Groups, strings.TrimSpace(userGroup)) { - auth.log.App.Trace().Str("group", userGroup).Str("required", acls.LDAP.Groups).Msg("User group matched") - return true - } - } - - auth.log.App.Debug().Msg("No groups matched") - return false -} - -func (auth *AuthService) IsAuthEnabled(uri string, acls *model.App) (bool, error) { - if acls == nil { - return true, nil - } - - // Check for block list - if acls.Path.Block != "" { - regex, err := regexp.Compile(acls.Path.Block) - - if err != nil { - return true, err - } - - if !regex.MatchString(uri) { - return false, nil - } - } - - // Check for allow list - if acls.Path.Allow != "" { - regex, err := regexp.Compile(acls.Path.Allow) - - if err != nil { - return true, err - } - - if regex.MatchString(uri) { - return false, nil - } - } - - return true, nil -} - -func (auth *AuthService) CheckIP(ip string, acls *model.App) bool { - if acls == nil { - return true - } - - // Merge the global and app IP filter - blockedIps := append(auth.config.Auth.IP.Block, acls.IP.Block...) - allowedIPs := append(auth.config.Auth.IP.Allow, acls.IP.Allow...) - - for _, blocked := range blockedIps { - res, err := utils.FilterIP(blocked, ip) - if err != nil { - auth.log.App.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list") - continue - } - if res { - auth.log.App.Debug().Str("ip", ip).Str("item", blocked).Msg("IP is in block list, denying access") - return false - } - } - - for _, allowed := range allowedIPs { - res, err := utils.FilterIP(allowed, ip) - if err != nil { - auth.log.App.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list") - continue - } - if res { - auth.log.App.Debug().Str("ip", ip).Str("item", allowed).Msg("IP is in allow list, allowing access") - return true - } - } - - if len(allowedIPs) > 0 { - auth.log.App.Debug().Str("ip", ip).Msg("IP not in allow list, denying access") - return false - } - - auth.log.App.Debug().Str("ip", ip).Msg("IP not in any block or allow list, allowing access by default") - return true -} - -func (auth *AuthService) IsBypassedIP(ip string, acls *model.App) bool { - if acls == nil { - return false - } - - for _, bypassed := range acls.IP.Bypass { - res, err := utils.FilterIP(bypassed, ip) - if err != nil { - auth.log.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list") - continue - } - if res { - auth.log.App.Debug().Str("ip", ip).Str("item", bypassed).Msg("IP is in bypass list, skipping authentication") - return true - } - } - - auth.log.App.Debug().Str("ip", ip).Msg("IP not in bypass list, proceeding with authentication") - return false -} - func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLParams) (string, OAuthPendingSession, error) { auth.ensureOAuthSessionLimit() diff --git a/internal/service/docker_service.go b/internal/service/docker_service.go index 9d077c53..1413f04e 100644 --- a/internal/service/docker_service.go +++ b/internal/service/docker_service.go @@ -85,17 +85,23 @@ func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) { return nil, err } + var nameMatch *model.App + + // First try to find a matching app by domain, then fallback to matching by app name (subdomain) for appName, appLabels := range labels.Apps { if appLabels.Config.Domain == appDomain { docker.log.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by domain") return &appLabels, nil } - if strings.SplitN(appDomain, ".", 2)[0] == appName { docker.log.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by app name") - return &appLabels, nil + nameMatch = &appLabels } } + + if nameMatch != nil { + return nameMatch, nil + } } docker.log.App.Debug().Str("domain", appDomain).Msg("No matching container found for domain") diff --git a/internal/service/policy_engine.go b/internal/service/policy_engine.go new file mode 100644 index 00000000..4250d8a0 --- /dev/null +++ b/internal/service/policy_engine.go @@ -0,0 +1,110 @@ +package service + +import ( + "fmt" + "net" + + "github.com/tinyauthapp/tinyauth/internal/model" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" +) + +type Policy string + +const ( + PolicyAllow Policy = "allow" + PolicyDeny Policy = "deny" +) + +type Effect int + +const ( + EffectAbstain Effect = iota + EffectAllow + EffectDeny +) + +type Rule interface { + Evaluate(ctx *ACLContext) Effect +} + +type ACLContext struct { + ACLs *model.App + UserContext *model.UserContext + IP net.IP + Path string +} + +type PolicyEngine struct { + log *logger.Logger + rules map[RuleName]Rule + policy Policy +} + +func NewPolicyEngine(config model.Config, log *logger.Logger) (*PolicyEngine, error) { + engine := PolicyEngine{ + log: log, + rules: make(map[RuleName]Rule), + } + + switch config.Auth.ACLs.Policy { + case string(PolicyAllow): + log.App.Debug().Msg("Using 'allow' ACL policy: access to apps will be allowed by default unless explicitly blocked") + engine.policy = PolicyAllow + case string(PolicyDeny): + log.App.Debug().Msg("Using 'deny' ACL policy: access to apps will be blocked by default unless explicitly allowed") + engine.policy = PolicyDeny + default: + return nil, fmt.Errorf("invalid acl policy: %s", config.Auth.ACLs.Policy) + } + + return &engine, nil +} + +func (engine *PolicyEngine) RegisterRule(name RuleName, rule Rule) { + engine.log.App.Debug().Str("rule", string(name)).Msg("Registering ACL rule in policy engine") + engine.rules[name] = rule +} + +func (engine *PolicyEngine) evaluateRuleByName(name RuleName, ctx *ACLContext) Effect { + rule, exists := engine.rules[name] + + if !exists { + engine.log.App.Warn().Str("rule", string(name)).Msg("Rule not found in policy engine, defaulting to deny") + return EffectDeny + } + + return rule.Evaluate(ctx) +} + +func (engine *PolicyEngine) effectToAccess(effect Effect) bool { + switch effect { + case EffectAllow: + return true + case EffectDeny: + return false + default: + // If the effect is abstain, we fall back to the default policy + return engine.policy == PolicyAllow + } +} + +func (engine *PolicyEngine) Evaluate(name RuleName, ctx *ACLContext) bool { + effect := engine.evaluateRuleByName(name, ctx) + access := engine.effectToAccess(effect) + + engine.log.App.Debug(). + Str("rule", string(name)). + Int("effect", int(effect)). + Bool("access", access). + Msg("Evaluated ACL rule") + + return access +} + +func (engine *PolicyEngine) Policy() Policy { + return engine.policy +} + +func (engine *PolicyEngine) Rules() map[RuleName]Rule { + return engine.rules +} diff --git a/internal/service/policy_engine_test.go b/internal/service/policy_engine_test.go new file mode 100644 index 00000000..d1ef4796 --- /dev/null +++ b/internal/service/policy_engine_test.go @@ -0,0 +1,93 @@ +package service_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/tinyauthapp/tinyauth/internal/service" + "github.com/tinyauthapp/tinyauth/internal/test" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" +) + +// Create test rule +type TestRule struct{} + +func (rule *TestRule) Evaluate(ctx *service.ACLContext) service.Effect { + switch ctx.Path { + case "/allowed": + return service.EffectAllow + case "/denied": + return service.EffectDeny + default: + return service.EffectAbstain + } +} + +func TestPolicyEngine(t *testing.T) { + log := logger.NewLogger().WithTestConfig() + log.Init() + + cfg, _ := test.CreateTestConfigs(t) + + testRule := &TestRule{} + + // Engine should fail with invalid policy + cfg.Auth.ACLs.Policy = "invalid_policy" + _, err := service.NewPolicyEngine(cfg, log) + assert.Error(t, err) + + // Engine should initialize with 'allow' policy + cfg.Auth.ACLs.Policy = string(service.PolicyAllow) + engine, err := service.NewPolicyEngine(cfg, log) + assert.NoError(t, err) + assert.Equal(t, service.PolicyAllow, engine.Policy()) + + // Engine should initialize with 'deny' policy + cfg.Auth.ACLs.Policy = string(service.PolicyDeny) + engine, err = service.NewPolicyEngine(cfg, log) + assert.NoError(t, err) + assert.Equal(t, service.PolicyDeny, engine.Policy()) + + // Engine should allow adding rules + engine, err = service.NewPolicyEngine(cfg, log) + assert.NoError(t, err) + engine.RegisterRule("test-rule", testRule) + _, ok := engine.Rules()["test-rule"] + assert.True(t, ok) + + // Begin allow policy tests + cfg.Auth.ACLs.Policy = string(service.PolicyAllow) + engine, err = service.NewPolicyEngine(cfg, log) + assert.NoError(t, err) + engine.RegisterRule("test-rule", testRule) + + // With allow policy, if rule allows, access should be allowed + ctx := &service.ACLContext{Path: "/allowed"} + assert.Equal(t, true, engine.Evaluate("test-rule", ctx)) + + // With allow policy, if rule denies, access should be denied + ctx.Path = "/denied" + assert.Equal(t, false, engine.Evaluate("test-rule", ctx)) + + // With allow policy, if rule abstains, access should be allowed (default) + ctx.Path = "/abstain" + assert.Equal(t, true, engine.Evaluate("test-rule", ctx)) + + // Begin deny policy tests + cfg.Auth.ACLs.Policy = string(service.PolicyDeny) + engine, err = service.NewPolicyEngine(cfg, log) + assert.NoError(t, err) + engine.RegisterRule("test-rule", testRule) + + // With deny policy, if rule allows, access should be allowed + ctx.Path = "/allowed" + assert.Equal(t, true, engine.Evaluate("test-rule", ctx)) + + // With deny policy, if rule denies, access should be denied + ctx.Path = "/denied" + assert.Equal(t, false, engine.Evaluate("test-rule", ctx)) + + // With deny policy, if rule abstains, access should be denied (default) + ctx.Path = "/abstain" + assert.Equal(t, false, engine.Evaluate("test-rule", ctx)) +} diff --git a/internal/test/test.go b/internal/test/test.go index 73ff5d38..415591fa 100644 --- a/internal/test/test.go +++ b/internal/test/test.go @@ -40,6 +40,9 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) { SessionExpiry: 10, LoginTimeout: 10, LoginMaxRetries: 3, + ACLs: model.ACLsConfig{ + Policy: "allow", + }, }, Database: model.DatabaseConfig{ Path: filepath.Join(tempDir, "test.db"), @@ -48,6 +51,32 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) { Enabled: true, Path: filepath.Join(tempDir, "resources"), }, + Apps: map[string]model.App{ + "app_path_allow": { + Config: model.AppConfig{ + Domain: "path-allow.example.com", + }, + Path: model.AppPath{ + Allow: "/allowed", + }, + }, + "app_user_allow": { + Config: model.AppConfig{ + Domain: "user-allow.example.com", + }, + Users: model.AppUsers{ + Allow: "testuser", + }, + }, + "ip_bypass": { + Config: model.AppConfig{ + Domain: "ip-bypass.example.com", + }, + IP: model.AppIP{ + Bypass: []string{"10.10.10.10"}, + }, + }, + }, } passwd, err := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost) diff --git a/internal/utils/security_utils.go b/internal/utils/security_utils.go index abfdbfe8..8e8dd23b 100644 --- a/internal/utils/security_utils.go +++ b/internal/utils/security_utils.go @@ -3,7 +3,7 @@ package utils import ( "crypto/rand" "encoding/base64" - "errors" + "fmt" "net" "regexp" "strings" @@ -46,26 +46,27 @@ func EncodeBasicAuth(username string, password string) string { return base64.StdEncoding.EncodeToString([]byte(auth)) } -func FilterIP(filter string, ip string) (bool, error) { +func CheckIPFilter(filter string, ip string) (bool, error) { ipAddr := net.ParseIP(ip) if ipAddr == nil { - return false, errors.New("invalid IP address") + return false, fmt.Errorf("invalid ip address") } - filter = strings.Replace(filter, "-", "/", -1) + filter = strings.ReplaceAll(filter, "-", "/") if strings.Contains(filter, "/") { _, cidr, err := net.ParseCIDR(filter) if err != nil { - return false, err + return false, fmt.Errorf("invalid cidr notation: %w", err) } return cidr.Contains(ipAddr), nil } ipFilter := net.ParseIP(filter) + if ipFilter == nil { - return false, errors.New("invalid IP address in filter") + return false, fmt.Errorf("invalid ip address") } if ipFilter.Equal(ipAddr) { @@ -75,31 +76,29 @@ func FilterIP(filter string, ip string) (bool, error) { return false, nil } -func CheckFilter(filter string, str string) bool { +func CheckFilter(filter string, input string) (bool, error) { if len(strings.TrimSpace(filter)) == 0 { - return true + return false, fmt.Errorf("filter is empty") } if strings.HasPrefix(filter, "/") && strings.HasSuffix(filter, "/") { re, err := regexp.Compile(filter[1 : len(filter)-1]) if err != nil { - return false + return false, fmt.Errorf("invalid regex filter: %w", err) } - if re.MatchString(strings.TrimSpace(str)) { - return true + if re.MatchString(input) { + return true, nil } } - filterSplit := strings.Split(filter, ",") - - for _, item := range filterSplit { - if strings.TrimSpace(item) == strings.TrimSpace(str) { - return true + for item := range strings.SplitSeq(filter, ",") { + if strings.TrimSpace(item) == input { + return true, nil } } - return false + return false, nil } func GenerateUUID(str string) string { diff --git a/internal/utils/security_utils_test.go b/internal/utils/security_utils_test.go index 6feac4ca..193fbca3 100644 --- a/internal/utils/security_utils_test.go +++ b/internal/utils/security_utils_test.go @@ -75,66 +75,77 @@ func TestEncodeBasicAuth(t *testing.T) { assert.Equal(t, expected, utils.EncodeBasicAuth(username, password)) } -func TestFilterIP(t *testing.T) { +func TestCheckIPFilter(t *testing.T) { // Exact match IPv4 - ok, err := utils.FilterIP("10.10.0.1", "10.10.0.1") + ok, err := utils.CheckIPFilter("10.10.0.1", "10.10.0.1") assert.NoError(t, err) assert.Equal(t, true, ok) // Non-match IPv4 - ok, err = utils.FilterIP("10.10.0.1", "10.10.0.2") + ok, err = utils.CheckIPFilter("10.10.0.1", "10.10.0.2") assert.NoError(t, err) assert.Equal(t, false, ok) // CIDR match IPv4 - ok, err = utils.FilterIP("10.10.0.0/24", "10.10.0.2") + ok, err = utils.CheckIPFilter("10.10.0.0/24", "10.10.0.2") assert.NoError(t, err) assert.Equal(t, true, ok) // CIDR match IPv4 with '-' instead of '/' - ok, err = utils.FilterIP("10.10.10.0-24", "10.10.10.5") + ok, err = utils.CheckIPFilter("10.10.10.0-24", "10.10.10.5") assert.NoError(t, err) assert.Equal(t, true, ok) // CIDR non-match IPv4 - ok, err = utils.FilterIP("10.10.0.0/24", "10.5.0.1") + ok, err = utils.CheckIPFilter("10.10.0.0/24", "10.5.0.1") assert.NoError(t, err) assert.Equal(t, false, ok) // Invalid CIDR - ok, err = utils.FilterIP("10.10.0.0/222", "10.0.0.1") - assert.ErrorContains(t, err, "invalid CIDR address") + ok, err = utils.CheckIPFilter("10.10.0.0/222", "10.0.0.1") + assert.ErrorContains(t, err, "invalid cidr notation: invalid CIDR address: 10.10.0.0/222") assert.Equal(t, false, ok) // Invalid IP in filter - ok, err = utils.FilterIP("invalid_ip", "10.5.5.5") - assert.ErrorContains(t, err, "invalid IP address in filter") + ok, err = utils.CheckIPFilter("invalid_ip", "10.5.5.5") + assert.ErrorContains(t, err, "invalid ip address") assert.Equal(t, false, ok) // Invalid IP to check - ok, err = utils.FilterIP("10.10.10.10", "invalid_ip") - assert.ErrorContains(t, err, "invalid IP address") + ok, err = utils.CheckIPFilter("10.10.10.10", "invalid_ip") + assert.ErrorContains(t, err, "invalid ip address") assert.Equal(t, false, ok) } func TestCheckFilter(t *testing.T) { // Empty filter - assert.Equal(t, true, utils.CheckFilter("", "anystring")) + _, err := utils.CheckFilter("", "anystring") + assert.ErrorContains(t, err, "filter is empty") // Exact match - assert.Equal(t, true, utils.CheckFilter("hello", "hello")) + ok, err := utils.CheckFilter("hello", "hello") + assert.NoError(t, err) + assert.Equal(t, true, ok) // Regex match - assert.Equal(t, true, utils.CheckFilter("/^h.*o$/", "hello")) + ok, err = utils.CheckFilter("/^h.*o$/", "hello") + assert.NoError(t, err) + assert.Equal(t, true, ok) // Invalid regex - assert.Equal(t, false, utils.CheckFilter("/[unclosed", "test")) + ok, err = utils.CheckFilter("/[unclosed/", "test") + assert.ErrorContains(t, err, "invalid regex") + assert.Equal(t, false, ok) // Comma-separated values - assert.Equal(t, true, utils.CheckFilter("apple, banana, cherry", "banana")) + ok, err = utils.CheckFilter("apple, banana, cherry", "banana") + assert.NoError(t, err) + assert.Equal(t, true, ok) // No match - assert.Equal(t, false, utils.CheckFilter("apple, banana, cherry", "grape")) + ok, err = utils.CheckFilter("apple, banana, cherry", "grape") + assert.NoError(t, err) + assert.Equal(t, false, ok) } func TestGenerateUUID(t *testing.T) {