feat: add oauth config parsing logic

This commit is contained in:
Stavros
2025-09-12 13:16:45 +03:00
parent 68fd5ac24c
commit 5fcc50d5fd
5 changed files with 111 additions and 4 deletions

View File

@@ -68,6 +68,7 @@ type Claims struct {
type OAuthServiceConfig struct {
ClientID string
ClientSecret string
ClientSecretFile string
Scopes []string
RedirectURL string
AuthURL string

View File

@@ -6,6 +6,9 @@ import (
"net/url"
"strings"
"tinyauth/internal/config"
"tinyauth/internal/utils/decoders"
"maps"
"github.com/gin-gonic/gin"
"github.com/rs/zerolog"
@@ -130,3 +133,54 @@ func GetLogLevel(level string) zerolog.Level {
return zerolog.InfoLevel
}
}
func GetOAuthProvidersConfig(env []string, args []string) (map[string]config.OAuthServiceConfig, error) {
providers := make(map[string]config.OAuthServiceConfig)
// Get from environment variables
envMap := make(map[string]string)
for _, e := range env {
pair := strings.SplitN(e, "=", 2)
envMap[pair[0]] = pair[1]
}
envProviders, err := decoders.DecodeEnv(envMap)
if err != nil {
return nil, err
}
maps.Copy(providers, envProviders.Providers)
// Get from flags
flagsMap := make(map[string]string)
for _, arg := range args[1:] {
if strings.HasPrefix(arg, "--") {
pair := strings.SplitN(arg[2:], "=", 2)
if len(pair) == 2 {
flagsMap[pair[0]] = pair[1]
}
}
}
flagProviders, err := decoders.DecodeFlags(flagsMap)
if err != nil {
return nil, err
}
maps.Copy(providers, flagProviders.Providers)
// For every provider get correct secret from file if set
for name, provider := range providers {
secret := GetSecret(provider.ClientSecret, provider.ClientSecretFile)
provider.ClientSecret = secret
provider.ClientSecretFile = ""
providers[name] = provider
}
// Return combined providers
return providers, nil
}

View File

@@ -1,6 +1,7 @@
package utils_test
import (
"os"
"testing"
"tinyauth/internal/config"
"tinyauth/internal/utils"
@@ -200,3 +201,56 @@ func TestIsRedirectSafe(t *testing.T) {
result = utils.IsRedirectSafe(redirectURL, domain)
assert.Equal(t, false, result)
}
func TestGetOAuthProvidersConfig(t *testing.T) {
env := []string{"PROVIDERS_CLIENT1_CLIENT_ID=client1-id", "PROVIDERS_CLIENT1_CLIENT_SECRET=client1-secret"}
args := []string{"/tinyauth/tinyauth", "--providers-client2-client-id=client2-id", "--providers-client2-client-secret=client2-secret"}
expected := map[string]config.OAuthServiceConfig{
"client1": {
ClientID: "client1-id",
ClientSecret: "client1-secret",
},
"client2": {
ClientID: "client2-id",
ClientSecret: "client2-secret",
},
}
result, err := utils.GetOAuthProvidersConfig(env, args)
assert.NilError(t, err)
assert.DeepEqual(t, expected, result)
// Case with no providers
env = []string{}
args = []string{"/tinyauth/tinyauth"}
expected = map[string]config.OAuthServiceConfig{}
result, err = utils.GetOAuthProvidersConfig(env, args)
assert.NilError(t, err)
assert.DeepEqual(t, expected, result)
// Case with secret from file
file, err := os.Create("/tmp/tinyauth_test_file")
assert.NilError(t, err)
_, err = file.WriteString("file content\n")
assert.NilError(t, err)
err = file.Close()
assert.NilError(t, err)
defer os.Remove("/tmp/tinyauth_test_file")
env = []string{"PROVIDERS_CLIENT1_CLIENT_ID=client1-id", "PROVIDERS_CLIENT1_CLIENT_SECRET_FILE=/tmp/tinyauth_test_file"}
args = []string{"/tinyauth/tinyauth"}
expected = map[string]config.OAuthServiceConfig{
"client1": {
ClientID: "client1-id",
ClientSecret: "file content",
},
}
result, err = utils.GetOAuthProvidersConfig(env, args)
assert.NilError(t, err)
assert.DeepEqual(t, expected, result)
}

View File

@@ -6,7 +6,6 @@ import (
"sort"
"strings"
"tinyauth/internal/config"
"tinyauth/internal/utils"
"github.com/traefik/paerser/parser"
)
@@ -127,7 +126,7 @@ func normalizeEnv(env map[string]string, rootName string) map[string]string {
fkb += s
continue
}
fkb += utils.Capitalize(s)
fkb += strings.ToUpper(string([]rune(s)[0])) + string([]rune(s)[1:])
}
fk = rootName + "_" + strings.Join(fks[:len(fks)-1], "_") + "_" + fkb
n[fk] = v

View File

@@ -6,7 +6,6 @@ import (
"sort"
"strings"
"tinyauth/internal/config"
"tinyauth/internal/utils"
"github.com/traefik/paerser/parser"
)
@@ -127,7 +126,7 @@ func normalizeFlags(flags map[string]string, rootName string) map[string]string
fkb += s
continue
}
fkb += utils.Capitalize(s)
fkb += strings.ToUpper(string([]rune(s)[0])) + string([]rune(s)[1:])
}
fk = rootName + "_" + strings.Join(fks[:len(fks)-1], "_") + "_" + fkb
n[fk] = v