mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2025-10-29 13:15:46 +00:00
refactor: check cookie prior to basiv auth in context hook
This commit is contained in:
@@ -50,7 +50,7 @@ func (auth *Auth) GetSession(c *gin.Context) (*sessions.Session, error) {
|
|||||||
|
|
||||||
// If there was an error getting the session, it might be invalid so let's clear it and retry
|
// If there was an error getting the session, it might be invalid so let's clear it and retry
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn().Err(err).Msg("Invalid session, clearing cookie and retrying")
|
log.Error().Err(err).Msg("Invalid session, clearing cookie and retrying")
|
||||||
c.SetCookie(auth.Config.SessionCookieName, "", -1, "/", fmt.Sprintf(".%s", auth.Config.Domain), auth.Config.CookieSecure, true)
|
c.SetCookie(auth.Config.SessionCookieName, "", -1, "/", fmt.Sprintf(".%s", auth.Config.Domain), auth.Config.CookieSecure, true)
|
||||||
session, err = auth.Store.Get(c.Request, auth.Config.SessionCookieName)
|
session, err = auth.Store.Get(c.Request, auth.Config.SessionCookieName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -79,7 +79,7 @@ func (auth *Auth) SearchUser(username string) types.UserSearch {
|
|||||||
log.Debug().Str("username", username).Msg("Checking LDAP for user")
|
log.Debug().Str("username", username).Msg("Checking LDAP for user")
|
||||||
userDN, err := auth.LDAP.Search(username)
|
userDN, err := auth.LDAP.Search(username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn().Err(err).Str("username", username).Msg("Failed to find user in LDAP")
|
log.Error().Err(err).Str("username", username).Msg("Failed to find user in LDAP")
|
||||||
return types.UserSearch{}
|
return types.UserSearch{}
|
||||||
}
|
}
|
||||||
return types.UserSearch{
|
return types.UserSearch{
|
||||||
@@ -107,7 +107,7 @@ func (auth *Auth) VerifyUser(search types.UserSearch, password string) bool {
|
|||||||
|
|
||||||
err := auth.LDAP.Bind(search.Username, password)
|
err := auth.LDAP.Bind(search.Username, password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn().Err(err).Str("username", search.Username).Msg("Failed to bind to LDAP")
|
log.Error().Err(err).Str("username", search.Username).Msg("Failed to bind to LDAP")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -372,7 +372,7 @@ func (auth *Auth) AuthEnabled(uri string, labels types.Labels) (bool, error) {
|
|||||||
|
|
||||||
// If there is an error, invalid regex, auth enabled
|
// If there is an error, invalid regex, auth enabled
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn().Err(err).Msg("Invalid regex")
|
log.Error().Err(err).Msg("Invalid regex")
|
||||||
return true, err
|
return true, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -401,7 +401,7 @@ func (auth *Auth) CheckIP(labels types.Labels, ip string) bool {
|
|||||||
for _, blocked := range labels.IP.Block {
|
for _, blocked := range labels.IP.Block {
|
||||||
res, err := utils.FilterIP(blocked, ip)
|
res, err := utils.FilterIP(blocked, ip)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list")
|
log.Error().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if res {
|
if res {
|
||||||
@@ -414,7 +414,7 @@ func (auth *Auth) CheckIP(labels types.Labels, ip string) bool {
|
|||||||
for _, allowed := range labels.IP.Allow {
|
for _, allowed := range labels.IP.Allow {
|
||||||
res, err := utils.FilterIP(allowed, ip)
|
res, err := utils.FilterIP(allowed, ip)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list")
|
log.Error().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if res {
|
if res {
|
||||||
@@ -438,7 +438,7 @@ func (auth *Auth) BypassedIP(labels types.Labels, ip string) bool {
|
|||||||
for _, bypassed := range labels.IP.Bypass {
|
for _, bypassed := range labels.IP.Bypass {
|
||||||
res, err := utils.FilterIP(bypassed, ip)
|
res, err := utils.FilterIP(bypassed, ip)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list")
|
log.Error().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if res {
|
if res {
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"tinyauth/internal/auth"
|
"tinyauth/internal/auth"
|
||||||
|
"tinyauth/internal/oauth"
|
||||||
"tinyauth/internal/providers"
|
"tinyauth/internal/providers"
|
||||||
"tinyauth/internal/types"
|
"tinyauth/internal/types"
|
||||||
"tinyauth/internal/utils"
|
"tinyauth/internal/utils"
|
||||||
@@ -27,28 +28,92 @@ func NewHooks(config types.HooksConfig, auth *auth.Auth, providers *providers.Pr
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
|
func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
|
||||||
// Get session cookie and basic auth
|
|
||||||
cookie, err := hooks.Auth.GetSessionCookie(c)
|
cookie, err := hooks.Auth.GetSessionCookie(c)
|
||||||
|
var provider *oauth.OAuth
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("Failed to get session cookie")
|
||||||
|
goto basic
|
||||||
|
}
|
||||||
|
|
||||||
|
if cookie.TotpPending {
|
||||||
|
log.Debug().Msg("Totp pending")
|
||||||
|
return types.UserContext{
|
||||||
|
Username: cookie.Username,
|
||||||
|
Name: cookie.Name,
|
||||||
|
Email: cookie.Email,
|
||||||
|
Provider: cookie.Provider,
|
||||||
|
TotpPending: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if cookie.Provider == "username" {
|
||||||
|
log.Debug().Msg("Provider is username")
|
||||||
|
|
||||||
|
userSearch := hooks.Auth.SearchUser(cookie.Username)
|
||||||
|
|
||||||
|
if userSearch.Type == "unknown" {
|
||||||
|
log.Warn().Str("username", cookie.Username).Msg("User does not exist")
|
||||||
|
goto basic
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().Str("type", userSearch.Type).Msg("User exists")
|
||||||
|
|
||||||
|
return types.UserContext{
|
||||||
|
Username: cookie.Username,
|
||||||
|
Name: cookie.Name,
|
||||||
|
Email: cookie.Email,
|
||||||
|
IsLoggedIn: true,
|
||||||
|
Provider: "username",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().Msg("Provider is not username")
|
||||||
|
|
||||||
|
provider = hooks.Providers.GetProvider(cookie.Provider)
|
||||||
|
|
||||||
|
if provider != nil {
|
||||||
|
log.Debug().Msg("Provider exists")
|
||||||
|
|
||||||
|
if !hooks.Auth.EmailWhitelisted(cookie.Email) {
|
||||||
|
log.Warn().Str("email", cookie.Email).Msg("Email is not whitelisted")
|
||||||
|
hooks.Auth.DeleteSessionCookie(c)
|
||||||
|
goto basic
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().Msg("Email is whitelisted")
|
||||||
|
|
||||||
|
return types.UserContext{
|
||||||
|
Username: cookie.Username,
|
||||||
|
Name: cookie.Name,
|
||||||
|
Email: cookie.Email,
|
||||||
|
IsLoggedIn: true,
|
||||||
|
OAuth: true,
|
||||||
|
Provider: cookie.Provider,
|
||||||
|
OAuthGroups: cookie.OAuthGroups,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
basic:
|
||||||
|
log.Debug().Msg("Trying basic auth")
|
||||||
|
|
||||||
basic := hooks.Auth.GetBasicAuth(c)
|
basic := hooks.Auth.GetBasicAuth(c)
|
||||||
|
|
||||||
// Check if basic auth is set
|
|
||||||
if basic != nil {
|
if basic != nil {
|
||||||
log.Debug().Msg("Got basic auth")
|
log.Debug().Msg("Got basic auth")
|
||||||
|
|
||||||
userSearch := hooks.Auth.SearchUser(basic.Username)
|
userSearch := hooks.Auth.SearchUser(basic.Username)
|
||||||
|
|
||||||
if userSearch.Type == "unkown" {
|
if userSearch.Type == "unkown" {
|
||||||
log.Warn().Str("username", basic.Username).Msg("Basic auth user does not exist, skipping")
|
log.Error().Str("username", basic.Username).Msg("Basic auth user does not exist")
|
||||||
goto session
|
return types.UserContext{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify the user
|
|
||||||
if !hooks.Auth.VerifyUser(userSearch, basic.Password) {
|
if !hooks.Auth.VerifyUser(userSearch, basic.Password) {
|
||||||
log.Error().Str("username", basic.Username).Msg("Basic auth user password incorrect, skipping")
|
log.Error().Str("username", basic.Username).Msg("Basic auth user password incorrect")
|
||||||
goto session
|
return types.UserContext{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the user type
|
|
||||||
if userSearch.Type == "ldap" {
|
if userSearch.Type == "ldap" {
|
||||||
log.Debug().Msg("User is LDAP")
|
log.Debug().Msg("User is LDAP")
|
||||||
|
|
||||||
@@ -75,74 +140,5 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
session:
|
|
||||||
// Check cookie error after basic auth
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msg("Failed to get session cookie")
|
|
||||||
return types.UserContext{}
|
|
||||||
}
|
|
||||||
|
|
||||||
if cookie.TotpPending {
|
|
||||||
log.Debug().Msg("Totp pending")
|
|
||||||
return types.UserContext{
|
|
||||||
Username: cookie.Username,
|
|
||||||
Name: cookie.Name,
|
|
||||||
Email: cookie.Email,
|
|
||||||
Provider: cookie.Provider,
|
|
||||||
TotpPending: true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if session cookie is username/password auth
|
|
||||||
if cookie.Provider == "username" {
|
|
||||||
log.Debug().Msg("Provider is username")
|
|
||||||
|
|
||||||
userSearch := hooks.Auth.SearchUser(cookie.Username)
|
|
||||||
|
|
||||||
if userSearch.Type == "unknown" {
|
|
||||||
log.Error().Str("username", cookie.Username).Msg("User does not exist")
|
|
||||||
return types.UserContext{}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug().Str("type", userSearch.Type).Msg("User exists")
|
|
||||||
|
|
||||||
return types.UserContext{
|
|
||||||
Username: cookie.Username,
|
|
||||||
Name: cookie.Name,
|
|
||||||
Email: cookie.Email,
|
|
||||||
IsLoggedIn: true,
|
|
||||||
Provider: "username",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug().Msg("Provider is not username")
|
|
||||||
|
|
||||||
// The provider is not username so we need to check if it is an oauth provider
|
|
||||||
provider := hooks.Providers.GetProvider(cookie.Provider)
|
|
||||||
|
|
||||||
// If we have a provider with this name
|
|
||||||
if provider != nil {
|
|
||||||
log.Debug().Msg("Provider exists")
|
|
||||||
|
|
||||||
// If the email is not whitelisted we delete the cookie and return an empty context
|
|
||||||
if !hooks.Auth.EmailWhitelisted(cookie.Email) {
|
|
||||||
log.Error().Str("email", cookie.Email).Msg("Email is not whitelisted")
|
|
||||||
hooks.Auth.DeleteSessionCookie(c)
|
|
||||||
return types.UserContext{}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug().Msg("Email is whitelisted")
|
|
||||||
|
|
||||||
return types.UserContext{
|
|
||||||
Username: cookie.Username,
|
|
||||||
Name: cookie.Name,
|
|
||||||
Email: cookie.Email,
|
|
||||||
IsLoggedIn: true,
|
|
||||||
OAuth: true,
|
|
||||||
Provider: cookie.Provider,
|
|
||||||
OAuthGroups: cookie.OAuthGroups,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return types.UserContext{}
|
return types.UserContext{}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user