mirror of
				https://github.com/steveiliop56/tinyauth.git
				synced 2025-10-30 21:55:43 +00:00 
			
		
		
		
	Compare commits
	
		
			17 Commits
		
	
	
		
			v3.3.0-alp
			...
			v3.3.0-alp
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| ![dependabot[bot]](/assets/img/avatar_default.png)  | 5278fbea68 | ||
|   | 773942dc3b | ||
|   | 83483d6374 | ||
|   | aab01b3195 | ||
| ![github-actions[bot]](/assets/img/avatar_default.png)  | fe5e07139f | ||
|   | 93a75324b8 | ||
| ![github-actions[bot]](/assets/img/avatar_default.png)  | 67a01c196f | ||
|   | 483b1de701 | ||
|   | 40ceed6686 | ||
|   | 3878c629c6 | ||
|   | 31e874a34f | ||
| ![dependabot[bot]](/assets/img/avatar_default.png)  | 74a346349a | ||
|   | a9e8bf89a9 | ||
|   | f824b84787 | ||
| ![dependabot[bot]](/assets/img/avatar_default.png)  | 71b0c301f6 | ||
|   | 1c738b718a | ||
|   | 4dc6bc0c98 | 
| @@ -28,3 +28,4 @@ LOGIN_MAX_RETRIES=5 | |||||||
| LOG_LEVEL=0 | LOG_LEVEL=0 | ||||||
| APP_TITLE=Tinyauth SSO | APP_TITLE=Tinyauth SSO | ||||||
| FORGOT_PASSWORD_MESSAGE=Some message about resetting the password | FORGOT_PASSWORD_MESSAGE=Some message about resetting the password | ||||||
|  | OAUTH_AUTO_REDIRECT=none | ||||||
							
								
								
									
										30
									
								
								.github/workflows/sponsors.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								.github/workflows/sponsors.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,30 @@ | |||||||
|  | name: Generate Sponsors List | ||||||
|  | on: | ||||||
|  |   workflow_dispatch: | ||||||
|  |  | ||||||
|  | jobs: | ||||||
|  |   generate-sponsors: | ||||||
|  |     runs-on: ubuntu-latest | ||||||
|  |     steps: | ||||||
|  |       - name: Checkout | ||||||
|  |         uses: actions/checkout@v4 | ||||||
|  |  | ||||||
|  |       - name: Generate Sponsors | ||||||
|  |         uses: JamesIves/github-sponsors-readme-action@v1 | ||||||
|  |         with: | ||||||
|  |           token: ${{ secrets.SPONSORS_GENERATOR_PAT }} | ||||||
|  |           file: README.md | ||||||
|  |           template: '<a href="https://github.com/{{{ login }}}"><img src="{{{ avatarUrl }}}" width="64px" alt="User avatar: {{{ login }}}" /></a>  ' | ||||||
|  |  | ||||||
|  |       - name: Create Pull Request | ||||||
|  |         uses: peter-evans/create-pull-request@v7 | ||||||
|  |         with: | ||||||
|  |           token: ${{ secrets.GITHUB_TOKEN }} | ||||||
|  |           commit-message: | | ||||||
|  |             docs: regenerate readme sponsors list | ||||||
|  |           committer: GitHub <noreply@github.com> | ||||||
|  |           author: GitHub <noreply@github.com> | ||||||
|  |           branch: docs/update-readme | ||||||
|  |           title: | | ||||||
|  |             docs: regenerate readme sponsors list | ||||||
|  |           labels: bot | ||||||
| @@ -1,5 +1,5 @@ | |||||||
| # Site builder | # Site builder | ||||||
| FROM oven/bun:1.2.10-alpine AS frontend-builder | FROM oven/bun:1.2.11-alpine AS frontend-builder | ||||||
|  |  | ||||||
| WORKDIR /frontend | WORKDIR /frontend | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,5 +1,5 @@ | |||||||
| <div align="center"> | <div align="center"> | ||||||
|     <img alt="Tinyauth" title="Tinyauth" width="256" src="frontend/public/logo.png"> |     <img alt="Tinyauth" title="Tinyauth" height="256" src="frontend/public/logo.png"> | ||||||
|     <h1>Tinyauth</h1> |     <h1>Tinyauth</h1> | ||||||
|     <p>The easiest way to secure your apps with a login screen.</p> |     <p>The easiest way to secure your apps with a login screen.</p> | ||||||
| </div> | </div> | ||||||
| @@ -53,9 +53,7 @@ Tinyauth is licensed under the GNU General Public License v3.0. TL;DR — You ma | |||||||
|  |  | ||||||
| Thanks a lot to the following people for providing me with more coffee: | Thanks a lot to the following people for providing me with more coffee: | ||||||
|  |  | ||||||
| | <img height="64" src="https://avatars.githubusercontent.com/u/47644445?v=4" alt="Nicolas"> | <img height="64" src="https://avatars.githubusercontent.com/u/4255748?v=4" alt="Erwin"> | <img height="64" src="https://avatars.githubusercontent.com/u/7935041?v=4" alt="SimpleHomelab" /> | | <!-- sponsors --><a href="https://github.com/nicotsx"><img src="https://github.com/nicotsx.png" width="64px" alt="User avatar: nicotsx" /></a>  <a href="https://github.com/SimpleHomelab"><img src="https://github.com/SimpleHomelab.png" width="64px" alt="User avatar: SimpleHomelab" /></a>  <a href="https://github.com/jmadden91"><img src="https://github.com/jmadden91.png" width="64px" alt="User avatar: jmadden91" /></a>  <a href="https://github.com/tribor"><img src="https://github.com/tribor.png" width="64px" alt="User avatar: tribor" /></a>  <!-- sponsors --> | ||||||
| | ------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------- | |  | ||||||
| | <div align="center"><a href="https://github.com/nicotsx">Nicolas</a></div>                 | <div align="center"><a href="https://github.com/erwinkramer">Erwin</a></div>            | <div align="center"><a href="https://github.com/SimpleHomelab">SimpleHomelab</a></div>            | |  | ||||||
|  |  | ||||||
| ## Acknowledgements | ## Acknowledgements | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										10
									
								
								cmd/root.go
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								cmd/root.go
									
									
									
									
									
								
							| @@ -91,6 +91,7 @@ var rootCmd = &cobra.Command{ | |||||||
| 			CookieSecure:          config.CookieSecure, | 			CookieSecure:          config.CookieSecure, | ||||||
| 			Domain:                domain, | 			Domain:                domain, | ||||||
| 			ForgotPasswordMessage: config.FogotPasswordMessage, | 			ForgotPasswordMessage: config.FogotPasswordMessage, | ||||||
|  | 			OAuthAutoRedirect:     config.OAuthAutoRedirect, | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// Create api config | 		// Create api config | ||||||
| @@ -111,6 +112,11 @@ var rootCmd = &cobra.Command{ | |||||||
| 			LoginMaxRetries: config.LoginMaxRetries, | 			LoginMaxRetries: config.LoginMaxRetries, | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | 		// Create hooks config | ||||||
|  | 		hooksConfig := types.HooksConfig{ | ||||||
|  | 			Domain: domain, | ||||||
|  | 		} | ||||||
|  |  | ||||||
| 		// Create docker service | 		// Create docker service | ||||||
| 		docker := docker.NewDocker() | 		docker := docker.NewDocker() | ||||||
|  |  | ||||||
| @@ -128,7 +134,7 @@ var rootCmd = &cobra.Command{ | |||||||
| 		providers.Init() | 		providers.Init() | ||||||
|  |  | ||||||
| 		// Create hooks service | 		// Create hooks service | ||||||
| 		hooks := hooks.NewHooks(auth, providers) | 		hooks := hooks.NewHooks(hooksConfig, auth, providers) | ||||||
|  |  | ||||||
| 		// Create handlers | 		// Create handlers | ||||||
| 		handlers := handlers.NewHandlers(handlersConfig, auth, hooks, providers, docker) | 		handlers := handlers.NewHandlers(handlersConfig, auth, hooks, providers, docker) | ||||||
| @@ -192,6 +198,7 @@ func init() { | |||||||
| 	rootCmd.Flags().String("generic-name", "Generic", "Generic OAuth provider name.") | 	rootCmd.Flags().String("generic-name", "Generic", "Generic OAuth provider name.") | ||||||
| 	rootCmd.Flags().Bool("disable-continue", false, "Disable continue screen and redirect to app directly.") | 	rootCmd.Flags().Bool("disable-continue", false, "Disable continue screen and redirect to app directly.") | ||||||
| 	rootCmd.Flags().String("oauth-whitelist", "", "Comma separated list of email addresses to whitelist when using OAuth.") | 	rootCmd.Flags().String("oauth-whitelist", "", "Comma separated list of email addresses to whitelist when using OAuth.") | ||||||
|  | 	rootCmd.Flags().String("oauth-auto-redirect", "none", "Auto redirect to the specified OAuth provider if configured. (available providers: github, google, generic)") | ||||||
| 	rootCmd.Flags().Int("session-expiry", 86400, "Session (cookie) expiration time in seconds.") | 	rootCmd.Flags().Int("session-expiry", 86400, "Session (cookie) expiration time in seconds.") | ||||||
| 	rootCmd.Flags().Int("login-timeout", 300, "Login timeout in seconds after max retries reached (0 to disable).") | 	rootCmd.Flags().Int("login-timeout", 300, "Login timeout in seconds after max retries reached (0 to disable).") | ||||||
| 	rootCmd.Flags().Int("login-max-retries", 5, "Maximum login attempts before timeout (0 to disable).") | 	rootCmd.Flags().Int("login-max-retries", 5, "Maximum login attempts before timeout (0 to disable).") | ||||||
| @@ -224,6 +231,7 @@ func init() { | |||||||
| 	viper.BindEnv("generic-name", "GENERIC_NAME") | 	viper.BindEnv("generic-name", "GENERIC_NAME") | ||||||
| 	viper.BindEnv("disable-continue", "DISABLE_CONTINUE") | 	viper.BindEnv("disable-continue", "DISABLE_CONTINUE") | ||||||
| 	viper.BindEnv("oauth-whitelist", "OAUTH_WHITELIST") | 	viper.BindEnv("oauth-whitelist", "OAUTH_WHITELIST") | ||||||
|  | 	viper.BindEnv("oauth-auto-redirect", "OAUTH_AUTO_REDIRECT") | ||||||
| 	viper.BindEnv("session-expiry", "SESSION_EXPIRY") | 	viper.BindEnv("session-expiry", "SESSION_EXPIRY") | ||||||
| 	viper.BindEnv("log-level", "LOG_LEVEL") | 	viper.BindEnv("log-level", "LOG_LEVEL") | ||||||
| 	viper.BindEnv("app-title", "APP_TITLE") | 	viper.BindEnv("app-title", "APP_TITLE") | ||||||
|   | |||||||
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										2561
									
								
								frontend/package-lock.json
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										2561
									
								
								frontend/package-lock.json
									
									
									
										generated
									
									
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @@ -25,7 +25,7 @@ | |||||||
|     "react-dom": "^19.1.0", |     "react-dom": "^19.1.0", | ||||||
|     "react-i18next": "^15.4.1", |     "react-i18next": "^15.4.1", | ||||||
|     "react-markdown": "^10.1.0", |     "react-markdown": "^10.1.0", | ||||||
|     "react-router": "^7.1.3", |     "react-router": "^7.5.2", | ||||||
|     "zod": "^3.24.1" |     "zod": "^3.24.1" | ||||||
|   }, |   }, | ||||||
|   "devDependencies": { |   "devDependencies": { | ||||||
| @@ -43,6 +43,6 @@ | |||||||
|     "prettier": "3.5.3", |     "prettier": "3.5.3", | ||||||
|     "typescript": "~5.8.3", |     "typescript": "~5.8.3", | ||||||
|     "typescript-eslint": "^8.18.2", |     "typescript-eslint": "^8.18.2", | ||||||
|     "vite": "^6.0.5" |     "vite": "^6.3.4" | ||||||
|   } |   } | ||||||
| } | } | ||||||
| @@ -25,7 +25,7 @@ export const LoginForm = (props: LoginFormProps) => { | |||||||
|     <form onSubmit={form.onSubmit(onSubmit)}> |     <form onSubmit={form.onSubmit(onSubmit)}> | ||||||
|       <TextInput |       <TextInput | ||||||
|         label={t("loginUsername")} |         label={t("loginUsername")} | ||||||
|         placeholder="username" |         placeholder="Username" | ||||||
|         disabled={isPending} |         disabled={isPending} | ||||||
|         required |         required | ||||||
|         withAsterisk={false} |         withAsterisk={false} | ||||||
| @@ -43,7 +43,7 @@ export const LoginForm = (props: LoginFormProps) => { | |||||||
|       </Group> |       </Group> | ||||||
|       <PasswordInput |       <PasswordInput | ||||||
|         className="password-input" |         className="password-input" | ||||||
|         placeholder="password" |         placeholder="Password" | ||||||
|         required |         required | ||||||
|         disabled={isPending} |         disabled={isPending} | ||||||
|         key={form.key("password")} |         key={form.key("password")} | ||||||
|   | |||||||
							
								
								
									
										4
									
								
								frontend/src/index.css
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								frontend/src/index.css
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,4 @@ | |||||||
|  | span, | ||||||
|  | p { | ||||||
|  |   word-break: break-word; | ||||||
|  | } | ||||||
							
								
								
									
										26
									
								
								frontend/src/lib/hooks/use-is-mounted.ts
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								frontend/src/lib/hooks/use-is-mounted.ts
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,26 @@ | |||||||
|  | import { useCallback, useEffect, useRef } from 'react' | ||||||
|  |  | ||||||
|  | /** | ||||||
|  |  * Custom hook that determines if the component is currently mounted. | ||||||
|  |  * @returns {() => boolean} A function that returns a boolean value indicating whether the component is mounted. | ||||||
|  |  * @public | ||||||
|  |  * @see [Documentation](https://usehooks-ts.com/react-hook/use-is-mounted) | ||||||
|  |  * @example | ||||||
|  |  * ```tsx | ||||||
|  |  * const isComponentMounted = useIsMounted(); | ||||||
|  |  * // Use isComponentMounted() to check if the component is currently mounted before performing certain actions. | ||||||
|  |  * ``` | ||||||
|  |  */ | ||||||
|  | export function useIsMounted(): () => boolean { | ||||||
|  |   const isMounted = useRef(false) | ||||||
|  |  | ||||||
|  |   useEffect(() => { | ||||||
|  |     isMounted.current = true | ||||||
|  |  | ||||||
|  |     return () => { | ||||||
|  |       isMounted.current = false | ||||||
|  |     } | ||||||
|  |   }, []) | ||||||
|  |  | ||||||
|  |   return useCallback(() => isMounted.current, []) | ||||||
|  | } | ||||||
| @@ -41,7 +41,8 @@ | |||||||
|     "totpTitle": "Enter your TOTP code", |     "totpTitle": "Enter your TOTP code", | ||||||
|     "unauthorizedTitle": "Unauthorized", |     "unauthorizedTitle": "Unauthorized", | ||||||
|     "unauthorizedResourceSubtitle": "The user with username <Code>{{username}}</Code> is not authorized to access the resource <Code>{{resource}}</Code>.", |     "unauthorizedResourceSubtitle": "The user with username <Code>{{username}}</Code> is not authorized to access the resource <Code>{{resource}}</Code>.", | ||||||
|     "unaothorizedLoginSubtitle": "The user with username <Code>{{username}}</Code> is not authorized to login.", |     "unauthorizedLoginSubtitle": "The user with username <Code>{{username}}</Code> is not authorized to login.", | ||||||
|  |     "unauthorizedGroupsSubtitle": "The user with username <Code>{{username}}</Code> is not in the groups required by the resource <Code>{{resource}}</Code>.", | ||||||
|     "unauthorizedButton": "Try again", |     "unauthorizedButton": "Try again", | ||||||
|     "untrustedRedirectTitle": "Untrusted redirect", |     "untrustedRedirectTitle": "Untrusted redirect", | ||||||
|     "untrustedRedirectSubtitle": "You are trying to redirect to a domain that does not match your configured domain (<Code>{{domain}}</Code>). Are you sure you want to continue?", |     "untrustedRedirectSubtitle": "You are trying to redirect to a domain that does not match your configured domain (<Code>{{domain}}</Code>). Are you sure you want to continue?", | ||||||
|   | |||||||
| @@ -41,7 +41,8 @@ | |||||||
|     "totpTitle": "Enter your TOTP code", |     "totpTitle": "Enter your TOTP code", | ||||||
|     "unauthorizedTitle": "Unauthorized", |     "unauthorizedTitle": "Unauthorized", | ||||||
|     "unauthorizedResourceSubtitle": "The user with username <Code>{{username}}</Code> is not authorized to access the resource <Code>{{resource}}</Code>.", |     "unauthorizedResourceSubtitle": "The user with username <Code>{{username}}</Code> is not authorized to access the resource <Code>{{resource}}</Code>.", | ||||||
|     "unaothorizedLoginSubtitle": "The user with username <Code>{{username}}</Code> is not authorized to login.", |     "unauthorizedLoginSubtitle": "The user with username <Code>{{username}}</Code> is not authorized to login.", | ||||||
|  |     "unauthorizedGroupsSubtitle": "The user with username <Code>{{username}}</Code> is not in the groups required by the resource <Code>{{resource}}</Code>.", | ||||||
|     "unauthorizedButton": "Try again", |     "unauthorizedButton": "Try again", | ||||||
|     "untrustedRedirectTitle": "Untrusted redirect", |     "untrustedRedirectTitle": "Untrusted redirect", | ||||||
|     "untrustedRedirectSubtitle": "You are trying to redirect to a domain that does not match your configured domain (<Code>{{domain}}</Code>). Are you sure you want to continue?", |     "untrustedRedirectSubtitle": "You are trying to redirect to a domain that does not match your configured domain (<Code>{{domain}}</Code>). Are you sure you want to continue?", | ||||||
|   | |||||||
| @@ -19,6 +19,7 @@ import { TotpPage } from "./pages/totp-page.tsx"; | |||||||
| import { AppContextProvider } from "./context/app-context.tsx"; | import { AppContextProvider } from "./context/app-context.tsx"; | ||||||
| import "./lib/i18n/i18n.ts"; | import "./lib/i18n/i18n.ts"; | ||||||
| import { ForgotPasswordPage } from "./pages/forgot-password-page.tsx"; | import { ForgotPasswordPage } from "./pages/forgot-password-page.tsx"; | ||||||
|  | import "./index.css"; | ||||||
|  |  | ||||||
| const queryClient = new QueryClient(); | const queryClient = new QueryClient(); | ||||||
|  |  | ||||||
| @@ -38,7 +39,10 @@ createRoot(document.getElementById("root")!).render( | |||||||
|                 <Route path="/continue" element={<ContinuePage />} /> |                 <Route path="/continue" element={<ContinuePage />} /> | ||||||
|                 <Route path="/unauthorized" element={<UnauthorizedPage />} /> |                 <Route path="/unauthorized" element={<UnauthorizedPage />} /> | ||||||
|                 <Route path="/error" element={<InternalServerError />} /> |                 <Route path="/error" element={<InternalServerError />} /> | ||||||
|                 <Route path="/forgot-password" element={<ForgotPasswordPage />} /> |                 <Route | ||||||
|  |                   path="/forgot-password" | ||||||
|  |                   element={<ForgotPasswordPage />} | ||||||
|  |                 /> | ||||||
|                 <Route path="*" element={<NotFoundPage />} /> |                 <Route path="*" element={<NotFoundPage />} /> | ||||||
|               </Routes> |               </Routes> | ||||||
|             </BrowserRouter> |             </BrowserRouter> | ||||||
|   | |||||||
| @@ -4,7 +4,7 @@ import { Navigate } from "react-router"; | |||||||
| import { useUserContext } from "../context/user-context"; | import { useUserContext } from "../context/user-context"; | ||||||
| import { Layout } from "../components/layouts/layout"; | import { Layout } from "../components/layouts/layout"; | ||||||
| import { ReactNode } from "react"; | import { ReactNode } from "react"; | ||||||
| import { escapeRegex, isQueryValid } from "../utils/utils"; | import { escapeRegex, isValidRedirectUri } from "../utils/utils"; | ||||||
| import { useAppContext } from "../context/app-context"; | import { useAppContext } from "../context/app-context"; | ||||||
| import { Trans, useTranslation } from "react-i18next"; | import { Trans, useTranslation } from "react-i18next"; | ||||||
|  |  | ||||||
| @@ -21,7 +21,7 @@ export const ContinuePage = () => { | |||||||
|     return <Navigate to={`/login?redirect_uri=${redirectUri}`} />; |     return <Navigate to={`/login?redirect_uri=${redirectUri}`} />; | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   if (!isQueryValid(redirectUri)) { |   if (!isValidRedirectUri(redirectUri)) { | ||||||
|     return <Navigate to="/" />; |     return <Navigate to="/" />; | ||||||
|   } |   } | ||||||
|  |  | ||||||
| @@ -51,7 +51,7 @@ export const ContinuePage = () => { | |||||||
|     ); |     ); | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   const regex = new RegExp(`^.*${escapeRegex(domain)}$`) |   const regex = new RegExp(`^.*${escapeRegex(domain)}$`); | ||||||
|  |  | ||||||
|   if (!regex.test(uri.hostname)) { |   if (!regex.test(uri.hostname)) { | ||||||
|     return ( |     return ( | ||||||
| @@ -66,13 +66,18 @@ export const ContinuePage = () => { | |||||||
|           values={{ domain: domain }} |           values={{ domain: domain }} | ||||||
|         /> |         /> | ||||||
|         <Button fullWidth mt="xl" color="red" onClick={redirect}> |         <Button fullWidth mt="xl" color="red" onClick={redirect}> | ||||||
|           {t('continueTitle')} |           {t("continueTitle")} | ||||||
|         </Button> |         </Button> | ||||||
|         <Button fullWidth mt="sm" color="gray" onClick={() => window.location.href = "/"}> |         <Button | ||||||
|           {t('cancelTitle')} |           fullWidth | ||||||
|  |           mt="sm" | ||||||
|  |           color="gray" | ||||||
|  |           onClick={() => (window.location.href = "/")} | ||||||
|  |         > | ||||||
|  |           {t("cancelTitle")} | ||||||
|         </Button> |         </Button> | ||||||
|       </ContinuePageLayout> |       </ContinuePageLayout> | ||||||
|     ) |     ); | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   if (disableContinue) { |   if (disableContinue) { | ||||||
| @@ -103,8 +108,13 @@ export const ContinuePage = () => { | |||||||
|         <Button fullWidth mt="xl" color="yellow" onClick={redirect}> |         <Button fullWidth mt="xl" color="yellow" onClick={redirect}> | ||||||
|           {t("continueTitle")} |           {t("continueTitle")} | ||||||
|         </Button> |         </Button> | ||||||
|         <Button fullWidth mt="sm" color="gray" onClick={() => window.location.href = "/"}> |         <Button | ||||||
|           {t('cancelTitle')} |           fullWidth | ||||||
|  |           mt="sm" | ||||||
|  |           color="gray" | ||||||
|  |           onClick={() => (window.location.href = "/")} | ||||||
|  |         > | ||||||
|  |           {t("cancelTitle")} | ||||||
|         </Button> |         </Button> | ||||||
|       </ContinuePageLayout> |       </ContinuePageLayout> | ||||||
|     ); |     ); | ||||||
|   | |||||||
| @@ -8,9 +8,11 @@ import { Layout } from "../components/layouts/layout"; | |||||||
| import { OAuthButtons } from "../components/auth/oauth-buttons"; | import { OAuthButtons } from "../components/auth/oauth-buttons"; | ||||||
| import { LoginFormValues } from "../schemas/login-schema"; | import { LoginFormValues } from "../schemas/login-schema"; | ||||||
| import { LoginForm } from "../components/auth/login-forn"; | import { LoginForm } from "../components/auth/login-forn"; | ||||||
| import { isQueryValid } from "../utils/utils"; |  | ||||||
| import { useAppContext } from "../context/app-context"; | import { useAppContext } from "../context/app-context"; | ||||||
| import { useTranslation } from "react-i18next"; | import { useTranslation } from "react-i18next"; | ||||||
|  | import { useEffect, useState } from "react"; | ||||||
|  | import { useIsMounted } from "../lib/hooks/use-is-mounted"; | ||||||
|  | import { isValidRedirectUri } from "../utils/utils"; | ||||||
|  |  | ||||||
| export const LoginPage = () => { | export const LoginPage = () => { | ||||||
|   const queryString = window.location.search; |   const queryString = window.location.search; | ||||||
| @@ -18,16 +20,29 @@ export const LoginPage = () => { | |||||||
|   const redirectUri = params.get("redirect_uri") ?? ""; |   const redirectUri = params.get("redirect_uri") ?? ""; | ||||||
|  |  | ||||||
|   const { isLoggedIn } = useUserContext(); |   const { isLoggedIn } = useUserContext(); | ||||||
|   const { configuredProviders, title, genericName } = useAppContext(); |  | ||||||
|  |   if (isLoggedIn) { | ||||||
|  |     return <Navigate to="/logout" />; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   const { | ||||||
|  |     configuredProviders, | ||||||
|  |     title, | ||||||
|  |     genericName, | ||||||
|  |     oauthAutoRedirect: oauthAutoRedirectContext, | ||||||
|  |   } = useAppContext(); | ||||||
|  |  | ||||||
|   const { t } = useTranslation(); |   const { t } = useTranslation(); | ||||||
|  |  | ||||||
|  |   const [oauthAutoRedirect, setOAuthAutoRedirect] = useState( | ||||||
|  |     oauthAutoRedirectContext, | ||||||
|  |   ); | ||||||
|  |  | ||||||
|   const oauthProviders = configuredProviders.filter( |   const oauthProviders = configuredProviders.filter( | ||||||
|     (value) => value !== "username", |     (value) => value !== "username", | ||||||
|   ); |   ); | ||||||
|  |  | ||||||
|   if (isLoggedIn) { |   const isMounted = useIsMounted(); | ||||||
|     return <Navigate to="/logout" />; |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   const loginMutation = useMutation({ |   const loginMutation = useMutation({ | ||||||
|     mutationFn: (login: LoginFormValues) => { |     mutationFn: (login: LoginFormValues) => { | ||||||
| @@ -63,7 +78,7 @@ export const LoginPage = () => { | |||||||
|       }); |       }); | ||||||
|  |  | ||||||
|       setTimeout(() => { |       setTimeout(() => { | ||||||
|         if (!isQueryValid(redirectUri)) { |         if (!isValidRedirectUri(redirectUri)) { | ||||||
|           window.location.replace("/"); |           window.location.replace("/"); | ||||||
|           return; |           return; | ||||||
|         } |         } | ||||||
| @@ -85,6 +100,7 @@ export const LoginPage = () => { | |||||||
|         message: t("loginOauthFailSubtitle"), |         message: t("loginOauthFailSubtitle"), | ||||||
|         color: "red", |         color: "red", | ||||||
|       }); |       }); | ||||||
|  |       setOAuthAutoRedirect("none"); | ||||||
|     }, |     }, | ||||||
|     onSuccess: (data) => { |     onSuccess: (data) => { | ||||||
|       notifications.show({ |       notifications.show({ | ||||||
| @@ -102,6 +118,33 @@ export const LoginPage = () => { | |||||||
|     loginMutation.mutate(values); |     loginMutation.mutate(values); | ||||||
|   }; |   }; | ||||||
|  |  | ||||||
|  |   useEffect(() => { | ||||||
|  |     if (isMounted()) { | ||||||
|  |       if ( | ||||||
|  |         oauthProviders.includes(oauthAutoRedirect) && | ||||||
|  |         isValidRedirectUri(redirectUri) | ||||||
|  |       ) { | ||||||
|  |         loginOAuthMutation.mutate(oauthAutoRedirect); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   }, []); | ||||||
|  |  | ||||||
|  |   if ( | ||||||
|  |     oauthProviders.includes(oauthAutoRedirect) && | ||||||
|  |     isValidRedirectUri(redirectUri) | ||||||
|  |   ) { | ||||||
|  |     return ( | ||||||
|  |       <Layout> | ||||||
|  |         <Paper shadow="md" p="xl" mt={30} radius="md" withBorder> | ||||||
|  |           <Text size="xl" fw={700}> | ||||||
|  |             {t("continueRedirectingTitle")} | ||||||
|  |           </Text> | ||||||
|  |           <Text>{t("loginOauthSuccessSubtitle")}</Text> | ||||||
|  |         </Paper> | ||||||
|  |       </Layout> | ||||||
|  |     ); | ||||||
|  |   } | ||||||
|  |  | ||||||
|   return ( |   return ( | ||||||
|     <Layout> |     <Layout> | ||||||
|       <Title ta="center">{title}</Title> |       <Title ta="center">{title}</Title> | ||||||
|   | |||||||
| @@ -10,7 +10,7 @@ import { useAppContext } from "../context/app-context"; | |||||||
| import { Trans, useTranslation } from "react-i18next"; | import { Trans, useTranslation } from "react-i18next"; | ||||||
|  |  | ||||||
| export const LogoutPage = () => { | export const LogoutPage = () => { | ||||||
|   const { isLoggedIn, username, oauth, provider } = useUserContext(); |   const { isLoggedIn, oauth, provider, email, username } = useUserContext(); | ||||||
|   const { genericName } = useAppContext(); |   const { genericName } = useAppContext(); | ||||||
|   const { t } = useTranslation(); |   const { t } = useTranslation(); | ||||||
|  |  | ||||||
| @@ -56,7 +56,7 @@ export const LogoutPage = () => { | |||||||
|               values={{ |               values={{ | ||||||
|                 provider: |                 provider: | ||||||
|                   provider === "generic" ? genericName : capitalize(provider), |                   provider === "generic" ? genericName : capitalize(provider), | ||||||
|                 username: username, |                 username: email, | ||||||
|               }} |               }} | ||||||
|             /> |             /> | ||||||
|           ) : ( |           ) : ( | ||||||
|   | |||||||
| @@ -1,48 +1,71 @@ | |||||||
| import { Button, Code, Paper, Text } from "@mantine/core"; | import { Button, Code, Paper, Text } from "@mantine/core"; | ||||||
| import { Layout } from "../components/layouts/layout"; | import { Layout } from "../components/layouts/layout"; | ||||||
| import { Navigate } from "react-router"; | import { Navigate } from "react-router"; | ||||||
| import { isQueryValid } from "../utils/utils"; |  | ||||||
| import { Trans, useTranslation } from "react-i18next"; | import { Trans, useTranslation } from "react-i18next"; | ||||||
|  | import React from "react"; | ||||||
|  | import { isValidQuery } from "../utils/utils"; | ||||||
|  |  | ||||||
| export const UnauthorizedPage = () => { | export const UnauthorizedPage = () => { | ||||||
|   const queryString = window.location.search; |   const queryString = window.location.search; | ||||||
|   const params = new URLSearchParams(queryString); |   const params = new URLSearchParams(queryString); | ||||||
|   const username = params.get("username") ?? ""; |   const username = params.get("username") ?? ""; | ||||||
|  |   const groupErr = params.get("groupErr") ?? ""; | ||||||
|   const resource = params.get("resource") ?? ""; |   const resource = params.get("resource") ?? ""; | ||||||
|  |  | ||||||
|   const { t } = useTranslation(); |   const { t } = useTranslation(); | ||||||
|  |  | ||||||
|   if (!isQueryValid(username)) { |   if (!isValidQuery(username)) { | ||||||
|     return <Navigate to="/" />; |     return <Navigate to="/" />; | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  |   if (isValidQuery(resource) && !isValidQuery(groupErr)) { | ||||||
|  |     return ( | ||||||
|  |       <UnauthorizedLayout> | ||||||
|  |         <Trans | ||||||
|  |           i18nKey="unauthorizedResourceSubtitle" | ||||||
|  |           t={t} | ||||||
|  |           components={{ Code: <Code /> }} | ||||||
|  |           values={{ resource, username }} | ||||||
|  |         /> | ||||||
|  |       </UnauthorizedLayout> | ||||||
|  |     ); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   if (isValidQuery(groupErr) && isValidQuery(resource)) { | ||||||
|  |     return ( | ||||||
|  |       <UnauthorizedLayout> | ||||||
|  |         <Trans | ||||||
|  |           i18nKey="unauthorizedGroupsSubtitle" | ||||||
|  |           t={t} | ||||||
|  |           components={{ Code: <Code /> }} | ||||||
|  |           values={{ username, resource }} | ||||||
|  |         /> | ||||||
|  |       </UnauthorizedLayout> | ||||||
|  |     ); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   return ( | ||||||
|  |     <UnauthorizedLayout> | ||||||
|  |       <Trans | ||||||
|  |         i18nKey="unauthorizedLoginSubtitle" | ||||||
|  |         t={t} | ||||||
|  |         components={{ Code: <Code /> }} | ||||||
|  |         values={{ username }} | ||||||
|  |       /> | ||||||
|  |     </UnauthorizedLayout> | ||||||
|  |   ); | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | const UnauthorizedLayout = ({ children }: { children: React.ReactNode }) => { | ||||||
|  |   const { t } = useTranslation(); | ||||||
|  |  | ||||||
|   return ( |   return ( | ||||||
|     <Layout> |     <Layout> | ||||||
|       <Paper shadow="md" p={30} mt={30} radius="md" withBorder> |       <Paper shadow="md" p={30} mt={30} radius="md" withBorder> | ||||||
|         <Text size="xl" fw={700}> |         <Text size="xl" fw={700}> | ||||||
|           {t("Unauthorized")} |           {t("Unauthorized")} | ||||||
|         </Text> |         </Text> | ||||||
|         <Text> |         <Text>{children}</Text> | ||||||
|           {isQueryValid(resource) ? ( |  | ||||||
|             <Text> |  | ||||||
|               <Trans |  | ||||||
|                 i18nKey="unauthorizedResourceSubtitle" |  | ||||||
|                 t={t} |  | ||||||
|                 components={{ Code: <Code /> }} |  | ||||||
|                 values={{ resource, username }} |  | ||||||
|               /> |  | ||||||
|             </Text> |  | ||||||
|           ) : ( |  | ||||||
|             <Text> |  | ||||||
|               <Trans |  | ||||||
|                 i18nKey="unaothorizedLoginSubtitle" |  | ||||||
|                 t={t} |  | ||||||
|                 components={{ Code: <Code /> }} |  | ||||||
|                 values={{ username }} |  | ||||||
|               /> |  | ||||||
|             </Text> |  | ||||||
|           )} |  | ||||||
|         </Text> |  | ||||||
|         <Button |         <Button | ||||||
|           fullWidth |           fullWidth | ||||||
|           mt="xl" |           mt="xl" | ||||||
|   | |||||||
| @@ -7,6 +7,7 @@ export const appContextSchema = z.object({ | |||||||
|   genericName: z.string(), |   genericName: z.string(), | ||||||
|   domain: z.string(), |   domain: z.string(), | ||||||
|   forgotPasswordMessage: z.string(), |   forgotPasswordMessage: z.string(), | ||||||
|  |   oauthAutoRedirect: z.enum(["none", "github", "google", "generic"]), | ||||||
| }); | }); | ||||||
|  |  | ||||||
| export type AppContextSchemaType = z.infer<typeof appContextSchema>; | export type AppContextSchemaType = z.infer<typeof appContextSchema>; | ||||||
|   | |||||||
| @@ -3,6 +3,8 @@ import { z } from "zod"; | |||||||
| export const userContextSchema = z.object({ | export const userContextSchema = z.object({ | ||||||
|   isLoggedIn: z.boolean(), |   isLoggedIn: z.boolean(), | ||||||
|   username: z.string(), |   username: z.string(), | ||||||
|  |   name: z.string(), | ||||||
|  |   email: z.string(), | ||||||
|   oauth: z.boolean(), |   oauth: z.boolean(), | ||||||
|   provider: z.string(), |   provider: z.string(), | ||||||
|   totpPending: z.boolean(), |   totpPending: z.boolean(), | ||||||
|   | |||||||
| @@ -1,3 +1,17 @@ | |||||||
| export const capitalize = (s: string) => s.charAt(0).toUpperCase() + s.slice(1); | export const capitalize = (s: string) => s.charAt(0).toUpperCase() + s.slice(1); | ||||||
| export const isQueryValid = (value: string) => value.trim() !== "" && value !== "null"; |  | ||||||
| export const escapeRegex = (value: string) => value.replace(/[-\/\\^$.*+?()[\]{}|]/g, "\\$&"); | export const escapeRegex = (value: string) => value.replace(/[-\/\\^$.*+?()[\]{}|]/g, "\\$&"); | ||||||
|  | export const isValidQuery = (query: string) => query && query.trim() !== ""; | ||||||
|  |  | ||||||
|  | export const isValidRedirectUri = (value: string) => { | ||||||
|  |     if (!isValidQuery(value)) { | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     try { | ||||||
|  |         new URL(value); | ||||||
|  |     } catch { | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     return true; | ||||||
|  | } | ||||||
| @@ -45,6 +45,11 @@ var authConfig = types.AuthConfig{ | |||||||
| 	LoginMaxRetries: 0, | 	LoginMaxRetries: 0, | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // Simple hooks config for tests | ||||||
|  | var hooksConfig = types.HooksConfig{ | ||||||
|  | 	Domain: "localhost", | ||||||
|  | } | ||||||
|  |  | ||||||
| // Cookie | // Cookie | ||||||
| var cookie string | var cookie string | ||||||
|  |  | ||||||
| @@ -83,7 +88,7 @@ func getAPI(t *testing.T) *api.API { | |||||||
| 	providers.Init() | 	providers.Init() | ||||||
|  |  | ||||||
| 	// Create hooks service | 	// Create hooks service | ||||||
| 	hooks := hooks.NewHooks(auth, providers) | 	hooks := hooks.NewHooks(hooksConfig, auth, providers) | ||||||
|  |  | ||||||
| 	// Create handlers service | 	// Create handlers service | ||||||
| 	handlers := handlers.NewHandlers(handlersConfig, auth, hooks, providers, docker) | 	handlers := handlers.NewHandlers(handlersConfig, auth, hooks, providers, docker) | ||||||
|   | |||||||
| @@ -160,9 +160,12 @@ func (auth *Auth) CreateSessionCookie(c *gin.Context, data *types.SessionCookie) | |||||||
|  |  | ||||||
| 	// Set data | 	// Set data | ||||||
| 	session.Values["username"] = data.Username | 	session.Values["username"] = data.Username | ||||||
|  | 	session.Values["name"] = data.Name | ||||||
|  | 	session.Values["email"] = data.Email | ||||||
| 	session.Values["provider"] = data.Provider | 	session.Values["provider"] = data.Provider | ||||||
| 	session.Values["expiry"] = time.Now().Add(time.Duration(sessionExpiry) * time.Second).Unix() | 	session.Values["expiry"] = time.Now().Add(time.Duration(sessionExpiry) * time.Second).Unix() | ||||||
| 	session.Values["totpPending"] = data.TotpPending | 	session.Values["totpPending"] = data.TotpPending | ||||||
|  | 	session.Values["oauthGroups"] = data.OAuthGroups | ||||||
|  |  | ||||||
| 	// Save session | 	// Save session | ||||||
| 	err = session.Save(c.Request, c.Writer) | 	err = session.Save(c.Request, c.Writer) | ||||||
| @@ -211,14 +214,24 @@ func (auth *Auth) GetSessionCookie(c *gin.Context) (types.SessionCookie, error) | |||||||
| 		return types.SessionCookie{}, err | 		return types.SessionCookie{}, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	log.Debug().Msg("Got session") | ||||||
|  |  | ||||||
| 	// Get data from session | 	// Get data from session | ||||||
| 	username, usernameOk := session.Values["username"].(string) | 	username, usernameOk := session.Values["username"].(string) | ||||||
|  | 	email, emailOk := session.Values["email"].(string) | ||||||
|  | 	name, nameOk := session.Values["name"].(string) | ||||||
| 	provider, providerOK := session.Values["provider"].(string) | 	provider, providerOK := session.Values["provider"].(string) | ||||||
| 	expiry, expiryOk := session.Values["expiry"].(int64) | 	expiry, expiryOk := session.Values["expiry"].(int64) | ||||||
| 	totpPending, totpPendingOk := session.Values["totpPending"].(bool) | 	totpPending, totpPendingOk := session.Values["totpPending"].(bool) | ||||||
|  | 	oauthGroups, oauthGroupsOk := session.Values["oauthGroups"].(string) | ||||||
|  |  | ||||||
| 	if !usernameOk || !providerOK || !expiryOk || !totpPendingOk { | 	if !usernameOk || !providerOK || !expiryOk || !totpPendingOk || !emailOk || !nameOk || !oauthGroupsOk { | ||||||
| 		log.Warn().Msg("Session cookie is missing data") | 		log.Warn().Msg("Session cookie is invalid") | ||||||
|  |  | ||||||
|  | 		// If any data is missing, delete the session cookie | ||||||
|  | 		auth.DeleteSessionCookie(c) | ||||||
|  |  | ||||||
|  | 		// Return empty cookie | ||||||
| 		return types.SessionCookie{}, nil | 		return types.SessionCookie{}, nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -233,13 +246,16 @@ func (auth *Auth) GetSessionCookie(c *gin.Context) (types.SessionCookie, error) | |||||||
| 		return types.SessionCookie{}, nil | 		return types.SessionCookie{}, nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debug().Str("username", username).Str("provider", provider).Int64("expiry", expiry).Bool("totpPending", totpPending).Msg("Parsed cookie") | 	log.Debug().Str("username", username).Str("provider", provider).Int64("expiry", expiry).Bool("totpPending", totpPending).Str("name", name).Str("email", email).Str("oauthGroups", oauthGroups).Msg("Parsed cookie") | ||||||
|  |  | ||||||
| 	// Return the cookie | 	// Return the cookie | ||||||
| 	return types.SessionCookie{ | 	return types.SessionCookie{ | ||||||
| 		Username:    username, | 		Username:    username, | ||||||
|  | 		Name:        name, | ||||||
|  | 		Email:       email, | ||||||
| 		Provider:    provider, | 		Provider:    provider, | ||||||
| 		TotpPending: totpPending, | 		TotpPending: totpPending, | ||||||
|  | 		OAuthGroups: oauthGroups, | ||||||
| 	}, nil | 	}, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -248,48 +264,52 @@ func (auth *Auth) UserAuthConfigured() bool { | |||||||
| 	return len(auth.Config.Users) > 0 | 	return len(auth.Config.Users) > 0 | ||||||
| } | } | ||||||
|  |  | ||||||
| func (auth *Auth) ResourceAllowed(c *gin.Context, context types.UserContext) (bool, error) { | func (auth *Auth) ResourceAllowed(c *gin.Context, context types.UserContext, labels types.TinyauthLabels) bool { | ||||||
| 	// Get headers |  | ||||||
| 	host := c.Request.Header.Get("X-Forwarded-Host") |  | ||||||
|  |  | ||||||
| 	// Get app id |  | ||||||
| 	appId := strings.Split(host, ".")[0] |  | ||||||
|  |  | ||||||
| 	// Get the container labels |  | ||||||
| 	labels, err := auth.Docker.GetLabels(appId) |  | ||||||
|  |  | ||||||
| 	// If there is an error, return false |  | ||||||
| 	if err != nil { |  | ||||||
| 		return false, err |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Check if oauth is allowed | 	// Check if oauth is allowed | ||||||
| 	if context.OAuth { | 	if context.OAuth { | ||||||
| 		log.Debug().Msg("Checking OAuth whitelist") | 		log.Debug().Msg("Checking OAuth whitelist") | ||||||
| 		return utils.CheckWhitelist(labels.OAuthWhitelist, context.Username), nil | 		return utils.CheckWhitelist(labels.OAuthWhitelist, context.Email) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Check users | 	// Check users | ||||||
| 	log.Debug().Msg("Checking users") | 	log.Debug().Msg("Checking users") | ||||||
|  |  | ||||||
| 	return utils.CheckWhitelist(labels.Users, context.Username), nil | 	return utils.CheckWhitelist(labels.Users, context.Username) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (auth *Auth) AuthEnabled(c *gin.Context) (bool, error) { | func (auth *Auth) OAuthGroup(c *gin.Context, context types.UserContext, labels types.TinyauthLabels) bool { | ||||||
|  | 	// Check if groups are required | ||||||
|  | 	if labels.OAuthGroups == "" { | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Check if we are using the generic oauth provider | ||||||
|  | 	if context.Provider != "generic" { | ||||||
|  | 		log.Debug().Msg("Not using generic provider, skipping group check") | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Split the groups by comma (no need to parse since they are from the API response) | ||||||
|  | 	oauthGroups := strings.Split(context.OAuthGroups, ",") | ||||||
|  |  | ||||||
|  | 	// For every group check if it is in the required groups | ||||||
|  | 	for _, group := range oauthGroups { | ||||||
|  | 		if utils.CheckWhitelist(labels.OAuthGroups, group) { | ||||||
|  | 			log.Debug().Str("group", group).Msg("Group is in required groups") | ||||||
|  | 			return true | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// No groups matched | ||||||
|  | 	log.Debug().Msg("No groups matched") | ||||||
|  |  | ||||||
|  | 	// Return false | ||||||
|  | 	return false | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (auth *Auth) AuthEnabled(c *gin.Context, labels types.TinyauthLabels) (bool, error) { | ||||||
| 	// Get headers | 	// Get headers | ||||||
| 	uri := c.Request.Header.Get("X-Forwarded-Uri") | 	uri := c.Request.Header.Get("X-Forwarded-Uri") | ||||||
| 	host := c.Request.Header.Get("X-Forwarded-Host") |  | ||||||
|  |  | ||||||
| 	// Get app id |  | ||||||
| 	appId := strings.Split(host, ".")[0] |  | ||||||
|  |  | ||||||
| 	// Get the container labels |  | ||||||
| 	labels, err := auth.Docker.GetLabels(appId) |  | ||||||
|  |  | ||||||
| 	// If there is an error, auth enabled |  | ||||||
| 	if err != nil { |  | ||||||
| 		return true, err |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Check if the allowed label is empty | 	// Check if the allowed label is empty | ||||||
| 	if labels.Allowed == "" { | 	if labels.Allowed == "" { | ||||||
|   | |||||||
| @@ -6,4 +6,13 @@ var TinyauthLabels = []string{ | |||||||
| 	"tinyauth.users", | 	"tinyauth.users", | ||||||
| 	"tinyauth.allowed", | 	"tinyauth.allowed", | ||||||
| 	"tinyauth.headers", | 	"tinyauth.headers", | ||||||
|  | 	"tinyauth.oauth.groups", | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Claims are the OIDC supported claims (including preferd username for some reason) | ||||||
|  | type Claims struct { | ||||||
|  | 	Name              string   `json:"name"` | ||||||
|  | 	Email             string   `json:"email"` | ||||||
|  | 	PreferredUsername string   `json:"preferred_username"` | ||||||
|  | 	Groups            []string `json:"groups"` | ||||||
| } | } | ||||||
|   | |||||||
| @@ -6,8 +6,7 @@ import ( | |||||||
| 	"tinyauth/internal/types" | 	"tinyauth/internal/types" | ||||||
| 	"tinyauth/internal/utils" | 	"tinyauth/internal/utils" | ||||||
|  |  | ||||||
| 	apiTypes "github.com/docker/docker/api/types" | 	container "github.com/docker/docker/api/types/container" | ||||||
| 	containerTypes "github.com/docker/docker/api/types/container" |  | ||||||
| 	"github.com/docker/docker/client" | 	"github.com/docker/docker/client" | ||||||
| 	"github.com/rs/zerolog/log" | 	"github.com/rs/zerolog/log" | ||||||
| ) | ) | ||||||
| @@ -38,9 +37,9 @@ func (docker *Docker) Init() error { | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (docker *Docker) GetContainers() ([]apiTypes.Container, error) { | func (docker *Docker) GetContainers() ([]container.Summary, error) { | ||||||
| 	// Get the list of containers | 	// Get the list of containers | ||||||
| 	containers, err := docker.Client.ContainerList(docker.Context, containerTypes.ListOptions{}) | 	containers, err := docker.Client.ContainerList(docker.Context, container.ListOptions{}) | ||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @@ -51,13 +50,13 @@ func (docker *Docker) GetContainers() ([]apiTypes.Container, error) { | |||||||
| 	return containers, nil | 	return containers, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (docker *Docker) InspectContainer(containerId string) (apiTypes.ContainerJSON, error) { | func (docker *Docker) InspectContainer(containerId string) (container.InspectResponse, error) { | ||||||
| 	// Inspect the container | 	// Inspect the container | ||||||
| 	inspect, err := docker.Client.ContainerInspect(docker.Context, containerId) | 	inspect, err := docker.Client.ContainerInspect(docker.Context, containerId) | ||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return apiTypes.ContainerJSON{}, err | 		return container.InspectResponse{}, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Return the inspect | 	// Return the inspect | ||||||
|   | |||||||
| @@ -10,6 +10,7 @@ import ( | |||||||
| 	"tinyauth/internal/hooks" | 	"tinyauth/internal/hooks" | ||||||
| 	"tinyauth/internal/providers" | 	"tinyauth/internal/providers" | ||||||
| 	"tinyauth/internal/types" | 	"tinyauth/internal/types" | ||||||
|  | 	"tinyauth/internal/utils" | ||||||
|  |  | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"github.com/google/go-querystring/query" | 	"github.com/google/go-querystring/query" | ||||||
| @@ -68,12 +69,15 @@ func (h *Handlers) AuthHandler(c *gin.Context) { | |||||||
| 	proto := c.Request.Header.Get("X-Forwarded-Proto") | 	proto := c.Request.Header.Get("X-Forwarded-Proto") | ||||||
| 	host := c.Request.Header.Get("X-Forwarded-Host") | 	host := c.Request.Header.Get("X-Forwarded-Host") | ||||||
|  |  | ||||||
| 	// Check if auth is enabled | 	// Get the app id | ||||||
| 	authEnabled, err := h.Auth.AuthEnabled(c) | 	appId := strings.Split(host, ".")[0] | ||||||
|  |  | ||||||
|  | 	// Get the container labels | ||||||
|  | 	labels, err := h.Docker.GetLabels(appId) | ||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Error().Err(err).Msg("Failed to check if app is allowed") | 		log.Error().Err(err).Msg("Failed to get container labels") | ||||||
|  |  | ||||||
| 		if proxy.Proxy == "nginx" || !isBrowser { | 		if proxy.Proxy == "nginx" || !isBrowser { | ||||||
| 			c.JSON(500, gin.H{ | 			c.JSON(500, gin.H{ | ||||||
| @@ -87,11 +91,8 @@ func (h *Handlers) AuthHandler(c *gin.Context) { | |||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Get the app id | 	// Check if auth is enabled | ||||||
| 	appId := strings.Split(host, ".")[0] | 	authEnabled, err := h.Auth.AuthEnabled(c, labels) | ||||||
|  |  | ||||||
| 	// Get the container labels |  | ||||||
| 	labels, err := h.Docker.GetLabels(appId) |  | ||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @@ -113,7 +114,7 @@ func (h *Handlers) AuthHandler(c *gin.Context) { | |||||||
| 	if !authEnabled { | 	if !authEnabled { | ||||||
| 		for key, value := range labels.Headers { | 		for key, value := range labels.Headers { | ||||||
| 			log.Debug().Str("key", key).Str("value", value).Msg("Setting header") | 			log.Debug().Str("key", key).Str("value", value).Msg("Setting header") | ||||||
| 			c.Header(key, value) | 			c.Header(key, utils.SanitizeHeader(value)) | ||||||
| 		} | 		} | ||||||
| 		c.JSON(200, gin.H{ | 		c.JSON(200, gin.H{ | ||||||
| 			"status":  200, | 			"status":  200, | ||||||
| @@ -125,28 +126,18 @@ func (h *Handlers) AuthHandler(c *gin.Context) { | |||||||
| 	// Get user context | 	// Get user context | ||||||
| 	userContext := h.Hooks.UseUserContext(c) | 	userContext := h.Hooks.UseUserContext(c) | ||||||
|  |  | ||||||
|  | 	// If we are using basic auth, we need to check if the user has totp and if it does then disable basic auth | ||||||
|  | 	if userContext.Provider == "basic" && userContext.TotpEnabled { | ||||||
|  | 		log.Warn().Str("username", userContext.Username).Msg("User has totp enabled, disabling basic auth") | ||||||
|  | 		userContext.IsLoggedIn = false | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	// Check if user is logged in | 	// Check if user is logged in | ||||||
| 	if userContext.IsLoggedIn { | 	if userContext.IsLoggedIn { | ||||||
| 		log.Debug().Msg("Authenticated") | 		log.Debug().Msg("Authenticated") | ||||||
|  |  | ||||||
| 		// Check if user is allowed to access subdomain, if request is nginx.example.com the subdomain (resource) is nginx | 		// Check if user is allowed to access subdomain, if request is nginx.example.com the subdomain (resource) is nginx | ||||||
| 		appAllowed, err := h.Auth.ResourceAllowed(c, userContext) | 		appAllowed := h.Auth.ResourceAllowed(c, userContext, labels) | ||||||
|  |  | ||||||
| 		// Check if there was an error |  | ||||||
| 		if err != nil { |  | ||||||
| 			log.Error().Err(err).Msg("Failed to check if app is allowed") |  | ||||||
|  |  | ||||||
| 			if proxy.Proxy == "nginx" || !isBrowser { |  | ||||||
| 				c.JSON(500, gin.H{ |  | ||||||
| 					"status":  500, |  | ||||||
| 					"message": "Internal Server Error", |  | ||||||
| 				}) |  | ||||||
| 				return |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		log.Debug().Bool("appAllowed", appAllowed).Msg("Checking if app is allowed") | 		log.Debug().Bool("appAllowed", appAllowed).Msg("Checking if app is allowed") | ||||||
|  |  | ||||||
| @@ -165,11 +156,20 @@ func (h *Handlers) AuthHandler(c *gin.Context) { | |||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			// Build query | 			// Values | ||||||
| 			queries, err := query.Values(types.UnauthorizedQuery{ | 			values := types.UnauthorizedQuery{ | ||||||
| 				Username: userContext.Username, |  | ||||||
| 				Resource: strings.Split(host, ".")[0], | 				Resource: strings.Split(host, ".")[0], | ||||||
| 			}) | 			} | ||||||
|  |  | ||||||
|  | 			// Use either username or email | ||||||
|  | 			if userContext.OAuth { | ||||||
|  | 				values.Username = userContext.Email | ||||||
|  | 			} else { | ||||||
|  | 				values.Username = userContext.Username | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			// Build query | ||||||
|  | 			queries, err := query.Values(values) | ||||||
|  |  | ||||||
| 			// Handle error (no need to check for nginx/headers since we are sure we are using caddy/traefik) | 			// Handle error (no need to check for nginx/headers since we are sure we are using caddy/traefik) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| @@ -183,13 +183,65 @@ func (h *Handlers) AuthHandler(c *gin.Context) { | |||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// Set the user header | 		log.Debug().Interface("labels", labels).Msg("Got labels") | ||||||
| 		c.Header("Remote-User", userContext.Username) |  | ||||||
|  | 		// Check if user is in required groups | ||||||
|  | 		groupOk := h.Auth.OAuthGroup(c, userContext, labels) | ||||||
|  |  | ||||||
|  | 		log.Debug().Bool("groupOk", groupOk).Msg("Checking if user is in required groups") | ||||||
|  |  | ||||||
|  | 		// The user is not allowed to access the app | ||||||
|  | 		if !groupOk { | ||||||
|  | 			log.Warn().Str("username", userContext.Username).Str("host", host).Msg("User is not in required groups") | ||||||
|  |  | ||||||
|  | 			// Set WWW-Authenticate header | ||||||
|  | 			c.Header("WWW-Authenticate", "Basic realm=\"tinyauth\"") | ||||||
|  |  | ||||||
|  | 			if proxy.Proxy == "nginx" || !isBrowser { | ||||||
|  | 				c.JSON(401, gin.H{ | ||||||
|  | 					"status":  401, | ||||||
|  | 					"message": "Unauthorized", | ||||||
|  | 				}) | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			// Values | ||||||
|  | 			values := types.UnauthorizedQuery{ | ||||||
|  | 				Resource: strings.Split(host, ".")[0], | ||||||
|  | 				GroupErr: true, | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			// Use either username or email | ||||||
|  | 			if userContext.OAuth { | ||||||
|  | 				values.Username = userContext.Email | ||||||
|  | 			} else { | ||||||
|  | 				values.Username = userContext.Username | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			// Build query | ||||||
|  | 			queries, err := query.Values(values) | ||||||
|  |  | ||||||
|  | 			// Handle error (no need to check for nginx/headers since we are sure we are using caddy/traefik) | ||||||
|  | 			if err != nil { | ||||||
|  | 				log.Error().Err(err).Msg("Failed to build queries") | ||||||
|  | 				c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			// We are using caddy/traefik so redirect | ||||||
|  | 			c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", h.Config.AppURL, queries.Encode())) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		c.Header("Remote-User", utils.SanitizeHeader(userContext.Username)) | ||||||
|  | 		c.Header("Remote-Name", utils.SanitizeHeader(userContext.Name)) | ||||||
|  | 		c.Header("Remote-Email", utils.SanitizeHeader(userContext.Email)) | ||||||
|  | 		c.Header("Remote-Groups", utils.SanitizeHeader(userContext.OAuthGroups)) | ||||||
|  |  | ||||||
| 		// Set the rest of the headers | 		// Set the rest of the headers | ||||||
| 		for key, value := range labels.Headers { | 		for key, value := range labels.Headers { | ||||||
| 			log.Debug().Str("key", key).Str("value", value).Msg("Setting header") | 			log.Debug().Str("key", key).Str("value", value).Msg("Setting header") | ||||||
| 			c.Header(key, value) | 			c.Header(key, utils.SanitizeHeader(value)) | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// The user is allowed to access the app | 		// The user is allowed to access the app | ||||||
| @@ -310,6 +362,8 @@ func (h *Handlers) LoginHandler(c *gin.Context) { | |||||||
| 		// Set totp pending cookie | 		// Set totp pending cookie | ||||||
| 		h.Auth.CreateSessionCookie(c, &types.SessionCookie{ | 		h.Auth.CreateSessionCookie(c, &types.SessionCookie{ | ||||||
| 			Username:    login.Username, | 			Username:    login.Username, | ||||||
|  | 			Name:        utils.Capitalize(login.Username), | ||||||
|  | 			Email:       fmt.Sprintf("%s@%s", strings.ToLower(login.Username), h.Config.Domain), | ||||||
| 			Provider:    "username", | 			Provider:    "username", | ||||||
| 			TotpPending: true, | 			TotpPending: true, | ||||||
| 		}) | 		}) | ||||||
| @@ -328,6 +382,8 @@ func (h *Handlers) LoginHandler(c *gin.Context) { | |||||||
| 	// Create session cookie with username as provider | 	// Create session cookie with username as provider | ||||||
| 	h.Auth.CreateSessionCookie(c, &types.SessionCookie{ | 	h.Auth.CreateSessionCookie(c, &types.SessionCookie{ | ||||||
| 		Username: login.Username, | 		Username: login.Username, | ||||||
|  | 		Name:     utils.Capitalize(login.Username), | ||||||
|  | 		Email:    fmt.Sprintf("%s@%s", strings.ToLower(login.Username), h.Config.Domain), | ||||||
| 		Provider: "username", | 		Provider: "username", | ||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
| @@ -402,6 +458,8 @@ func (h *Handlers) TotpHandler(c *gin.Context) { | |||||||
| 	// Create session cookie with username as provider | 	// Create session cookie with username as provider | ||||||
| 	h.Auth.CreateSessionCookie(c, &types.SessionCookie{ | 	h.Auth.CreateSessionCookie(c, &types.SessionCookie{ | ||||||
| 		Username: user.Username, | 		Username: user.Username, | ||||||
|  | 		Name:     utils.Capitalize(user.Username), | ||||||
|  | 		Email:    fmt.Sprintf("%s@%s", strings.ToLower(user.Username), h.Config.Domain), | ||||||
| 		Provider: "username", | 		Provider: "username", | ||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
| @@ -448,6 +506,7 @@ func (h *Handlers) AppHandler(c *gin.Context) { | |||||||
| 		GenericName:           h.Config.GenericName, | 		GenericName:           h.Config.GenericName, | ||||||
| 		Domain:                h.Config.Domain, | 		Domain:                h.Config.Domain, | ||||||
| 		ForgotPasswordMessage: h.Config.ForgotPasswordMessage, | 		ForgotPasswordMessage: h.Config.ForgotPasswordMessage, | ||||||
|  | 		OAuthAutoRedirect:     h.Config.OAuthAutoRedirect, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Return app context | 	// Return app context | ||||||
| @@ -465,6 +524,8 @@ func (h *Handlers) UserHandler(c *gin.Context) { | |||||||
| 		Status:      200, | 		Status:      200, | ||||||
| 		IsLoggedIn:  userContext.IsLoggedIn, | 		IsLoggedIn:  userContext.IsLoggedIn, | ||||||
| 		Username:    userContext.Username, | 		Username:    userContext.Username, | ||||||
|  | 		Name:        userContext.Name, | ||||||
|  | 		Email:       userContext.Email, | ||||||
| 		Provider:    userContext.Provider, | 		Provider:    userContext.Provider, | ||||||
| 		Oauth:       userContext.OAuth, | 		Oauth:       userContext.OAuth, | ||||||
| 		TotpPending: userContext.TotpPending, | 		TotpPending: userContext.TotpPending, | ||||||
| @@ -608,35 +669,42 @@ func (h *Handlers) OauthCallbackHandler(c *gin.Context) { | |||||||
|  |  | ||||||
| 	// Handle error | 	// Handle error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Error().Msg("Failed to exchange token") | 		log.Error().Err(err).Msg("Failed to exchange token") | ||||||
| 		c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) | 		c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Get email | 	// Get user | ||||||
| 	email, err := h.Providers.GetUser(providerName.Provider) | 	user, err := h.Providers.GetUser(providerName.Provider) | ||||||
|  |  | ||||||
| 	log.Debug().Str("email", email).Msg("Got email") |  | ||||||
|  |  | ||||||
| 	// Handle error | 	// Handle error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Error().Msg("Failed to get email") | 		log.Error().Msg("Failed to get user") | ||||||
|  | 		c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	log.Debug().Msg("Got user") | ||||||
|  |  | ||||||
|  | 	// Check that email is not empty | ||||||
|  | 	if user.Email == "" { | ||||||
|  | 		log.Error().Msg("Email is empty") | ||||||
| 		c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) | 		c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Email is not whitelisted | 	// Email is not whitelisted | ||||||
| 	if !h.Auth.EmailWhitelisted(email) { | 	if !h.Auth.EmailWhitelisted(user.Email) { | ||||||
| 		log.Warn().Str("email", email).Msg("Email not whitelisted") | 		log.Warn().Str("email", user.Email).Msg("Email not whitelisted") | ||||||
|  |  | ||||||
| 		// Build query | 		// Build query | ||||||
| 		queries, err := query.Values(types.UnauthorizedQuery{ | 		queries, err := query.Values(types.UnauthorizedQuery{ | ||||||
| 			Username: email, | 			Username: user.Email, | ||||||
| 		}) | 		}) | ||||||
|  |  | ||||||
| 		// Handle error | 		// Handle error | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			log.Error().Msg("Failed to build queries") | 			log.Error().Err(err).Msg("Failed to build queries") | ||||||
| 			c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) | 			c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| @@ -647,10 +715,31 @@ func (h *Handlers) OauthCallbackHandler(c *gin.Context) { | |||||||
|  |  | ||||||
| 	log.Debug().Msg("Email whitelisted") | 	log.Debug().Msg("Email whitelisted") | ||||||
|  |  | ||||||
|  | 	// Get username | ||||||
|  | 	var username string | ||||||
|  |  | ||||||
|  | 	if user.PreferredUsername != "" { | ||||||
|  | 		username = user.PreferredUsername | ||||||
|  | 	} else { | ||||||
|  | 		username = fmt.Sprintf("%s_%s", strings.Split(user.Email, "@")[0], strings.Split(user.Email, "@")[1]) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Get name | ||||||
|  | 	var name string | ||||||
|  |  | ||||||
|  | 	if user.Name != "" { | ||||||
|  | 		name = user.Name | ||||||
|  | 	} else { | ||||||
|  | 		name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1]) | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	// Create session cookie (also cleans up redirect cookie) | 	// Create session cookie (also cleans up redirect cookie) | ||||||
| 	h.Auth.CreateSessionCookie(c, &types.SessionCookie{ | 	h.Auth.CreateSessionCookie(c, &types.SessionCookie{ | ||||||
| 		Username: email, | 		Username:    username, | ||||||
|  | 		Name:        name, | ||||||
|  | 		Email:       user.Email, | ||||||
| 		Provider:    providerName.Provider, | 		Provider:    providerName.Provider, | ||||||
|  | 		OAuthGroups: strings.Join(user.Groups, ","), | ||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
| 	// Check if we have a redirect URI | 	// Check if we have a redirect URI | ||||||
| @@ -673,7 +762,7 @@ func (h *Handlers) OauthCallbackHandler(c *gin.Context) { | |||||||
|  |  | ||||||
| 	// Handle error | 	// Handle error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Error().Msg("Failed to build queries") | 		log.Error().Err(err).Msg("Failed to build queries") | ||||||
| 		c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) | 		c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -1,22 +1,27 @@ | |||||||
| package hooks | package hooks | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"strings" | ||||||
| 	"tinyauth/internal/auth" | 	"tinyauth/internal/auth" | ||||||
| 	"tinyauth/internal/providers" | 	"tinyauth/internal/providers" | ||||||
| 	"tinyauth/internal/types" | 	"tinyauth/internal/types" | ||||||
|  | 	"tinyauth/internal/utils" | ||||||
|  |  | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"github.com/rs/zerolog/log" | 	"github.com/rs/zerolog/log" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func NewHooks(auth *auth.Auth, providers *providers.Providers) *Hooks { | func NewHooks(config types.HooksConfig, auth *auth.Auth, providers *providers.Providers) *Hooks { | ||||||
| 	return &Hooks{ | 	return &Hooks{ | ||||||
|  | 		Config:    config, | ||||||
| 		Auth:      auth, | 		Auth:      auth, | ||||||
| 		Providers: providers, | 		Providers: providers, | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| type Hooks struct { | type Hooks struct { | ||||||
|  | 	Config    types.HooksConfig | ||||||
| 	Auth      *auth.Auth | 	Auth      *auth.Auth | ||||||
| 	Providers *providers.Providers | 	Providers *providers.Providers | ||||||
| } | } | ||||||
| @@ -30,17 +35,27 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { | |||||||
| 	if basic != nil { | 	if basic != nil { | ||||||
| 		log.Debug().Msg("Got basic auth") | 		log.Debug().Msg("Got basic auth") | ||||||
|  |  | ||||||
| 		// Check if user exists and password is correct | 		// Get user | ||||||
| 		user := hooks.Auth.GetUser(basic.Username) | 		user := hooks.Auth.GetUser(basic.Username) | ||||||
|  |  | ||||||
| 		if user != nil && hooks.Auth.CheckPassword(*user, basic.Password) { | 		// Check we have a user | ||||||
|  | 		if user == nil { | ||||||
|  | 			log.Error().Str("username", basic.Username).Msg("User does not exist") | ||||||
|  |  | ||||||
|  | 			// Return empty context | ||||||
|  | 			return types.UserContext{} | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		// Check if the user has a correct password | ||||||
|  | 		if hooks.Auth.CheckPassword(*user, basic.Password) { | ||||||
| 			// Return user context since we are logged in with basic auth | 			// Return user context since we are logged in with basic auth | ||||||
| 			return types.UserContext{ | 			return types.UserContext{ | ||||||
| 				Username:    basic.Username, | 				Username:    basic.Username, | ||||||
|  | 				Name:        utils.Capitalize(basic.Username), | ||||||
|  | 				Email:       fmt.Sprintf("%s@%s", strings.ToLower(basic.Username), hooks.Config.Domain), | ||||||
| 				IsLoggedIn:  true, | 				IsLoggedIn:  true, | ||||||
| 				OAuth:       false, |  | ||||||
| 				Provider:    "basic", | 				Provider:    "basic", | ||||||
| 				TotpPending: false, | 				TotpEnabled: user.TotpSecret != "", | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| @@ -50,13 +65,7 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Error().Err(err).Msg("Failed to get session cookie") | 		log.Error().Err(err).Msg("Failed to get session cookie") | ||||||
| 		// Return empty context | 		// Return empty context | ||||||
| 		return types.UserContext{ | 		return types.UserContext{} | ||||||
| 			Username:    "", |  | ||||||
| 			IsLoggedIn:  false, |  | ||||||
| 			OAuth:       false, |  | ||||||
| 			Provider:    "", |  | ||||||
| 			TotpPending: false, |  | ||||||
| 		} |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Check if session cookie has totp pending | 	// Check if session cookie has totp pending | ||||||
| @@ -65,8 +74,8 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { | |||||||
| 		// Return empty context since we are pending totp | 		// Return empty context since we are pending totp | ||||||
| 		return types.UserContext{ | 		return types.UserContext{ | ||||||
| 			Username:    cookie.Username, | 			Username:    cookie.Username, | ||||||
| 			IsLoggedIn:  false, | 			Name:        cookie.Name, | ||||||
| 			OAuth:       false, | 			Email:       cookie.Email, | ||||||
| 			Provider:    cookie.Provider, | 			Provider:    cookie.Provider, | ||||||
| 			TotpPending: true, | 			TotpPending: true, | ||||||
| 		} | 		} | ||||||
| @@ -83,10 +92,10 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { | |||||||
| 			// It exists so we are logged in | 			// It exists so we are logged in | ||||||
| 			return types.UserContext{ | 			return types.UserContext{ | ||||||
| 				Username:   cookie.Username, | 				Username:   cookie.Username, | ||||||
|  | 				Name:       cookie.Name, | ||||||
|  | 				Email:      cookie.Email, | ||||||
| 				IsLoggedIn: true, | 				IsLoggedIn: true, | ||||||
| 				OAuth:       false, |  | ||||||
| 				Provider:   "username", | 				Provider:   "username", | ||||||
| 				TotpPending: false, |  | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| @@ -108,13 +117,7 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { | |||||||
| 			hooks.Auth.DeleteSessionCookie(c) | 			hooks.Auth.DeleteSessionCookie(c) | ||||||
|  |  | ||||||
| 			// Return empty context | 			// Return empty context | ||||||
| 			return types.UserContext{ | 			return types.UserContext{} | ||||||
| 				Username:    "", |  | ||||||
| 				IsLoggedIn:  false, |  | ||||||
| 				OAuth:       false, |  | ||||||
| 				Provider:    "", |  | ||||||
| 				TotpPending: false, |  | ||||||
| 			} |  | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Email is whitelisted") | 		log.Debug().Msg("Email is whitelisted") | ||||||
| @@ -122,19 +125,15 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { | |||||||
| 		// Return user context since we are logged in with oauth | 		// Return user context since we are logged in with oauth | ||||||
| 		return types.UserContext{ | 		return types.UserContext{ | ||||||
| 			Username:    cookie.Username, | 			Username:    cookie.Username, | ||||||
|  | 			Name:        cookie.Name, | ||||||
|  | 			Email:       cookie.Email, | ||||||
| 			IsLoggedIn:  true, | 			IsLoggedIn:  true, | ||||||
| 			OAuth:       true, | 			OAuth:       true, | ||||||
| 			Provider:    cookie.Provider, | 			Provider:    cookie.Provider, | ||||||
| 			TotpPending: false, | 			OAuthGroups: cookie.OAuthGroups, | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Neither basic auth or oauth is set so we return an empty context | 	// Neither basic auth or oauth is set so we return an empty context | ||||||
| 	return types.UserContext{ | 	return types.UserContext{} | ||||||
| 		Username:    "", |  | ||||||
| 		IsLoggedIn:  false, |  | ||||||
| 		OAuth:       false, |  | ||||||
| 		Provider:    "", |  | ||||||
| 		TotpPending: false, |  | ||||||
| 	} |  | ||||||
| } | } | ||||||
|   | |||||||
| @@ -4,24 +4,25 @@ import ( | |||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"io" | 	"io" | ||||||
| 	"net/http" | 	"net/http" | ||||||
|  | 	"tinyauth/internal/constants" | ||||||
|  |  | ||||||
| 	"github.com/rs/zerolog/log" | 	"github.com/rs/zerolog/log" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // We are assuming that the generic provider will return a JSON object with an email field | func GetGenericUser(client *http.Client, url string) (constants.Claims, error) { | ||||||
| type GenericUserInfoResponse struct { | 	// Create user struct | ||||||
| 	Email string `json:"email"` | 	var user constants.Claims | ||||||
| } |  | ||||||
|  |  | ||||||
| func GetGenericEmail(client *http.Client, url string) (string, error) { |  | ||||||
| 	// Using the oauth client get the user info url | 	// Using the oauth client get the user info url | ||||||
| 	res, err := client.Get(url) | 	res, err := client.Get(url) | ||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", err | 		return user, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	defer res.Body.Close() | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Got response from generic provider") | 	log.Debug().Msg("Got response from generic provider") | ||||||
|  |  | ||||||
| 	// Read the body of the response | 	// Read the body of the response | ||||||
| @@ -29,24 +30,21 @@ func GetGenericEmail(client *http.Client, url string) (string, error) { | |||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", err | 		return user, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Read body from generic provider") | 	log.Debug().Msg("Read body from generic provider") | ||||||
|  |  | ||||||
| 	// Parse the body into a user struct |  | ||||||
| 	var user GenericUserInfoResponse |  | ||||||
|  |  | ||||||
| 	// Unmarshal the body into the user struct | 	// Unmarshal the body into the user struct | ||||||
| 	err = json.Unmarshal(body, &user) | 	err = json.Unmarshal(body, &user) | ||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", err | 		return user, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Parsed user from generic provider") | 	log.Debug().Msg("Parsed user from generic provider") | ||||||
|  |  | ||||||
| 	// Return the email | 	// Return the user | ||||||
| 	return user.Email, nil | 	return user, nil | ||||||
| } | } | ||||||
|   | |||||||
| @@ -5,51 +5,96 @@ import ( | |||||||
| 	"errors" | 	"errors" | ||||||
| 	"io" | 	"io" | ||||||
| 	"net/http" | 	"net/http" | ||||||
|  | 	"tinyauth/internal/constants" | ||||||
|  |  | ||||||
| 	"github.com/rs/zerolog/log" | 	"github.com/rs/zerolog/log" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // Github has a different response than the generic provider | // Response for the github email endpoint | ||||||
| type GithubUserInfoResponse []struct { | type GithubEmailResponse []struct { | ||||||
| 	Email   string `json:"email"` | 	Email   string `json:"email"` | ||||||
| 	Primary bool   `json:"primary"` | 	Primary bool   `json:"primary"` | ||||||
| } | } | ||||||
|  |  | ||||||
| // The scopes required for the github provider | // Response for the github user endpoint | ||||||
| func GithubScopes() []string { | type GithubUserInfoResponse struct { | ||||||
| 	return []string{"user:email"} | 	Login string `json:"login"` | ||||||
|  | 	Name  string `json:"name"` | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetGithubEmail(client *http.Client) (string, error) { | // The scopes required for the github provider | ||||||
| 	// Get the user emails from github using the oauth http client | func GithubScopes() []string { | ||||||
| 	res, err := client.Get("https://api.github.com/user/emails") | 	return []string{"user:email", "read:user"} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func GetGithubUser(client *http.Client) (constants.Claims, error) { | ||||||
|  | 	// Create user struct | ||||||
|  | 	var user constants.Claims | ||||||
|  |  | ||||||
|  | 	// Get the user info from github using the oauth http client | ||||||
|  | 	res, err := client.Get("https://api.github.com/user") | ||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", err | 		return user, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Got response from github") | 	defer res.Body.Close() | ||||||
|  |  | ||||||
|  | 	log.Debug().Msg("Got user response from github") | ||||||
|  |  | ||||||
| 	// Read the body of the response | 	// Read the body of the response | ||||||
| 	body, err := io.ReadAll(res.Body) | 	body, err := io.ReadAll(res.Body) | ||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", err | 		return user, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Read body from github") | 	log.Debug().Msg("Read user body from github") | ||||||
|  |  | ||||||
| 	// Parse the body into a user struct | 	// Parse the body into a user struct | ||||||
| 	var emails GithubUserInfoResponse | 	var userInfo GithubUserInfoResponse | ||||||
|  |  | ||||||
|  | 	// Unmarshal the body into the user struct | ||||||
|  | 	err = json.Unmarshal(body, &userInfo) | ||||||
|  |  | ||||||
|  | 	// Check if there was an error | ||||||
|  | 	if err != nil { | ||||||
|  | 		return user, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Get the user emails from github using the oauth http client | ||||||
|  | 	res, err = client.Get("https://api.github.com/user/emails") | ||||||
|  |  | ||||||
|  | 	// Check if there was an error | ||||||
|  | 	if err != nil { | ||||||
|  | 		return user, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	defer res.Body.Close() | ||||||
|  |  | ||||||
|  | 	log.Debug().Msg("Got email response from github") | ||||||
|  |  | ||||||
|  | 	// Read the body of the response | ||||||
|  | 	body, err = io.ReadAll(res.Body) | ||||||
|  |  | ||||||
|  | 	// Check if there was an error | ||||||
|  | 	if err != nil { | ||||||
|  | 		return user, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	log.Debug().Msg("Read email body from github") | ||||||
|  |  | ||||||
|  | 	// Parse the body into a user struct | ||||||
|  | 	var emails GithubEmailResponse | ||||||
|  |  | ||||||
| 	// Unmarshal the body into the user struct | 	// Unmarshal the body into the user struct | ||||||
| 	err = json.Unmarshal(body, &emails) | 	err = json.Unmarshal(body, &emails) | ||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", err | 		return user, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Parsed emails from github") | 	log.Debug().Msg("Parsed emails from github") | ||||||
| @@ -57,10 +102,26 @@ func GetGithubEmail(client *http.Client) (string, error) { | |||||||
| 	// Find and return the primary email | 	// Find and return the primary email | ||||||
| 	for _, email := range emails { | 	for _, email := range emails { | ||||||
| 		if email.Primary { | 		if email.Primary { | ||||||
| 			return email.Email, nil | 			// Set the email then exit | ||||||
|  | 			user.Email = email.Email | ||||||
|  | 			break | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// User does not have a primary email? | 	// If no primary email was found, use the first available email | ||||||
| 	return "", errors.New("no primary email found") | 	if len(emails) == 0 { | ||||||
|  | 		return user, errors.New("no emails found") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Set the email if it is not set picking the first one | ||||||
|  | 	if user.Email == "" { | ||||||
|  | 		user.Email = emails[0].Email | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Set the username and name | ||||||
|  | 	user.PreferredUsername = userInfo.Login | ||||||
|  | 	user.Name = userInfo.Name | ||||||
|  |  | ||||||
|  | 	// Return | ||||||
|  | 	return user, nil | ||||||
| } | } | ||||||
|   | |||||||
| @@ -4,29 +4,37 @@ import ( | |||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"io" | 	"io" | ||||||
| 	"net/http" | 	"net/http" | ||||||
|  | 	"strings" | ||||||
|  | 	"tinyauth/internal/constants" | ||||||
|  |  | ||||||
| 	"github.com/rs/zerolog/log" | 	"github.com/rs/zerolog/log" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // Google works the same as the generic provider | // Response for the google user endpoint | ||||||
| type GoogleUserInfoResponse struct { | type GoogleUserInfoResponse struct { | ||||||
| 	Email string `json:"email"` | 	Email string `json:"email"` | ||||||
|  | 	Name  string `json:"name"` | ||||||
| } | } | ||||||
|  |  | ||||||
| // The scopes required for the google provider | // The scopes required for the google provider | ||||||
| func GoogleScopes() []string { | func GoogleScopes() []string { | ||||||
| 	return []string{"https://www.googleapis.com/auth/userinfo.email"} | 	return []string{"https://www.googleapis.com/auth/userinfo.email", "https://www.googleapis.com/auth/userinfo.profile"} | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetGoogleEmail(client *http.Client) (string, error) { | func GetGoogleUser(client *http.Client) (constants.Claims, error) { | ||||||
|  | 	// Create user struct | ||||||
|  | 	var user constants.Claims | ||||||
|  |  | ||||||
| 	// Get the user info from google using the oauth http client | 	// Get the user info from google using the oauth http client | ||||||
| 	res, err := client.Get("https://www.googleapis.com/userinfo/v2/me") | 	res, err := client.Get("https://www.googleapis.com/userinfo/v2/me") | ||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", err | 		return user, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	defer res.Body.Close() | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Got response from google") | 	log.Debug().Msg("Got response from google") | ||||||
|  |  | ||||||
| 	// Read the body of the response | 	// Read the body of the response | ||||||
| @@ -34,24 +42,29 @@ func GetGoogleEmail(client *http.Client) (string, error) { | |||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", err | 		return user, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Read body from google") | 	log.Debug().Msg("Read body from google") | ||||||
|  |  | ||||||
| 	// Parse the body into a user struct | 	// Create a new user info struct | ||||||
| 	var user GoogleUserInfoResponse | 	var userInfo GoogleUserInfoResponse | ||||||
|  |  | ||||||
| 	// Unmarshal the body into the user struct | 	// Unmarshal the body into the user struct | ||||||
| 	err = json.Unmarshal(body, &user) | 	err = json.Unmarshal(body, &userInfo) | ||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", err | 		return user, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Parsed user from google") | 	log.Debug().Msg("Parsed user from google") | ||||||
|  |  | ||||||
| 	// Return the email | 	// Map the user info to the user struct | ||||||
| 	return user.Email, nil | 	user.PreferredUsername = strings.Split(userInfo.Email, "@")[0] | ||||||
|  | 	user.Name = userInfo.Name | ||||||
|  | 	user.Email = userInfo.Email | ||||||
|  |  | ||||||
|  | 	// Return the user | ||||||
|  | 	return user, nil | ||||||
| } | } | ||||||
|   | |||||||
| @@ -2,6 +2,7 @@ package providers | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"tinyauth/internal/constants" | ||||||
| 	"tinyauth/internal/oauth" | 	"tinyauth/internal/oauth" | ||||||
| 	"tinyauth/internal/types" | 	"tinyauth/internal/types" | ||||||
|  |  | ||||||
| @@ -93,14 +94,17 @@ func (providers *Providers) GetProvider(provider string) *oauth.OAuth { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func (providers *Providers) GetUser(provider string) (string, error) { | func (providers *Providers) GetUser(provider string) (constants.Claims, error) { | ||||||
| 	// Get the email from the provider | 	// Create user struct | ||||||
|  | 	var user constants.Claims | ||||||
|  |  | ||||||
|  | 	// Get the user from the provider | ||||||
| 	switch provider { | 	switch provider { | ||||||
| 	case "github": | 	case "github": | ||||||
| 		// If the github provider is not configured, return an error | 		// If the github provider is not configured, return an error | ||||||
| 		if providers.Github == nil { | 		if providers.Github == nil { | ||||||
| 			log.Debug().Msg("Github provider not configured") | 			log.Debug().Msg("Github provider not configured") | ||||||
| 			return "", nil | 			return user, nil | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// Get the client from the github provider | 		// Get the client from the github provider | ||||||
| @@ -108,23 +112,23 @@ func (providers *Providers) GetUser(provider string) (string, error) { | |||||||
|  |  | ||||||
| 		log.Debug().Msg("Got client from github") | 		log.Debug().Msg("Got client from github") | ||||||
|  |  | ||||||
| 		// Get the email from the github provider | 		// Get the user from the github provider | ||||||
| 		email, err := GetGithubEmail(client) | 		user, err := GetGithubUser(client) | ||||||
|  |  | ||||||
| 		// Check if there was an error | 		// Check if there was an error | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return "", err | 			return user, err | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Got email from github") | 		log.Debug().Msg("Got user from github") | ||||||
|  |  | ||||||
| 		// Return the email | 		// Return the user | ||||||
| 		return email, nil | 		return user, nil | ||||||
| 	case "google": | 	case "google": | ||||||
| 		// If the google provider is not configured, return an error | 		// If the google provider is not configured, return an error | ||||||
| 		if providers.Google == nil { | 		if providers.Google == nil { | ||||||
| 			log.Debug().Msg("Google provider not configured") | 			log.Debug().Msg("Google provider not configured") | ||||||
| 			return "", nil | 			return user, nil | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// Get the client from the google provider | 		// Get the client from the google provider | ||||||
| @@ -132,23 +136,23 @@ func (providers *Providers) GetUser(provider string) (string, error) { | |||||||
|  |  | ||||||
| 		log.Debug().Msg("Got client from google") | 		log.Debug().Msg("Got client from google") | ||||||
|  |  | ||||||
| 		// Get the email from the google provider | 		// Get the user from the google provider | ||||||
| 		email, err := GetGoogleEmail(client) | 		user, err := GetGoogleUser(client) | ||||||
|  |  | ||||||
| 		// Check if there was an error | 		// Check if there was an error | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return "", err | 			return user, err | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Got email from google") | 		log.Debug().Msg("Got user from google") | ||||||
|  |  | ||||||
| 		// Return the email | 		// Return the user | ||||||
| 		return email, nil | 		return user, nil | ||||||
| 	case "generic": | 	case "generic": | ||||||
| 		// If the generic provider is not configured, return an error | 		// If the generic provider is not configured, return an error | ||||||
| 		if providers.Generic == nil { | 		if providers.Generic == nil { | ||||||
| 			log.Debug().Msg("Generic provider not configured") | 			log.Debug().Msg("Generic provider not configured") | ||||||
| 			return "", nil | 			return user, nil | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// Get the client from the generic provider | 		// Get the client from the generic provider | ||||||
| @@ -156,20 +160,20 @@ func (providers *Providers) GetUser(provider string) (string, error) { | |||||||
|  |  | ||||||
| 		log.Debug().Msg("Got client from generic") | 		log.Debug().Msg("Got client from generic") | ||||||
|  |  | ||||||
| 		// Get the email from the generic provider | 		// Get the user from the generic provider | ||||||
| 		email, err := GetGenericEmail(client, providers.Config.GenericUserURL) | 		user, err := GetGenericUser(client, providers.Config.GenericUserURL) | ||||||
|  |  | ||||||
| 		// Check if there was an error | 		// Check if there was an error | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return "", err | 			return user, err | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Got email from generic") | 		log.Debug().Msg("Got user from generic") | ||||||
|  |  | ||||||
| 		// Return the email | 		// Return the email | ||||||
| 		return email, nil | 		return user, nil | ||||||
| 	default: | 	default: | ||||||
| 		return "", nil | 		return user, nil | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -20,6 +20,7 @@ type OAuthRequest struct { | |||||||
| type UnauthorizedQuery struct { | type UnauthorizedQuery struct { | ||||||
| 	Username string `url:"username"` | 	Username string `url:"username"` | ||||||
| 	Resource string `url:"resource"` | 	Resource string `url:"resource"` | ||||||
|  | 	GroupErr bool   `url:"groupErr"` | ||||||
| } | } | ||||||
|  |  | ||||||
| // Proxy is the uri parameters for the proxy endpoint | // Proxy is the uri parameters for the proxy endpoint | ||||||
| @@ -33,6 +34,8 @@ type UserContextResponse struct { | |||||||
| 	Message     string `json:"message"` | 	Message     string `json:"message"` | ||||||
| 	IsLoggedIn  bool   `json:"isLoggedIn"` | 	IsLoggedIn  bool   `json:"isLoggedIn"` | ||||||
| 	Username    string `json:"username"` | 	Username    string `json:"username"` | ||||||
|  | 	Name        string `json:"name"` | ||||||
|  | 	Email       string `json:"email"` | ||||||
| 	Provider    string `json:"provider"` | 	Provider    string `json:"provider"` | ||||||
| 	Oauth       bool   `json:"oauth"` | 	Oauth       bool   `json:"oauth"` | ||||||
| 	TotpPending bool   `json:"totpPending"` | 	TotpPending bool   `json:"totpPending"` | ||||||
| @@ -48,6 +51,7 @@ type AppContext struct { | |||||||
| 	GenericName           string   `json:"genericName"` | 	GenericName           string   `json:"genericName"` | ||||||
| 	Domain                string   `json:"domain"` | 	Domain                string   `json:"domain"` | ||||||
| 	ForgotPasswordMessage string   `json:"forgotPasswordMessage"` | 	ForgotPasswordMessage string   `json:"forgotPasswordMessage"` | ||||||
|  | 	OAuthAutoRedirect     string   `json:"oauthAutoRedirect"` | ||||||
| } | } | ||||||
|  |  | ||||||
| // Totp request is the request for the totp endpoint | // Totp request is the request for the totp endpoint | ||||||
|   | |||||||
| @@ -26,6 +26,7 @@ type Config struct { | |||||||
| 	GenericName             string `mapstructure:"generic-name"` | 	GenericName             string `mapstructure:"generic-name"` | ||||||
| 	DisableContinue         bool   `mapstructure:"disable-continue"` | 	DisableContinue         bool   `mapstructure:"disable-continue"` | ||||||
| 	OAuthWhitelist          string `mapstructure:"oauth-whitelist"` | 	OAuthWhitelist          string `mapstructure:"oauth-whitelist"` | ||||||
|  | 	OAuthAutoRedirect       string `mapstructure:"oauth-auto-redirect" validate:"oneof=none github google generic"` | ||||||
| 	SessionExpiry           int    `mapstructure:"session-expiry"` | 	SessionExpiry           int    `mapstructure:"session-expiry"` | ||||||
| 	LogLevel                int8   `mapstructure:"log-level" validate:"min=-1,max=5"` | 	LogLevel                int8   `mapstructure:"log-level" validate:"min=-1,max=5"` | ||||||
| 	Title                   string `mapstructure:"app-title"` | 	Title                   string `mapstructure:"app-title"` | ||||||
| @@ -44,6 +45,7 @@ type HandlersConfig struct { | |||||||
| 	GenericName           string | 	GenericName           string | ||||||
| 	Title                 string | 	Title                 string | ||||||
| 	ForgotPasswordMessage string | 	ForgotPasswordMessage string | ||||||
|  | 	OAuthAutoRedirect     string | ||||||
| } | } | ||||||
|  |  | ||||||
| // OAuthConfig is the configuration for the providers | // OAuthConfig is the configuration for the providers | ||||||
| @@ -78,3 +80,8 @@ type AuthConfig struct { | |||||||
| 	LoginTimeout    int | 	LoginTimeout    int | ||||||
| 	LoginMaxRetries int | 	LoginMaxRetries int | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // HooksConfig is the configuration for the hooks service | ||||||
|  | type HooksConfig struct { | ||||||
|  | 	Domain string | ||||||
|  | } | ||||||
|   | |||||||
| @@ -25,8 +25,11 @@ type OAuthProviders struct { | |||||||
| // SessionCookie is the cookie for the session (exculding the expiry) | // SessionCookie is the cookie for the session (exculding the expiry) | ||||||
| type SessionCookie struct { | type SessionCookie struct { | ||||||
| 	Username    string | 	Username    string | ||||||
|  | 	Name        string | ||||||
|  | 	Email       string | ||||||
| 	Provider    string | 	Provider    string | ||||||
| 	TotpPending bool | 	TotpPending bool | ||||||
|  | 	OAuthGroups string | ||||||
| } | } | ||||||
|  |  | ||||||
| // TinyauthLabels is the labels for the tinyauth container | // TinyauthLabels is the labels for the tinyauth container | ||||||
| @@ -35,15 +38,20 @@ type TinyauthLabels struct { | |||||||
| 	Users          string | 	Users          string | ||||||
| 	Allowed        string | 	Allowed        string | ||||||
| 	Headers        map[string]string | 	Headers        map[string]string | ||||||
|  | 	OAuthGroups    string | ||||||
| } | } | ||||||
|  |  | ||||||
| // UserContext is the context for the user | // UserContext is the context for the user | ||||||
| type UserContext struct { | type UserContext struct { | ||||||
| 	Username    string | 	Username    string | ||||||
|  | 	Name        string | ||||||
|  | 	Email       string | ||||||
| 	IsLoggedIn  bool | 	IsLoggedIn  bool | ||||||
| 	OAuth       bool | 	OAuth       bool | ||||||
| 	Provider    string | 	Provider    string | ||||||
| 	TotpPending bool | 	TotpPending bool | ||||||
|  | 	OAuthGroups string | ||||||
|  | 	TotpEnabled bool | ||||||
| } | } | ||||||
|  |  | ||||||
| // LoginAttempt tracks information about login attempts for rate limiting | // LoginAttempt tracks information about login attempts for rate limiting | ||||||
|   | |||||||
| @@ -204,6 +204,8 @@ func GetTinyauthLabels(labels map[string]string) types.TinyauthLabels { | |||||||
| 					} | 					} | ||||||
| 					tinyauthLabels.Headers[headerSplit[0]] = headerSplit[1] | 					tinyauthLabels.Headers[headerSplit[0]] = headerSplit[1] | ||||||
| 				} | 				} | ||||||
|  | 			case "tinyauth.oauth.groups": | ||||||
|  | 				tinyauthLabels.OAuthGroups = value | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| @@ -323,3 +325,22 @@ func CheckWhitelist(whitelist string, str string) bool { | |||||||
| 	// Return false if no match was found | 	// Return false if no match was found | ||||||
| 	return false | 	return false | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // Capitalize just the first letter of a string | ||||||
|  | func Capitalize(str string) string { | ||||||
|  | 	if len(str) == 0 { | ||||||
|  | 		return "" | ||||||
|  | 	} | ||||||
|  | 	return strings.ToUpper(string([]rune(str)[0])) + string([]rune(str)[1:]) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Sanitize header removes all control characters from a string | ||||||
|  | func SanitizeHeader(header string) string { | ||||||
|  | 	return strings.Map(func(r rune) rune { | ||||||
|  | 		// Allow only printable ASCII characters (32-126) and safe whitespace (space, tab) | ||||||
|  | 		if r == ' ' || r == '\t' || (r >= 32 && r <= 126) { | ||||||
|  | 			return r | ||||||
|  | 		} | ||||||
|  | 		return -1 | ||||||
|  | 	}, header) | ||||||
|  | } | ||||||
|   | |||||||
| @@ -467,3 +467,65 @@ func TestCheckWhitelist(t *testing.T) { | |||||||
| 		t.Fatalf("Expected %v, got %v", expected, result) | 		t.Fatalf("Expected %v, got %v", expected, result) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // Test capitalize | ||||||
|  | func TestCapitalize(t *testing.T) { | ||||||
|  | 	t.Log("Testing capitalize with a valid string") | ||||||
|  |  | ||||||
|  | 	// Create variables | ||||||
|  | 	str := "test" | ||||||
|  | 	expected := "Test" | ||||||
|  |  | ||||||
|  | 	// Test the capitalize function | ||||||
|  | 	result := utils.Capitalize(str) | ||||||
|  |  | ||||||
|  | 	// Check if the result is equal to the expected | ||||||
|  | 	if result != expected { | ||||||
|  | 		t.Fatalf("Expected %v, got %v", expected, result) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	t.Log("Testing capitalize with an empty string") | ||||||
|  |  | ||||||
|  | 	// Create variables | ||||||
|  | 	str = "" | ||||||
|  | 	expected = "" | ||||||
|  |  | ||||||
|  | 	// Test the capitalize function | ||||||
|  | 	result = utils.Capitalize(str) | ||||||
|  |  | ||||||
|  | 	// Check if the result is equal to the expected | ||||||
|  | 	if result != expected { | ||||||
|  | 		t.Fatalf("Expected %v, got %v", expected, result) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Test the header sanitizer | ||||||
|  | func TestSanitizeHeader(t *testing.T) { | ||||||
|  | 	t.Log("Testing sanitize header with a valid string") | ||||||
|  |  | ||||||
|  | 	// Create variables | ||||||
|  | 	str := "X-Header=value" | ||||||
|  | 	expected := "X-Header=value" | ||||||
|  |  | ||||||
|  | 	// Test the sanitize header function | ||||||
|  | 	result := utils.SanitizeHeader(str) | ||||||
|  |  | ||||||
|  | 	// Check if the result is equal to the expected | ||||||
|  | 	if result != expected { | ||||||
|  | 		t.Fatalf("Expected %v, got %v", expected, result) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	t.Log("Testing sanitize header with an invalid string") | ||||||
|  |  | ||||||
|  | 	// Create variables | ||||||
|  | 	str = "X-Header=val\nue" | ||||||
|  | 	expected = "X-Header=value" | ||||||
|  |  | ||||||
|  | 	// Test the sanitize header function | ||||||
|  | 	result = utils.SanitizeHeader(str) | ||||||
|  |  | ||||||
|  | 	// Check if the result is equal to the expected | ||||||
|  | 	if result != expected { | ||||||
|  | 		t.Fatalf("Expected %v, got %v", expected, result) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user