feat: oauth email whitelist

This commit is contained in:
Stavros
2025-01-24 20:17:08 +02:00
parent b87cb54d91
commit c5a8639822
8 changed files with 121 additions and 9 deletions

View File

@@ -56,6 +56,9 @@ var rootCmd = &cobra.Command{
users, parseErr := utils.ParseUsers(usersString) users, parseErr := utils.ParseUsers(usersString)
HandleError(parseErr, "Failed to parse users") HandleError(parseErr, "Failed to parse users")
// Create whitelist
whitelist := utils.ParseWhitelist(config.Whitelist)
// Create OAuth config // Create OAuth config
oauthConfig := types.OAuthConfig{ oauthConfig := types.OAuthConfig{
GithubClientId: config.GithubClientId, GithubClientId: config.GithubClientId,
@@ -72,7 +75,7 @@ var rootCmd = &cobra.Command{
} }
// Create auth service // Create auth service
auth := auth.NewAuth(users) auth := auth.NewAuth(users, whitelist)
// Create OAuth providers service // Create OAuth providers service
providers := providers.NewProviders(oauthConfig) providers := providers.NewProviders(oauthConfig)
@@ -136,6 +139,7 @@ func init() {
rootCmd.Flags().String("generic-token-url", "", "Generic OAuth token URL.") rootCmd.Flags().String("generic-token-url", "", "Generic OAuth token URL.")
rootCmd.Flags().String("generic-user-url", "", "Generic OAuth user info URL.") rootCmd.Flags().String("generic-user-url", "", "Generic OAuth user info URL.")
rootCmd.Flags().Bool("disable-continue", false, "Disable continue screen and redirect to app directly.") rootCmd.Flags().Bool("disable-continue", false, "Disable continue screen and redirect to app directly.")
rootCmd.Flags().String("whitelist", "", "Comma separated list of email addresses to whitelist (only for oauth).")
viper.BindEnv("port", "PORT") viper.BindEnv("port", "PORT")
viper.BindEnv("address", "ADDRESS") viper.BindEnv("address", "ADDRESS")
viper.BindEnv("secret", "SECRET") viper.BindEnv("secret", "SECRET")
@@ -154,5 +158,6 @@ func init() {
viper.BindEnv("generic-token-url", "GENERIC_TOKEN_URL") viper.BindEnv("generic-token-url", "GENERIC_TOKEN_URL")
viper.BindEnv("generic-user-url", "GENERIC_USER_URL") viper.BindEnv("generic-user-url", "GENERIC_USER_URL")
viper.BindEnv("disable-continue", "DISABLE_CONTINUE") viper.BindEnv("disable-continue", "DISABLE_CONTINUE")
viper.BindEnv("whitelist", "WHITELIST")
viper.BindPFlags(rootCmd.Flags()) viper.BindPFlags(rootCmd.Flags())
} }

View File

@@ -319,6 +319,33 @@ func (api *API) SetupRoutes() {
return return
} }
email, emailErr := api.Providers.GetUser(providerName.Provider)
if emailErr != nil {
log.Error().Err(emailErr).Msg("Failed to get user")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
})
return
}
if !api.Auth.EmailWhitelisted(email) {
log.Warn().Str("email", email).Msg("Email not whitelisted")
unauthorizedQuery, unauthorizedQueryErr := query.Values(types.UnauthorizedQuery{
Email: email,
})
if unauthorizedQueryErr != nil {
log.Error().Err(unauthorizedQueryErr).Msg("Failed to build query")
c.JSON(501, gin.H{
"status": 501,
"message": "Internal Server Error",
})
return
}
c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/unauthorized?%s", api.Config.AppURL, unauthorizedQuery.Encode()))
}
session := sessions.Default(c) session := sessions.Default(c)
session.Set("tinyauth_sid", fmt.Sprintf("%s:%s", providerName.Provider, token)) session.Set("tinyauth_sid", fmt.Sprintf("%s:%s", providerName.Provider, token))
session.Save() session.Save()
@@ -334,12 +361,12 @@ func (api *API) SetupRoutes() {
c.SetCookie("tinyauth_redirect_uri", "", -1, "/", api.Domain, api.Config.CookieSecure, true) c.SetCookie("tinyauth_redirect_uri", "", -1, "/", api.Domain, api.Config.CookieSecure, true)
queries, queryErr := query.Values(types.LoginQuery{ redirectQuery, redirectQueryErr := query.Values(types.LoginQuery{
RedirectURI: redirectURI, RedirectURI: redirectURI,
}) })
if queryErr != nil { if redirectQueryErr != nil {
log.Error().Err(queryErr).Msg("Failed to build query") log.Error().Err(redirectQueryErr).Msg("Failed to build query")
c.JSON(501, gin.H{ c.JSON(501, gin.H{
"status": 501, "status": 501,
"message": "Internal Server Error", "message": "Internal Server Error",
@@ -347,7 +374,7 @@ func (api *API) SetupRoutes() {
return return
} }
c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/continue?%s", api.Config.AppURL, queries.Encode())) c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/continue?%s", api.Config.AppURL, redirectQuery.Encode()))
}) })
} }

View File

@@ -6,14 +6,16 @@ import (
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
func NewAuth(userList types.Users) *Auth { func NewAuth(userList types.Users, whitelist []string) *Auth {
return &Auth{ return &Auth{
Users: userList, Users: userList,
Whitelist: whitelist,
} }
} }
type Auth struct { type Auth struct {
Users types.Users Users types.Users
Whitelist []string
} }
func (auth *Auth) GetUser(email string) *types.User { func (auth *Auth) GetUser(email string) *types.User {
@@ -28,4 +30,16 @@ func (auth *Auth) GetUser(email string) *types.User {
func (auth *Auth) CheckPassword(user types.User, password string) bool { func (auth *Auth) CheckPassword(user types.User, password string) bool {
hashedPasswordErr := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) hashedPasswordErr := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
return hashedPasswordErr == nil return hashedPasswordErr == nil
} }
func (auth *Auth) EmailWhitelisted(emailSrc string) bool {
if len(auth.Whitelist) == 0 {
return true
}
for _, email := range auth.Whitelist {
if email == emailSrc {
return true
}
}
return false
}

View File

@@ -105,6 +105,17 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) (types.UserContext, error) {
}, nil }, nil
} }
if !hooks.Auth.EmailWhitelisted(email) {
session.Delete("tinyauth_sid")
session.Save()
return types.UserContext{
Email: "",
IsLoggedIn: false,
OAuth: false,
Provider: "",
}, nil
}
return types.UserContext{ return types.UserContext{
Email: email, Email: email,
IsLoggedIn: true, IsLoggedIn: true,

View File

@@ -37,6 +37,7 @@ type Config struct {
GenericTokenURL string `mapstructure:"generic-token-url"` GenericTokenURL string `mapstructure:"generic-token-url"`
GenericUserURL string `mapstructure:"generic-user-info-url"` GenericUserURL string `mapstructure:"generic-user-info-url"`
DisableContinue bool `mapstructure:"disable-continue"` DisableContinue bool `mapstructure:"disable-continue"`
Whitelist string `mapstructure:"whitelist"`
} }
type UserContext struct { type UserContext struct {
@@ -78,3 +79,7 @@ type OAuthProviders struct {
Google *oauth.OAuth Google *oauth.OAuth
Microsoft *oauth.OAuth Microsoft *oauth.OAuth
} }
type UnauthorizedQuery struct {
Email string `url:"email"`
}

View File

@@ -74,3 +74,10 @@ func ParseFileToLine(content string) string {
return strings.Join(users, ",") return strings.Join(users, ",")
} }
func ParseWhitelist(whitelist string) []string {
if whitelist == "" {
return []string{}
}
return strings.Split(whitelist, ",")
}

View File

@@ -13,6 +13,7 @@ import { LoginPage } from "./pages/login-page.tsx";
import { LogoutPage } from "./pages/logout-page.tsx"; import { LogoutPage } from "./pages/logout-page.tsx";
import { ContinuePage } from "./pages/continue-page.tsx"; import { ContinuePage } from "./pages/continue-page.tsx";
import { NotFoundPage } from "./pages/not-found-page.tsx"; import { NotFoundPage } from "./pages/not-found-page.tsx";
import { UnauthorizedPage } from "./pages/unauthorized-page.tsx";
const queryClient = new QueryClient({ const queryClient = new QueryClient({
defaultOptions: { defaultOptions: {
@@ -34,6 +35,7 @@ createRoot(document.getElementById("root")!).render(
<Route path="/login" element={<LoginPage />} /> <Route path="/login" element={<LoginPage />} />
<Route path="/logout" element={<LogoutPage />} /> <Route path="/logout" element={<LogoutPage />} />
<Route path="/continue" element={<ContinuePage />} /> <Route path="/continue" element={<ContinuePage />} />
<Route path="/unauthorized" element={<UnauthorizedPage />} />
<Route path="*" element={<NotFoundPage />} /> <Route path="*" element={<NotFoundPage />} />
</Routes> </Routes>
</BrowserRouter> </BrowserRouter>

View File

@@ -0,0 +1,41 @@
import { Button, Code, Paper, Text } from "@mantine/core";
import { Layout } from "../components/layouts/layout";
import { useUserContext } from "../context/user-context";
import { Navigate } from "react-router";
export const UnauthorizedPage = () => {
const queryString = window.location.search;
const params = new URLSearchParams(queryString);
const email = params.get("email");
const { isLoggedIn } = useUserContext();
if (isLoggedIn) {
return <Navigate to="/" />;
}
if (email === "null") {
return <Navigate to="/" />;
}
return (
<Layout>
<Paper shadow="md" p={30} mt={30} radius="md" withBorder>
<Text size="xl" fw={700}>
Unauthorized
</Text>
<Text>
The user with email address <Code>{email}</Code> is not authorized to
login.
</Text>
<Button
fullWidth
mt="xl"
onClick={() => window.location.replace("/login")}
>
Try again
</Button>
</Paper>
</Layout>
);
};