From 5c866bad1ad6227b1cf9c1be4cead6d5c4be680e Mon Sep 17 00:00:00 2001 From: Stavros Date: Tue, 16 Sep 2025 13:28:28 +0300 Subject: [PATCH] feat: multiple oauth providers (#355) * feat: add flag decoder (candidate) * refactor: finalize flags decoder * feat: add env decoder * feat: add oauth config parsing logic * feat: implement backend logic for multiple oauth providers * feat: implement multiple oauth providers in the frontend * feat: add some default icons * chore: add credits for parser * feat: style oauth auto redirect screen * fix: bot suggestions * refactor: rework decoders using simpler and more efficient pattern * refactor: rework oauth name database migration --- cmd/root.go | 23 +--- frontend/src/components/icons/microsoft.tsx | 18 +++ .../icons/{generic.tsx => oauth.tsx} | 2 +- frontend/src/components/icons/pocket-id.tsx | 20 +++ frontend/src/components/icons/tailscale.tsx | 26 ++++ frontend/src/lib/i18n/locales/en-US.json | 3 + frontend/src/lib/i18n/locales/en.json | 3 + frontend/src/pages/continue-page.tsx | 2 +- frontend/src/pages/login-page.tsx | 120 +++++++++++------- frontend/src/pages/logout-page.tsx | 8 +- frontend/src/schemas/app-context-schema.ts | 11 +- frontend/src/schemas/user-context-schema.ts | 1 + .../migrations/000002_oauth_name.down.sql | 1 + .../migrations/000002_oauth_name.up.sql | 10 ++ internal/bootstrap/app_bootstrap.go | 76 ++++++----- internal/config/config.go | 89 ++++++------- internal/controller/context_controller.go | 33 +++-- .../controller/context_controller_test.go | 17 ++- internal/controller/oauth_controller.go | 1 + internal/middleware/context_middleware.go | 1 + internal/model/session_model.go | 1 + internal/service/auth_service.go | 2 + internal/service/generic_oauth_service.go | 6 + internal/service/github_oauth_service.go | 6 + internal/service/google_oauth_service.go | 6 + internal/service/oauth_broker_service.go | 1 + internal/utils/app_utils.go | 68 ++++++++++ internal/utils/app_utils_test.go | 69 ++++++++++ internal/utils/decoders/decoders.go | 81 ++++++++++++ internal/utils/decoders/decoders_test.go | 44 +++++++ internal/utils/decoders/env_decoder.go | 20 +++ internal/utils/decoders/env_decoder_test.go | 60 +++++++++ internal/utils/decoders/flags_decoder.go | 30 +++++ internal/utils/decoders/flags_decoder_test.go | 60 +++++++++ internal/utils/decoders/label_decoder_test.go | 13 +- 35 files changed, 745 insertions(+), 187 deletions(-) create mode 100644 frontend/src/components/icons/microsoft.tsx rename frontend/src/components/icons/{generic.tsx => oauth.tsx} (91%) create mode 100644 frontend/src/components/icons/pocket-id.tsx create mode 100644 frontend/src/components/icons/tailscale.tsx create mode 100644 internal/assets/migrations/000002_oauth_name.down.sql create mode 100644 internal/assets/migrations/000002_oauth_name.up.sql create mode 100644 internal/utils/decoders/decoders.go create mode 100644 internal/utils/decoders/decoders_test.go create mode 100644 internal/utils/decoders/env_decoder.go create mode 100644 internal/utils/decoders/env_decoder_test.go create mode 100644 internal/utils/decoders/flags_decoder.go create mode 100644 internal/utils/decoders/flags_decoder_test.go diff --git a/cmd/root.go b/cmd/root.go index 155ccd2..aeb96a5 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -27,11 +27,6 @@ var rootCmd = &cobra.Command{ log.Fatal().Err(err).Msg("Failed to parse config") } - // Check if secrets have a file associated with them - conf.GithubClientSecret = utils.GetSecret(conf.GithubClientSecret, conf.GithubClientSecretFile) - conf.GoogleClientSecret = utils.GetSecret(conf.GoogleClientSecret, conf.GoogleClientSecretFile) - conf.GenericClientSecret = utils.GetSecret(conf.GenericClientSecret, conf.GenericClientSecretFile) - // Validate config v := validator.New() @@ -57,6 +52,7 @@ var rootCmd = &cobra.Command{ } func Execute() { + rootCmd.FParseErrWhitelist.UnknownFlags = true err := rootCmd.Execute() if err != nil { log.Fatal().Err(err).Msg("Failed to execute command") @@ -80,21 +76,6 @@ func init() { {"users", "", "Comma separated list of users in the format username:hash."}, {"users-file", "", "Path to a file containing users in the format username:hash."}, {"secure-cookie", false, "Send cookie over secure connection only."}, - {"github-client-id", "", "Github OAuth client ID."}, - {"github-client-secret", "", "Github OAuth client secret."}, - {"github-client-secret-file", "", "Github OAuth client secret file."}, - {"google-client-id", "", "Google OAuth client ID."}, - {"google-client-secret", "", "Google OAuth client secret."}, - {"google-client-secret-file", "", "Google OAuth client secret file."}, - {"generic-client-id", "", "Generic OAuth client ID."}, - {"generic-client-secret", "", "Generic OAuth client secret."}, - {"generic-client-secret-file", "", "Generic OAuth client secret file."}, - {"generic-scopes", "", "Generic OAuth scopes."}, - {"generic-auth-url", "", "Generic OAuth auth URL."}, - {"generic-token-url", "", "Generic OAuth token URL."}, - {"generic-user-url", "", "Generic OAuth user info URL."}, - {"generic-name", "Generic", "Generic OAuth provider name."}, - {"generic-skip-ssl", false, "Skip SSL verification for the generic OAuth provider."}, {"oauth-whitelist", "", "Comma separated list of email addresses to whitelist when using OAuth."}, {"oauth-auto-redirect", "none", "Auto redirect to the specified OAuth provider if configured. (available providers: github, google, generic)"}, {"session-expiry", 86400, "Session (cookie) expiration time in seconds."}, @@ -112,7 +93,7 @@ func init() { {"ldap-search-filter", "(uid=%s)", "LDAP search filter for user lookup."}, {"resources-dir", "/data/resources", "Path to a directory containing custom resources (e.g. background image)."}, {"database-path", "/data/tinyauth.db", "Path to the Sqlite database file."}, - {"trusted-proxies", "", "Comma separated list of trusted proxies (IP addresses) for correct client IP detection and for header ACLs."}, + {"trusted-proxies", "", "Comma separated list of trusted proxies (IP addresses or CIDRs) for correct client IP detection."}, } for _, opt := range configOptions { diff --git a/frontend/src/components/icons/microsoft.tsx b/frontend/src/components/icons/microsoft.tsx new file mode 100644 index 0000000..58d470c --- /dev/null +++ b/frontend/src/components/icons/microsoft.tsx @@ -0,0 +1,18 @@ +import type { SVGProps } from "react"; + +export function MicrosoftIcon(props: SVGProps) { + return ( + + + + + + + ); +} diff --git a/frontend/src/components/icons/generic.tsx b/frontend/src/components/icons/oauth.tsx similarity index 91% rename from frontend/src/components/icons/generic.tsx rename to frontend/src/components/icons/oauth.tsx index 6be8289..3ca531d 100644 --- a/frontend/src/components/icons/generic.tsx +++ b/frontend/src/components/icons/oauth.tsx @@ -1,6 +1,6 @@ import type { SVGProps } from "react"; -export function GenericIcon(props: SVGProps) { +export function OAuthIcon(props: SVGProps) { return ( ) { + return ( + + + + + ); +} diff --git a/frontend/src/components/icons/tailscale.tsx b/frontend/src/components/icons/tailscale.tsx new file mode 100644 index 0000000..9381b5c --- /dev/null +++ b/frontend/src/components/icons/tailscale.tsx @@ -0,0 +1,26 @@ +import type { SVGProps } from "react"; + +export function TailscaleIcon(props: SVGProps) { + return ( + + + + + + ); +} diff --git a/frontend/src/lib/i18n/locales/en-US.json b/frontend/src/lib/i18n/locales/en-US.json index 6338a88..4300428 100644 --- a/frontend/src/lib/i18n/locales/en-US.json +++ b/frontend/src/lib/i18n/locales/en-US.json @@ -14,6 +14,9 @@ "loginOauthFailSubtitle": "Failed to get OAuth URL", "loginOauthSuccessTitle": "Redirecting", "loginOauthSuccessSubtitle": "Redirecting to your OAuth provider", + "loginOauthAutoRedirectTitle": "OAuth Auto Redirect", + "loginOauthAutoRedirectSubtitle": "You will be automatically redirected to your OAuth provider to authenticate.", + "loginOauthAutoRedirectButton": "Redirect now", "continueTitle": "Continue", "continueRedirectingTitle": "Redirecting...", "continueRedirectingSubtitle": "You should be redirected to the app soon", diff --git a/frontend/src/lib/i18n/locales/en.json b/frontend/src/lib/i18n/locales/en.json index 6338a88..4300428 100644 --- a/frontend/src/lib/i18n/locales/en.json +++ b/frontend/src/lib/i18n/locales/en.json @@ -14,6 +14,9 @@ "loginOauthFailSubtitle": "Failed to get OAuth URL", "loginOauthSuccessTitle": "Redirecting", "loginOauthSuccessSubtitle": "Redirecting to your OAuth provider", + "loginOauthAutoRedirectTitle": "OAuth Auto Redirect", + "loginOauthAutoRedirectSubtitle": "You will be automatically redirected to your OAuth provider to authenticate.", + "loginOauthAutoRedirectButton": "Redirect now", "continueTitle": "Continue", "continueRedirectingTitle": "Redirecting...", "continueRedirectingSubtitle": "You should be redirected to the app soon", diff --git a/frontend/src/pages/continue-page.tsx b/frontend/src/pages/continue-page.tsx index f17bd97..dd03a4c 100644 --- a/frontend/src/pages/continue-page.tsx +++ b/frontend/src/pages/continue-page.tsx @@ -70,7 +70,7 @@ export const ContinuePage = () => { const reveal = setTimeout(() => { setLoading(false); setShowRedirectButton(true); - }, 1000); + }, 5000); return () => { clearTimeout(auto); diff --git a/frontend/src/pages/login-page.tsx b/frontend/src/pages/login-page.tsx index fd7108c..2f3bc99 100644 --- a/frontend/src/pages/login-page.tsx +++ b/frontend/src/pages/login-page.tsx @@ -1,13 +1,18 @@ import { LoginForm } from "@/components/auth/login-form"; -import { GenericIcon } from "@/components/icons/generic"; import { GithubIcon } from "@/components/icons/github"; import { GoogleIcon } from "@/components/icons/google"; +import { MicrosoftIcon } from "@/components/icons/microsoft"; +import { OAuthIcon } from "@/components/icons/oauth"; +import { PocketIDIcon } from "@/components/icons/pocket-id"; +import { TailscaleIcon } from "@/components/icons/tailscale"; +import { Button } from "@/components/ui/button"; import { Card, CardHeader, CardTitle, CardDescription, CardContent, + CardFooter, } from "@/components/ui/card"; import { OAuthButton } from "@/components/ui/oauth-button"; import { SeperatorWithChildren } from "@/components/ui/separator"; @@ -17,28 +22,40 @@ import { useIsMounted } from "@/lib/hooks/use-is-mounted"; import { LoginSchema } from "@/schemas/login-schema"; import { useMutation } from "@tanstack/react-query"; import axios, { AxiosError } from "axios"; -import { useEffect, useRef } from "react"; +import { useEffect, useRef, useState } from "react"; import { useTranslation } from "react-i18next"; import { Navigate, useLocation } from "react-router"; import { toast } from "sonner"; +const iconMap: Record = { + google: , + github: , + tailscale: , + microsoft: , + pocketid: , +}; + export const LoginPage = () => { const { isLoggedIn } = useUserContext(); - const { configuredProviders, title, oauthAutoRedirect, genericName } = - useAppContext(); + const { providers, title, oauthAutoRedirect } = useAppContext(); const { search } = useLocation(); const { t } = useTranslation(); const isMounted = useIsMounted(); + const [oauthAutoRedirectHandover, setOauthAutoRedirectHandover] = + useState(false); + const [showRedirectButton, setShowRedirectButton] = useState(false); const redirectTimer = useRef(null); + const redirectButtonTimer = useRef(null); const searchParams = new URLSearchParams(search); const redirectUri = searchParams.get("redirect_uri"); - const oauthConfigured = - configuredProviders.filter((provider) => provider !== "username").length > - 0; - const userAuthConfigured = configuredProviders.includes("username"); + const oauthProviders = providers.filter( + (provider) => provider.id !== "username", + ); + const userAuthConfigured = + providers.find((provider) => provider.id === "username") !== undefined; const oauthMutation = useMutation({ mutationFn: (provider: string) => @@ -56,6 +73,7 @@ export const LoginPage = () => { }, 500); }, onError: () => { + setOauthAutoRedirectHandover(false); toast.error(t("loginOauthFailTitle"), { description: t("loginOauthFailSubtitle"), }); @@ -96,12 +114,16 @@ export const LoginPage = () => { useEffect(() => { if (isMounted()) { if ( - oauthConfigured && - configuredProviders.includes(oauthAutoRedirect) && + oauthProviders.length !== 0 && + providers.find((provider) => provider.id === oauthAutoRedirect) && !isLoggedIn && redirectUri ) { + setOauthAutoRedirectHandover(true); oauthMutation.mutate(oauthAutoRedirect); + redirectButtonTimer.current = window.setTimeout(() => { + setShowRedirectButton(true); + }, 5000); } } }, []); @@ -109,6 +131,8 @@ export const LoginPage = () => { useEffect( () => () => { if (redirectTimer.current) clearTimeout(redirectTimer.current); + if (redirectButtonTimer.current) + clearTimeout(redirectButtonTimer.current); }, [], ); @@ -126,61 +150,63 @@ export const LoginPage = () => { return ; } + if (oauthAutoRedirectHandover) { + return ( + + + + {t("loginOauthAutoRedirectTitle")} + + + {t("loginOauthAutoRedirectSubtitle")} + + + {showRedirectButton && ( + + + + )} + + ); + } return ( {title} - {configuredProviders.length > 0 && ( + {providers.length > 0 && ( - {oauthConfigured ? t("loginTitle") : t("loginTitleSimple")} + {oauthProviders.length !== 0 + ? t("loginTitle") + : t("loginTitleSimple")} )} - {oauthConfigured && ( + {oauthProviders.length !== 0 && (
- {configuredProviders.includes("google") && ( + {oauthProviders.map((provider) => ( } + key={provider.id} + title={provider.name} + icon={iconMap[provider.id] ?? } className="w-full" - onClick={() => oauthMutation.mutate("google")} + onClick={() => oauthMutation.mutate(provider.id)} loading={ oauthMutation.isPending && - oauthMutation.variables === "google" + oauthMutation.variables === provider.id } disabled={oauthMutation.isPending || loginMutation.isPending} /> - )} - {configuredProviders.includes("github") && ( - } - className="w-full" - onClick={() => oauthMutation.mutate("github")} - loading={ - oauthMutation.isPending && - oauthMutation.variables === "github" - } - disabled={oauthMutation.isPending || loginMutation.isPending} - /> - )} - {configuredProviders.includes("generic") && ( - } - className="w-full" - onClick={() => oauthMutation.mutate("generic")} - loading={ - oauthMutation.isPending && - oauthMutation.variables === "generic" - } - disabled={oauthMutation.isPending || loginMutation.isPending} - /> - )} + ))}
)} - {userAuthConfigured && oauthConfigured && ( + {userAuthConfigured && oauthProviders.length !== 0 && ( {t("loginDivider")} )} {userAuthConfigured && ( @@ -189,7 +215,7 @@ export const LoginPage = () => { loading={loginMutation.isPending || oauthMutation.isPending} /> )} - {configuredProviders.length == 0 && ( + {providers.length == 0 && (

{t("failedToFetchProvidersTitle")}

diff --git a/frontend/src/pages/logout-page.tsx b/frontend/src/pages/logout-page.tsx index 17693bb..480d8ae 100644 --- a/frontend/src/pages/logout-page.tsx +++ b/frontend/src/pages/logout-page.tsx @@ -6,9 +6,7 @@ import { CardHeader, CardTitle, } from "@/components/ui/card"; -import { useAppContext } from "@/context/app-context"; import { useUserContext } from "@/context/user-context"; -import { capitalize } from "@/lib/utils"; import { useMutation } from "@tanstack/react-query"; import axios from "axios"; import { useEffect, useRef } from "react"; @@ -17,8 +15,7 @@ import { Navigate } from "react-router"; import { toast } from "sonner"; export const LogoutPage = () => { - const { provider, username, isLoggedIn, email } = useUserContext(); - const { genericName } = useAppContext(); + const { provider, username, isLoggedIn, email, oauthName } = useUserContext(); const { t } = useTranslation(); const redirectTimer = useRef(null); @@ -67,8 +64,7 @@ export const LogoutPage = () => { }} values={{ username: email, - provider: - provider === "generic" ? genericName : capitalize(provider), + provider: oauthName, }} /> ) : ( diff --git a/frontend/src/schemas/app-context-schema.ts b/frontend/src/schemas/app-context-schema.ts index 8931be1..ec766ee 100644 --- a/frontend/src/schemas/app-context-schema.ts +++ b/frontend/src/schemas/app-context-schema.ts @@ -1,14 +1,19 @@ import { z } from "zod"; +export const providerSchema = z.object({ + id: z.string(), + name: z.string(), + oauth: z.boolean(), +}); + export const appContextSchema = z.object({ - configuredProviders: z.array(z.string()), + providers: z.array(providerSchema), title: z.string(), - genericName: z.string(), appUrl: z.string(), cookieDomain: z.string(), forgotPasswordMessage: z.string(), - oauthAutoRedirect: z.enum(["none", "github", "google", "generic"]), backgroundImage: z.string(), + oauthAutoRedirect: z.string(), }); export type AppContextSchema = z.infer; diff --git a/frontend/src/schemas/user-context-schema.ts b/frontend/src/schemas/user-context-schema.ts index ee6682c..e7e057a 100644 --- a/frontend/src/schemas/user-context-schema.ts +++ b/frontend/src/schemas/user-context-schema.ts @@ -8,6 +8,7 @@ export const userContextSchema = z.object({ provider: z.string(), oauth: z.boolean(), totpPending: z.boolean(), + oauthName: z.string(), }); export type UserContextSchema = z.infer; diff --git a/internal/assets/migrations/000002_oauth_name.down.sql b/internal/assets/migrations/000002_oauth_name.down.sql new file mode 100644 index 0000000..75ce3b0 --- /dev/null +++ b/internal/assets/migrations/000002_oauth_name.down.sql @@ -0,0 +1 @@ +ALTER TABLE "sessions" DROP COLUMN "oauth_name"; \ No newline at end of file diff --git a/internal/assets/migrations/000002_oauth_name.up.sql b/internal/assets/migrations/000002_oauth_name.up.sql new file mode 100644 index 0000000..416bd29 --- /dev/null +++ b/internal/assets/migrations/000002_oauth_name.up.sql @@ -0,0 +1,10 @@ +ALTER TABLE "sessions" ADD COLUMN "oauth_name" TEXT; + +UPDATE "sessions" +SET "oauth_name" = CASE + WHEN LOWER("provider") = 'github' THEN 'GitHub' + WHEN LOWER("provider") = 'google' THEN 'Google' + ELSE UPPER(SUBSTR("provider", 1, 1)) || SUBSTR("provider", 2) +END +WHERE "oauth_name" IS NULL AND "provider" IS NOT NULL; + diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go index db2e564..5301a76 100644 --- a/internal/bootstrap/app_bootstrap.go +++ b/internal/bootstrap/app_bootstrap.go @@ -3,6 +3,7 @@ package bootstrap import ( "fmt" "net/url" + "os" "strings" "tinyauth/internal/config" "tinyauth/internal/controller" @@ -45,6 +46,13 @@ func (app *BootstrapApp) Setup() error { return err } + // Get OAuth configs + oauthProviders, err := utils.GetOAuthProvidersConfig(os.Environ(), os.Args, app.Config.AppURL) + + if err != nil { + return err + } + // Get cookie domain cookieDomain, err := utils.GetCookieDomain(app.Config.AppURL) @@ -112,7 +120,7 @@ func (app *BootstrapApp) Setup() error { // Create services dockerService := service.NewDockerService() authService := service.NewAuthService(authConfig, dockerService, ldapService, database) - oauthBrokerService := service.NewOAuthBrokerService(app.getOAuthBrokerConfig()) + oauthBrokerService := service.NewOAuthBrokerService(oauthProviders) // Initialize services services := []Service{ @@ -132,13 +140,41 @@ func (app *BootstrapApp) Setup() error { } // Configured providers - var configuredProviders []string + babysit := map[string]string{ + "google": "Google", + "github": "GitHub", + } + configuredProviders := make([]controller.Provider, 0) - if authService.UserAuthConfigured() || ldapService != nil { - configuredProviders = append(configuredProviders, "username") + for id, provider := range oauthProviders { + if id == "" { + continue + } + + if provider.Name == "" { + if name, ok := babysit[id]; ok { + provider.Name = name + } else { + provider.Name = utils.Capitalize(id) + } + } + + configuredProviders = append(configuredProviders, controller.Provider{ + Name: provider.Name, + ID: id, + OAuth: true, + }) } - configuredProviders = append(configuredProviders, oauthBrokerService.GetConfiguredServices()...) + if authService.UserAuthConfigured() || ldapService != nil { + configuredProviders = append(configuredProviders, controller.Provider{ + Name: "Username", + ID: "username", + OAuth: false, + }) + } + + log.Debug().Interface("providers", configuredProviders).Msg("Authentication providers") if len(configuredProviders) == 0 { return fmt.Errorf("no authentication providers configured") @@ -179,9 +215,8 @@ func (app *BootstrapApp) Setup() error { // Create controllers contextController := controller.NewContextController(controller.ContextControllerConfig{ - ConfiguredProviders: configuredProviders, + Providers: configuredProviders, Title: app.Config.Title, - GenericName: app.Config.GenericName, AppURL: app.Config.AppURL, CookieDomain: cookieDomain, ForgotPasswordMessage: app.Config.ForgotPasswordMessage, @@ -235,30 +270,3 @@ func (app *BootstrapApp) Setup() error { return nil } - -// Temporary -func (app *BootstrapApp) getOAuthBrokerConfig() map[string]config.OAuthServiceConfig { - return map[string]config.OAuthServiceConfig{ - "google": { - ClientID: app.Config.GoogleClientId, - ClientSecret: app.Config.GoogleClientSecret, - RedirectURL: fmt.Sprintf("%s/api/oauth/callback/google", app.Config.AppURL), - }, - "github": { - ClientID: app.Config.GithubClientId, - ClientSecret: app.Config.GithubClientSecret, - RedirectURL: fmt.Sprintf("%s/api/oauth/callback/github", app.Config.AppURL), - }, - "generic": { - ClientID: app.Config.GenericClientId, - ClientSecret: app.Config.GenericClientSecret, - RedirectURL: fmt.Sprintf("%s/api/oauth/callback/generic", app.Config.AppURL), - Scopes: strings.Split(app.Config.GenericScopes, ","), - AuthURL: app.Config.GenericAuthURL, - TokenURL: app.Config.GenericTokenURL, - UserinfoURL: app.Config.GenericUserURL, - InsecureSkipVerify: app.Config.GenericSkipSSL, - }, - } - -} diff --git a/internal/config/config.go b/internal/config/config.go index 7ccedd3..4fc66fc 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -15,45 +15,30 @@ var RedirectCookieName = "tinyauth-redirect" // Main app config type Config struct { - Port int `mapstructure:"port" validate:"required"` - Address string `validate:"required,ip4_addr" mapstructure:"address"` - AppURL string `validate:"required,url" mapstructure:"app-url"` - Users string `mapstructure:"users"` - UsersFile string `mapstructure:"users-file"` - SecureCookie bool `mapstructure:"secure-cookie"` - GithubClientId string `mapstructure:"github-client-id"` - GithubClientSecret string `mapstructure:"github-client-secret"` - GithubClientSecretFile string `mapstructure:"github-client-secret-file"` - GoogleClientId string `mapstructure:"google-client-id"` - GoogleClientSecret string `mapstructure:"google-client-secret"` - GoogleClientSecretFile string `mapstructure:"google-client-secret-file"` - GenericClientId string `mapstructure:"generic-client-id"` - GenericClientSecret string `mapstructure:"generic-client-secret"` - GenericClientSecretFile string `mapstructure:"generic-client-secret-file"` - GenericScopes string `mapstructure:"generic-scopes"` - GenericAuthURL string `mapstructure:"generic-auth-url"` - GenericTokenURL string `mapstructure:"generic-token-url"` - GenericUserURL string `mapstructure:"generic-user-url"` - GenericName string `mapstructure:"generic-name"` - GenericSkipSSL bool `mapstructure:"generic-skip-ssl"` - OAuthWhitelist string `mapstructure:"oauth-whitelist"` - OAuthAutoRedirect string `mapstructure:"oauth-auto-redirect" validate:"oneof=none github google generic"` - SessionExpiry int `mapstructure:"session-expiry"` - LogLevel string `mapstructure:"log-level" validate:"oneof=trace debug info warn error fatal panic"` - Title string `mapstructure:"app-title"` - LoginTimeout int `mapstructure:"login-timeout"` - LoginMaxRetries int `mapstructure:"login-max-retries"` - ForgotPasswordMessage string `mapstructure:"forgot-password-message"` - BackgroundImage string `mapstructure:"background-image" validate:"required"` - LdapAddress string `mapstructure:"ldap-address"` - LdapBindDN string `mapstructure:"ldap-bind-dn"` - LdapBindPassword string `mapstructure:"ldap-bind-password"` - LdapBaseDN string `mapstructure:"ldap-base-dn"` - LdapInsecure bool `mapstructure:"ldap-insecure"` - LdapSearchFilter string `mapstructure:"ldap-search-filter"` - ResourcesDir string `mapstructure:"resources-dir"` - DatabasePath string `mapstructure:"database-path" validate:"required"` - TrustedProxies string `mapstructure:"trusted-proxies"` + Port int `mapstructure:"port" validate:"required"` + Address string `validate:"required,ip4_addr" mapstructure:"address"` + AppURL string `validate:"required,url" mapstructure:"app-url"` + Users string `mapstructure:"users"` + UsersFile string `mapstructure:"users-file"` + SecureCookie bool `mapstructure:"secure-cookie"` + OAuthWhitelist string `mapstructure:"oauth-whitelist"` + OAuthAutoRedirect string `mapstructure:"oauth-auto-redirect"` + SessionExpiry int `mapstructure:"session-expiry"` + LogLevel string `mapstructure:"log-level" validate:"oneof=trace debug info warn error fatal panic"` + Title string `mapstructure:"app-title"` + LoginTimeout int `mapstructure:"login-timeout"` + LoginMaxRetries int `mapstructure:"login-max-retries"` + ForgotPasswordMessage string `mapstructure:"forgot-password-message"` + BackgroundImage string `mapstructure:"background-image" validate:"required"` + LdapAddress string `mapstructure:"ldap-address"` + LdapBindDN string `mapstructure:"ldap-bind-dn"` + LdapBindPassword string `mapstructure:"ldap-bind-password"` + LdapBaseDN string `mapstructure:"ldap-base-dn"` + LdapInsecure bool `mapstructure:"ldap-insecure"` + LdapSearchFilter string `mapstructure:"ldap-search-filter"` + ResourcesDir string `mapstructure:"resources-dir"` + DatabasePath string `mapstructure:"database-path" validate:"required"` + TrustedProxies string `mapstructure:"trusted-proxies"` } // OAuth/OIDC config @@ -66,14 +51,16 @@ type Claims struct { } type OAuthServiceConfig struct { - ClientID string - ClientSecret string - Scopes []string - RedirectURL string - AuthURL string - TokenURL string - UserinfoURL string - InsecureSkipVerify bool + ClientID string `key:"client-id"` + ClientSecret string `key:"client-secret"` + ClientSecretFile string `key:"client-secret-file"` + Scopes []string `key:"scopes"` + RedirectURL string `key:"redirect-url"` + AuthURL string `key:"auth-url"` + TokenURL string `key:"token-url"` + UserinfoURL string `key:"user-info-url"` + InsecureSkipVerify bool `key:"insecure-skip-verify"` + Name string `key:"name"` } // User/session related stuff @@ -97,6 +84,7 @@ type SessionCookie struct { Provider string TotpPending bool OAuthGroups string + OAuthName string } type UserContext struct { @@ -109,6 +97,7 @@ type UserContext struct { TotpPending bool OAuthGroups string TotpEnabled bool + OAuthName string } // API responses and queries @@ -174,3 +163,9 @@ type AppPath struct { Allow string Block string } + +// Flags + +type Providers struct { + Providers map[string]OAuthServiceConfig +} diff --git a/internal/controller/context_controller.go b/internal/controller/context_controller.go index ee3eec6..80ec61a 100644 --- a/internal/controller/context_controller.go +++ b/internal/controller/context_controller.go @@ -19,25 +19,30 @@ type UserContextResponse struct { Provider string `json:"provider"` OAuth bool `json:"oauth"` TotpPending bool `json:"totpPending"` + OAuthName string `json:"oauthName"` } type AppContextResponse struct { - Status int `json:"status"` - Message string `json:"message"` - ConfiguredProviders []string `json:"configuredProviders"` - Title string `json:"title"` - GenericName string `json:"genericName"` - AppURL string `json:"appUrl"` - CookieDomain string `json:"cookieDomain"` - ForgotPasswordMessage string `json:"forgotPasswordMessage"` - BackgroundImage string `json:"backgroundImage"` - OAuthAutoRedirect string `json:"oauthAutoRedirect"` + Status int `json:"status"` + Message string `json:"message"` + Providers []Provider `json:"providers"` + Title string `json:"title"` + AppURL string `json:"appUrl"` + CookieDomain string `json:"cookieDomain"` + ForgotPasswordMessage string `json:"forgotPasswordMessage"` + BackgroundImage string `json:"backgroundImage"` + OAuthAutoRedirect string `json:"oauthAutoRedirect"` +} + +type Provider struct { + Name string `json:"name"` + ID string `json:"id"` + OAuth bool `json:"oauth"` } type ContextControllerConfig struct { - ConfiguredProviders []string + Providers []Provider Title string - GenericName string AppURL string CookieDomain string ForgotPasswordMessage string @@ -76,6 +81,7 @@ func (controller *ContextController) userContextHandler(c *gin.Context) { Provider: context.Provider, OAuth: context.OAuth, TotpPending: context.TotpPending, + OAuthName: context.OAuthName, } if err != nil { @@ -96,9 +102,8 @@ func (controller *ContextController) appContextHandler(c *gin.Context) { c.JSON(200, AppContextResponse{ Status: 200, Message: "Success", - ConfiguredProviders: controller.config.ConfiguredProviders, + Providers: controller.config.Providers, Title: controller.config.Title, - GenericName: controller.config.GenericName, AppURL: fmt.Sprintf("%s://%s", appUrl.Scheme, appUrl.Host), CookieDomain: controller.config.CookieDomain, ForgotPasswordMessage: controller.config.ForgotPasswordMessage, diff --git a/internal/controller/context_controller_test.go b/internal/controller/context_controller_test.go index 44f77a1..85be0b5 100644 --- a/internal/controller/context_controller_test.go +++ b/internal/controller/context_controller_test.go @@ -12,9 +12,19 @@ import ( ) var controllerCfg = controller.ContextControllerConfig{ - ConfiguredProviders: []string{"github", "google", "generic"}, + Providers: []controller.Provider{ + { + Name: "Username", + ID: "username", + OAuth: false, + }, + { + Name: "Google", + ID: "google", + OAuth: true, + }, + }, Title: "Test App", - GenericName: "Generic", AppURL: "http://localhost:8080", CookieDomain: "localhost", ForgotPasswordMessage: "Contact admin to reset your password.", @@ -58,9 +68,8 @@ func TestAppContextHandler(t *testing.T) { expectedRes := controller.AppContextResponse{ Status: 200, Message: "Success", - ConfiguredProviders: controllerCfg.ConfiguredProviders, + Providers: controllerCfg.Providers, Title: controllerCfg.Title, - GenericName: controllerCfg.GenericName, AppURL: controllerCfg.AppURL, CookieDomain: controllerCfg.CookieDomain, ForgotPasswordMessage: controllerCfg.ForgotPasswordMessage, diff --git a/internal/controller/oauth_controller.go b/internal/controller/oauth_controller.go index a65b53a..bf50ff9 100644 --- a/internal/controller/oauth_controller.go +++ b/internal/controller/oauth_controller.go @@ -186,6 +186,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { Email: user.Email, Provider: req.Provider, OAuthGroups: utils.CoalesceToString(user.Groups), + OAuthName: service.GetName(), }) if err != nil { diff --git a/internal/middleware/context_middleware.go b/internal/middleware/context_middleware.go index 30fa623..2c903be 100644 --- a/internal/middleware/context_middleware.go +++ b/internal/middleware/context_middleware.go @@ -95,6 +95,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { Email: cookie.Email, Provider: cookie.Provider, OAuthGroups: cookie.OAuthGroups, + OAuthName: cookie.OAuthName, IsLoggedIn: true, OAuth: true, }) diff --git a/internal/model/session_model.go b/internal/model/session_model.go index 45e6065..0fdb6c3 100644 --- a/internal/model/session_model.go +++ b/internal/model/session_model.go @@ -9,4 +9,5 @@ type Session struct { TOTPPending bool `gorm:"column:totp_pending"` OAuthGroups string `gorm:"column:oauth_groups"` Expiry int64 `gorm:"column:expiry"` + OAuthName string `gorm:"column:oauth_name"` } diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index a3f8ed0..8925e49 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -210,6 +210,7 @@ func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *config.Sessio TOTPPending: data.TotpPending, OAuthGroups: data.OAuthGroups, Expiry: time.Now().Add(time.Duration(expiry) * time.Second).Unix(), + OAuthName: data.OAuthName, } err = auth.database.Create(&session).Error @@ -278,6 +279,7 @@ func (auth *AuthService) GetSessionCookie(c *gin.Context) (config.SessionCookie, Provider: session.Provider, TotpPending: session.TOTPPending, OAuthGroups: session.OAuthGroups, + OAuthName: session.OAuthName, }, nil } diff --git a/internal/service/generic_oauth_service.go b/internal/service/generic_oauth_service.go index 72c2357..aae89c4 100644 --- a/internal/service/generic_oauth_service.go +++ b/internal/service/generic_oauth_service.go @@ -22,6 +22,7 @@ type GenericOAuthService struct { verifier string insecureSkipVerify bool userinfoUrl string + name string } func NewGenericOAuthService(config config.OAuthServiceConfig) *GenericOAuthService { @@ -38,6 +39,7 @@ func NewGenericOAuthService(config config.OAuthServiceConfig) *GenericOAuthServi }, insecureSkipVerify: config.InsecureSkipVerify, userinfoUrl: config.UserinfoURL, + name: config.Name, } } @@ -115,3 +117,7 @@ func (generic *GenericOAuthService) Userinfo() (config.Claims, error) { return user, nil } + +func (generic *GenericOAuthService) GetName() string { + return generic.name +} diff --git a/internal/service/github_oauth_service.go b/internal/service/github_oauth_service.go index 26d73b1..163c2c8 100644 --- a/internal/service/github_oauth_service.go +++ b/internal/service/github_oauth_service.go @@ -33,6 +33,7 @@ type GithubOAuthService struct { context context.Context token *oauth2.Token verifier string + name string } func NewGithubOAuthService(config config.OAuthServiceConfig) *GithubOAuthService { @@ -44,6 +45,7 @@ func NewGithubOAuthService(config config.OAuthServiceConfig) *GithubOAuthService Scopes: GithubOAuthScopes, Endpoint: endpoints.GitHub, }, + name: config.Name, } } @@ -167,3 +169,7 @@ func (github *GithubOAuthService) Userinfo() (config.Claims, error) { return user, nil } + +func (github *GithubOAuthService) GetName() string { + return github.name +} diff --git a/internal/service/google_oauth_service.go b/internal/service/google_oauth_service.go index 0f8c7eb..ab0597d 100644 --- a/internal/service/google_oauth_service.go +++ b/internal/service/google_oauth_service.go @@ -28,6 +28,7 @@ type GoogleOAuthService struct { context context.Context token *oauth2.Token verifier string + name string } func NewGoogleOAuthService(config config.OAuthServiceConfig) *GoogleOAuthService { @@ -39,6 +40,7 @@ func NewGoogleOAuthService(config config.OAuthServiceConfig) *GoogleOAuthService Scopes: GoogleOAuthScopes, Endpoint: endpoints.Google, }, + name: config.Name, } } @@ -111,3 +113,7 @@ func (google *GoogleOAuthService) Userinfo() (config.Claims, error) { return user, nil } + +func (google *GoogleOAuthService) GetName() string { + return google.name +} diff --git a/internal/service/oauth_broker_service.go b/internal/service/oauth_broker_service.go index f9df4f8..e6c6ddb 100644 --- a/internal/service/oauth_broker_service.go +++ b/internal/service/oauth_broker_service.go @@ -14,6 +14,7 @@ type OAuthService interface { GetAuthURL(state string) string VerifyCode(code string) error Userinfo() (config.Claims, error) + GetName() string } type OAuthBrokerService struct { diff --git a/internal/utils/app_utils.go b/internal/utils/app_utils.go index c4b98c6..643c9cf 100644 --- a/internal/utils/app_utils.go +++ b/internal/utils/app_utils.go @@ -6,6 +6,9 @@ import ( "net/url" "strings" "tinyauth/internal/config" + "tinyauth/internal/utils/decoders" + + "maps" "github.com/gin-gonic/gin" "github.com/rs/zerolog" @@ -130,3 +133,68 @@ func GetLogLevel(level string) zerolog.Level { return zerolog.InfoLevel } } + +func GetOAuthProvidersConfig(env []string, args []string, appUrl string) (map[string]config.OAuthServiceConfig, error) { + providers := make(map[string]config.OAuthServiceConfig) + + // Get from environment variables + envMap := make(map[string]string) + + for _, e := range env { + pair := strings.SplitN(e, "=", 2) + if len(pair) == 2 { + envMap[pair[0]] = pair[1] + } + } + + envProviders, err := decoders.DecodeEnv(envMap) + + if err != nil { + return nil, err + } + + maps.Copy(providers, envProviders.Providers) + + // Get from flags + flagsMap := make(map[string]string) + + for _, arg := range args[1:] { + if strings.HasPrefix(arg, "--") { + pair := strings.SplitN(arg[2:], "=", 2) + if len(pair) == 2 { + flagsMap[pair[0]] = pair[1] + } + } + } + + flagProviders, err := decoders.DecodeFlags(flagsMap) + + if err != nil { + return nil, err + } + + maps.Copy(providers, flagProviders.Providers) + + // For every provider get correct secret from file if set + for name, provider := range providers { + secret := GetSecret(provider.ClientSecret, provider.ClientSecretFile) + provider.ClientSecret = secret + provider.ClientSecretFile = "" + providers[name] = provider + } + + // If we have google/github providers and no redirect URL babysit them + babysitProviders := []string{"google", "github"} + + for _, name := range babysitProviders { + if provider, exists := providers[name]; exists { + if provider.RedirectURL == "" { + provider.RedirectURL = appUrl + "/api/oauth/callback/" + name + providers[name] = provider + } + } + } + + // Return combined providers + return providers, nil +} diff --git a/internal/utils/app_utils_test.go b/internal/utils/app_utils_test.go index c35db3d..a7f09fe 100644 --- a/internal/utils/app_utils_test.go +++ b/internal/utils/app_utils_test.go @@ -1,6 +1,7 @@ package utils_test import ( + "os" "testing" "tinyauth/internal/config" "tinyauth/internal/utils" @@ -200,3 +201,71 @@ func TestIsRedirectSafe(t *testing.T) { result = utils.IsRedirectSafe(redirectURL, domain) assert.Equal(t, false, result) } + +func TestGetOAuthProvidersConfig(t *testing.T) { + env := []string{"PROVIDERS_CLIENT1_CLIENT_ID=client1-id", "PROVIDERS_CLIENT1_CLIENT_SECRET=client1-secret"} + args := []string{"/tinyauth/tinyauth", "--providers-client2-client-id=client2-id", "--providers-client2-client-secret=client2-secret"} + + expected := map[string]config.OAuthServiceConfig{ + "client1": { + ClientID: "client1-id", + ClientSecret: "client1-secret", + }, + "client2": { + ClientID: "client2-id", + ClientSecret: "client2-secret", + }, + } + + result, err := utils.GetOAuthProvidersConfig(env, args, "") + assert.NilError(t, err) + assert.DeepEqual(t, expected, result) + + // Case with no providers + env = []string{} + args = []string{"/tinyauth/tinyauth"} + expected = map[string]config.OAuthServiceConfig{} + + result, err = utils.GetOAuthProvidersConfig(env, args, "") + assert.NilError(t, err) + assert.DeepEqual(t, expected, result) + + // Case with secret from file + file, err := os.Create("/tmp/tinyauth_test_file") + assert.NilError(t, err) + + _, err = file.WriteString("file content\n") + assert.NilError(t, err) + + err = file.Close() + assert.NilError(t, err) + defer os.Remove("/tmp/tinyauth_test_file") + + env = []string{"PROVIDERS_CLIENT1_CLIENT_ID=client1-id", "PROVIDERS_CLIENT1_CLIENT_SECRET_FILE=/tmp/tinyauth_test_file"} + args = []string{"/tinyauth/tinyauth"} + expected = map[string]config.OAuthServiceConfig{ + "client1": { + ClientID: "client1-id", + ClientSecret: "file content", + }, + } + + result, err = utils.GetOAuthProvidersConfig(env, args, "") + assert.NilError(t, err) + assert.DeepEqual(t, expected, result) + + // Case with google provider and no redirect URL + env = []string{"PROVIDERS_GOOGLE_CLIENT_ID=google-id", "PROVIDERS_GOOGLE_CLIENT_SECRET=google-secret"} + args = []string{"/tinyauth/tinyauth"} + expected = map[string]config.OAuthServiceConfig{ + "google": { + ClientID: "google-id", + ClientSecret: "google-secret", + RedirectURL: "http://app.url/api/oauth/callback/google", + }, + } + + result, err = utils.GetOAuthProvidersConfig(env, args, "http://app.url") + assert.NilError(t, err) + assert.DeepEqual(t, expected, result) +} diff --git a/internal/utils/decoders/decoders.go b/internal/utils/decoders/decoders.go new file mode 100644 index 0000000..72a11d5 --- /dev/null +++ b/internal/utils/decoders/decoders.go @@ -0,0 +1,81 @@ +package decoders + +import ( + "reflect" + "strings" + "tinyauth/internal/config" +) + +func NormalizeKeys(keys map[string]string, rootName string, sep string) map[string]string { + normalized := make(map[string]string) + knownKeys := getKnownKeys() + + for k, v := range keys { + var finalKey []string + var suffix string + var camelClientName string + var camelField string + + finalKey = append(finalKey, rootName) + finalKey = append(finalKey, "providers") + cebabKey := strings.ToLower(k) + + for _, known := range knownKeys { + if strings.HasSuffix(cebabKey, strings.ReplaceAll(known, "-", sep)) { + suffix = known + break + } + } + + if suffix == "" { + continue + } + + clientNameParts := strings.Split(strings.TrimPrefix(strings.TrimSuffix(cebabKey, sep+strings.ReplaceAll(suffix, "-", sep)), "providers"+sep), sep) + + for i, p := range clientNameParts { + if i == 0 { + camelClientName += p + continue + } + if p == "" { + continue + } + camelClientName += strings.ToUpper(string([]rune(p)[0])) + string([]rune(p)[1:]) + } + + finalKey = append(finalKey, camelClientName) + + filedParts := strings.Split(suffix, "-") + + for i, p := range filedParts { + if i == 0 { + camelField += p + continue + } + if p == "" { + continue + } + camelField += strings.ToUpper(string([]rune(p)[0])) + string([]rune(p)[1:]) + } + + finalKey = append(finalKey, camelField) + normalized[strings.Join(finalKey, ".")] = v + } + + return normalized +} + +func getKnownKeys() []string { + var known []string + + p := config.OAuthServiceConfig{} + v := reflect.ValueOf(p) + typeOfP := v.Type() + + for field := range typeOfP.NumField() { + known = append(known, typeOfP.Field(field).Tag.Get("key")) + } + + return known +} diff --git a/internal/utils/decoders/decoders_test.go b/internal/utils/decoders/decoders_test.go new file mode 100644 index 0000000..285760c --- /dev/null +++ b/internal/utils/decoders/decoders_test.go @@ -0,0 +1,44 @@ +package decoders_test + +import ( + "testing" + "tinyauth/internal/utils/decoders" + + "gotest.tools/v3/assert" +) + +func TestNormalizeKeys(t *testing.T) { + // Test with env + test := map[string]string{ + "PROVIDERS_CLIENT1_CLIENT_ID": "my-client-id", + "PROVIDERS_CLIENT1_CLIENT_SECRET": "my-client-secret", + "PROVIDERS_MY_AWESOME_CLIENT_CLIENT_ID": "my-awesome-client-id", + "PROVIDERS_MY_AWESOME_CLIENT_CLIENT_SECRET_FILE": "/path/to/secret", + } + expected := map[string]string{ + "tinyauth.providers.client1.clientId": "my-client-id", + "tinyauth.providers.client1.clientSecret": "my-client-secret", + "tinyauth.providers.myAwesomeClient.clientId": "my-awesome-client-id", + "tinyauth.providers.myAwesomeClient.clientSecretFile": "/path/to/secret", + } + + normalized := decoders.NormalizeKeys(test, "tinyauth", "_") + assert.DeepEqual(t, normalized, expected) + + // Test with flags (assume -- is already stripped) + test = map[string]string{ + "providers-client1-client-id": "my-client-id", + "providers-client1-client-secret": "my-client-secret", + "providers-my-awesome-client-client-id": "my-awesome-client-id", + "providers-my-awesome-client-client-secret-file": "/path/to/secret", + } + expected = map[string]string{ + "tinyauth.providers.client1.clientId": "my-client-id", + "tinyauth.providers.client1.clientSecret": "my-client-secret", + "tinyauth.providers.myAwesomeClient.clientId": "my-awesome-client-id", + "tinyauth.providers.myAwesomeClient.clientSecretFile": "/path/to/secret", + } + + normalized = decoders.NormalizeKeys(test, "tinyauth", "-") + assert.DeepEqual(t, normalized, expected) +} diff --git a/internal/utils/decoders/env_decoder.go b/internal/utils/decoders/env_decoder.go new file mode 100644 index 0000000..4164aa5 --- /dev/null +++ b/internal/utils/decoders/env_decoder.go @@ -0,0 +1,20 @@ +package decoders + +import ( + "tinyauth/internal/config" + + "github.com/traefik/paerser/parser" +) + +func DecodeEnv(env map[string]string) (config.Providers, error) { + normalized := NormalizeKeys(env, "tinyauth", "_") + var providers config.Providers + + err := parser.Decode(normalized, &providers, "tinyauth", "tinyauth.providers") + + if err != nil { + return config.Providers{}, err + } + + return providers, nil +} diff --git a/internal/utils/decoders/env_decoder_test.go b/internal/utils/decoders/env_decoder_test.go new file mode 100644 index 0000000..2233241 --- /dev/null +++ b/internal/utils/decoders/env_decoder_test.go @@ -0,0 +1,60 @@ +package decoders_test + +import ( + "testing" + "tinyauth/internal/config" + "tinyauth/internal/utils/decoders" + + "gotest.tools/v3/assert" +) + +func TestDecodeEnv(t *testing.T) { + // Variables + expected := config.Providers{ + Providers: map[string]config.OAuthServiceConfig{ + "client1": { + ClientID: "client1-id", + ClientSecret: "client1-secret", + Scopes: []string{"client1-scope1", "client1-scope2"}, + RedirectURL: "client1-redirect-url", + AuthURL: "client1-auth-url", + UserinfoURL: "client1-user-info-url", + Name: "Client1", + InsecureSkipVerify: false, + }, + "client2": { + ClientID: "client2-id", + ClientSecret: "client2-secret", + Scopes: []string{"client2-scope1", "client2-scope2"}, + RedirectURL: "client2-redirect-url", + AuthURL: "client2-auth-url", + UserinfoURL: "client2-user-info-url", + Name: "My Awesome Client2", + InsecureSkipVerify: false, + }, + }, + } + test := map[string]string{ + "PROVIDERS_CLIENT1_CLIENT_ID": "client1-id", + "PROVIDERS_CLIENT1_CLIENT_SECRET": "client1-secret", + "PROVIDERS_CLIENT1_SCOPES": "client1-scope1,client1-scope2", + "PROVIDERS_CLIENT1_REDIRECT_URL": "client1-redirect-url", + "PROVIDERS_CLIENT1_AUTH_URL": "client1-auth-url", + "PROVIDERS_CLIENT1_USER_INFO_URL": "client1-user-info-url", + "PROVIDERS_CLIENT1_NAME": "Client1", + "PROVIDERS_CLIENT1_INSECURE_SKIP_VERIFY": "false", + "PROVIDERS_CLIENT2_CLIENT_ID": "client2-id", + "PROVIDERS_CLIENT2_CLIENT_SECRET": "client2-secret", + "PROVIDERS_CLIENT2_SCOPES": "client2-scope1,client2-scope2", + "PROVIDERS_CLIENT2_REDIRECT_URL": "client2-redirect-url", + "PROVIDERS_CLIENT2_AUTH_URL": "client2-auth-url", + "PROVIDERS_CLIENT2_USER_INFO_URL": "client2-user-info-url", + "PROVIDERS_CLIENT2_NAME": "My Awesome Client2", + "PROVIDERS_CLIENT2_INSECURE_SKIP_VERIFY": "false", + } + + // Test + res, err := decoders.DecodeEnv(test) + assert.NilError(t, err) + assert.DeepEqual(t, expected, res) +} diff --git a/internal/utils/decoders/flags_decoder.go b/internal/utils/decoders/flags_decoder.go new file mode 100644 index 0000000..d973d29 --- /dev/null +++ b/internal/utils/decoders/flags_decoder.go @@ -0,0 +1,30 @@ +package decoders + +import ( + "strings" + "tinyauth/internal/config" + + "github.com/traefik/paerser/parser" +) + +func DecodeFlags(flags map[string]string) (config.Providers, error) { + filtered := filterFlags(flags) + normalized := NormalizeKeys(filtered, "tinyauth", "-") + var providers config.Providers + + err := parser.Decode(normalized, &providers, "tinyauth", "tinyauth.providers") + + if err != nil { + return config.Providers{}, err + } + + return providers, nil +} + +func filterFlags(flags map[string]string) map[string]string { + filtered := make(map[string]string) + for k, v := range flags { + filtered[strings.TrimPrefix(k, "--")] = v + } + return filtered +} diff --git a/internal/utils/decoders/flags_decoder_test.go b/internal/utils/decoders/flags_decoder_test.go new file mode 100644 index 0000000..356b4ae --- /dev/null +++ b/internal/utils/decoders/flags_decoder_test.go @@ -0,0 +1,60 @@ +package decoders_test + +import ( + "testing" + "tinyauth/internal/config" + "tinyauth/internal/utils/decoders" + + "gotest.tools/v3/assert" +) + +func TestDecodeFlags(t *testing.T) { + // Variables + expected := config.Providers{ + Providers: map[string]config.OAuthServiceConfig{ + "client1": { + ClientID: "client1-id", + ClientSecret: "client1-secret", + Scopes: []string{"client1-scope1", "client1-scope2"}, + RedirectURL: "client1-redirect-url", + AuthURL: "client1-auth-url", + UserinfoURL: "client1-user-info-url", + Name: "Client1", + InsecureSkipVerify: false, + }, + "client2": { + ClientID: "client2-id", + ClientSecret: "client2-secret", + Scopes: []string{"client2-scope1", "client2-scope2"}, + RedirectURL: "client2-redirect-url", + AuthURL: "client2-auth-url", + UserinfoURL: "client2-user-info-url", + Name: "My Awesome Client2", + InsecureSkipVerify: false, + }, + }, + } + test := map[string]string{ + "--providers-client1-client-id": "client1-id", + "--providers-client1-client-secret": "client1-secret", + "--providers-client1-scopes": "client1-scope1,client1-scope2", + "--providers-client1-redirect-url": "client1-redirect-url", + "--providers-client1-auth-url": "client1-auth-url", + "--providers-client1-user-info-url": "client1-user-info-url", + "--providers-client1-name": "Client1", + "--providers-client1-insecure-skip-verify": "false", + "--providers-client2-client-id": "client2-id", + "--providers-client2-client-secret": "client2-secret", + "--providers-client2-scopes": "client2-scope1,client2-scope2", + "--providers-client2-redirect-url": "client2-redirect-url", + "--providers-client2-auth-url": "client2-auth-url", + "--providers-client2-user-info-url": "client2-user-info-url", + "--providers-client2-name": "My Awesome Client2", + "--providers-client2-insecure-skip-verify": "false", + } + + // Test + res, err := decoders.DecodeFlags(test) + assert.NilError(t, err) + assert.DeepEqual(t, expected, res) +} diff --git a/internal/utils/decoders/label_decoder_test.go b/internal/utils/decoders/label_decoder_test.go index 1df885c..63189d1 100644 --- a/internal/utils/decoders/label_decoder_test.go +++ b/internal/utils/decoders/label_decoder_test.go @@ -1,10 +1,11 @@ package decoders_test import ( - "reflect" "testing" "tinyauth/internal/config" "tinyauth/internal/utils/decoders" + + "gotest.tools/v3/assert" ) func TestDecodeLabels(t *testing.T) { @@ -62,12 +63,6 @@ func TestDecodeLabels(t *testing.T) { // Test result, err := decoders.DecodeLabels(test) - - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - if reflect.DeepEqual(expected, result) == false { - t.Fatalf("Expected %v but got %v", expected, result) - } + assert.NilError(t, err) + assert.DeepEqual(t, expected, result) }