mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-05-25 21:50:16 +00:00
fix: use policy engine in oauth whitelist check (#904)
This commit is contained in:
@@ -42,7 +42,7 @@ func (app *BootstrapApp) setupServices() error {
|
||||
oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders, app.ctx)
|
||||
app.services.oauthBrokerService = oauthBrokerService
|
||||
|
||||
authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, &app.wg, app.services.ldapService, app.queries, app.services.oauthBrokerService, app.services.tailscaleService)
|
||||
authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, &app.wg, app.services.ldapService, app.queries, app.services.oauthBrokerService, app.services.tailscaleService, app.services.policyEngine)
|
||||
app.services.authService = authService
|
||||
|
||||
oidcService, err := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ctx, &app.wg)
|
||||
|
||||
@@ -357,7 +357,6 @@ func TestProxyController(t *testing.T) {
|
||||
ctx := context.TODO()
|
||||
|
||||
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
|
||||
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil)
|
||||
aclsService := service.NewAccessControlsService(log, cfg, nil)
|
||||
|
||||
policyEngine, err := service.NewPolicyEngine(cfg, log)
|
||||
@@ -383,6 +382,8 @@ func TestProxyController(t *testing.T) {
|
||||
Log: log,
|
||||
})
|
||||
|
||||
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil, policyEngine)
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
router := gin.Default()
|
||||
|
||||
@@ -414,8 +414,11 @@ func TestUserController(t *testing.T) {
|
||||
ctx := context.TODO()
|
||||
wg := &sync.WaitGroup{}
|
||||
|
||||
policyEngine, err := service.NewPolicyEngine(cfg, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
|
||||
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil)
|
||||
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil, policyEngine)
|
||||
|
||||
beforeEach := func() {
|
||||
// Clear failed login attempts before each test
|
||||
|
||||
@@ -254,8 +254,11 @@ func TestContextMiddleware(t *testing.T) {
|
||||
|
||||
store := memory.New()
|
||||
|
||||
policyEngine, err := service.NewPolicyEngine(cfg, log)
|
||||
require.NoError(t, err)
|
||||
|
||||
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
|
||||
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil)
|
||||
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil, policyEngine)
|
||||
|
||||
contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker, nil)
|
||||
|
||||
|
||||
@@ -75,10 +75,11 @@ type AuthService struct {
|
||||
runtime model.RuntimeConfig
|
||||
context context.Context
|
||||
|
||||
ldap *LdapService
|
||||
queries repository.Store
|
||||
oauthBroker *OAuthBrokerService
|
||||
tailscale *TailscaleService
|
||||
ldap *LdapService
|
||||
queries repository.Store
|
||||
oauthBroker *OAuthBrokerService
|
||||
tailscale *TailscaleService
|
||||
policyEngine *PolicyEngine
|
||||
|
||||
loginAttempts map[string]*LoginAttempt
|
||||
ldapGroupsCache map[string]*LdapGroupsCache
|
||||
@@ -101,6 +102,7 @@ func NewAuthService(
|
||||
queries repository.Store,
|
||||
oauthBroker *OAuthBrokerService,
|
||||
tailscale *TailscaleService,
|
||||
policy *PolicyEngine,
|
||||
) *AuthService {
|
||||
service := &AuthService{
|
||||
log: log,
|
||||
@@ -114,6 +116,7 @@ func NewAuthService(
|
||||
queries: queries,
|
||||
oauthBroker: oauthBroker,
|
||||
tailscale: tailscale,
|
||||
policyEngine: policy,
|
||||
}
|
||||
|
||||
wg.Go(service.CleanupOAuthSessionsRoutine)
|
||||
@@ -285,18 +288,27 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
|
||||
}
|
||||
}
|
||||
|
||||
// We could also directly access the policyEngine.effectToAccess but
|
||||
// I believe it's better to use the exported functions instead
|
||||
func (auth *AuthService) IsEmailWhitelisted(provider string, email string) bool {
|
||||
whitelist := auth.runtime.OAuthWhitelist
|
||||
if providerConfig, ok := auth.runtime.OAuthProviders[provider]; ok && len(providerConfig.Whitelist) > 0 {
|
||||
whitelist = providerConfig.Whitelist
|
||||
}
|
||||
|
||||
match, err := utils.CheckFilter(strings.Join(whitelist, ","), email)
|
||||
if err != nil {
|
||||
auth.log.App.Warn().Err(err).Str("provider", provider).Str("email", email).Msg("Invalid email filter pattern")
|
||||
return false
|
||||
}
|
||||
return match
|
||||
return auth.policyEngine.EvaluateFunc(func() Effect {
|
||||
whitelist := auth.runtime.OAuthWhitelist
|
||||
if providerConfig, ok := auth.runtime.OAuthProviders[provider]; ok && len(providerConfig.Whitelist) > 0 {
|
||||
whitelist = providerConfig.Whitelist
|
||||
}
|
||||
match, err := utils.CheckFilter(strings.Join(whitelist, ","), email)
|
||||
if err != nil {
|
||||
if err == utils.ErrFilterEmpty {
|
||||
return EffectAbstain
|
||||
}
|
||||
auth.log.App.Error().Err(err).Str("email", email).Msg("Failed to evaluate email whitelist filter, defaulting to deny")
|
||||
return EffectDeny
|
||||
}
|
||||
if match {
|
||||
return EffectAllow
|
||||
}
|
||||
return EffectDeny
|
||||
})
|
||||
}
|
||||
|
||||
func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) {
|
||||
|
||||
@@ -108,3 +108,7 @@ func (engine *PolicyEngine) Policy() Policy {
|
||||
func (engine *PolicyEngine) Rules() map[RuleName]Rule {
|
||||
return engine.rules
|
||||
}
|
||||
|
||||
func (engine *PolicyEngine) EvaluateFunc(f func() Effect) bool {
|
||||
return engine.effectToAccess(f())
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package utils
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"regexp"
|
||||
@@ -11,6 +12,10 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrFilterEmpty = errors.New("filter is empty")
|
||||
)
|
||||
|
||||
func GetSecret(conf string, file string) string {
|
||||
if conf == "" && file == "" {
|
||||
return ""
|
||||
@@ -78,7 +83,7 @@ func CheckIPFilter(filter string, ip string) (bool, error) {
|
||||
|
||||
func CheckFilter(filter string, input string) (bool, error) {
|
||||
if len(strings.TrimSpace(filter)) == 0 {
|
||||
return false, fmt.Errorf("filter is empty")
|
||||
return false, ErrFilterEmpty
|
||||
}
|
||||
|
||||
if strings.HasPrefix(filter, "/") && strings.HasSuffix(filter, "/") {
|
||||
|
||||
Reference in New Issue
Block a user