This commit is contained in:
Stavros
2025-11-04 17:37:40 +02:00
parent 5f7e89c330
commit 1ad862d86c
13 changed files with 155 additions and 230 deletions

View File

@@ -62,6 +62,13 @@ func (app *BootstrapApp) Setup() error {
return err return err
} }
// Get access controls
acls, err := utils.GetACLS(os.Environ(), os.Args)
if err != nil {
return err
}
// Get cookie domain // Get cookie domain
cookieDomain, err := utils.GetCookieDomain(app.config.AppURL) cookieDomain, err := utils.GetCookieDomain(app.config.AppURL)
@@ -139,7 +146,7 @@ func (app *BootstrapApp) Setup() error {
// Create services // Create services
dockerService := service.NewDockerService() dockerService := service.NewDockerService()
aclsService := service.NewAccessControlsService(dockerService) aclsService := service.NewAccessControlsService(dockerService, acls)
authService := service.NewAuthService(authConfig, dockerService, ldapService, database) authService := service.NewAuthService(authConfig, dockerService, ldapService, database)
oauthBrokerService := service.NewOAuthBrokerService(oauthProviders) oauthBrokerService := service.NewOAuthBrokerService(oauthProviders)

View File

@@ -40,7 +40,7 @@ func setupProxyController(t *testing.T, middlewares *[]gin.HandlerFunc) (*gin.En
assert.NilError(t, dockerService.Init()) assert.NilError(t, dockerService.Init())
// Access controls // Access controls
accessControlsService := service.NewAccessControlsService(dockerService) accessControlsService := service.NewAccessControlsService(dockerService, config.Apps{})
assert.NilError(t, accessControlsService.Init()) assert.NilError(t, accessControlsService.Init())

View File

@@ -1,82 +1,34 @@
package service package service
import ( import (
"os"
"strings" "strings"
"tinyauth/internal/config" "tinyauth/internal/config"
"tinyauth/internal/utils/decoders"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
type AccessControlsService struct { type AccessControlsService struct {
docker *DockerService docker *DockerService
envACLs config.Apps nonDocker config.Apps
} }
func NewAccessControlsService(docker *DockerService) *AccessControlsService { func NewAccessControlsService(docker *DockerService, nonDocker config.Apps) *AccessControlsService {
return &AccessControlsService{ return &AccessControlsService{
docker: docker, docker: docker,
nonDocker: nonDocker,
} }
} }
func (acls *AccessControlsService) Init() error { func (acls *AccessControlsService) Init() error {
acls.envACLs = config.Apps{}
env := os.Environ()
appEnvVars := []string{}
for _, e := range env {
if strings.HasPrefix(e, "TINYAUTH_APPS_") {
appEnvVars = append(appEnvVars, e)
}
}
err := acls.loadEnvACLs(appEnvVars)
if err != nil {
return err
}
return nil return nil
} }
func (acls *AccessControlsService) loadEnvACLs(appEnvVars []string) error { func (acls *AccessControlsService) lookupNonDockerACLs(appDomain string) *config.App {
if len(appEnvVars) == 0 { if len(acls.nonDocker.Apps) == 0 {
return nil return nil
} }
envAcls := map[string]string{} for appName, appACLs := range acls.nonDocker.Apps {
for _, e := range appEnvVars {
parts := strings.SplitN(e, "=", 2)
if len(parts) != 2 {
continue
}
// Normalize key, this should use the same normalization logic as in utils/decoders/decoders.go
key := parts[0]
key = strings.ToLower(key)
key = strings.ReplaceAll(key, "_", ".")
value := parts[1]
envAcls[key] = value
}
apps, err := decoders.DecodeLabels(envAcls)
if err != nil {
return err
}
acls.envACLs = apps
return nil
}
func (acls *AccessControlsService) lookupEnvACLs(appDomain string) *config.App {
if len(acls.envACLs.Apps) == 0 {
return nil
}
for appName, appACLs := range acls.envACLs.Apps {
if appACLs.Config.Domain == appDomain { if appACLs.Config.Domain == appDomain {
return &appACLs return &appACLs
} }
@@ -90,11 +42,11 @@ func (acls *AccessControlsService) lookupEnvACLs(appDomain string) *config.App {
} }
func (acls *AccessControlsService) GetAccessControls(appDomain string) (config.App, error) { func (acls *AccessControlsService) GetAccessControls(appDomain string) (config.App, error) {
// First check environment variables // First check non-docker apps
envACLs := acls.lookupEnvACLs(appDomain) envACLs := acls.lookupNonDockerACLs(appDomain)
if envACLs != nil { if envACLs != nil {
log.Debug().Str("domain", appDomain).Msg("Found matching access controls in environment variables") log.Debug().Str("domain", appDomain).Msg("Found matching access controls in environment variables or flags")
return *envACLs, nil return *envACLs, nil
} }

View File

@@ -82,7 +82,7 @@ func (docker *DockerService) GetLabels(appDomain string) (config.App, error) {
return config.App{}, err return config.App{}, err
} }
labels, err := decoders.DecodeLabels(inspect.Config.Labels) labels, err := decoders.DecodeLabels[config.Apps](inspect.Config.Labels)
if err != nil { if err != nil {
return config.App{}, err return config.App{}, err
} }

View File

@@ -134,20 +134,11 @@ func GetLogLevel(level string) zerolog.Level {
} }
} }
func GetOAuthProvidersConfig(env []string, args []string, appUrl string) (map[string]config.OAuthServiceConfig, error) { func GetOAuthProvidersConfig(environ []string, args []string, appUrl string) (map[string]config.OAuthServiceConfig, error) {
providers := make(map[string]config.OAuthServiceConfig) providers := make(map[string]config.OAuthServiceConfig)
// Get from environment variables // Get from environment variables
envMap := make(map[string]string) envProviders, err := decoders.DecodeEnv[config.Providers](environ)
for _, e := range env {
pair := strings.SplitN(e, "=", 2)
if len(pair) == 2 {
envMap[pair[0]] = pair[1]
}
}
envProviders, err := decoders.DecodeEnv[config.Providers, config.OAuthServiceConfig](envMap, "providers")
if err != nil { if err != nil {
return nil, err return nil, err
@@ -155,25 +146,14 @@ func GetOAuthProvidersConfig(env []string, args []string, appUrl string) (map[st
maps.Copy(providers, envProviders.Providers) maps.Copy(providers, envProviders.Providers)
// Get from flags // Get from args
flagsMap := make(map[string]string) argProviders, err := decoders.DecodeFlags[config.Providers](args)
for _, arg := range args[1:] {
if strings.HasPrefix(arg, "--") {
pair := strings.SplitN(arg[2:], "=", 2)
if len(pair) == 2 {
flagsMap[pair[0]] = pair[1]
}
}
}
flagProviders, err := decoders.DecodeFlags[config.Providers, config.OAuthServiceConfig](flagsMap, "providers")
if err != nil { if err != nil {
return nil, err return nil, err
} }
maps.Copy(providers, flagProviders.Providers) maps.Copy(providers, argProviders.Providers)
// For every provider get correct secret from file if set // For every provider get correct secret from file if set
for name, provider := range providers { for name, provider := range providers {
@@ -208,3 +188,28 @@ func GetOAuthProvidersConfig(env []string, args []string, appUrl string) (map[st
// Return combined providers // Return combined providers
return providers, nil return providers, nil
} }
func GetACLS(environ []string, args []string) (config.Apps, error) {
acls := config.Apps{}
acls.Apps = make(map[string]config.App)
// Get from environment variables
envACLs, err := decoders.DecodeEnv[config.Apps](environ)
if err != nil {
return config.Apps{}, err
}
maps.Copy(acls.Apps, envACLs.Apps)
// Get from args
argACLs, err := decoders.DecodeFlags[config.Apps](args)
if err != nil {
return config.Apps{}, err
}
maps.Copy(acls.Apps, argACLs.Apps)
return acls, nil
}

View File

@@ -238,8 +238,8 @@ func TestIsRedirectSafeMultiLevel(t *testing.T) {
} }
func TestGetOAuthProvidersConfig(t *testing.T) { func TestGetOAuthProvidersConfig(t *testing.T) {
env := []string{"PROVIDERS_CLIENT1_CLIENT_ID=client1-id", "PROVIDERS_CLIENT1_CLIENT_SECRET=client1-secret"} env := []string{"PROVIDERS_CLIENT1_CLIENTID=client1-id", "PROVIDERS_CLIENT1_CLIENTSECRET=client1-secret"}
args := []string{"/tinyauth/tinyauth", "--providers-client2-client-id=client2-id", "--providers-client2-client-secret=client2-secret"} args := []string{"/tinyauth/tinyauth", "--providers-client2-clientid=client2-id", "--providers-client2-clientsecret=client2-secret"}
expected := map[string]config.OAuthServiceConfig{ expected := map[string]config.OAuthServiceConfig{
"client1": { "client1": {
@@ -278,7 +278,7 @@ func TestGetOAuthProvidersConfig(t *testing.T) {
assert.NilError(t, err) assert.NilError(t, err)
defer os.Remove("/tmp/tinyauth_test_file") defer os.Remove("/tmp/tinyauth_test_file")
env = []string{"PROVIDERS_CLIENT1_CLIENT_ID=client1-id", "PROVIDERS_CLIENT1_CLIENT_SECRET_FILE=/tmp/tinyauth_test_file"} env = []string{"PROVIDERS_CLIENT1_CLIENTID=client1-id", "PROVIDERS_CLIENT1_CLIENTSECRETFILE=/tmp/tinyauth_test_file"}
args = []string{"/tinyauth/tinyauth"} args = []string{"/tinyauth/tinyauth"}
expected = map[string]config.OAuthServiceConfig{ expected = map[string]config.OAuthServiceConfig{
"client1": { "client1": {
@@ -293,7 +293,7 @@ func TestGetOAuthProvidersConfig(t *testing.T) {
assert.DeepEqual(t, expected, result) assert.DeepEqual(t, expected, result)
// Case with google provider and no redirect URL // Case with google provider and no redirect URL
env = []string{"PROVIDERS_GOOGLE_CLIENT_ID=google-id", "PROVIDERS_GOOGLE_CLIENT_SECRET=google-secret"} env = []string{"PROVIDERS_GOOGLE_CLIENTID=google-id", "PROVIDERS_GOOGLE_CLIENTSECRET=google-secret"}
args = []string{"/tinyauth/tinyauth"} args = []string{"/tinyauth/tinyauth"}
expected = map[string]config.OAuthServiceConfig{ expected = map[string]config.OAuthServiceConfig{
"google": { "google": {
@@ -308,3 +308,39 @@ func TestGetOAuthProvidersConfig(t *testing.T) {
assert.NilError(t, err) assert.NilError(t, err)
assert.DeepEqual(t, expected, result) assert.DeepEqual(t, expected, result)
} }
func TestGetACLS(t *testing.T) {
// Setup
env := []string{"TINYAUTH_APPS_APP1_CONFIG_DOMAIN=app1.com", "TINYAUTH_APPS_APP2_CONFIG_DOMAIN=app2.com"}
args := []string{"--apps-app3-config-domain=app3.com", "--apps-app4-config-domain=app4.com"}
expected := config.Apps{
Apps: map[string]config.App{
"app1": {
Config: config.AppConfig{
Domain: "app1.com",
},
},
"app2": {
Config: config.AppConfig{
Domain: "app2.com",
},
},
"app3": {
Config: config.AppConfig{
Domain: "app3.com",
},
},
"app4": {
Config: config.AppConfig{
Domain: "app4.com",
},
},
},
}
// Test
result, err := utils.GetACLS(env, args)
assert.NilError(t, err)
assert.DeepEqual(t, expected, result)
}

View File

@@ -1,76 +0,0 @@
package decoders
import (
"reflect"
"strings"
"github.com/stoewer/go-strcase"
)
func normalizeKeys[T any](input map[string]string, root string, sep string) map[string]string {
knownKeys := getKnownKeys[T]()
normalized := make(map[string]string)
for k, v := range input {
parts := []string{"tinyauth"}
key := strings.ToLower(k)
key = strings.ReplaceAll(key, sep, "-")
suffix := ""
for _, known := range knownKeys {
if strings.HasSuffix(key, known) {
suffix = known
break
}
}
if suffix == "" {
continue
}
parts = append(parts, root)
id := strings.TrimPrefix(key, root+"-")
id = strings.TrimSuffix(id, "-"+suffix)
if id == "" {
continue
}
parts = append(parts, id)
parts = append(parts, suffix)
final := ""
for i, part := range parts {
if i > 0 {
final += "."
}
final += strcase.LowerCamelCase(part)
}
normalized[final] = v
}
return normalized
}
func getKnownKeys[T any]() []string {
var keys []string
var t T
v := reflect.ValueOf(t)
typeOfT := v.Type()
for field := range typeOfT.NumField() {
if typeOfT.Field(field).Tag.Get("field") != "" {
keys = append(keys, typeOfT.Field(field).Tag.Get("field"))
continue
}
keys = append(keys, strcase.KebabCase(typeOfT.Field(field).Name))
}
return keys
}

View File

@@ -1,19 +1,17 @@
package decoders package decoders
import ( import (
"github.com/traefik/paerser/parser" "github.com/traefik/paerser/env"
) )
func DecodeEnv[T any, C any](env map[string]string, subName string) (T, error) { func DecodeEnv[T any](environ []string) (T, error) {
var result T var target T
normalized := normalizeKeys[C](env, subName, "_") err := env.Decode(environ, "TINYAUTH_", &target)
err := parser.Decode(normalized, &result, "tinyauth", "tinyauth."+subName)
if err != nil { if err != nil {
return result, err return target, err
} }
return result, nil return target, nil
} }

View File

@@ -10,11 +10,11 @@ import (
func TestDecodeEnv(t *testing.T) { func TestDecodeEnv(t *testing.T) {
// Setup // Setup
env := map[string]string{ env := []string{
"PROVIDERS_GOOGLE_CLIENT_ID": "google-client-id", "TINYAUTH_PROVIDERS_GOOGLE_CLIENTID=google-client-id",
"PROVIDERS_GOOGLE_CLIENT_SECRET": "google-client-secret", "TINYAUTH_PROVIDERS_GOOGLE_CLIENTSECRET=google-client-secret",
"PROVIDERS_MY_GITHUB_CLIENT_ID": "github-client-id", "TINYAUTH_PROVIDERS_GITHUB_CLIENTID=github-client-id",
"PROVIDERS_MY_GITHUB_CLIENT_SECRET": "github-client-secret", "TINYAUTH_PROVIDERS_GITHUB_CLIENTSECRET=github-client-secret",
} }
expected := config.Providers{ expected := config.Providers{
@@ -23,7 +23,7 @@ func TestDecodeEnv(t *testing.T) {
ClientID: "google-client-id", ClientID: "google-client-id",
ClientSecret: "google-client-secret", ClientSecret: "google-client-secret",
}, },
"myGithub": { "github": {
ClientID: "github-client-id", ClientID: "github-client-id",
ClientSecret: "github-client-secret", ClientSecret: "github-client-secret",
}, },
@@ -31,7 +31,7 @@ func TestDecodeEnv(t *testing.T) {
} }
// Execute // Execute
result, err := decoders.DecodeEnv[config.Providers, config.OAuthServiceConfig](env, "providers") result, err := decoders.DecodeEnv[config.Providers](env)
assert.NilError(t, err) assert.NilError(t, err)
assert.DeepEqual(t, result, expected) assert.DeepEqual(t, result, expected)
} }

View File

@@ -3,28 +3,32 @@ package decoders
import ( import (
"strings" "strings"
"github.com/traefik/paerser/parser" "github.com/traefik/paerser/flag"
) )
func DecodeFlags[T any, C any](flags map[string]string, subName string) (T, error) { func DecodeFlags[T any](args []string) (T, error) {
var result T var target T
var formatted = []string{}
filtered := filterFlags(flags) for _, arg := range args {
normalized := normalizeKeys[C](filtered, subName, "_") argFmt := strings.TrimPrefix(arg, "--")
argParts := strings.SplitN(argFmt, "=", 2)
err := parser.Decode(normalized, &result, "tinyauth", "tinyauth."+subName) if len(argParts) != 2 {
continue
}
key := argParts[0]
value := argParts[1]
formatted = append(formatted, "--"+strings.ReplaceAll(key, "-", ".")+"="+value)
}
err := flag.Decode(formatted, &target)
if err != nil { if err != nil {
return result, err return target, err
} }
return result, nil return target, nil
}
func filterFlags(flags map[string]string) map[string]string {
filtered := make(map[string]string)
for k, v := range flags {
filtered[strings.TrimPrefix(k, "--")] = v
}
return filtered
} }

View File

@@ -10,11 +10,11 @@ import (
func TestDecodeFlags(t *testing.T) { func TestDecodeFlags(t *testing.T) {
// Setup // Setup
flags := map[string]string{ args := []string{
"--providers-google-client-id": "google-client-id", "--providers-google-clientid=google-client-id",
"--providers-google-client-secret": "google-client-secret", "--providers-google-clientsecret=google-client-secret",
"--providers-my-github-client-id": "github-client-id", "--providers-github-clientid=github-client-id",
"--providers-my-github-client-secret": "github-client-secret", "--providers-github-clientsecret=github-client-secret",
} }
expected := config.Providers{ expected := config.Providers{
@@ -23,7 +23,7 @@ func TestDecodeFlags(t *testing.T) {
ClientID: "google-client-id", ClientID: "google-client-id",
ClientSecret: "google-client-secret", ClientSecret: "google-client-secret",
}, },
"myGithub": { "github": {
ClientID: "github-client-id", ClientID: "github-client-id",
ClientSecret: "github-client-secret", ClientSecret: "github-client-secret",
}, },
@@ -31,7 +31,7 @@ func TestDecodeFlags(t *testing.T) {
} }
// Execute // Execute
result, err := decoders.DecodeFlags[config.Providers, config.OAuthServiceConfig](flags, "providers") result, err := decoders.DecodeFlags[config.Providers](args)
assert.NilError(t, err) assert.NilError(t, err)
assert.DeepEqual(t, result, expected) assert.DeepEqual(t, result, expected)
} }

View File

@@ -1,19 +1,17 @@
package decoders package decoders
import ( import (
"tinyauth/internal/config"
"github.com/traefik/paerser/parser" "github.com/traefik/paerser/parser"
) )
func DecodeLabels(labels map[string]string) (config.Apps, error) { func DecodeLabels[T any](labels map[string]string) (T, error) {
var appLabels config.Apps var target T
err := parser.Decode(labels, &appLabels, "tinyauth", "tinyauth.apps") err := parser.Decode(labels, &target, "tinyauth")
if err != nil { if err != nil {
return config.Apps{}, err return target, err
} }
return appLabels, nil return target, nil
} }

View File

@@ -9,7 +9,24 @@ import (
) )
func TestDecodeLabels(t *testing.T) { func TestDecodeLabels(t *testing.T) {
// Variables // Setup
labels := map[string]string{
"tinyauth.apps.foo.config.domain": "example.com",
"tinyauth.apps.foo.users.allow": "user1,user2",
"tinyauth.apps.foo.users.block": "user3",
"tinyauth.apps.foo.oauth.whitelist": "somebody@example.com",
"tinyauth.apps.foo.oauth.groups": "group3",
"tinyauth.apps.foo.ip.allow": "10.71.0.1/24,10.71.0.2",
"tinyauth.apps.foo.ip.block": "10.10.10.10,10.0.0.0/24",
"tinyauth.apps.foo.ip.bypass": "192.168.1.1",
"tinyauth.apps.foo.response.headers": "X-Foo=Bar,X-Baz=Qux",
"tinyauth.apps.foo.response.basicauth.username": "admin",
"tinyauth.apps.foo.response.basicauth.password": "password",
"tinyauth.apps.foo.response.basicauth.passwordfile": "/path/to/passwordfile",
"tinyauth.apps.foo.path.allow": "/public",
"tinyauth.apps.foo.path.block": "/private",
}
expected := config.Apps{ expected := config.Apps{
Apps: map[string]config.App{ Apps: map[string]config.App{
"foo": { "foo": {
@@ -44,25 +61,9 @@ func TestDecodeLabels(t *testing.T) {
}, },
}, },
} }
test := map[string]string{
"tinyauth.apps.foo.config.domain": "example.com",
"tinyauth.apps.foo.users.allow": "user1,user2",
"tinyauth.apps.foo.users.block": "user3",
"tinyauth.apps.foo.oauth.whitelist": "somebody@example.com",
"tinyauth.apps.foo.oauth.groups": "group3",
"tinyauth.apps.foo.ip.allow": "10.71.0.1/24,10.71.0.2",
"tinyauth.apps.foo.ip.block": "10.10.10.10,10.0.0.0/24",
"tinyauth.apps.foo.ip.bypass": "192.168.1.1",
"tinyauth.apps.foo.response.headers": "X-Foo=Bar,X-Baz=Qux",
"tinyauth.apps.foo.response.basicauth.username": "admin",
"tinyauth.apps.foo.response.basicauth.password": "password",
"tinyauth.apps.foo.response.basicauth.passwordfile": "/path/to/passwordfile",
"tinyauth.apps.foo.path.allow": "/public",
"tinyauth.apps.foo.path.block": "/private",
}
// Test // Test
result, err := decoders.DecodeLabels(test) result, err := decoders.DecodeLabels[config.Apps](labels)
assert.NilError(t, err) assert.NilError(t, err)
assert.DeepEqual(t, expected, result) assert.DeepEqual(t, expected, result)
} }