fix: use policy engine in oauth whitelist check (#904)

This commit is contained in:
Stavros
2026-05-26 00:07:46 +03:00
committed by GitHub
parent c3461131f5
commit 0a3e7bf265
7 changed files with 48 additions and 20 deletions
+1 -1
View File
@@ -42,7 +42,7 @@ func (app *BootstrapApp) setupServices() error {
oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders, app.ctx) oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders, app.ctx)
app.services.oauthBrokerService = oauthBrokerService app.services.oauthBrokerService = oauthBrokerService
authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, &app.wg, app.services.ldapService, app.queries, app.services.oauthBrokerService, app.services.tailscaleService) authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, &app.wg, app.services.ldapService, app.queries, app.services.oauthBrokerService, app.services.tailscaleService, app.services.policyEngine)
app.services.authService = authService app.services.authService = authService
oidcService, err := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ctx, &app.wg) oidcService, err := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ctx, &app.wg)
+2 -1
View File
@@ -357,7 +357,6 @@ func TestProxyController(t *testing.T) {
ctx := context.TODO() ctx := context.TODO()
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx) broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil)
aclsService := service.NewAccessControlsService(log, cfg, nil) aclsService := service.NewAccessControlsService(log, cfg, nil)
policyEngine, err := service.NewPolicyEngine(cfg, log) policyEngine, err := service.NewPolicyEngine(cfg, log)
@@ -383,6 +382,8 @@ func TestProxyController(t *testing.T) {
Log: log, Log: log,
}) })
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil, policyEngine)
for _, test := range tests { for _, test := range tests {
t.Run(test.description, func(t *testing.T) { t.Run(test.description, func(t *testing.T) {
router := gin.Default() router := gin.Default()
+4 -1
View File
@@ -414,8 +414,11 @@ func TestUserController(t *testing.T) {
ctx := context.TODO() ctx := context.TODO()
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}
policyEngine, err := service.NewPolicyEngine(cfg, log)
require.NoError(t, err)
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx) broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil) authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil, policyEngine)
beforeEach := func() { beforeEach := func() {
// Clear failed login attempts before each test // Clear failed login attempts before each test
@@ -254,8 +254,11 @@ func TestContextMiddleware(t *testing.T) {
store := memory.New() store := memory.New()
policyEngine, err := service.NewPolicyEngine(cfg, log)
require.NoError(t, err)
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx) broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil) authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil, policyEngine)
contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker, nil) contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker, nil)
+27 -15
View File
@@ -75,10 +75,11 @@ type AuthService struct {
runtime model.RuntimeConfig runtime model.RuntimeConfig
context context.Context context context.Context
ldap *LdapService ldap *LdapService
queries repository.Store queries repository.Store
oauthBroker *OAuthBrokerService oauthBroker *OAuthBrokerService
tailscale *TailscaleService tailscale *TailscaleService
policyEngine *PolicyEngine
loginAttempts map[string]*LoginAttempt loginAttempts map[string]*LoginAttempt
ldapGroupsCache map[string]*LdapGroupsCache ldapGroupsCache map[string]*LdapGroupsCache
@@ -101,6 +102,7 @@ func NewAuthService(
queries repository.Store, queries repository.Store,
oauthBroker *OAuthBrokerService, oauthBroker *OAuthBrokerService,
tailscale *TailscaleService, tailscale *TailscaleService,
policy *PolicyEngine,
) *AuthService { ) *AuthService {
service := &AuthService{ service := &AuthService{
log: log, log: log,
@@ -114,6 +116,7 @@ func NewAuthService(
queries: queries, queries: queries,
oauthBroker: oauthBroker, oauthBroker: oauthBroker,
tailscale: tailscale, tailscale: tailscale,
policyEngine: policy,
} }
wg.Go(service.CleanupOAuthSessionsRoutine) wg.Go(service.CleanupOAuthSessionsRoutine)
@@ -285,18 +288,27 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
} }
} }
// We could also directly access the policyEngine.effectToAccess but
// I believe it's better to use the exported functions instead
func (auth *AuthService) IsEmailWhitelisted(provider string, email string) bool { func (auth *AuthService) IsEmailWhitelisted(provider string, email string) bool {
whitelist := auth.runtime.OAuthWhitelist return auth.policyEngine.EvaluateFunc(func() Effect {
if providerConfig, ok := auth.runtime.OAuthProviders[provider]; ok && len(providerConfig.Whitelist) > 0 { whitelist := auth.runtime.OAuthWhitelist
whitelist = providerConfig.Whitelist if providerConfig, ok := auth.runtime.OAuthProviders[provider]; ok && len(providerConfig.Whitelist) > 0 {
} whitelist = providerConfig.Whitelist
}
match, err := utils.CheckFilter(strings.Join(whitelist, ","), email) match, err := utils.CheckFilter(strings.Join(whitelist, ","), email)
if err != nil { if err != nil {
auth.log.App.Warn().Err(err).Str("provider", provider).Str("email", email).Msg("Invalid email filter pattern") if err == utils.ErrFilterEmpty {
return false return EffectAbstain
} }
return match auth.log.App.Error().Err(err).Str("email", email).Msg("Failed to evaluate email whitelist filter, defaulting to deny")
return EffectDeny
}
if match {
return EffectAllow
}
return EffectDeny
})
} }
func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) { func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) {
+4
View File
@@ -108,3 +108,7 @@ func (engine *PolicyEngine) Policy() Policy {
func (engine *PolicyEngine) Rules() map[RuleName]Rule { func (engine *PolicyEngine) Rules() map[RuleName]Rule {
return engine.rules return engine.rules
} }
func (engine *PolicyEngine) EvaluateFunc(f func() Effect) bool {
return engine.effectToAccess(f())
}
+6 -1
View File
@@ -3,6 +3,7 @@ package utils
import ( import (
"crypto/rand" "crypto/rand"
"encoding/base64" "encoding/base64"
"errors"
"fmt" "fmt"
"net" "net"
"regexp" "regexp"
@@ -11,6 +12,10 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
) )
var (
ErrFilterEmpty = errors.New("filter is empty")
)
func GetSecret(conf string, file string) string { func GetSecret(conf string, file string) string {
if conf == "" && file == "" { if conf == "" && file == "" {
return "" return ""
@@ -78,7 +83,7 @@ func CheckIPFilter(filter string, ip string) (bool, error) {
func CheckFilter(filter string, input string) (bool, error) { func CheckFilter(filter string, input string) (bool, error) {
if len(strings.TrimSpace(filter)) == 0 { if len(strings.TrimSpace(filter)) == 0 {
return false, fmt.Errorf("filter is empty") return false, ErrFilterEmpty
} }
if strings.HasPrefix(filter, "/") && strings.HasSuffix(filter, "/") { if strings.HasPrefix(filter, "/") && strings.HasSuffix(filter, "/") {