From f978ae155a097990b33a70c194f4fd40eb6296a2 Mon Sep 17 00:00:00 2001 From: Nicolas Meienberger Date: Tue, 28 Oct 2025 19:14:57 +0100 Subject: [PATCH] feat: parse apps acl flags and env dynamically --- cmd/root.go | 37 ++++++-- internal/bootstrap/app_bootstrap.go | 11 ++- internal/service/access_controls_service.go | 55 +++--------- internal/utils/app_utils.go | 50 +++++++++++ internal/utils/decoders/acl_decoder_test.go | 92 ++++++++++++++++++++ internal/utils/decoders/decoders.go | 95 +++++++++++++++++++++ internal/utils/decoders/env_decoder.go | 14 +++ internal/utils/decoders/flags_decoder.go | 15 ++++ 8 files changed, 316 insertions(+), 53 deletions(-) create mode 100644 internal/utils/decoders/acl_decoder_test.go diff --git a/cmd/root.go b/cmd/root.go index 99b6a45..334d918 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -1,6 +1,7 @@ package cmd import ( + "os" "strings" "tinyauth/internal/bootstrap" "tinyauth/internal/config" @@ -14,15 +15,16 @@ import ( ) type rootCmd struct { - root *cobra.Command - cmd *cobra.Command - - viper *viper.Viper + root *cobra.Command + cmd *cobra.Command + viper *viper.Viper + aclFlags map[string]string } func newRootCmd() *rootCmd { return &rootCmd{ - viper: viper.New(), + viper: viper.New(), + aclFlags: make(map[string]string), } } @@ -116,7 +118,7 @@ func (c *rootCmd) run(cmd *cobra.Command, args []string) { log.Warn().Msg("Log level set to trace, this will log sensitive information!") } - app := bootstrap.NewBootstrapApp(conf) + app := bootstrap.NewBootstrapApp(conf, c.aclFlags) err = app.Setup() if err != nil { @@ -126,6 +128,9 @@ func (c *rootCmd) run(cmd *cobra.Command, args []string) { func Run() { rootCmd := newRootCmd() + rootCmd.aclFlags = utils.ExtractACLFlags(os.Args[1:]) + os.Args = filterACLFlags(os.Args) + rootCmd.Register() root := rootCmd.GetCmd() @@ -155,3 +160,23 @@ func Run() { log.Fatal().Err(err).Msg("Failed to execute root command") } } + +func filterACLFlags(args []string) []string { + filtered := make([]string, 0) + + for i, arg := range args { + // Program name + if i == 0 { + filtered = append(filtered, arg) + continue + } + + if strings.HasPrefix(arg, "--apps-") || strings.HasPrefix(arg, "--tinyauth-apps-") { + continue + } + + filtered = append(filtered, arg) + } + + return filtered +} diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go index fdbd382..583bef4 100644 --- a/internal/bootstrap/app_bootstrap.go +++ b/internal/bootstrap/app_bootstrap.go @@ -37,13 +37,15 @@ type Service interface { } type BootstrapApp struct { - config config.Config - uuid string + config config.Config + aclFlags map[string]string + uuid string } -func NewBootstrapApp(config config.Config) *BootstrapApp { +func NewBootstrapApp(config config.Config, aclFlags map[string]string) *BootstrapApp { return &BootstrapApp{ - config: config, + config: config, + aclFlags: aclFlags, } } @@ -140,6 +142,7 @@ func (app *BootstrapApp) Setup() error { // Create services dockerService := service.NewDockerService() aclsService := service.NewAccessControlsService(dockerService) + aclsService.SetACLFlags(app.aclFlags) authService := service.NewAuthService(authConfig, dockerService, ldapService, database) oauthBrokerService := service.NewOAuthBrokerService(oauthProviders) diff --git a/internal/service/access_controls_service.go b/internal/service/access_controls_service.go index cde27e5..f5d5bcf 100644 --- a/internal/service/access_controls_service.go +++ b/internal/service/access_controls_service.go @@ -4,70 +4,39 @@ import ( "os" "strings" "tinyauth/internal/config" - "tinyauth/internal/utils/decoders" + "tinyauth/internal/utils" "github.com/rs/zerolog/log" ) type AccessControlsService struct { - docker *DockerService - envACLs config.Apps + docker *DockerService + envACLs config.Apps + aclFlags map[string]string } func NewAccessControlsService(docker *DockerService) *AccessControlsService { return &AccessControlsService{ - docker: docker, + docker: docker, + aclFlags: make(map[string]string), } } +func (acls *AccessControlsService) SetACLFlags(flags map[string]string) { + acls.aclFlags = flags +} + func (acls *AccessControlsService) Init() error { - acls.envACLs = config.Apps{} env := os.Environ() - appEnvVars := []string{} - for _, e := range env { - if strings.HasPrefix(e, "TINYAUTH_APPS_") { - appEnvVars = append(appEnvVars, e) - } - } - - err := acls.loadEnvACLs(appEnvVars) - - if err != nil { - return err - } - - return nil -} - -func (acls *AccessControlsService) loadEnvACLs(appEnvVars []string) error { - if len(appEnvVars) == 0 { - return nil - } - - envAcls := map[string]string{} - - for _, e := range appEnvVars { - parts := strings.SplitN(e, "=", 2) - if len(parts) != 2 { - continue - } - - // Normalize key, this should use the same normalization logic as in utils/decoders/decoders.go - key := parts[0] - key = strings.ToLower(key) - key = strings.ReplaceAll(key, "_", ".") - value := parts[1] - envAcls[key] = value - } - - apps, err := decoders.DecodeLabels(envAcls) + apps, err := utils.GetACLsConfig(env, acls.aclFlags) if err != nil { return err } acls.envACLs = apps + return nil } diff --git a/internal/utils/app_utils.go b/internal/utils/app_utils.go index 76044c9..77e2fac 100644 --- a/internal/utils/app_utils.go +++ b/internal/utils/app_utils.go @@ -208,3 +208,53 @@ func GetOAuthProvidersConfig(env []string, args []string, appUrl string) (map[st // Return combined providers return providers, nil } + +func GetACLsConfig(env []string, flagsMap map[string]string) (config.Apps, error) { + apps := config.Apps{Apps: make(map[string]config.App)} + + envMap := make(map[string]string) + + for _, e := range env { + pair := strings.SplitN(e, "=", 2) + if len(pair) == 2 { + envMap[pair[0]] = pair[1] + } + } + + envApps, err := decoders.DecodeACLEnv[config.Apps](envMap, "apps") + + if err != nil { + return config.Apps{}, err + } + + if envApps.Apps != nil { + maps.Copy(apps.Apps, envApps.Apps) + } + + flagApps, err := decoders.DecodeACLFlags[config.Apps](flagsMap, "apps") + + if err != nil { + return config.Apps{}, err + } + + if flagApps.Apps != nil { + maps.Copy(apps.Apps, flagApps.Apps) + } + + return apps, nil +} + +func ExtractACLFlags(args []string) map[string]string { + aclFlags := make(map[string]string) + + for _, arg := range args { + if strings.HasPrefix(arg, "--apps-") || strings.HasPrefix(arg, "--tinyauth-apps-") { + pair := strings.SplitN(arg[2:], "=", 2) + if len(pair) == 2 { + aclFlags[pair[0]] = pair[1] + } + } + } + + return aclFlags +} diff --git a/internal/utils/decoders/acl_decoder_test.go b/internal/utils/decoders/acl_decoder_test.go new file mode 100644 index 0000000..3d0bb3c --- /dev/null +++ b/internal/utils/decoders/acl_decoder_test.go @@ -0,0 +1,92 @@ +package decoders_test + +import ( + "testing" + "tinyauth/internal/config" + "tinyauth/internal/utils/decoders" + + "gotest.tools/v3/assert" +) + +func TestDecodeACLEnv(t *testing.T) { + env := map[string]string{ + "TINYAUTH_APPS_MY_COOL_APP_CONFIG_DOMAIN": "example.com", + "TINYAUTH_APPS_MY_COOL_APP_USERS_ALLOW": "user1,user2", + "TINYAUTH_APPS_MY_COOL_APP_USERS_BLOCK": "user3", + "TINYAUTH_APPS_MY_COOL_APP_OAUTH_WHITELIST": "provider1", + "TINYAUTH_APPS_MY_COOL_APP_OAUTH_GROUPS": "group1,group2", + "TINYAUTH_APPS_OTHERAPP_CONFIG_DOMAIN": "test.com", + "TINYAUTH_APPS_OTHERAPP_USERS_ALLOW": "admin", + } + + expected := config.Apps{ + Apps: map[string]config.App{ + "my_cool_app": { + Config: config.AppConfig{ + Domain: "example.com", + }, + Users: config.AppUsers{ + Allow: "user1,user2", + Block: "user3", + }, + OAuth: config.AppOAuth{ + Whitelist: "provider1", + Groups: "group1,group2", + }, + }, + "otherapp": { + Config: config.AppConfig{ + Domain: "test.com", + }, + Users: config.AppUsers{ + Allow: "admin", + }, + }, + }, + } + + // Execute + result, err := decoders.DecodeACLEnv[config.Apps](env, "apps") + assert.NilError(t, err) + assert.DeepEqual(t, result, expected) +} + +func TestDecodeACLFlags(t *testing.T) { + // Setup + flags := map[string]string{ + "tinyauth-apps-webapp-config-domain": "webapp.example.com", + "tinyauth-apps-webapp-users-allow": "alice,bob", + "tinyauth-apps-webapp-oauth-whitelist": "google", + "tinyauth-apps-api-config-domain": "api.example.com", + "tinyauth-apps-api-users-block": "banned", + } + + expected := config.Apps{ + Apps: map[string]config.App{ + "webapp": { + Config: config.AppConfig{ + Domain: "webapp.example.com", + }, + Users: config.AppUsers{ + Allow: "alice,bob", + }, + OAuth: config.AppOAuth{ + Whitelist: "google", + }, + }, + "api": { + Config: config.AppConfig{ + Domain: "api.example.com", + }, + Users: config.AppUsers{ + Block: "banned", + }, + }, + }, + } + + // Execute + result, err := decoders.DecodeACLFlags[config.Apps](flags, "apps") + assert.NilError(t, err) + assert.DeepEqual(t, result, expected) +} diff --git a/internal/utils/decoders/decoders.go b/internal/utils/decoders/decoders.go index 28b72fb..0c7f515 100644 --- a/internal/utils/decoders/decoders.go +++ b/internal/utils/decoders/decoders.go @@ -7,6 +7,60 @@ import ( "github.com/stoewer/go-strcase" ) +func ParsePath(parts []string, idx int, t reflect.Type) []string { + if idx >= len(parts) { + return []string{} + } + + if t.Kind() == reflect.Map { + mapName := strings.ToLower(parts[idx]) + + if idx+1 >= len(parts) { + return []string{mapName} + } + + elemType := t.Elem() + keyEndIdx := idx + 1 + + if elemType.Kind() == reflect.Struct { + for i := idx + 1; i < len(parts); i++ { + found := false + + for j := 0; j < elemType.NumField(); j++ { + field := elemType.Field(j) + if strings.EqualFold(parts[i], field.Name) { + keyEndIdx = i + found = true + break + } + } + + if found { + break + } + } + } + + keyParts := parts[idx+1 : keyEndIdx] + keyName := strings.ToLower(strings.Join(keyParts, "_")) + + rest := ParsePath(parts, keyEndIdx, elemType) + return append([]string{mapName, keyName}, rest...) + } + + if t.Kind() == reflect.Struct { + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + if strings.EqualFold(parts[idx], field.Name) { + rest := ParsePath(parts, idx+1, field.Type) + return append([]string{strings.ToLower(field.Name)}, rest...) + } + } + } + + return []string{} +} + func normalizeKeys[T any](input map[string]string, root string, sep string) map[string]string { knownKeys := getKnownKeys[T]() normalized := make(map[string]string) @@ -74,3 +128,44 @@ func getKnownKeys[T any]() []string { return keys } + +func normalizeACLKeys[T any](input map[string]string, root string, sep string) map[string]string { + normalized := make(map[string]string) + var t T + rootType := reflect.TypeOf(t) + + for k, v := range input { + parts := strings.Split(strings.ToLower(k), sep) + + if len(parts) < 2 { + continue + } + + if parts[0] != "tinyauth" { + continue + } + + if parts[1] != root { + continue + } + + if len(parts) > 2 { + parsedParts := ParsePath(parts[2:], 0, rootType) + + if len(parsedParts) == 0 { + continue + } + + final := "tinyauth" + final += "." + root + + for _, part := range parsedParts { + final += "." + strcase.LowerCamelCase(part) + } + + normalized[final] = v + } + } + + return normalized +} diff --git a/internal/utils/decoders/env_decoder.go b/internal/utils/decoders/env_decoder.go index 532ec64..0132adb 100644 --- a/internal/utils/decoders/env_decoder.go +++ b/internal/utils/decoders/env_decoder.go @@ -17,3 +17,17 @@ func DecodeEnv[T any, C any](env map[string]string, subName string) (T, error) { return result, nil } + +func DecodeACLEnv[T any](env map[string]string, subName string) (T, error) { + var result T + + normalized := normalizeACLKeys[T](env, subName, "_") + + err := parser.Decode(normalized, &result, "tinyauth", "tinyauth."+subName) + + if err != nil { + return result, err + } + + return result, nil +} diff --git a/internal/utils/decoders/flags_decoder.go b/internal/utils/decoders/flags_decoder.go index 0aae234..72b623e 100644 --- a/internal/utils/decoders/flags_decoder.go +++ b/internal/utils/decoders/flags_decoder.go @@ -21,6 +21,21 @@ func DecodeFlags[T any, C any](flags map[string]string, subName string) (T, erro return result, nil } +func DecodeACLFlags[T any](flags map[string]string, subName string) (T, error) { + var result T + + filtered := filterFlags(flags) + normalized := normalizeACLKeys[T](filtered, subName, "-") + + err := parser.Decode(normalized, &result, "tinyauth", "tinyauth."+subName) + + if err != nil { + return result, err + } + + return result, nil +} + func filterFlags(flags map[string]string) map[string]string { filtered := make(map[string]string) for k, v := range flags {