diff --git a/internal/config/config.go b/internal/config/config.go index 6b143bf..f925b0c 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -74,6 +74,7 @@ type OAuthServiceConfig struct { TokenURL string UserinfoURL string InsecureSkipVerify bool + Name string } // User/session related stuff @@ -178,9 +179,5 @@ type AppPath struct { // Flags type Providers struct { - Providers map[string]ProviderConfig -} - -type ProviderConfig struct { - Config OAuthServiceConfig + Providers map[string]OAuthServiceConfig } diff --git a/internal/utils/decoders/flags_decoder.go b/internal/utils/decoders/flags_decoder.go index 4154157..6a29d3a 100644 --- a/internal/utils/decoders/flags_decoder.go +++ b/internal/utils/decoders/flags_decoder.go @@ -1,6 +1,9 @@ package decoders import ( + "fmt" + "slices" + "sort" "strings" "tinyauth/internal/config" "tinyauth/internal/utils" @@ -9,43 +12,25 @@ import ( ) func DecodeFlags(flags map[string]string) (config.Providers, error) { - // Normalize flags (sorry to whoever has to read this) - // --providers-client1-client-id -> tinyauth.providers.client1.clientId - normalized := make(map[string]string) - for k, v := range flags { - newKey := "" + normalized := normalizeFlags(flags, "tinyauth") - nk := strings.TrimPrefix(k, "--") - parts := strings.SplitN(nk, "-", 4) - - for i, part := range parts { - if i == 3 { - subParts := strings.Split(part, "-") - for j, subPart := range subParts { - if j == 0 { - newKey += "." + subPart - } else { - newKey += utils.Capitalize(subPart) - } - } - continue - } - if i == 0 { - newKey += part - continue - } - newKey += "." + part - } - - newKey = "tinyauth." + newKey - normalized[newKey] = v + node, err := decodeFlagsToNode(normalized, "tinyauth", "tinyauth_providers") + if err != nil { + return config.Providers{}, err } - // Decode var providers config.Providers - err := parser.Decode(normalized, &providers, "tinyauth", "tinyauth.providers") + metaOpts := parser.MetadataOpts{TagName: "flag", AllowSliceAsStruct: true} + + err = parser.AddMetadata(&providers, node, metaOpts) + + if err != nil { + return config.Providers{}, err + } + + err = parser.Fill(&providers, node, parser.FillerOpts{AllowSliceAsStruct: true}) if err != nil { return config.Providers{}, err @@ -53,3 +38,99 @@ func DecodeFlags(flags map[string]string) (config.Providers, error) { return providers, nil } + +func decodeFlagsToNode(flags map[string]string, rootName string, filters ...string) (*parser.Node, error) { + sorted := sortFlagKeys(flags, filters) + + var node *parser.Node + + for i, k := range sorted { + split := strings.SplitN(k, "_", 4) + + if split[0] != rootName { + return nil, fmt.Errorf("invalid flag root %s", split[0]) + } + + if slices.Contains(split, "") { + return nil, fmt.Errorf("invalid element: %s", k) + } + + if i == 0 { + node = &parser.Node{} + } + + decodeFlagToNode(node, split, flags[k]) + } + + return node, nil +} + +func decodeFlagToNode(root *parser.Node, path []string, value string) { + if len(root.Name) == 0 { + root.Name = path[0] + } + + if !(len(path) > 1) { + root.Value = value + return + } + + if n := containsFlagNode(root.Children, path[1]); n != nil { + decodeFlagToNode(n, path[1:], value) + return + } + + child := &parser.Node{Name: path[1]} + decodeFlagToNode(child, path[1:], value) + root.Children = append(root.Children, child) +} + +func containsFlagNode(node []*parser.Node, name string) *parser.Node { + for _, n := range node { + if strings.EqualFold(n.Name, name) { + return n + } + } + return nil +} + +func sortFlagKeys(flags map[string]string, filters []string) []string { + var sorted []string + + for k := range flags { + if len(filters) == 0 { + sorted = append(sorted, k) + continue + } + + for _, f := range filters { + if strings.HasPrefix(k, f) { + sorted = append(sorted, k) + break + } + } + } + + sort.Strings(sorted) + return sorted +} + +// normalizeFlags converts flags from --providers-client-client-id to tinyauth_providers_client_clientId +func normalizeFlags(flags map[string]string, rootName string) map[string]string { + n := make(map[string]string) + for k, v := range flags { + fk := strings.TrimPrefix(k, "--") + fks := strings.SplitN(fk, "-", 3) + fkb := "" + for i, s := range strings.Split(fks[len(fks)-1], "-") { + if i == 0 { + fkb += s + continue + } + fkb += utils.Capitalize(s) + } + fk = rootName + "_" + strings.Join(fks[:len(fks)-1], "_") + "_" + fkb + n[fk] = v + } + return n +} diff --git a/internal/utils/decoders/flags_decoder_test.go b/internal/utils/decoders/flags_decoder_test.go index a10760a..356b4ae 100644 --- a/internal/utils/decoders/flags_decoder_test.go +++ b/internal/utils/decoders/flags_decoder_test.go @@ -11,46 +11,46 @@ import ( func TestDecodeFlags(t *testing.T) { // Variables expected := config.Providers{ - Providers: map[string]config.ProviderConfig{ + Providers: map[string]config.OAuthServiceConfig{ "client1": { - Config: config.OAuthServiceConfig{ - ClientID: "client1-id", - ClientSecret: "client1-secret", - Scopes: []string{"client1-scope1", "client1-scope2"}, - RedirectURL: "client1-redirect-url", - AuthURL: "client1-auth-url", - UserinfoURL: "client1-user-info-url", - InsecureSkipVerify: false, - }, + ClientID: "client1-id", + ClientSecret: "client1-secret", + Scopes: []string{"client1-scope1", "client1-scope2"}, + RedirectURL: "client1-redirect-url", + AuthURL: "client1-auth-url", + UserinfoURL: "client1-user-info-url", + Name: "Client1", + InsecureSkipVerify: false, }, "client2": { - Config: config.OAuthServiceConfig{ - ClientID: "client2-id", - ClientSecret: "client2-secret", - Scopes: []string{"client2-scope1", "client2-scope2"}, - RedirectURL: "client2-redirect-url", - AuthURL: "client2-auth-url", - UserinfoURL: "client2-user-info-url", - InsecureSkipVerify: false, - }, + ClientID: "client2-id", + ClientSecret: "client2-secret", + Scopes: []string{"client2-scope1", "client2-scope2"}, + RedirectURL: "client2-redirect-url", + AuthURL: "client2-auth-url", + UserinfoURL: "client2-user-info-url", + Name: "My Awesome Client2", + InsecureSkipVerify: false, }, }, } test := map[string]string{ - "--providers-client1-config-client-id": "client1-id", - "--providers-client1-config-client-secret": "client1-secret", - "--providers-client1-config-scopes": "client1-scope1,client1-scope2", - "--providers-client1-config-redirect-url": "client1-redirect-url", - "--providers-client1-config-auth-url": "client1-auth-url", - "--providers-client1-config-user-info-url": "client1-user-info-url", - "--providers-client1-config-insecure-skip-verify": "false", - "--providers-client2-config-client-id": "client2-id", - "--providers-client2-config-client-secret": "client2-secret", - "--providers-client2-config-scopes": "client2-scope1,client2-scope2", - "--providers-client2-config-redirect-url": "client2-redirect-url", - "--providers-client2-config-auth-url": "client2-auth-url", - "--providers-client2-config-user-info-url": "client2-user-info-url", - "--providers-client2-config-insecure-skip-verify": "false", + "--providers-client1-client-id": "client1-id", + "--providers-client1-client-secret": "client1-secret", + "--providers-client1-scopes": "client1-scope1,client1-scope2", + "--providers-client1-redirect-url": "client1-redirect-url", + "--providers-client1-auth-url": "client1-auth-url", + "--providers-client1-user-info-url": "client1-user-info-url", + "--providers-client1-name": "Client1", + "--providers-client1-insecure-skip-verify": "false", + "--providers-client2-client-id": "client2-id", + "--providers-client2-client-secret": "client2-secret", + "--providers-client2-scopes": "client2-scope1,client2-scope2", + "--providers-client2-redirect-url": "client2-redirect-url", + "--providers-client2-auth-url": "client2-auth-url", + "--providers-client2-user-info-url": "client2-user-info-url", + "--providers-client2-name": "My Awesome Client2", + "--providers-client2-insecure-skip-verify": "false", } // Test