From 0a3e7bf265ccb4b76a407248ef295d60cab9f537 Mon Sep 17 00:00:00 2001 From: Stavros Date: Tue, 26 May 2026 00:07:46 +0300 Subject: [PATCH] fix: use policy engine in oauth whitelist check (#904) --- internal/bootstrap/service_bootstrap.go | 2 +- internal/controller/proxy_controller_test.go | 3 +- internal/controller/user_controller_test.go | 5 ++- .../middleware/context_middleware_test.go | 5 ++- internal/service/auth_service.go | 42 ++++++++++++------- internal/service/policy_engine.go | 4 ++ internal/utils/security_utils.go | 7 +++- 7 files changed, 48 insertions(+), 20 deletions(-) diff --git a/internal/bootstrap/service_bootstrap.go b/internal/bootstrap/service_bootstrap.go index 7474ec27..151cdeb8 100644 --- a/internal/bootstrap/service_bootstrap.go +++ b/internal/bootstrap/service_bootstrap.go @@ -42,7 +42,7 @@ func (app *BootstrapApp) setupServices() error { oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders, app.ctx) app.services.oauthBrokerService = oauthBrokerService - authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, &app.wg, app.services.ldapService, app.queries, app.services.oauthBrokerService, app.services.tailscaleService) + authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, &app.wg, app.services.ldapService, app.queries, app.services.oauthBrokerService, app.services.tailscaleService, app.services.policyEngine) app.services.authService = authService oidcService, err := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ctx, &app.wg) diff --git a/internal/controller/proxy_controller_test.go b/internal/controller/proxy_controller_test.go index ef004f8f..cc2226a1 100644 --- a/internal/controller/proxy_controller_test.go +++ b/internal/controller/proxy_controller_test.go @@ -357,7 +357,6 @@ func TestProxyController(t *testing.T) { ctx := context.TODO() broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx) - authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil) aclsService := service.NewAccessControlsService(log, cfg, nil) policyEngine, err := service.NewPolicyEngine(cfg, log) @@ -383,6 +382,8 @@ func TestProxyController(t *testing.T) { Log: log, }) + authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil, policyEngine) + for _, test := range tests { t.Run(test.description, func(t *testing.T) { router := gin.Default() diff --git a/internal/controller/user_controller_test.go b/internal/controller/user_controller_test.go index 39b343c0..7527d752 100644 --- a/internal/controller/user_controller_test.go +++ b/internal/controller/user_controller_test.go @@ -414,8 +414,11 @@ func TestUserController(t *testing.T) { ctx := context.TODO() wg := &sync.WaitGroup{} + policyEngine, err := service.NewPolicyEngine(cfg, log) + require.NoError(t, err) + broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx) - authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil) + authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil, policyEngine) beforeEach := func() { // Clear failed login attempts before each test diff --git a/internal/middleware/context_middleware_test.go b/internal/middleware/context_middleware_test.go index b672684f..b31231fa 100644 --- a/internal/middleware/context_middleware_test.go +++ b/internal/middleware/context_middleware_test.go @@ -254,8 +254,11 @@ func TestContextMiddleware(t *testing.T) { store := memory.New() + policyEngine, err := service.NewPolicyEngine(cfg, log) + require.NoError(t, err) + broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx) - authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil) + authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil, policyEngine) contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker, nil) diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index 5af7aa87..387eb7f0 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -75,10 +75,11 @@ type AuthService struct { runtime model.RuntimeConfig context context.Context - ldap *LdapService - queries repository.Store - oauthBroker *OAuthBrokerService - tailscale *TailscaleService + ldap *LdapService + queries repository.Store + oauthBroker *OAuthBrokerService + tailscale *TailscaleService + policyEngine *PolicyEngine loginAttempts map[string]*LoginAttempt ldapGroupsCache map[string]*LdapGroupsCache @@ -101,6 +102,7 @@ func NewAuthService( queries repository.Store, oauthBroker *OAuthBrokerService, tailscale *TailscaleService, + policy *PolicyEngine, ) *AuthService { service := &AuthService{ log: log, @@ -114,6 +116,7 @@ func NewAuthService( queries: queries, oauthBroker: oauthBroker, tailscale: tailscale, + policyEngine: policy, } wg.Go(service.CleanupOAuthSessionsRoutine) @@ -285,18 +288,27 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) { } } +// We could also directly access the policyEngine.effectToAccess but +// I believe it's better to use the exported functions instead func (auth *AuthService) IsEmailWhitelisted(provider string, email string) bool { - whitelist := auth.runtime.OAuthWhitelist - if providerConfig, ok := auth.runtime.OAuthProviders[provider]; ok && len(providerConfig.Whitelist) > 0 { - whitelist = providerConfig.Whitelist - } - - match, err := utils.CheckFilter(strings.Join(whitelist, ","), email) - if err != nil { - auth.log.App.Warn().Err(err).Str("provider", provider).Str("email", email).Msg("Invalid email filter pattern") - return false - } - return match + return auth.policyEngine.EvaluateFunc(func() Effect { + whitelist := auth.runtime.OAuthWhitelist + if providerConfig, ok := auth.runtime.OAuthProviders[provider]; ok && len(providerConfig.Whitelist) > 0 { + whitelist = providerConfig.Whitelist + } + match, err := utils.CheckFilter(strings.Join(whitelist, ","), email) + if err != nil { + if err == utils.ErrFilterEmpty { + return EffectAbstain + } + auth.log.App.Error().Err(err).Str("email", email).Msg("Failed to evaluate email whitelist filter, defaulting to deny") + return EffectDeny + } + if match { + return EffectAllow + } + return EffectDeny + }) } func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) { diff --git a/internal/service/policy_engine.go b/internal/service/policy_engine.go index 4250d8a0..7f301da6 100644 --- a/internal/service/policy_engine.go +++ b/internal/service/policy_engine.go @@ -108,3 +108,7 @@ func (engine *PolicyEngine) Policy() Policy { func (engine *PolicyEngine) Rules() map[RuleName]Rule { return engine.rules } + +func (engine *PolicyEngine) EvaluateFunc(f func() Effect) bool { + return engine.effectToAccess(f()) +} diff --git a/internal/utils/security_utils.go b/internal/utils/security_utils.go index 8e8dd23b..71b59d41 100644 --- a/internal/utils/security_utils.go +++ b/internal/utils/security_utils.go @@ -3,6 +3,7 @@ package utils import ( "crypto/rand" "encoding/base64" + "errors" "fmt" "net" "regexp" @@ -11,6 +12,10 @@ import ( "github.com/google/uuid" ) +var ( + ErrFilterEmpty = errors.New("filter is empty") +) + func GetSecret(conf string, file string) string { if conf == "" && file == "" { return "" @@ -78,7 +83,7 @@ func CheckIPFilter(filter string, ip string) (bool, error) { func CheckFilter(filter string, input string) (bool, error) { if len(strings.TrimSpace(filter)) == 0 { - return false, fmt.Errorf("filter is empty") + return false, ErrFilterEmpty } if strings.HasPrefix(filter, "/") && strings.HasSuffix(filter, "/") {