refactor: use cookie store correctly

This commit is contained in:
Stavros
2025-01-26 19:20:34 +02:00
parent d67133aca7
commit 3b50d9303b
5 changed files with 83 additions and 94 deletions

View File

@@ -160,7 +160,7 @@ func init() {
viper.BindEnv("generic-token-url", "GENERIC_TOKEN_URL") viper.BindEnv("generic-token-url", "GENERIC_TOKEN_URL")
viper.BindEnv("generic-user-url", "GENERIC_USER_URL") viper.BindEnv("generic-user-url", "GENERIC_USER_URL")
viper.BindEnv("disable-continue", "DISABLE_CONTINUE") viper.BindEnv("disable-continue", "DISABLE_CONTINUE")
viper.BindEnv("oauth-whitelist", "WHITELIST") viper.BindEnv("oauth-whitelist", "OAUTH_WHITELIST")
viper.BindEnv("cookie-expiry", "COOKIE_EXPIRY") viper.BindEnv("cookie-expiry", "COOKIE_EXPIRY")
viper.BindPFlags(rootCmd.Flags()) viper.BindPFlags(rootCmd.Flags())
} }

View File

@@ -90,16 +90,7 @@ func (api *API) Init() {
func (api *API) SetupRoutes() { func (api *API) SetupRoutes() {
api.Router.GET("/api/auth", func(c *gin.Context) { api.Router.GET("/api/auth", func(c *gin.Context) {
userContext, userContextErr := api.Hooks.UseUserContext(c) userContext := api.Hooks.UseUserContext(c)
if userContextErr != nil {
log.Error().Err(userContextErr).Msg("Failed to get user context")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
})
return
}
if userContext.IsLoggedIn { if userContext.IsLoggedIn {
c.JSON(200, gin.H{ c.JSON(200, gin.H{
@@ -160,9 +151,10 @@ func (api *API) SetupRoutes() {
return return
} }
session := sessions.Default(c) api.Auth.CreateSessionCookie(c, &types.SessionCookie{
session.Set("tinyauth_sid", fmt.Sprintf("username:%s", login.Username)) Username: login.Username,
session.Save() Provider: "username",
})
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
@@ -171,9 +163,7 @@ func (api *API) SetupRoutes() {
}) })
api.Router.POST("/api/logout", func(c *gin.Context) { api.Router.POST("/api/logout", func(c *gin.Context) {
session := sessions.Default(c) api.Auth.DeleteSessionCookie(c)
session.Delete("tinyauth_sid")
session.Save()
c.SetCookie("tinyauth_redirect_uri", "", -1, "/", api.Domain, api.Config.CookieSecure, true) c.SetCookie("tinyauth_redirect_uri", "", -1, "/", api.Domain, api.Config.CookieSecure, true)
@@ -184,16 +174,7 @@ func (api *API) SetupRoutes() {
}) })
api.Router.GET("/api/status", func(c *gin.Context) { api.Router.GET("/api/status", func(c *gin.Context) {
userContext, userContextErr := api.Hooks.UseUserContext(c) userContext := api.Hooks.UseUserContext(c)
if userContextErr != nil {
log.Error().Err(userContextErr).Msg("Failed to get user context")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
})
return
}
if !userContext.IsLoggedIn { if !userContext.IsLoggedIn {
c.JSON(200, gin.H{ c.JSON(200, gin.H{
@@ -314,9 +295,10 @@ func (api *API) SetupRoutes() {
c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/unauthorized?%s", api.Config.AppURL, unauthorizedQuery.Encode())) c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/unauthorized?%s", api.Config.AppURL, unauthorizedQuery.Encode()))
} }
session := sessions.Default(c) api.Auth.CreateSessionCookie(c, &types.SessionCookie{
session.Set("tinyauth_sid", fmt.Sprintf("%s:%s", providerName.Provider, email)) Username: email,
session.Save() Provider: providerName.Provider,
})
redirectURI, redirectURIErr := c.Cookie("tinyauth_redirect_uri") redirectURI, redirectURIErr := c.Cookie("tinyauth_redirect_uri")

View File

@@ -3,6 +3,9 @@ package auth
import ( import (
"tinyauth/internal/types" "tinyauth/internal/types"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"github.com/rs/zerolog/log"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
@@ -43,3 +46,36 @@ func (auth *Auth) EmailWhitelisted(emailSrc string) bool {
} }
return false return false
} }
func (auth *Auth) CreateSessionCookie(c *gin.Context, data *types.SessionCookie) {
sessions := sessions.Default(c)
sessions.Set("username", data.Username)
sessions.Set("provider", data.Provider)
sessions.Save()
}
func (auth *Auth) DeleteSessionCookie(c *gin.Context) {
sessions := sessions.Default(c)
sessions.Clear()
sessions.Save()
}
func (auth *Auth) GetSessionCookie(c *gin.Context) (types.SessionCookie, error) {
sessions := sessions.Default(c)
cookieUsername := sessions.Get("username")
cookieProvider := sessions.Get("provider")
username, usernameOk := cookieUsername.(string)
provider, providerOk := cookieProvider.(string)
if !usernameOk || !providerOk {
log.Warn().Msg("Session cookie invalid")
return types.SessionCookie{}, nil
}
return types.SessionCookie{
Username: username,
Provider: provider,
}, nil
}

View File

@@ -1,13 +1,12 @@
package hooks package hooks
import ( import (
"strings"
"tinyauth/internal/auth" "tinyauth/internal/auth"
"tinyauth/internal/providers" "tinyauth/internal/providers"
"tinyauth/internal/types" "tinyauth/internal/types"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/rs/zerolog/log"
) )
func NewHooks(auth *auth.Auth, providers *providers.Providers) *Hooks { func NewHooks(auth *auth.Auth, providers *providers.Providers) *Hooks {
@@ -22,88 +21,55 @@ type Hooks struct {
Providers *providers.Providers Providers *providers.Providers
} }
func (hooks *Hooks) UseUserContext(c *gin.Context) (types.UserContext, error) { func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
session := sessions.Default(c) cookie, cookiErr := hooks.Auth.GetSessionCookie(c)
sessionCookie := session.Get("tinyauth_sid")
if sessionCookie == nil { if cookiErr != nil {
log.Error().Err(cookiErr).Msg("Failed to get session cookie")
return types.UserContext{ return types.UserContext{
Username: "", Username: "",
IsLoggedIn: false, IsLoggedIn: false,
OAuth: false, OAuth: false,
Provider: "", Provider: "",
}, nil }
} }
data, dataOk := sessionCookie.(string) if cookie.Provider == "username" {
if hooks.Auth.GetUser(cookie.Username) != nil {
if !dataOk { return types.UserContext{
return types.UserContext{ Username: cookie.Username,
Username: "", IsLoggedIn: true,
IsLoggedIn: false, OAuth: false,
OAuth: false, Provider: "",
Provider: "", }
}, nil }
} }
split := strings.Split(data, ":") provider := hooks.Providers.GetProvider(cookie.Provider)
if len(split) != 2 { if provider != nil {
return types.UserContext{ if !hooks.Auth.EmailWhitelisted(cookie.Username) {
Username: "", log.Error().Msgf("Email %s not whitelisted", cookie.Username)
IsLoggedIn: false, hooks.Auth.DeleteSessionCookie(c)
OAuth: false,
Provider: "",
}, nil
}
sessionType := split[0]
sessionValue := split[1]
if sessionType == "username" {
user := hooks.Auth.GetUser(sessionValue)
if user == nil {
return types.UserContext{ return types.UserContext{
Username: "", Username: "",
IsLoggedIn: false, IsLoggedIn: false,
OAuth: false, OAuth: false,
Provider: "", Provider: "",
}, nil }
} }
return types.UserContext{ return types.UserContext{
Username: sessionValue, Username: cookie.Username,
IsLoggedIn: true, IsLoggedIn: true,
OAuth: false, OAuth: true,
Provider: "", Provider: cookie.Provider,
}, nil }
}
provider := hooks.Providers.GetProvider(sessionType)
if provider == nil {
return types.UserContext{
Username: "",
IsLoggedIn: false,
OAuth: false,
Provider: "",
}, nil
}
if !hooks.Auth.EmailWhitelisted(sessionValue) {
session.Delete("tinyauth_sid")
session.Save()
return types.UserContext{
Username: "",
IsLoggedIn: false,
OAuth: false,
Provider: "",
}, nil
} }
return types.UserContext{ return types.UserContext{
Username: sessionValue, Username: "",
IsLoggedIn: true, IsLoggedIn: false,
OAuth: true, OAuth: false,
Provider: sessionType, Provider: "",
}, nil }
} }

View File

@@ -85,3 +85,8 @@ type OAuthProviders struct {
type UnauthorizedQuery struct { type UnauthorizedQuery struct {
Username string `url:"username"` Username string `url:"username"`
} }
type SessionCookie struct {
Username string
Provider string
}