From c5a86398228c42a8f7e4564e2ef41e8031fb095c Mon Sep 17 00:00:00 2001 From: Stavros Date: Fri, 24 Jan 2025 20:17:08 +0200 Subject: [PATCH] feat: oauth email whitelist --- cmd/root.go | 7 ++++- internal/api/api.go | 35 +++++++++++++++++++++--- internal/auth/auth.go | 22 ++++++++++++--- internal/hooks/hooks.go | 11 ++++++++ internal/types/types.go | 5 ++++ internal/utils/utils.go | 7 +++++ site/src/main.tsx | 2 ++ site/src/pages/unauthorized-page.tsx | 41 ++++++++++++++++++++++++++++ 8 files changed, 121 insertions(+), 9 deletions(-) create mode 100644 site/src/pages/unauthorized-page.tsx diff --git a/cmd/root.go b/cmd/root.go index 9d0a0ae..a16a089 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -56,6 +56,9 @@ var rootCmd = &cobra.Command{ users, parseErr := utils.ParseUsers(usersString) HandleError(parseErr, "Failed to parse users") + // Create whitelist + whitelist := utils.ParseWhitelist(config.Whitelist) + // Create OAuth config oauthConfig := types.OAuthConfig{ GithubClientId: config.GithubClientId, @@ -72,7 +75,7 @@ var rootCmd = &cobra.Command{ } // Create auth service - auth := auth.NewAuth(users) + auth := auth.NewAuth(users, whitelist) // Create OAuth providers service providers := providers.NewProviders(oauthConfig) @@ -136,6 +139,7 @@ func init() { rootCmd.Flags().String("generic-token-url", "", "Generic OAuth token URL.") rootCmd.Flags().String("generic-user-url", "", "Generic OAuth user info URL.") rootCmd.Flags().Bool("disable-continue", false, "Disable continue screen and redirect to app directly.") + rootCmd.Flags().String("whitelist", "", "Comma separated list of email addresses to whitelist (only for oauth).") viper.BindEnv("port", "PORT") viper.BindEnv("address", "ADDRESS") viper.BindEnv("secret", "SECRET") @@ -154,5 +158,6 @@ func init() { viper.BindEnv("generic-token-url", "GENERIC_TOKEN_URL") viper.BindEnv("generic-user-url", "GENERIC_USER_URL") viper.BindEnv("disable-continue", "DISABLE_CONTINUE") + viper.BindEnv("whitelist", "WHITELIST") viper.BindPFlags(rootCmd.Flags()) } diff --git a/internal/api/api.go b/internal/api/api.go index 93b4d17..6229410 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -319,6 +319,33 @@ func (api *API) SetupRoutes() { return } + email, emailErr := api.Providers.GetUser(providerName.Provider) + + if emailErr != nil { + log.Error().Err(emailErr).Msg("Failed to get user") + c.JSON(500, gin.H{ + "status": 500, + "message": "Internal Server Error", + }) + return + } + + if !api.Auth.EmailWhitelisted(email) { + log.Warn().Str("email", email).Msg("Email not whitelisted") + unauthorizedQuery, unauthorizedQueryErr := query.Values(types.UnauthorizedQuery{ + Email: email, + }) + if unauthorizedQueryErr != nil { + log.Error().Err(unauthorizedQueryErr).Msg("Failed to build query") + c.JSON(501, gin.H{ + "status": 501, + "message": "Internal Server Error", + }) + return + } + c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/unauthorized?%s", api.Config.AppURL, unauthorizedQuery.Encode())) + } + session := sessions.Default(c) session.Set("tinyauth_sid", fmt.Sprintf("%s:%s", providerName.Provider, token)) session.Save() @@ -334,12 +361,12 @@ func (api *API) SetupRoutes() { c.SetCookie("tinyauth_redirect_uri", "", -1, "/", api.Domain, api.Config.CookieSecure, true) - queries, queryErr := query.Values(types.LoginQuery{ + redirectQuery, redirectQueryErr := query.Values(types.LoginQuery{ RedirectURI: redirectURI, }) - if queryErr != nil { - log.Error().Err(queryErr).Msg("Failed to build query") + if redirectQueryErr != nil { + log.Error().Err(redirectQueryErr).Msg("Failed to build query") c.JSON(501, gin.H{ "status": 501, "message": "Internal Server Error", @@ -347,7 +374,7 @@ func (api *API) SetupRoutes() { return } - c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/continue?%s", api.Config.AppURL, queries.Encode())) + c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/continue?%s", api.Config.AppURL, redirectQuery.Encode())) }) } diff --git a/internal/auth/auth.go b/internal/auth/auth.go index a82f737..d7bf1ea 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -6,14 +6,16 @@ import ( "golang.org/x/crypto/bcrypt" ) -func NewAuth(userList types.Users) *Auth { +func NewAuth(userList types.Users, whitelist []string) *Auth { return &Auth{ - Users: userList, + Users: userList, + Whitelist: whitelist, } } type Auth struct { - Users types.Users + Users types.Users + Whitelist []string } func (auth *Auth) GetUser(email string) *types.User { @@ -28,4 +30,16 @@ func (auth *Auth) GetUser(email string) *types.User { func (auth *Auth) CheckPassword(user types.User, password string) bool { hashedPasswordErr := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) return hashedPasswordErr == nil -} \ No newline at end of file +} + +func (auth *Auth) EmailWhitelisted(emailSrc string) bool { + if len(auth.Whitelist) == 0 { + return true + } + for _, email := range auth.Whitelist { + if email == emailSrc { + return true + } + } + return false +} diff --git a/internal/hooks/hooks.go b/internal/hooks/hooks.go index bfb1ac3..0fdff3b 100644 --- a/internal/hooks/hooks.go +++ b/internal/hooks/hooks.go @@ -105,6 +105,17 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) (types.UserContext, error) { }, nil } + if !hooks.Auth.EmailWhitelisted(email) { + session.Delete("tinyauth_sid") + session.Save() + return types.UserContext{ + Email: "", + IsLoggedIn: false, + OAuth: false, + Provider: "", + }, nil + } + return types.UserContext{ Email: email, IsLoggedIn: true, diff --git a/internal/types/types.go b/internal/types/types.go index d7ef1bf..f82b19c 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -37,6 +37,7 @@ type Config struct { GenericTokenURL string `mapstructure:"generic-token-url"` GenericUserURL string `mapstructure:"generic-user-info-url"` DisableContinue bool `mapstructure:"disable-continue"` + Whitelist string `mapstructure:"whitelist"` } type UserContext struct { @@ -78,3 +79,7 @@ type OAuthProviders struct { Google *oauth.OAuth Microsoft *oauth.OAuth } + +type UnauthorizedQuery struct { + Email string `url:"email"` +} diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 3e4a79c..1eaf38a 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -74,3 +74,10 @@ func ParseFileToLine(content string) string { return strings.Join(users, ",") } + +func ParseWhitelist(whitelist string) []string { + if whitelist == "" { + return []string{} + } + return strings.Split(whitelist, ",") +} diff --git a/site/src/main.tsx b/site/src/main.tsx index c0964a3..2cd216a 100644 --- a/site/src/main.tsx +++ b/site/src/main.tsx @@ -13,6 +13,7 @@ import { LoginPage } from "./pages/login-page.tsx"; import { LogoutPage } from "./pages/logout-page.tsx"; import { ContinuePage } from "./pages/continue-page.tsx"; import { NotFoundPage } from "./pages/not-found-page.tsx"; +import { UnauthorizedPage } from "./pages/unauthorized-page.tsx"; const queryClient = new QueryClient({ defaultOptions: { @@ -34,6 +35,7 @@ createRoot(document.getElementById("root")!).render( } /> } /> } /> + } /> } /> diff --git a/site/src/pages/unauthorized-page.tsx b/site/src/pages/unauthorized-page.tsx new file mode 100644 index 0000000..5648191 --- /dev/null +++ b/site/src/pages/unauthorized-page.tsx @@ -0,0 +1,41 @@ +import { Button, Code, Paper, Text } from "@mantine/core"; +import { Layout } from "../components/layouts/layout"; +import { useUserContext } from "../context/user-context"; +import { Navigate } from "react-router"; + +export const UnauthorizedPage = () => { + const queryString = window.location.search; + const params = new URLSearchParams(queryString); + const email = params.get("email"); + + const { isLoggedIn } = useUserContext(); + + if (isLoggedIn) { + return ; + } + + if (email === "null") { + return ; + } + + return ( + + + + Unauthorized + + + The user with email address {email} is not authorized to + login. + + + + + ); +};