mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2025-10-28 04:35:40 +00:00
feat: oauth email whitelist
This commit is contained in:
@@ -56,6 +56,9 @@ var rootCmd = &cobra.Command{
|
|||||||
users, parseErr := utils.ParseUsers(usersString)
|
users, parseErr := utils.ParseUsers(usersString)
|
||||||
HandleError(parseErr, "Failed to parse users")
|
HandleError(parseErr, "Failed to parse users")
|
||||||
|
|
||||||
|
// Create whitelist
|
||||||
|
whitelist := utils.ParseWhitelist(config.Whitelist)
|
||||||
|
|
||||||
// Create OAuth config
|
// Create OAuth config
|
||||||
oauthConfig := types.OAuthConfig{
|
oauthConfig := types.OAuthConfig{
|
||||||
GithubClientId: config.GithubClientId,
|
GithubClientId: config.GithubClientId,
|
||||||
@@ -72,7 +75,7 @@ var rootCmd = &cobra.Command{
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create auth service
|
// Create auth service
|
||||||
auth := auth.NewAuth(users)
|
auth := auth.NewAuth(users, whitelist)
|
||||||
|
|
||||||
// Create OAuth providers service
|
// Create OAuth providers service
|
||||||
providers := providers.NewProviders(oauthConfig)
|
providers := providers.NewProviders(oauthConfig)
|
||||||
@@ -136,6 +139,7 @@ func init() {
|
|||||||
rootCmd.Flags().String("generic-token-url", "", "Generic OAuth token URL.")
|
rootCmd.Flags().String("generic-token-url", "", "Generic OAuth token URL.")
|
||||||
rootCmd.Flags().String("generic-user-url", "", "Generic OAuth user info URL.")
|
rootCmd.Flags().String("generic-user-url", "", "Generic OAuth user info URL.")
|
||||||
rootCmd.Flags().Bool("disable-continue", false, "Disable continue screen and redirect to app directly.")
|
rootCmd.Flags().Bool("disable-continue", false, "Disable continue screen and redirect to app directly.")
|
||||||
|
rootCmd.Flags().String("whitelist", "", "Comma separated list of email addresses to whitelist (only for oauth).")
|
||||||
viper.BindEnv("port", "PORT")
|
viper.BindEnv("port", "PORT")
|
||||||
viper.BindEnv("address", "ADDRESS")
|
viper.BindEnv("address", "ADDRESS")
|
||||||
viper.BindEnv("secret", "SECRET")
|
viper.BindEnv("secret", "SECRET")
|
||||||
@@ -154,5 +158,6 @@ func init() {
|
|||||||
viper.BindEnv("generic-token-url", "GENERIC_TOKEN_URL")
|
viper.BindEnv("generic-token-url", "GENERIC_TOKEN_URL")
|
||||||
viper.BindEnv("generic-user-url", "GENERIC_USER_URL")
|
viper.BindEnv("generic-user-url", "GENERIC_USER_URL")
|
||||||
viper.BindEnv("disable-continue", "DISABLE_CONTINUE")
|
viper.BindEnv("disable-continue", "DISABLE_CONTINUE")
|
||||||
|
viper.BindEnv("whitelist", "WHITELIST")
|
||||||
viper.BindPFlags(rootCmd.Flags())
|
viper.BindPFlags(rootCmd.Flags())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -319,6 +319,33 @@ func (api *API) SetupRoutes() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
email, emailErr := api.Providers.GetUser(providerName.Provider)
|
||||||
|
|
||||||
|
if emailErr != nil {
|
||||||
|
log.Error().Err(emailErr).Msg("Failed to get user")
|
||||||
|
c.JSON(500, gin.H{
|
||||||
|
"status": 500,
|
||||||
|
"message": "Internal Server Error",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !api.Auth.EmailWhitelisted(email) {
|
||||||
|
log.Warn().Str("email", email).Msg("Email not whitelisted")
|
||||||
|
unauthorizedQuery, unauthorizedQueryErr := query.Values(types.UnauthorizedQuery{
|
||||||
|
Email: email,
|
||||||
|
})
|
||||||
|
if unauthorizedQueryErr != nil {
|
||||||
|
log.Error().Err(unauthorizedQueryErr).Msg("Failed to build query")
|
||||||
|
c.JSON(501, gin.H{
|
||||||
|
"status": 501,
|
||||||
|
"message": "Internal Server Error",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/unauthorized?%s", api.Config.AppURL, unauthorizedQuery.Encode()))
|
||||||
|
}
|
||||||
|
|
||||||
session := sessions.Default(c)
|
session := sessions.Default(c)
|
||||||
session.Set("tinyauth_sid", fmt.Sprintf("%s:%s", providerName.Provider, token))
|
session.Set("tinyauth_sid", fmt.Sprintf("%s:%s", providerName.Provider, token))
|
||||||
session.Save()
|
session.Save()
|
||||||
@@ -334,12 +361,12 @@ func (api *API) SetupRoutes() {
|
|||||||
|
|
||||||
c.SetCookie("tinyauth_redirect_uri", "", -1, "/", api.Domain, api.Config.CookieSecure, true)
|
c.SetCookie("tinyauth_redirect_uri", "", -1, "/", api.Domain, api.Config.CookieSecure, true)
|
||||||
|
|
||||||
queries, queryErr := query.Values(types.LoginQuery{
|
redirectQuery, redirectQueryErr := query.Values(types.LoginQuery{
|
||||||
RedirectURI: redirectURI,
|
RedirectURI: redirectURI,
|
||||||
})
|
})
|
||||||
|
|
||||||
if queryErr != nil {
|
if redirectQueryErr != nil {
|
||||||
log.Error().Err(queryErr).Msg("Failed to build query")
|
log.Error().Err(redirectQueryErr).Msg("Failed to build query")
|
||||||
c.JSON(501, gin.H{
|
c.JSON(501, gin.H{
|
||||||
"status": 501,
|
"status": 501,
|
||||||
"message": "Internal Server Error",
|
"message": "Internal Server Error",
|
||||||
@@ -347,7 +374,7 @@ func (api *API) SetupRoutes() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/continue?%s", api.Config.AppURL, queries.Encode()))
|
c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/continue?%s", api.Config.AppURL, redirectQuery.Encode()))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,14 +6,16 @@ import (
|
|||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewAuth(userList types.Users) *Auth {
|
func NewAuth(userList types.Users, whitelist []string) *Auth {
|
||||||
return &Auth{
|
return &Auth{
|
||||||
Users: userList,
|
Users: userList,
|
||||||
|
Whitelist: whitelist,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type Auth struct {
|
type Auth struct {
|
||||||
Users types.Users
|
Users types.Users
|
||||||
|
Whitelist []string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *Auth) GetUser(email string) *types.User {
|
func (auth *Auth) GetUser(email string) *types.User {
|
||||||
@@ -28,4 +30,16 @@ func (auth *Auth) GetUser(email string) *types.User {
|
|||||||
func (auth *Auth) CheckPassword(user types.User, password string) bool {
|
func (auth *Auth) CheckPassword(user types.User, password string) bool {
|
||||||
hashedPasswordErr := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
|
hashedPasswordErr := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
|
||||||
return hashedPasswordErr == nil
|
return hashedPasswordErr == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (auth *Auth) EmailWhitelisted(emailSrc string) bool {
|
||||||
|
if len(auth.Whitelist) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
for _, email := range auth.Whitelist {
|
||||||
|
if email == emailSrc {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
@@ -105,6 +105,17 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) (types.UserContext, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !hooks.Auth.EmailWhitelisted(email) {
|
||||||
|
session.Delete("tinyauth_sid")
|
||||||
|
session.Save()
|
||||||
|
return types.UserContext{
|
||||||
|
Email: "",
|
||||||
|
IsLoggedIn: false,
|
||||||
|
OAuth: false,
|
||||||
|
Provider: "",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
return types.UserContext{
|
return types.UserContext{
|
||||||
Email: email,
|
Email: email,
|
||||||
IsLoggedIn: true,
|
IsLoggedIn: true,
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ type Config struct {
|
|||||||
GenericTokenURL string `mapstructure:"generic-token-url"`
|
GenericTokenURL string `mapstructure:"generic-token-url"`
|
||||||
GenericUserURL string `mapstructure:"generic-user-info-url"`
|
GenericUserURL string `mapstructure:"generic-user-info-url"`
|
||||||
DisableContinue bool `mapstructure:"disable-continue"`
|
DisableContinue bool `mapstructure:"disable-continue"`
|
||||||
|
Whitelist string `mapstructure:"whitelist"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type UserContext struct {
|
type UserContext struct {
|
||||||
@@ -78,3 +79,7 @@ type OAuthProviders struct {
|
|||||||
Google *oauth.OAuth
|
Google *oauth.OAuth
|
||||||
Microsoft *oauth.OAuth
|
Microsoft *oauth.OAuth
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type UnauthorizedQuery struct {
|
||||||
|
Email string `url:"email"`
|
||||||
|
}
|
||||||
|
|||||||
@@ -74,3 +74,10 @@ func ParseFileToLine(content string) string {
|
|||||||
|
|
||||||
return strings.Join(users, ",")
|
return strings.Join(users, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ParseWhitelist(whitelist string) []string {
|
||||||
|
if whitelist == "" {
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
return strings.Split(whitelist, ",")
|
||||||
|
}
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import { LoginPage } from "./pages/login-page.tsx";
|
|||||||
import { LogoutPage } from "./pages/logout-page.tsx";
|
import { LogoutPage } from "./pages/logout-page.tsx";
|
||||||
import { ContinuePage } from "./pages/continue-page.tsx";
|
import { ContinuePage } from "./pages/continue-page.tsx";
|
||||||
import { NotFoundPage } from "./pages/not-found-page.tsx";
|
import { NotFoundPage } from "./pages/not-found-page.tsx";
|
||||||
|
import { UnauthorizedPage } from "./pages/unauthorized-page.tsx";
|
||||||
|
|
||||||
const queryClient = new QueryClient({
|
const queryClient = new QueryClient({
|
||||||
defaultOptions: {
|
defaultOptions: {
|
||||||
@@ -34,6 +35,7 @@ createRoot(document.getElementById("root")!).render(
|
|||||||
<Route path="/login" element={<LoginPage />} />
|
<Route path="/login" element={<LoginPage />} />
|
||||||
<Route path="/logout" element={<LogoutPage />} />
|
<Route path="/logout" element={<LogoutPage />} />
|
||||||
<Route path="/continue" element={<ContinuePage />} />
|
<Route path="/continue" element={<ContinuePage />} />
|
||||||
|
<Route path="/unauthorized" element={<UnauthorizedPage />} />
|
||||||
<Route path="*" element={<NotFoundPage />} />
|
<Route path="*" element={<NotFoundPage />} />
|
||||||
</Routes>
|
</Routes>
|
||||||
</BrowserRouter>
|
</BrowserRouter>
|
||||||
|
|||||||
41
site/src/pages/unauthorized-page.tsx
Normal file
41
site/src/pages/unauthorized-page.tsx
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
import { Button, Code, Paper, Text } from "@mantine/core";
|
||||||
|
import { Layout } from "../components/layouts/layout";
|
||||||
|
import { useUserContext } from "../context/user-context";
|
||||||
|
import { Navigate } from "react-router";
|
||||||
|
|
||||||
|
export const UnauthorizedPage = () => {
|
||||||
|
const queryString = window.location.search;
|
||||||
|
const params = new URLSearchParams(queryString);
|
||||||
|
const email = params.get("email");
|
||||||
|
|
||||||
|
const { isLoggedIn } = useUserContext();
|
||||||
|
|
||||||
|
if (isLoggedIn) {
|
||||||
|
return <Navigate to="/" />;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (email === "null") {
|
||||||
|
return <Navigate to="/" />;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Layout>
|
||||||
|
<Paper shadow="md" p={30} mt={30} radius="md" withBorder>
|
||||||
|
<Text size="xl" fw={700}>
|
||||||
|
Unauthorized
|
||||||
|
</Text>
|
||||||
|
<Text>
|
||||||
|
The user with email address <Code>{email}</Code> is not authorized to
|
||||||
|
login.
|
||||||
|
</Text>
|
||||||
|
<Button
|
||||||
|
fullWidth
|
||||||
|
mt="xl"
|
||||||
|
onClick={() => window.location.replace("/login")}
|
||||||
|
>
|
||||||
|
Try again
|
||||||
|
</Button>
|
||||||
|
</Paper>
|
||||||
|
</Layout>
|
||||||
|
);
|
||||||
|
};
|
||||||
Reference in New Issue
Block a user