From 3b50d9303b6915e183a45d64ec57c58cffa0d716 Mon Sep 17 00:00:00 2001 From: Stavros Date: Sun, 26 Jan 2025 19:20:34 +0200 Subject: [PATCH] refactor: use cookie store correctly --- cmd/root.go | 2 +- internal/api/api.go | 40 +++++------------- internal/auth/auth.go | 36 ++++++++++++++++ internal/hooks/hooks.go | 94 +++++++++++++---------------------------- internal/types/types.go | 5 +++ 5 files changed, 83 insertions(+), 94 deletions(-) diff --git a/cmd/root.go b/cmd/root.go index 411fa1d..cd74805 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -160,7 +160,7 @@ func init() { viper.BindEnv("generic-token-url", "GENERIC_TOKEN_URL") viper.BindEnv("generic-user-url", "GENERIC_USER_URL") viper.BindEnv("disable-continue", "DISABLE_CONTINUE") - viper.BindEnv("oauth-whitelist", "WHITELIST") + viper.BindEnv("oauth-whitelist", "OAUTH_WHITELIST") viper.BindEnv("cookie-expiry", "COOKIE_EXPIRY") viper.BindPFlags(rootCmd.Flags()) } diff --git a/internal/api/api.go b/internal/api/api.go index b10a010..2016732 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -90,16 +90,7 @@ func (api *API) Init() { func (api *API) SetupRoutes() { api.Router.GET("/api/auth", func(c *gin.Context) { - userContext, userContextErr := api.Hooks.UseUserContext(c) - - if userContextErr != nil { - log.Error().Err(userContextErr).Msg("Failed to get user context") - c.JSON(500, gin.H{ - "status": 500, - "message": "Internal Server Error", - }) - return - } + userContext := api.Hooks.UseUserContext(c) if userContext.IsLoggedIn { c.JSON(200, gin.H{ @@ -160,9 +151,10 @@ func (api *API) SetupRoutes() { return } - session := sessions.Default(c) - session.Set("tinyauth_sid", fmt.Sprintf("username:%s", login.Username)) - session.Save() + api.Auth.CreateSessionCookie(c, &types.SessionCookie{ + Username: login.Username, + Provider: "username", + }) c.JSON(200, gin.H{ "status": 200, @@ -171,9 +163,7 @@ func (api *API) SetupRoutes() { }) api.Router.POST("/api/logout", func(c *gin.Context) { - session := sessions.Default(c) - session.Delete("tinyauth_sid") - session.Save() + api.Auth.DeleteSessionCookie(c) c.SetCookie("tinyauth_redirect_uri", "", -1, "/", api.Domain, api.Config.CookieSecure, true) @@ -184,16 +174,7 @@ func (api *API) SetupRoutes() { }) api.Router.GET("/api/status", func(c *gin.Context) { - userContext, userContextErr := api.Hooks.UseUserContext(c) - - if userContextErr != nil { - log.Error().Err(userContextErr).Msg("Failed to get user context") - c.JSON(500, gin.H{ - "status": 500, - "message": "Internal Server Error", - }) - return - } + userContext := api.Hooks.UseUserContext(c) if !userContext.IsLoggedIn { c.JSON(200, gin.H{ @@ -314,9 +295,10 @@ func (api *API) SetupRoutes() { c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/unauthorized?%s", api.Config.AppURL, unauthorizedQuery.Encode())) } - session := sessions.Default(c) - session.Set("tinyauth_sid", fmt.Sprintf("%s:%s", providerName.Provider, email)) - session.Save() + api.Auth.CreateSessionCookie(c, &types.SessionCookie{ + Username: email, + Provider: providerName.Provider, + }) redirectURI, redirectURIErr := c.Cookie("tinyauth_redirect_uri") diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 5cfa784..b32b355 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -3,6 +3,9 @@ package auth import ( "tinyauth/internal/types" + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" + "github.com/rs/zerolog/log" "golang.org/x/crypto/bcrypt" ) @@ -43,3 +46,36 @@ func (auth *Auth) EmailWhitelisted(emailSrc string) bool { } return false } + +func (auth *Auth) CreateSessionCookie(c *gin.Context, data *types.SessionCookie) { + sessions := sessions.Default(c) + sessions.Set("username", data.Username) + sessions.Set("provider", data.Provider) + sessions.Save() +} + +func (auth *Auth) DeleteSessionCookie(c *gin.Context) { + sessions := sessions.Default(c) + sessions.Clear() + sessions.Save() +} + +func (auth *Auth) GetSessionCookie(c *gin.Context) (types.SessionCookie, error) { + sessions := sessions.Default(c) + + cookieUsername := sessions.Get("username") + cookieProvider := sessions.Get("provider") + + username, usernameOk := cookieUsername.(string) + provider, providerOk := cookieProvider.(string) + + if !usernameOk || !providerOk { + log.Warn().Msg("Session cookie invalid") + return types.SessionCookie{}, nil + } + + return types.SessionCookie{ + Username: username, + Provider: provider, + }, nil +} diff --git a/internal/hooks/hooks.go b/internal/hooks/hooks.go index 7dc4e02..3d1ce4c 100644 --- a/internal/hooks/hooks.go +++ b/internal/hooks/hooks.go @@ -1,13 +1,12 @@ package hooks import ( - "strings" "tinyauth/internal/auth" "tinyauth/internal/providers" "tinyauth/internal/types" - "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" + "github.com/rs/zerolog/log" ) func NewHooks(auth *auth.Auth, providers *providers.Providers) *Hooks { @@ -22,88 +21,55 @@ type Hooks struct { Providers *providers.Providers } -func (hooks *Hooks) UseUserContext(c *gin.Context) (types.UserContext, error) { - session := sessions.Default(c) - sessionCookie := session.Get("tinyauth_sid") +func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { + cookie, cookiErr := hooks.Auth.GetSessionCookie(c) - if sessionCookie == nil { + if cookiErr != nil { + log.Error().Err(cookiErr).Msg("Failed to get session cookie") return types.UserContext{ Username: "", IsLoggedIn: false, OAuth: false, Provider: "", - }, nil + } } - data, dataOk := sessionCookie.(string) - - if !dataOk { - return types.UserContext{ - Username: "", - IsLoggedIn: false, - OAuth: false, - Provider: "", - }, nil + if cookie.Provider == "username" { + if hooks.Auth.GetUser(cookie.Username) != nil { + return types.UserContext{ + Username: cookie.Username, + IsLoggedIn: true, + OAuth: false, + Provider: "", + } + } } - split := strings.Split(data, ":") + provider := hooks.Providers.GetProvider(cookie.Provider) - if len(split) != 2 { - return types.UserContext{ - Username: "", - IsLoggedIn: false, - OAuth: false, - Provider: "", - }, nil - } - - sessionType := split[0] - sessionValue := split[1] - - if sessionType == "username" { - user := hooks.Auth.GetUser(sessionValue) - if user == nil { + if provider != nil { + if !hooks.Auth.EmailWhitelisted(cookie.Username) { + log.Error().Msgf("Email %s not whitelisted", cookie.Username) + hooks.Auth.DeleteSessionCookie(c) return types.UserContext{ Username: "", IsLoggedIn: false, OAuth: false, Provider: "", - }, nil + } } return types.UserContext{ - Username: sessionValue, + Username: cookie.Username, IsLoggedIn: true, - OAuth: false, - Provider: "", - }, nil - } - - provider := hooks.Providers.GetProvider(sessionType) - - if provider == nil { - return types.UserContext{ - Username: "", - IsLoggedIn: false, - OAuth: false, - Provider: "", - }, nil - } - - if !hooks.Auth.EmailWhitelisted(sessionValue) { - session.Delete("tinyauth_sid") - session.Save() - return types.UserContext{ - Username: "", - IsLoggedIn: false, - OAuth: false, - Provider: "", - }, nil + OAuth: true, + Provider: cookie.Provider, + } } return types.UserContext{ - Username: sessionValue, - IsLoggedIn: true, - OAuth: true, - Provider: sessionType, - }, nil + Username: "", + IsLoggedIn: false, + OAuth: false, + Provider: "", + } } diff --git a/internal/types/types.go b/internal/types/types.go index b85adc9..4b98a30 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -85,3 +85,8 @@ type OAuthProviders struct { type UnauthorizedQuery struct { Username string `url:"username"` } + +type SessionCookie struct { + Username string + Provider string +}