mirror of
				https://github.com/steveiliop56/tinyauth.git
				synced 2025-10-30 21:55:43 +00:00 
			
		
		
		
	Compare commits
	
		
			12 Commits
		
	
	
		
			docs/updat
			...
			7795a989cd
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|   | 7795a989cd | ||
|   | cebce1a92c | ||
|   | 120ae2c79d | ||
|   | 060e20e578 | ||
|   | e001f63eb5 | ||
|   | 9f97a4ddd5 | ||
|   | e5ecf6336f | ||
|   | fbf5843592 | ||
|   | 5fcc50d5fd | ||
|   | 68fd5ac24c | ||
|   | b30b908de3 | ||
|   | 91048c16f8 | 
							
								
								
									
										23
									
								
								cmd/root.go
									
									
									
									
									
								
							
							
						
						
									
										23
									
								
								cmd/root.go
									
									
									
									
									
								
							| @@ -27,11 +27,6 @@ var rootCmd = &cobra.Command{ | ||||
| 			log.Fatal().Err(err).Msg("Failed to parse config") | ||||
| 		} | ||||
|  | ||||
| 		// Check if secrets have a file associated with them | ||||
| 		conf.GithubClientSecret = utils.GetSecret(conf.GithubClientSecret, conf.GithubClientSecretFile) | ||||
| 		conf.GoogleClientSecret = utils.GetSecret(conf.GoogleClientSecret, conf.GoogleClientSecretFile) | ||||
| 		conf.GenericClientSecret = utils.GetSecret(conf.GenericClientSecret, conf.GenericClientSecretFile) | ||||
|  | ||||
| 		// Validate config | ||||
| 		v := validator.New() | ||||
|  | ||||
| @@ -57,6 +52,7 @@ var rootCmd = &cobra.Command{ | ||||
| } | ||||
|  | ||||
| func Execute() { | ||||
| 	rootCmd.FParseErrWhitelist.UnknownFlags = true | ||||
| 	err := rootCmd.Execute() | ||||
| 	if err != nil { | ||||
| 		log.Fatal().Err(err).Msg("Failed to execute command") | ||||
| @@ -80,21 +76,6 @@ func init() { | ||||
| 		{"users", "", "Comma separated list of users in the format username:hash."}, | ||||
| 		{"users-file", "", "Path to a file containing users in the format username:hash."}, | ||||
| 		{"secure-cookie", false, "Send cookie over secure connection only."}, | ||||
| 		{"github-client-id", "", "Github OAuth client ID."}, | ||||
| 		{"github-client-secret", "", "Github OAuth client secret."}, | ||||
| 		{"github-client-secret-file", "", "Github OAuth client secret file."}, | ||||
| 		{"google-client-id", "", "Google OAuth client ID."}, | ||||
| 		{"google-client-secret", "", "Google OAuth client secret."}, | ||||
| 		{"google-client-secret-file", "", "Google OAuth client secret file."}, | ||||
| 		{"generic-client-id", "", "Generic OAuth client ID."}, | ||||
| 		{"generic-client-secret", "", "Generic OAuth client secret."}, | ||||
| 		{"generic-client-secret-file", "", "Generic OAuth client secret file."}, | ||||
| 		{"generic-scopes", "", "Generic OAuth scopes."}, | ||||
| 		{"generic-auth-url", "", "Generic OAuth auth URL."}, | ||||
| 		{"generic-token-url", "", "Generic OAuth token URL."}, | ||||
| 		{"generic-user-url", "", "Generic OAuth user info URL."}, | ||||
| 		{"generic-name", "Generic", "Generic OAuth provider name."}, | ||||
| 		{"generic-skip-ssl", false, "Skip SSL verification for the generic OAuth provider."}, | ||||
| 		{"oauth-whitelist", "", "Comma separated list of email addresses to whitelist when using OAuth."}, | ||||
| 		{"oauth-auto-redirect", "none", "Auto redirect to the specified OAuth provider if configured. (available providers: github, google, generic)"}, | ||||
| 		{"session-expiry", 86400, "Session (cookie) expiration time in seconds."}, | ||||
| @@ -112,7 +93,7 @@ func init() { | ||||
| 		{"ldap-search-filter", "(uid=%s)", "LDAP search filter for user lookup."}, | ||||
| 		{"resources-dir", "/data/resources", "Path to a directory containing custom resources (e.g. background image)."}, | ||||
| 		{"database-path", "/data/tinyauth.db", "Path to the Sqlite database file."}, | ||||
| 		{"trusted-proxies", "", "Comma separated list of trusted proxies (IP addresses) for correct client IP detection and for header ACLs."}, | ||||
| 		{"trusted-proxies", "", "Comma separated list of trusted proxies (IP addresses or CIDRs) for correct client IP detection."}, | ||||
| 	} | ||||
|  | ||||
| 	for _, opt := range configOptions { | ||||
|   | ||||
							
								
								
									
										18
									
								
								frontend/src/components/icons/microsoft.tsx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								frontend/src/components/icons/microsoft.tsx
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,18 @@ | ||||
| import type { SVGProps } from "react"; | ||||
|  | ||||
| export function MicrosoftIcon(props: SVGProps<SVGSVGElement>) { | ||||
|   return ( | ||||
|     <svg | ||||
|       xmlns="http://www.w3.org/2000/svg" | ||||
|       width="2em" | ||||
|       height="2em" | ||||
|       viewBox="0 0 256 256" | ||||
|       {...props} | ||||
|     > | ||||
|       <path fill="#f1511b" d="M121.666 121.666H0V0h121.666z"></path> | ||||
|       <path fill="#80cc28" d="M256 121.666H134.335V0H256z"></path> | ||||
|       <path fill="#00adef" d="M121.663 256.002H0V134.336h121.663z"></path> | ||||
|       <path fill="#fbbc09" d="M256 256.002H134.335V134.336H256z"></path> | ||||
|     </svg> | ||||
|   ); | ||||
| } | ||||
| @@ -1,6 +1,6 @@ | ||||
| import type { SVGProps } from "react"; | ||||
| 
 | ||||
| export function GenericIcon(props: SVGProps<SVGSVGElement>) { | ||||
| export function OAuthIcon(props: SVGProps<SVGSVGElement>) { | ||||
|   return ( | ||||
|     <svg | ||||
|       xmlns="http://www.w3.org/2000/svg" | ||||
							
								
								
									
										20
									
								
								frontend/src/components/icons/pocket-id.tsx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								frontend/src/components/icons/pocket-id.tsx
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,20 @@ | ||||
| import type { SVGProps } from "react"; | ||||
|  | ||||
| export function PocketIDIcon(props: SVGProps<SVGSVGElement>) { | ||||
|   return ( | ||||
|     <svg | ||||
|       xmlns="http://www.w3.org/2000/svg" | ||||
|       xmlSpace="preserve" | ||||
|       width={512} | ||||
|       height={512} | ||||
|       viewBox="0 0 512 512" | ||||
|       {...props} | ||||
|     > | ||||
|       <circle cx="256" cy="256" r="256" /> | ||||
|       <path | ||||
|         d="M268.6 102.4c64.4 0 116.8 52.4 116.8 116.7 0 25.3-8 49.4-23 69.6-14.8 19.9-35 34.3-58.4 41.7l-6.5 2-15.5-76.2 4.3-2c14-6.7 23-21.1 23-36.6 0-22.4-18.2-40.6-40.6-40.6S228 195.2 228 217.6c0 15.5 9 29.8 23 36.6l4.2 2-25 153.4h-69.5V102.4z" | ||||
|         className="fill-white" | ||||
|       /> | ||||
|     </svg> | ||||
|   ); | ||||
| } | ||||
							
								
								
									
										26
									
								
								frontend/src/components/icons/tailscale.tsx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								frontend/src/components/icons/tailscale.tsx
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,26 @@ | ||||
| import type { SVGProps } from "react"; | ||||
|  | ||||
| export function TailscaleIcon(props: SVGProps<SVGSVGElement>) { | ||||
|   return ( | ||||
|     <svg | ||||
|       xmlns="http://www.w3.org/2000/svg" | ||||
|       xmlSpace="preserve" | ||||
|       width={512} | ||||
|       height={512} | ||||
|       viewBox="0 0 512 512" | ||||
|       {...props} | ||||
|     > | ||||
|       <path | ||||
|         className="opacity-80" | ||||
|         fill="currentColor" | ||||
|         d="M65.6 318.1c35.3 0 63.9-28.6 63.9-63.9s-28.6-63.9-63.9-63.9S1.8 219 1.8 254.2s28.6 63.9 63.8 63.9m191.6 0c35.3 0 63.9-28.6 63.9-63.9s-28.6-63.9-63.9-63.9-63.9 28.6-63.9 63.9 28.6 63.9 63.9 63.9m0 193.9c35.3 0 63.9-28.6 63.9-63.9s-28.6-63.9-63.9-63.9-63.9 28.6-63.9 63.9 28.6 63.9 63.9 63.9m189.2-193.9c35.3 0 63.9-28.6 63.9-63.9s-28.6-63.9-63.9-63.9-63.9 28.6-63.9 63.9 28.6 63.9 63.9 63.9" | ||||
|       /> | ||||
|  | ||||
|       <path | ||||
|         d="M65.6 127.7c35.3 0 63.9-28.6 63.9-63.9S100.9 0 65.6 0 1.8 28.6 1.8 63.9s28.6 63.8 63.8 63.8m0 384.3c35.3 0 63.9-28.6 63.9-63.9s-28.6-63.9-63.9-63.9-63.8 28.7-63.8 63.9S30.4 512 65.6 512m191.6-384.3c35.3 0 63.9-28.6 63.9-63.9S292.5 0 257.2 0s-63.9 28.6-63.9 63.9 28.6 63.8 63.9 63.8m189.2 0c35.3 0 63.9-28.6 63.9-63.9S481.6 0 446.4 0c-35.3 0-63.9 28.6-63.9 63.9s28.6 63.8 63.9 63.8m0 384.3c35.3 0 63.9-28.6 63.9-63.9s-28.6-63.9-63.9-63.9-63.9 28.6-63.9 63.9 28.6 63.9 63.9 63.9" | ||||
|         className="opacity-20" | ||||
|         fill="currentColor" | ||||
|       /> | ||||
|     </svg> | ||||
|   ); | ||||
| } | ||||
| @@ -14,6 +14,9 @@ | ||||
|     "loginOauthFailSubtitle": "Failed to get OAuth URL", | ||||
|     "loginOauthSuccessTitle": "Redirecting", | ||||
|     "loginOauthSuccessSubtitle": "Redirecting to your OAuth provider", | ||||
|     "loginOauthAutoRedirectTitle": "OAuth Auto Redirect", | ||||
|     "loginOauthAutoRedirectSubtitle": "You will be automatically redirected to your OAuth provider to authenticate.", | ||||
|     "loginOauthAutoRedirectButton": "Redirect now", | ||||
|     "continueTitle": "Continue", | ||||
|     "continueRedirectingTitle": "Redirecting...", | ||||
|     "continueRedirectingSubtitle": "You should be redirected to the app soon", | ||||
|   | ||||
| @@ -14,6 +14,9 @@ | ||||
|     "loginOauthFailSubtitle": "Failed to get OAuth URL", | ||||
|     "loginOauthSuccessTitle": "Redirecting", | ||||
|     "loginOauthSuccessSubtitle": "Redirecting to your OAuth provider", | ||||
|     "loginOauthAutoRedirectTitle": "OAuth Auto Redirect", | ||||
|     "loginOauthAutoRedirectSubtitle": "You will be automatically redirected to your OAuth provider to authenticate.", | ||||
|     "loginOauthAutoRedirectButton": "Redirect now", | ||||
|     "continueTitle": "Continue", | ||||
|     "continueRedirectingTitle": "Redirecting...", | ||||
|     "continueRedirectingSubtitle": "You should be redirected to the app soon", | ||||
|   | ||||
| @@ -70,7 +70,7 @@ export const ContinuePage = () => { | ||||
|     const reveal = setTimeout(() => { | ||||
|       setLoading(false); | ||||
|       setShowRedirectButton(true); | ||||
|     }, 1000); | ||||
|     }, 5000); | ||||
|  | ||||
|     return () => { | ||||
|       clearTimeout(auto); | ||||
|   | ||||
| @@ -1,13 +1,18 @@ | ||||
| import { LoginForm } from "@/components/auth/login-form"; | ||||
| import { GenericIcon } from "@/components/icons/generic"; | ||||
| import { GithubIcon } from "@/components/icons/github"; | ||||
| import { GoogleIcon } from "@/components/icons/google"; | ||||
| import { MicrosoftIcon } from "@/components/icons/microsoft"; | ||||
| import { OAuthIcon } from "@/components/icons/oauth"; | ||||
| import { PocketIDIcon } from "@/components/icons/pocket-id"; | ||||
| import { TailscaleIcon } from "@/components/icons/tailscale"; | ||||
| import { Button } from "@/components/ui/button"; | ||||
| import { | ||||
|   Card, | ||||
|   CardHeader, | ||||
|   CardTitle, | ||||
|   CardDescription, | ||||
|   CardContent, | ||||
|   CardFooter, | ||||
| } from "@/components/ui/card"; | ||||
| import { OAuthButton } from "@/components/ui/oauth-button"; | ||||
| import { SeperatorWithChildren } from "@/components/ui/separator"; | ||||
| @@ -17,28 +22,40 @@ import { useIsMounted } from "@/lib/hooks/use-is-mounted"; | ||||
| import { LoginSchema } from "@/schemas/login-schema"; | ||||
| import { useMutation } from "@tanstack/react-query"; | ||||
| import axios, { AxiosError } from "axios"; | ||||
| import { useEffect, useRef } from "react"; | ||||
| import { useEffect, useRef, useState } from "react"; | ||||
| import { useTranslation } from "react-i18next"; | ||||
| import { Navigate, useLocation } from "react-router"; | ||||
| import { toast } from "sonner"; | ||||
|  | ||||
| const iconMap: Record<string, React.ReactNode> = { | ||||
|   google: <GoogleIcon />, | ||||
|   github: <GithubIcon />, | ||||
|   tailscale: <TailscaleIcon />, | ||||
|   microsoft: <MicrosoftIcon />, | ||||
|   pocketid: <PocketIDIcon />, | ||||
| }; | ||||
|  | ||||
| export const LoginPage = () => { | ||||
|   const { isLoggedIn } = useUserContext(); | ||||
|   const { configuredProviders, title, oauthAutoRedirect, genericName } = | ||||
|     useAppContext(); | ||||
|   const { providers, title, oauthAutoRedirect } = useAppContext(); | ||||
|   const { search } = useLocation(); | ||||
|   const { t } = useTranslation(); | ||||
|   const isMounted = useIsMounted(); | ||||
|   const [oauthAutoRedirectHandover, setOauthAutoRedirectHandover] = | ||||
|     useState(false); | ||||
|   const [showRedirectButton, setShowRedirectButton] = useState(false); | ||||
|  | ||||
|   const redirectTimer = useRef<number | null>(null); | ||||
|   const redirectButtonTimer = useRef<number | null>(null); | ||||
|  | ||||
|   const searchParams = new URLSearchParams(search); | ||||
|   const redirectUri = searchParams.get("redirect_uri"); | ||||
|  | ||||
|   const oauthConfigured = | ||||
|     configuredProviders.filter((provider) => provider !== "username").length > | ||||
|     0; | ||||
|   const userAuthConfigured = configuredProviders.includes("username"); | ||||
|   const oauthProviders = providers.filter( | ||||
|     (provider) => provider.id !== "username", | ||||
|   ); | ||||
|   const userAuthConfigured = | ||||
|     providers.find((provider) => provider.id === "username") !== undefined; | ||||
|  | ||||
|   const oauthMutation = useMutation({ | ||||
|     mutationFn: (provider: string) => | ||||
| @@ -56,6 +73,7 @@ export const LoginPage = () => { | ||||
|       }, 500); | ||||
|     }, | ||||
|     onError: () => { | ||||
|       setOauthAutoRedirectHandover(false); | ||||
|       toast.error(t("loginOauthFailTitle"), { | ||||
|         description: t("loginOauthFailSubtitle"), | ||||
|       }); | ||||
| @@ -96,12 +114,16 @@ export const LoginPage = () => { | ||||
|   useEffect(() => { | ||||
|     if (isMounted()) { | ||||
|       if ( | ||||
|         oauthConfigured && | ||||
|         configuredProviders.includes(oauthAutoRedirect) && | ||||
|         oauthProviders.length !== 0 && | ||||
|         providers.find((provider) => provider.id === oauthAutoRedirect) && | ||||
|         !isLoggedIn && | ||||
|         redirectUri | ||||
|       ) { | ||||
|         setOauthAutoRedirectHandover(true); | ||||
|         oauthMutation.mutate(oauthAutoRedirect); | ||||
|         redirectButtonTimer.current = window.setTimeout(() => { | ||||
|           setShowRedirectButton(true); | ||||
|         }, 5000); | ||||
|       } | ||||
|     } | ||||
|   }, []); | ||||
| @@ -109,6 +131,8 @@ export const LoginPage = () => { | ||||
|   useEffect( | ||||
|     () => () => { | ||||
|       if (redirectTimer.current) clearTimeout(redirectTimer.current); | ||||
|       if (redirectButtonTimer.current) | ||||
|         clearTimeout(redirectButtonTimer.current); | ||||
|     }, | ||||
|     [], | ||||
|   ); | ||||
| @@ -126,61 +150,63 @@ export const LoginPage = () => { | ||||
|     return <Navigate to="/logout" replace />; | ||||
|   } | ||||
|  | ||||
|   if (oauthAutoRedirectHandover) { | ||||
|     return ( | ||||
|       <Card className="min-w-xs sm:min-w-sm"> | ||||
|         <CardHeader> | ||||
|           <CardTitle className="text-3xl"> | ||||
|             {t("loginOauthAutoRedirectTitle")} | ||||
|           </CardTitle> | ||||
|           <CardDescription> | ||||
|             {t("loginOauthAutoRedirectSubtitle")} | ||||
|           </CardDescription> | ||||
|         </CardHeader> | ||||
|         {showRedirectButton && ( | ||||
|           <CardFooter className="flex flex-col items-stretch"> | ||||
|             <Button | ||||
|               onClick={() => { | ||||
|                 window.location.replace(oauthMutation.data?.data.url); | ||||
|               }} | ||||
|             > | ||||
|               {t("loginOauthAutoRedirectButton")} | ||||
|             </Button> | ||||
|           </CardFooter> | ||||
|         )} | ||||
|       </Card> | ||||
|     ); | ||||
|   } | ||||
|   return ( | ||||
|     <Card className="min-w-xs sm:min-w-sm"> | ||||
|       <CardHeader> | ||||
|         <CardTitle className="text-center text-3xl">{title}</CardTitle> | ||||
|         {configuredProviders.length > 0 && ( | ||||
|         {providers.length > 0 && ( | ||||
|           <CardDescription className="text-center"> | ||||
|             {oauthConfigured ? t("loginTitle") : t("loginTitleSimple")} | ||||
|             {oauthProviders.length !== 0 | ||||
|               ? t("loginTitle") | ||||
|               : t("loginTitleSimple")} | ||||
|           </CardDescription> | ||||
|         )} | ||||
|       </CardHeader> | ||||
|       <CardContent className="flex flex-col gap-4"> | ||||
|         {oauthConfigured && ( | ||||
|         {oauthProviders.length !== 0 && ( | ||||
|           <div className="flex flex-col gap-2 items-center justify-center"> | ||||
|             {configuredProviders.includes("google") && ( | ||||
|             {oauthProviders.map((provider) => ( | ||||
|               <OAuthButton | ||||
|                 title="Google" | ||||
|                 icon={<GoogleIcon />} | ||||
|                 key={provider.id} | ||||
|                 title={provider.name} | ||||
|                 icon={iconMap[provider.id] ?? <OAuthIcon />} | ||||
|                 className="w-full" | ||||
|                 onClick={() => oauthMutation.mutate("google")} | ||||
|                 onClick={() => oauthMutation.mutate(provider.id)} | ||||
|                 loading={ | ||||
|                   oauthMutation.isPending && | ||||
|                   oauthMutation.variables === "google" | ||||
|                   oauthMutation.variables === provider.id | ||||
|                 } | ||||
|                 disabled={oauthMutation.isPending || loginMutation.isPending} | ||||
|               /> | ||||
|             )} | ||||
|             {configuredProviders.includes("github") && ( | ||||
|               <OAuthButton | ||||
|                 title="Github" | ||||
|                 icon={<GithubIcon />} | ||||
|                 className="w-full" | ||||
|                 onClick={() => oauthMutation.mutate("github")} | ||||
|                 loading={ | ||||
|                   oauthMutation.isPending && | ||||
|                   oauthMutation.variables === "github" | ||||
|                 } | ||||
|                 disabled={oauthMutation.isPending || loginMutation.isPending} | ||||
|               /> | ||||
|             )} | ||||
|             {configuredProviders.includes("generic") && ( | ||||
|               <OAuthButton | ||||
|                 title={genericName} | ||||
|                 icon={<GenericIcon />} | ||||
|                 className="w-full" | ||||
|                 onClick={() => oauthMutation.mutate("generic")} | ||||
|                 loading={ | ||||
|                   oauthMutation.isPending && | ||||
|                   oauthMutation.variables === "generic" | ||||
|                 } | ||||
|                 disabled={oauthMutation.isPending || loginMutation.isPending} | ||||
|               /> | ||||
|             )} | ||||
|             ))} | ||||
|           </div> | ||||
|         )} | ||||
|         {userAuthConfigured && oauthConfigured && ( | ||||
|         {userAuthConfigured && oauthProviders.length !== 0 && ( | ||||
|           <SeperatorWithChildren>{t("loginDivider")}</SeperatorWithChildren> | ||||
|         )} | ||||
|         {userAuthConfigured && ( | ||||
| @@ -189,7 +215,7 @@ export const LoginPage = () => { | ||||
|             loading={loginMutation.isPending || oauthMutation.isPending} | ||||
|           /> | ||||
|         )} | ||||
|         {configuredProviders.length == 0 && ( | ||||
|         {providers.length == 0 && ( | ||||
|           <p className="text-center text-red-600 max-w-sm"> | ||||
|             {t("failedToFetchProvidersTitle")} | ||||
|           </p> | ||||
|   | ||||
| @@ -6,9 +6,7 @@ import { | ||||
|   CardHeader, | ||||
|   CardTitle, | ||||
| } from "@/components/ui/card"; | ||||
| import { useAppContext } from "@/context/app-context"; | ||||
| import { useUserContext } from "@/context/user-context"; | ||||
| import { capitalize } from "@/lib/utils"; | ||||
| import { useMutation } from "@tanstack/react-query"; | ||||
| import axios from "axios"; | ||||
| import { useEffect, useRef } from "react"; | ||||
| @@ -17,8 +15,7 @@ import { Navigate } from "react-router"; | ||||
| import { toast } from "sonner"; | ||||
|  | ||||
| export const LogoutPage = () => { | ||||
|   const { provider, username, isLoggedIn, email } = useUserContext(); | ||||
|   const { genericName } = useAppContext(); | ||||
|   const { provider, username, isLoggedIn, email, oauthName } = useUserContext(); | ||||
|   const { t } = useTranslation(); | ||||
|  | ||||
|   const redirectTimer = useRef<number | null>(null); | ||||
| @@ -67,8 +64,7 @@ export const LogoutPage = () => { | ||||
|               }} | ||||
|               values={{ | ||||
|                 username: email, | ||||
|                 provider: | ||||
|                   provider === "generic" ? genericName : capitalize(provider), | ||||
|                 provider: oauthName, | ||||
|               }} | ||||
|             /> | ||||
|           ) : ( | ||||
|   | ||||
| @@ -1,14 +1,19 @@ | ||||
| import { z } from "zod"; | ||||
|  | ||||
| export const providerSchema = z.object({ | ||||
|   id: z.string(), | ||||
|   name: z.string(), | ||||
|   oauth: z.boolean(), | ||||
| }); | ||||
|  | ||||
| export const appContextSchema = z.object({ | ||||
|   configuredProviders: z.array(z.string()), | ||||
|   providers: z.array(providerSchema), | ||||
|   title: z.string(), | ||||
|   genericName: z.string(), | ||||
|   appUrl: z.string(), | ||||
|   cookieDomain: z.string(), | ||||
|   forgotPasswordMessage: z.string(), | ||||
|   oauthAutoRedirect: z.enum(["none", "github", "google", "generic"]), | ||||
|   backgroundImage: z.string(), | ||||
|   oauthAutoRedirect: z.string(), | ||||
| }); | ||||
|  | ||||
| export type AppContextSchema = z.infer<typeof appContextSchema>; | ||||
|   | ||||
| @@ -8,6 +8,7 @@ export const userContextSchema = z.object({ | ||||
|   provider: z.string(), | ||||
|   oauth: z.boolean(), | ||||
|   totpPending: z.boolean(), | ||||
|   oauthName: z.string(), | ||||
| }); | ||||
|  | ||||
| export type UserContextSchema = z.infer<typeof userContextSchema>; | ||||
|   | ||||
							
								
								
									
										1
									
								
								internal/assets/migrations/000002_oauth_name.down.sql
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								internal/assets/migrations/000002_oauth_name.down.sql
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1 @@ | ||||
| ALTER TABLE "sessions" DROP COLUMN "oauth_name"; | ||||
							
								
								
									
										10
									
								
								internal/assets/migrations/000002_oauth_name.up.sql
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								internal/assets/migrations/000002_oauth_name.up.sql
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,10 @@ | ||||
| ALTER TABLE "sessions" ADD COLUMN "oauth_name" TEXT; | ||||
|  | ||||
| UPDATE "sessions" | ||||
| SET "oauth_name" = CASE | ||||
|   WHEN LOWER("provider") = 'github' THEN 'GitHub' | ||||
|   WHEN LOWER("provider") = 'google' THEN 'Google' | ||||
|   ELSE UPPER(SUBSTR("provider", 1, 1)) || SUBSTR("provider", 2) | ||||
| END | ||||
| WHERE "oauth_name" IS NULL AND "provider" IS NOT NULL; | ||||
|  | ||||
| @@ -3,6 +3,7 @@ package bootstrap | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"net/url" | ||||
| 	"os" | ||||
| 	"strings" | ||||
| 	"tinyauth/internal/config" | ||||
| 	"tinyauth/internal/controller" | ||||
| @@ -45,6 +46,13 @@ func (app *BootstrapApp) Setup() error { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	// Get OAuth configs | ||||
| 	oauthProviders, err := utils.GetOAuthProvidersConfig(os.Environ(), os.Args, app.Config.AppURL) | ||||
|  | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	// Get cookie domain | ||||
| 	cookieDomain, err := utils.GetCookieDomain(app.Config.AppURL) | ||||
|  | ||||
| @@ -112,7 +120,7 @@ func (app *BootstrapApp) Setup() error { | ||||
| 	// Create services | ||||
| 	dockerService := service.NewDockerService() | ||||
| 	authService := service.NewAuthService(authConfig, dockerService, ldapService, database) | ||||
| 	oauthBrokerService := service.NewOAuthBrokerService(app.getOAuthBrokerConfig()) | ||||
| 	oauthBrokerService := service.NewOAuthBrokerService(oauthProviders) | ||||
|  | ||||
| 	// Initialize services | ||||
| 	services := []Service{ | ||||
| @@ -132,13 +140,41 @@ func (app *BootstrapApp) Setup() error { | ||||
| 	} | ||||
|  | ||||
| 	// Configured providers | ||||
| 	var configuredProviders []string | ||||
| 	babysit := map[string]string{ | ||||
| 		"google": "Google", | ||||
| 		"github": "GitHub", | ||||
| 	} | ||||
| 	configuredProviders := make([]controller.Provider, 0) | ||||
|  | ||||
| 	if authService.UserAuthConfigured() || ldapService != nil { | ||||
| 		configuredProviders = append(configuredProviders, "username") | ||||
| 	for id, provider := range oauthProviders { | ||||
| 		if id == "" { | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 	configuredProviders = append(configuredProviders, oauthBrokerService.GetConfiguredServices()...) | ||||
| 		if provider.Name == "" { | ||||
| 			if name, ok := babysit[id]; ok { | ||||
| 				provider.Name = name | ||||
| 			} else { | ||||
| 				provider.Name = utils.Capitalize(id) | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		configuredProviders = append(configuredProviders, controller.Provider{ | ||||
| 			Name:  provider.Name, | ||||
| 			ID:    id, | ||||
| 			OAuth: true, | ||||
| 		}) | ||||
| 	} | ||||
|  | ||||
| 	if authService.UserAuthConfigured() || ldapService != nil { | ||||
| 		configuredProviders = append(configuredProviders, controller.Provider{ | ||||
| 			Name:  "Username", | ||||
| 			ID:    "username", | ||||
| 			OAuth: false, | ||||
| 		}) | ||||
| 	} | ||||
|  | ||||
| 	log.Debug().Interface("providers", configuredProviders).Msg("Authentication providers") | ||||
|  | ||||
| 	if len(configuredProviders) == 0 { | ||||
| 		return fmt.Errorf("no authentication providers configured") | ||||
| @@ -179,9 +215,8 @@ func (app *BootstrapApp) Setup() error { | ||||
|  | ||||
| 	// Create controllers | ||||
| 	contextController := controller.NewContextController(controller.ContextControllerConfig{ | ||||
| 		ConfiguredProviders:   configuredProviders, | ||||
| 		Providers:             configuredProviders, | ||||
| 		Title:                 app.Config.Title, | ||||
| 		GenericName:           app.Config.GenericName, | ||||
| 		AppURL:                app.Config.AppURL, | ||||
| 		CookieDomain:          cookieDomain, | ||||
| 		ForgotPasswordMessage: app.Config.ForgotPasswordMessage, | ||||
| @@ -235,30 +270,3 @@ func (app *BootstrapApp) Setup() error { | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // Temporary | ||||
| func (app *BootstrapApp) getOAuthBrokerConfig() map[string]config.OAuthServiceConfig { | ||||
| 	return map[string]config.OAuthServiceConfig{ | ||||
| 		"google": { | ||||
| 			ClientID:     app.Config.GoogleClientId, | ||||
| 			ClientSecret: app.Config.GoogleClientSecret, | ||||
| 			RedirectURL:  fmt.Sprintf("%s/api/oauth/callback/google", app.Config.AppURL), | ||||
| 		}, | ||||
| 		"github": { | ||||
| 			ClientID:     app.Config.GithubClientId, | ||||
| 			ClientSecret: app.Config.GithubClientSecret, | ||||
| 			RedirectURL:  fmt.Sprintf("%s/api/oauth/callback/github", app.Config.AppURL), | ||||
| 		}, | ||||
| 		"generic": { | ||||
| 			ClientID:           app.Config.GenericClientId, | ||||
| 			ClientSecret:       app.Config.GenericClientSecret, | ||||
| 			RedirectURL:        fmt.Sprintf("%s/api/oauth/callback/generic", app.Config.AppURL), | ||||
| 			Scopes:             strings.Split(app.Config.GenericScopes, ","), | ||||
| 			AuthURL:            app.Config.GenericAuthURL, | ||||
| 			TokenURL:           app.Config.GenericTokenURL, | ||||
| 			UserinfoURL:        app.Config.GenericUserURL, | ||||
| 			InsecureSkipVerify: app.Config.GenericSkipSSL, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| } | ||||
|   | ||||
| @@ -21,23 +21,8 @@ type Config struct { | ||||
| 	Users                 string `mapstructure:"users"` | ||||
| 	UsersFile             string `mapstructure:"users-file"` | ||||
| 	SecureCookie          bool   `mapstructure:"secure-cookie"` | ||||
| 	GithubClientId          string `mapstructure:"github-client-id"` | ||||
| 	GithubClientSecret      string `mapstructure:"github-client-secret"` | ||||
| 	GithubClientSecretFile  string `mapstructure:"github-client-secret-file"` | ||||
| 	GoogleClientId          string `mapstructure:"google-client-id"` | ||||
| 	GoogleClientSecret      string `mapstructure:"google-client-secret"` | ||||
| 	GoogleClientSecretFile  string `mapstructure:"google-client-secret-file"` | ||||
| 	GenericClientId         string `mapstructure:"generic-client-id"` | ||||
| 	GenericClientSecret     string `mapstructure:"generic-client-secret"` | ||||
| 	GenericClientSecretFile string `mapstructure:"generic-client-secret-file"` | ||||
| 	GenericScopes           string `mapstructure:"generic-scopes"` | ||||
| 	GenericAuthURL          string `mapstructure:"generic-auth-url"` | ||||
| 	GenericTokenURL         string `mapstructure:"generic-token-url"` | ||||
| 	GenericUserURL          string `mapstructure:"generic-user-url"` | ||||
| 	GenericName             string `mapstructure:"generic-name"` | ||||
| 	GenericSkipSSL          bool   `mapstructure:"generic-skip-ssl"` | ||||
| 	OAuthWhitelist        string `mapstructure:"oauth-whitelist"` | ||||
| 	OAuthAutoRedirect       string `mapstructure:"oauth-auto-redirect" validate:"oneof=none github google generic"` | ||||
| 	OAuthAutoRedirect     string `mapstructure:"oauth-auto-redirect"` | ||||
| 	SessionExpiry         int    `mapstructure:"session-expiry"` | ||||
| 	LogLevel              string `mapstructure:"log-level" validate:"oneof=trace debug info warn error fatal panic"` | ||||
| 	Title                 string `mapstructure:"app-title"` | ||||
| @@ -66,14 +51,16 @@ type Claims struct { | ||||
| } | ||||
|  | ||||
| type OAuthServiceConfig struct { | ||||
| 	ClientID           string | ||||
| 	ClientSecret       string | ||||
| 	Scopes             []string | ||||
| 	RedirectURL        string | ||||
| 	AuthURL            string | ||||
| 	TokenURL           string | ||||
| 	UserinfoURL        string | ||||
| 	InsecureSkipVerify bool | ||||
| 	ClientID           string   `key:"client-id"` | ||||
| 	ClientSecret       string   `key:"client-secret"` | ||||
| 	ClientSecretFile   string   `key:"client-secret-file"` | ||||
| 	Scopes             []string `key:"scopes"` | ||||
| 	RedirectURL        string   `key:"redirect-url"` | ||||
| 	AuthURL            string   `key:"auth-url"` | ||||
| 	TokenURL           string   `key:"token-url"` | ||||
| 	UserinfoURL        string   `key:"user-info-url"` | ||||
| 	InsecureSkipVerify bool     `key:"insecure-skip-verify"` | ||||
| 	Name               string   `key:"name"` | ||||
| } | ||||
|  | ||||
| // User/session related stuff | ||||
| @@ -97,6 +84,7 @@ type SessionCookie struct { | ||||
| 	Provider    string | ||||
| 	TotpPending bool | ||||
| 	OAuthGroups string | ||||
| 	OAuthName   string | ||||
| } | ||||
|  | ||||
| type UserContext struct { | ||||
| @@ -109,6 +97,7 @@ type UserContext struct { | ||||
| 	TotpPending bool | ||||
| 	OAuthGroups string | ||||
| 	TotpEnabled bool | ||||
| 	OAuthName   string | ||||
| } | ||||
|  | ||||
| // API responses and queries | ||||
| @@ -174,3 +163,9 @@ type AppPath struct { | ||||
| 	Allow string | ||||
| 	Block string | ||||
| } | ||||
|  | ||||
| // Flags | ||||
|  | ||||
| type Providers struct { | ||||
| 	Providers map[string]OAuthServiceConfig | ||||
| } | ||||
|   | ||||
| @@ -19,14 +19,14 @@ type UserContextResponse struct { | ||||
| 	Provider    string `json:"provider"` | ||||
| 	OAuth       bool   `json:"oauth"` | ||||
| 	TotpPending bool   `json:"totpPending"` | ||||
| 	OAuthName   string `json:"oauthName"` | ||||
| } | ||||
|  | ||||
| type AppContextResponse struct { | ||||
| 	Status                int        `json:"status"` | ||||
| 	Message               string     `json:"message"` | ||||
| 	ConfiguredProviders   []string `json:"configuredProviders"` | ||||
| 	Providers             []Provider `json:"providers"` | ||||
| 	Title                 string     `json:"title"` | ||||
| 	GenericName           string   `json:"genericName"` | ||||
| 	AppURL                string     `json:"appUrl"` | ||||
| 	CookieDomain          string     `json:"cookieDomain"` | ||||
| 	ForgotPasswordMessage string     `json:"forgotPasswordMessage"` | ||||
| @@ -34,10 +34,15 @@ type AppContextResponse struct { | ||||
| 	OAuthAutoRedirect     string     `json:"oauthAutoRedirect"` | ||||
| } | ||||
|  | ||||
| type Provider struct { | ||||
| 	Name  string `json:"name"` | ||||
| 	ID    string `json:"id"` | ||||
| 	OAuth bool   `json:"oauth"` | ||||
| } | ||||
|  | ||||
| type ContextControllerConfig struct { | ||||
| 	ConfiguredProviders   []string | ||||
| 	Providers             []Provider | ||||
| 	Title                 string | ||||
| 	GenericName           string | ||||
| 	AppURL                string | ||||
| 	CookieDomain          string | ||||
| 	ForgotPasswordMessage string | ||||
| @@ -76,6 +81,7 @@ func (controller *ContextController) userContextHandler(c *gin.Context) { | ||||
| 		Provider:    context.Provider, | ||||
| 		OAuth:       context.OAuth, | ||||
| 		TotpPending: context.TotpPending, | ||||
| 		OAuthName:   context.OAuthName, | ||||
| 	} | ||||
|  | ||||
| 	if err != nil { | ||||
| @@ -96,9 +102,8 @@ func (controller *ContextController) appContextHandler(c *gin.Context) { | ||||
| 	c.JSON(200, AppContextResponse{ | ||||
| 		Status:                200, | ||||
| 		Message:               "Success", | ||||
| 		ConfiguredProviders:   controller.config.ConfiguredProviders, | ||||
| 		Providers:             controller.config.Providers, | ||||
| 		Title:                 controller.config.Title, | ||||
| 		GenericName:           controller.config.GenericName, | ||||
| 		AppURL:                fmt.Sprintf("%s://%s", appUrl.Scheme, appUrl.Host), | ||||
| 		CookieDomain:          controller.config.CookieDomain, | ||||
| 		ForgotPasswordMessage: controller.config.ForgotPasswordMessage, | ||||
|   | ||||
| @@ -12,9 +12,19 @@ import ( | ||||
| ) | ||||
|  | ||||
| var controllerCfg = controller.ContextControllerConfig{ | ||||
| 	ConfiguredProviders:   []string{"github", "google", "generic"}, | ||||
| 	Providers: []controller.Provider{ | ||||
| 		{ | ||||
| 			Name:  "Username", | ||||
| 			ID:    "username", | ||||
| 			OAuth: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Name:  "Google", | ||||
| 			ID:    "google", | ||||
| 			OAuth: true, | ||||
| 		}, | ||||
| 	}, | ||||
| 	Title:                 "Test App", | ||||
| 	GenericName:           "Generic", | ||||
| 	AppURL:                "http://localhost:8080", | ||||
| 	CookieDomain:          "localhost", | ||||
| 	ForgotPasswordMessage: "Contact admin to reset your password.", | ||||
| @@ -58,9 +68,8 @@ func TestAppContextHandler(t *testing.T) { | ||||
| 	expectedRes := controller.AppContextResponse{ | ||||
| 		Status:                200, | ||||
| 		Message:               "Success", | ||||
| 		ConfiguredProviders:   controllerCfg.ConfiguredProviders, | ||||
| 		Providers:             controllerCfg.Providers, | ||||
| 		Title:                 controllerCfg.Title, | ||||
| 		GenericName:           controllerCfg.GenericName, | ||||
| 		AppURL:                controllerCfg.AppURL, | ||||
| 		CookieDomain:          controllerCfg.CookieDomain, | ||||
| 		ForgotPasswordMessage: controllerCfg.ForgotPasswordMessage, | ||||
|   | ||||
| @@ -186,6 +186,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { | ||||
| 		Email:       user.Email, | ||||
| 		Provider:    req.Provider, | ||||
| 		OAuthGroups: utils.CoalesceToString(user.Groups), | ||||
| 		OAuthName:   service.GetName(), | ||||
| 	}) | ||||
|  | ||||
| 	if err != nil { | ||||
|   | ||||
| @@ -95,6 +95,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { | ||||
| 				Email:       cookie.Email, | ||||
| 				Provider:    cookie.Provider, | ||||
| 				OAuthGroups: cookie.OAuthGroups, | ||||
| 				OAuthName:   cookie.OAuthName, | ||||
| 				IsLoggedIn:  true, | ||||
| 				OAuth:       true, | ||||
| 			}) | ||||
|   | ||||
| @@ -9,4 +9,5 @@ type Session struct { | ||||
| 	TOTPPending bool   `gorm:"column:totp_pending"` | ||||
| 	OAuthGroups string `gorm:"column:oauth_groups"` | ||||
| 	Expiry      int64  `gorm:"column:expiry"` | ||||
| 	OAuthName   string `gorm:"column:oauth_name"` | ||||
| } | ||||
|   | ||||
| @@ -210,6 +210,7 @@ func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *config.Sessio | ||||
| 		TOTPPending: data.TotpPending, | ||||
| 		OAuthGroups: data.OAuthGroups, | ||||
| 		Expiry:      time.Now().Add(time.Duration(expiry) * time.Second).Unix(), | ||||
| 		OAuthName:   data.OAuthName, | ||||
| 	} | ||||
|  | ||||
| 	err = auth.database.Create(&session).Error | ||||
| @@ -278,6 +279,7 @@ func (auth *AuthService) GetSessionCookie(c *gin.Context) (config.SessionCookie, | ||||
| 		Provider:    session.Provider, | ||||
| 		TotpPending: session.TOTPPending, | ||||
| 		OAuthGroups: session.OAuthGroups, | ||||
| 		OAuthName:   session.OAuthName, | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -22,6 +22,7 @@ type GenericOAuthService struct { | ||||
| 	verifier           string | ||||
| 	insecureSkipVerify bool | ||||
| 	userinfoUrl        string | ||||
| 	name               string | ||||
| } | ||||
|  | ||||
| func NewGenericOAuthService(config config.OAuthServiceConfig) *GenericOAuthService { | ||||
| @@ -38,6 +39,7 @@ func NewGenericOAuthService(config config.OAuthServiceConfig) *GenericOAuthServi | ||||
| 		}, | ||||
| 		insecureSkipVerify: config.InsecureSkipVerify, | ||||
| 		userinfoUrl:        config.UserinfoURL, | ||||
| 		name:               config.Name, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -115,3 +117,7 @@ func (generic *GenericOAuthService) Userinfo() (config.Claims, error) { | ||||
|  | ||||
| 	return user, nil | ||||
| } | ||||
|  | ||||
| func (generic *GenericOAuthService) GetName() string { | ||||
| 	return generic.name | ||||
| } | ||||
|   | ||||
| @@ -33,6 +33,7 @@ type GithubOAuthService struct { | ||||
| 	context  context.Context | ||||
| 	token    *oauth2.Token | ||||
| 	verifier string | ||||
| 	name     string | ||||
| } | ||||
|  | ||||
| func NewGithubOAuthService(config config.OAuthServiceConfig) *GithubOAuthService { | ||||
| @@ -44,6 +45,7 @@ func NewGithubOAuthService(config config.OAuthServiceConfig) *GithubOAuthService | ||||
| 			Scopes:       GithubOAuthScopes, | ||||
| 			Endpoint:     endpoints.GitHub, | ||||
| 		}, | ||||
| 		name: config.Name, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -167,3 +169,7 @@ func (github *GithubOAuthService) Userinfo() (config.Claims, error) { | ||||
|  | ||||
| 	return user, nil | ||||
| } | ||||
|  | ||||
| func (github *GithubOAuthService) GetName() string { | ||||
| 	return github.name | ||||
| } | ||||
|   | ||||
| @@ -28,6 +28,7 @@ type GoogleOAuthService struct { | ||||
| 	context  context.Context | ||||
| 	token    *oauth2.Token | ||||
| 	verifier string | ||||
| 	name     string | ||||
| } | ||||
|  | ||||
| func NewGoogleOAuthService(config config.OAuthServiceConfig) *GoogleOAuthService { | ||||
| @@ -39,6 +40,7 @@ func NewGoogleOAuthService(config config.OAuthServiceConfig) *GoogleOAuthService | ||||
| 			Scopes:       GoogleOAuthScopes, | ||||
| 			Endpoint:     endpoints.Google, | ||||
| 		}, | ||||
| 		name: config.Name, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -111,3 +113,7 @@ func (google *GoogleOAuthService) Userinfo() (config.Claims, error) { | ||||
|  | ||||
| 	return user, nil | ||||
| } | ||||
|  | ||||
| func (google *GoogleOAuthService) GetName() string { | ||||
| 	return google.name | ||||
| } | ||||
|   | ||||
| @@ -14,6 +14,7 @@ type OAuthService interface { | ||||
| 	GetAuthURL(state string) string | ||||
| 	VerifyCode(code string) error | ||||
| 	Userinfo() (config.Claims, error) | ||||
| 	GetName() string | ||||
| } | ||||
|  | ||||
| type OAuthBrokerService struct { | ||||
|   | ||||
| @@ -6,6 +6,9 @@ import ( | ||||
| 	"net/url" | ||||
| 	"strings" | ||||
| 	"tinyauth/internal/config" | ||||
| 	"tinyauth/internal/utils/decoders" | ||||
|  | ||||
| 	"maps" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/rs/zerolog" | ||||
| @@ -130,3 +133,68 @@ func GetLogLevel(level string) zerolog.Level { | ||||
| 		return zerolog.InfoLevel | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func GetOAuthProvidersConfig(env []string, args []string, appUrl string) (map[string]config.OAuthServiceConfig, error) { | ||||
| 	providers := make(map[string]config.OAuthServiceConfig) | ||||
|  | ||||
| 	// Get from environment variables | ||||
| 	envMap := make(map[string]string) | ||||
|  | ||||
| 	for _, e := range env { | ||||
| 		pair := strings.SplitN(e, "=", 2) | ||||
| 		if len(pair) == 2 { | ||||
| 			envMap[pair[0]] = pair[1] | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	envProviders, err := decoders.DecodeEnv(envMap) | ||||
|  | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	maps.Copy(providers, envProviders.Providers) | ||||
|  | ||||
| 	// Get from flags | ||||
| 	flagsMap := make(map[string]string) | ||||
|  | ||||
| 	for _, arg := range args[1:] { | ||||
| 		if strings.HasPrefix(arg, "--") { | ||||
| 			pair := strings.SplitN(arg[2:], "=", 2) | ||||
| 			if len(pair) == 2 { | ||||
| 				flagsMap[pair[0]] = pair[1] | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	flagProviders, err := decoders.DecodeFlags(flagsMap) | ||||
|  | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	maps.Copy(providers, flagProviders.Providers) | ||||
|  | ||||
| 	// For every provider get correct secret from file if set | ||||
| 	for name, provider := range providers { | ||||
| 		secret := GetSecret(provider.ClientSecret, provider.ClientSecretFile) | ||||
| 		provider.ClientSecret = secret | ||||
| 		provider.ClientSecretFile = "" | ||||
| 		providers[name] = provider | ||||
| 	} | ||||
|  | ||||
| 	// If we have google/github providers and no redirect URL babysit them | ||||
| 	babysitProviders := []string{"google", "github"} | ||||
|  | ||||
| 	for _, name := range babysitProviders { | ||||
| 		if provider, exists := providers[name]; exists { | ||||
| 			if provider.RedirectURL == "" { | ||||
| 				provider.RedirectURL = appUrl + "/api/oauth/callback/" + name | ||||
| 				providers[name] = provider | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// Return combined providers | ||||
| 	return providers, nil | ||||
| } | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| package utils_test | ||||
|  | ||||
| import ( | ||||
| 	"os" | ||||
| 	"testing" | ||||
| 	"tinyauth/internal/config" | ||||
| 	"tinyauth/internal/utils" | ||||
| @@ -200,3 +201,71 @@ func TestIsRedirectSafe(t *testing.T) { | ||||
| 	result = utils.IsRedirectSafe(redirectURL, domain) | ||||
| 	assert.Equal(t, false, result) | ||||
| } | ||||
|  | ||||
| func TestGetOAuthProvidersConfig(t *testing.T) { | ||||
| 	env := []string{"PROVIDERS_CLIENT1_CLIENT_ID=client1-id", "PROVIDERS_CLIENT1_CLIENT_SECRET=client1-secret"} | ||||
| 	args := []string{"/tinyauth/tinyauth", "--providers-client2-client-id=client2-id", "--providers-client2-client-secret=client2-secret"} | ||||
|  | ||||
| 	expected := map[string]config.OAuthServiceConfig{ | ||||
| 		"client1": { | ||||
| 			ClientID:     "client1-id", | ||||
| 			ClientSecret: "client1-secret", | ||||
| 		}, | ||||
| 		"client2": { | ||||
| 			ClientID:     "client2-id", | ||||
| 			ClientSecret: "client2-secret", | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	result, err := utils.GetOAuthProvidersConfig(env, args, "") | ||||
| 	assert.NilError(t, err) | ||||
| 	assert.DeepEqual(t, expected, result) | ||||
|  | ||||
| 	// Case with no providers | ||||
| 	env = []string{} | ||||
| 	args = []string{"/tinyauth/tinyauth"} | ||||
| 	expected = map[string]config.OAuthServiceConfig{} | ||||
|  | ||||
| 	result, err = utils.GetOAuthProvidersConfig(env, args, "") | ||||
| 	assert.NilError(t, err) | ||||
| 	assert.DeepEqual(t, expected, result) | ||||
|  | ||||
| 	// Case with secret from file | ||||
| 	file, err := os.Create("/tmp/tinyauth_test_file") | ||||
| 	assert.NilError(t, err) | ||||
|  | ||||
| 	_, err = file.WriteString("file content\n") | ||||
| 	assert.NilError(t, err) | ||||
|  | ||||
| 	err = file.Close() | ||||
| 	assert.NilError(t, err) | ||||
| 	defer os.Remove("/tmp/tinyauth_test_file") | ||||
|  | ||||
| 	env = []string{"PROVIDERS_CLIENT1_CLIENT_ID=client1-id", "PROVIDERS_CLIENT1_CLIENT_SECRET_FILE=/tmp/tinyauth_test_file"} | ||||
| 	args = []string{"/tinyauth/tinyauth"} | ||||
| 	expected = map[string]config.OAuthServiceConfig{ | ||||
| 		"client1": { | ||||
| 			ClientID:     "client1-id", | ||||
| 			ClientSecret: "file content", | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	result, err = utils.GetOAuthProvidersConfig(env, args, "") | ||||
| 	assert.NilError(t, err) | ||||
| 	assert.DeepEqual(t, expected, result) | ||||
|  | ||||
| 	// Case with google provider and no redirect URL | ||||
| 	env = []string{"PROVIDERS_GOOGLE_CLIENT_ID=google-id", "PROVIDERS_GOOGLE_CLIENT_SECRET=google-secret"} | ||||
| 	args = []string{"/tinyauth/tinyauth"} | ||||
| 	expected = map[string]config.OAuthServiceConfig{ | ||||
| 		"google": { | ||||
| 			ClientID:     "google-id", | ||||
| 			ClientSecret: "google-secret", | ||||
| 			RedirectURL:  "http://app.url/api/oauth/callback/google", | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	result, err = utils.GetOAuthProvidersConfig(env, args, "http://app.url") | ||||
| 	assert.NilError(t, err) | ||||
| 	assert.DeepEqual(t, expected, result) | ||||
| } | ||||
|   | ||||
							
								
								
									
										81
									
								
								internal/utils/decoders/decoders.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										81
									
								
								internal/utils/decoders/decoders.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,81 @@ | ||||
| package decoders | ||||
|  | ||||
| import ( | ||||
| 	"reflect" | ||||
| 	"strings" | ||||
| 	"tinyauth/internal/config" | ||||
| ) | ||||
|  | ||||
| func NormalizeKeys(keys map[string]string, rootName string, sep string) map[string]string { | ||||
| 	normalized := make(map[string]string) | ||||
| 	knownKeys := getKnownKeys() | ||||
|  | ||||
| 	for k, v := range keys { | ||||
| 		var finalKey []string | ||||
| 		var suffix string | ||||
| 		var camelClientName string | ||||
| 		var camelField string | ||||
|  | ||||
| 		finalKey = append(finalKey, rootName) | ||||
| 		finalKey = append(finalKey, "providers") | ||||
| 		cebabKey := strings.ToLower(k) | ||||
|  | ||||
| 		for _, known := range knownKeys { | ||||
| 			if strings.HasSuffix(cebabKey, strings.ReplaceAll(known, "-", sep)) { | ||||
| 				suffix = known | ||||
| 				break | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		if suffix == "" { | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		clientNameParts := strings.Split(strings.TrimPrefix(strings.TrimSuffix(cebabKey, sep+strings.ReplaceAll(suffix, "-", sep)), "providers"+sep), sep) | ||||
|  | ||||
| 		for i, p := range clientNameParts { | ||||
| 			if i == 0 { | ||||
| 				camelClientName += p | ||||
| 				continue | ||||
| 			} | ||||
| 			if p == "" { | ||||
| 				continue | ||||
| 			} | ||||
| 			camelClientName += strings.ToUpper(string([]rune(p)[0])) + string([]rune(p)[1:]) | ||||
| 		} | ||||
|  | ||||
| 		finalKey = append(finalKey, camelClientName) | ||||
|  | ||||
| 		filedParts := strings.Split(suffix, "-") | ||||
|  | ||||
| 		for i, p := range filedParts { | ||||
| 			if i == 0 { | ||||
| 				camelField += p | ||||
| 				continue | ||||
| 			} | ||||
| 			if p == "" { | ||||
| 				continue | ||||
| 			} | ||||
| 			camelField += strings.ToUpper(string([]rune(p)[0])) + string([]rune(p)[1:]) | ||||
| 		} | ||||
|  | ||||
| 		finalKey = append(finalKey, camelField) | ||||
| 		normalized[strings.Join(finalKey, ".")] = v | ||||
| 	} | ||||
|  | ||||
| 	return normalized | ||||
| } | ||||
|  | ||||
| func getKnownKeys() []string { | ||||
| 	var known []string | ||||
|  | ||||
| 	p := config.OAuthServiceConfig{} | ||||
| 	v := reflect.ValueOf(p) | ||||
| 	typeOfP := v.Type() | ||||
|  | ||||
| 	for field := range typeOfP.NumField() { | ||||
| 		known = append(known, typeOfP.Field(field).Tag.Get("key")) | ||||
| 	} | ||||
|  | ||||
| 	return known | ||||
| } | ||||
							
								
								
									
										44
									
								
								internal/utils/decoders/decoders_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								internal/utils/decoders/decoders_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,44 @@ | ||||
| package decoders_test | ||||
|  | ||||
| import ( | ||||
| 	"testing" | ||||
| 	"tinyauth/internal/utils/decoders" | ||||
|  | ||||
| 	"gotest.tools/v3/assert" | ||||
| ) | ||||
|  | ||||
| func TestNormalizeKeys(t *testing.T) { | ||||
| 	// Test with env | ||||
| 	test := map[string]string{ | ||||
| 		"PROVIDERS_CLIENT1_CLIENT_ID":                    "my-client-id", | ||||
| 		"PROVIDERS_CLIENT1_CLIENT_SECRET":                "my-client-secret", | ||||
| 		"PROVIDERS_MY_AWESOME_CLIENT_CLIENT_ID":          "my-awesome-client-id", | ||||
| 		"PROVIDERS_MY_AWESOME_CLIENT_CLIENT_SECRET_FILE": "/path/to/secret", | ||||
| 	} | ||||
| 	expected := map[string]string{ | ||||
| 		"tinyauth.providers.client1.clientId":                 "my-client-id", | ||||
| 		"tinyauth.providers.client1.clientSecret":             "my-client-secret", | ||||
| 		"tinyauth.providers.myAwesomeClient.clientId":         "my-awesome-client-id", | ||||
| 		"tinyauth.providers.myAwesomeClient.clientSecretFile": "/path/to/secret", | ||||
| 	} | ||||
|  | ||||
| 	normalized := decoders.NormalizeKeys(test, "tinyauth", "_") | ||||
| 	assert.DeepEqual(t, normalized, expected) | ||||
|  | ||||
| 	// Test with flags (assume -- is already stripped) | ||||
| 	test = map[string]string{ | ||||
| 		"providers-client1-client-id":                    "my-client-id", | ||||
| 		"providers-client1-client-secret":                "my-client-secret", | ||||
| 		"providers-my-awesome-client-client-id":          "my-awesome-client-id", | ||||
| 		"providers-my-awesome-client-client-secret-file": "/path/to/secret", | ||||
| 	} | ||||
| 	expected = map[string]string{ | ||||
| 		"tinyauth.providers.client1.clientId":                 "my-client-id", | ||||
| 		"tinyauth.providers.client1.clientSecret":             "my-client-secret", | ||||
| 		"tinyauth.providers.myAwesomeClient.clientId":         "my-awesome-client-id", | ||||
| 		"tinyauth.providers.myAwesomeClient.clientSecretFile": "/path/to/secret", | ||||
| 	} | ||||
|  | ||||
| 	normalized = decoders.NormalizeKeys(test, "tinyauth", "-") | ||||
| 	assert.DeepEqual(t, normalized, expected) | ||||
| } | ||||
							
								
								
									
										20
									
								
								internal/utils/decoders/env_decoder.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								internal/utils/decoders/env_decoder.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,20 @@ | ||||
| package decoders | ||||
|  | ||||
| import ( | ||||
| 	"tinyauth/internal/config" | ||||
|  | ||||
| 	"github.com/traefik/paerser/parser" | ||||
| ) | ||||
|  | ||||
| func DecodeEnv(env map[string]string) (config.Providers, error) { | ||||
| 	normalized := NormalizeKeys(env, "tinyauth", "_") | ||||
| 	var providers config.Providers | ||||
|  | ||||
| 	err := parser.Decode(normalized, &providers, "tinyauth", "tinyauth.providers") | ||||
|  | ||||
| 	if err != nil { | ||||
| 		return config.Providers{}, err | ||||
| 	} | ||||
|  | ||||
| 	return providers, nil | ||||
| } | ||||
							
								
								
									
										60
									
								
								internal/utils/decoders/env_decoder_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										60
									
								
								internal/utils/decoders/env_decoder_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,60 @@ | ||||
| package decoders_test | ||||
|  | ||||
| import ( | ||||
| 	"testing" | ||||
| 	"tinyauth/internal/config" | ||||
| 	"tinyauth/internal/utils/decoders" | ||||
|  | ||||
| 	"gotest.tools/v3/assert" | ||||
| ) | ||||
|  | ||||
| func TestDecodeEnv(t *testing.T) { | ||||
| 	// Variables | ||||
| 	expected := config.Providers{ | ||||
| 		Providers: map[string]config.OAuthServiceConfig{ | ||||
| 			"client1": { | ||||
| 				ClientID:           "client1-id", | ||||
| 				ClientSecret:       "client1-secret", | ||||
| 				Scopes:             []string{"client1-scope1", "client1-scope2"}, | ||||
| 				RedirectURL:        "client1-redirect-url", | ||||
| 				AuthURL:            "client1-auth-url", | ||||
| 				UserinfoURL:        "client1-user-info-url", | ||||
| 				Name:               "Client1", | ||||
| 				InsecureSkipVerify: false, | ||||
| 			}, | ||||
| 			"client2": { | ||||
| 				ClientID:           "client2-id", | ||||
| 				ClientSecret:       "client2-secret", | ||||
| 				Scopes:             []string{"client2-scope1", "client2-scope2"}, | ||||
| 				RedirectURL:        "client2-redirect-url", | ||||
| 				AuthURL:            "client2-auth-url", | ||||
| 				UserinfoURL:        "client2-user-info-url", | ||||
| 				Name:               "My Awesome Client2", | ||||
| 				InsecureSkipVerify: false, | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| 	test := map[string]string{ | ||||
| 		"PROVIDERS_CLIENT1_CLIENT_ID":            "client1-id", | ||||
| 		"PROVIDERS_CLIENT1_CLIENT_SECRET":        "client1-secret", | ||||
| 		"PROVIDERS_CLIENT1_SCOPES":               "client1-scope1,client1-scope2", | ||||
| 		"PROVIDERS_CLIENT1_REDIRECT_URL":         "client1-redirect-url", | ||||
| 		"PROVIDERS_CLIENT1_AUTH_URL":             "client1-auth-url", | ||||
| 		"PROVIDERS_CLIENT1_USER_INFO_URL":        "client1-user-info-url", | ||||
| 		"PROVIDERS_CLIENT1_NAME":                 "Client1", | ||||
| 		"PROVIDERS_CLIENT1_INSECURE_SKIP_VERIFY": "false", | ||||
| 		"PROVIDERS_CLIENT2_CLIENT_ID":            "client2-id", | ||||
| 		"PROVIDERS_CLIENT2_CLIENT_SECRET":        "client2-secret", | ||||
| 		"PROVIDERS_CLIENT2_SCOPES":               "client2-scope1,client2-scope2", | ||||
| 		"PROVIDERS_CLIENT2_REDIRECT_URL":         "client2-redirect-url", | ||||
| 		"PROVIDERS_CLIENT2_AUTH_URL":             "client2-auth-url", | ||||
| 		"PROVIDERS_CLIENT2_USER_INFO_URL":        "client2-user-info-url", | ||||
| 		"PROVIDERS_CLIENT2_NAME":                 "My Awesome Client2", | ||||
| 		"PROVIDERS_CLIENT2_INSECURE_SKIP_VERIFY": "false", | ||||
| 	} | ||||
|  | ||||
| 	// Test | ||||
| 	res, err := decoders.DecodeEnv(test) | ||||
| 	assert.NilError(t, err) | ||||
| 	assert.DeepEqual(t, expected, res) | ||||
| } | ||||
							
								
								
									
										30
									
								
								internal/utils/decoders/flags_decoder.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								internal/utils/decoders/flags_decoder.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,30 @@ | ||||
| package decoders | ||||
|  | ||||
| import ( | ||||
| 	"strings" | ||||
| 	"tinyauth/internal/config" | ||||
|  | ||||
| 	"github.com/traefik/paerser/parser" | ||||
| ) | ||||
|  | ||||
| func DecodeFlags(flags map[string]string) (config.Providers, error) { | ||||
| 	filtered := filterFlags(flags) | ||||
| 	normalized := NormalizeKeys(filtered, "tinyauth", "-") | ||||
| 	var providers config.Providers | ||||
|  | ||||
| 	err := parser.Decode(normalized, &providers, "tinyauth", "tinyauth.providers") | ||||
|  | ||||
| 	if err != nil { | ||||
| 		return config.Providers{}, err | ||||
| 	} | ||||
|  | ||||
| 	return providers, nil | ||||
| } | ||||
|  | ||||
| func filterFlags(flags map[string]string) map[string]string { | ||||
| 	filtered := make(map[string]string) | ||||
| 	for k, v := range flags { | ||||
| 		filtered[strings.TrimPrefix(k, "--")] = v | ||||
| 	} | ||||
| 	return filtered | ||||
| } | ||||
							
								
								
									
										60
									
								
								internal/utils/decoders/flags_decoder_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										60
									
								
								internal/utils/decoders/flags_decoder_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,60 @@ | ||||
| package decoders_test | ||||
|  | ||||
| import ( | ||||
| 	"testing" | ||||
| 	"tinyauth/internal/config" | ||||
| 	"tinyauth/internal/utils/decoders" | ||||
|  | ||||
| 	"gotest.tools/v3/assert" | ||||
| ) | ||||
|  | ||||
| func TestDecodeFlags(t *testing.T) { | ||||
| 	// Variables | ||||
| 	expected := config.Providers{ | ||||
| 		Providers: map[string]config.OAuthServiceConfig{ | ||||
| 			"client1": { | ||||
| 				ClientID:           "client1-id", | ||||
| 				ClientSecret:       "client1-secret", | ||||
| 				Scopes:             []string{"client1-scope1", "client1-scope2"}, | ||||
| 				RedirectURL:        "client1-redirect-url", | ||||
| 				AuthURL:            "client1-auth-url", | ||||
| 				UserinfoURL:        "client1-user-info-url", | ||||
| 				Name:               "Client1", | ||||
| 				InsecureSkipVerify: false, | ||||
| 			}, | ||||
| 			"client2": { | ||||
| 				ClientID:           "client2-id", | ||||
| 				ClientSecret:       "client2-secret", | ||||
| 				Scopes:             []string{"client2-scope1", "client2-scope2"}, | ||||
| 				RedirectURL:        "client2-redirect-url", | ||||
| 				AuthURL:            "client2-auth-url", | ||||
| 				UserinfoURL:        "client2-user-info-url", | ||||
| 				Name:               "My Awesome Client2", | ||||
| 				InsecureSkipVerify: false, | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| 	test := map[string]string{ | ||||
| 		"--providers-client1-client-id":            "client1-id", | ||||
| 		"--providers-client1-client-secret":        "client1-secret", | ||||
| 		"--providers-client1-scopes":               "client1-scope1,client1-scope2", | ||||
| 		"--providers-client1-redirect-url":         "client1-redirect-url", | ||||
| 		"--providers-client1-auth-url":             "client1-auth-url", | ||||
| 		"--providers-client1-user-info-url":        "client1-user-info-url", | ||||
| 		"--providers-client1-name":                 "Client1", | ||||
| 		"--providers-client1-insecure-skip-verify": "false", | ||||
| 		"--providers-client2-client-id":            "client2-id", | ||||
| 		"--providers-client2-client-secret":        "client2-secret", | ||||
| 		"--providers-client2-scopes":               "client2-scope1,client2-scope2", | ||||
| 		"--providers-client2-redirect-url":         "client2-redirect-url", | ||||
| 		"--providers-client2-auth-url":             "client2-auth-url", | ||||
| 		"--providers-client2-user-info-url":        "client2-user-info-url", | ||||
| 		"--providers-client2-name":                 "My Awesome Client2", | ||||
| 		"--providers-client2-insecure-skip-verify": "false", | ||||
| 	} | ||||
|  | ||||
| 	// Test | ||||
| 	res, err := decoders.DecodeFlags(test) | ||||
| 	assert.NilError(t, err) | ||||
| 	assert.DeepEqual(t, expected, res) | ||||
| } | ||||
| @@ -1,10 +1,11 @@ | ||||
| package decoders_test | ||||
|  | ||||
| import ( | ||||
| 	"reflect" | ||||
| 	"testing" | ||||
| 	"tinyauth/internal/config" | ||||
| 	"tinyauth/internal/utils/decoders" | ||||
|  | ||||
| 	"gotest.tools/v3/assert" | ||||
| ) | ||||
|  | ||||
| func TestDecodeLabels(t *testing.T) { | ||||
| @@ -62,12 +63,6 @@ func TestDecodeLabels(t *testing.T) { | ||||
|  | ||||
| 	// Test | ||||
| 	result, err := decoders.DecodeLabels(test) | ||||
|  | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Unexpected error: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	if reflect.DeepEqual(expected, result) == false { | ||||
| 		t.Fatalf("Expected %v but got %v", expected, result) | ||||
| 	} | ||||
| 	assert.NilError(t, err) | ||||
| 	assert.DeepEqual(t, expected, result) | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user