mirror of
				https://github.com/steveiliop56/tinyauth.git
				synced 2025-10-29 13:15:46 +00:00 
			
		
		
		
	Compare commits
	
		
			8 Commits
		
	
	
		
			2328e17ff4
			...
			feat/oauth
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|   | 86a18b4cac | ||
|   | d171c5940b | ||
|   | 3dff650e71 | ||
|   | 1fec583ead | ||
|   | 065b9eaf3d | ||
|   | dca09a3d9d | ||
|   | 5e4e2ddbd9 | ||
|   | 13032e564d | 
| @@ -111,6 +111,11 @@ var rootCmd = &cobra.Command{ | ||||
| 			LoginMaxRetries: config.LoginMaxRetries, | ||||
| 		} | ||||
|  | ||||
| 		// Create hooks config | ||||
| 		hooksConfig := types.HooksConfig{ | ||||
| 			Domain: domain, | ||||
| 		} | ||||
|  | ||||
| 		// Create docker service | ||||
| 		docker := docker.NewDocker() | ||||
|  | ||||
| @@ -128,7 +133,7 @@ var rootCmd = &cobra.Command{ | ||||
| 		providers.Init() | ||||
|  | ||||
| 		// Create hooks service | ||||
| 		hooks := hooks.NewHooks(auth, providers) | ||||
| 		hooks := hooks.NewHooks(hooksConfig, auth, providers) | ||||
|  | ||||
| 		// Create handlers | ||||
| 		handlers := handlers.NewHandlers(handlersConfig, auth, hooks, providers, docker) | ||||
| @@ -189,7 +194,7 @@ func init() { | ||||
| 	rootCmd.Flags().String("generic-auth-url", "", "Generic OAuth auth URL.") | ||||
| 	rootCmd.Flags().String("generic-token-url", "", "Generic OAuth token URL.") | ||||
| 	rootCmd.Flags().String("generic-user-url", "", "Generic OAuth user info URL.") | ||||
| 	rootCmd.Flags().String("generic-name", "Other", "Generic OAuth provider name.") | ||||
| 	rootCmd.Flags().String("generic-name", "Generic", "Generic OAuth provider name.") | ||||
| 	rootCmd.Flags().Bool("disable-continue", false, "Disable continue screen and redirect to app directly.") | ||||
| 	rootCmd.Flags().String("oauth-whitelist", "", "Comma separated list of email addresses to whitelist when using OAuth.") | ||||
| 	rootCmd.Flags().Int("session-expiry", 86400, "Session (cookie) expiration time in seconds.") | ||||
|   | ||||
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										4
									
								
								frontend/src/index.css
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								frontend/src/index.css
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,4 @@ | ||||
| span, | ||||
| p { | ||||
|   word-break: break-word; | ||||
| } | ||||
| @@ -41,7 +41,8 @@ | ||||
|     "totpTitle": "Enter your TOTP code", | ||||
|     "unauthorizedTitle": "Unauthorized", | ||||
|     "unauthorizedResourceSubtitle": "The user with username <Code>{{username}}</Code> is not authorized to access the resource <Code>{{resource}}</Code>.", | ||||
|     "unaothorizedLoginSubtitle": "The user with username <Code>{{username}}</Code> is not authorized to login.", | ||||
|     "unauthorizedLoginSubtitle": "The user with username <Code>{{username}}</Code> is not authorized to login.", | ||||
|     "unauthorizedGroupsSubtitle": "The user with username <Code>{{username}}</Code> is not in the groups required by the resource <Code>{{resource}}</Code>.", | ||||
|     "unauthorizedButton": "Try again", | ||||
|     "untrustedRedirectTitle": "Untrusted redirect", | ||||
|     "untrustedRedirectSubtitle": "You are trying to redirect to a domain that does not match your configured domain (<Code>{{domain}}</Code>). Are you sure you want to continue?", | ||||
|   | ||||
| @@ -41,7 +41,8 @@ | ||||
|     "totpTitle": "Enter your TOTP code", | ||||
|     "unauthorizedTitle": "Unauthorized", | ||||
|     "unauthorizedResourceSubtitle": "The user with username <Code>{{username}}</Code> is not authorized to access the resource <Code>{{resource}}</Code>.", | ||||
|     "unaothorizedLoginSubtitle": "The user with username <Code>{{username}}</Code> is not authorized to login.", | ||||
|     "unauthorizedLoginSubtitle": "The user with username <Code>{{username}}</Code> is not authorized to login.", | ||||
|     "unauthorizedGroupsSubtitle": "The user with username <Code>{{username}}</Code> is not in the groups required by the resource <Code>{{resource}}</Code>.", | ||||
|     "unauthorizedButton": "Try again", | ||||
|     "untrustedRedirectTitle": "Untrusted redirect", | ||||
|     "untrustedRedirectSubtitle": "You are trying to redirect to a domain that does not match your configured domain (<Code>{{domain}}</Code>). Are you sure you want to continue?", | ||||
|   | ||||
| @@ -19,6 +19,7 @@ import { TotpPage } from "./pages/totp-page.tsx"; | ||||
| import { AppContextProvider } from "./context/app-context.tsx"; | ||||
| import "./lib/i18n/i18n.ts"; | ||||
| import { ForgotPasswordPage } from "./pages/forgot-password-page.tsx"; | ||||
| import "./index.css"; | ||||
|  | ||||
| const queryClient = new QueryClient(); | ||||
|  | ||||
| @@ -38,7 +39,10 @@ createRoot(document.getElementById("root")!).render( | ||||
|                 <Route path="/continue" element={<ContinuePage />} /> | ||||
|                 <Route path="/unauthorized" element={<UnauthorizedPage />} /> | ||||
|                 <Route path="/error" element={<InternalServerError />} /> | ||||
|                 <Route path="/forgot-password" element={<ForgotPasswordPage />} /> | ||||
|                 <Route | ||||
|                   path="/forgot-password" | ||||
|                   element={<ForgotPasswordPage />} | ||||
|                 /> | ||||
|                 <Route path="*" element={<NotFoundPage />} /> | ||||
|               </Routes> | ||||
|             </BrowserRouter> | ||||
|   | ||||
| @@ -10,7 +10,7 @@ import { useAppContext } from "../context/app-context"; | ||||
| import { Trans, useTranslation } from "react-i18next"; | ||||
|  | ||||
| export const LogoutPage = () => { | ||||
|   const { isLoggedIn, username, oauth, provider } = useUserContext(); | ||||
|   const { isLoggedIn, oauth, provider, email, username } = useUserContext(); | ||||
|   const { genericName } = useAppContext(); | ||||
|   const { t } = useTranslation(); | ||||
|  | ||||
| @@ -56,7 +56,7 @@ export const LogoutPage = () => { | ||||
|               values={{ | ||||
|                 provider: | ||||
|                   provider === "generic" ? genericName : capitalize(provider), | ||||
|                 username: username, | ||||
|                 username: email, | ||||
|               }} | ||||
|             /> | ||||
|           ) : ( | ||||
|   | ||||
| @@ -3,11 +3,13 @@ import { Layout } from "../components/layouts/layout"; | ||||
| import { Navigate } from "react-router"; | ||||
| import { isQueryValid } from "../utils/utils"; | ||||
| import { Trans, useTranslation } from "react-i18next"; | ||||
| import React from "react"; | ||||
|  | ||||
| export const UnauthorizedPage = () => { | ||||
|   const queryString = window.location.search; | ||||
|   const params = new URLSearchParams(queryString); | ||||
|   const username = params.get("username") ?? ""; | ||||
|   const groupErr = params.get("groupErr") ?? ""; | ||||
|   const resource = params.get("resource") ?? ""; | ||||
|  | ||||
|   const { t } = useTranslation(); | ||||
| @@ -16,33 +18,54 @@ export const UnauthorizedPage = () => { | ||||
|     return <Navigate to="/" />; | ||||
|   } | ||||
|  | ||||
|   if (isQueryValid(resource) && !isQueryValid(groupErr)) { | ||||
|     return ( | ||||
|       <UnauthorizedLayout> | ||||
|         <Trans | ||||
|           i18nKey="unauthorizedResourceSubtitle" | ||||
|           t={t} | ||||
|           components={{ Code: <Code /> }} | ||||
|           values={{ resource, username }} | ||||
|         /> | ||||
|       </UnauthorizedLayout> | ||||
|     ); | ||||
|   } | ||||
|  | ||||
|   if (isQueryValid(groupErr) && isQueryValid(resource)) { | ||||
|     return ( | ||||
|       <UnauthorizedLayout> | ||||
|         <Trans | ||||
|           i18nKey="unauthorizedGroupsSubtitle" | ||||
|           t={t} | ||||
|           components={{ Code: <Code /> }} | ||||
|           values={{ username, resource }} | ||||
|         /> | ||||
|       </UnauthorizedLayout> | ||||
|     ); | ||||
|   } | ||||
|  | ||||
|   return ( | ||||
|     <UnauthorizedLayout> | ||||
|       <Trans | ||||
|         i18nKey="unauthorizedLoginSubtitle" | ||||
|         t={t} | ||||
|         components={{ Code: <Code /> }} | ||||
|         values={{ username }} | ||||
|       /> | ||||
|     </UnauthorizedLayout> | ||||
|   ); | ||||
| }; | ||||
|  | ||||
| const UnauthorizedLayout = ({ children }: { children: React.ReactNode }) => { | ||||
|   const { t } = useTranslation(); | ||||
|  | ||||
|   return ( | ||||
|     <Layout> | ||||
|       <Paper shadow="md" p={30} mt={30} radius="md" withBorder> | ||||
|         <Text size="xl" fw={700}> | ||||
|           {t("Unauthorized")} | ||||
|         </Text> | ||||
|         <Text> | ||||
|           {isQueryValid(resource) ? ( | ||||
|             <Text> | ||||
|               <Trans | ||||
|                 i18nKey="unauthorizedResourceSubtitle" | ||||
|                 t={t} | ||||
|                 components={{ Code: <Code /> }} | ||||
|                 values={{ resource, username }} | ||||
|               /> | ||||
|             </Text> | ||||
|           ) : ( | ||||
|             <Text> | ||||
|               <Trans | ||||
|                 i18nKey="unaothorizedLoginSubtitle" | ||||
|                 t={t} | ||||
|                 components={{ Code: <Code /> }} | ||||
|                 values={{ username }} | ||||
|               /> | ||||
|             </Text> | ||||
|           )} | ||||
|         </Text> | ||||
|         <Text>{children}</Text> | ||||
|         <Button | ||||
|           fullWidth | ||||
|           mt="xl" | ||||
|   | ||||
| @@ -3,6 +3,8 @@ import { z } from "zod"; | ||||
| export const userContextSchema = z.object({ | ||||
|   isLoggedIn: z.boolean(), | ||||
|   username: z.string(), | ||||
|   name: z.string(), | ||||
|   email: z.string(), | ||||
|   oauth: z.boolean(), | ||||
|   provider: z.string(), | ||||
|   totpPending: z.boolean(), | ||||
|   | ||||
| @@ -45,6 +45,11 @@ var authConfig = types.AuthConfig{ | ||||
| 	LoginMaxRetries: 0, | ||||
| } | ||||
|  | ||||
| // Simple hooks config for tests | ||||
| var hooksConfig = types.HooksConfig{ | ||||
| 	Domain: "localhost", | ||||
| } | ||||
|  | ||||
| // Cookie | ||||
| var cookie string | ||||
|  | ||||
| @@ -83,7 +88,7 @@ func getAPI(t *testing.T) *api.API { | ||||
| 	providers.Init() | ||||
|  | ||||
| 	// Create hooks service | ||||
| 	hooks := hooks.NewHooks(auth, providers) | ||||
| 	hooks := hooks.NewHooks(hooksConfig, auth, providers) | ||||
|  | ||||
| 	// Create handlers service | ||||
| 	handlers := handlers.NewHandlers(handlersConfig, auth, hooks, providers, docker) | ||||
|   | ||||
| @@ -160,9 +160,12 @@ func (auth *Auth) CreateSessionCookie(c *gin.Context, data *types.SessionCookie) | ||||
|  | ||||
| 	// Set data | ||||
| 	session.Values["username"] = data.Username | ||||
| 	session.Values["name"] = data.Name | ||||
| 	session.Values["email"] = data.Email | ||||
| 	session.Values["provider"] = data.Provider | ||||
| 	session.Values["expiry"] = time.Now().Add(time.Duration(sessionExpiry) * time.Second).Unix() | ||||
| 	session.Values["totpPending"] = data.TotpPending | ||||
| 	session.Values["oauthGroups"] = data.OAuthGroups | ||||
|  | ||||
| 	// Save session | ||||
| 	err = session.Save(c.Request, c.Writer) | ||||
| @@ -211,14 +214,24 @@ func (auth *Auth) GetSessionCookie(c *gin.Context) (types.SessionCookie, error) | ||||
| 		return types.SessionCookie{}, err | ||||
| 	} | ||||
|  | ||||
| 	log.Debug().Msg("Got session") | ||||
|  | ||||
| 	// Get data from session | ||||
| 	username, usernameOk := session.Values["username"].(string) | ||||
| 	email, emailOk := session.Values["email"].(string) | ||||
| 	name, nameOk := session.Values["name"].(string) | ||||
| 	provider, providerOK := session.Values["provider"].(string) | ||||
| 	expiry, expiryOk := session.Values["expiry"].(int64) | ||||
| 	totpPending, totpPendingOk := session.Values["totpPending"].(bool) | ||||
| 	oauthGroups, oauthGroupsOk := session.Values["oauthGroups"].(string) | ||||
|  | ||||
| 	if !usernameOk || !providerOK || !expiryOk || !totpPendingOk { | ||||
| 		log.Warn().Msg("Session cookie is missing data") | ||||
| 	if !usernameOk || !providerOK || !expiryOk || !totpPendingOk || !emailOk || !nameOk || !oauthGroupsOk { | ||||
| 		log.Warn().Msg("Session cookie is invalid") | ||||
|  | ||||
| 		// If any data is missing, delete the session cookie | ||||
| 		auth.DeleteSessionCookie(c) | ||||
|  | ||||
| 		// Return empty cookie | ||||
| 		return types.SessionCookie{}, nil | ||||
| 	} | ||||
|  | ||||
| @@ -233,13 +246,16 @@ func (auth *Auth) GetSessionCookie(c *gin.Context) (types.SessionCookie, error) | ||||
| 		return types.SessionCookie{}, nil | ||||
| 	} | ||||
|  | ||||
| 	log.Debug().Str("username", username).Str("provider", provider).Int64("expiry", expiry).Bool("totpPending", totpPending).Msg("Parsed cookie") | ||||
| 	log.Debug().Str("username", username).Str("provider", provider).Int64("expiry", expiry).Bool("totpPending", totpPending).Str("name", name).Str("email", email).Str("oauthGroups", oauthGroups).Msg("Parsed cookie") | ||||
|  | ||||
| 	// Return the cookie | ||||
| 	return types.SessionCookie{ | ||||
| 		Username:    username, | ||||
| 		Name:        name, | ||||
| 		Email:       email, | ||||
| 		Provider:    provider, | ||||
| 		TotpPending: totpPending, | ||||
| 		OAuthGroups: oauthGroups, | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| @@ -248,48 +264,52 @@ func (auth *Auth) UserAuthConfigured() bool { | ||||
| 	return len(auth.Config.Users) > 0 | ||||
| } | ||||
|  | ||||
| func (auth *Auth) ResourceAllowed(c *gin.Context, context types.UserContext) (bool, error) { | ||||
| 	// Get headers | ||||
| 	host := c.Request.Header.Get("X-Forwarded-Host") | ||||
|  | ||||
| 	// Get app id | ||||
| 	appId := strings.Split(host, ".")[0] | ||||
|  | ||||
| 	// Get the container labels | ||||
| 	labels, err := auth.Docker.GetLabels(appId) | ||||
|  | ||||
| 	// If there is an error, return false | ||||
| 	if err != nil { | ||||
| 		return false, err | ||||
| 	} | ||||
|  | ||||
| func (auth *Auth) ResourceAllowed(c *gin.Context, context types.UserContext, labels types.TinyauthLabels) bool { | ||||
| 	// Check if oauth is allowed | ||||
| 	if context.OAuth { | ||||
| 		log.Debug().Msg("Checking OAuth whitelist") | ||||
| 		return utils.CheckWhitelist(labels.OAuthWhitelist, context.Username), nil | ||||
| 		return utils.CheckWhitelist(labels.OAuthWhitelist, context.Email) | ||||
| 	} | ||||
|  | ||||
| 	// Check users | ||||
| 	log.Debug().Msg("Checking users") | ||||
|  | ||||
| 	return utils.CheckWhitelist(labels.Users, context.Username), nil | ||||
| 	return utils.CheckWhitelist(labels.Users, context.Username) | ||||
| } | ||||
|  | ||||
| func (auth *Auth) AuthEnabled(c *gin.Context) (bool, error) { | ||||
| func (auth *Auth) OAuthGroup(c *gin.Context, context types.UserContext, labels types.TinyauthLabels) bool { | ||||
| 	// Check if groups are required | ||||
| 	if labels.OAuthGroups == "" { | ||||
| 		return true | ||||
| 	} | ||||
|  | ||||
| 	// Check if we are using the generic oauth provider | ||||
| 	if context.Provider != "generic" { | ||||
| 		log.Debug().Msg("Not using generic provider, skipping group check") | ||||
| 		return true | ||||
| 	} | ||||
|  | ||||
| 	// Split the groups by comma (no need to parse since they are from the API response) | ||||
| 	oauthGroups := strings.Split(context.OAuthGroups, ",") | ||||
|  | ||||
| 	// For every group check if it is in the required groups | ||||
| 	for _, group := range oauthGroups { | ||||
| 		if utils.CheckWhitelist(labels.OAuthGroups, group) { | ||||
| 			log.Debug().Str("group", group).Msg("Group is in required groups") | ||||
| 			return true | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// No groups matched | ||||
| 	log.Debug().Msg("No groups matched") | ||||
|  | ||||
| 	// Return false | ||||
| 	return false | ||||
| } | ||||
|  | ||||
| func (auth *Auth) AuthEnabled(c *gin.Context, labels types.TinyauthLabels) (bool, error) { | ||||
| 	// Get headers | ||||
| 	uri := c.Request.Header.Get("X-Forwarded-Uri") | ||||
| 	host := c.Request.Header.Get("X-Forwarded-Host") | ||||
|  | ||||
| 	// Get app id | ||||
| 	appId := strings.Split(host, ".")[0] | ||||
|  | ||||
| 	// Get the container labels | ||||
| 	labels, err := auth.Docker.GetLabels(appId) | ||||
|  | ||||
| 	// If there is an error, auth enabled | ||||
| 	if err != nil { | ||||
| 		return true, err | ||||
| 	} | ||||
|  | ||||
| 	// Check if the allowed label is empty | ||||
| 	if labels.Allowed == "" { | ||||
|   | ||||
| @@ -6,4 +6,13 @@ var TinyauthLabels = []string{ | ||||
| 	"tinyauth.users", | ||||
| 	"tinyauth.allowed", | ||||
| 	"tinyauth.headers", | ||||
| 	"tinyauth.oauth.groups", | ||||
| } | ||||
|  | ||||
| // Claims are the OIDC supported claims (including preferd username for some reason) | ||||
| type Claims struct { | ||||
| 	Name              string   `json:"name"` | ||||
| 	Email             string   `json:"email"` | ||||
| 	PreferredUsername string   `json:"preferred_username"` | ||||
| 	Groups            []string `json:"groups"` | ||||
| } | ||||
|   | ||||
| @@ -6,8 +6,7 @@ import ( | ||||
| 	"tinyauth/internal/types" | ||||
| 	"tinyauth/internal/utils" | ||||
|  | ||||
| 	apiTypes "github.com/docker/docker/api/types" | ||||
| 	containerTypes "github.com/docker/docker/api/types/container" | ||||
| 	container "github.com/docker/docker/api/types/container" | ||||
| 	"github.com/docker/docker/client" | ||||
| 	"github.com/rs/zerolog/log" | ||||
| ) | ||||
| @@ -38,9 +37,9 @@ func (docker *Docker) Init() error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (docker *Docker) GetContainers() ([]apiTypes.Container, error) { | ||||
| func (docker *Docker) GetContainers() ([]container.Summary, error) { | ||||
| 	// Get the list of containers | ||||
| 	containers, err := docker.Client.ContainerList(docker.Context, containerTypes.ListOptions{}) | ||||
| 	containers, err := docker.Client.ContainerList(docker.Context, container.ListOptions{}) | ||||
|  | ||||
| 	// Check if there was an error | ||||
| 	if err != nil { | ||||
| @@ -51,13 +50,13 @@ func (docker *Docker) GetContainers() ([]apiTypes.Container, error) { | ||||
| 	return containers, nil | ||||
| } | ||||
|  | ||||
| func (docker *Docker) InspectContainer(containerId string) (apiTypes.ContainerJSON, error) { | ||||
| func (docker *Docker) InspectContainer(containerId string) (container.InspectResponse, error) { | ||||
| 	// Inspect the container | ||||
| 	inspect, err := docker.Client.ContainerInspect(docker.Context, containerId) | ||||
|  | ||||
| 	// Check if there was an error | ||||
| 	if err != nil { | ||||
| 		return apiTypes.ContainerJSON{}, err | ||||
| 		return container.InspectResponse{}, err | ||||
| 	} | ||||
|  | ||||
| 	// Return the inspect | ||||
|   | ||||
| @@ -10,6 +10,7 @@ import ( | ||||
| 	"tinyauth/internal/hooks" | ||||
| 	"tinyauth/internal/providers" | ||||
| 	"tinyauth/internal/types" | ||||
| 	"tinyauth/internal/utils" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/google/go-querystring/query" | ||||
| @@ -68,12 +69,15 @@ func (h *Handlers) AuthHandler(c *gin.Context) { | ||||
| 	proto := c.Request.Header.Get("X-Forwarded-Proto") | ||||
| 	host := c.Request.Header.Get("X-Forwarded-Host") | ||||
|  | ||||
| 	// Check if auth is enabled | ||||
| 	authEnabled, err := h.Auth.AuthEnabled(c) | ||||
| 	// Get the app id | ||||
| 	appId := strings.Split(host, ".")[0] | ||||
|  | ||||
| 	// Get the container labels | ||||
| 	labels, err := h.Docker.GetLabels(appId) | ||||
|  | ||||
| 	// Check if there was an error | ||||
| 	if err != nil { | ||||
| 		log.Error().Err(err).Msg("Failed to check if app is allowed") | ||||
| 		log.Error().Err(err).Msg("Failed to get container labels") | ||||
|  | ||||
| 		if proxy.Proxy == "nginx" || !isBrowser { | ||||
| 			c.JSON(500, gin.H{ | ||||
| @@ -87,11 +91,8 @@ func (h *Handlers) AuthHandler(c *gin.Context) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// Get the app id | ||||
| 	appId := strings.Split(host, ".")[0] | ||||
|  | ||||
| 	// Get the container labels | ||||
| 	labels, err := h.Docker.GetLabels(appId) | ||||
| 	// Check if auth is enabled | ||||
| 	authEnabled, err := h.Auth.AuthEnabled(c, labels) | ||||
|  | ||||
| 	// Check if there was an error | ||||
| 	if err != nil { | ||||
| @@ -113,7 +114,7 @@ func (h *Handlers) AuthHandler(c *gin.Context) { | ||||
| 	if !authEnabled { | ||||
| 		for key, value := range labels.Headers { | ||||
| 			log.Debug().Str("key", key).Str("value", value).Msg("Setting header") | ||||
| 			c.Header(key, value) | ||||
| 			c.Header(key, utils.SanitizeHeader(value)) | ||||
| 		} | ||||
| 		c.JSON(200, gin.H{ | ||||
| 			"status":  200, | ||||
| @@ -130,23 +131,7 @@ func (h *Handlers) AuthHandler(c *gin.Context) { | ||||
| 		log.Debug().Msg("Authenticated") | ||||
|  | ||||
| 		// Check if user is allowed to access subdomain, if request is nginx.example.com the subdomain (resource) is nginx | ||||
| 		appAllowed, err := h.Auth.ResourceAllowed(c, userContext) | ||||
|  | ||||
| 		// Check if there was an error | ||||
| 		if err != nil { | ||||
| 			log.Error().Err(err).Msg("Failed to check if app is allowed") | ||||
|  | ||||
| 			if proxy.Proxy == "nginx" || !isBrowser { | ||||
| 				c.JSON(500, gin.H{ | ||||
| 					"status":  500, | ||||
| 					"message": "Internal Server Error", | ||||
| 				}) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) | ||||
| 			return | ||||
| 		} | ||||
| 		appAllowed := h.Auth.ResourceAllowed(c, userContext, labels) | ||||
|  | ||||
| 		log.Debug().Bool("appAllowed", appAllowed).Msg("Checking if app is allowed") | ||||
|  | ||||
| @@ -165,11 +150,20 @@ func (h *Handlers) AuthHandler(c *gin.Context) { | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			// Build query | ||||
| 			queries, err := query.Values(types.UnauthorizedQuery{ | ||||
| 				Username: userContext.Username, | ||||
| 			// Values | ||||
| 			values := types.UnauthorizedQuery{ | ||||
| 				Resource: strings.Split(host, ".")[0], | ||||
| 			}) | ||||
| 			} | ||||
|  | ||||
| 			// Use either username or email | ||||
| 			if userContext.OAuth { | ||||
| 				values.Username = userContext.Email | ||||
| 			} else { | ||||
| 				values.Username = userContext.Username | ||||
| 			} | ||||
|  | ||||
| 			// Build query | ||||
| 			queries, err := query.Values(values) | ||||
|  | ||||
| 			// Handle error (no need to check for nginx/headers since we are sure we are using caddy/traefik) | ||||
| 			if err != nil { | ||||
| @@ -183,13 +177,65 @@ func (h *Handlers) AuthHandler(c *gin.Context) { | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		// Set the user header | ||||
| 		c.Header("Remote-User", userContext.Username) | ||||
| 		log.Debug().Interface("labels", labels).Msg("Got labels") | ||||
|  | ||||
| 		// Check if user is in required groups | ||||
| 		groupOk := h.Auth.OAuthGroup(c, userContext, labels) | ||||
|  | ||||
| 		log.Debug().Bool("groupOk", groupOk).Msg("Checking if user is in required groups") | ||||
|  | ||||
| 		// The user is not allowed to access the app | ||||
| 		if !groupOk { | ||||
| 			log.Warn().Str("username", userContext.Username).Str("host", host).Msg("User is not in required groups") | ||||
|  | ||||
| 			// Set WWW-Authenticate header | ||||
| 			c.Header("WWW-Authenticate", "Basic realm=\"tinyauth\"") | ||||
|  | ||||
| 			if proxy.Proxy == "nginx" || !isBrowser { | ||||
| 				c.JSON(401, gin.H{ | ||||
| 					"status":  401, | ||||
| 					"message": "Unauthorized", | ||||
| 				}) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			// Values | ||||
| 			values := types.UnauthorizedQuery{ | ||||
| 				Resource: strings.Split(host, ".")[0], | ||||
| 				GroupErr: true, | ||||
| 			} | ||||
|  | ||||
| 			// Use either username or email | ||||
| 			if userContext.OAuth { | ||||
| 				values.Username = userContext.Email | ||||
| 			} else { | ||||
| 				values.Username = userContext.Username | ||||
| 			} | ||||
|  | ||||
| 			// Build query | ||||
| 			queries, err := query.Values(values) | ||||
|  | ||||
| 			// Handle error (no need to check for nginx/headers since we are sure we are using caddy/traefik) | ||||
| 			if err != nil { | ||||
| 				log.Error().Err(err).Msg("Failed to build queries") | ||||
| 				c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			// We are using caddy/traefik so redirect | ||||
| 			c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", h.Config.AppURL, queries.Encode())) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		c.Header("Remote-User", utils.SanitizeHeader(userContext.Username)) | ||||
| 		c.Header("Remote-Name", utils.SanitizeHeader(userContext.Name)) | ||||
| 		c.Header("Remote-Email", utils.SanitizeHeader(userContext.Email)) | ||||
| 		c.Header("Remote-Groups", utils.SanitizeHeader(userContext.OAuthGroups)) | ||||
|  | ||||
| 		// Set the rest of the headers | ||||
| 		for key, value := range labels.Headers { | ||||
| 			log.Debug().Str("key", key).Str("value", value).Msg("Setting header") | ||||
| 			c.Header(key, value) | ||||
| 			c.Header(key, utils.SanitizeHeader(value)) | ||||
| 		} | ||||
|  | ||||
| 		// The user is allowed to access the app | ||||
| @@ -310,6 +356,8 @@ func (h *Handlers) LoginHandler(c *gin.Context) { | ||||
| 		// Set totp pending cookie | ||||
| 		h.Auth.CreateSessionCookie(c, &types.SessionCookie{ | ||||
| 			Username:    login.Username, | ||||
| 			Name:        utils.Capitalize(login.Username), | ||||
| 			Email:       fmt.Sprintf("%s@%s", strings.ToLower(login.Username), h.Config.Domain), | ||||
| 			Provider:    "username", | ||||
| 			TotpPending: true, | ||||
| 		}) | ||||
| @@ -328,6 +376,8 @@ func (h *Handlers) LoginHandler(c *gin.Context) { | ||||
| 	// Create session cookie with username as provider | ||||
| 	h.Auth.CreateSessionCookie(c, &types.SessionCookie{ | ||||
| 		Username: login.Username, | ||||
| 		Name:     utils.Capitalize(login.Username), | ||||
| 		Email:    fmt.Sprintf("%s@%s", strings.ToLower(login.Username), h.Config.Domain), | ||||
| 		Provider: "username", | ||||
| 	}) | ||||
|  | ||||
| @@ -402,6 +452,8 @@ func (h *Handlers) TotpHandler(c *gin.Context) { | ||||
| 	// Create session cookie with username as provider | ||||
| 	h.Auth.CreateSessionCookie(c, &types.SessionCookie{ | ||||
| 		Username: user.Username, | ||||
| 		Name:     utils.Capitalize(user.Username), | ||||
| 		Email:    fmt.Sprintf("%s@%s", strings.ToLower(user.Username), h.Config.Domain), | ||||
| 		Provider: "username", | ||||
| 	}) | ||||
|  | ||||
| @@ -465,6 +517,8 @@ func (h *Handlers) UserHandler(c *gin.Context) { | ||||
| 		Status:      200, | ||||
| 		IsLoggedIn:  userContext.IsLoggedIn, | ||||
| 		Username:    userContext.Username, | ||||
| 		Name:        userContext.Name, | ||||
| 		Email:       userContext.Email, | ||||
| 		Provider:    userContext.Provider, | ||||
| 		Oauth:       userContext.OAuth, | ||||
| 		TotpPending: userContext.TotpPending, | ||||
| @@ -613,25 +667,32 @@ func (h *Handlers) OauthCallbackHandler(c *gin.Context) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// Get email | ||||
| 	email, err := h.Providers.GetUser(providerName.Provider) | ||||
|  | ||||
| 	log.Debug().Str("email", email).Msg("Got email") | ||||
| 	// Get user | ||||
| 	user, err := h.Providers.GetUser(providerName.Provider) | ||||
|  | ||||
| 	// Handle error | ||||
| 	if err != nil { | ||||
| 		log.Error().Err(err).Msg("Failed to get email") | ||||
| 		log.Error().Msg("Failed to get user") | ||||
| 		c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	log.Debug().Msg("Got user") | ||||
|  | ||||
| 	// Check that email is not empty | ||||
| 	if user.Email == "" { | ||||
| 		log.Error().Msg("Email is empty") | ||||
| 		c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// Email is not whitelisted | ||||
| 	if !h.Auth.EmailWhitelisted(email) { | ||||
| 		log.Warn().Str("email", email).Msg("Email not whitelisted") | ||||
| 	if !h.Auth.EmailWhitelisted(user.Email) { | ||||
| 		log.Warn().Str("email", user.Email).Msg("Email not whitelisted") | ||||
|  | ||||
| 		// Build query | ||||
| 		queries, err := query.Values(types.UnauthorizedQuery{ | ||||
| 			Username: email, | ||||
| 			Username: user.Email, | ||||
| 		}) | ||||
|  | ||||
| 		// Handle error | ||||
| @@ -647,10 +708,31 @@ func (h *Handlers) OauthCallbackHandler(c *gin.Context) { | ||||
|  | ||||
| 	log.Debug().Msg("Email whitelisted") | ||||
|  | ||||
| 	// Get username | ||||
| 	var username string | ||||
|  | ||||
| 	if user.PreferredUsername != "" { | ||||
| 		username = user.PreferredUsername | ||||
| 	} else { | ||||
| 		username = fmt.Sprintf("%s_%s", strings.Split(user.Email, "@")[0], strings.Split(user.Email, "@")[1]) | ||||
| 	} | ||||
|  | ||||
| 	// Get name | ||||
| 	var name string | ||||
|  | ||||
| 	if user.Name != "" { | ||||
| 		name = user.Name | ||||
| 	} else { | ||||
| 		name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1]) | ||||
| 	} | ||||
|  | ||||
| 	// Create session cookie (also cleans up redirect cookie) | ||||
| 	h.Auth.CreateSessionCookie(c, &types.SessionCookie{ | ||||
| 		Username: email, | ||||
| 		Provider: providerName.Provider, | ||||
| 		Username:    username, | ||||
| 		Name:        name, | ||||
| 		Email:       user.Email, | ||||
| 		Provider:    providerName.Provider, | ||||
| 		OAuthGroups: strings.Join(user.Groups, ","), | ||||
| 	}) | ||||
|  | ||||
| 	// Check if we have a redirect URI | ||||
|   | ||||
| @@ -1,22 +1,27 @@ | ||||
| package hooks | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"strings" | ||||
| 	"tinyauth/internal/auth" | ||||
| 	"tinyauth/internal/providers" | ||||
| 	"tinyauth/internal/types" | ||||
| 	"tinyauth/internal/utils" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/rs/zerolog/log" | ||||
| ) | ||||
|  | ||||
| func NewHooks(auth *auth.Auth, providers *providers.Providers) *Hooks { | ||||
| func NewHooks(config types.HooksConfig, auth *auth.Auth, providers *providers.Providers) *Hooks { | ||||
| 	return &Hooks{ | ||||
| 		Config:    config, | ||||
| 		Auth:      auth, | ||||
| 		Providers: providers, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type Hooks struct { | ||||
| 	Config    types.HooksConfig | ||||
| 	Auth      *auth.Auth | ||||
| 	Providers *providers.Providers | ||||
| } | ||||
| @@ -36,11 +41,11 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { | ||||
| 		if user != nil && hooks.Auth.CheckPassword(*user, basic.Password) { | ||||
| 			// Return user context since we are logged in with basic auth | ||||
| 			return types.UserContext{ | ||||
| 				Username:    basic.Username, | ||||
| 				IsLoggedIn:  true, | ||||
| 				OAuth:       false, | ||||
| 				Provider:    "basic", | ||||
| 				TotpPending: false, | ||||
| 				Username:   basic.Username, | ||||
| 				Name:       utils.Capitalize(basic.Username), | ||||
| 				Email:      fmt.Sprintf("%s@%s", strings.ToLower(basic.Username), hooks.Config.Domain), | ||||
| 				IsLoggedIn: true, | ||||
| 				Provider:   "basic", | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| @@ -50,13 +55,7 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { | ||||
| 	if err != nil { | ||||
| 		log.Error().Err(err).Msg("Failed to get session cookie") | ||||
| 		// Return empty context | ||||
| 		return types.UserContext{ | ||||
| 			Username:    "", | ||||
| 			IsLoggedIn:  false, | ||||
| 			OAuth:       false, | ||||
| 			Provider:    "", | ||||
| 			TotpPending: false, | ||||
| 		} | ||||
| 		return types.UserContext{} | ||||
| 	} | ||||
|  | ||||
| 	// Check if session cookie has totp pending | ||||
| @@ -65,8 +64,8 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { | ||||
| 		// Return empty context since we are pending totp | ||||
| 		return types.UserContext{ | ||||
| 			Username:    cookie.Username, | ||||
| 			IsLoggedIn:  false, | ||||
| 			OAuth:       false, | ||||
| 			Name:        cookie.Name, | ||||
| 			Email:       cookie.Email, | ||||
| 			Provider:    cookie.Provider, | ||||
| 			TotpPending: true, | ||||
| 		} | ||||
| @@ -82,11 +81,11 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { | ||||
|  | ||||
| 			// It exists so we are logged in | ||||
| 			return types.UserContext{ | ||||
| 				Username:    cookie.Username, | ||||
| 				IsLoggedIn:  true, | ||||
| 				OAuth:       false, | ||||
| 				Provider:    "username", | ||||
| 				TotpPending: false, | ||||
| 				Username:   cookie.Username, | ||||
| 				Name:       cookie.Name, | ||||
| 				Email:      cookie.Email, | ||||
| 				IsLoggedIn: true, | ||||
| 				Provider:   "username", | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| @@ -108,13 +107,7 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { | ||||
| 			hooks.Auth.DeleteSessionCookie(c) | ||||
|  | ||||
| 			// Return empty context | ||||
| 			return types.UserContext{ | ||||
| 				Username:    "", | ||||
| 				IsLoggedIn:  false, | ||||
| 				OAuth:       false, | ||||
| 				Provider:    "", | ||||
| 				TotpPending: false, | ||||
| 			} | ||||
| 			return types.UserContext{} | ||||
| 		} | ||||
|  | ||||
| 		log.Debug().Msg("Email is whitelisted") | ||||
| @@ -122,19 +115,15 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { | ||||
| 		// Return user context since we are logged in with oauth | ||||
| 		return types.UserContext{ | ||||
| 			Username:    cookie.Username, | ||||
| 			Name:        cookie.Name, | ||||
| 			Email:       cookie.Email, | ||||
| 			IsLoggedIn:  true, | ||||
| 			OAuth:       true, | ||||
| 			Provider:    cookie.Provider, | ||||
| 			TotpPending: false, | ||||
| 			OAuthGroups: cookie.OAuthGroups, | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// Neither basic auth or oauth is set so we return an empty context | ||||
| 	return types.UserContext{ | ||||
| 		Username:    "", | ||||
| 		IsLoggedIn:  false, | ||||
| 		OAuth:       false, | ||||
| 		Provider:    "", | ||||
| 		TotpPending: false, | ||||
| 	} | ||||
| 	return types.UserContext{} | ||||
| } | ||||
|   | ||||
| @@ -4,24 +4,25 @@ import ( | ||||
| 	"encoding/json" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"tinyauth/internal/constants" | ||||
|  | ||||
| 	"github.com/rs/zerolog/log" | ||||
| ) | ||||
|  | ||||
| // We are assuming that the generic provider will return a JSON object with an email field | ||||
| type GenericUserInfoResponse struct { | ||||
| 	Email string `json:"email"` | ||||
| } | ||||
| func GetGenericUser(client *http.Client, url string) (constants.Claims, error) { | ||||
| 	// Create user struct | ||||
| 	var user constants.Claims | ||||
|  | ||||
| func GetGenericEmail(client *http.Client, url string) (string, error) { | ||||
| 	// Using the oauth client get the user info url | ||||
| 	res, err := client.Get(url) | ||||
|  | ||||
| 	// Check if there was an error | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 		return user, err | ||||
| 	} | ||||
|  | ||||
| 	defer res.Body.Close() | ||||
|  | ||||
| 	log.Debug().Msg("Got response from generic provider") | ||||
|  | ||||
| 	// Read the body of the response | ||||
| @@ -29,24 +30,21 @@ func GetGenericEmail(client *http.Client, url string) (string, error) { | ||||
|  | ||||
| 	// Check if there was an error | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 		return user, err | ||||
| 	} | ||||
|  | ||||
| 	log.Debug().Msg("Read body from generic provider") | ||||
|  | ||||
| 	// Parse the body into a user struct | ||||
| 	var user GenericUserInfoResponse | ||||
|  | ||||
| 	// Unmarshal the body into the user struct | ||||
| 	err = json.Unmarshal(body, &user) | ||||
|  | ||||
| 	// Check if there was an error | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 		return user, err | ||||
| 	} | ||||
|  | ||||
| 	log.Debug().Msg("Parsed user from generic provider") | ||||
|  | ||||
| 	// Return the email | ||||
| 	return user.Email, nil | ||||
| 	// Return the user | ||||
| 	return user, nil | ||||
| } | ||||
|   | ||||
| @@ -5,51 +5,96 @@ import ( | ||||
| 	"errors" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"tinyauth/internal/constants" | ||||
|  | ||||
| 	"github.com/rs/zerolog/log" | ||||
| ) | ||||
|  | ||||
| // Github has a different response than the generic provider | ||||
| type GithubUserInfoResponse []struct { | ||||
| // Response for the github email endpoint | ||||
| type GithubEmailResponse []struct { | ||||
| 	Email   string `json:"email"` | ||||
| 	Primary bool   `json:"primary"` | ||||
| } | ||||
|  | ||||
| // The scopes required for the github provider | ||||
| func GithubScopes() []string { | ||||
| 	return []string{"user:email"} | ||||
| // Response for the github user endpoint | ||||
| type GithubUserInfoResponse struct { | ||||
| 	Login string `json:"login"` | ||||
| 	Name  string `json:"name"` | ||||
| } | ||||
|  | ||||
| func GetGithubEmail(client *http.Client) (string, error) { | ||||
| 	// Get the user emails from github using the oauth http client | ||||
| 	res, err := client.Get("https://api.github.com/user/emails") | ||||
| // The scopes required for the github provider | ||||
| func GithubScopes() []string { | ||||
| 	return []string{"user:email", "read:user"} | ||||
| } | ||||
|  | ||||
| func GetGithubUser(client *http.Client) (constants.Claims, error) { | ||||
| 	// Create user struct | ||||
| 	var user constants.Claims | ||||
|  | ||||
| 	// Get the user info from github using the oauth http client | ||||
| 	res, err := client.Get("https://api.github.com/user") | ||||
|  | ||||
| 	// Check if there was an error | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 		return user, err | ||||
| 	} | ||||
|  | ||||
| 	log.Debug().Msg("Got response from github") | ||||
| 	defer res.Body.Close() | ||||
|  | ||||
| 	log.Debug().Msg("Got user response from github") | ||||
|  | ||||
| 	// Read the body of the response | ||||
| 	body, err := io.ReadAll(res.Body) | ||||
|  | ||||
| 	// Check if there was an error | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 		return user, err | ||||
| 	} | ||||
|  | ||||
| 	log.Debug().Msg("Read body from github") | ||||
| 	log.Debug().Msg("Read user body from github") | ||||
|  | ||||
| 	// Parse the body into a user struct | ||||
| 	var emails GithubUserInfoResponse | ||||
| 	var userInfo GithubUserInfoResponse | ||||
|  | ||||
| 	// Unmarshal the body into the user struct | ||||
| 	err = json.Unmarshal(body, &userInfo) | ||||
|  | ||||
| 	// Check if there was an error | ||||
| 	if err != nil { | ||||
| 		return user, err | ||||
| 	} | ||||
|  | ||||
| 	// Get the user emails from github using the oauth http client | ||||
| 	res, err = client.Get("https://api.github.com/user/emails") | ||||
|  | ||||
| 	// Check if there was an error | ||||
| 	if err != nil { | ||||
| 		return user, err | ||||
| 	} | ||||
|  | ||||
| 	defer res.Body.Close() | ||||
|  | ||||
| 	log.Debug().Msg("Got email response from github") | ||||
|  | ||||
| 	// Read the body of the response | ||||
| 	body, err = io.ReadAll(res.Body) | ||||
|  | ||||
| 	// Check if there was an error | ||||
| 	if err != nil { | ||||
| 		return user, err | ||||
| 	} | ||||
|  | ||||
| 	log.Debug().Msg("Read email body from github") | ||||
|  | ||||
| 	// Parse the body into a user struct | ||||
| 	var emails GithubEmailResponse | ||||
|  | ||||
| 	// Unmarshal the body into the user struct | ||||
| 	err = json.Unmarshal(body, &emails) | ||||
|  | ||||
| 	// Check if there was an error | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 		return user, err | ||||
| 	} | ||||
|  | ||||
| 	log.Debug().Msg("Parsed emails from github") | ||||
| @@ -57,10 +102,26 @@ func GetGithubEmail(client *http.Client) (string, error) { | ||||
| 	// Find and return the primary email | ||||
| 	for _, email := range emails { | ||||
| 		if email.Primary { | ||||
| 			return email.Email, nil | ||||
| 			// Set the email then exit | ||||
| 			user.Email = email.Email | ||||
| 			break | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// User does not have a primary email? | ||||
| 	return "", errors.New("no primary email found") | ||||
| 	// If no primary email was found, use the first available email | ||||
| 	if len(emails) == 0 { | ||||
| 		return user, errors.New("no emails found") | ||||
| 	} | ||||
|  | ||||
| 	// Set the email if it is not set picking the first one | ||||
| 	if user.Email == "" { | ||||
| 		user.Email = emails[0].Email | ||||
| 	} | ||||
|  | ||||
| 	// Set the username and name | ||||
| 	user.PreferredUsername = userInfo.Login | ||||
| 	user.Name = userInfo.Name | ||||
|  | ||||
| 	// Return | ||||
| 	return user, nil | ||||
| } | ||||
|   | ||||
| @@ -4,29 +4,37 @@ import ( | ||||
| 	"encoding/json" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| 	"tinyauth/internal/constants" | ||||
|  | ||||
| 	"github.com/rs/zerolog/log" | ||||
| ) | ||||
|  | ||||
| // Google works the same as the generic provider | ||||
| // Response for the google user endpoint | ||||
| type GoogleUserInfoResponse struct { | ||||
| 	Email string `json:"email"` | ||||
| 	Name  string `json:"name"` | ||||
| } | ||||
|  | ||||
| // The scopes required for the google provider | ||||
| func GoogleScopes() []string { | ||||
| 	return []string{"https://www.googleapis.com/auth/userinfo.email"} | ||||
| 	return []string{"https://www.googleapis.com/auth/userinfo.email", "https://www.googleapis.com/auth/userinfo.profile"} | ||||
| } | ||||
|  | ||||
| func GetGoogleEmail(client *http.Client) (string, error) { | ||||
| func GetGoogleUser(client *http.Client) (constants.Claims, error) { | ||||
| 	// Create user struct | ||||
| 	var user constants.Claims | ||||
|  | ||||
| 	// Get the user info from google using the oauth http client | ||||
| 	res, err := client.Get("https://www.googleapis.com/userinfo/v2/me") | ||||
|  | ||||
| 	// Check if there was an error | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 		return user, err | ||||
| 	} | ||||
|  | ||||
| 	defer res.Body.Close() | ||||
|  | ||||
| 	log.Debug().Msg("Got response from google") | ||||
|  | ||||
| 	// Read the body of the response | ||||
| @@ -34,24 +42,29 @@ func GetGoogleEmail(client *http.Client) (string, error) { | ||||
|  | ||||
| 	// Check if there was an error | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 		return user, err | ||||
| 	} | ||||
|  | ||||
| 	log.Debug().Msg("Read body from google") | ||||
|  | ||||
| 	// Parse the body into a user struct | ||||
| 	var user GoogleUserInfoResponse | ||||
| 	// Create a new user info struct | ||||
| 	var userInfo GoogleUserInfoResponse | ||||
|  | ||||
| 	// Unmarshal the body into the user struct | ||||
| 	err = json.Unmarshal(body, &user) | ||||
| 	err = json.Unmarshal(body, &userInfo) | ||||
|  | ||||
| 	// Check if there was an error | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 		return user, err | ||||
| 	} | ||||
|  | ||||
| 	log.Debug().Msg("Parsed user from google") | ||||
|  | ||||
| 	// Return the email | ||||
| 	return user.Email, nil | ||||
| 	// Map the user info to the user struct | ||||
| 	user.PreferredUsername = strings.Split(userInfo.Email, "@")[0] | ||||
| 	user.Name = userInfo.Name | ||||
| 	user.Email = userInfo.Email | ||||
|  | ||||
| 	// Return the user | ||||
| 	return user, nil | ||||
| } | ||||
|   | ||||
| @@ -2,6 +2,7 @@ package providers | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"tinyauth/internal/constants" | ||||
| 	"tinyauth/internal/oauth" | ||||
| 	"tinyauth/internal/types" | ||||
|  | ||||
| @@ -93,14 +94,17 @@ func (providers *Providers) GetProvider(provider string) *oauth.OAuth { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (providers *Providers) GetUser(provider string) (string, error) { | ||||
| 	// Get the email from the provider | ||||
| func (providers *Providers) GetUser(provider string) (constants.Claims, error) { | ||||
| 	// Create user struct | ||||
| 	var user constants.Claims | ||||
|  | ||||
| 	// Get the user from the provider | ||||
| 	switch provider { | ||||
| 	case "github": | ||||
| 		// If the github provider is not configured, return an error | ||||
| 		if providers.Github == nil { | ||||
| 			log.Debug().Msg("Github provider not configured") | ||||
| 			return "", nil | ||||
| 			return user, nil | ||||
| 		} | ||||
|  | ||||
| 		// Get the client from the github provider | ||||
| @@ -108,23 +112,23 @@ func (providers *Providers) GetUser(provider string) (string, error) { | ||||
|  | ||||
| 		log.Debug().Msg("Got client from github") | ||||
|  | ||||
| 		// Get the email from the github provider | ||||
| 		email, err := GetGithubEmail(client) | ||||
| 		// Get the user from the github provider | ||||
| 		user, err := GetGithubUser(client) | ||||
|  | ||||
| 		// Check if there was an error | ||||
| 		if err != nil { | ||||
| 			return "", err | ||||
| 			return user, err | ||||
| 		} | ||||
|  | ||||
| 		log.Debug().Msg("Got email from github") | ||||
| 		log.Debug().Msg("Got user from github") | ||||
|  | ||||
| 		// Return the email | ||||
| 		return email, nil | ||||
| 		// Return the user | ||||
| 		return user, nil | ||||
| 	case "google": | ||||
| 		// If the google provider is not configured, return an error | ||||
| 		if providers.Google == nil { | ||||
| 			log.Debug().Msg("Google provider not configured") | ||||
| 			return "", nil | ||||
| 			return user, nil | ||||
| 		} | ||||
|  | ||||
| 		// Get the client from the google provider | ||||
| @@ -132,23 +136,23 @@ func (providers *Providers) GetUser(provider string) (string, error) { | ||||
|  | ||||
| 		log.Debug().Msg("Got client from google") | ||||
|  | ||||
| 		// Get the email from the google provider | ||||
| 		email, err := GetGoogleEmail(client) | ||||
| 		// Get the user from the google provider | ||||
| 		user, err := GetGoogleUser(client) | ||||
|  | ||||
| 		// Check if there was an error | ||||
| 		if err != nil { | ||||
| 			return "", err | ||||
| 			return user, err | ||||
| 		} | ||||
|  | ||||
| 		log.Debug().Msg("Got email from google") | ||||
| 		log.Debug().Msg("Got user from google") | ||||
|  | ||||
| 		// Return the email | ||||
| 		return email, nil | ||||
| 		// Return the user | ||||
| 		return user, nil | ||||
| 	case "generic": | ||||
| 		// If the generic provider is not configured, return an error | ||||
| 		if providers.Generic == nil { | ||||
| 			log.Debug().Msg("Generic provider not configured") | ||||
| 			return "", nil | ||||
| 			return user, nil | ||||
| 		} | ||||
|  | ||||
| 		// Get the client from the generic provider | ||||
| @@ -156,20 +160,20 @@ func (providers *Providers) GetUser(provider string) (string, error) { | ||||
|  | ||||
| 		log.Debug().Msg("Got client from generic") | ||||
|  | ||||
| 		// Get the email from the generic provider | ||||
| 		email, err := GetGenericEmail(client, providers.Config.GenericUserURL) | ||||
| 		// Get the user from the generic provider | ||||
| 		user, err := GetGenericUser(client, providers.Config.GenericUserURL) | ||||
|  | ||||
| 		// Check if there was an error | ||||
| 		if err != nil { | ||||
| 			return "", err | ||||
| 			return user, err | ||||
| 		} | ||||
|  | ||||
| 		log.Debug().Msg("Got email from generic") | ||||
| 		log.Debug().Msg("Got user from generic") | ||||
|  | ||||
| 		// Return the email | ||||
| 		return email, nil | ||||
| 		return user, nil | ||||
| 	default: | ||||
| 		return "", nil | ||||
| 		return user, nil | ||||
| 	} | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -20,6 +20,7 @@ type OAuthRequest struct { | ||||
| type UnauthorizedQuery struct { | ||||
| 	Username string `url:"username"` | ||||
| 	Resource string `url:"resource"` | ||||
| 	GroupErr bool   `url:"groupErr"` | ||||
| } | ||||
|  | ||||
| // Proxy is the uri parameters for the proxy endpoint | ||||
| @@ -33,6 +34,8 @@ type UserContextResponse struct { | ||||
| 	Message     string `json:"message"` | ||||
| 	IsLoggedIn  bool   `json:"isLoggedIn"` | ||||
| 	Username    string `json:"username"` | ||||
| 	Name        string `json:"name"` | ||||
| 	Email       string `json:"email"` | ||||
| 	Provider    string `json:"provider"` | ||||
| 	Oauth       bool   `json:"oauth"` | ||||
| 	TotpPending bool   `json:"totpPending"` | ||||
|   | ||||
| @@ -78,3 +78,8 @@ type AuthConfig struct { | ||||
| 	LoginTimeout    int | ||||
| 	LoginMaxRetries int | ||||
| } | ||||
|  | ||||
| // HooksConfig is the configuration for the hooks service | ||||
| type HooksConfig struct { | ||||
| 	Domain string | ||||
| } | ||||
|   | ||||
| @@ -25,8 +25,11 @@ type OAuthProviders struct { | ||||
| // SessionCookie is the cookie for the session (exculding the expiry) | ||||
| type SessionCookie struct { | ||||
| 	Username    string | ||||
| 	Name        string | ||||
| 	Email       string | ||||
| 	Provider    string | ||||
| 	TotpPending bool | ||||
| 	OAuthGroups string | ||||
| } | ||||
|  | ||||
| // TinyauthLabels is the labels for the tinyauth container | ||||
| @@ -35,15 +38,19 @@ type TinyauthLabels struct { | ||||
| 	Users          string | ||||
| 	Allowed        string | ||||
| 	Headers        map[string]string | ||||
| 	OAuthGroups    string | ||||
| } | ||||
|  | ||||
| // UserContext is the context for the user | ||||
| type UserContext struct { | ||||
| 	Username    string | ||||
| 	Name        string | ||||
| 	Email       string | ||||
| 	IsLoggedIn  bool | ||||
| 	OAuth       bool | ||||
| 	Provider    string | ||||
| 	TotpPending bool | ||||
| 	OAuthGroups string | ||||
| } | ||||
|  | ||||
| // LoginAttempt tracks information about login attempts for rate limiting | ||||
|   | ||||
| @@ -204,6 +204,8 @@ func GetTinyauthLabels(labels map[string]string) types.TinyauthLabels { | ||||
| 					} | ||||
| 					tinyauthLabels.Headers[headerSplit[0]] = headerSplit[1] | ||||
| 				} | ||||
| 			case "tinyauth.oauth.groups": | ||||
| 				tinyauthLabels.OAuthGroups = value | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| @@ -323,3 +325,22 @@ func CheckWhitelist(whitelist string, str string) bool { | ||||
| 	// Return false if no match was found | ||||
| 	return false | ||||
| } | ||||
|  | ||||
| // Capitalize just the first letter of a string | ||||
| func Capitalize(str string) string { | ||||
| 	if len(str) == 0 { | ||||
| 		return "" | ||||
| 	} | ||||
| 	return strings.ToUpper(string([]rune(str)[0])) + string([]rune(str)[1:]) | ||||
| } | ||||
|  | ||||
| // Sanitize header removes all control characters from a string | ||||
| func SanitizeHeader(header string) string { | ||||
| 	return strings.Map(func(r rune) rune { | ||||
| 		// Allow only printable ASCII characters (32-126) and safe whitespace (space, tab) | ||||
| 		if r == ' ' || r == '\t' || (r >= 32 && r <= 126) { | ||||
| 			return r | ||||
| 		} | ||||
| 		return -1 | ||||
| 	}, header) | ||||
| } | ||||
|   | ||||
| @@ -467,3 +467,65 @@ func TestCheckWhitelist(t *testing.T) { | ||||
| 		t.Fatalf("Expected %v, got %v", expected, result) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Test capitalize | ||||
| func TestCapitalize(t *testing.T) { | ||||
| 	t.Log("Testing capitalize with a valid string") | ||||
|  | ||||
| 	// Create variables | ||||
| 	str := "test" | ||||
| 	expected := "Test" | ||||
|  | ||||
| 	// Test the capitalize function | ||||
| 	result := utils.Capitalize(str) | ||||
|  | ||||
| 	// Check if the result is equal to the expected | ||||
| 	if result != expected { | ||||
| 		t.Fatalf("Expected %v, got %v", expected, result) | ||||
| 	} | ||||
|  | ||||
| 	t.Log("Testing capitalize with an empty string") | ||||
|  | ||||
| 	// Create variables | ||||
| 	str = "" | ||||
| 	expected = "" | ||||
|  | ||||
| 	// Test the capitalize function | ||||
| 	result = utils.Capitalize(str) | ||||
|  | ||||
| 	// Check if the result is equal to the expected | ||||
| 	if result != expected { | ||||
| 		t.Fatalf("Expected %v, got %v", expected, result) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Test the header sanitizer | ||||
| func TestSanitizeHeader(t *testing.T) { | ||||
| 	t.Log("Testing sanitize header with a valid string") | ||||
|  | ||||
| 	// Create variables | ||||
| 	str := "X-Header=value" | ||||
| 	expected := "X-Header=value" | ||||
|  | ||||
| 	// Test the sanitize header function | ||||
| 	result := utils.SanitizeHeader(str) | ||||
|  | ||||
| 	// Check if the result is equal to the expected | ||||
| 	if result != expected { | ||||
| 		t.Fatalf("Expected %v, got %v", expected, result) | ||||
| 	} | ||||
|  | ||||
| 	t.Log("Testing sanitize header with an invalid string") | ||||
|  | ||||
| 	// Create variables | ||||
| 	str = "X-Header=val\nue" | ||||
| 	expected = "X-Header=value" | ||||
|  | ||||
| 	// Test the sanitize header function | ||||
| 	result = utils.SanitizeHeader(str) | ||||
|  | ||||
| 	// Check if the result is equal to the expected | ||||
| 	if result != expected { | ||||
| 		t.Fatalf("Expected %v, got %v", expected, result) | ||||
| 	} | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user