From f25ab72747f99258951859b6295d6a3d0b3e1dd3 Mon Sep 17 00:00:00 2001 From: Stavros Date: Tue, 15 Jul 2025 02:10:16 +0300 Subject: [PATCH] refactor: check cookie prior to basiv auth in context hook --- internal/auth/auth.go | 14 ++-- internal/hooks/hooks.go | 150 +++++++++++++++++++--------------------- 2 files changed, 80 insertions(+), 84 deletions(-) diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 36e7db0..9f6f2c1 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -50,7 +50,7 @@ func (auth *Auth) GetSession(c *gin.Context) (*sessions.Session, error) { // If there was an error getting the session, it might be invalid so let's clear it and retry if err != nil { - log.Warn().Err(err).Msg("Invalid session, clearing cookie and retrying") + log.Error().Err(err).Msg("Invalid session, clearing cookie and retrying") c.SetCookie(auth.Config.SessionCookieName, "", -1, "/", fmt.Sprintf(".%s", auth.Config.Domain), auth.Config.CookieSecure, true) session, err = auth.Store.Get(c.Request, auth.Config.SessionCookieName) if err != nil { @@ -79,7 +79,7 @@ func (auth *Auth) SearchUser(username string) types.UserSearch { log.Debug().Str("username", username).Msg("Checking LDAP for user") userDN, err := auth.LDAP.Search(username) if err != nil { - log.Warn().Err(err).Str("username", username).Msg("Failed to find user in LDAP") + log.Error().Err(err).Str("username", username).Msg("Failed to find user in LDAP") return types.UserSearch{} } return types.UserSearch{ @@ -107,7 +107,7 @@ func (auth *Auth) VerifyUser(search types.UserSearch, password string) bool { err := auth.LDAP.Bind(search.Username, password) if err != nil { - log.Warn().Err(err).Str("username", search.Username).Msg("Failed to bind to LDAP") + log.Error().Err(err).Str("username", search.Username).Msg("Failed to bind to LDAP") return false } @@ -372,7 +372,7 @@ func (auth *Auth) AuthEnabled(uri string, labels types.Labels) (bool, error) { // If there is an error, invalid regex, auth enabled if err != nil { - log.Warn().Err(err).Msg("Invalid regex") + log.Error().Err(err).Msg("Invalid regex") return true, err } @@ -401,7 +401,7 @@ func (auth *Auth) CheckIP(labels types.Labels, ip string) bool { for _, blocked := range labels.IP.Block { res, err := utils.FilterIP(blocked, ip) if err != nil { - log.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list") + log.Error().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list") continue } if res { @@ -414,7 +414,7 @@ func (auth *Auth) CheckIP(labels types.Labels, ip string) bool { for _, allowed := range labels.IP.Allow { res, err := utils.FilterIP(allowed, ip) if err != nil { - log.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list") + log.Error().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list") continue } if res { @@ -438,7 +438,7 @@ func (auth *Auth) BypassedIP(labels types.Labels, ip string) bool { for _, bypassed := range labels.IP.Bypass { res, err := utils.FilterIP(bypassed, ip) if err != nil { - log.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list") + log.Error().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list") continue } if res { diff --git a/internal/hooks/hooks.go b/internal/hooks/hooks.go index c57b338..3083b98 100644 --- a/internal/hooks/hooks.go +++ b/internal/hooks/hooks.go @@ -4,6 +4,7 @@ import ( "fmt" "strings" "tinyauth/internal/auth" + "tinyauth/internal/oauth" "tinyauth/internal/providers" "tinyauth/internal/types" "tinyauth/internal/utils" @@ -27,28 +28,92 @@ func NewHooks(config types.HooksConfig, auth *auth.Auth, providers *providers.Pr } func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { - // Get session cookie and basic auth cookie, err := hooks.Auth.GetSessionCookie(c) + var provider *oauth.OAuth + + if err != nil { + log.Error().Err(err).Msg("Failed to get session cookie") + goto basic + } + + if cookie.TotpPending { + log.Debug().Msg("Totp pending") + return types.UserContext{ + Username: cookie.Username, + Name: cookie.Name, + Email: cookie.Email, + Provider: cookie.Provider, + TotpPending: true, + } + } + + if cookie.Provider == "username" { + log.Debug().Msg("Provider is username") + + userSearch := hooks.Auth.SearchUser(cookie.Username) + + if userSearch.Type == "unknown" { + log.Warn().Str("username", cookie.Username).Msg("User does not exist") + goto basic + } + + log.Debug().Str("type", userSearch.Type).Msg("User exists") + + return types.UserContext{ + Username: cookie.Username, + Name: cookie.Name, + Email: cookie.Email, + IsLoggedIn: true, + Provider: "username", + } + } + + log.Debug().Msg("Provider is not username") + + provider = hooks.Providers.GetProvider(cookie.Provider) + + if provider != nil { + log.Debug().Msg("Provider exists") + + if !hooks.Auth.EmailWhitelisted(cookie.Email) { + log.Warn().Str("email", cookie.Email).Msg("Email is not whitelisted") + hooks.Auth.DeleteSessionCookie(c) + goto basic + } + + log.Debug().Msg("Email is whitelisted") + + return types.UserContext{ + Username: cookie.Username, + Name: cookie.Name, + Email: cookie.Email, + IsLoggedIn: true, + OAuth: true, + Provider: cookie.Provider, + OAuthGroups: cookie.OAuthGroups, + } + } + +basic: + log.Debug().Msg("Trying basic auth") + basic := hooks.Auth.GetBasicAuth(c) - // Check if basic auth is set if basic != nil { log.Debug().Msg("Got basic auth") userSearch := hooks.Auth.SearchUser(basic.Username) if userSearch.Type == "unkown" { - log.Warn().Str("username", basic.Username).Msg("Basic auth user does not exist, skipping") - goto session + log.Error().Str("username", basic.Username).Msg("Basic auth user does not exist") + return types.UserContext{} } - // Verify the user if !hooks.Auth.VerifyUser(userSearch, basic.Password) { - log.Error().Str("username", basic.Username).Msg("Basic auth user password incorrect, skipping") - goto session + log.Error().Str("username", basic.Username).Msg("Basic auth user password incorrect") + return types.UserContext{} } - // Get the user type if userSearch.Type == "ldap" { log.Debug().Msg("User is LDAP") @@ -75,74 +140,5 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { } -session: - // Check cookie error after basic auth - if err != nil { - log.Error().Err(err).Msg("Failed to get session cookie") - return types.UserContext{} - } - - if cookie.TotpPending { - log.Debug().Msg("Totp pending") - return types.UserContext{ - Username: cookie.Username, - Name: cookie.Name, - Email: cookie.Email, - Provider: cookie.Provider, - TotpPending: true, - } - } - - // Check if session cookie is username/password auth - if cookie.Provider == "username" { - log.Debug().Msg("Provider is username") - - userSearch := hooks.Auth.SearchUser(cookie.Username) - - if userSearch.Type == "unknown" { - log.Error().Str("username", cookie.Username).Msg("User does not exist") - return types.UserContext{} - } - - log.Debug().Str("type", userSearch.Type).Msg("User exists") - - return types.UserContext{ - Username: cookie.Username, - Name: cookie.Name, - Email: cookie.Email, - IsLoggedIn: true, - Provider: "username", - } - } - - log.Debug().Msg("Provider is not username") - - // The provider is not username so we need to check if it is an oauth provider - provider := hooks.Providers.GetProvider(cookie.Provider) - - // If we have a provider with this name - if provider != nil { - log.Debug().Msg("Provider exists") - - // If the email is not whitelisted we delete the cookie and return an empty context - if !hooks.Auth.EmailWhitelisted(cookie.Email) { - log.Error().Str("email", cookie.Email).Msg("Email is not whitelisted") - hooks.Auth.DeleteSessionCookie(c) - return types.UserContext{} - } - - log.Debug().Msg("Email is whitelisted") - - return types.UserContext{ - Username: cookie.Username, - Name: cookie.Name, - Email: cookie.Email, - IsLoggedIn: true, - OAuth: true, - Provider: cookie.Provider, - OAuthGroups: cookie.OAuthGroups, - } - } - return types.UserContext{} }