mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-05-26 14:10:13 +00:00
fix: use policy engine in oauth whitelist check (#904)
This commit is contained in:
@@ -42,7 +42,7 @@ func (app *BootstrapApp) setupServices() error {
|
|||||||
oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders, app.ctx)
|
oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders, app.ctx)
|
||||||
app.services.oauthBrokerService = oauthBrokerService
|
app.services.oauthBrokerService = oauthBrokerService
|
||||||
|
|
||||||
authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, &app.wg, app.services.ldapService, app.queries, app.services.oauthBrokerService, app.services.tailscaleService)
|
authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, &app.wg, app.services.ldapService, app.queries, app.services.oauthBrokerService, app.services.tailscaleService, app.services.policyEngine)
|
||||||
app.services.authService = authService
|
app.services.authService = authService
|
||||||
|
|
||||||
oidcService, err := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ctx, &app.wg)
|
oidcService, err := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ctx, &app.wg)
|
||||||
|
|||||||
@@ -357,7 +357,6 @@ func TestProxyController(t *testing.T) {
|
|||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
|
|
||||||
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
|
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
|
||||||
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil)
|
|
||||||
aclsService := service.NewAccessControlsService(log, cfg, nil)
|
aclsService := service.NewAccessControlsService(log, cfg, nil)
|
||||||
|
|
||||||
policyEngine, err := service.NewPolicyEngine(cfg, log)
|
policyEngine, err := service.NewPolicyEngine(cfg, log)
|
||||||
@@ -383,6 +382,8 @@ func TestProxyController(t *testing.T) {
|
|||||||
Log: log,
|
Log: log,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil, policyEngine)
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.description, func(t *testing.T) {
|
t.Run(test.description, func(t *testing.T) {
|
||||||
router := gin.Default()
|
router := gin.Default()
|
||||||
|
|||||||
@@ -414,8 +414,11 @@ func TestUserController(t *testing.T) {
|
|||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
wg := &sync.WaitGroup{}
|
wg := &sync.WaitGroup{}
|
||||||
|
|
||||||
|
policyEngine, err := service.NewPolicyEngine(cfg, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
|
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
|
||||||
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil)
|
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil, policyEngine)
|
||||||
|
|
||||||
beforeEach := func() {
|
beforeEach := func() {
|
||||||
// Clear failed login attempts before each test
|
// Clear failed login attempts before each test
|
||||||
|
|||||||
@@ -254,8 +254,11 @@ func TestContextMiddleware(t *testing.T) {
|
|||||||
|
|
||||||
store := memory.New()
|
store := memory.New()
|
||||||
|
|
||||||
|
policyEngine, err := service.NewPolicyEngine(cfg, log)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
|
broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx)
|
||||||
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil)
|
authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, store, broker, nil, policyEngine)
|
||||||
|
|
||||||
contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker, nil)
|
contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker, nil)
|
||||||
|
|
||||||
|
|||||||
@@ -75,10 +75,11 @@ type AuthService struct {
|
|||||||
runtime model.RuntimeConfig
|
runtime model.RuntimeConfig
|
||||||
context context.Context
|
context context.Context
|
||||||
|
|
||||||
ldap *LdapService
|
ldap *LdapService
|
||||||
queries repository.Store
|
queries repository.Store
|
||||||
oauthBroker *OAuthBrokerService
|
oauthBroker *OAuthBrokerService
|
||||||
tailscale *TailscaleService
|
tailscale *TailscaleService
|
||||||
|
policyEngine *PolicyEngine
|
||||||
|
|
||||||
loginAttempts map[string]*LoginAttempt
|
loginAttempts map[string]*LoginAttempt
|
||||||
ldapGroupsCache map[string]*LdapGroupsCache
|
ldapGroupsCache map[string]*LdapGroupsCache
|
||||||
@@ -101,6 +102,7 @@ func NewAuthService(
|
|||||||
queries repository.Store,
|
queries repository.Store,
|
||||||
oauthBroker *OAuthBrokerService,
|
oauthBroker *OAuthBrokerService,
|
||||||
tailscale *TailscaleService,
|
tailscale *TailscaleService,
|
||||||
|
policy *PolicyEngine,
|
||||||
) *AuthService {
|
) *AuthService {
|
||||||
service := &AuthService{
|
service := &AuthService{
|
||||||
log: log,
|
log: log,
|
||||||
@@ -114,6 +116,7 @@ func NewAuthService(
|
|||||||
queries: queries,
|
queries: queries,
|
||||||
oauthBroker: oauthBroker,
|
oauthBroker: oauthBroker,
|
||||||
tailscale: tailscale,
|
tailscale: tailscale,
|
||||||
|
policyEngine: policy,
|
||||||
}
|
}
|
||||||
|
|
||||||
wg.Go(service.CleanupOAuthSessionsRoutine)
|
wg.Go(service.CleanupOAuthSessionsRoutine)
|
||||||
@@ -285,18 +288,27 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We could also directly access the policyEngine.effectToAccess but
|
||||||
|
// I believe it's better to use the exported functions instead
|
||||||
func (auth *AuthService) IsEmailWhitelisted(provider string, email string) bool {
|
func (auth *AuthService) IsEmailWhitelisted(provider string, email string) bool {
|
||||||
whitelist := auth.runtime.OAuthWhitelist
|
return auth.policyEngine.EvaluateFunc(func() Effect {
|
||||||
if providerConfig, ok := auth.runtime.OAuthProviders[provider]; ok && len(providerConfig.Whitelist) > 0 {
|
whitelist := auth.runtime.OAuthWhitelist
|
||||||
whitelist = providerConfig.Whitelist
|
if providerConfig, ok := auth.runtime.OAuthProviders[provider]; ok && len(providerConfig.Whitelist) > 0 {
|
||||||
}
|
whitelist = providerConfig.Whitelist
|
||||||
|
}
|
||||||
match, err := utils.CheckFilter(strings.Join(whitelist, ","), email)
|
match, err := utils.CheckFilter(strings.Join(whitelist, ","), email)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
auth.log.App.Warn().Err(err).Str("provider", provider).Str("email", email).Msg("Invalid email filter pattern")
|
if err == utils.ErrFilterEmpty {
|
||||||
return false
|
return EffectAbstain
|
||||||
}
|
}
|
||||||
return match
|
auth.log.App.Error().Err(err).Str("email", email).Msg("Failed to evaluate email whitelist filter, defaulting to deny")
|
||||||
|
return EffectDeny
|
||||||
|
}
|
||||||
|
if match {
|
||||||
|
return EffectAllow
|
||||||
|
}
|
||||||
|
return EffectDeny
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) {
|
func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) {
|
||||||
|
|||||||
@@ -108,3 +108,7 @@ func (engine *PolicyEngine) Policy() Policy {
|
|||||||
func (engine *PolicyEngine) Rules() map[RuleName]Rule {
|
func (engine *PolicyEngine) Rules() map[RuleName]Rule {
|
||||||
return engine.rules
|
return engine.rules
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (engine *PolicyEngine) EvaluateFunc(f func() Effect) bool {
|
||||||
|
return engine.effectToAccess(f())
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package utils
|
|||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"regexp"
|
"regexp"
|
||||||
@@ -11,6 +12,10 @@ import (
|
|||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrFilterEmpty = errors.New("filter is empty")
|
||||||
|
)
|
||||||
|
|
||||||
func GetSecret(conf string, file string) string {
|
func GetSecret(conf string, file string) string {
|
||||||
if conf == "" && file == "" {
|
if conf == "" && file == "" {
|
||||||
return ""
|
return ""
|
||||||
@@ -78,7 +83,7 @@ func CheckIPFilter(filter string, ip string) (bool, error) {
|
|||||||
|
|
||||||
func CheckFilter(filter string, input string) (bool, error) {
|
func CheckFilter(filter string, input string) (bool, error) {
|
||||||
if len(strings.TrimSpace(filter)) == 0 {
|
if len(strings.TrimSpace(filter)) == 0 {
|
||||||
return false, fmt.Errorf("filter is empty")
|
return false, ErrFilterEmpty
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.HasPrefix(filter, "/") && strings.HasSuffix(filter, "/") {
|
if strings.HasPrefix(filter, "/") && strings.HasSuffix(filter, "/") {
|
||||||
|
|||||||
Reference in New Issue
Block a user