fix: use policy engine in oauth whitelist check (#904)

This commit is contained in:
Stavros
2026-05-26 00:07:46 +03:00
committed by GitHub
parent c3461131f5
commit 0a3e7bf265
7 changed files with 48 additions and 20 deletions
+1 -1
View File
@@ -42,7 +42,7 @@ func (app *BootstrapApp) setupServices() error {
oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders, app.ctx)
app.services.oauthBrokerService = oauthBrokerService
authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, &app.wg, app.services.ldapService, app.queries, app.services.oauthBrokerService, app.services.tailscaleService)
authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, &app.wg, app.services.ldapService, app.queries, app.services.oauthBrokerService, app.services.tailscaleService, app.services.policyEngine)
app.services.authService = authService
oidcService, err := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ctx, &app.wg)
+2 -1
View File
@@ -357,7 +357,6 @@ func TestProxyController(t *testing.T) {
ctx := context.TODO()
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil)
aclsService := service.NewAccessControlsService(log, cfg, nil)
policyEngine, err := service.NewPolicyEngine(cfg, log)
@@ -383,6 +382,8 @@ func TestProxyController(t *testing.T) {
Log: log,
})
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil, policyEngine)
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
router := gin.Default()
+4 -1
View File
@@ -414,8 +414,11 @@ func TestUserController(t *testing.T) {
ctx := context.TODO()
wg := &sync.WaitGroup{}
policyEngine, err := service.NewPolicyEngine(cfg, log)
require.NoError(t, err)
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil)
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil, policyEngine)
beforeEach := func() {
// Clear failed login attempts before each test
@@ -254,8 +254,11 @@ func TestContextMiddleware(t *testing.T) {
store := memory.New()
policyEngine, err := service.NewPolicyEngine(cfg, log)
require.NoError(t, err)
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil)
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil, policyEngine)
contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker, nil)
+27 -15
View File
@@ -75,10 +75,11 @@ type AuthService struct {
runtime model.RuntimeConfig
context context.Context
ldap *LdapService
queries repository.Store
oauthBroker *OAuthBrokerService
tailscale *TailscaleService
ldap *LdapService
queries repository.Store
oauthBroker *OAuthBrokerService
tailscale *TailscaleService
policyEngine *PolicyEngine
loginAttempts map[string]*LoginAttempt
ldapGroupsCache map[string]*LdapGroupsCache
@@ -101,6 +102,7 @@ func NewAuthService(
queries repository.Store,
oauthBroker *OAuthBrokerService,
tailscale *TailscaleService,
policy *PolicyEngine,
) *AuthService {
service := &AuthService{
log: log,
@@ -114,6 +116,7 @@ func NewAuthService(
queries: queries,
oauthBroker: oauthBroker,
tailscale: tailscale,
policyEngine: policy,
}
wg.Go(service.CleanupOAuthSessionsRoutine)
@@ -285,18 +288,27 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
}
}
// We could also directly access the policyEngine.effectToAccess but
// I believe it's better to use the exported functions instead
func (auth *AuthService) IsEmailWhitelisted(provider string, email string) bool {
whitelist := auth.runtime.OAuthWhitelist
if providerConfig, ok := auth.runtime.OAuthProviders[provider]; ok && len(providerConfig.Whitelist) > 0 {
whitelist = providerConfig.Whitelist
}
match, err := utils.CheckFilter(strings.Join(whitelist, ","), email)
if err != nil {
auth.log.App.Warn().Err(err).Str("provider", provider).Str("email", email).Msg("Invalid email filter pattern")
return false
}
return match
return auth.policyEngine.EvaluateFunc(func() Effect {
whitelist := auth.runtime.OAuthWhitelist
if providerConfig, ok := auth.runtime.OAuthProviders[provider]; ok && len(providerConfig.Whitelist) > 0 {
whitelist = providerConfig.Whitelist
}
match, err := utils.CheckFilter(strings.Join(whitelist, ","), email)
if err != nil {
if err == utils.ErrFilterEmpty {
return EffectAbstain
}
auth.log.App.Error().Err(err).Str("email", email).Msg("Failed to evaluate email whitelist filter, defaulting to deny")
return EffectDeny
}
if match {
return EffectAllow
}
return EffectDeny
})
}
func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) {
+4
View File
@@ -108,3 +108,7 @@ func (engine *PolicyEngine) Policy() Policy {
func (engine *PolicyEngine) Rules() map[RuleName]Rule {
return engine.rules
}
func (engine *PolicyEngine) EvaluateFunc(f func() Effect) bool {
return engine.effectToAccess(f())
}
+6 -1
View File
@@ -3,6 +3,7 @@ package utils
import (
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"net"
"regexp"
@@ -11,6 +12,10 @@ import (
"github.com/google/uuid"
)
var (
ErrFilterEmpty = errors.New("filter is empty")
)
func GetSecret(conf string, file string) string {
if conf == "" && file == "" {
return ""
@@ -78,7 +83,7 @@ func CheckIPFilter(filter string, ip string) (bool, error) {
func CheckFilter(filter string, input string) (bool, error) {
if len(strings.TrimSpace(filter)) == 0 {
return false, fmt.Errorf("filter is empty")
return false, ErrFilterEmpty
}
if strings.HasPrefix(filter, "/") && strings.HasSuffix(filter, "/") {