From 5fcc50d5fd4fc33d9c3928582b4acdbacf208e00 Mon Sep 17 00:00:00 2001 From: Stavros Date: Fri, 12 Sep 2025 13:16:45 +0300 Subject: [PATCH] feat: add oauth config parsing logic --- internal/config/config.go | 1 + internal/utils/app_utils.go | 54 ++++++++++++++++++++++++ internal/utils/app_utils_test.go | 54 ++++++++++++++++++++++++ internal/utils/decoders/env_decoder.go | 3 +- internal/utils/decoders/flags_decoder.go | 3 +- 5 files changed, 111 insertions(+), 4 deletions(-) diff --git a/internal/config/config.go b/internal/config/config.go index f925b0c..880d663 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -68,6 +68,7 @@ type Claims struct { type OAuthServiceConfig struct { ClientID string ClientSecret string + ClientSecretFile string Scopes []string RedirectURL string AuthURL string diff --git a/internal/utils/app_utils.go b/internal/utils/app_utils.go index c4b98c6..ed06746 100644 --- a/internal/utils/app_utils.go +++ b/internal/utils/app_utils.go @@ -6,6 +6,9 @@ import ( "net/url" "strings" "tinyauth/internal/config" + "tinyauth/internal/utils/decoders" + + "maps" "github.com/gin-gonic/gin" "github.com/rs/zerolog" @@ -130,3 +133,54 @@ func GetLogLevel(level string) zerolog.Level { return zerolog.InfoLevel } } + +func GetOAuthProvidersConfig(env []string, args []string) (map[string]config.OAuthServiceConfig, error) { + providers := make(map[string]config.OAuthServiceConfig) + + // Get from environment variables + envMap := make(map[string]string) + + for _, e := range env { + pair := strings.SplitN(e, "=", 2) + envMap[pair[0]] = pair[1] + } + + envProviders, err := decoders.DecodeEnv(envMap) + + if err != nil { + return nil, err + } + + maps.Copy(providers, envProviders.Providers) + + // Get from flags + flagsMap := make(map[string]string) + + for _, arg := range args[1:] { + if strings.HasPrefix(arg, "--") { + pair := strings.SplitN(arg[2:], "=", 2) + if len(pair) == 2 { + flagsMap[pair[0]] = pair[1] + } + } + } + + flagProviders, err := decoders.DecodeFlags(flagsMap) + + if err != nil { + return nil, err + } + + maps.Copy(providers, flagProviders.Providers) + + // For every provider get correct secret from file if set + for name, provider := range providers { + secret := GetSecret(provider.ClientSecret, provider.ClientSecretFile) + provider.ClientSecret = secret + provider.ClientSecretFile = "" + providers[name] = provider + } + + // Return combined providers + return providers, nil +} diff --git a/internal/utils/app_utils_test.go b/internal/utils/app_utils_test.go index c35db3d..48bb915 100644 --- a/internal/utils/app_utils_test.go +++ b/internal/utils/app_utils_test.go @@ -1,6 +1,7 @@ package utils_test import ( + "os" "testing" "tinyauth/internal/config" "tinyauth/internal/utils" @@ -200,3 +201,56 @@ func TestIsRedirectSafe(t *testing.T) { result = utils.IsRedirectSafe(redirectURL, domain) assert.Equal(t, false, result) } + +func TestGetOAuthProvidersConfig(t *testing.T) { + env := []string{"PROVIDERS_CLIENT1_CLIENT_ID=client1-id", "PROVIDERS_CLIENT1_CLIENT_SECRET=client1-secret"} + args := []string{"/tinyauth/tinyauth", "--providers-client2-client-id=client2-id", "--providers-client2-client-secret=client2-secret"} + + expected := map[string]config.OAuthServiceConfig{ + "client1": { + ClientID: "client1-id", + ClientSecret: "client1-secret", + }, + "client2": { + ClientID: "client2-id", + ClientSecret: "client2-secret", + }, + } + + result, err := utils.GetOAuthProvidersConfig(env, args) + assert.NilError(t, err) + assert.DeepEqual(t, expected, result) + + // Case with no providers + env = []string{} + args = []string{"/tinyauth/tinyauth"} + expected = map[string]config.OAuthServiceConfig{} + + result, err = utils.GetOAuthProvidersConfig(env, args) + assert.NilError(t, err) + assert.DeepEqual(t, expected, result) + + // Case with secret from file + file, err := os.Create("/tmp/tinyauth_test_file") + assert.NilError(t, err) + + _, err = file.WriteString("file content\n") + assert.NilError(t, err) + + err = file.Close() + assert.NilError(t, err) + defer os.Remove("/tmp/tinyauth_test_file") + + env = []string{"PROVIDERS_CLIENT1_CLIENT_ID=client1-id", "PROVIDERS_CLIENT1_CLIENT_SECRET_FILE=/tmp/tinyauth_test_file"} + args = []string{"/tinyauth/tinyauth"} + expected = map[string]config.OAuthServiceConfig{ + "client1": { + ClientID: "client1-id", + ClientSecret: "file content", + }, + } + + result, err = utils.GetOAuthProvidersConfig(env, args) + assert.NilError(t, err) + assert.DeepEqual(t, expected, result) +} diff --git a/internal/utils/decoders/env_decoder.go b/internal/utils/decoders/env_decoder.go index 467c68b..fc90945 100644 --- a/internal/utils/decoders/env_decoder.go +++ b/internal/utils/decoders/env_decoder.go @@ -6,7 +6,6 @@ import ( "sort" "strings" "tinyauth/internal/config" - "tinyauth/internal/utils" "github.com/traefik/paerser/parser" ) @@ -127,7 +126,7 @@ func normalizeEnv(env map[string]string, rootName string) map[string]string { fkb += s continue } - fkb += utils.Capitalize(s) + fkb += strings.ToUpper(string([]rune(s)[0])) + string([]rune(s)[1:]) } fk = rootName + "_" + strings.Join(fks[:len(fks)-1], "_") + "_" + fkb n[fk] = v diff --git a/internal/utils/decoders/flags_decoder.go b/internal/utils/decoders/flags_decoder.go index 6a29d3a..97aac72 100644 --- a/internal/utils/decoders/flags_decoder.go +++ b/internal/utils/decoders/flags_decoder.go @@ -6,7 +6,6 @@ import ( "sort" "strings" "tinyauth/internal/config" - "tinyauth/internal/utils" "github.com/traefik/paerser/parser" ) @@ -127,7 +126,7 @@ func normalizeFlags(flags map[string]string, rootName string) map[string]string fkb += s continue } - fkb += utils.Capitalize(s) + fkb += strings.ToUpper(string([]rune(s)[0])) + string([]rune(s)[1:]) } fk = rootName + "_" + strings.Join(fks[:len(fks)-1], "_") + "_" + fkb n[fk] = v