refactor: rework user context handling throughout tinyauth (#829)

* wip

* fix: fix util imports

* fix: fix bootstrap import issues

* fix: fix cli imports

* fix: context controller

* fix: use new context in user controller

* fix: fix imports and context in proxy controller

* fix: fix oauth and oidc controller imports and context

* feat: finalize context functionality

* refactor: simplify acls checking logic by passing the entire acl struct

* chore: rename get basic auth to encode basic auth for clarity

* fix: fix controller tests

* tests: fix service tests

* tests: fix utils tests

* tests: move to testify for testing in utils

* fix: fix config reference generator

* tests: add tests for context parsing

* tests: add tests for context middleware

* tests: remove error wrapper from context tests

* tests: fix log wrapper tests

* fix: fix verion setting in cd and dockerfiles

* fix: review comments batch 1

* fix: review comments batch 2

* fix: review comments batch 3

* fix: delete totp pending session cookie on totp success

* tests: fix user controller tests

* fix: don't audit login too early

* fix: own comments
This commit is contained in:
Stavros
2026-05-07 15:41:07 +03:00
committed by GitHub
parent 24f2da4e58
commit 1382ab41e7
58 changed files with 2070 additions and 1117 deletions
+41 -40
View File
@@ -8,7 +8,7 @@ import (
"regexp"
"strings"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
@@ -103,7 +103,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
clientIP := c.ClientIP()
if controller.auth.IsBypassedIP(acls.IP, clientIP) {
if controller.auth.IsBypassedIP(clientIP, acls) {
controller.setHeaders(c, acls)
c.JSON(200, gin.H{
"status": 200,
@@ -112,7 +112,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
authEnabled, err := controller.auth.IsAuthEnabled(proxyCtx.Path, acls.Path)
authEnabled, err := controller.auth.IsAuthEnabled(proxyCtx.Path, acls)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to check if auth is enabled for resource")
@@ -130,8 +130,8 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
if !controller.auth.CheckIP(acls.IP, clientIP) {
queries, err := query.Values(config.UnauthorizedQuery{
if !controller.auth.CheckIP(clientIP, acls) {
queries, err := query.Values(UnauthorizedQuery{
Resource: strings.Split(proxyCtx.Host, ".")[0],
IP: clientIP,
})
@@ -157,28 +157,24 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
var userContext config.UserContext
context, err := utils.GetContext(c)
userContext, err := new(model.UserContext).NewFromGin(c)
if err != nil {
tlog.App.Debug().Msg("No user context found in request, treating as not logged in")
userContext = config.UserContext{
IsLoggedIn: false,
tlog.App.Debug().Err(err).Msg("No user context found in request, treating as unauthenticated")
userContext = &model.UserContext{
Authenticated: false,
}
} else {
userContext = context
}
tlog.App.Trace().Interface("context", userContext).Msg("User context from request")
if userContext.IsLoggedIn {
userAllowed := controller.auth.IsUserAllowed(c, userContext, acls)
if userContext.Authenticated {
userAllowed := controller.auth.IsUserAllowed(c, *userContext, acls)
if !userAllowed {
tlog.App.Warn().Str("user", userContext.Username).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User not allowed to access resource")
tlog.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User not allowed to access resource")
queries, err := query.Values(config.UnauthorizedQuery{
queries, err := query.Values(UnauthorizedQuery{
Resource: strings.Split(proxyCtx.Host, ".")[0],
})
@@ -188,10 +184,10 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
if userContext.OAuth {
queries.Set("username", userContext.Email)
if userContext.IsOAuth() {
queries.Set("username", userContext.GetEmail())
} else {
queries.Set("username", userContext.Username)
queries.Set("username", userContext.GetUsername())
}
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())
@@ -209,19 +205,19 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
if userContext.OAuth || userContext.Provider == "ldap" {
if userContext.IsOAuth() || userContext.IsLDAP() {
var groupOK bool
if userContext.OAuth {
groupOK = controller.auth.IsInOAuthGroup(c, userContext, acls.OAuth.Groups)
if userContext.IsOAuth() {
groupOK = controller.auth.IsInOAuthGroup(c, *userContext, acls)
} else {
groupOK = controller.auth.IsInLdapGroup(c, userContext, acls.LDAP.Groups)
groupOK = controller.auth.IsInLDAPGroup(c, *userContext, acls)
}
if !groupOK {
tlog.App.Warn().Str("user", userContext.Username).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User groups do not match resource requirements")
tlog.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User groups do not match resource requirements")
queries, err := query.Values(config.UnauthorizedQuery{
queries, err := query.Values(UnauthorizedQuery{
Resource: strings.Split(proxyCtx.Host, ".")[0],
GroupErr: true,
})
@@ -232,10 +228,10 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
if userContext.OAuth {
queries.Set("username", userContext.Email)
if userContext.IsOAuth() {
queries.Set("username", userContext.GetEmail())
} else {
queries.Set("username", userContext.Username)
queries.Set("username", userContext.GetUsername())
}
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())
@@ -254,17 +250,18 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
}
}
c.Header("Remote-User", utils.SanitizeHeader(userContext.Username))
c.Header("Remote-Name", utils.SanitizeHeader(userContext.Name))
c.Header("Remote-Email", utils.SanitizeHeader(userContext.Email))
c.Header("Remote-User", utils.SanitizeHeader(userContext.GetUsername()))
c.Header("Remote-Name", utils.SanitizeHeader(userContext.GetName()))
c.Header("Remote-Email", utils.SanitizeHeader(userContext.GetEmail()))
if userContext.Provider == "ldap" {
c.Header("Remote-Groups", utils.SanitizeHeader(userContext.LdapGroups))
} else if userContext.Provider != "local" {
c.Header("Remote-Groups", utils.SanitizeHeader(userContext.OAuthGroups))
if userContext.IsLDAP() {
c.Header("Remote-Groups", utils.SanitizeHeader(strings.Join(userContext.LDAP.Groups, ",")))
}
c.Header("Remote-Sub", utils.SanitizeHeader(userContext.OAuthSub))
if userContext.IsOAuth() {
c.Header("Remote-Groups", utils.SanitizeHeader(strings.Join(userContext.OAuth.Groups, ",")))
c.Header("Remote-Sub", utils.SanitizeHeader(userContext.OAuth.Sub))
}
controller.setHeaders(c, acls)
@@ -275,7 +272,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
queries, err := query.Values(config.RedirectQuery{
queries, err := query.Values(RedirectQuery{
RedirectURI: fmt.Sprintf("%s://%s%s", proxyCtx.Proto, proxyCtx.Host, proxyCtx.Path),
})
@@ -299,9 +296,13 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
}
func (controller *ProxyController) setHeaders(c *gin.Context, acls config.App) {
func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) {
c.Header("Authorization", c.Request.Header.Get("Authorization"))
if acls == nil {
return
}
headers := utils.ParseHeaders(acls.Response.Headers)
for key, value := range headers {
@@ -313,7 +314,7 @@ func (controller *ProxyController) setHeaders(c *gin.Context, acls config.App) {
if acls.Response.BasicAuth.Username != "" && basicPassword != "" {
tlog.App.Debug().Str("username", acls.Response.BasicAuth.Username).Msg("Setting basic auth header")
c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(acls.Response.BasicAuth.Username, basicPassword)))
c.Header("Authorization", fmt.Sprintf("Basic %s", utils.EncodeBasicAuth(acls.Response.BasicAuth.Username, basicPassword)))
}
}