mirror of
				https://github.com/steveiliop56/tinyauth.git
				synced 2025-10-30 21:55:43 +00:00 
			
		
		
		
	feat: oauth email whitelist
This commit is contained in:
		| @@ -56,6 +56,9 @@ var rootCmd = &cobra.Command{ | |||||||
| 		users, parseErr := utils.ParseUsers(usersString) | 		users, parseErr := utils.ParseUsers(usersString) | ||||||
| 		HandleError(parseErr, "Failed to parse users") | 		HandleError(parseErr, "Failed to parse users") | ||||||
|  |  | ||||||
|  | 		// Create whitelist | ||||||
|  | 		whitelist := utils.ParseWhitelist(config.Whitelist) | ||||||
|  |  | ||||||
| 		// Create OAuth config | 		// Create OAuth config | ||||||
| 		oauthConfig := types.OAuthConfig{ | 		oauthConfig := types.OAuthConfig{ | ||||||
| 			GithubClientId:      config.GithubClientId, | 			GithubClientId:      config.GithubClientId, | ||||||
| @@ -72,7 +75,7 @@ var rootCmd = &cobra.Command{ | |||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// Create auth service | 		// Create auth service | ||||||
| 		auth := auth.NewAuth(users) | 		auth := auth.NewAuth(users, whitelist) | ||||||
|  |  | ||||||
| 		// Create OAuth providers service | 		// Create OAuth providers service | ||||||
| 		providers := providers.NewProviders(oauthConfig) | 		providers := providers.NewProviders(oauthConfig) | ||||||
| @@ -136,6 +139,7 @@ func init() { | |||||||
| 	rootCmd.Flags().String("generic-token-url", "", "Generic OAuth token URL.") | 	rootCmd.Flags().String("generic-token-url", "", "Generic OAuth token URL.") | ||||||
| 	rootCmd.Flags().String("generic-user-url", "", "Generic OAuth user info URL.") | 	rootCmd.Flags().String("generic-user-url", "", "Generic OAuth user info URL.") | ||||||
| 	rootCmd.Flags().Bool("disable-continue", false, "Disable continue screen and redirect to app directly.") | 	rootCmd.Flags().Bool("disable-continue", false, "Disable continue screen and redirect to app directly.") | ||||||
|  | 	rootCmd.Flags().String("whitelist", "", "Comma separated list of email addresses to whitelist (only for oauth).") | ||||||
| 	viper.BindEnv("port", "PORT") | 	viper.BindEnv("port", "PORT") | ||||||
| 	viper.BindEnv("address", "ADDRESS") | 	viper.BindEnv("address", "ADDRESS") | ||||||
| 	viper.BindEnv("secret", "SECRET") | 	viper.BindEnv("secret", "SECRET") | ||||||
| @@ -154,5 +158,6 @@ func init() { | |||||||
| 	viper.BindEnv("generic-token-url", "GENERIC_TOKEN_URL") | 	viper.BindEnv("generic-token-url", "GENERIC_TOKEN_URL") | ||||||
| 	viper.BindEnv("generic-user-url", "GENERIC_USER_URL") | 	viper.BindEnv("generic-user-url", "GENERIC_USER_URL") | ||||||
| 	viper.BindEnv("disable-continue", "DISABLE_CONTINUE") | 	viper.BindEnv("disable-continue", "DISABLE_CONTINUE") | ||||||
|  | 	viper.BindEnv("whitelist", "WHITELIST") | ||||||
| 	viper.BindPFlags(rootCmd.Flags()) | 	viper.BindPFlags(rootCmd.Flags()) | ||||||
| } | } | ||||||
|   | |||||||
| @@ -319,6 +319,33 @@ func (api *API) SetupRoutes() { | |||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | 		email, emailErr := api.Providers.GetUser(providerName.Provider) | ||||||
|  |  | ||||||
|  | 		if emailErr != nil { | ||||||
|  | 			log.Error().Err(emailErr).Msg("Failed to get user") | ||||||
|  | 			c.JSON(500, gin.H{ | ||||||
|  | 				"status":  500, | ||||||
|  | 				"message": "Internal Server Error", | ||||||
|  | 			}) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if !api.Auth.EmailWhitelisted(email) { | ||||||
|  | 			log.Warn().Str("email", email).Msg("Email not whitelisted") | ||||||
|  | 			unauthorizedQuery, unauthorizedQueryErr := query.Values(types.UnauthorizedQuery{ | ||||||
|  | 				Email: email, | ||||||
|  | 			}) | ||||||
|  | 			if unauthorizedQueryErr != nil { | ||||||
|  | 				log.Error().Err(unauthorizedQueryErr).Msg("Failed to build query") | ||||||
|  | 				c.JSON(501, gin.H{ | ||||||
|  | 					"status":  501, | ||||||
|  | 					"message": "Internal Server Error", | ||||||
|  | 				}) | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 			c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/unauthorized?%s", api.Config.AppURL, unauthorizedQuery.Encode())) | ||||||
|  | 		} | ||||||
|  |  | ||||||
| 		session := sessions.Default(c) | 		session := sessions.Default(c) | ||||||
| 		session.Set("tinyauth_sid", fmt.Sprintf("%s:%s", providerName.Provider, token)) | 		session.Set("tinyauth_sid", fmt.Sprintf("%s:%s", providerName.Provider, token)) | ||||||
| 		session.Save() | 		session.Save() | ||||||
| @@ -334,12 +361,12 @@ func (api *API) SetupRoutes() { | |||||||
|  |  | ||||||
| 		c.SetCookie("tinyauth_redirect_uri", "", -1, "/", api.Domain, api.Config.CookieSecure, true) | 		c.SetCookie("tinyauth_redirect_uri", "", -1, "/", api.Domain, api.Config.CookieSecure, true) | ||||||
|  |  | ||||||
| 		queries, queryErr := query.Values(types.LoginQuery{ | 		redirectQuery, redirectQueryErr := query.Values(types.LoginQuery{ | ||||||
| 			RedirectURI: redirectURI, | 			RedirectURI: redirectURI, | ||||||
| 		}) | 		}) | ||||||
|  |  | ||||||
| 		if queryErr != nil { | 		if redirectQueryErr != nil { | ||||||
| 			log.Error().Err(queryErr).Msg("Failed to build query") | 			log.Error().Err(redirectQueryErr).Msg("Failed to build query") | ||||||
| 			c.JSON(501, gin.H{ | 			c.JSON(501, gin.H{ | ||||||
| 				"status":  501, | 				"status":  501, | ||||||
| 				"message": "Internal Server Error", | 				"message": "Internal Server Error", | ||||||
| @@ -347,7 +374,7 @@ func (api *API) SetupRoutes() { | |||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/continue?%s", api.Config.AppURL, queries.Encode())) | 		c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/continue?%s", api.Config.AppURL, redirectQuery.Encode())) | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -6,14 +6,16 @@ import ( | |||||||
| 	"golang.org/x/crypto/bcrypt" | 	"golang.org/x/crypto/bcrypt" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func NewAuth(userList types.Users) *Auth { | func NewAuth(userList types.Users, whitelist []string) *Auth { | ||||||
| 	return &Auth{ | 	return &Auth{ | ||||||
| 		Users:     userList, | 		Users:     userList, | ||||||
|  | 		Whitelist: whitelist, | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| type Auth struct { | type Auth struct { | ||||||
| 	Users     types.Users | 	Users     types.Users | ||||||
|  | 	Whitelist []string | ||||||
| } | } | ||||||
|  |  | ||||||
| func (auth *Auth) GetUser(email string) *types.User { | func (auth *Auth) GetUser(email string) *types.User { | ||||||
| @@ -29,3 +31,15 @@ func (auth *Auth) CheckPassword(user types.User, password string) bool { | |||||||
| 	hashedPasswordErr := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) | 	hashedPasswordErr := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) | ||||||
| 	return hashedPasswordErr == nil | 	return hashedPasswordErr == nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (auth *Auth) EmailWhitelisted(emailSrc string) bool { | ||||||
|  | 	if len(auth.Whitelist) == 0 { | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  | 	for _, email := range auth.Whitelist { | ||||||
|  | 		if email == emailSrc { | ||||||
|  | 			return true | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return false | ||||||
|  | } | ||||||
|   | |||||||
| @@ -105,6 +105,17 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) (types.UserContext, error) { | |||||||
| 		}, nil | 		}, nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	if !hooks.Auth.EmailWhitelisted(email) { | ||||||
|  | 		session.Delete("tinyauth_sid") | ||||||
|  | 		session.Save() | ||||||
|  | 		return types.UserContext{ | ||||||
|  | 			Email:      "", | ||||||
|  | 			IsLoggedIn: false, | ||||||
|  | 			OAuth:      false, | ||||||
|  | 			Provider:   "", | ||||||
|  | 		}, nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	return types.UserContext{ | 	return types.UserContext{ | ||||||
| 		Email:      email, | 		Email:      email, | ||||||
| 		IsLoggedIn: true, | 		IsLoggedIn: true, | ||||||
|   | |||||||
| @@ -37,6 +37,7 @@ type Config struct { | |||||||
| 	GenericTokenURL     string `mapstructure:"generic-token-url"` | 	GenericTokenURL     string `mapstructure:"generic-token-url"` | ||||||
| 	GenericUserURL      string `mapstructure:"generic-user-info-url"` | 	GenericUserURL      string `mapstructure:"generic-user-info-url"` | ||||||
| 	DisableContinue     bool   `mapstructure:"disable-continue"` | 	DisableContinue     bool   `mapstructure:"disable-continue"` | ||||||
|  | 	Whitelist           string `mapstructure:"whitelist"` | ||||||
| } | } | ||||||
|  |  | ||||||
| type UserContext struct { | type UserContext struct { | ||||||
| @@ -78,3 +79,7 @@ type OAuthProviders struct { | |||||||
| 	Google    *oauth.OAuth | 	Google    *oauth.OAuth | ||||||
| 	Microsoft *oauth.OAuth | 	Microsoft *oauth.OAuth | ||||||
| } | } | ||||||
|  |  | ||||||
|  | type UnauthorizedQuery struct { | ||||||
|  | 	Email string `url:"email"` | ||||||
|  | } | ||||||
|   | |||||||
| @@ -74,3 +74,10 @@ func ParseFileToLine(content string) string { | |||||||
|  |  | ||||||
| 	return strings.Join(users, ",") | 	return strings.Join(users, ",") | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func ParseWhitelist(whitelist string) []string { | ||||||
|  | 	if whitelist == "" { | ||||||
|  | 		return []string{} | ||||||
|  | 	} | ||||||
|  | 	return strings.Split(whitelist, ",") | ||||||
|  | } | ||||||
|   | |||||||
| @@ -13,6 +13,7 @@ import { LoginPage } from "./pages/login-page.tsx"; | |||||||
| import { LogoutPage } from "./pages/logout-page.tsx"; | import { LogoutPage } from "./pages/logout-page.tsx"; | ||||||
| import { ContinuePage } from "./pages/continue-page.tsx"; | import { ContinuePage } from "./pages/continue-page.tsx"; | ||||||
| import { NotFoundPage } from "./pages/not-found-page.tsx"; | import { NotFoundPage } from "./pages/not-found-page.tsx"; | ||||||
|  | import { UnauthorizedPage } from "./pages/unauthorized-page.tsx"; | ||||||
|  |  | ||||||
| const queryClient = new QueryClient({ | const queryClient = new QueryClient({ | ||||||
|   defaultOptions: { |   defaultOptions: { | ||||||
| @@ -34,6 +35,7 @@ createRoot(document.getElementById("root")!).render( | |||||||
|               <Route path="/login" element={<LoginPage />} /> |               <Route path="/login" element={<LoginPage />} /> | ||||||
|               <Route path="/logout" element={<LogoutPage />} /> |               <Route path="/logout" element={<LogoutPage />} /> | ||||||
|               <Route path="/continue" element={<ContinuePage />} /> |               <Route path="/continue" element={<ContinuePage />} /> | ||||||
|  |               <Route path="/unauthorized" element={<UnauthorizedPage />} /> | ||||||
|               <Route path="*" element={<NotFoundPage />} /> |               <Route path="*" element={<NotFoundPage />} /> | ||||||
|             </Routes> |             </Routes> | ||||||
|           </BrowserRouter> |           </BrowserRouter> | ||||||
|   | |||||||
							
								
								
									
										41
									
								
								site/src/pages/unauthorized-page.tsx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								site/src/pages/unauthorized-page.tsx
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,41 @@ | |||||||
|  | import { Button, Code, Paper, Text } from "@mantine/core"; | ||||||
|  | import { Layout } from "../components/layouts/layout"; | ||||||
|  | import { useUserContext } from "../context/user-context"; | ||||||
|  | import { Navigate } from "react-router"; | ||||||
|  |  | ||||||
|  | export const UnauthorizedPage = () => { | ||||||
|  |   const queryString = window.location.search; | ||||||
|  |   const params = new URLSearchParams(queryString); | ||||||
|  |   const email = params.get("email"); | ||||||
|  |  | ||||||
|  |   const { isLoggedIn } = useUserContext(); | ||||||
|  |  | ||||||
|  |   if (isLoggedIn) { | ||||||
|  |     return <Navigate to="/" />; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   if (email === "null") { | ||||||
|  |     return <Navigate to="/" />; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   return ( | ||||||
|  |     <Layout> | ||||||
|  |       <Paper shadow="md" p={30} mt={30} radius="md" withBorder> | ||||||
|  |         <Text size="xl" fw={700}> | ||||||
|  |           Unauthorized | ||||||
|  |         </Text> | ||||||
|  |         <Text> | ||||||
|  |           The user with email address <Code>{email}</Code> is not authorized to | ||||||
|  |           login. | ||||||
|  |         </Text> | ||||||
|  |         <Button | ||||||
|  |           fullWidth | ||||||
|  |           mt="xl" | ||||||
|  |           onClick={() => window.location.replace("/login")} | ||||||
|  |         > | ||||||
|  |           Try again | ||||||
|  |         </Button> | ||||||
|  |       </Paper> | ||||||
|  |     </Layout> | ||||||
|  |   ); | ||||||
|  | }; | ||||||
		Reference in New Issue
	
	Block a user
	 Stavros
					Stavros