feat: add support for oauth whitelist file (#817)

This commit is contained in:
djedditt
2026-04-29 02:53:56 +02:00
parent d51e3efe32
commit 6b5a6bd982
7 changed files with 84 additions and 27 deletions
+2
View File
@@ -91,6 +91,8 @@ TINYAUTH_APPS_name_LDAP_GROUPS=
# Comma-separated list of allowed OAuth domains. # Comma-separated list of allowed OAuth domains.
TINYAUTH_OAUTH_WHITELIST= TINYAUTH_OAUTH_WHITELIST=
# Path to the OAuth whitelist file.
TINYAUTH_OAUTH_WHITELISTFILE=
# The OAuth provider to use for automatic redirection. # The OAuth provider to use for automatic redirection.
TINYAUTH_OAUTH_AUTOREDIRECT= TINYAUTH_OAUTH_AUTOREDIRECT=
# OAuth client ID. # OAuth client ID.
+8
View File
@@ -30,6 +30,7 @@ type BootstrapApp struct {
redirectCookieName string redirectCookieName string
oauthSessionCookieName string oauthSessionCookieName string
users []config.User users []config.User
oauthWhitelist []string
oauthProviders map[string]config.OAuthServiceConfig oauthProviders map[string]config.OAuthServiceConfig
configuredProviders []controller.Provider configuredProviders []controller.Provider
oidcClients []config.OIDCClientConfig oidcClients []config.OIDCClientConfig
@@ -71,6 +72,13 @@ func (app *BootstrapApp) Setup() error {
app.context.users = users app.context.users = users
oauthWhitelist, err := utils.GetStringList(app.config.OAuth.Whitelist, app.config.OAuth.WhitelistFile)
if err != nil {
return err
}
app.context.oauthWhitelist = oauthWhitelist
// Setup OAuth providers // Setup OAuth providers
app.context.oauthProviders = app.config.OAuth.Providers app.context.oauthProviders = app.config.OAuth.Providers
+1 -1
View File
@@ -70,7 +70,7 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er
authService := service.NewAuthService(service.AuthServiceConfig{ authService := service.NewAuthService(service.AuthServiceConfig{
Users: app.context.users, Users: app.context.users,
OauthWhitelist: app.config.OAuth.Whitelist, OauthWhitelist: app.context.oauthWhitelist,
SessionExpiry: app.config.Auth.SessionExpiry, SessionExpiry: app.config.Auth.SessionExpiry,
SessionMaxLifetime: app.config.Auth.SessionMaxLifetime, SessionMaxLifetime: app.config.Auth.SessionMaxLifetime,
SecureCookie: app.config.Auth.SecureCookie, SecureCookie: app.config.Auth.SecureCookie,
+1
View File
@@ -159,6 +159,7 @@ type IPConfig struct {
type OAuthConfig struct { type OAuthConfig struct {
Whitelist []string `description:"Comma-separated list of allowed OAuth domains." yaml:"whitelist"` Whitelist []string `description:"Comma-separated list of allowed OAuth domains." yaml:"whitelist"`
WhitelistFile string `description:"Path to the OAuth whitelist file." yaml:"whitelistFile"`
AutoRedirect string `description:"The OAuth provider to use for automatic redirection." yaml:"autoRedirect"` AutoRedirect string `description:"The OAuth provider to use for automatic redirection." yaml:"autoRedirect"`
Providers map[string]OAuthServiceConfig `description:"OAuth providers configuration." yaml:"providers"` Providers map[string]OAuthServiceConfig `description:"OAuth providers configuration." yaml:"providers"`
} }
+38
View File
@@ -28,3 +28,41 @@ func CoalesceToString(value any) string {
return "" return ""
} }
} }
func ParseNonEmptyLines(contents string) []string {
lines := make([]string, 0)
for line := range strings.SplitSeq(contents, "\n") {
lineTrimmed := strings.TrimSpace(line)
if lineTrimmed == "" {
continue
}
lines = append(lines, lineTrimmed)
}
return lines
}
func GetStringList(valuesCfg []string, valuesPath string) ([]string, error) {
values := make([]string, 0, len(valuesCfg))
for _, value := range valuesCfg {
valueTrimmed := strings.TrimSpace(value)
if valueTrimmed == "" {
continue
}
values = append(values, valueTrimmed)
}
if valuesPath == "" {
return values, nil
}
contents, err := ReadFile(valuesPath)
if err != nil {
return []string{}, err
}
values = append(values, ParseNonEmptyLines(contents)...)
return values, nil
}
+31
View File
@@ -1,6 +1,7 @@
package utils_test package utils_test
import ( import (
"os"
"testing" "testing"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
@@ -57,3 +58,33 @@ func TestCompileUserEmail(t *testing.T) {
// Test with invalid email // Test with invalid email
assert.Equal(t, "user@example.com", utils.CompileUserEmail("user", "example.com")) assert.Equal(t, "user@example.com", utils.CompileUserEmail("user", "example.com"))
} }
func TestParseNonEmptyLines(t *testing.T) {
lines := utils.ParseNonEmptyLines(" first@example.com \n\n second@example.com \n \n")
assert.DeepEqual(t, []string{"first@example.com", "second@example.com"}, lines)
}
func TestGetStringList(t *testing.T) {
file, err := os.Create("/tmp/tinyauth_list_test_file")
assert.NilError(t, err)
_, err = file.WriteString(" third@example.com \n\n fourth@example.com \n")
assert.NilError(t, err)
err = file.Close()
assert.NilError(t, err)
defer os.Remove("/tmp/tinyauth_list_test_file")
values, err := utils.GetStringList([]string{" first@example.com ", "", "second@example.com"}, "/tmp/tinyauth_list_test_file")
assert.NilError(t, err)
assert.DeepEqual(t, []string{"first@example.com", "second@example.com", "third@example.com", "fourth@example.com"}, values)
values, err = utils.GetStringList(nil, "")
assert.NilError(t, err)
assert.DeepEqual(t, []string{}, values)
values, err = utils.GetStringList(nil, "/tmp/non_existing_list_file")
assert.ErrorContains(t, err, "no such file or directory")
assert.DeepEqual(t, []string{}, values)
}
+3 -26
View File
@@ -34,32 +34,9 @@ func ParseUsers(usersStr []string, userAttributes map[string]config.UserAttribut
} }
func GetUsers(usersCfg []string, usersPath string, userAttributes map[string]config.UserAttributes) ([]config.User, error) { func GetUsers(usersCfg []string, usersPath string, userAttributes map[string]config.UserAttributes) ([]config.User, error) {
var usersStr []string usersStr, err := GetStringList(usersCfg, usersPath)
if err != nil {
if len(usersCfg) == 0 && usersPath == "" { return []config.User{}, err
return []config.User{}, nil
}
if len(usersCfg) > 0 {
usersStr = append(usersStr, usersCfg...)
}
if usersPath != "" {
contents, err := ReadFile(usersPath)
if err != nil {
return []config.User{}, err
}
lines := strings.SplitSeq(contents, "\n")
for line := range lines {
lineTrimmed := strings.TrimSpace(line)
if lineTrimmed == "" {
continue
}
usersStr = append(usersStr, lineTrimmed)
}
} }
return ParseUsers(usersStr, userAttributes) return ParseUsers(usersStr, userAttributes)