mirror of
				https://github.com/steveiliop56/tinyauth.git
				synced 2025-10-30 21:55:43 +00:00 
			
		
		
		
	Compare commits
	
		
			12 Commits
		
	
	
		
			docs/updat
			...
			7795a989cd
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|   | 7795a989cd | ||
|   | cebce1a92c | ||
|   | 120ae2c79d | ||
|   | 060e20e578 | ||
|   | e001f63eb5 | ||
|   | 9f97a4ddd5 | ||
|   | e5ecf6336f | ||
|   | fbf5843592 | ||
|   | 5fcc50d5fd | ||
|   | 68fd5ac24c | ||
|   | b30b908de3 | ||
|   | 91048c16f8 | 
							
								
								
									
										23
									
								
								cmd/root.go
									
									
									
									
									
								
							
							
						
						
									
										23
									
								
								cmd/root.go
									
									
									
									
									
								
							| @@ -27,11 +27,6 @@ var rootCmd = &cobra.Command{ | |||||||
| 			log.Fatal().Err(err).Msg("Failed to parse config") | 			log.Fatal().Err(err).Msg("Failed to parse config") | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// Check if secrets have a file associated with them |  | ||||||
| 		conf.GithubClientSecret = utils.GetSecret(conf.GithubClientSecret, conf.GithubClientSecretFile) |  | ||||||
| 		conf.GoogleClientSecret = utils.GetSecret(conf.GoogleClientSecret, conf.GoogleClientSecretFile) |  | ||||||
| 		conf.GenericClientSecret = utils.GetSecret(conf.GenericClientSecret, conf.GenericClientSecretFile) |  | ||||||
|  |  | ||||||
| 		// Validate config | 		// Validate config | ||||||
| 		v := validator.New() | 		v := validator.New() | ||||||
|  |  | ||||||
| @@ -57,6 +52,7 @@ var rootCmd = &cobra.Command{ | |||||||
| } | } | ||||||
|  |  | ||||||
| func Execute() { | func Execute() { | ||||||
|  | 	rootCmd.FParseErrWhitelist.UnknownFlags = true | ||||||
| 	err := rootCmd.Execute() | 	err := rootCmd.Execute() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Fatal().Err(err).Msg("Failed to execute command") | 		log.Fatal().Err(err).Msg("Failed to execute command") | ||||||
| @@ -80,21 +76,6 @@ func init() { | |||||||
| 		{"users", "", "Comma separated list of users in the format username:hash."}, | 		{"users", "", "Comma separated list of users in the format username:hash."}, | ||||||
| 		{"users-file", "", "Path to a file containing users in the format username:hash."}, | 		{"users-file", "", "Path to a file containing users in the format username:hash."}, | ||||||
| 		{"secure-cookie", false, "Send cookie over secure connection only."}, | 		{"secure-cookie", false, "Send cookie over secure connection only."}, | ||||||
| 		{"github-client-id", "", "Github OAuth client ID."}, |  | ||||||
| 		{"github-client-secret", "", "Github OAuth client secret."}, |  | ||||||
| 		{"github-client-secret-file", "", "Github OAuth client secret file."}, |  | ||||||
| 		{"google-client-id", "", "Google OAuth client ID."}, |  | ||||||
| 		{"google-client-secret", "", "Google OAuth client secret."}, |  | ||||||
| 		{"google-client-secret-file", "", "Google OAuth client secret file."}, |  | ||||||
| 		{"generic-client-id", "", "Generic OAuth client ID."}, |  | ||||||
| 		{"generic-client-secret", "", "Generic OAuth client secret."}, |  | ||||||
| 		{"generic-client-secret-file", "", "Generic OAuth client secret file."}, |  | ||||||
| 		{"generic-scopes", "", "Generic OAuth scopes."}, |  | ||||||
| 		{"generic-auth-url", "", "Generic OAuth auth URL."}, |  | ||||||
| 		{"generic-token-url", "", "Generic OAuth token URL."}, |  | ||||||
| 		{"generic-user-url", "", "Generic OAuth user info URL."}, |  | ||||||
| 		{"generic-name", "Generic", "Generic OAuth provider name."}, |  | ||||||
| 		{"generic-skip-ssl", false, "Skip SSL verification for the generic OAuth provider."}, |  | ||||||
| 		{"oauth-whitelist", "", "Comma separated list of email addresses to whitelist when using OAuth."}, | 		{"oauth-whitelist", "", "Comma separated list of email addresses to whitelist when using OAuth."}, | ||||||
| 		{"oauth-auto-redirect", "none", "Auto redirect to the specified OAuth provider if configured. (available providers: github, google, generic)"}, | 		{"oauth-auto-redirect", "none", "Auto redirect to the specified OAuth provider if configured. (available providers: github, google, generic)"}, | ||||||
| 		{"session-expiry", 86400, "Session (cookie) expiration time in seconds."}, | 		{"session-expiry", 86400, "Session (cookie) expiration time in seconds."}, | ||||||
| @@ -112,7 +93,7 @@ func init() { | |||||||
| 		{"ldap-search-filter", "(uid=%s)", "LDAP search filter for user lookup."}, | 		{"ldap-search-filter", "(uid=%s)", "LDAP search filter for user lookup."}, | ||||||
| 		{"resources-dir", "/data/resources", "Path to a directory containing custom resources (e.g. background image)."}, | 		{"resources-dir", "/data/resources", "Path to a directory containing custom resources (e.g. background image)."}, | ||||||
| 		{"database-path", "/data/tinyauth.db", "Path to the Sqlite database file."}, | 		{"database-path", "/data/tinyauth.db", "Path to the Sqlite database file."}, | ||||||
| 		{"trusted-proxies", "", "Comma separated list of trusted proxies (IP addresses) for correct client IP detection and for header ACLs."}, | 		{"trusted-proxies", "", "Comma separated list of trusted proxies (IP addresses or CIDRs) for correct client IP detection."}, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	for _, opt := range configOptions { | 	for _, opt := range configOptions { | ||||||
|   | |||||||
							
								
								
									
										18
									
								
								frontend/src/components/icons/microsoft.tsx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								frontend/src/components/icons/microsoft.tsx
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,18 @@ | |||||||
|  | import type { SVGProps } from "react"; | ||||||
|  |  | ||||||
|  | export function MicrosoftIcon(props: SVGProps<SVGSVGElement>) { | ||||||
|  |   return ( | ||||||
|  |     <svg | ||||||
|  |       xmlns="http://www.w3.org/2000/svg" | ||||||
|  |       width="2em" | ||||||
|  |       height="2em" | ||||||
|  |       viewBox="0 0 256 256" | ||||||
|  |       {...props} | ||||||
|  |     > | ||||||
|  |       <path fill="#f1511b" d="M121.666 121.666H0V0h121.666z"></path> | ||||||
|  |       <path fill="#80cc28" d="M256 121.666H134.335V0H256z"></path> | ||||||
|  |       <path fill="#00adef" d="M121.663 256.002H0V134.336h121.663z"></path> | ||||||
|  |       <path fill="#fbbc09" d="M256 256.002H134.335V134.336H256z"></path> | ||||||
|  |     </svg> | ||||||
|  |   ); | ||||||
|  | } | ||||||
| @@ -1,6 +1,6 @@ | |||||||
| import type { SVGProps } from "react"; | import type { SVGProps } from "react"; | ||||||
| 
 | 
 | ||||||
| export function GenericIcon(props: SVGProps<SVGSVGElement>) { | export function OAuthIcon(props: SVGProps<SVGSVGElement>) { | ||||||
|   return ( |   return ( | ||||||
|     <svg |     <svg | ||||||
|       xmlns="http://www.w3.org/2000/svg" |       xmlns="http://www.w3.org/2000/svg" | ||||||
							
								
								
									
										20
									
								
								frontend/src/components/icons/pocket-id.tsx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								frontend/src/components/icons/pocket-id.tsx
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,20 @@ | |||||||
|  | import type { SVGProps } from "react"; | ||||||
|  |  | ||||||
|  | export function PocketIDIcon(props: SVGProps<SVGSVGElement>) { | ||||||
|  |   return ( | ||||||
|  |     <svg | ||||||
|  |       xmlns="http://www.w3.org/2000/svg" | ||||||
|  |       xmlSpace="preserve" | ||||||
|  |       width={512} | ||||||
|  |       height={512} | ||||||
|  |       viewBox="0 0 512 512" | ||||||
|  |       {...props} | ||||||
|  |     > | ||||||
|  |       <circle cx="256" cy="256" r="256" /> | ||||||
|  |       <path | ||||||
|  |         d="M268.6 102.4c64.4 0 116.8 52.4 116.8 116.7 0 25.3-8 49.4-23 69.6-14.8 19.9-35 34.3-58.4 41.7l-6.5 2-15.5-76.2 4.3-2c14-6.7 23-21.1 23-36.6 0-22.4-18.2-40.6-40.6-40.6S228 195.2 228 217.6c0 15.5 9 29.8 23 36.6l4.2 2-25 153.4h-69.5V102.4z" | ||||||
|  |         className="fill-white" | ||||||
|  |       /> | ||||||
|  |     </svg> | ||||||
|  |   ); | ||||||
|  | } | ||||||
							
								
								
									
										26
									
								
								frontend/src/components/icons/tailscale.tsx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								frontend/src/components/icons/tailscale.tsx
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,26 @@ | |||||||
|  | import type { SVGProps } from "react"; | ||||||
|  |  | ||||||
|  | export function TailscaleIcon(props: SVGProps<SVGSVGElement>) { | ||||||
|  |   return ( | ||||||
|  |     <svg | ||||||
|  |       xmlns="http://www.w3.org/2000/svg" | ||||||
|  |       xmlSpace="preserve" | ||||||
|  |       width={512} | ||||||
|  |       height={512} | ||||||
|  |       viewBox="0 0 512 512" | ||||||
|  |       {...props} | ||||||
|  |     > | ||||||
|  |       <path | ||||||
|  |         className="opacity-80" | ||||||
|  |         fill="currentColor" | ||||||
|  |         d="M65.6 318.1c35.3 0 63.9-28.6 63.9-63.9s-28.6-63.9-63.9-63.9S1.8 219 1.8 254.2s28.6 63.9 63.8 63.9m191.6 0c35.3 0 63.9-28.6 63.9-63.9s-28.6-63.9-63.9-63.9-63.9 28.6-63.9 63.9 28.6 63.9 63.9 63.9m0 193.9c35.3 0 63.9-28.6 63.9-63.9s-28.6-63.9-63.9-63.9-63.9 28.6-63.9 63.9 28.6 63.9 63.9 63.9m189.2-193.9c35.3 0 63.9-28.6 63.9-63.9s-28.6-63.9-63.9-63.9-63.9 28.6-63.9 63.9 28.6 63.9 63.9 63.9" | ||||||
|  |       /> | ||||||
|  |  | ||||||
|  |       <path | ||||||
|  |         d="M65.6 127.7c35.3 0 63.9-28.6 63.9-63.9S100.9 0 65.6 0 1.8 28.6 1.8 63.9s28.6 63.8 63.8 63.8m0 384.3c35.3 0 63.9-28.6 63.9-63.9s-28.6-63.9-63.9-63.9-63.8 28.7-63.8 63.9S30.4 512 65.6 512m191.6-384.3c35.3 0 63.9-28.6 63.9-63.9S292.5 0 257.2 0s-63.9 28.6-63.9 63.9 28.6 63.8 63.9 63.8m189.2 0c35.3 0 63.9-28.6 63.9-63.9S481.6 0 446.4 0c-35.3 0-63.9 28.6-63.9 63.9s28.6 63.8 63.9 63.8m0 384.3c35.3 0 63.9-28.6 63.9-63.9s-28.6-63.9-63.9-63.9-63.9 28.6-63.9 63.9 28.6 63.9 63.9 63.9" | ||||||
|  |         className="opacity-20" | ||||||
|  |         fill="currentColor" | ||||||
|  |       /> | ||||||
|  |     </svg> | ||||||
|  |   ); | ||||||
|  | } | ||||||
| @@ -14,6 +14,9 @@ | |||||||
|     "loginOauthFailSubtitle": "Failed to get OAuth URL", |     "loginOauthFailSubtitle": "Failed to get OAuth URL", | ||||||
|     "loginOauthSuccessTitle": "Redirecting", |     "loginOauthSuccessTitle": "Redirecting", | ||||||
|     "loginOauthSuccessSubtitle": "Redirecting to your OAuth provider", |     "loginOauthSuccessSubtitle": "Redirecting to your OAuth provider", | ||||||
|  |     "loginOauthAutoRedirectTitle": "OAuth Auto Redirect", | ||||||
|  |     "loginOauthAutoRedirectSubtitle": "You will be automatically redirected to your OAuth provider to authenticate.", | ||||||
|  |     "loginOauthAutoRedirectButton": "Redirect now", | ||||||
|     "continueTitle": "Continue", |     "continueTitle": "Continue", | ||||||
|     "continueRedirectingTitle": "Redirecting...", |     "continueRedirectingTitle": "Redirecting...", | ||||||
|     "continueRedirectingSubtitle": "You should be redirected to the app soon", |     "continueRedirectingSubtitle": "You should be redirected to the app soon", | ||||||
|   | |||||||
| @@ -14,6 +14,9 @@ | |||||||
|     "loginOauthFailSubtitle": "Failed to get OAuth URL", |     "loginOauthFailSubtitle": "Failed to get OAuth URL", | ||||||
|     "loginOauthSuccessTitle": "Redirecting", |     "loginOauthSuccessTitle": "Redirecting", | ||||||
|     "loginOauthSuccessSubtitle": "Redirecting to your OAuth provider", |     "loginOauthSuccessSubtitle": "Redirecting to your OAuth provider", | ||||||
|  |     "loginOauthAutoRedirectTitle": "OAuth Auto Redirect", | ||||||
|  |     "loginOauthAutoRedirectSubtitle": "You will be automatically redirected to your OAuth provider to authenticate.", | ||||||
|  |     "loginOauthAutoRedirectButton": "Redirect now", | ||||||
|     "continueTitle": "Continue", |     "continueTitle": "Continue", | ||||||
|     "continueRedirectingTitle": "Redirecting...", |     "continueRedirectingTitle": "Redirecting...", | ||||||
|     "continueRedirectingSubtitle": "You should be redirected to the app soon", |     "continueRedirectingSubtitle": "You should be redirected to the app soon", | ||||||
|   | |||||||
| @@ -70,7 +70,7 @@ export const ContinuePage = () => { | |||||||
|     const reveal = setTimeout(() => { |     const reveal = setTimeout(() => { | ||||||
|       setLoading(false); |       setLoading(false); | ||||||
|       setShowRedirectButton(true); |       setShowRedirectButton(true); | ||||||
|     }, 1000); |     }, 5000); | ||||||
|  |  | ||||||
|     return () => { |     return () => { | ||||||
|       clearTimeout(auto); |       clearTimeout(auto); | ||||||
|   | |||||||
| @@ -1,13 +1,18 @@ | |||||||
| import { LoginForm } from "@/components/auth/login-form"; | import { LoginForm } from "@/components/auth/login-form"; | ||||||
| import { GenericIcon } from "@/components/icons/generic"; |  | ||||||
| import { GithubIcon } from "@/components/icons/github"; | import { GithubIcon } from "@/components/icons/github"; | ||||||
| import { GoogleIcon } from "@/components/icons/google"; | import { GoogleIcon } from "@/components/icons/google"; | ||||||
|  | import { MicrosoftIcon } from "@/components/icons/microsoft"; | ||||||
|  | import { OAuthIcon } from "@/components/icons/oauth"; | ||||||
|  | import { PocketIDIcon } from "@/components/icons/pocket-id"; | ||||||
|  | import { TailscaleIcon } from "@/components/icons/tailscale"; | ||||||
|  | import { Button } from "@/components/ui/button"; | ||||||
| import { | import { | ||||||
|   Card, |   Card, | ||||||
|   CardHeader, |   CardHeader, | ||||||
|   CardTitle, |   CardTitle, | ||||||
|   CardDescription, |   CardDescription, | ||||||
|   CardContent, |   CardContent, | ||||||
|  |   CardFooter, | ||||||
| } from "@/components/ui/card"; | } from "@/components/ui/card"; | ||||||
| import { OAuthButton } from "@/components/ui/oauth-button"; | import { OAuthButton } from "@/components/ui/oauth-button"; | ||||||
| import { SeperatorWithChildren } from "@/components/ui/separator"; | import { SeperatorWithChildren } from "@/components/ui/separator"; | ||||||
| @@ -17,28 +22,40 @@ import { useIsMounted } from "@/lib/hooks/use-is-mounted"; | |||||||
| import { LoginSchema } from "@/schemas/login-schema"; | import { LoginSchema } from "@/schemas/login-schema"; | ||||||
| import { useMutation } from "@tanstack/react-query"; | import { useMutation } from "@tanstack/react-query"; | ||||||
| import axios, { AxiosError } from "axios"; | import axios, { AxiosError } from "axios"; | ||||||
| import { useEffect, useRef } from "react"; | import { useEffect, useRef, useState } from "react"; | ||||||
| import { useTranslation } from "react-i18next"; | import { useTranslation } from "react-i18next"; | ||||||
| import { Navigate, useLocation } from "react-router"; | import { Navigate, useLocation } from "react-router"; | ||||||
| import { toast } from "sonner"; | import { toast } from "sonner"; | ||||||
|  |  | ||||||
|  | const iconMap: Record<string, React.ReactNode> = { | ||||||
|  |   google: <GoogleIcon />, | ||||||
|  |   github: <GithubIcon />, | ||||||
|  |   tailscale: <TailscaleIcon />, | ||||||
|  |   microsoft: <MicrosoftIcon />, | ||||||
|  |   pocketid: <PocketIDIcon />, | ||||||
|  | }; | ||||||
|  |  | ||||||
| export const LoginPage = () => { | export const LoginPage = () => { | ||||||
|   const { isLoggedIn } = useUserContext(); |   const { isLoggedIn } = useUserContext(); | ||||||
|   const { configuredProviders, title, oauthAutoRedirect, genericName } = |   const { providers, title, oauthAutoRedirect } = useAppContext(); | ||||||
|     useAppContext(); |  | ||||||
|   const { search } = useLocation(); |   const { search } = useLocation(); | ||||||
|   const { t } = useTranslation(); |   const { t } = useTranslation(); | ||||||
|   const isMounted = useIsMounted(); |   const isMounted = useIsMounted(); | ||||||
|  |   const [oauthAutoRedirectHandover, setOauthAutoRedirectHandover] = | ||||||
|  |     useState(false); | ||||||
|  |   const [showRedirectButton, setShowRedirectButton] = useState(false); | ||||||
|  |  | ||||||
|   const redirectTimer = useRef<number | null>(null); |   const redirectTimer = useRef<number | null>(null); | ||||||
|  |   const redirectButtonTimer = useRef<number | null>(null); | ||||||
|  |  | ||||||
|   const searchParams = new URLSearchParams(search); |   const searchParams = new URLSearchParams(search); | ||||||
|   const redirectUri = searchParams.get("redirect_uri"); |   const redirectUri = searchParams.get("redirect_uri"); | ||||||
|  |  | ||||||
|   const oauthConfigured = |   const oauthProviders = providers.filter( | ||||||
|     configuredProviders.filter((provider) => provider !== "username").length > |     (provider) => provider.id !== "username", | ||||||
|     0; |   ); | ||||||
|   const userAuthConfigured = configuredProviders.includes("username"); |   const userAuthConfigured = | ||||||
|  |     providers.find((provider) => provider.id === "username") !== undefined; | ||||||
|  |  | ||||||
|   const oauthMutation = useMutation({ |   const oauthMutation = useMutation({ | ||||||
|     mutationFn: (provider: string) => |     mutationFn: (provider: string) => | ||||||
| @@ -56,6 +73,7 @@ export const LoginPage = () => { | |||||||
|       }, 500); |       }, 500); | ||||||
|     }, |     }, | ||||||
|     onError: () => { |     onError: () => { | ||||||
|  |       setOauthAutoRedirectHandover(false); | ||||||
|       toast.error(t("loginOauthFailTitle"), { |       toast.error(t("loginOauthFailTitle"), { | ||||||
|         description: t("loginOauthFailSubtitle"), |         description: t("loginOauthFailSubtitle"), | ||||||
|       }); |       }); | ||||||
| @@ -96,12 +114,16 @@ export const LoginPage = () => { | |||||||
|   useEffect(() => { |   useEffect(() => { | ||||||
|     if (isMounted()) { |     if (isMounted()) { | ||||||
|       if ( |       if ( | ||||||
|         oauthConfigured && |         oauthProviders.length !== 0 && | ||||||
|         configuredProviders.includes(oauthAutoRedirect) && |         providers.find((provider) => provider.id === oauthAutoRedirect) && | ||||||
|         !isLoggedIn && |         !isLoggedIn && | ||||||
|         redirectUri |         redirectUri | ||||||
|       ) { |       ) { | ||||||
|  |         setOauthAutoRedirectHandover(true); | ||||||
|         oauthMutation.mutate(oauthAutoRedirect); |         oauthMutation.mutate(oauthAutoRedirect); | ||||||
|  |         redirectButtonTimer.current = window.setTimeout(() => { | ||||||
|  |           setShowRedirectButton(true); | ||||||
|  |         }, 5000); | ||||||
|       } |       } | ||||||
|     } |     } | ||||||
|   }, []); |   }, []); | ||||||
| @@ -109,6 +131,8 @@ export const LoginPage = () => { | |||||||
|   useEffect( |   useEffect( | ||||||
|     () => () => { |     () => () => { | ||||||
|       if (redirectTimer.current) clearTimeout(redirectTimer.current); |       if (redirectTimer.current) clearTimeout(redirectTimer.current); | ||||||
|  |       if (redirectButtonTimer.current) | ||||||
|  |         clearTimeout(redirectButtonTimer.current); | ||||||
|     }, |     }, | ||||||
|     [], |     [], | ||||||
|   ); |   ); | ||||||
| @@ -126,61 +150,63 @@ export const LoginPage = () => { | |||||||
|     return <Navigate to="/logout" replace />; |     return <Navigate to="/logout" replace />; | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  |   if (oauthAutoRedirectHandover) { | ||||||
|  |     return ( | ||||||
|  |       <Card className="min-w-xs sm:min-w-sm"> | ||||||
|  |         <CardHeader> | ||||||
|  |           <CardTitle className="text-3xl"> | ||||||
|  |             {t("loginOauthAutoRedirectTitle")} | ||||||
|  |           </CardTitle> | ||||||
|  |           <CardDescription> | ||||||
|  |             {t("loginOauthAutoRedirectSubtitle")} | ||||||
|  |           </CardDescription> | ||||||
|  |         </CardHeader> | ||||||
|  |         {showRedirectButton && ( | ||||||
|  |           <CardFooter className="flex flex-col items-stretch"> | ||||||
|  |             <Button | ||||||
|  |               onClick={() => { | ||||||
|  |                 window.location.replace(oauthMutation.data?.data.url); | ||||||
|  |               }} | ||||||
|  |             > | ||||||
|  |               {t("loginOauthAutoRedirectButton")} | ||||||
|  |             </Button> | ||||||
|  |           </CardFooter> | ||||||
|  |         )} | ||||||
|  |       </Card> | ||||||
|  |     ); | ||||||
|  |   } | ||||||
|   return ( |   return ( | ||||||
|     <Card className="min-w-xs sm:min-w-sm"> |     <Card className="min-w-xs sm:min-w-sm"> | ||||||
|       <CardHeader> |       <CardHeader> | ||||||
|         <CardTitle className="text-center text-3xl">{title}</CardTitle> |         <CardTitle className="text-center text-3xl">{title}</CardTitle> | ||||||
|         {configuredProviders.length > 0 && ( |         {providers.length > 0 && ( | ||||||
|           <CardDescription className="text-center"> |           <CardDescription className="text-center"> | ||||||
|             {oauthConfigured ? t("loginTitle") : t("loginTitleSimple")} |             {oauthProviders.length !== 0 | ||||||
|  |               ? t("loginTitle") | ||||||
|  |               : t("loginTitleSimple")} | ||||||
|           </CardDescription> |           </CardDescription> | ||||||
|         )} |         )} | ||||||
|       </CardHeader> |       </CardHeader> | ||||||
|       <CardContent className="flex flex-col gap-4"> |       <CardContent className="flex flex-col gap-4"> | ||||||
|         {oauthConfigured && ( |         {oauthProviders.length !== 0 && ( | ||||||
|           <div className="flex flex-col gap-2 items-center justify-center"> |           <div className="flex flex-col gap-2 items-center justify-center"> | ||||||
|             {configuredProviders.includes("google") && ( |             {oauthProviders.map((provider) => ( | ||||||
|               <OAuthButton |               <OAuthButton | ||||||
|                 title="Google" |                 key={provider.id} | ||||||
|                 icon={<GoogleIcon />} |                 title={provider.name} | ||||||
|  |                 icon={iconMap[provider.id] ?? <OAuthIcon />} | ||||||
|                 className="w-full" |                 className="w-full" | ||||||
|                 onClick={() => oauthMutation.mutate("google")} |                 onClick={() => oauthMutation.mutate(provider.id)} | ||||||
|                 loading={ |                 loading={ | ||||||
|                   oauthMutation.isPending && |                   oauthMutation.isPending && | ||||||
|                   oauthMutation.variables === "google" |                   oauthMutation.variables === provider.id | ||||||
|                 } |                 } | ||||||
|                 disabled={oauthMutation.isPending || loginMutation.isPending} |                 disabled={oauthMutation.isPending || loginMutation.isPending} | ||||||
|               /> |               /> | ||||||
|             )} |             ))} | ||||||
|             {configuredProviders.includes("github") && ( |  | ||||||
|               <OAuthButton |  | ||||||
|                 title="Github" |  | ||||||
|                 icon={<GithubIcon />} |  | ||||||
|                 className="w-full" |  | ||||||
|                 onClick={() => oauthMutation.mutate("github")} |  | ||||||
|                 loading={ |  | ||||||
|                   oauthMutation.isPending && |  | ||||||
|                   oauthMutation.variables === "github" |  | ||||||
|                 } |  | ||||||
|                 disabled={oauthMutation.isPending || loginMutation.isPending} |  | ||||||
|               /> |  | ||||||
|             )} |  | ||||||
|             {configuredProviders.includes("generic") && ( |  | ||||||
|               <OAuthButton |  | ||||||
|                 title={genericName} |  | ||||||
|                 icon={<GenericIcon />} |  | ||||||
|                 className="w-full" |  | ||||||
|                 onClick={() => oauthMutation.mutate("generic")} |  | ||||||
|                 loading={ |  | ||||||
|                   oauthMutation.isPending && |  | ||||||
|                   oauthMutation.variables === "generic" |  | ||||||
|                 } |  | ||||||
|                 disabled={oauthMutation.isPending || loginMutation.isPending} |  | ||||||
|               /> |  | ||||||
|             )} |  | ||||||
|           </div> |           </div> | ||||||
|         )} |         )} | ||||||
|         {userAuthConfigured && oauthConfigured && ( |         {userAuthConfigured && oauthProviders.length !== 0 && ( | ||||||
|           <SeperatorWithChildren>{t("loginDivider")}</SeperatorWithChildren> |           <SeperatorWithChildren>{t("loginDivider")}</SeperatorWithChildren> | ||||||
|         )} |         )} | ||||||
|         {userAuthConfigured && ( |         {userAuthConfigured && ( | ||||||
| @@ -189,7 +215,7 @@ export const LoginPage = () => { | |||||||
|             loading={loginMutation.isPending || oauthMutation.isPending} |             loading={loginMutation.isPending || oauthMutation.isPending} | ||||||
|           /> |           /> | ||||||
|         )} |         )} | ||||||
|         {configuredProviders.length == 0 && ( |         {providers.length == 0 && ( | ||||||
|           <p className="text-center text-red-600 max-w-sm"> |           <p className="text-center text-red-600 max-w-sm"> | ||||||
|             {t("failedToFetchProvidersTitle")} |             {t("failedToFetchProvidersTitle")} | ||||||
|           </p> |           </p> | ||||||
|   | |||||||
| @@ -6,9 +6,7 @@ import { | |||||||
|   CardHeader, |   CardHeader, | ||||||
|   CardTitle, |   CardTitle, | ||||||
| } from "@/components/ui/card"; | } from "@/components/ui/card"; | ||||||
| import { useAppContext } from "@/context/app-context"; |  | ||||||
| import { useUserContext } from "@/context/user-context"; | import { useUserContext } from "@/context/user-context"; | ||||||
| import { capitalize } from "@/lib/utils"; |  | ||||||
| import { useMutation } from "@tanstack/react-query"; | import { useMutation } from "@tanstack/react-query"; | ||||||
| import axios from "axios"; | import axios from "axios"; | ||||||
| import { useEffect, useRef } from "react"; | import { useEffect, useRef } from "react"; | ||||||
| @@ -17,8 +15,7 @@ import { Navigate } from "react-router"; | |||||||
| import { toast } from "sonner"; | import { toast } from "sonner"; | ||||||
|  |  | ||||||
| export const LogoutPage = () => { | export const LogoutPage = () => { | ||||||
|   const { provider, username, isLoggedIn, email } = useUserContext(); |   const { provider, username, isLoggedIn, email, oauthName } = useUserContext(); | ||||||
|   const { genericName } = useAppContext(); |  | ||||||
|   const { t } = useTranslation(); |   const { t } = useTranslation(); | ||||||
|  |  | ||||||
|   const redirectTimer = useRef<number | null>(null); |   const redirectTimer = useRef<number | null>(null); | ||||||
| @@ -67,8 +64,7 @@ export const LogoutPage = () => { | |||||||
|               }} |               }} | ||||||
|               values={{ |               values={{ | ||||||
|                 username: email, |                 username: email, | ||||||
|                 provider: |                 provider: oauthName, | ||||||
|                   provider === "generic" ? genericName : capitalize(provider), |  | ||||||
|               }} |               }} | ||||||
|             /> |             /> | ||||||
|           ) : ( |           ) : ( | ||||||
|   | |||||||
| @@ -1,14 +1,19 @@ | |||||||
| import { z } from "zod"; | import { z } from "zod"; | ||||||
|  |  | ||||||
|  | export const providerSchema = z.object({ | ||||||
|  |   id: z.string(), | ||||||
|  |   name: z.string(), | ||||||
|  |   oauth: z.boolean(), | ||||||
|  | }); | ||||||
|  |  | ||||||
| export const appContextSchema = z.object({ | export const appContextSchema = z.object({ | ||||||
|   configuredProviders: z.array(z.string()), |   providers: z.array(providerSchema), | ||||||
|   title: z.string(), |   title: z.string(), | ||||||
|   genericName: z.string(), |  | ||||||
|   appUrl: z.string(), |   appUrl: z.string(), | ||||||
|   cookieDomain: z.string(), |   cookieDomain: z.string(), | ||||||
|   forgotPasswordMessage: z.string(), |   forgotPasswordMessage: z.string(), | ||||||
|   oauthAutoRedirect: z.enum(["none", "github", "google", "generic"]), |  | ||||||
|   backgroundImage: z.string(), |   backgroundImage: z.string(), | ||||||
|  |   oauthAutoRedirect: z.string(), | ||||||
| }); | }); | ||||||
|  |  | ||||||
| export type AppContextSchema = z.infer<typeof appContextSchema>; | export type AppContextSchema = z.infer<typeof appContextSchema>; | ||||||
|   | |||||||
| @@ -8,6 +8,7 @@ export const userContextSchema = z.object({ | |||||||
|   provider: z.string(), |   provider: z.string(), | ||||||
|   oauth: z.boolean(), |   oauth: z.boolean(), | ||||||
|   totpPending: z.boolean(), |   totpPending: z.boolean(), | ||||||
|  |   oauthName: z.string(), | ||||||
| }); | }); | ||||||
|  |  | ||||||
| export type UserContextSchema = z.infer<typeof userContextSchema>; | export type UserContextSchema = z.infer<typeof userContextSchema>; | ||||||
|   | |||||||
							
								
								
									
										1
									
								
								internal/assets/migrations/000002_oauth_name.down.sql
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								internal/assets/migrations/000002_oauth_name.down.sql
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1 @@ | |||||||
|  | ALTER TABLE "sessions" DROP COLUMN "oauth_name"; | ||||||
							
								
								
									
										10
									
								
								internal/assets/migrations/000002_oauth_name.up.sql
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								internal/assets/migrations/000002_oauth_name.up.sql
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,10 @@ | |||||||
|  | ALTER TABLE "sessions" ADD COLUMN "oauth_name" TEXT; | ||||||
|  |  | ||||||
|  | UPDATE "sessions" | ||||||
|  | SET "oauth_name" = CASE | ||||||
|  |   WHEN LOWER("provider") = 'github' THEN 'GitHub' | ||||||
|  |   WHEN LOWER("provider") = 'google' THEN 'Google' | ||||||
|  |   ELSE UPPER(SUBSTR("provider", 1, 1)) || SUBSTR("provider", 2) | ||||||
|  | END | ||||||
|  | WHERE "oauth_name" IS NULL AND "provider" IS NOT NULL; | ||||||
|  |  | ||||||
| @@ -3,6 +3,7 @@ package bootstrap | |||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net/url" | 	"net/url" | ||||||
|  | 	"os" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"tinyauth/internal/config" | 	"tinyauth/internal/config" | ||||||
| 	"tinyauth/internal/controller" | 	"tinyauth/internal/controller" | ||||||
| @@ -45,6 +46,13 @@ func (app *BootstrapApp) Setup() error { | |||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// Get OAuth configs | ||||||
|  | 	oauthProviders, err := utils.GetOAuthProvidersConfig(os.Environ(), os.Args, app.Config.AppURL) | ||||||
|  |  | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	// Get cookie domain | 	// Get cookie domain | ||||||
| 	cookieDomain, err := utils.GetCookieDomain(app.Config.AppURL) | 	cookieDomain, err := utils.GetCookieDomain(app.Config.AppURL) | ||||||
|  |  | ||||||
| @@ -112,7 +120,7 @@ func (app *BootstrapApp) Setup() error { | |||||||
| 	// Create services | 	// Create services | ||||||
| 	dockerService := service.NewDockerService() | 	dockerService := service.NewDockerService() | ||||||
| 	authService := service.NewAuthService(authConfig, dockerService, ldapService, database) | 	authService := service.NewAuthService(authConfig, dockerService, ldapService, database) | ||||||
| 	oauthBrokerService := service.NewOAuthBrokerService(app.getOAuthBrokerConfig()) | 	oauthBrokerService := service.NewOAuthBrokerService(oauthProviders) | ||||||
|  |  | ||||||
| 	// Initialize services | 	// Initialize services | ||||||
| 	services := []Service{ | 	services := []Service{ | ||||||
| @@ -132,13 +140,41 @@ func (app *BootstrapApp) Setup() error { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Configured providers | 	// Configured providers | ||||||
| 	var configuredProviders []string | 	babysit := map[string]string{ | ||||||
|  | 		"google": "Google", | ||||||
|  | 		"github": "GitHub", | ||||||
|  | 	} | ||||||
|  | 	configuredProviders := make([]controller.Provider, 0) | ||||||
|  |  | ||||||
| 	if authService.UserAuthConfigured() || ldapService != nil { | 	for id, provider := range oauthProviders { | ||||||
| 		configuredProviders = append(configuredProviders, "username") | 		if id == "" { | ||||||
|  | 			continue | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 	configuredProviders = append(configuredProviders, oauthBrokerService.GetConfiguredServices()...) | 		if provider.Name == "" { | ||||||
|  | 			if name, ok := babysit[id]; ok { | ||||||
|  | 				provider.Name = name | ||||||
|  | 			} else { | ||||||
|  | 				provider.Name = utils.Capitalize(id) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		configuredProviders = append(configuredProviders, controller.Provider{ | ||||||
|  | 			Name:  provider.Name, | ||||||
|  | 			ID:    id, | ||||||
|  | 			OAuth: true, | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if authService.UserAuthConfigured() || ldapService != nil { | ||||||
|  | 		configuredProviders = append(configuredProviders, controller.Provider{ | ||||||
|  | 			Name:  "Username", | ||||||
|  | 			ID:    "username", | ||||||
|  | 			OAuth: false, | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	log.Debug().Interface("providers", configuredProviders).Msg("Authentication providers") | ||||||
|  |  | ||||||
| 	if len(configuredProviders) == 0 { | 	if len(configuredProviders) == 0 { | ||||||
| 		return fmt.Errorf("no authentication providers configured") | 		return fmt.Errorf("no authentication providers configured") | ||||||
| @@ -179,9 +215,8 @@ func (app *BootstrapApp) Setup() error { | |||||||
|  |  | ||||||
| 	// Create controllers | 	// Create controllers | ||||||
| 	contextController := controller.NewContextController(controller.ContextControllerConfig{ | 	contextController := controller.NewContextController(controller.ContextControllerConfig{ | ||||||
| 		ConfiguredProviders:   configuredProviders, | 		Providers:             configuredProviders, | ||||||
| 		Title:                 app.Config.Title, | 		Title:                 app.Config.Title, | ||||||
| 		GenericName:           app.Config.GenericName, |  | ||||||
| 		AppURL:                app.Config.AppURL, | 		AppURL:                app.Config.AppURL, | ||||||
| 		CookieDomain:          cookieDomain, | 		CookieDomain:          cookieDomain, | ||||||
| 		ForgotPasswordMessage: app.Config.ForgotPasswordMessage, | 		ForgotPasswordMessage: app.Config.ForgotPasswordMessage, | ||||||
| @@ -235,30 +270,3 @@ func (app *BootstrapApp) Setup() error { | |||||||
|  |  | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| // Temporary |  | ||||||
| func (app *BootstrapApp) getOAuthBrokerConfig() map[string]config.OAuthServiceConfig { |  | ||||||
| 	return map[string]config.OAuthServiceConfig{ |  | ||||||
| 		"google": { |  | ||||||
| 			ClientID:     app.Config.GoogleClientId, |  | ||||||
| 			ClientSecret: app.Config.GoogleClientSecret, |  | ||||||
| 			RedirectURL:  fmt.Sprintf("%s/api/oauth/callback/google", app.Config.AppURL), |  | ||||||
| 		}, |  | ||||||
| 		"github": { |  | ||||||
| 			ClientID:     app.Config.GithubClientId, |  | ||||||
| 			ClientSecret: app.Config.GithubClientSecret, |  | ||||||
| 			RedirectURL:  fmt.Sprintf("%s/api/oauth/callback/github", app.Config.AppURL), |  | ||||||
| 		}, |  | ||||||
| 		"generic": { |  | ||||||
| 			ClientID:           app.Config.GenericClientId, |  | ||||||
| 			ClientSecret:       app.Config.GenericClientSecret, |  | ||||||
| 			RedirectURL:        fmt.Sprintf("%s/api/oauth/callback/generic", app.Config.AppURL), |  | ||||||
| 			Scopes:             strings.Split(app.Config.GenericScopes, ","), |  | ||||||
| 			AuthURL:            app.Config.GenericAuthURL, |  | ||||||
| 			TokenURL:           app.Config.GenericTokenURL, |  | ||||||
| 			UserinfoURL:        app.Config.GenericUserURL, |  | ||||||
| 			InsecureSkipVerify: app.Config.GenericSkipSSL, |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| } |  | ||||||
|   | |||||||
| @@ -21,23 +21,8 @@ type Config struct { | |||||||
| 	Users                 string `mapstructure:"users"` | 	Users                 string `mapstructure:"users"` | ||||||
| 	UsersFile             string `mapstructure:"users-file"` | 	UsersFile             string `mapstructure:"users-file"` | ||||||
| 	SecureCookie          bool   `mapstructure:"secure-cookie"` | 	SecureCookie          bool   `mapstructure:"secure-cookie"` | ||||||
| 	GithubClientId          string `mapstructure:"github-client-id"` |  | ||||||
| 	GithubClientSecret      string `mapstructure:"github-client-secret"` |  | ||||||
| 	GithubClientSecretFile  string `mapstructure:"github-client-secret-file"` |  | ||||||
| 	GoogleClientId          string `mapstructure:"google-client-id"` |  | ||||||
| 	GoogleClientSecret      string `mapstructure:"google-client-secret"` |  | ||||||
| 	GoogleClientSecretFile  string `mapstructure:"google-client-secret-file"` |  | ||||||
| 	GenericClientId         string `mapstructure:"generic-client-id"` |  | ||||||
| 	GenericClientSecret     string `mapstructure:"generic-client-secret"` |  | ||||||
| 	GenericClientSecretFile string `mapstructure:"generic-client-secret-file"` |  | ||||||
| 	GenericScopes           string `mapstructure:"generic-scopes"` |  | ||||||
| 	GenericAuthURL          string `mapstructure:"generic-auth-url"` |  | ||||||
| 	GenericTokenURL         string `mapstructure:"generic-token-url"` |  | ||||||
| 	GenericUserURL          string `mapstructure:"generic-user-url"` |  | ||||||
| 	GenericName             string `mapstructure:"generic-name"` |  | ||||||
| 	GenericSkipSSL          bool   `mapstructure:"generic-skip-ssl"` |  | ||||||
| 	OAuthWhitelist        string `mapstructure:"oauth-whitelist"` | 	OAuthWhitelist        string `mapstructure:"oauth-whitelist"` | ||||||
| 	OAuthAutoRedirect       string `mapstructure:"oauth-auto-redirect" validate:"oneof=none github google generic"` | 	OAuthAutoRedirect     string `mapstructure:"oauth-auto-redirect"` | ||||||
| 	SessionExpiry         int    `mapstructure:"session-expiry"` | 	SessionExpiry         int    `mapstructure:"session-expiry"` | ||||||
| 	LogLevel              string `mapstructure:"log-level" validate:"oneof=trace debug info warn error fatal panic"` | 	LogLevel              string `mapstructure:"log-level" validate:"oneof=trace debug info warn error fatal panic"` | ||||||
| 	Title                 string `mapstructure:"app-title"` | 	Title                 string `mapstructure:"app-title"` | ||||||
| @@ -66,14 +51,16 @@ type Claims struct { | |||||||
| } | } | ||||||
|  |  | ||||||
| type OAuthServiceConfig struct { | type OAuthServiceConfig struct { | ||||||
| 	ClientID           string | 	ClientID           string   `key:"client-id"` | ||||||
| 	ClientSecret       string | 	ClientSecret       string   `key:"client-secret"` | ||||||
| 	Scopes             []string | 	ClientSecretFile   string   `key:"client-secret-file"` | ||||||
| 	RedirectURL        string | 	Scopes             []string `key:"scopes"` | ||||||
| 	AuthURL            string | 	RedirectURL        string   `key:"redirect-url"` | ||||||
| 	TokenURL           string | 	AuthURL            string   `key:"auth-url"` | ||||||
| 	UserinfoURL        string | 	TokenURL           string   `key:"token-url"` | ||||||
| 	InsecureSkipVerify bool | 	UserinfoURL        string   `key:"user-info-url"` | ||||||
|  | 	InsecureSkipVerify bool     `key:"insecure-skip-verify"` | ||||||
|  | 	Name               string   `key:"name"` | ||||||
| } | } | ||||||
|  |  | ||||||
| // User/session related stuff | // User/session related stuff | ||||||
| @@ -97,6 +84,7 @@ type SessionCookie struct { | |||||||
| 	Provider    string | 	Provider    string | ||||||
| 	TotpPending bool | 	TotpPending bool | ||||||
| 	OAuthGroups string | 	OAuthGroups string | ||||||
|  | 	OAuthName   string | ||||||
| } | } | ||||||
|  |  | ||||||
| type UserContext struct { | type UserContext struct { | ||||||
| @@ -109,6 +97,7 @@ type UserContext struct { | |||||||
| 	TotpPending bool | 	TotpPending bool | ||||||
| 	OAuthGroups string | 	OAuthGroups string | ||||||
| 	TotpEnabled bool | 	TotpEnabled bool | ||||||
|  | 	OAuthName   string | ||||||
| } | } | ||||||
|  |  | ||||||
| // API responses and queries | // API responses and queries | ||||||
| @@ -174,3 +163,9 @@ type AppPath struct { | |||||||
| 	Allow string | 	Allow string | ||||||
| 	Block string | 	Block string | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // Flags | ||||||
|  |  | ||||||
|  | type Providers struct { | ||||||
|  | 	Providers map[string]OAuthServiceConfig | ||||||
|  | } | ||||||
|   | |||||||
| @@ -19,14 +19,14 @@ type UserContextResponse struct { | |||||||
| 	Provider    string `json:"provider"` | 	Provider    string `json:"provider"` | ||||||
| 	OAuth       bool   `json:"oauth"` | 	OAuth       bool   `json:"oauth"` | ||||||
| 	TotpPending bool   `json:"totpPending"` | 	TotpPending bool   `json:"totpPending"` | ||||||
|  | 	OAuthName   string `json:"oauthName"` | ||||||
| } | } | ||||||
|  |  | ||||||
| type AppContextResponse struct { | type AppContextResponse struct { | ||||||
| 	Status                int        `json:"status"` | 	Status                int        `json:"status"` | ||||||
| 	Message               string     `json:"message"` | 	Message               string     `json:"message"` | ||||||
| 	ConfiguredProviders   []string `json:"configuredProviders"` | 	Providers             []Provider `json:"providers"` | ||||||
| 	Title                 string     `json:"title"` | 	Title                 string     `json:"title"` | ||||||
| 	GenericName           string   `json:"genericName"` |  | ||||||
| 	AppURL                string     `json:"appUrl"` | 	AppURL                string     `json:"appUrl"` | ||||||
| 	CookieDomain          string     `json:"cookieDomain"` | 	CookieDomain          string     `json:"cookieDomain"` | ||||||
| 	ForgotPasswordMessage string     `json:"forgotPasswordMessage"` | 	ForgotPasswordMessage string     `json:"forgotPasswordMessage"` | ||||||
| @@ -34,10 +34,15 @@ type AppContextResponse struct { | |||||||
| 	OAuthAutoRedirect     string     `json:"oauthAutoRedirect"` | 	OAuthAutoRedirect     string     `json:"oauthAutoRedirect"` | ||||||
| } | } | ||||||
|  |  | ||||||
|  | type Provider struct { | ||||||
|  | 	Name  string `json:"name"` | ||||||
|  | 	ID    string `json:"id"` | ||||||
|  | 	OAuth bool   `json:"oauth"` | ||||||
|  | } | ||||||
|  |  | ||||||
| type ContextControllerConfig struct { | type ContextControllerConfig struct { | ||||||
| 	ConfiguredProviders   []string | 	Providers             []Provider | ||||||
| 	Title                 string | 	Title                 string | ||||||
| 	GenericName           string |  | ||||||
| 	AppURL                string | 	AppURL                string | ||||||
| 	CookieDomain          string | 	CookieDomain          string | ||||||
| 	ForgotPasswordMessage string | 	ForgotPasswordMessage string | ||||||
| @@ -76,6 +81,7 @@ func (controller *ContextController) userContextHandler(c *gin.Context) { | |||||||
| 		Provider:    context.Provider, | 		Provider:    context.Provider, | ||||||
| 		OAuth:       context.OAuth, | 		OAuth:       context.OAuth, | ||||||
| 		TotpPending: context.TotpPending, | 		TotpPending: context.TotpPending, | ||||||
|  | 		OAuthName:   context.OAuthName, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @@ -96,9 +102,8 @@ func (controller *ContextController) appContextHandler(c *gin.Context) { | |||||||
| 	c.JSON(200, AppContextResponse{ | 	c.JSON(200, AppContextResponse{ | ||||||
| 		Status:                200, | 		Status:                200, | ||||||
| 		Message:               "Success", | 		Message:               "Success", | ||||||
| 		ConfiguredProviders:   controller.config.ConfiguredProviders, | 		Providers:             controller.config.Providers, | ||||||
| 		Title:                 controller.config.Title, | 		Title:                 controller.config.Title, | ||||||
| 		GenericName:           controller.config.GenericName, |  | ||||||
| 		AppURL:                fmt.Sprintf("%s://%s", appUrl.Scheme, appUrl.Host), | 		AppURL:                fmt.Sprintf("%s://%s", appUrl.Scheme, appUrl.Host), | ||||||
| 		CookieDomain:          controller.config.CookieDomain, | 		CookieDomain:          controller.config.CookieDomain, | ||||||
| 		ForgotPasswordMessage: controller.config.ForgotPasswordMessage, | 		ForgotPasswordMessage: controller.config.ForgotPasswordMessage, | ||||||
|   | |||||||
| @@ -12,9 +12,19 @@ import ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| var controllerCfg = controller.ContextControllerConfig{ | var controllerCfg = controller.ContextControllerConfig{ | ||||||
| 	ConfiguredProviders:   []string{"github", "google", "generic"}, | 	Providers: []controller.Provider{ | ||||||
|  | 		{ | ||||||
|  | 			Name:  "Username", | ||||||
|  | 			ID:    "username", | ||||||
|  | 			OAuth: false, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Name:  "Google", | ||||||
|  | 			ID:    "google", | ||||||
|  | 			OAuth: true, | ||||||
|  | 		}, | ||||||
|  | 	}, | ||||||
| 	Title:                 "Test App", | 	Title:                 "Test App", | ||||||
| 	GenericName:           "Generic", |  | ||||||
| 	AppURL:                "http://localhost:8080", | 	AppURL:                "http://localhost:8080", | ||||||
| 	CookieDomain:          "localhost", | 	CookieDomain:          "localhost", | ||||||
| 	ForgotPasswordMessage: "Contact admin to reset your password.", | 	ForgotPasswordMessage: "Contact admin to reset your password.", | ||||||
| @@ -58,9 +68,8 @@ func TestAppContextHandler(t *testing.T) { | |||||||
| 	expectedRes := controller.AppContextResponse{ | 	expectedRes := controller.AppContextResponse{ | ||||||
| 		Status:                200, | 		Status:                200, | ||||||
| 		Message:               "Success", | 		Message:               "Success", | ||||||
| 		ConfiguredProviders:   controllerCfg.ConfiguredProviders, | 		Providers:             controllerCfg.Providers, | ||||||
| 		Title:                 controllerCfg.Title, | 		Title:                 controllerCfg.Title, | ||||||
| 		GenericName:           controllerCfg.GenericName, |  | ||||||
| 		AppURL:                controllerCfg.AppURL, | 		AppURL:                controllerCfg.AppURL, | ||||||
| 		CookieDomain:          controllerCfg.CookieDomain, | 		CookieDomain:          controllerCfg.CookieDomain, | ||||||
| 		ForgotPasswordMessage: controllerCfg.ForgotPasswordMessage, | 		ForgotPasswordMessage: controllerCfg.ForgotPasswordMessage, | ||||||
|   | |||||||
| @@ -186,6 +186,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { | |||||||
| 		Email:       user.Email, | 		Email:       user.Email, | ||||||
| 		Provider:    req.Provider, | 		Provider:    req.Provider, | ||||||
| 		OAuthGroups: utils.CoalesceToString(user.Groups), | 		OAuthGroups: utils.CoalesceToString(user.Groups), | ||||||
|  | 		OAuthName:   service.GetName(), | ||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|   | |||||||
| @@ -95,6 +95,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { | |||||||
| 				Email:       cookie.Email, | 				Email:       cookie.Email, | ||||||
| 				Provider:    cookie.Provider, | 				Provider:    cookie.Provider, | ||||||
| 				OAuthGroups: cookie.OAuthGroups, | 				OAuthGroups: cookie.OAuthGroups, | ||||||
|  | 				OAuthName:   cookie.OAuthName, | ||||||
| 				IsLoggedIn:  true, | 				IsLoggedIn:  true, | ||||||
| 				OAuth:       true, | 				OAuth:       true, | ||||||
| 			}) | 			}) | ||||||
|   | |||||||
| @@ -9,4 +9,5 @@ type Session struct { | |||||||
| 	TOTPPending bool   `gorm:"column:totp_pending"` | 	TOTPPending bool   `gorm:"column:totp_pending"` | ||||||
| 	OAuthGroups string `gorm:"column:oauth_groups"` | 	OAuthGroups string `gorm:"column:oauth_groups"` | ||||||
| 	Expiry      int64  `gorm:"column:expiry"` | 	Expiry      int64  `gorm:"column:expiry"` | ||||||
|  | 	OAuthName   string `gorm:"column:oauth_name"` | ||||||
| } | } | ||||||
|   | |||||||
| @@ -210,6 +210,7 @@ func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *config.Sessio | |||||||
| 		TOTPPending: data.TotpPending, | 		TOTPPending: data.TotpPending, | ||||||
| 		OAuthGroups: data.OAuthGroups, | 		OAuthGroups: data.OAuthGroups, | ||||||
| 		Expiry:      time.Now().Add(time.Duration(expiry) * time.Second).Unix(), | 		Expiry:      time.Now().Add(time.Duration(expiry) * time.Second).Unix(), | ||||||
|  | 		OAuthName:   data.OAuthName, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	err = auth.database.Create(&session).Error | 	err = auth.database.Create(&session).Error | ||||||
| @@ -278,6 +279,7 @@ func (auth *AuthService) GetSessionCookie(c *gin.Context) (config.SessionCookie, | |||||||
| 		Provider:    session.Provider, | 		Provider:    session.Provider, | ||||||
| 		TotpPending: session.TOTPPending, | 		TotpPending: session.TOTPPending, | ||||||
| 		OAuthGroups: session.OAuthGroups, | 		OAuthGroups: session.OAuthGroups, | ||||||
|  | 		OAuthName:   session.OAuthName, | ||||||
| 	}, nil | 	}, nil | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -22,6 +22,7 @@ type GenericOAuthService struct { | |||||||
| 	verifier           string | 	verifier           string | ||||||
| 	insecureSkipVerify bool | 	insecureSkipVerify bool | ||||||
| 	userinfoUrl        string | 	userinfoUrl        string | ||||||
|  | 	name               string | ||||||
| } | } | ||||||
|  |  | ||||||
| func NewGenericOAuthService(config config.OAuthServiceConfig) *GenericOAuthService { | func NewGenericOAuthService(config config.OAuthServiceConfig) *GenericOAuthService { | ||||||
| @@ -38,6 +39,7 @@ func NewGenericOAuthService(config config.OAuthServiceConfig) *GenericOAuthServi | |||||||
| 		}, | 		}, | ||||||
| 		insecureSkipVerify: config.InsecureSkipVerify, | 		insecureSkipVerify: config.InsecureSkipVerify, | ||||||
| 		userinfoUrl:        config.UserinfoURL, | 		userinfoUrl:        config.UserinfoURL, | ||||||
|  | 		name:               config.Name, | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -115,3 +117,7 @@ func (generic *GenericOAuthService) Userinfo() (config.Claims, error) { | |||||||
|  |  | ||||||
| 	return user, nil | 	return user, nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (generic *GenericOAuthService) GetName() string { | ||||||
|  | 	return generic.name | ||||||
|  | } | ||||||
|   | |||||||
| @@ -33,6 +33,7 @@ type GithubOAuthService struct { | |||||||
| 	context  context.Context | 	context  context.Context | ||||||
| 	token    *oauth2.Token | 	token    *oauth2.Token | ||||||
| 	verifier string | 	verifier string | ||||||
|  | 	name     string | ||||||
| } | } | ||||||
|  |  | ||||||
| func NewGithubOAuthService(config config.OAuthServiceConfig) *GithubOAuthService { | func NewGithubOAuthService(config config.OAuthServiceConfig) *GithubOAuthService { | ||||||
| @@ -44,6 +45,7 @@ func NewGithubOAuthService(config config.OAuthServiceConfig) *GithubOAuthService | |||||||
| 			Scopes:       GithubOAuthScopes, | 			Scopes:       GithubOAuthScopes, | ||||||
| 			Endpoint:     endpoints.GitHub, | 			Endpoint:     endpoints.GitHub, | ||||||
| 		}, | 		}, | ||||||
|  | 		name: config.Name, | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -167,3 +169,7 @@ func (github *GithubOAuthService) Userinfo() (config.Claims, error) { | |||||||
|  |  | ||||||
| 	return user, nil | 	return user, nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (github *GithubOAuthService) GetName() string { | ||||||
|  | 	return github.name | ||||||
|  | } | ||||||
|   | |||||||
| @@ -28,6 +28,7 @@ type GoogleOAuthService struct { | |||||||
| 	context  context.Context | 	context  context.Context | ||||||
| 	token    *oauth2.Token | 	token    *oauth2.Token | ||||||
| 	verifier string | 	verifier string | ||||||
|  | 	name     string | ||||||
| } | } | ||||||
|  |  | ||||||
| func NewGoogleOAuthService(config config.OAuthServiceConfig) *GoogleOAuthService { | func NewGoogleOAuthService(config config.OAuthServiceConfig) *GoogleOAuthService { | ||||||
| @@ -39,6 +40,7 @@ func NewGoogleOAuthService(config config.OAuthServiceConfig) *GoogleOAuthService | |||||||
| 			Scopes:       GoogleOAuthScopes, | 			Scopes:       GoogleOAuthScopes, | ||||||
| 			Endpoint:     endpoints.Google, | 			Endpoint:     endpoints.Google, | ||||||
| 		}, | 		}, | ||||||
|  | 		name: config.Name, | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -111,3 +113,7 @@ func (google *GoogleOAuthService) Userinfo() (config.Claims, error) { | |||||||
|  |  | ||||||
| 	return user, nil | 	return user, nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (google *GoogleOAuthService) GetName() string { | ||||||
|  | 	return google.name | ||||||
|  | } | ||||||
|   | |||||||
| @@ -14,6 +14,7 @@ type OAuthService interface { | |||||||
| 	GetAuthURL(state string) string | 	GetAuthURL(state string) string | ||||||
| 	VerifyCode(code string) error | 	VerifyCode(code string) error | ||||||
| 	Userinfo() (config.Claims, error) | 	Userinfo() (config.Claims, error) | ||||||
|  | 	GetName() string | ||||||
| } | } | ||||||
|  |  | ||||||
| type OAuthBrokerService struct { | type OAuthBrokerService struct { | ||||||
|   | |||||||
| @@ -6,6 +6,9 @@ import ( | |||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"tinyauth/internal/config" | 	"tinyauth/internal/config" | ||||||
|  | 	"tinyauth/internal/utils/decoders" | ||||||
|  |  | ||||||
|  | 	"maps" | ||||||
|  |  | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"github.com/rs/zerolog" | 	"github.com/rs/zerolog" | ||||||
| @@ -130,3 +133,68 @@ func GetLogLevel(level string) zerolog.Level { | |||||||
| 		return zerolog.InfoLevel | 		return zerolog.InfoLevel | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func GetOAuthProvidersConfig(env []string, args []string, appUrl string) (map[string]config.OAuthServiceConfig, error) { | ||||||
|  | 	providers := make(map[string]config.OAuthServiceConfig) | ||||||
|  |  | ||||||
|  | 	// Get from environment variables | ||||||
|  | 	envMap := make(map[string]string) | ||||||
|  |  | ||||||
|  | 	for _, e := range env { | ||||||
|  | 		pair := strings.SplitN(e, "=", 2) | ||||||
|  | 		if len(pair) == 2 { | ||||||
|  | 			envMap[pair[0]] = pair[1] | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	envProviders, err := decoders.DecodeEnv(envMap) | ||||||
|  |  | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	maps.Copy(providers, envProviders.Providers) | ||||||
|  |  | ||||||
|  | 	// Get from flags | ||||||
|  | 	flagsMap := make(map[string]string) | ||||||
|  |  | ||||||
|  | 	for _, arg := range args[1:] { | ||||||
|  | 		if strings.HasPrefix(arg, "--") { | ||||||
|  | 			pair := strings.SplitN(arg[2:], "=", 2) | ||||||
|  | 			if len(pair) == 2 { | ||||||
|  | 				flagsMap[pair[0]] = pair[1] | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	flagProviders, err := decoders.DecodeFlags(flagsMap) | ||||||
|  |  | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	maps.Copy(providers, flagProviders.Providers) | ||||||
|  |  | ||||||
|  | 	// For every provider get correct secret from file if set | ||||||
|  | 	for name, provider := range providers { | ||||||
|  | 		secret := GetSecret(provider.ClientSecret, provider.ClientSecretFile) | ||||||
|  | 		provider.ClientSecret = secret | ||||||
|  | 		provider.ClientSecretFile = "" | ||||||
|  | 		providers[name] = provider | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// If we have google/github providers and no redirect URL babysit them | ||||||
|  | 	babysitProviders := []string{"google", "github"} | ||||||
|  |  | ||||||
|  | 	for _, name := range babysitProviders { | ||||||
|  | 		if provider, exists := providers[name]; exists { | ||||||
|  | 			if provider.RedirectURL == "" { | ||||||
|  | 				provider.RedirectURL = appUrl + "/api/oauth/callback/" + name | ||||||
|  | 				providers[name] = provider | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Return combined providers | ||||||
|  | 	return providers, nil | ||||||
|  | } | ||||||
|   | |||||||
| @@ -1,6 +1,7 @@ | |||||||
| package utils_test | package utils_test | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"os" | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"tinyauth/internal/config" | 	"tinyauth/internal/config" | ||||||
| 	"tinyauth/internal/utils" | 	"tinyauth/internal/utils" | ||||||
| @@ -200,3 +201,71 @@ func TestIsRedirectSafe(t *testing.T) { | |||||||
| 	result = utils.IsRedirectSafe(redirectURL, domain) | 	result = utils.IsRedirectSafe(redirectURL, domain) | ||||||
| 	assert.Equal(t, false, result) | 	assert.Equal(t, false, result) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestGetOAuthProvidersConfig(t *testing.T) { | ||||||
|  | 	env := []string{"PROVIDERS_CLIENT1_CLIENT_ID=client1-id", "PROVIDERS_CLIENT1_CLIENT_SECRET=client1-secret"} | ||||||
|  | 	args := []string{"/tinyauth/tinyauth", "--providers-client2-client-id=client2-id", "--providers-client2-client-secret=client2-secret"} | ||||||
|  |  | ||||||
|  | 	expected := map[string]config.OAuthServiceConfig{ | ||||||
|  | 		"client1": { | ||||||
|  | 			ClientID:     "client1-id", | ||||||
|  | 			ClientSecret: "client1-secret", | ||||||
|  | 		}, | ||||||
|  | 		"client2": { | ||||||
|  | 			ClientID:     "client2-id", | ||||||
|  | 			ClientSecret: "client2-secret", | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	result, err := utils.GetOAuthProvidersConfig(env, args, "") | ||||||
|  | 	assert.NilError(t, err) | ||||||
|  | 	assert.DeepEqual(t, expected, result) | ||||||
|  |  | ||||||
|  | 	// Case with no providers | ||||||
|  | 	env = []string{} | ||||||
|  | 	args = []string{"/tinyauth/tinyauth"} | ||||||
|  | 	expected = map[string]config.OAuthServiceConfig{} | ||||||
|  |  | ||||||
|  | 	result, err = utils.GetOAuthProvidersConfig(env, args, "") | ||||||
|  | 	assert.NilError(t, err) | ||||||
|  | 	assert.DeepEqual(t, expected, result) | ||||||
|  |  | ||||||
|  | 	// Case with secret from file | ||||||
|  | 	file, err := os.Create("/tmp/tinyauth_test_file") | ||||||
|  | 	assert.NilError(t, err) | ||||||
|  |  | ||||||
|  | 	_, err = file.WriteString("file content\n") | ||||||
|  | 	assert.NilError(t, err) | ||||||
|  |  | ||||||
|  | 	err = file.Close() | ||||||
|  | 	assert.NilError(t, err) | ||||||
|  | 	defer os.Remove("/tmp/tinyauth_test_file") | ||||||
|  |  | ||||||
|  | 	env = []string{"PROVIDERS_CLIENT1_CLIENT_ID=client1-id", "PROVIDERS_CLIENT1_CLIENT_SECRET_FILE=/tmp/tinyauth_test_file"} | ||||||
|  | 	args = []string{"/tinyauth/tinyauth"} | ||||||
|  | 	expected = map[string]config.OAuthServiceConfig{ | ||||||
|  | 		"client1": { | ||||||
|  | 			ClientID:     "client1-id", | ||||||
|  | 			ClientSecret: "file content", | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	result, err = utils.GetOAuthProvidersConfig(env, args, "") | ||||||
|  | 	assert.NilError(t, err) | ||||||
|  | 	assert.DeepEqual(t, expected, result) | ||||||
|  |  | ||||||
|  | 	// Case with google provider and no redirect URL | ||||||
|  | 	env = []string{"PROVIDERS_GOOGLE_CLIENT_ID=google-id", "PROVIDERS_GOOGLE_CLIENT_SECRET=google-secret"} | ||||||
|  | 	args = []string{"/tinyauth/tinyauth"} | ||||||
|  | 	expected = map[string]config.OAuthServiceConfig{ | ||||||
|  | 		"google": { | ||||||
|  | 			ClientID:     "google-id", | ||||||
|  | 			ClientSecret: "google-secret", | ||||||
|  | 			RedirectURL:  "http://app.url/api/oauth/callback/google", | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	result, err = utils.GetOAuthProvidersConfig(env, args, "http://app.url") | ||||||
|  | 	assert.NilError(t, err) | ||||||
|  | 	assert.DeepEqual(t, expected, result) | ||||||
|  | } | ||||||
|   | |||||||
							
								
								
									
										81
									
								
								internal/utils/decoders/decoders.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										81
									
								
								internal/utils/decoders/decoders.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,81 @@ | |||||||
|  | package decoders | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"reflect" | ||||||
|  | 	"strings" | ||||||
|  | 	"tinyauth/internal/config" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func NormalizeKeys(keys map[string]string, rootName string, sep string) map[string]string { | ||||||
|  | 	normalized := make(map[string]string) | ||||||
|  | 	knownKeys := getKnownKeys() | ||||||
|  |  | ||||||
|  | 	for k, v := range keys { | ||||||
|  | 		var finalKey []string | ||||||
|  | 		var suffix string | ||||||
|  | 		var camelClientName string | ||||||
|  | 		var camelField string | ||||||
|  |  | ||||||
|  | 		finalKey = append(finalKey, rootName) | ||||||
|  | 		finalKey = append(finalKey, "providers") | ||||||
|  | 		cebabKey := strings.ToLower(k) | ||||||
|  |  | ||||||
|  | 		for _, known := range knownKeys { | ||||||
|  | 			if strings.HasSuffix(cebabKey, strings.ReplaceAll(known, "-", sep)) { | ||||||
|  | 				suffix = known | ||||||
|  | 				break | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if suffix == "" { | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		clientNameParts := strings.Split(strings.TrimPrefix(strings.TrimSuffix(cebabKey, sep+strings.ReplaceAll(suffix, "-", sep)), "providers"+sep), sep) | ||||||
|  |  | ||||||
|  | 		for i, p := range clientNameParts { | ||||||
|  | 			if i == 0 { | ||||||
|  | 				camelClientName += p | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 			if p == "" { | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 			camelClientName += strings.ToUpper(string([]rune(p)[0])) + string([]rune(p)[1:]) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		finalKey = append(finalKey, camelClientName) | ||||||
|  |  | ||||||
|  | 		filedParts := strings.Split(suffix, "-") | ||||||
|  |  | ||||||
|  | 		for i, p := range filedParts { | ||||||
|  | 			if i == 0 { | ||||||
|  | 				camelField += p | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 			if p == "" { | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 			camelField += strings.ToUpper(string([]rune(p)[0])) + string([]rune(p)[1:]) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		finalKey = append(finalKey, camelField) | ||||||
|  | 		normalized[strings.Join(finalKey, ".")] = v | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return normalized | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func getKnownKeys() []string { | ||||||
|  | 	var known []string | ||||||
|  |  | ||||||
|  | 	p := config.OAuthServiceConfig{} | ||||||
|  | 	v := reflect.ValueOf(p) | ||||||
|  | 	typeOfP := v.Type() | ||||||
|  |  | ||||||
|  | 	for field := range typeOfP.NumField() { | ||||||
|  | 		known = append(known, typeOfP.Field(field).Tag.Get("key")) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return known | ||||||
|  | } | ||||||
							
								
								
									
										44
									
								
								internal/utils/decoders/decoders_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								internal/utils/decoders/decoders_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,44 @@ | |||||||
|  | package decoders_test | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"testing" | ||||||
|  | 	"tinyauth/internal/utils/decoders" | ||||||
|  |  | ||||||
|  | 	"gotest.tools/v3/assert" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestNormalizeKeys(t *testing.T) { | ||||||
|  | 	// Test with env | ||||||
|  | 	test := map[string]string{ | ||||||
|  | 		"PROVIDERS_CLIENT1_CLIENT_ID":                    "my-client-id", | ||||||
|  | 		"PROVIDERS_CLIENT1_CLIENT_SECRET":                "my-client-secret", | ||||||
|  | 		"PROVIDERS_MY_AWESOME_CLIENT_CLIENT_ID":          "my-awesome-client-id", | ||||||
|  | 		"PROVIDERS_MY_AWESOME_CLIENT_CLIENT_SECRET_FILE": "/path/to/secret", | ||||||
|  | 	} | ||||||
|  | 	expected := map[string]string{ | ||||||
|  | 		"tinyauth.providers.client1.clientId":                 "my-client-id", | ||||||
|  | 		"tinyauth.providers.client1.clientSecret":             "my-client-secret", | ||||||
|  | 		"tinyauth.providers.myAwesomeClient.clientId":         "my-awesome-client-id", | ||||||
|  | 		"tinyauth.providers.myAwesomeClient.clientSecretFile": "/path/to/secret", | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	normalized := decoders.NormalizeKeys(test, "tinyauth", "_") | ||||||
|  | 	assert.DeepEqual(t, normalized, expected) | ||||||
|  |  | ||||||
|  | 	// Test with flags (assume -- is already stripped) | ||||||
|  | 	test = map[string]string{ | ||||||
|  | 		"providers-client1-client-id":                    "my-client-id", | ||||||
|  | 		"providers-client1-client-secret":                "my-client-secret", | ||||||
|  | 		"providers-my-awesome-client-client-id":          "my-awesome-client-id", | ||||||
|  | 		"providers-my-awesome-client-client-secret-file": "/path/to/secret", | ||||||
|  | 	} | ||||||
|  | 	expected = map[string]string{ | ||||||
|  | 		"tinyauth.providers.client1.clientId":                 "my-client-id", | ||||||
|  | 		"tinyauth.providers.client1.clientSecret":             "my-client-secret", | ||||||
|  | 		"tinyauth.providers.myAwesomeClient.clientId":         "my-awesome-client-id", | ||||||
|  | 		"tinyauth.providers.myAwesomeClient.clientSecretFile": "/path/to/secret", | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	normalized = decoders.NormalizeKeys(test, "tinyauth", "-") | ||||||
|  | 	assert.DeepEqual(t, normalized, expected) | ||||||
|  | } | ||||||
							
								
								
									
										20
									
								
								internal/utils/decoders/env_decoder.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								internal/utils/decoders/env_decoder.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,20 @@ | |||||||
|  | package decoders | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"tinyauth/internal/config" | ||||||
|  |  | ||||||
|  | 	"github.com/traefik/paerser/parser" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func DecodeEnv(env map[string]string) (config.Providers, error) { | ||||||
|  | 	normalized := NormalizeKeys(env, "tinyauth", "_") | ||||||
|  | 	var providers config.Providers | ||||||
|  |  | ||||||
|  | 	err := parser.Decode(normalized, &providers, "tinyauth", "tinyauth.providers") | ||||||
|  |  | ||||||
|  | 	if err != nil { | ||||||
|  | 		return config.Providers{}, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return providers, nil | ||||||
|  | } | ||||||
							
								
								
									
										60
									
								
								internal/utils/decoders/env_decoder_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										60
									
								
								internal/utils/decoders/env_decoder_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,60 @@ | |||||||
|  | package decoders_test | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"testing" | ||||||
|  | 	"tinyauth/internal/config" | ||||||
|  | 	"tinyauth/internal/utils/decoders" | ||||||
|  |  | ||||||
|  | 	"gotest.tools/v3/assert" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestDecodeEnv(t *testing.T) { | ||||||
|  | 	// Variables | ||||||
|  | 	expected := config.Providers{ | ||||||
|  | 		Providers: map[string]config.OAuthServiceConfig{ | ||||||
|  | 			"client1": { | ||||||
|  | 				ClientID:           "client1-id", | ||||||
|  | 				ClientSecret:       "client1-secret", | ||||||
|  | 				Scopes:             []string{"client1-scope1", "client1-scope2"}, | ||||||
|  | 				RedirectURL:        "client1-redirect-url", | ||||||
|  | 				AuthURL:            "client1-auth-url", | ||||||
|  | 				UserinfoURL:        "client1-user-info-url", | ||||||
|  | 				Name:               "Client1", | ||||||
|  | 				InsecureSkipVerify: false, | ||||||
|  | 			}, | ||||||
|  | 			"client2": { | ||||||
|  | 				ClientID:           "client2-id", | ||||||
|  | 				ClientSecret:       "client2-secret", | ||||||
|  | 				Scopes:             []string{"client2-scope1", "client2-scope2"}, | ||||||
|  | 				RedirectURL:        "client2-redirect-url", | ||||||
|  | 				AuthURL:            "client2-auth-url", | ||||||
|  | 				UserinfoURL:        "client2-user-info-url", | ||||||
|  | 				Name:               "My Awesome Client2", | ||||||
|  | 				InsecureSkipVerify: false, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	test := map[string]string{ | ||||||
|  | 		"PROVIDERS_CLIENT1_CLIENT_ID":            "client1-id", | ||||||
|  | 		"PROVIDERS_CLIENT1_CLIENT_SECRET":        "client1-secret", | ||||||
|  | 		"PROVIDERS_CLIENT1_SCOPES":               "client1-scope1,client1-scope2", | ||||||
|  | 		"PROVIDERS_CLIENT1_REDIRECT_URL":         "client1-redirect-url", | ||||||
|  | 		"PROVIDERS_CLIENT1_AUTH_URL":             "client1-auth-url", | ||||||
|  | 		"PROVIDERS_CLIENT1_USER_INFO_URL":        "client1-user-info-url", | ||||||
|  | 		"PROVIDERS_CLIENT1_NAME":                 "Client1", | ||||||
|  | 		"PROVIDERS_CLIENT1_INSECURE_SKIP_VERIFY": "false", | ||||||
|  | 		"PROVIDERS_CLIENT2_CLIENT_ID":            "client2-id", | ||||||
|  | 		"PROVIDERS_CLIENT2_CLIENT_SECRET":        "client2-secret", | ||||||
|  | 		"PROVIDERS_CLIENT2_SCOPES":               "client2-scope1,client2-scope2", | ||||||
|  | 		"PROVIDERS_CLIENT2_REDIRECT_URL":         "client2-redirect-url", | ||||||
|  | 		"PROVIDERS_CLIENT2_AUTH_URL":             "client2-auth-url", | ||||||
|  | 		"PROVIDERS_CLIENT2_USER_INFO_URL":        "client2-user-info-url", | ||||||
|  | 		"PROVIDERS_CLIENT2_NAME":                 "My Awesome Client2", | ||||||
|  | 		"PROVIDERS_CLIENT2_INSECURE_SKIP_VERIFY": "false", | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Test | ||||||
|  | 	res, err := decoders.DecodeEnv(test) | ||||||
|  | 	assert.NilError(t, err) | ||||||
|  | 	assert.DeepEqual(t, expected, res) | ||||||
|  | } | ||||||
							
								
								
									
										30
									
								
								internal/utils/decoders/flags_decoder.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								internal/utils/decoders/flags_decoder.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,30 @@ | |||||||
|  | package decoders | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"strings" | ||||||
|  | 	"tinyauth/internal/config" | ||||||
|  |  | ||||||
|  | 	"github.com/traefik/paerser/parser" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func DecodeFlags(flags map[string]string) (config.Providers, error) { | ||||||
|  | 	filtered := filterFlags(flags) | ||||||
|  | 	normalized := NormalizeKeys(filtered, "tinyauth", "-") | ||||||
|  | 	var providers config.Providers | ||||||
|  |  | ||||||
|  | 	err := parser.Decode(normalized, &providers, "tinyauth", "tinyauth.providers") | ||||||
|  |  | ||||||
|  | 	if err != nil { | ||||||
|  | 		return config.Providers{}, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return providers, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func filterFlags(flags map[string]string) map[string]string { | ||||||
|  | 	filtered := make(map[string]string) | ||||||
|  | 	for k, v := range flags { | ||||||
|  | 		filtered[strings.TrimPrefix(k, "--")] = v | ||||||
|  | 	} | ||||||
|  | 	return filtered | ||||||
|  | } | ||||||
							
								
								
									
										60
									
								
								internal/utils/decoders/flags_decoder_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										60
									
								
								internal/utils/decoders/flags_decoder_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,60 @@ | |||||||
|  | package decoders_test | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"testing" | ||||||
|  | 	"tinyauth/internal/config" | ||||||
|  | 	"tinyauth/internal/utils/decoders" | ||||||
|  |  | ||||||
|  | 	"gotest.tools/v3/assert" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestDecodeFlags(t *testing.T) { | ||||||
|  | 	// Variables | ||||||
|  | 	expected := config.Providers{ | ||||||
|  | 		Providers: map[string]config.OAuthServiceConfig{ | ||||||
|  | 			"client1": { | ||||||
|  | 				ClientID:           "client1-id", | ||||||
|  | 				ClientSecret:       "client1-secret", | ||||||
|  | 				Scopes:             []string{"client1-scope1", "client1-scope2"}, | ||||||
|  | 				RedirectURL:        "client1-redirect-url", | ||||||
|  | 				AuthURL:            "client1-auth-url", | ||||||
|  | 				UserinfoURL:        "client1-user-info-url", | ||||||
|  | 				Name:               "Client1", | ||||||
|  | 				InsecureSkipVerify: false, | ||||||
|  | 			}, | ||||||
|  | 			"client2": { | ||||||
|  | 				ClientID:           "client2-id", | ||||||
|  | 				ClientSecret:       "client2-secret", | ||||||
|  | 				Scopes:             []string{"client2-scope1", "client2-scope2"}, | ||||||
|  | 				RedirectURL:        "client2-redirect-url", | ||||||
|  | 				AuthURL:            "client2-auth-url", | ||||||
|  | 				UserinfoURL:        "client2-user-info-url", | ||||||
|  | 				Name:               "My Awesome Client2", | ||||||
|  | 				InsecureSkipVerify: false, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	test := map[string]string{ | ||||||
|  | 		"--providers-client1-client-id":            "client1-id", | ||||||
|  | 		"--providers-client1-client-secret":        "client1-secret", | ||||||
|  | 		"--providers-client1-scopes":               "client1-scope1,client1-scope2", | ||||||
|  | 		"--providers-client1-redirect-url":         "client1-redirect-url", | ||||||
|  | 		"--providers-client1-auth-url":             "client1-auth-url", | ||||||
|  | 		"--providers-client1-user-info-url":        "client1-user-info-url", | ||||||
|  | 		"--providers-client1-name":                 "Client1", | ||||||
|  | 		"--providers-client1-insecure-skip-verify": "false", | ||||||
|  | 		"--providers-client2-client-id":            "client2-id", | ||||||
|  | 		"--providers-client2-client-secret":        "client2-secret", | ||||||
|  | 		"--providers-client2-scopes":               "client2-scope1,client2-scope2", | ||||||
|  | 		"--providers-client2-redirect-url":         "client2-redirect-url", | ||||||
|  | 		"--providers-client2-auth-url":             "client2-auth-url", | ||||||
|  | 		"--providers-client2-user-info-url":        "client2-user-info-url", | ||||||
|  | 		"--providers-client2-name":                 "My Awesome Client2", | ||||||
|  | 		"--providers-client2-insecure-skip-verify": "false", | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Test | ||||||
|  | 	res, err := decoders.DecodeFlags(test) | ||||||
|  | 	assert.NilError(t, err) | ||||||
|  | 	assert.DeepEqual(t, expected, res) | ||||||
|  | } | ||||||
| @@ -1,10 +1,11 @@ | |||||||
| package decoders_test | package decoders_test | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"reflect" |  | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"tinyauth/internal/config" | 	"tinyauth/internal/config" | ||||||
| 	"tinyauth/internal/utils/decoders" | 	"tinyauth/internal/utils/decoders" | ||||||
|  |  | ||||||
|  | 	"gotest.tools/v3/assert" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func TestDecodeLabels(t *testing.T) { | func TestDecodeLabels(t *testing.T) { | ||||||
| @@ -62,12 +63,6 @@ func TestDecodeLabels(t *testing.T) { | |||||||
|  |  | ||||||
| 	// Test | 	// Test | ||||||
| 	result, err := decoders.DecodeLabels(test) | 	result, err := decoders.DecodeLabels(test) | ||||||
|  | 	assert.NilError(t, err) | ||||||
| 	if err != nil { | 	assert.DeepEqual(t, expected, result) | ||||||
| 		t.Fatalf("Unexpected error: %v", err) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if reflect.DeepEqual(expected, result) == false { |  | ||||||
| 		t.Fatalf("Expected %v but got %v", expected, result) |  | ||||||
| 	} |  | ||||||
| } | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user