diff --git a/.env.example b/.env.example index cd29653c..5fd3ae19 100644 --- a/.env.example +++ b/.env.example @@ -101,6 +101,10 @@ TINYAUTH_OAUTH_PROVIDERS_name_CLIENTID= TINYAUTH_OAUTH_PROVIDERS_name_CLIENTSECRET= # Path to the file containing the OAuth client secret. TINYAUTH_OAUTH_PROVIDERS_name_CLIENTSECRETFILE= +# Comma-separated list of allowed OAuth domains for this provider. +TINYAUTH_OAUTH_PROVIDERS_name_WHITELIST= +# Path to the OAuth whitelist file for this provider. +TINYAUTH_OAUTH_PROVIDERS_name_WHITELISTFILE= # OAuth scopes. TINYAUTH_OAUTH_PROVIDERS_name_SCOPES= # OAuth redirect URL. diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go index 92b049ef..0bdd2214 100644 --- a/internal/bootstrap/app_bootstrap.go +++ b/internal/bootstrap/app_bootstrap.go @@ -117,6 +117,13 @@ func (app *BootstrapApp) Setup() error { app.runtime.OAuthProviders = app.config.OAuth.Providers for id, provider := range app.runtime.OAuthProviders { + providerWhitelist, err := utils.GetStringList(provider.Whitelist, provider.WhitelistFile) + if err != nil { + return fmt.Errorf("failed to load oauth whitelist for provider %s: %w", id, err) + } + + provider.Whitelist = providerWhitelist + secret := utils.GetSecret(provider.ClientSecret, provider.ClientSecretFile) provider.ClientSecret = secret provider.ClientSecretFile = "" diff --git a/internal/controller/oauth_controller.go b/internal/controller/oauth_controller.go index e72c09fd..18bed57c 100644 --- a/internal/controller/oauth_controller.go +++ b/internal/controller/oauth_controller.go @@ -183,9 +183,23 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { return } - if !controller.auth.IsEmailWhitelisted(user.Email) { + svc, err := controller.auth.GetOAuthService(sessionIdCookie) + + if err != nil { + controller.log.App.Error().Err(err).Msg("Failed to get OAuth service for session") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) + return + } + + if svc.ID() != req.Provider { + controller.log.App.Warn().Msgf("OAuth provider mismatch: expected %s, got %s", req.Provider, svc.ID()) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) + return + } + + if !controller.auth.IsEmailWhitelisted(svc.ID(), user.Email) { controller.log.App.Warn().Str("email", user.Email).Msg("Email not whitelisted, denying access") - controller.log.AuditLoginFailure(user.Email, req.Provider, c.ClientIP(), "email not whitelisted") + controller.log.AuditLoginFailure(user.Email, svc.ID(), c.ClientIP(), "email not whitelisted") queries, err := query.Values(UnauthorizedQuery{ Username: user.Email, @@ -226,20 +240,6 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { username = strings.Replace(user.Email, "@", "_", 1) } - svc, err := controller.auth.GetOAuthService(sessionIdCookie) - - if err != nil { - controller.log.App.Error().Err(err).Msg("Failed to get OAuth service for session") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) - return - } - - if svc.ID() != req.Provider { - controller.log.App.Warn().Msgf("OAuth provider mismatch: expected %s, got %s", req.Provider, svc.ID()) - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) - return - } - sessionCookie := repository.Session{ Username: username, Name: name, diff --git a/internal/middleware/context_middleware.go b/internal/middleware/context_middleware.go index a75582a7..a7223525 100644 --- a/internal/middleware/context_middleware.go +++ b/internal/middleware/context_middleware.go @@ -205,7 +205,7 @@ func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string, ip stri return nil, nil, fmt.Errorf("oauth provider from session cookie not found: %s", userContext.OAuth.ID) } - if !m.auth.IsEmailWhitelisted(userContext.OAuth.Email) { + if !m.auth.IsEmailWhitelisted(userContext.OAuth.ID, userContext.OAuth.Email) { m.auth.DeleteSession(ctx, uuid) return nil, nil, fmt.Errorf("email from session cookie not whitelisted: %s", userContext.OAuth.Email) } diff --git a/internal/model/config.go b/internal/model/config.go index 5963e431..07c9a4f5 100644 --- a/internal/model/config.go +++ b/internal/model/config.go @@ -226,6 +226,8 @@ type OAuthServiceConfig struct { ClientID string `description:"OAuth client ID." yaml:"clientId"` ClientSecret string `description:"OAuth client secret." yaml:"clientSecret"` ClientSecretFile string `description:"Path to the file containing the OAuth client secret." yaml:"clientSecretFile"` + Whitelist []string `description:"Comma-separated list of allowed OAuth domains for this provider." yaml:"whitelist"` + WhitelistFile string `description:"Path to the OAuth whitelist file for this provider." yaml:"whitelistFile"` Scopes []string `description:"OAuth scopes." yaml:"scopes"` RedirectURL string `description:"OAuth redirect URL." yaml:"redirectUrl"` AuthURL string `description:"OAuth authorization URL." yaml:"authUrl"` diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index 76fdafbd..5af7aa87 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -285,10 +285,15 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) { } } -func (auth *AuthService) IsEmailWhitelisted(email string) bool { - match, err := utils.CheckFilter(strings.Join(auth.runtime.OAuthWhitelist, ","), email) +func (auth *AuthService) IsEmailWhitelisted(provider string, email string) bool { + whitelist := auth.runtime.OAuthWhitelist + if providerConfig, ok := auth.runtime.OAuthProviders[provider]; ok && len(providerConfig.Whitelist) > 0 { + whitelist = providerConfig.Whitelist + } + + match, err := utils.CheckFilter(strings.Join(whitelist, ","), email) if err != nil { - auth.log.App.Warn().Err(err).Str("email", email).Msg("Invalid email filter pattern") + auth.log.App.Warn().Err(err).Str("provider", provider).Str("email", email).Msg("Invalid email filter pattern") return false } return match diff --git a/internal/service/auth_service_test.go b/internal/service/auth_service_test.go new file mode 100644 index 00000000..3000adcc --- /dev/null +++ b/internal/service/auth_service_test.go @@ -0,0 +1,39 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/tinyauthapp/tinyauth/internal/model" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" +) + +func TestIsEmailWhitelistedUsesProviderSpecificList(t *testing.T) { + log := logger.NewLogger().WithTestConfig() + log.Init() + + auth := &AuthService{ + log: log, + runtime: model.RuntimeConfig{ + OAuthWhitelist: []string{"global@example.com"}, + OAuthProviders: map[string]model.OAuthServiceConfig{ + "github": { + Whitelist: []string{"github@example.com"}, + }, + "pocketid": { + Whitelist: []string{"pocket@example.com"}, + }, + "gitlab": { + Whitelist: []string{}, + }, + }, + }, + } + + assert.True(t, auth.IsEmailWhitelisted("github", "github@example.com")) + assert.False(t, auth.IsEmailWhitelisted("github", "pocket@example.com")) + assert.True(t, auth.IsEmailWhitelisted("pocketid", "pocket@example.com")) + assert.True(t, auth.IsEmailWhitelisted("google", "global@example.com")) + assert.True(t, auth.IsEmailWhitelisted("gitlab", "global@example.com")) + assert.False(t, auth.IsEmailWhitelisted("gitlab", "unknown@example.com")) +}