mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2025-12-23 08:32:30 +00:00
feat: unified config (#533)
* chore: add yaml config ref * feat: add initial implementation of a traefik like cli * refactor: remove dependency on traefik * chore: update example env * refactor: update build * chore: remove unused code * fix: fix translations not loading * feat: add experimental config file support * chore: mod tidy * fix: review comments * refactor: move tinyauth to separate package * chore: add quotes to all env variables * chore: resolve go mod and sum conflicts * chore: go mod tidy * fix: review comments
This commit is contained in:
@@ -43,7 +43,7 @@ func NewBootstrapApp(config config.Config) *BootstrapApp {
|
||||
|
||||
func (app *BootstrapApp) Setup() error {
|
||||
// Parse users
|
||||
users, err := utils.GetUsers(app.config.Users, app.config.UsersFile)
|
||||
users, err := utils.GetUsers(app.config.Auth.Users, app.config.Auth.UsersFile)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -51,14 +51,35 @@ func (app *BootstrapApp) Setup() error {
|
||||
|
||||
app.context.users = users
|
||||
|
||||
// Get OAuth configs
|
||||
oauthProviders, err := utils.GetOAuthProvidersConfig(os.Environ(), os.Args, app.config.AppURL)
|
||||
// Setup OAuth providers
|
||||
app.context.oauthProviders = app.config.OAuth.Providers
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
for name, provider := range app.context.oauthProviders {
|
||||
secret := utils.GetSecret(provider.ClientSecret, provider.ClientSecretFile)
|
||||
provider.ClientSecret = secret
|
||||
provider.ClientSecretFile = ""
|
||||
app.context.oauthProviders[name] = provider
|
||||
}
|
||||
|
||||
app.context.oauthProviders = oauthProviders
|
||||
for id := range config.OverrideProviders {
|
||||
if provider, exists := app.context.oauthProviders[id]; exists {
|
||||
if provider.RedirectURL == "" {
|
||||
provider.RedirectURL = app.config.AppURL + "/api/oauth/callback/" + id
|
||||
app.context.oauthProviders[id] = provider
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for id, provider := range app.context.oauthProviders {
|
||||
if provider.Name == "" {
|
||||
if name, ok := config.OverrideProviders[id]; ok {
|
||||
provider.Name = name
|
||||
} else {
|
||||
provider.Name = utils.Capitalize(id)
|
||||
}
|
||||
}
|
||||
app.context.oauthProviders[id] = provider
|
||||
}
|
||||
|
||||
// Get cookie domain
|
||||
cookieDomain, err := utils.GetCookieDomain(app.config.AppURL)
|
||||
@@ -98,7 +119,7 @@ func (app *BootstrapApp) Setup() error {
|
||||
// Configured providers
|
||||
configuredProviders := make([]controller.Provider, 0)
|
||||
|
||||
for id, provider := range oauthProviders {
|
||||
for id, provider := range app.context.oauthProviders {
|
||||
configuredProviders = append(configuredProviders, controller.Provider{
|
||||
Name: provider.Name,
|
||||
ID: id,
|
||||
@@ -144,17 +165,17 @@ func (app *BootstrapApp) Setup() error {
|
||||
}
|
||||
|
||||
// If we have an socket path, bind to it
|
||||
if app.config.SocketPath != "" {
|
||||
if _, err := os.Stat(app.config.SocketPath); err == nil {
|
||||
log.Info().Msgf("Removing existing socket file %s", app.config.SocketPath)
|
||||
err := os.Remove(app.config.SocketPath)
|
||||
if app.config.Server.SocketPath != "" {
|
||||
if _, err := os.Stat(app.config.Server.SocketPath); err == nil {
|
||||
log.Info().Msgf("Removing existing socket file %s", app.config.Server.SocketPath)
|
||||
err := os.Remove(app.config.Server.SocketPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove existing socket file: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Info().Msgf("Starting server on unix socket %s", app.config.SocketPath)
|
||||
if err := router.RunUnix(app.config.SocketPath); err != nil {
|
||||
log.Info().Msgf("Starting server on unix socket %s", app.config.Server.SocketPath)
|
||||
if err := router.RunUnix(app.config.Server.SocketPath); err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to start server")
|
||||
}
|
||||
|
||||
@@ -162,7 +183,7 @@ func (app *BootstrapApp) Setup() error {
|
||||
}
|
||||
|
||||
// Start server
|
||||
address := fmt.Sprintf("%s:%d", app.config.Address, app.config.Port)
|
||||
address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port)
|
||||
log.Info().Msgf("Starting server on %s", address)
|
||||
if err := router.Run(address); err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to start server")
|
||||
@@ -193,7 +214,7 @@ func (app *BootstrapApp) heartbeat() {
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: time.Duration(10) * time.Second, // The server should never take more than 10 seconds to respond
|
||||
Timeout: 30 * time.Second, // The server should never take more than 30 seconds to respond
|
||||
}
|
||||
|
||||
heartbeatURL := config.ApiServer + "/v1/instances/heartbeat"
|
||||
|
||||
@@ -13,8 +13,8 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
|
||||
engine := gin.New()
|
||||
engine.Use(gin.Recovery())
|
||||
|
||||
if len(app.config.TrustedProxies) > 0 {
|
||||
err := engine.SetTrustedProxies(strings.Split(app.config.TrustedProxies, ","))
|
||||
if len(app.config.Server.TrustedProxies) > 0 {
|
||||
err := engine.SetTrustedProxies(strings.Split(app.config.Server.TrustedProxies, ","))
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to set trusted proxies: %w", err)
|
||||
@@ -57,12 +57,12 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
|
||||
|
||||
contextController := controller.NewContextController(controller.ContextControllerConfig{
|
||||
Providers: app.context.configuredProviders,
|
||||
Title: app.config.Title,
|
||||
Title: app.config.UI.Title,
|
||||
AppURL: app.config.AppURL,
|
||||
CookieDomain: app.context.cookieDomain,
|
||||
ForgotPasswordMessage: app.config.ForgotPasswordMessage,
|
||||
BackgroundImage: app.config.BackgroundImage,
|
||||
OAuthAutoRedirect: app.config.OAuthAutoRedirect,
|
||||
ForgotPasswordMessage: app.config.UI.ForgotPasswordMessage,
|
||||
BackgroundImage: app.config.UI.BackgroundImage,
|
||||
OAuthAutoRedirect: app.config.OAuth.AutoRedirect,
|
||||
DisableUIWarnings: app.config.DisableUIWarnings,
|
||||
}, apiRouter)
|
||||
|
||||
@@ -70,7 +70,7 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
|
||||
|
||||
oauthController := controller.NewOAuthController(controller.OAuthControllerConfig{
|
||||
AppURL: app.config.AppURL,
|
||||
SecureCookie: app.config.SecureCookie,
|
||||
SecureCookie: app.config.Auth.SecureCookie,
|
||||
CSRFCookieName: app.context.csrfCookieName,
|
||||
RedirectCookieName: app.context.redirectCookieName,
|
||||
CookieDomain: app.context.cookieDomain,
|
||||
|
||||
@@ -31,12 +31,12 @@ func (app *BootstrapApp) initServices() (Services, error) {
|
||||
services.databaseService = databaseService
|
||||
|
||||
ldapService := service.NewLdapService(service.LdapServiceConfig{
|
||||
Address: app.config.LdapAddress,
|
||||
BindDN: app.config.LdapBindDN,
|
||||
BindPassword: app.config.LdapBindPassword,
|
||||
BaseDN: app.config.LdapBaseDN,
|
||||
Insecure: app.config.LdapInsecure,
|
||||
SearchFilter: app.config.LdapSearchFilter,
|
||||
Address: app.config.Ldap.Address,
|
||||
BindDN: app.config.Ldap.BindDN,
|
||||
BindPassword: app.config.Ldap.BindPassword,
|
||||
BaseDN: app.config.Ldap.BaseDN,
|
||||
Insecure: app.config.Ldap.Insecure,
|
||||
SearchFilter: app.config.Ldap.SearchFilter,
|
||||
})
|
||||
|
||||
err = ldapService.Init()
|
||||
@@ -69,12 +69,12 @@ func (app *BootstrapApp) initServices() (Services, error) {
|
||||
|
||||
authService := service.NewAuthService(service.AuthServiceConfig{
|
||||
Users: app.context.users,
|
||||
OauthWhitelist: app.config.OAuthWhitelist,
|
||||
SessionExpiry: app.config.SessionExpiry,
|
||||
SecureCookie: app.config.SecureCookie,
|
||||
OauthWhitelist: app.config.OAuth.Whitelist,
|
||||
SessionExpiry: app.config.Auth.SessionExpiry,
|
||||
SecureCookie: app.config.Auth.SecureCookie,
|
||||
CookieDomain: app.context.cookieDomain,
|
||||
LoginTimeout: app.config.LoginTimeout,
|
||||
LoginMaxRetries: app.config.LoginMaxRetries,
|
||||
LoginTimeout: app.config.Auth.LoginTimeout,
|
||||
LoginMaxRetries: app.config.Auth.LoginMaxRetries,
|
||||
SessionCookieName: app.context.sessionCookieName,
|
||||
}, dockerService, ldapService, databaseService.GetDatabase())
|
||||
|
||||
|
||||
@@ -15,36 +15,67 @@ var RedirectCookieName = "tinyauth-redirect"
|
||||
// Main app config
|
||||
|
||||
type Config struct {
|
||||
Port int `mapstructure:"port" validate:"required"`
|
||||
Address string `validate:"required,ip4_addr" mapstructure:"address"`
|
||||
AppURL string `validate:"required,url" mapstructure:"app-url"`
|
||||
Users string `mapstructure:"users"`
|
||||
UsersFile string `mapstructure:"users-file"`
|
||||
SecureCookie bool `mapstructure:"secure-cookie"`
|
||||
OAuthWhitelist string `mapstructure:"oauth-whitelist"`
|
||||
OAuthAutoRedirect string `mapstructure:"oauth-auto-redirect"`
|
||||
SessionExpiry int `mapstructure:"session-expiry"`
|
||||
LogLevel string `mapstructure:"log-level" validate:"oneof=trace debug info warn error fatal panic"`
|
||||
Title string `mapstructure:"app-title"`
|
||||
LoginTimeout int `mapstructure:"login-timeout"`
|
||||
LoginMaxRetries int `mapstructure:"login-max-retries"`
|
||||
ForgotPasswordMessage string `mapstructure:"forgot-password-message"`
|
||||
BackgroundImage string `mapstructure:"background-image" validate:"required"`
|
||||
LdapAddress string `mapstructure:"ldap-address"`
|
||||
LdapBindDN string `mapstructure:"ldap-bind-dn"`
|
||||
LdapBindPassword string `mapstructure:"ldap-bind-password"`
|
||||
LdapBaseDN string `mapstructure:"ldap-base-dn"`
|
||||
LdapInsecure bool `mapstructure:"ldap-insecure"`
|
||||
LdapSearchFilter string `mapstructure:"ldap-search-filter"`
|
||||
ResourcesDir string `mapstructure:"resources-dir"`
|
||||
DatabasePath string `mapstructure:"database-path"`
|
||||
TrustedProxies string `mapstructure:"trusted-proxies"`
|
||||
DisableAnalytics bool `mapstructure:"disable-analytics"`
|
||||
DisableResources bool `mapstructure:"disable-resources"`
|
||||
DisableUIWarnings bool `mapstructure:"disable-ui-warnings"`
|
||||
SocketPath string `mapstructure:"socket-path"`
|
||||
AppURL string `description:"The base URL where the app is hosted." yaml:"appUrl"`
|
||||
LogLevel string `description:"Log level (trace, debug, info, warn, error)." yaml:"logLevel"`
|
||||
ResourcesDir string `description:"The directory where resources are stored." yaml:"resourcesDir"`
|
||||
DatabasePath string `description:"The path to the database file." yaml:"databasePath"`
|
||||
DisableAnalytics bool `description:"Disable analytics." yaml:"disableAnalytics"`
|
||||
DisableResources bool `description:"Disable resources server." yaml:"disableResources"`
|
||||
DisableUIWarnings bool `description:"Disable UI warnings." yaml:"disableUIWarnings"`
|
||||
LogJSON bool `description:"Enable JSON formatted logs." yaml:"logJSON"`
|
||||
Server ServerConfig `description:"Server configuration." yaml:"server"`
|
||||
Auth AuthConfig `description:"Authentication configuration." yaml:"auth"`
|
||||
OAuth OAuthConfig `description:"OAuth configuration." yaml:"oauth"`
|
||||
UI UIConfig `description:"UI customization." yaml:"ui"`
|
||||
Ldap LdapConfig `description:"LDAP configuration." yaml:"ldap"`
|
||||
Experimental ExperimentalConfig `description:"Experimental features, use with caution." yaml:"experimental"`
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
Port int `description:"The port on which the server listens." yaml:"port"`
|
||||
Address string `description:"The address on which the server listens." yaml:"address"`
|
||||
SocketPath string `description:"The path to the Unix socket." yaml:"socketPath"`
|
||||
TrustedProxies string `description:"Comma-separated list of trusted proxy addresses." yaml:"trustedProxies"`
|
||||
}
|
||||
|
||||
type AuthConfig struct {
|
||||
Users string `description:"Comma-separated list of users (username:hashed_password)." yaml:"users"`
|
||||
UsersFile string `description:"Path to the users file." yaml:"usersFile"`
|
||||
SecureCookie bool `description:"Enable secure cookies." yaml:"secureCookie"`
|
||||
SessionExpiry int `description:"Session expiry time in seconds." yaml:"sessionExpiry"`
|
||||
LoginTimeout int `description:"Login timeout in seconds." yaml:"loginTimeout"`
|
||||
LoginMaxRetries int `description:"Maximum login retries." yaml:"loginMaxRetries"`
|
||||
}
|
||||
|
||||
type OAuthConfig struct {
|
||||
Whitelist string `description:"Comma-separated list of allowed OAuth domains." yaml:"whitelist"`
|
||||
AutoRedirect string `description:"The OAuth provider to use for automatic redirection." yaml:"autoRedirect"`
|
||||
Providers map[string]OAuthServiceConfig `description:"OAuth providers configuration." yaml:"providers"`
|
||||
}
|
||||
|
||||
type UIConfig struct {
|
||||
Title string `description:"The title of the UI." yaml:"title"`
|
||||
ForgotPasswordMessage string `description:"Message displayed on the forgot password page." yaml:"forgotPasswordMessage"`
|
||||
BackgroundImage string `description:"Path to the background image." yaml:"backgroundImage"`
|
||||
}
|
||||
|
||||
type LdapConfig struct {
|
||||
Address string `description:"LDAP server address." yaml:"address"`
|
||||
BindDN string `description:"Bind DN for LDAP authentication." yaml:"bindDn"`
|
||||
BindPassword string `description:"Bind password for LDAP authentication." yaml:"bindPassword"`
|
||||
BaseDN string `description:"Base DN for LDAP searches." yaml:"baseDn"`
|
||||
Insecure bool `description:"Allow insecure LDAP connections." yaml:"insecure"`
|
||||
SearchFilter string `description:"LDAP search filter." yaml:"searchFilter"`
|
||||
}
|
||||
|
||||
type ExperimentalConfig struct {
|
||||
ConfigFile string `description:"Path to config file." yaml:"-"`
|
||||
}
|
||||
|
||||
// Config loader options
|
||||
|
||||
const DefaultNamePrefix = "TINYAUTH_"
|
||||
|
||||
// OAuth/OIDC config
|
||||
|
||||
type Claims struct {
|
||||
@@ -55,16 +86,16 @@ type Claims struct {
|
||||
}
|
||||
|
||||
type OAuthServiceConfig struct {
|
||||
ClientID string `field:"client-id"`
|
||||
ClientSecret string
|
||||
ClientSecretFile string
|
||||
Scopes []string
|
||||
RedirectURL string `field:"redirect-url"`
|
||||
AuthURL string `field:"auth-url"`
|
||||
TokenURL string `field:"token-url"`
|
||||
UserinfoURL string `field:"user-info-url"`
|
||||
InsecureSkipVerify bool
|
||||
Name string
|
||||
ClientID string `description:"OAuth client ID."`
|
||||
ClientSecret string `description:"OAuth client secret."`
|
||||
ClientSecretFile string `description:"Path to the file containing the OAuth client secret."`
|
||||
Scopes []string `description:"OAuth scopes."`
|
||||
RedirectURL string `description:"OAuth redirect URL."`
|
||||
AuthURL string `description:"OAuth authorization URL."`
|
||||
TokenURL string `description:"OAuth token URL."`
|
||||
UserinfoURL string `description:"OAuth userinfo URL."`
|
||||
Insecure bool `description:"Allow insecure OAuth connections."`
|
||||
Name string `description:"Provider name in UI."`
|
||||
}
|
||||
|
||||
var OverrideProviders = map[string]string{
|
||||
|
||||
@@ -82,7 +82,7 @@ func (docker *DockerService) GetLabels(appDomain string) (config.App, error) {
|
||||
return config.App{}, err
|
||||
}
|
||||
|
||||
labels, err := decoders.DecodeLabels(inspect.Config.Labels)
|
||||
labels, err := decoders.DecodeLabels[config.Apps](inspect.Config.Labels, "apps")
|
||||
if err != nil {
|
||||
return config.App{}, err
|
||||
}
|
||||
|
||||
@@ -38,7 +38,7 @@ func NewGenericOAuthService(config config.OAuthServiceConfig) *GenericOAuthServi
|
||||
TokenURL: config.TokenURL,
|
||||
},
|
||||
},
|
||||
insecureSkipVerify: config.InsecureSkipVerify,
|
||||
insecureSkipVerify: config.Insecure,
|
||||
userinfoUrl: config.UserinfoURL,
|
||||
name: config.Name,
|
||||
}
|
||||
@@ -54,6 +54,7 @@ func (generic *GenericOAuthService) Init() error {
|
||||
|
||||
httpClient := &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -50,7 +50,9 @@ func NewGithubOAuthService(config config.OAuthServiceConfig) *GithubOAuthService
|
||||
}
|
||||
|
||||
func (github *GithubOAuthService) Init() error {
|
||||
httpClient := &http.Client{}
|
||||
httpClient := &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
ctx := context.Background()
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
|
||||
github.context = ctx
|
||||
|
||||
@@ -45,7 +45,9 @@ func NewGoogleOAuthService(config config.OAuthServiceConfig) *GoogleOAuthService
|
||||
}
|
||||
|
||||
func (google *GoogleOAuthService) Init() error {
|
||||
httpClient := &http.Client{}
|
||||
httpClient := &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
ctx := context.Background()
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
|
||||
google.context = ctx
|
||||
|
||||
@@ -6,12 +6,8 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"tinyauth/internal/config"
|
||||
"tinyauth/internal/utils/decoders"
|
||||
|
||||
"maps"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/weppos/publicsuffix-go/publicsuffix"
|
||||
)
|
||||
|
||||
@@ -104,119 +100,3 @@ func IsRedirectSafe(redirectURL string, domain string) bool {
|
||||
|
||||
return hostname == domain
|
||||
}
|
||||
|
||||
func GetLogLevel(level string) zerolog.Level {
|
||||
switch strings.ToLower(level) {
|
||||
case "trace":
|
||||
return zerolog.TraceLevel
|
||||
case "debug":
|
||||
return zerolog.DebugLevel
|
||||
case "info":
|
||||
return zerolog.InfoLevel
|
||||
case "warn":
|
||||
return zerolog.WarnLevel
|
||||
case "error":
|
||||
return zerolog.ErrorLevel
|
||||
case "fatal":
|
||||
return zerolog.FatalLevel
|
||||
case "panic":
|
||||
return zerolog.PanicLevel
|
||||
default:
|
||||
return zerolog.InfoLevel
|
||||
}
|
||||
}
|
||||
|
||||
func GetOAuthProvidersConfig(env []string, args []string, appUrl string) (map[string]config.OAuthServiceConfig, error) {
|
||||
providers := make(map[string]config.OAuthServiceConfig)
|
||||
|
||||
// Get from environment variables
|
||||
envMap := make(map[string]string)
|
||||
|
||||
for _, e := range env {
|
||||
pair := strings.SplitN(e, "=", 2)
|
||||
if len(pair) == 2 {
|
||||
envMap[pair[0]] = pair[1]
|
||||
}
|
||||
}
|
||||
|
||||
envProviders, err := decoders.DecodeEnv[config.Providers, config.OAuthServiceConfig](envMap, "providers")
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
maps.Copy(providers, envProviders.Providers)
|
||||
|
||||
// Get from flags
|
||||
flagsMap := make(map[string]string)
|
||||
|
||||
for _, arg := range args[1:] {
|
||||
if strings.HasPrefix(arg, "--") {
|
||||
pair := strings.SplitN(arg[2:], "=", 2)
|
||||
if len(pair) == 2 {
|
||||
flagsMap[pair[0]] = pair[1]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
flagProviders, err := decoders.DecodeFlags[config.Providers, config.OAuthServiceConfig](flagsMap, "providers")
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
maps.Copy(providers, flagProviders.Providers)
|
||||
|
||||
// For every provider get correct secret from file if set
|
||||
for name, provider := range providers {
|
||||
secret := GetSecret(provider.ClientSecret, provider.ClientSecretFile)
|
||||
provider.ClientSecret = secret
|
||||
provider.ClientSecretFile = ""
|
||||
providers[name] = provider
|
||||
}
|
||||
|
||||
// If we have google/github providers and no redirect URL then set a default
|
||||
for id := range config.OverrideProviders {
|
||||
if provider, exists := providers[id]; exists {
|
||||
if provider.RedirectURL == "" {
|
||||
provider.RedirectURL = appUrl + "/api/oauth/callback/" + id
|
||||
providers[id] = provider
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set names
|
||||
for id, provider := range providers {
|
||||
if provider.Name == "" {
|
||||
if name, ok := config.OverrideProviders[id]; ok {
|
||||
provider.Name = name
|
||||
} else {
|
||||
provider.Name = Capitalize(id)
|
||||
}
|
||||
}
|
||||
providers[id] = provider
|
||||
}
|
||||
|
||||
// Return combined providers
|
||||
return providers, nil
|
||||
}
|
||||
|
||||
func ShoudLogJSON(environ []string, args []string) bool {
|
||||
for _, e := range environ {
|
||||
pair := strings.SplitN(e, "=", 2)
|
||||
if len(pair) == 2 && pair[0] == "LOG_JSON" && strings.ToLower(pair[1]) == "true" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
for _, arg := range args[1:] {
|
||||
if strings.HasPrefix(arg, "--log-json=") {
|
||||
value := strings.SplitN(arg, "=", 2)[1]
|
||||
if strings.ToLower(value) == "true" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package utils_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
"tinyauth/internal/config"
|
||||
"tinyauth/internal/utils"
|
||||
@@ -206,93 +205,3 @@ func TestIsRedirectSafe(t *testing.T) {
|
||||
result = utils.IsRedirectSafe(redirectURL, domain)
|
||||
assert.Equal(t, false, result)
|
||||
}
|
||||
|
||||
func TestGetOAuthProvidersConfig(t *testing.T) {
|
||||
env := []string{"PROVIDERS_CLIENT1_CLIENT_ID=client1-id", "PROVIDERS_CLIENT1_CLIENT_SECRET=client1-secret"}
|
||||
args := []string{"/tinyauth/tinyauth", "--providers-client2-client-id=client2-id", "--providers-client2-client-secret=client2-secret"}
|
||||
|
||||
expected := map[string]config.OAuthServiceConfig{
|
||||
"client1": {
|
||||
ClientID: "client1-id",
|
||||
ClientSecret: "client1-secret",
|
||||
Name: "Client1",
|
||||
},
|
||||
"client2": {
|
||||
ClientID: "client2-id",
|
||||
ClientSecret: "client2-secret",
|
||||
Name: "Client2",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := utils.GetOAuthProvidersConfig(env, args, "")
|
||||
assert.NilError(t, err)
|
||||
assert.DeepEqual(t, expected, result)
|
||||
|
||||
// Case with no providers
|
||||
env = []string{}
|
||||
args = []string{"/tinyauth/tinyauth"}
|
||||
expected = map[string]config.OAuthServiceConfig{}
|
||||
|
||||
result, err = utils.GetOAuthProvidersConfig(env, args, "")
|
||||
assert.NilError(t, err)
|
||||
assert.DeepEqual(t, expected, result)
|
||||
|
||||
// Case with secret from file
|
||||
file, err := os.Create("/tmp/tinyauth_test_file")
|
||||
assert.NilError(t, err)
|
||||
|
||||
_, err = file.WriteString("file content\n")
|
||||
assert.NilError(t, err)
|
||||
|
||||
err = file.Close()
|
||||
assert.NilError(t, err)
|
||||
defer os.Remove("/tmp/tinyauth_test_file")
|
||||
|
||||
env = []string{"PROVIDERS_CLIENT1_CLIENT_ID=client1-id", "PROVIDERS_CLIENT1_CLIENT_SECRET_FILE=/tmp/tinyauth_test_file"}
|
||||
args = []string{"/tinyauth/tinyauth"}
|
||||
expected = map[string]config.OAuthServiceConfig{
|
||||
"client1": {
|
||||
ClientID: "client1-id",
|
||||
ClientSecret: "file content",
|
||||
Name: "Client1",
|
||||
},
|
||||
}
|
||||
|
||||
result, err = utils.GetOAuthProvidersConfig(env, args, "")
|
||||
assert.NilError(t, err)
|
||||
assert.DeepEqual(t, expected, result)
|
||||
|
||||
// Case with google provider and no redirect URL
|
||||
env = []string{"PROVIDERS_GOOGLE_CLIENT_ID=google-id", "PROVIDERS_GOOGLE_CLIENT_SECRET=google-secret"}
|
||||
args = []string{"/tinyauth/tinyauth"}
|
||||
expected = map[string]config.OAuthServiceConfig{
|
||||
"google": {
|
||||
ClientID: "google-id",
|
||||
ClientSecret: "google-secret",
|
||||
RedirectURL: "http://app.url/api/oauth/callback/google",
|
||||
Name: "Google",
|
||||
},
|
||||
}
|
||||
|
||||
result, err = utils.GetOAuthProvidersConfig(env, args, "http://app.url")
|
||||
assert.NilError(t, err)
|
||||
assert.DeepEqual(t, expected, result)
|
||||
}
|
||||
|
||||
func TestShoudLogJSON(t *testing.T) {
|
||||
// Test with no env or args
|
||||
result := utils.ShoudLogJSON([]string{"FOO=bar"}, []string{"tinyauth", "--foo-bar=baz"})
|
||||
assert.Equal(t, false, result)
|
||||
|
||||
// Test with env variable set
|
||||
result = utils.ShoudLogJSON([]string{"LOG_JSON=true"}, []string{"tinyauth", "--foo-bar=baz"})
|
||||
assert.Equal(t, true, result)
|
||||
|
||||
// Test with flag set
|
||||
result = utils.ShoudLogJSON([]string{"FOO=bar"}, []string{"tinyauth", "--log-json=true"})
|
||||
assert.Equal(t, true, result)
|
||||
|
||||
// Test with both env and flag set to false
|
||||
result = utils.ShoudLogJSON([]string{"LOG_JSON=false"}, []string{"tinyauth", "--log-json=false"})
|
||||
assert.Equal(t, false, result)
|
||||
}
|
||||
|
||||
@@ -1,80 +0,0 @@
|
||||
package decoders
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/stoewer/go-strcase"
|
||||
)
|
||||
|
||||
func normalizeKeys[T any](input map[string]string, root string, sep string) map[string]string {
|
||||
knownKeys := getKnownKeys[T]()
|
||||
normalized := make(map[string]string)
|
||||
|
||||
for k, v := range input {
|
||||
parts := []string{"tinyauth"}
|
||||
|
||||
key := strings.ToLower(k)
|
||||
key = strings.ReplaceAll(key, sep, "-")
|
||||
|
||||
if !strings.HasPrefix(key, root+"-") {
|
||||
continue
|
||||
}
|
||||
|
||||
suffix := ""
|
||||
|
||||
for _, known := range knownKeys {
|
||||
if strings.HasSuffix(key, known) {
|
||||
suffix = known
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if suffix == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
parts = append(parts, root)
|
||||
|
||||
id := strings.TrimPrefix(key, root+"-")
|
||||
id = strings.TrimSuffix(id, "-"+suffix)
|
||||
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
parts = append(parts, id)
|
||||
parts = append(parts, suffix)
|
||||
|
||||
final := ""
|
||||
|
||||
for i, part := range parts {
|
||||
if i > 0 {
|
||||
final += "."
|
||||
}
|
||||
final += strcase.LowerCamelCase(part)
|
||||
}
|
||||
|
||||
normalized[final] = v
|
||||
}
|
||||
|
||||
return normalized
|
||||
}
|
||||
|
||||
func getKnownKeys[T any]() []string {
|
||||
var keys []string
|
||||
var t T
|
||||
|
||||
v := reflect.ValueOf(t)
|
||||
typeOfT := v.Type()
|
||||
|
||||
for field := range typeOfT.NumField() {
|
||||
if typeOfT.Field(field).Tag.Get("field") != "" {
|
||||
keys = append(keys, typeOfT.Field(field).Tag.Get("field"))
|
||||
continue
|
||||
}
|
||||
keys = append(keys, strcase.KebabCase(typeOfT.Field(field).Name))
|
||||
}
|
||||
|
||||
return keys
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
package decoders
|
||||
|
||||
import (
|
||||
"github.com/traefik/paerser/parser"
|
||||
)
|
||||
|
||||
func DecodeEnv[T any, C any](env map[string]string, subName string) (T, error) {
|
||||
var result T
|
||||
|
||||
normalized := normalizeKeys[C](env, subName, "_")
|
||||
|
||||
err := parser.Decode(normalized, &result, "tinyauth", "tinyauth."+subName)
|
||||
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
@@ -1,37 +0,0 @@
|
||||
package decoders_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"tinyauth/internal/config"
|
||||
"tinyauth/internal/utils/decoders"
|
||||
|
||||
"gotest.tools/v3/assert"
|
||||
)
|
||||
|
||||
func TestDecodeEnv(t *testing.T) {
|
||||
// Setup
|
||||
env := map[string]string{
|
||||
"PROVIDERS_GOOGLE_CLIENT_ID": "google-client-id",
|
||||
"PROVIDERS_GOOGLE_CLIENT_SECRET": "google-client-secret",
|
||||
"PROVIDERS_MY_GITHUB_CLIENT_ID": "github-client-id",
|
||||
"PROVIDERS_MY_GITHUB_CLIENT_SECRET": "github-client-secret",
|
||||
}
|
||||
|
||||
expected := config.Providers{
|
||||
Providers: map[string]config.OAuthServiceConfig{
|
||||
"google": {
|
||||
ClientID: "google-client-id",
|
||||
ClientSecret: "google-client-secret",
|
||||
},
|
||||
"myGithub": {
|
||||
ClientID: "github-client-id",
|
||||
ClientSecret: "github-client-secret",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Execute
|
||||
result, err := decoders.DecodeEnv[config.Providers, config.OAuthServiceConfig](env, "providers")
|
||||
assert.NilError(t, err)
|
||||
assert.DeepEqual(t, result, expected)
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
package decoders
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/traefik/paerser/parser"
|
||||
)
|
||||
|
||||
func DecodeFlags[T any, C any](flags map[string]string, subName string) (T, error) {
|
||||
var result T
|
||||
|
||||
filtered := filterFlags(flags)
|
||||
normalized := normalizeKeys[C](filtered, subName, "_")
|
||||
|
||||
err := parser.Decode(normalized, &result, "tinyauth", "tinyauth."+subName)
|
||||
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func filterFlags(flags map[string]string) map[string]string {
|
||||
filtered := make(map[string]string)
|
||||
for k, v := range flags {
|
||||
filtered[strings.TrimPrefix(k, "--")] = v
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
@@ -1,37 +0,0 @@
|
||||
package decoders_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"tinyauth/internal/config"
|
||||
"tinyauth/internal/utils/decoders"
|
||||
|
||||
"gotest.tools/v3/assert"
|
||||
)
|
||||
|
||||
func TestDecodeFlags(t *testing.T) {
|
||||
// Setup
|
||||
flags := map[string]string{
|
||||
"--providers-google-client-id": "google-client-id",
|
||||
"--providers-google-client-secret": "google-client-secret",
|
||||
"--providers-my-github-client-id": "github-client-id",
|
||||
"--providers-my-github-client-secret": "github-client-secret",
|
||||
}
|
||||
|
||||
expected := config.Providers{
|
||||
Providers: map[string]config.OAuthServiceConfig{
|
||||
"google": {
|
||||
ClientID: "google-client-id",
|
||||
ClientSecret: "google-client-secret",
|
||||
},
|
||||
"myGithub": {
|
||||
ClientID: "github-client-id",
|
||||
ClientSecret: "github-client-secret",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Execute
|
||||
result, err := decoders.DecodeFlags[config.Providers, config.OAuthServiceConfig](flags, "providers")
|
||||
assert.NilError(t, err)
|
||||
assert.DeepEqual(t, result, expected)
|
||||
}
|
||||
@@ -1,19 +1,17 @@
|
||||
package decoders
|
||||
|
||||
import (
|
||||
"tinyauth/internal/config"
|
||||
|
||||
"github.com/traefik/paerser/parser"
|
||||
)
|
||||
|
||||
func DecodeLabels(labels map[string]string) (config.Apps, error) {
|
||||
var appLabels config.Apps
|
||||
func DecodeLabels[T any](labels map[string]string, root string) (T, error) {
|
||||
var labelsDecoded T
|
||||
|
||||
err := parser.Decode(labels, &appLabels, "tinyauth", "tinyauth.apps")
|
||||
err := parser.Decode(labels, &labelsDecoded, "tinyauth", "tinyauth."+root)
|
||||
|
||||
if err != nil {
|
||||
return config.Apps{}, err
|
||||
return labelsDecoded, err
|
||||
}
|
||||
|
||||
return appLabels, nil
|
||||
return labelsDecoded, nil
|
||||
}
|
||||
|
||||
@@ -62,7 +62,7 @@ func TestDecodeLabels(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test
|
||||
result, err := decoders.DecodeLabels(test)
|
||||
result, err := decoders.DecodeLabels[config.Apps](test, "apps")
|
||||
assert.NilError(t, err)
|
||||
assert.DeepEqual(t, expected, result)
|
||||
}
|
||||
|
||||
25
internal/utils/loaders/loader_env.go
Normal file
25
internal/utils/loaders/loader_env.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package loaders
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"tinyauth/internal/config"
|
||||
|
||||
"github.com/traefik/paerser/cli"
|
||||
"github.com/traefik/paerser/env"
|
||||
)
|
||||
|
||||
type EnvLoader struct{}
|
||||
|
||||
func (e *EnvLoader) Load(_ []string, cmd *cli.Command) (bool, error) {
|
||||
vars := env.FindPrefixedEnvVars(os.Environ(), config.DefaultNamePrefix, cmd.Configuration)
|
||||
if len(vars) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if err := env.Decode(vars, config.DefaultNamePrefix, cmd.Configuration); err != nil {
|
||||
return false, fmt.Errorf("failed to decode configuration from environment variables: %w", err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
35
internal/utils/loaders/loader_file.go
Normal file
35
internal/utils/loaders/loader_file.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package loaders
|
||||
|
||||
import (
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/traefik/paerser/cli"
|
||||
"github.com/traefik/paerser/file"
|
||||
"github.com/traefik/paerser/flag"
|
||||
)
|
||||
|
||||
type FileLoader struct{}
|
||||
|
||||
func (f *FileLoader) Load(args []string, cmd *cli.Command) (bool, error) {
|
||||
flags, err := flag.Parse(args, cmd.Configuration)
|
||||
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// I guess we are using traefik as the root name
|
||||
configFileFlag := "traefik.experimental.configFile"
|
||||
|
||||
if _, ok := flags[configFileFlag]; !ok {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
log.Warn().Msg("Using experimental file config loader, this feature is experimental and may change or be removed in future releases")
|
||||
|
||||
err = file.Decode(flags[configFileFlag], cmd.Configuration)
|
||||
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
22
internal/utils/loaders/loader_flag.go
Normal file
22
internal/utils/loaders/loader_flag.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package loaders
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/traefik/paerser/cli"
|
||||
"github.com/traefik/paerser/flag"
|
||||
)
|
||||
|
||||
type FlagLoader struct{}
|
||||
|
||||
func (*FlagLoader) Load(args []string, cmd *cli.Command) (bool, error) {
|
||||
if len(args) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if err := flag.Decode(args, cmd.Configuration); err != nil {
|
||||
return false, fmt.Errorf("failed to decode configuration from flags: %w", err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
Reference in New Issue
Block a user