mirror of
				https://github.com/steveiliop56/tinyauth.git
				synced 2025-10-30 21:55:43 +00:00 
			
		
		
		
	Compare commits
	
		
			9 Commits
		
	
	
		
			v3.3.0-alp
			...
			feat/oauth
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|   | 86a18b4cac | ||
|   | d171c5940b | ||
|   | 3dff650e71 | ||
|   | 1fec583ead | ||
|   | 065b9eaf3d | ||
|   | f824b84787 | ||
|   | dca09a3d9d | ||
|   | 5e4e2ddbd9 | ||
|   | 13032e564d | 
| @@ -53,9 +53,9 @@ Tinyauth is licensed under the GNU General Public License v3.0. TL;DR — You ma | |||||||
|  |  | ||||||
| Thanks a lot to the following people for providing me with more coffee: | Thanks a lot to the following people for providing me with more coffee: | ||||||
|  |  | ||||||
| | <img height="64" src="https://avatars.githubusercontent.com/u/47644445?v=4" alt="Nicolas"> | <img height="64" src="https://avatars.githubusercontent.com/u/4255748?v=4" alt="Erwin"> | <img height="64" src="https://avatars.githubusercontent.com/u/7935041?v=4" alt="SimpleHomelab" /> | | | <div align="center"><img height="64" src="https://avatars.githubusercontent.com/u/47644445?v=4" alt="Nicolas"></div> | <div align="center"><img height="64" src="https://avatars.githubusercontent.com/u/4255748?v=4" alt="Erwin"></div> | <div align="center"><img height="64" src="https://avatars.githubusercontent.com/u/7935041?v=4" alt="SimpleHomelab"></div> | <div align="center"><img height="64" src="https://avatars.githubusercontent.com/u/30562276?v=4" alt="jmadden91"></div> | | ||||||
| | ------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------- | | | -------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------- | | ||||||
| | <div align="center"><a href="https://github.com/nicotsx">Nicolas</a></div>                 | <div align="center"><a href="https://github.com/erwinkramer">Erwin</a></div>            | <div align="center"><a href="https://github.com/SimpleHomelab">SimpleHomelab</a></div>            | | | <div align="center"><a href="https://github.com/nicotsx">Nicolas</a></div>                                           | <div align="center"><a href="https://github.com/erwinkramer">Erwin</a></div>                                      | <div align="center"><a href="https://github.com/SimpleHomelab">SimpleHomelab</a></div>                                    | <div align="center"><a href="https://github.com/jmadden91">jmadden91</a></div>                                         | | ||||||
|  |  | ||||||
| ## Acknowledgements | ## Acknowledgements | ||||||
|  |  | ||||||
|   | |||||||
| @@ -111,6 +111,11 @@ var rootCmd = &cobra.Command{ | |||||||
| 			LoginMaxRetries: config.LoginMaxRetries, | 			LoginMaxRetries: config.LoginMaxRetries, | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | 		// Create hooks config | ||||||
|  | 		hooksConfig := types.HooksConfig{ | ||||||
|  | 			Domain: domain, | ||||||
|  | 		} | ||||||
|  |  | ||||||
| 		// Create docker service | 		// Create docker service | ||||||
| 		docker := docker.NewDocker() | 		docker := docker.NewDocker() | ||||||
|  |  | ||||||
| @@ -128,7 +133,7 @@ var rootCmd = &cobra.Command{ | |||||||
| 		providers.Init() | 		providers.Init() | ||||||
|  |  | ||||||
| 		// Create hooks service | 		// Create hooks service | ||||||
| 		hooks := hooks.NewHooks(auth, providers) | 		hooks := hooks.NewHooks(hooksConfig, auth, providers) | ||||||
|  |  | ||||||
| 		// Create handlers | 		// Create handlers | ||||||
| 		handlers := handlers.NewHandlers(handlersConfig, auth, hooks, providers, docker) | 		handlers := handlers.NewHandlers(handlersConfig, auth, hooks, providers, docker) | ||||||
| @@ -189,7 +194,7 @@ func init() { | |||||||
| 	rootCmd.Flags().String("generic-auth-url", "", "Generic OAuth auth URL.") | 	rootCmd.Flags().String("generic-auth-url", "", "Generic OAuth auth URL.") | ||||||
| 	rootCmd.Flags().String("generic-token-url", "", "Generic OAuth token URL.") | 	rootCmd.Flags().String("generic-token-url", "", "Generic OAuth token URL.") | ||||||
| 	rootCmd.Flags().String("generic-user-url", "", "Generic OAuth user info URL.") | 	rootCmd.Flags().String("generic-user-url", "", "Generic OAuth user info URL.") | ||||||
| 	rootCmd.Flags().String("generic-name", "Other", "Generic OAuth provider name.") | 	rootCmd.Flags().String("generic-name", "Generic", "Generic OAuth provider name.") | ||||||
| 	rootCmd.Flags().Bool("disable-continue", false, "Disable continue screen and redirect to app directly.") | 	rootCmd.Flags().Bool("disable-continue", false, "Disable continue screen and redirect to app directly.") | ||||||
| 	rootCmd.Flags().String("oauth-whitelist", "", "Comma separated list of email addresses to whitelist when using OAuth.") | 	rootCmd.Flags().String("oauth-whitelist", "", "Comma separated list of email addresses to whitelist when using OAuth.") | ||||||
| 	rootCmd.Flags().Int("session-expiry", 86400, "Session (cookie) expiration time in seconds.") | 	rootCmd.Flags().Int("session-expiry", 86400, "Session (cookie) expiration time in seconds.") | ||||||
|   | |||||||
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										4
									
								
								frontend/src/index.css
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								frontend/src/index.css
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,4 @@ | |||||||
|  | span, | ||||||
|  | p { | ||||||
|  |   word-break: break-word; | ||||||
|  | } | ||||||
| @@ -41,7 +41,8 @@ | |||||||
|     "totpTitle": "Enter your TOTP code", |     "totpTitle": "Enter your TOTP code", | ||||||
|     "unauthorizedTitle": "Unauthorized", |     "unauthorizedTitle": "Unauthorized", | ||||||
|     "unauthorizedResourceSubtitle": "The user with username <Code>{{username}}</Code> is not authorized to access the resource <Code>{{resource}}</Code>.", |     "unauthorizedResourceSubtitle": "The user with username <Code>{{username}}</Code> is not authorized to access the resource <Code>{{resource}}</Code>.", | ||||||
|     "unaothorizedLoginSubtitle": "The user with username <Code>{{username}}</Code> is not authorized to login.", |     "unauthorizedLoginSubtitle": "The user with username <Code>{{username}}</Code> is not authorized to login.", | ||||||
|  |     "unauthorizedGroupsSubtitle": "The user with username <Code>{{username}}</Code> is not in the groups required by the resource <Code>{{resource}}</Code>.", | ||||||
|     "unauthorizedButton": "Try again", |     "unauthorizedButton": "Try again", | ||||||
|     "untrustedRedirectTitle": "Untrusted redirect", |     "untrustedRedirectTitle": "Untrusted redirect", | ||||||
|     "untrustedRedirectSubtitle": "You are trying to redirect to a domain that does not match your configured domain (<Code>{{domain}}</Code>). Are you sure you want to continue?", |     "untrustedRedirectSubtitle": "You are trying to redirect to a domain that does not match your configured domain (<Code>{{domain}}</Code>). Are you sure you want to continue?", | ||||||
|   | |||||||
| @@ -41,7 +41,8 @@ | |||||||
|     "totpTitle": "Enter your TOTP code", |     "totpTitle": "Enter your TOTP code", | ||||||
|     "unauthorizedTitle": "Unauthorized", |     "unauthorizedTitle": "Unauthorized", | ||||||
|     "unauthorizedResourceSubtitle": "The user with username <Code>{{username}}</Code> is not authorized to access the resource <Code>{{resource}}</Code>.", |     "unauthorizedResourceSubtitle": "The user with username <Code>{{username}}</Code> is not authorized to access the resource <Code>{{resource}}</Code>.", | ||||||
|     "unaothorizedLoginSubtitle": "The user with username <Code>{{username}}</Code> is not authorized to login.", |     "unauthorizedLoginSubtitle": "The user with username <Code>{{username}}</Code> is not authorized to login.", | ||||||
|  |     "unauthorizedGroupsSubtitle": "The user with username <Code>{{username}}</Code> is not in the groups required by the resource <Code>{{resource}}</Code>.", | ||||||
|     "unauthorizedButton": "Try again", |     "unauthorizedButton": "Try again", | ||||||
|     "untrustedRedirectTitle": "Untrusted redirect", |     "untrustedRedirectTitle": "Untrusted redirect", | ||||||
|     "untrustedRedirectSubtitle": "You are trying to redirect to a domain that does not match your configured domain (<Code>{{domain}}</Code>). Are you sure you want to continue?", |     "untrustedRedirectSubtitle": "You are trying to redirect to a domain that does not match your configured domain (<Code>{{domain}}</Code>). Are you sure you want to continue?", | ||||||
|   | |||||||
| @@ -19,6 +19,7 @@ import { TotpPage } from "./pages/totp-page.tsx"; | |||||||
| import { AppContextProvider } from "./context/app-context.tsx"; | import { AppContextProvider } from "./context/app-context.tsx"; | ||||||
| import "./lib/i18n/i18n.ts"; | import "./lib/i18n/i18n.ts"; | ||||||
| import { ForgotPasswordPage } from "./pages/forgot-password-page.tsx"; | import { ForgotPasswordPage } from "./pages/forgot-password-page.tsx"; | ||||||
|  | import "./index.css"; | ||||||
|  |  | ||||||
| const queryClient = new QueryClient(); | const queryClient = new QueryClient(); | ||||||
|  |  | ||||||
| @@ -38,7 +39,10 @@ createRoot(document.getElementById("root")!).render( | |||||||
|                 <Route path="/continue" element={<ContinuePage />} /> |                 <Route path="/continue" element={<ContinuePage />} /> | ||||||
|                 <Route path="/unauthorized" element={<UnauthorizedPage />} /> |                 <Route path="/unauthorized" element={<UnauthorizedPage />} /> | ||||||
|                 <Route path="/error" element={<InternalServerError />} /> |                 <Route path="/error" element={<InternalServerError />} /> | ||||||
|                 <Route path="/forgot-password" element={<ForgotPasswordPage />} /> |                 <Route | ||||||
|  |                   path="/forgot-password" | ||||||
|  |                   element={<ForgotPasswordPage />} | ||||||
|  |                 /> | ||||||
|                 <Route path="*" element={<NotFoundPage />} /> |                 <Route path="*" element={<NotFoundPage />} /> | ||||||
|               </Routes> |               </Routes> | ||||||
|             </BrowserRouter> |             </BrowserRouter> | ||||||
|   | |||||||
| @@ -10,7 +10,7 @@ import { useAppContext } from "../context/app-context"; | |||||||
| import { Trans, useTranslation } from "react-i18next"; | import { Trans, useTranslation } from "react-i18next"; | ||||||
|  |  | ||||||
| export const LogoutPage = () => { | export const LogoutPage = () => { | ||||||
|   const { isLoggedIn, username, oauth, provider } = useUserContext(); |   const { isLoggedIn, oauth, provider, email, username } = useUserContext(); | ||||||
|   const { genericName } = useAppContext(); |   const { genericName } = useAppContext(); | ||||||
|   const { t } = useTranslation(); |   const { t } = useTranslation(); | ||||||
|  |  | ||||||
| @@ -56,7 +56,7 @@ export const LogoutPage = () => { | |||||||
|               values={{ |               values={{ | ||||||
|                 provider: |                 provider: | ||||||
|                   provider === "generic" ? genericName : capitalize(provider), |                   provider === "generic" ? genericName : capitalize(provider), | ||||||
|                 username: username, |                 username: email, | ||||||
|               }} |               }} | ||||||
|             /> |             /> | ||||||
|           ) : ( |           ) : ( | ||||||
|   | |||||||
| @@ -3,11 +3,13 @@ import { Layout } from "../components/layouts/layout"; | |||||||
| import { Navigate } from "react-router"; | import { Navigate } from "react-router"; | ||||||
| import { isQueryValid } from "../utils/utils"; | import { isQueryValid } from "../utils/utils"; | ||||||
| import { Trans, useTranslation } from "react-i18next"; | import { Trans, useTranslation } from "react-i18next"; | ||||||
|  | import React from "react"; | ||||||
|  |  | ||||||
| export const UnauthorizedPage = () => { | export const UnauthorizedPage = () => { | ||||||
|   const queryString = window.location.search; |   const queryString = window.location.search; | ||||||
|   const params = new URLSearchParams(queryString); |   const params = new URLSearchParams(queryString); | ||||||
|   const username = params.get("username") ?? ""; |   const username = params.get("username") ?? ""; | ||||||
|  |   const groupErr = params.get("groupErr") ?? ""; | ||||||
|   const resource = params.get("resource") ?? ""; |   const resource = params.get("resource") ?? ""; | ||||||
|  |  | ||||||
|   const { t } = useTranslation(); |   const { t } = useTranslation(); | ||||||
| @@ -16,33 +18,54 @@ export const UnauthorizedPage = () => { | |||||||
|     return <Navigate to="/" />; |     return <Navigate to="/" />; | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  |   if (isQueryValid(resource) && !isQueryValid(groupErr)) { | ||||||
|  |     return ( | ||||||
|  |       <UnauthorizedLayout> | ||||||
|  |         <Trans | ||||||
|  |           i18nKey="unauthorizedResourceSubtitle" | ||||||
|  |           t={t} | ||||||
|  |           components={{ Code: <Code /> }} | ||||||
|  |           values={{ resource, username }} | ||||||
|  |         /> | ||||||
|  |       </UnauthorizedLayout> | ||||||
|  |     ); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   if (isQueryValid(groupErr) && isQueryValid(resource)) { | ||||||
|  |     return ( | ||||||
|  |       <UnauthorizedLayout> | ||||||
|  |         <Trans | ||||||
|  |           i18nKey="unauthorizedGroupsSubtitle" | ||||||
|  |           t={t} | ||||||
|  |           components={{ Code: <Code /> }} | ||||||
|  |           values={{ username, resource }} | ||||||
|  |         /> | ||||||
|  |       </UnauthorizedLayout> | ||||||
|  |     ); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   return ( | ||||||
|  |     <UnauthorizedLayout> | ||||||
|  |       <Trans | ||||||
|  |         i18nKey="unauthorizedLoginSubtitle" | ||||||
|  |         t={t} | ||||||
|  |         components={{ Code: <Code /> }} | ||||||
|  |         values={{ username }} | ||||||
|  |       /> | ||||||
|  |     </UnauthorizedLayout> | ||||||
|  |   ); | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | const UnauthorizedLayout = ({ children }: { children: React.ReactNode }) => { | ||||||
|  |   const { t } = useTranslation(); | ||||||
|  |  | ||||||
|   return ( |   return ( | ||||||
|     <Layout> |     <Layout> | ||||||
|       <Paper shadow="md" p={30} mt={30} radius="md" withBorder> |       <Paper shadow="md" p={30} mt={30} radius="md" withBorder> | ||||||
|         <Text size="xl" fw={700}> |         <Text size="xl" fw={700}> | ||||||
|           {t("Unauthorized")} |           {t("Unauthorized")} | ||||||
|         </Text> |         </Text> | ||||||
|         <Text> |         <Text>{children}</Text> | ||||||
|           {isQueryValid(resource) ? ( |  | ||||||
|             <Text> |  | ||||||
|               <Trans |  | ||||||
|                 i18nKey="unauthorizedResourceSubtitle" |  | ||||||
|                 t={t} |  | ||||||
|                 components={{ Code: <Code /> }} |  | ||||||
|                 values={{ resource, username }} |  | ||||||
|               /> |  | ||||||
|             </Text> |  | ||||||
|           ) : ( |  | ||||||
|             <Text> |  | ||||||
|               <Trans |  | ||||||
|                 i18nKey="unaothorizedLoginSubtitle" |  | ||||||
|                 t={t} |  | ||||||
|                 components={{ Code: <Code /> }} |  | ||||||
|                 values={{ username }} |  | ||||||
|               /> |  | ||||||
|             </Text> |  | ||||||
|           )} |  | ||||||
|         </Text> |  | ||||||
|         <Button |         <Button | ||||||
|           fullWidth |           fullWidth | ||||||
|           mt="xl" |           mt="xl" | ||||||
|   | |||||||
| @@ -3,6 +3,8 @@ import { z } from "zod"; | |||||||
| export const userContextSchema = z.object({ | export const userContextSchema = z.object({ | ||||||
|   isLoggedIn: z.boolean(), |   isLoggedIn: z.boolean(), | ||||||
|   username: z.string(), |   username: z.string(), | ||||||
|  |   name: z.string(), | ||||||
|  |   email: z.string(), | ||||||
|   oauth: z.boolean(), |   oauth: z.boolean(), | ||||||
|   provider: z.string(), |   provider: z.string(), | ||||||
|   totpPending: z.boolean(), |   totpPending: z.boolean(), | ||||||
|   | |||||||
| @@ -45,6 +45,11 @@ var authConfig = types.AuthConfig{ | |||||||
| 	LoginMaxRetries: 0, | 	LoginMaxRetries: 0, | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // Simple hooks config for tests | ||||||
|  | var hooksConfig = types.HooksConfig{ | ||||||
|  | 	Domain: "localhost", | ||||||
|  | } | ||||||
|  |  | ||||||
| // Cookie | // Cookie | ||||||
| var cookie string | var cookie string | ||||||
|  |  | ||||||
| @@ -83,7 +88,7 @@ func getAPI(t *testing.T) *api.API { | |||||||
| 	providers.Init() | 	providers.Init() | ||||||
|  |  | ||||||
| 	// Create hooks service | 	// Create hooks service | ||||||
| 	hooks := hooks.NewHooks(auth, providers) | 	hooks := hooks.NewHooks(hooksConfig, auth, providers) | ||||||
|  |  | ||||||
| 	// Create handlers service | 	// Create handlers service | ||||||
| 	handlers := handlers.NewHandlers(handlersConfig, auth, hooks, providers, docker) | 	handlers := handlers.NewHandlers(handlersConfig, auth, hooks, providers, docker) | ||||||
|   | |||||||
| @@ -160,9 +160,12 @@ func (auth *Auth) CreateSessionCookie(c *gin.Context, data *types.SessionCookie) | |||||||
|  |  | ||||||
| 	// Set data | 	// Set data | ||||||
| 	session.Values["username"] = data.Username | 	session.Values["username"] = data.Username | ||||||
|  | 	session.Values["name"] = data.Name | ||||||
|  | 	session.Values["email"] = data.Email | ||||||
| 	session.Values["provider"] = data.Provider | 	session.Values["provider"] = data.Provider | ||||||
| 	session.Values["expiry"] = time.Now().Add(time.Duration(sessionExpiry) * time.Second).Unix() | 	session.Values["expiry"] = time.Now().Add(time.Duration(sessionExpiry) * time.Second).Unix() | ||||||
| 	session.Values["totpPending"] = data.TotpPending | 	session.Values["totpPending"] = data.TotpPending | ||||||
|  | 	session.Values["oauthGroups"] = data.OAuthGroups | ||||||
|  |  | ||||||
| 	// Save session | 	// Save session | ||||||
| 	err = session.Save(c.Request, c.Writer) | 	err = session.Save(c.Request, c.Writer) | ||||||
| @@ -211,14 +214,24 @@ func (auth *Auth) GetSessionCookie(c *gin.Context) (types.SessionCookie, error) | |||||||
| 		return types.SessionCookie{}, err | 		return types.SessionCookie{}, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	log.Debug().Msg("Got session") | ||||||
|  |  | ||||||
| 	// Get data from session | 	// Get data from session | ||||||
| 	username, usernameOk := session.Values["username"].(string) | 	username, usernameOk := session.Values["username"].(string) | ||||||
|  | 	email, emailOk := session.Values["email"].(string) | ||||||
|  | 	name, nameOk := session.Values["name"].(string) | ||||||
| 	provider, providerOK := session.Values["provider"].(string) | 	provider, providerOK := session.Values["provider"].(string) | ||||||
| 	expiry, expiryOk := session.Values["expiry"].(int64) | 	expiry, expiryOk := session.Values["expiry"].(int64) | ||||||
| 	totpPending, totpPendingOk := session.Values["totpPending"].(bool) | 	totpPending, totpPendingOk := session.Values["totpPending"].(bool) | ||||||
|  | 	oauthGroups, oauthGroupsOk := session.Values["oauthGroups"].(string) | ||||||
|  |  | ||||||
| 	if !usernameOk || !providerOK || !expiryOk || !totpPendingOk { | 	if !usernameOk || !providerOK || !expiryOk || !totpPendingOk || !emailOk || !nameOk || !oauthGroupsOk { | ||||||
| 		log.Warn().Msg("Session cookie is missing data") | 		log.Warn().Msg("Session cookie is invalid") | ||||||
|  |  | ||||||
|  | 		// If any data is missing, delete the session cookie | ||||||
|  | 		auth.DeleteSessionCookie(c) | ||||||
|  |  | ||||||
|  | 		// Return empty cookie | ||||||
| 		return types.SessionCookie{}, nil | 		return types.SessionCookie{}, nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -233,13 +246,16 @@ func (auth *Auth) GetSessionCookie(c *gin.Context) (types.SessionCookie, error) | |||||||
| 		return types.SessionCookie{}, nil | 		return types.SessionCookie{}, nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debug().Str("username", username).Str("provider", provider).Int64("expiry", expiry).Bool("totpPending", totpPending).Msg("Parsed cookie") | 	log.Debug().Str("username", username).Str("provider", provider).Int64("expiry", expiry).Bool("totpPending", totpPending).Str("name", name).Str("email", email).Str("oauthGroups", oauthGroups).Msg("Parsed cookie") | ||||||
|  |  | ||||||
| 	// Return the cookie | 	// Return the cookie | ||||||
| 	return types.SessionCookie{ | 	return types.SessionCookie{ | ||||||
| 		Username:    username, | 		Username:    username, | ||||||
|  | 		Name:        name, | ||||||
|  | 		Email:       email, | ||||||
| 		Provider:    provider, | 		Provider:    provider, | ||||||
| 		TotpPending: totpPending, | 		TotpPending: totpPending, | ||||||
|  | 		OAuthGroups: oauthGroups, | ||||||
| 	}, nil | 	}, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -248,48 +264,52 @@ func (auth *Auth) UserAuthConfigured() bool { | |||||||
| 	return len(auth.Config.Users) > 0 | 	return len(auth.Config.Users) > 0 | ||||||
| } | } | ||||||
|  |  | ||||||
| func (auth *Auth) ResourceAllowed(c *gin.Context, context types.UserContext) (bool, error) { | func (auth *Auth) ResourceAllowed(c *gin.Context, context types.UserContext, labels types.TinyauthLabels) bool { | ||||||
| 	// Get headers |  | ||||||
| 	host := c.Request.Header.Get("X-Forwarded-Host") |  | ||||||
|  |  | ||||||
| 	// Get app id |  | ||||||
| 	appId := strings.Split(host, ".")[0] |  | ||||||
|  |  | ||||||
| 	// Get the container labels |  | ||||||
| 	labels, err := auth.Docker.GetLabels(appId) |  | ||||||
|  |  | ||||||
| 	// If there is an error, return false |  | ||||||
| 	if err != nil { |  | ||||||
| 		return false, err |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Check if oauth is allowed | 	// Check if oauth is allowed | ||||||
| 	if context.OAuth { | 	if context.OAuth { | ||||||
| 		log.Debug().Msg("Checking OAuth whitelist") | 		log.Debug().Msg("Checking OAuth whitelist") | ||||||
| 		return utils.CheckWhitelist(labels.OAuthWhitelist, context.Username), nil | 		return utils.CheckWhitelist(labels.OAuthWhitelist, context.Email) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Check users | 	// Check users | ||||||
| 	log.Debug().Msg("Checking users") | 	log.Debug().Msg("Checking users") | ||||||
|  |  | ||||||
| 	return utils.CheckWhitelist(labels.Users, context.Username), nil | 	return utils.CheckWhitelist(labels.Users, context.Username) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (auth *Auth) AuthEnabled(c *gin.Context) (bool, error) { | func (auth *Auth) OAuthGroup(c *gin.Context, context types.UserContext, labels types.TinyauthLabels) bool { | ||||||
|  | 	// Check if groups are required | ||||||
|  | 	if labels.OAuthGroups == "" { | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Check if we are using the generic oauth provider | ||||||
|  | 	if context.Provider != "generic" { | ||||||
|  | 		log.Debug().Msg("Not using generic provider, skipping group check") | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Split the groups by comma (no need to parse since they are from the API response) | ||||||
|  | 	oauthGroups := strings.Split(context.OAuthGroups, ",") | ||||||
|  |  | ||||||
|  | 	// For every group check if it is in the required groups | ||||||
|  | 	for _, group := range oauthGroups { | ||||||
|  | 		if utils.CheckWhitelist(labels.OAuthGroups, group) { | ||||||
|  | 			log.Debug().Str("group", group).Msg("Group is in required groups") | ||||||
|  | 			return true | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// No groups matched | ||||||
|  | 	log.Debug().Msg("No groups matched") | ||||||
|  |  | ||||||
|  | 	// Return false | ||||||
|  | 	return false | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (auth *Auth) AuthEnabled(c *gin.Context, labels types.TinyauthLabels) (bool, error) { | ||||||
| 	// Get headers | 	// Get headers | ||||||
| 	uri := c.Request.Header.Get("X-Forwarded-Uri") | 	uri := c.Request.Header.Get("X-Forwarded-Uri") | ||||||
| 	host := c.Request.Header.Get("X-Forwarded-Host") |  | ||||||
|  |  | ||||||
| 	// Get app id |  | ||||||
| 	appId := strings.Split(host, ".")[0] |  | ||||||
|  |  | ||||||
| 	// Get the container labels |  | ||||||
| 	labels, err := auth.Docker.GetLabels(appId) |  | ||||||
|  |  | ||||||
| 	// If there is an error, auth enabled |  | ||||||
| 	if err != nil { |  | ||||||
| 		return true, err |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Check if the allowed label is empty | 	// Check if the allowed label is empty | ||||||
| 	if labels.Allowed == "" { | 	if labels.Allowed == "" { | ||||||
|   | |||||||
| @@ -6,4 +6,13 @@ var TinyauthLabels = []string{ | |||||||
| 	"tinyauth.users", | 	"tinyauth.users", | ||||||
| 	"tinyauth.allowed", | 	"tinyauth.allowed", | ||||||
| 	"tinyauth.headers", | 	"tinyauth.headers", | ||||||
|  | 	"tinyauth.oauth.groups", | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Claims are the OIDC supported claims (including preferd username for some reason) | ||||||
|  | type Claims struct { | ||||||
|  | 	Name              string   `json:"name"` | ||||||
|  | 	Email             string   `json:"email"` | ||||||
|  | 	PreferredUsername string   `json:"preferred_username"` | ||||||
|  | 	Groups            []string `json:"groups"` | ||||||
| } | } | ||||||
|   | |||||||
| @@ -6,8 +6,7 @@ import ( | |||||||
| 	"tinyauth/internal/types" | 	"tinyauth/internal/types" | ||||||
| 	"tinyauth/internal/utils" | 	"tinyauth/internal/utils" | ||||||
|  |  | ||||||
| 	apiTypes "github.com/docker/docker/api/types" | 	container "github.com/docker/docker/api/types/container" | ||||||
| 	containerTypes "github.com/docker/docker/api/types/container" |  | ||||||
| 	"github.com/docker/docker/client" | 	"github.com/docker/docker/client" | ||||||
| 	"github.com/rs/zerolog/log" | 	"github.com/rs/zerolog/log" | ||||||
| ) | ) | ||||||
| @@ -38,9 +37,9 @@ func (docker *Docker) Init() error { | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (docker *Docker) GetContainers() ([]apiTypes.Container, error) { | func (docker *Docker) GetContainers() ([]container.Summary, error) { | ||||||
| 	// Get the list of containers | 	// Get the list of containers | ||||||
| 	containers, err := docker.Client.ContainerList(docker.Context, containerTypes.ListOptions{}) | 	containers, err := docker.Client.ContainerList(docker.Context, container.ListOptions{}) | ||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @@ -51,13 +50,13 @@ func (docker *Docker) GetContainers() ([]apiTypes.Container, error) { | |||||||
| 	return containers, nil | 	return containers, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (docker *Docker) InspectContainer(containerId string) (apiTypes.ContainerJSON, error) { | func (docker *Docker) InspectContainer(containerId string) (container.InspectResponse, error) { | ||||||
| 	// Inspect the container | 	// Inspect the container | ||||||
| 	inspect, err := docker.Client.ContainerInspect(docker.Context, containerId) | 	inspect, err := docker.Client.ContainerInspect(docker.Context, containerId) | ||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return apiTypes.ContainerJSON{}, err | 		return container.InspectResponse{}, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Return the inspect | 	// Return the inspect | ||||||
|   | |||||||
| @@ -10,6 +10,7 @@ import ( | |||||||
| 	"tinyauth/internal/hooks" | 	"tinyauth/internal/hooks" | ||||||
| 	"tinyauth/internal/providers" | 	"tinyauth/internal/providers" | ||||||
| 	"tinyauth/internal/types" | 	"tinyauth/internal/types" | ||||||
|  | 	"tinyauth/internal/utils" | ||||||
|  |  | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"github.com/google/go-querystring/query" | 	"github.com/google/go-querystring/query" | ||||||
| @@ -68,12 +69,15 @@ func (h *Handlers) AuthHandler(c *gin.Context) { | |||||||
| 	proto := c.Request.Header.Get("X-Forwarded-Proto") | 	proto := c.Request.Header.Get("X-Forwarded-Proto") | ||||||
| 	host := c.Request.Header.Get("X-Forwarded-Host") | 	host := c.Request.Header.Get("X-Forwarded-Host") | ||||||
|  |  | ||||||
| 	// Check if auth is enabled | 	// Get the app id | ||||||
| 	authEnabled, err := h.Auth.AuthEnabled(c) | 	appId := strings.Split(host, ".")[0] | ||||||
|  |  | ||||||
|  | 	// Get the container labels | ||||||
|  | 	labels, err := h.Docker.GetLabels(appId) | ||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Error().Err(err).Msg("Failed to check if app is allowed") | 		log.Error().Err(err).Msg("Failed to get container labels") | ||||||
|  |  | ||||||
| 		if proxy.Proxy == "nginx" || !isBrowser { | 		if proxy.Proxy == "nginx" || !isBrowser { | ||||||
| 			c.JSON(500, gin.H{ | 			c.JSON(500, gin.H{ | ||||||
| @@ -87,11 +91,8 @@ func (h *Handlers) AuthHandler(c *gin.Context) { | |||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Get the app id | 	// Check if auth is enabled | ||||||
| 	appId := strings.Split(host, ".")[0] | 	authEnabled, err := h.Auth.AuthEnabled(c, labels) | ||||||
|  |  | ||||||
| 	// Get the container labels |  | ||||||
| 	labels, err := h.Docker.GetLabels(appId) |  | ||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @@ -113,7 +114,7 @@ func (h *Handlers) AuthHandler(c *gin.Context) { | |||||||
| 	if !authEnabled { | 	if !authEnabled { | ||||||
| 		for key, value := range labels.Headers { | 		for key, value := range labels.Headers { | ||||||
| 			log.Debug().Str("key", key).Str("value", value).Msg("Setting header") | 			log.Debug().Str("key", key).Str("value", value).Msg("Setting header") | ||||||
| 			c.Header(key, value) | 			c.Header(key, utils.SanitizeHeader(value)) | ||||||
| 		} | 		} | ||||||
| 		c.JSON(200, gin.H{ | 		c.JSON(200, gin.H{ | ||||||
| 			"status":  200, | 			"status":  200, | ||||||
| @@ -130,23 +131,7 @@ func (h *Handlers) AuthHandler(c *gin.Context) { | |||||||
| 		log.Debug().Msg("Authenticated") | 		log.Debug().Msg("Authenticated") | ||||||
|  |  | ||||||
| 		// Check if user is allowed to access subdomain, if request is nginx.example.com the subdomain (resource) is nginx | 		// Check if user is allowed to access subdomain, if request is nginx.example.com the subdomain (resource) is nginx | ||||||
| 		appAllowed, err := h.Auth.ResourceAllowed(c, userContext) | 		appAllowed := h.Auth.ResourceAllowed(c, userContext, labels) | ||||||
|  |  | ||||||
| 		// Check if there was an error |  | ||||||
| 		if err != nil { |  | ||||||
| 			log.Error().Err(err).Msg("Failed to check if app is allowed") |  | ||||||
|  |  | ||||||
| 			if proxy.Proxy == "nginx" || !isBrowser { |  | ||||||
| 				c.JSON(500, gin.H{ |  | ||||||
| 					"status":  500, |  | ||||||
| 					"message": "Internal Server Error", |  | ||||||
| 				}) |  | ||||||
| 				return |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		log.Debug().Bool("appAllowed", appAllowed).Msg("Checking if app is allowed") | 		log.Debug().Bool("appAllowed", appAllowed).Msg("Checking if app is allowed") | ||||||
|  |  | ||||||
| @@ -165,11 +150,20 @@ func (h *Handlers) AuthHandler(c *gin.Context) { | |||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			// Build query | 			// Values | ||||||
| 			queries, err := query.Values(types.UnauthorizedQuery{ | 			values := types.UnauthorizedQuery{ | ||||||
| 				Username: userContext.Username, |  | ||||||
| 				Resource: strings.Split(host, ".")[0], | 				Resource: strings.Split(host, ".")[0], | ||||||
| 			}) | 			} | ||||||
|  |  | ||||||
|  | 			// Use either username or email | ||||||
|  | 			if userContext.OAuth { | ||||||
|  | 				values.Username = userContext.Email | ||||||
|  | 			} else { | ||||||
|  | 				values.Username = userContext.Username | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			// Build query | ||||||
|  | 			queries, err := query.Values(values) | ||||||
|  |  | ||||||
| 			// Handle error (no need to check for nginx/headers since we are sure we are using caddy/traefik) | 			// Handle error (no need to check for nginx/headers since we are sure we are using caddy/traefik) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| @@ -183,13 +177,65 @@ func (h *Handlers) AuthHandler(c *gin.Context) { | |||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// Set the user header | 		log.Debug().Interface("labels", labels).Msg("Got labels") | ||||||
| 		c.Header("Remote-User", userContext.Username) |  | ||||||
|  | 		// Check if user is in required groups | ||||||
|  | 		groupOk := h.Auth.OAuthGroup(c, userContext, labels) | ||||||
|  |  | ||||||
|  | 		log.Debug().Bool("groupOk", groupOk).Msg("Checking if user is in required groups") | ||||||
|  |  | ||||||
|  | 		// The user is not allowed to access the app | ||||||
|  | 		if !groupOk { | ||||||
|  | 			log.Warn().Str("username", userContext.Username).Str("host", host).Msg("User is not in required groups") | ||||||
|  |  | ||||||
|  | 			// Set WWW-Authenticate header | ||||||
|  | 			c.Header("WWW-Authenticate", "Basic realm=\"tinyauth\"") | ||||||
|  |  | ||||||
|  | 			if proxy.Proxy == "nginx" || !isBrowser { | ||||||
|  | 				c.JSON(401, gin.H{ | ||||||
|  | 					"status":  401, | ||||||
|  | 					"message": "Unauthorized", | ||||||
|  | 				}) | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			// Values | ||||||
|  | 			values := types.UnauthorizedQuery{ | ||||||
|  | 				Resource: strings.Split(host, ".")[0], | ||||||
|  | 				GroupErr: true, | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			// Use either username or email | ||||||
|  | 			if userContext.OAuth { | ||||||
|  | 				values.Username = userContext.Email | ||||||
|  | 			} else { | ||||||
|  | 				values.Username = userContext.Username | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			// Build query | ||||||
|  | 			queries, err := query.Values(values) | ||||||
|  |  | ||||||
|  | 			// Handle error (no need to check for nginx/headers since we are sure we are using caddy/traefik) | ||||||
|  | 			if err != nil { | ||||||
|  | 				log.Error().Err(err).Msg("Failed to build queries") | ||||||
|  | 				c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			// We are using caddy/traefik so redirect | ||||||
|  | 			c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", h.Config.AppURL, queries.Encode())) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		c.Header("Remote-User", utils.SanitizeHeader(userContext.Username)) | ||||||
|  | 		c.Header("Remote-Name", utils.SanitizeHeader(userContext.Name)) | ||||||
|  | 		c.Header("Remote-Email", utils.SanitizeHeader(userContext.Email)) | ||||||
|  | 		c.Header("Remote-Groups", utils.SanitizeHeader(userContext.OAuthGroups)) | ||||||
|  |  | ||||||
| 		// Set the rest of the headers | 		// Set the rest of the headers | ||||||
| 		for key, value := range labels.Headers { | 		for key, value := range labels.Headers { | ||||||
| 			log.Debug().Str("key", key).Str("value", value).Msg("Setting header") | 			log.Debug().Str("key", key).Str("value", value).Msg("Setting header") | ||||||
| 			c.Header(key, value) | 			c.Header(key, utils.SanitizeHeader(value)) | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// The user is allowed to access the app | 		// The user is allowed to access the app | ||||||
| @@ -310,6 +356,8 @@ func (h *Handlers) LoginHandler(c *gin.Context) { | |||||||
| 		// Set totp pending cookie | 		// Set totp pending cookie | ||||||
| 		h.Auth.CreateSessionCookie(c, &types.SessionCookie{ | 		h.Auth.CreateSessionCookie(c, &types.SessionCookie{ | ||||||
| 			Username:    login.Username, | 			Username:    login.Username, | ||||||
|  | 			Name:        utils.Capitalize(login.Username), | ||||||
|  | 			Email:       fmt.Sprintf("%s@%s", strings.ToLower(login.Username), h.Config.Domain), | ||||||
| 			Provider:    "username", | 			Provider:    "username", | ||||||
| 			TotpPending: true, | 			TotpPending: true, | ||||||
| 		}) | 		}) | ||||||
| @@ -328,6 +376,8 @@ func (h *Handlers) LoginHandler(c *gin.Context) { | |||||||
| 	// Create session cookie with username as provider | 	// Create session cookie with username as provider | ||||||
| 	h.Auth.CreateSessionCookie(c, &types.SessionCookie{ | 	h.Auth.CreateSessionCookie(c, &types.SessionCookie{ | ||||||
| 		Username: login.Username, | 		Username: login.Username, | ||||||
|  | 		Name:     utils.Capitalize(login.Username), | ||||||
|  | 		Email:    fmt.Sprintf("%s@%s", strings.ToLower(login.Username), h.Config.Domain), | ||||||
| 		Provider: "username", | 		Provider: "username", | ||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
| @@ -402,6 +452,8 @@ func (h *Handlers) TotpHandler(c *gin.Context) { | |||||||
| 	// Create session cookie with username as provider | 	// Create session cookie with username as provider | ||||||
| 	h.Auth.CreateSessionCookie(c, &types.SessionCookie{ | 	h.Auth.CreateSessionCookie(c, &types.SessionCookie{ | ||||||
| 		Username: user.Username, | 		Username: user.Username, | ||||||
|  | 		Name:     utils.Capitalize(user.Username), | ||||||
|  | 		Email:    fmt.Sprintf("%s@%s", strings.ToLower(user.Username), h.Config.Domain), | ||||||
| 		Provider: "username", | 		Provider: "username", | ||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
| @@ -465,6 +517,8 @@ func (h *Handlers) UserHandler(c *gin.Context) { | |||||||
| 		Status:      200, | 		Status:      200, | ||||||
| 		IsLoggedIn:  userContext.IsLoggedIn, | 		IsLoggedIn:  userContext.IsLoggedIn, | ||||||
| 		Username:    userContext.Username, | 		Username:    userContext.Username, | ||||||
|  | 		Name:        userContext.Name, | ||||||
|  | 		Email:       userContext.Email, | ||||||
| 		Provider:    userContext.Provider, | 		Provider:    userContext.Provider, | ||||||
| 		Oauth:       userContext.OAuth, | 		Oauth:       userContext.OAuth, | ||||||
| 		TotpPending: userContext.TotpPending, | 		TotpPending: userContext.TotpPending, | ||||||
| @@ -613,25 +667,32 @@ func (h *Handlers) OauthCallbackHandler(c *gin.Context) { | |||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Get email | 	// Get user | ||||||
| 	email, err := h.Providers.GetUser(providerName.Provider) | 	user, err := h.Providers.GetUser(providerName.Provider) | ||||||
|  |  | ||||||
| 	log.Debug().Str("email", email).Msg("Got email") |  | ||||||
|  |  | ||||||
| 	// Handle error | 	// Handle error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Error().Err(err).Msg("Failed to get email") | 		log.Error().Msg("Failed to get user") | ||||||
|  | 		c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	log.Debug().Msg("Got user") | ||||||
|  |  | ||||||
|  | 	// Check that email is not empty | ||||||
|  | 	if user.Email == "" { | ||||||
|  | 		log.Error().Msg("Email is empty") | ||||||
| 		c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) | 		c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Email is not whitelisted | 	// Email is not whitelisted | ||||||
| 	if !h.Auth.EmailWhitelisted(email) { | 	if !h.Auth.EmailWhitelisted(user.Email) { | ||||||
| 		log.Warn().Str("email", email).Msg("Email not whitelisted") | 		log.Warn().Str("email", user.Email).Msg("Email not whitelisted") | ||||||
|  |  | ||||||
| 		// Build query | 		// Build query | ||||||
| 		queries, err := query.Values(types.UnauthorizedQuery{ | 		queries, err := query.Values(types.UnauthorizedQuery{ | ||||||
| 			Username: email, | 			Username: user.Email, | ||||||
| 		}) | 		}) | ||||||
|  |  | ||||||
| 		// Handle error | 		// Handle error | ||||||
| @@ -647,10 +708,31 @@ func (h *Handlers) OauthCallbackHandler(c *gin.Context) { | |||||||
|  |  | ||||||
| 	log.Debug().Msg("Email whitelisted") | 	log.Debug().Msg("Email whitelisted") | ||||||
|  |  | ||||||
|  | 	// Get username | ||||||
|  | 	var username string | ||||||
|  |  | ||||||
|  | 	if user.PreferredUsername != "" { | ||||||
|  | 		username = user.PreferredUsername | ||||||
|  | 	} else { | ||||||
|  | 		username = fmt.Sprintf("%s_%s", strings.Split(user.Email, "@")[0], strings.Split(user.Email, "@")[1]) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Get name | ||||||
|  | 	var name string | ||||||
|  |  | ||||||
|  | 	if user.Name != "" { | ||||||
|  | 		name = user.Name | ||||||
|  | 	} else { | ||||||
|  | 		name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1]) | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	// Create session cookie (also cleans up redirect cookie) | 	// Create session cookie (also cleans up redirect cookie) | ||||||
| 	h.Auth.CreateSessionCookie(c, &types.SessionCookie{ | 	h.Auth.CreateSessionCookie(c, &types.SessionCookie{ | ||||||
| 		Username: email, | 		Username:    username, | ||||||
| 		Provider: providerName.Provider, | 		Name:        name, | ||||||
|  | 		Email:       user.Email, | ||||||
|  | 		Provider:    providerName.Provider, | ||||||
|  | 		OAuthGroups: strings.Join(user.Groups, ","), | ||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
| 	// Check if we have a redirect URI | 	// Check if we have a redirect URI | ||||||
|   | |||||||
| @@ -1,22 +1,27 @@ | |||||||
| package hooks | package hooks | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"strings" | ||||||
| 	"tinyauth/internal/auth" | 	"tinyauth/internal/auth" | ||||||
| 	"tinyauth/internal/providers" | 	"tinyauth/internal/providers" | ||||||
| 	"tinyauth/internal/types" | 	"tinyauth/internal/types" | ||||||
|  | 	"tinyauth/internal/utils" | ||||||
|  |  | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"github.com/rs/zerolog/log" | 	"github.com/rs/zerolog/log" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func NewHooks(auth *auth.Auth, providers *providers.Providers) *Hooks { | func NewHooks(config types.HooksConfig, auth *auth.Auth, providers *providers.Providers) *Hooks { | ||||||
| 	return &Hooks{ | 	return &Hooks{ | ||||||
|  | 		Config:    config, | ||||||
| 		Auth:      auth, | 		Auth:      auth, | ||||||
| 		Providers: providers, | 		Providers: providers, | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| type Hooks struct { | type Hooks struct { | ||||||
|  | 	Config    types.HooksConfig | ||||||
| 	Auth      *auth.Auth | 	Auth      *auth.Auth | ||||||
| 	Providers *providers.Providers | 	Providers *providers.Providers | ||||||
| } | } | ||||||
| @@ -36,11 +41,11 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { | |||||||
| 		if user != nil && hooks.Auth.CheckPassword(*user, basic.Password) { | 		if user != nil && hooks.Auth.CheckPassword(*user, basic.Password) { | ||||||
| 			// Return user context since we are logged in with basic auth | 			// Return user context since we are logged in with basic auth | ||||||
| 			return types.UserContext{ | 			return types.UserContext{ | ||||||
| 				Username:    basic.Username, | 				Username:   basic.Username, | ||||||
| 				IsLoggedIn:  true, | 				Name:       utils.Capitalize(basic.Username), | ||||||
| 				OAuth:       false, | 				Email:      fmt.Sprintf("%s@%s", strings.ToLower(basic.Username), hooks.Config.Domain), | ||||||
| 				Provider:    "basic", | 				IsLoggedIn: true, | ||||||
| 				TotpPending: false, | 				Provider:   "basic", | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| @@ -50,13 +55,7 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Error().Err(err).Msg("Failed to get session cookie") | 		log.Error().Err(err).Msg("Failed to get session cookie") | ||||||
| 		// Return empty context | 		// Return empty context | ||||||
| 		return types.UserContext{ | 		return types.UserContext{} | ||||||
| 			Username:    "", |  | ||||||
| 			IsLoggedIn:  false, |  | ||||||
| 			OAuth:       false, |  | ||||||
| 			Provider:    "", |  | ||||||
| 			TotpPending: false, |  | ||||||
| 		} |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Check if session cookie has totp pending | 	// Check if session cookie has totp pending | ||||||
| @@ -65,8 +64,8 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { | |||||||
| 		// Return empty context since we are pending totp | 		// Return empty context since we are pending totp | ||||||
| 		return types.UserContext{ | 		return types.UserContext{ | ||||||
| 			Username:    cookie.Username, | 			Username:    cookie.Username, | ||||||
| 			IsLoggedIn:  false, | 			Name:        cookie.Name, | ||||||
| 			OAuth:       false, | 			Email:       cookie.Email, | ||||||
| 			Provider:    cookie.Provider, | 			Provider:    cookie.Provider, | ||||||
| 			TotpPending: true, | 			TotpPending: true, | ||||||
| 		} | 		} | ||||||
| @@ -82,11 +81,11 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { | |||||||
|  |  | ||||||
| 			// It exists so we are logged in | 			// It exists so we are logged in | ||||||
| 			return types.UserContext{ | 			return types.UserContext{ | ||||||
| 				Username:    cookie.Username, | 				Username:   cookie.Username, | ||||||
| 				IsLoggedIn:  true, | 				Name:       cookie.Name, | ||||||
| 				OAuth:       false, | 				Email:      cookie.Email, | ||||||
| 				Provider:    "username", | 				IsLoggedIn: true, | ||||||
| 				TotpPending: false, | 				Provider:   "username", | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| @@ -108,13 +107,7 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { | |||||||
| 			hooks.Auth.DeleteSessionCookie(c) | 			hooks.Auth.DeleteSessionCookie(c) | ||||||
|  |  | ||||||
| 			// Return empty context | 			// Return empty context | ||||||
| 			return types.UserContext{ | 			return types.UserContext{} | ||||||
| 				Username:    "", |  | ||||||
| 				IsLoggedIn:  false, |  | ||||||
| 				OAuth:       false, |  | ||||||
| 				Provider:    "", |  | ||||||
| 				TotpPending: false, |  | ||||||
| 			} |  | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Email is whitelisted") | 		log.Debug().Msg("Email is whitelisted") | ||||||
| @@ -122,19 +115,15 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { | |||||||
| 		// Return user context since we are logged in with oauth | 		// Return user context since we are logged in with oauth | ||||||
| 		return types.UserContext{ | 		return types.UserContext{ | ||||||
| 			Username:    cookie.Username, | 			Username:    cookie.Username, | ||||||
|  | 			Name:        cookie.Name, | ||||||
|  | 			Email:       cookie.Email, | ||||||
| 			IsLoggedIn:  true, | 			IsLoggedIn:  true, | ||||||
| 			OAuth:       true, | 			OAuth:       true, | ||||||
| 			Provider:    cookie.Provider, | 			Provider:    cookie.Provider, | ||||||
| 			TotpPending: false, | 			OAuthGroups: cookie.OAuthGroups, | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Neither basic auth or oauth is set so we return an empty context | 	// Neither basic auth or oauth is set so we return an empty context | ||||||
| 	return types.UserContext{ | 	return types.UserContext{} | ||||||
| 		Username:    "", |  | ||||||
| 		IsLoggedIn:  false, |  | ||||||
| 		OAuth:       false, |  | ||||||
| 		Provider:    "", |  | ||||||
| 		TotpPending: false, |  | ||||||
| 	} |  | ||||||
| } | } | ||||||
|   | |||||||
| @@ -4,24 +4,25 @@ import ( | |||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"io" | 	"io" | ||||||
| 	"net/http" | 	"net/http" | ||||||
|  | 	"tinyauth/internal/constants" | ||||||
|  |  | ||||||
| 	"github.com/rs/zerolog/log" | 	"github.com/rs/zerolog/log" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // We are assuming that the generic provider will return a JSON object with an email field | func GetGenericUser(client *http.Client, url string) (constants.Claims, error) { | ||||||
| type GenericUserInfoResponse struct { | 	// Create user struct | ||||||
| 	Email string `json:"email"` | 	var user constants.Claims | ||||||
| } |  | ||||||
|  |  | ||||||
| func GetGenericEmail(client *http.Client, url string) (string, error) { |  | ||||||
| 	// Using the oauth client get the user info url | 	// Using the oauth client get the user info url | ||||||
| 	res, err := client.Get(url) | 	res, err := client.Get(url) | ||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", err | 		return user, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	defer res.Body.Close() | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Got response from generic provider") | 	log.Debug().Msg("Got response from generic provider") | ||||||
|  |  | ||||||
| 	// Read the body of the response | 	// Read the body of the response | ||||||
| @@ -29,24 +30,21 @@ func GetGenericEmail(client *http.Client, url string) (string, error) { | |||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", err | 		return user, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Read body from generic provider") | 	log.Debug().Msg("Read body from generic provider") | ||||||
|  |  | ||||||
| 	// Parse the body into a user struct |  | ||||||
| 	var user GenericUserInfoResponse |  | ||||||
|  |  | ||||||
| 	// Unmarshal the body into the user struct | 	// Unmarshal the body into the user struct | ||||||
| 	err = json.Unmarshal(body, &user) | 	err = json.Unmarshal(body, &user) | ||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", err | 		return user, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Parsed user from generic provider") | 	log.Debug().Msg("Parsed user from generic provider") | ||||||
|  |  | ||||||
| 	// Return the email | 	// Return the user | ||||||
| 	return user.Email, nil | 	return user, nil | ||||||
| } | } | ||||||
|   | |||||||
| @@ -5,51 +5,96 @@ import ( | |||||||
| 	"errors" | 	"errors" | ||||||
| 	"io" | 	"io" | ||||||
| 	"net/http" | 	"net/http" | ||||||
|  | 	"tinyauth/internal/constants" | ||||||
|  |  | ||||||
| 	"github.com/rs/zerolog/log" | 	"github.com/rs/zerolog/log" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // Github has a different response than the generic provider | // Response for the github email endpoint | ||||||
| type GithubUserInfoResponse []struct { | type GithubEmailResponse []struct { | ||||||
| 	Email   string `json:"email"` | 	Email   string `json:"email"` | ||||||
| 	Primary bool   `json:"primary"` | 	Primary bool   `json:"primary"` | ||||||
| } | } | ||||||
|  |  | ||||||
| // The scopes required for the github provider | // Response for the github user endpoint | ||||||
| func GithubScopes() []string { | type GithubUserInfoResponse struct { | ||||||
| 	return []string{"user:email"} | 	Login string `json:"login"` | ||||||
|  | 	Name  string `json:"name"` | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetGithubEmail(client *http.Client) (string, error) { | // The scopes required for the github provider | ||||||
| 	// Get the user emails from github using the oauth http client | func GithubScopes() []string { | ||||||
| 	res, err := client.Get("https://api.github.com/user/emails") | 	return []string{"user:email", "read:user"} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func GetGithubUser(client *http.Client) (constants.Claims, error) { | ||||||
|  | 	// Create user struct | ||||||
|  | 	var user constants.Claims | ||||||
|  |  | ||||||
|  | 	// Get the user info from github using the oauth http client | ||||||
|  | 	res, err := client.Get("https://api.github.com/user") | ||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", err | 		return user, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Got response from github") | 	defer res.Body.Close() | ||||||
|  |  | ||||||
|  | 	log.Debug().Msg("Got user response from github") | ||||||
|  |  | ||||||
| 	// Read the body of the response | 	// Read the body of the response | ||||||
| 	body, err := io.ReadAll(res.Body) | 	body, err := io.ReadAll(res.Body) | ||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", err | 		return user, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Read body from github") | 	log.Debug().Msg("Read user body from github") | ||||||
|  |  | ||||||
| 	// Parse the body into a user struct | 	// Parse the body into a user struct | ||||||
| 	var emails GithubUserInfoResponse | 	var userInfo GithubUserInfoResponse | ||||||
|  |  | ||||||
|  | 	// Unmarshal the body into the user struct | ||||||
|  | 	err = json.Unmarshal(body, &userInfo) | ||||||
|  |  | ||||||
|  | 	// Check if there was an error | ||||||
|  | 	if err != nil { | ||||||
|  | 		return user, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Get the user emails from github using the oauth http client | ||||||
|  | 	res, err = client.Get("https://api.github.com/user/emails") | ||||||
|  |  | ||||||
|  | 	// Check if there was an error | ||||||
|  | 	if err != nil { | ||||||
|  | 		return user, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	defer res.Body.Close() | ||||||
|  |  | ||||||
|  | 	log.Debug().Msg("Got email response from github") | ||||||
|  |  | ||||||
|  | 	// Read the body of the response | ||||||
|  | 	body, err = io.ReadAll(res.Body) | ||||||
|  |  | ||||||
|  | 	// Check if there was an error | ||||||
|  | 	if err != nil { | ||||||
|  | 		return user, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	log.Debug().Msg("Read email body from github") | ||||||
|  |  | ||||||
|  | 	// Parse the body into a user struct | ||||||
|  | 	var emails GithubEmailResponse | ||||||
|  |  | ||||||
| 	// Unmarshal the body into the user struct | 	// Unmarshal the body into the user struct | ||||||
| 	err = json.Unmarshal(body, &emails) | 	err = json.Unmarshal(body, &emails) | ||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", err | 		return user, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Parsed emails from github") | 	log.Debug().Msg("Parsed emails from github") | ||||||
| @@ -57,10 +102,26 @@ func GetGithubEmail(client *http.Client) (string, error) { | |||||||
| 	// Find and return the primary email | 	// Find and return the primary email | ||||||
| 	for _, email := range emails { | 	for _, email := range emails { | ||||||
| 		if email.Primary { | 		if email.Primary { | ||||||
| 			return email.Email, nil | 			// Set the email then exit | ||||||
|  | 			user.Email = email.Email | ||||||
|  | 			break | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// User does not have a primary email? | 	// If no primary email was found, use the first available email | ||||||
| 	return "", errors.New("no primary email found") | 	if len(emails) == 0 { | ||||||
|  | 		return user, errors.New("no emails found") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Set the email if it is not set picking the first one | ||||||
|  | 	if user.Email == "" { | ||||||
|  | 		user.Email = emails[0].Email | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Set the username and name | ||||||
|  | 	user.PreferredUsername = userInfo.Login | ||||||
|  | 	user.Name = userInfo.Name | ||||||
|  |  | ||||||
|  | 	// Return | ||||||
|  | 	return user, nil | ||||||
| } | } | ||||||
|   | |||||||
| @@ -4,29 +4,37 @@ import ( | |||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"io" | 	"io" | ||||||
| 	"net/http" | 	"net/http" | ||||||
|  | 	"strings" | ||||||
|  | 	"tinyauth/internal/constants" | ||||||
|  |  | ||||||
| 	"github.com/rs/zerolog/log" | 	"github.com/rs/zerolog/log" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // Google works the same as the generic provider | // Response for the google user endpoint | ||||||
| type GoogleUserInfoResponse struct { | type GoogleUserInfoResponse struct { | ||||||
| 	Email string `json:"email"` | 	Email string `json:"email"` | ||||||
|  | 	Name  string `json:"name"` | ||||||
| } | } | ||||||
|  |  | ||||||
| // The scopes required for the google provider | // The scopes required for the google provider | ||||||
| func GoogleScopes() []string { | func GoogleScopes() []string { | ||||||
| 	return []string{"https://www.googleapis.com/auth/userinfo.email"} | 	return []string{"https://www.googleapis.com/auth/userinfo.email", "https://www.googleapis.com/auth/userinfo.profile"} | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetGoogleEmail(client *http.Client) (string, error) { | func GetGoogleUser(client *http.Client) (constants.Claims, error) { | ||||||
|  | 	// Create user struct | ||||||
|  | 	var user constants.Claims | ||||||
|  |  | ||||||
| 	// Get the user info from google using the oauth http client | 	// Get the user info from google using the oauth http client | ||||||
| 	res, err := client.Get("https://www.googleapis.com/userinfo/v2/me") | 	res, err := client.Get("https://www.googleapis.com/userinfo/v2/me") | ||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", err | 		return user, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	defer res.Body.Close() | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Got response from google") | 	log.Debug().Msg("Got response from google") | ||||||
|  |  | ||||||
| 	// Read the body of the response | 	// Read the body of the response | ||||||
| @@ -34,24 +42,29 @@ func GetGoogleEmail(client *http.Client) (string, error) { | |||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", err | 		return user, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Read body from google") | 	log.Debug().Msg("Read body from google") | ||||||
|  |  | ||||||
| 	// Parse the body into a user struct | 	// Create a new user info struct | ||||||
| 	var user GoogleUserInfoResponse | 	var userInfo GoogleUserInfoResponse | ||||||
|  |  | ||||||
| 	// Unmarshal the body into the user struct | 	// Unmarshal the body into the user struct | ||||||
| 	err = json.Unmarshal(body, &user) | 	err = json.Unmarshal(body, &userInfo) | ||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", err | 		return user, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Parsed user from google") | 	log.Debug().Msg("Parsed user from google") | ||||||
|  |  | ||||||
| 	// Return the email | 	// Map the user info to the user struct | ||||||
| 	return user.Email, nil | 	user.PreferredUsername = strings.Split(userInfo.Email, "@")[0] | ||||||
|  | 	user.Name = userInfo.Name | ||||||
|  | 	user.Email = userInfo.Email | ||||||
|  |  | ||||||
|  | 	// Return the user | ||||||
|  | 	return user, nil | ||||||
| } | } | ||||||
|   | |||||||
| @@ -2,6 +2,7 @@ package providers | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"tinyauth/internal/constants" | ||||||
| 	"tinyauth/internal/oauth" | 	"tinyauth/internal/oauth" | ||||||
| 	"tinyauth/internal/types" | 	"tinyauth/internal/types" | ||||||
|  |  | ||||||
| @@ -93,14 +94,17 @@ func (providers *Providers) GetProvider(provider string) *oauth.OAuth { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func (providers *Providers) GetUser(provider string) (string, error) { | func (providers *Providers) GetUser(provider string) (constants.Claims, error) { | ||||||
| 	// Get the email from the provider | 	// Create user struct | ||||||
|  | 	var user constants.Claims | ||||||
|  |  | ||||||
|  | 	// Get the user from the provider | ||||||
| 	switch provider { | 	switch provider { | ||||||
| 	case "github": | 	case "github": | ||||||
| 		// If the github provider is not configured, return an error | 		// If the github provider is not configured, return an error | ||||||
| 		if providers.Github == nil { | 		if providers.Github == nil { | ||||||
| 			log.Debug().Msg("Github provider not configured") | 			log.Debug().Msg("Github provider not configured") | ||||||
| 			return "", nil | 			return user, nil | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// Get the client from the github provider | 		// Get the client from the github provider | ||||||
| @@ -108,23 +112,23 @@ func (providers *Providers) GetUser(provider string) (string, error) { | |||||||
|  |  | ||||||
| 		log.Debug().Msg("Got client from github") | 		log.Debug().Msg("Got client from github") | ||||||
|  |  | ||||||
| 		// Get the email from the github provider | 		// Get the user from the github provider | ||||||
| 		email, err := GetGithubEmail(client) | 		user, err := GetGithubUser(client) | ||||||
|  |  | ||||||
| 		// Check if there was an error | 		// Check if there was an error | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return "", err | 			return user, err | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Got email from github") | 		log.Debug().Msg("Got user from github") | ||||||
|  |  | ||||||
| 		// Return the email | 		// Return the user | ||||||
| 		return email, nil | 		return user, nil | ||||||
| 	case "google": | 	case "google": | ||||||
| 		// If the google provider is not configured, return an error | 		// If the google provider is not configured, return an error | ||||||
| 		if providers.Google == nil { | 		if providers.Google == nil { | ||||||
| 			log.Debug().Msg("Google provider not configured") | 			log.Debug().Msg("Google provider not configured") | ||||||
| 			return "", nil | 			return user, nil | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// Get the client from the google provider | 		// Get the client from the google provider | ||||||
| @@ -132,23 +136,23 @@ func (providers *Providers) GetUser(provider string) (string, error) { | |||||||
|  |  | ||||||
| 		log.Debug().Msg("Got client from google") | 		log.Debug().Msg("Got client from google") | ||||||
|  |  | ||||||
| 		// Get the email from the google provider | 		// Get the user from the google provider | ||||||
| 		email, err := GetGoogleEmail(client) | 		user, err := GetGoogleUser(client) | ||||||
|  |  | ||||||
| 		// Check if there was an error | 		// Check if there was an error | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return "", err | 			return user, err | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Got email from google") | 		log.Debug().Msg("Got user from google") | ||||||
|  |  | ||||||
| 		// Return the email | 		// Return the user | ||||||
| 		return email, nil | 		return user, nil | ||||||
| 	case "generic": | 	case "generic": | ||||||
| 		// If the generic provider is not configured, return an error | 		// If the generic provider is not configured, return an error | ||||||
| 		if providers.Generic == nil { | 		if providers.Generic == nil { | ||||||
| 			log.Debug().Msg("Generic provider not configured") | 			log.Debug().Msg("Generic provider not configured") | ||||||
| 			return "", nil | 			return user, nil | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// Get the client from the generic provider | 		// Get the client from the generic provider | ||||||
| @@ -156,20 +160,20 @@ func (providers *Providers) GetUser(provider string) (string, error) { | |||||||
|  |  | ||||||
| 		log.Debug().Msg("Got client from generic") | 		log.Debug().Msg("Got client from generic") | ||||||
|  |  | ||||||
| 		// Get the email from the generic provider | 		// Get the user from the generic provider | ||||||
| 		email, err := GetGenericEmail(client, providers.Config.GenericUserURL) | 		user, err := GetGenericUser(client, providers.Config.GenericUserURL) | ||||||
|  |  | ||||||
| 		// Check if there was an error | 		// Check if there was an error | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return "", err | 			return user, err | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Got email from generic") | 		log.Debug().Msg("Got user from generic") | ||||||
|  |  | ||||||
| 		// Return the email | 		// Return the email | ||||||
| 		return email, nil | 		return user, nil | ||||||
| 	default: | 	default: | ||||||
| 		return "", nil | 		return user, nil | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -20,6 +20,7 @@ type OAuthRequest struct { | |||||||
| type UnauthorizedQuery struct { | type UnauthorizedQuery struct { | ||||||
| 	Username string `url:"username"` | 	Username string `url:"username"` | ||||||
| 	Resource string `url:"resource"` | 	Resource string `url:"resource"` | ||||||
|  | 	GroupErr bool   `url:"groupErr"` | ||||||
| } | } | ||||||
|  |  | ||||||
| // Proxy is the uri parameters for the proxy endpoint | // Proxy is the uri parameters for the proxy endpoint | ||||||
| @@ -33,6 +34,8 @@ type UserContextResponse struct { | |||||||
| 	Message     string `json:"message"` | 	Message     string `json:"message"` | ||||||
| 	IsLoggedIn  bool   `json:"isLoggedIn"` | 	IsLoggedIn  bool   `json:"isLoggedIn"` | ||||||
| 	Username    string `json:"username"` | 	Username    string `json:"username"` | ||||||
|  | 	Name        string `json:"name"` | ||||||
|  | 	Email       string `json:"email"` | ||||||
| 	Provider    string `json:"provider"` | 	Provider    string `json:"provider"` | ||||||
| 	Oauth       bool   `json:"oauth"` | 	Oauth       bool   `json:"oauth"` | ||||||
| 	TotpPending bool   `json:"totpPending"` | 	TotpPending bool   `json:"totpPending"` | ||||||
|   | |||||||
| @@ -78,3 +78,8 @@ type AuthConfig struct { | |||||||
| 	LoginTimeout    int | 	LoginTimeout    int | ||||||
| 	LoginMaxRetries int | 	LoginMaxRetries int | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // HooksConfig is the configuration for the hooks service | ||||||
|  | type HooksConfig struct { | ||||||
|  | 	Domain string | ||||||
|  | } | ||||||
|   | |||||||
| @@ -25,8 +25,11 @@ type OAuthProviders struct { | |||||||
| // SessionCookie is the cookie for the session (exculding the expiry) | // SessionCookie is the cookie for the session (exculding the expiry) | ||||||
| type SessionCookie struct { | type SessionCookie struct { | ||||||
| 	Username    string | 	Username    string | ||||||
|  | 	Name        string | ||||||
|  | 	Email       string | ||||||
| 	Provider    string | 	Provider    string | ||||||
| 	TotpPending bool | 	TotpPending bool | ||||||
|  | 	OAuthGroups string | ||||||
| } | } | ||||||
|  |  | ||||||
| // TinyauthLabels is the labels for the tinyauth container | // TinyauthLabels is the labels for the tinyauth container | ||||||
| @@ -35,15 +38,19 @@ type TinyauthLabels struct { | |||||||
| 	Users          string | 	Users          string | ||||||
| 	Allowed        string | 	Allowed        string | ||||||
| 	Headers        map[string]string | 	Headers        map[string]string | ||||||
|  | 	OAuthGroups    string | ||||||
| } | } | ||||||
|  |  | ||||||
| // UserContext is the context for the user | // UserContext is the context for the user | ||||||
| type UserContext struct { | type UserContext struct { | ||||||
| 	Username    string | 	Username    string | ||||||
|  | 	Name        string | ||||||
|  | 	Email       string | ||||||
| 	IsLoggedIn  bool | 	IsLoggedIn  bool | ||||||
| 	OAuth       bool | 	OAuth       bool | ||||||
| 	Provider    string | 	Provider    string | ||||||
| 	TotpPending bool | 	TotpPending bool | ||||||
|  | 	OAuthGroups string | ||||||
| } | } | ||||||
|  |  | ||||||
| // LoginAttempt tracks information about login attempts for rate limiting | // LoginAttempt tracks information about login attempts for rate limiting | ||||||
|   | |||||||
| @@ -204,6 +204,8 @@ func GetTinyauthLabels(labels map[string]string) types.TinyauthLabels { | |||||||
| 					} | 					} | ||||||
| 					tinyauthLabels.Headers[headerSplit[0]] = headerSplit[1] | 					tinyauthLabels.Headers[headerSplit[0]] = headerSplit[1] | ||||||
| 				} | 				} | ||||||
|  | 			case "tinyauth.oauth.groups": | ||||||
|  | 				tinyauthLabels.OAuthGroups = value | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| @@ -323,3 +325,22 @@ func CheckWhitelist(whitelist string, str string) bool { | |||||||
| 	// Return false if no match was found | 	// Return false if no match was found | ||||||
| 	return false | 	return false | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // Capitalize just the first letter of a string | ||||||
|  | func Capitalize(str string) string { | ||||||
|  | 	if len(str) == 0 { | ||||||
|  | 		return "" | ||||||
|  | 	} | ||||||
|  | 	return strings.ToUpper(string([]rune(str)[0])) + string([]rune(str)[1:]) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Sanitize header removes all control characters from a string | ||||||
|  | func SanitizeHeader(header string) string { | ||||||
|  | 	return strings.Map(func(r rune) rune { | ||||||
|  | 		// Allow only printable ASCII characters (32-126) and safe whitespace (space, tab) | ||||||
|  | 		if r == ' ' || r == '\t' || (r >= 32 && r <= 126) { | ||||||
|  | 			return r | ||||||
|  | 		} | ||||||
|  | 		return -1 | ||||||
|  | 	}, header) | ||||||
|  | } | ||||||
|   | |||||||
| @@ -467,3 +467,65 @@ func TestCheckWhitelist(t *testing.T) { | |||||||
| 		t.Fatalf("Expected %v, got %v", expected, result) | 		t.Fatalf("Expected %v, got %v", expected, result) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // Test capitalize | ||||||
|  | func TestCapitalize(t *testing.T) { | ||||||
|  | 	t.Log("Testing capitalize with a valid string") | ||||||
|  |  | ||||||
|  | 	// Create variables | ||||||
|  | 	str := "test" | ||||||
|  | 	expected := "Test" | ||||||
|  |  | ||||||
|  | 	// Test the capitalize function | ||||||
|  | 	result := utils.Capitalize(str) | ||||||
|  |  | ||||||
|  | 	// Check if the result is equal to the expected | ||||||
|  | 	if result != expected { | ||||||
|  | 		t.Fatalf("Expected %v, got %v", expected, result) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	t.Log("Testing capitalize with an empty string") | ||||||
|  |  | ||||||
|  | 	// Create variables | ||||||
|  | 	str = "" | ||||||
|  | 	expected = "" | ||||||
|  |  | ||||||
|  | 	// Test the capitalize function | ||||||
|  | 	result = utils.Capitalize(str) | ||||||
|  |  | ||||||
|  | 	// Check if the result is equal to the expected | ||||||
|  | 	if result != expected { | ||||||
|  | 		t.Fatalf("Expected %v, got %v", expected, result) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Test the header sanitizer | ||||||
|  | func TestSanitizeHeader(t *testing.T) { | ||||||
|  | 	t.Log("Testing sanitize header with a valid string") | ||||||
|  |  | ||||||
|  | 	// Create variables | ||||||
|  | 	str := "X-Header=value" | ||||||
|  | 	expected := "X-Header=value" | ||||||
|  |  | ||||||
|  | 	// Test the sanitize header function | ||||||
|  | 	result := utils.SanitizeHeader(str) | ||||||
|  |  | ||||||
|  | 	// Check if the result is equal to the expected | ||||||
|  | 	if result != expected { | ||||||
|  | 		t.Fatalf("Expected %v, got %v", expected, result) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	t.Log("Testing sanitize header with an invalid string") | ||||||
|  |  | ||||||
|  | 	// Create variables | ||||||
|  | 	str = "X-Header=val\nue" | ||||||
|  | 	expected = "X-Header=value" | ||||||
|  |  | ||||||
|  | 	// Test the sanitize header function | ||||||
|  | 	result = utils.SanitizeHeader(str) | ||||||
|  |  | ||||||
|  | 	// Check if the result is equal to the expected | ||||||
|  | 	if result != expected { | ||||||
|  | 		t.Fatalf("Expected %v, got %v", expected, result) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user