refactor: rework user context handling throughout tinyauth (#829)

* wip

* fix: fix util imports

* fix: fix bootstrap import issues

* fix: fix cli imports

* fix: context controller

* fix: use new context in user controller

* fix: fix imports and context in proxy controller

* fix: fix oauth and oidc controller imports and context

* feat: finalize context functionality

* refactor: simplify acls checking logic by passing the entire acl struct

* chore: rename get basic auth to encode basic auth for clarity

* fix: fix controller tests

* tests: fix service tests

* tests: fix utils tests

* tests: move to testify for testing in utils

* fix: fix config reference generator

* tests: add tests for context parsing

* tests: add tests for context middleware

* tests: remove error wrapper from context tests

* tests: fix log wrapper tests

* fix: fix verion setting in cd and dockerfiles

* fix: review comments batch 1

* fix: review comments batch 2

* fix: review comments batch 3

* fix: delete totp pending session cookie on totp success

* tests: fix user controller tests

* fix: don't audit login too early

* fix: own comments
This commit is contained in:
Stavros
2026-05-07 15:41:07 +03:00
committed by GitHub
parent 24f2da4e58
commit 1382ab41e7
58 changed files with 2070 additions and 1117 deletions
+21 -20
View File
@@ -4,7 +4,7 @@ import (
"fmt"
"net/url"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/gin-gonic/gin"
@@ -19,7 +19,7 @@ type UserContextResponse struct {
Email string `json:"email"`
Provider string `json:"provider"`
OAuth bool `json:"oauth"`
TotpPending bool `json:"totpPending"`
TOTPPending bool `json:"totpPending"`
OAuthName string `json:"oauthName"`
}
@@ -76,28 +76,29 @@ func (controller *ContextController) SetupRoutes() {
}
func (controller *ContextController) userContextHandler(c *gin.Context) {
context, err := utils.GetContext(c)
context, err := new(model.UserContext).NewFromGin(c)
if err != nil {
tlog.App.Debug().Err(err).Msg("No user context found in request")
c.JSON(200, UserContextResponse{
Status: 401,
Message: "Unauthorized",
IsLoggedIn: false,
})
return
}
userContext := UserContextResponse{
Status: 200,
Message: "Success",
IsLoggedIn: context.IsLoggedIn,
Username: context.Username,
Name: context.Name,
Email: context.Email,
Provider: context.Provider,
OAuth: context.OAuth,
TotpPending: context.TotpPending,
OAuthName: context.OAuthName,
}
if err != nil {
tlog.App.Debug().Err(err).Msg("No user context found in request")
userContext.Status = 401
userContext.Message = "Unauthorized"
userContext.IsLoggedIn = false
c.JSON(200, userContext)
return
IsLoggedIn: context.Authenticated,
Username: context.GetUsername(),
Name: context.GetName(),
Email: context.GetEmail(),
Provider: context.GetProviderID(),
OAuth: context.IsOAuth(),
TOTPPending: context.TOTPPending(),
OAuthName: context.OAuthName(),
}
c.JSON(200, userContext)
+12 -8
View File
@@ -7,11 +7,11 @@ import (
"testing"
"github.com/gin-gonic/gin"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/stretchr/testify/assert"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert"
)
func TestContextController(t *testing.T) {
@@ -79,12 +79,16 @@ func TestContextController(t *testing.T) {
description: "Ensure user context returns when authorized",
middlewares: []gin.HandlerFunc{
func(c *gin.Context) {
c.Set("context", &config.UserContext{
Username: "johndoe",
Name: "John Doe",
Email: utils.CompileUserEmail("johndoe", controllerConfig.CookieDomain),
Provider: "local",
IsLoggedIn: true,
c.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderLocal,
Local: &model.LocalContext{
BaseContext: model.BaseContext{
Username: "johndoe",
Name: "John Doe",
Email: utils.CompileUserEmail("johndoe", controllerConfig.CookieDomain),
},
},
})
},
},
+12
View File
@@ -0,0 +1,12 @@
package controller
type UnauthorizedQuery struct {
Username string `url:"username"`
Resource string `url:"resource"`
GroupErr bool `url:"groupErr"`
IP string `url:"ip"`
}
type RedirectQuery struct {
RedirectURI string `url:"redirect_uri"`
}
+5 -4
View File
@@ -6,7 +6,6 @@ import (
"strings"
"time"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils"
@@ -176,7 +175,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
tlog.App.Warn().Str("email", user.Email).Msg("Email not whitelisted")
tlog.AuditLoginFailure(c, user.Email, req.Provider, "email not whitelisted")
queries, err := query.Values(config.UnauthorizedQuery{
queries, err := query.Values(UnauthorizedQuery{
Username: user.Email,
})
@@ -236,7 +235,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
err = controller.auth.CreateSessionCookie(c, &sessionCookie)
cookie, err := controller.auth.CreateSession(c, sessionCookie)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to create session cookie")
@@ -244,6 +243,8 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
return
}
http.SetCookie(c.Writer, cookie)
tlog.AuditLoginSuccess(c, sessionCookie.Username, sessionCookie.Provider)
if controller.isOidcRequest(oauthPendingSession.CallbackParams) {
@@ -259,7 +260,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
}
if oauthPendingSession.CallbackParams.RedirectURI != "" {
queries, err := query.Values(config.RedirectQuery{
queries, err := query.Values(RedirectQuery{
RedirectURI: oauthPendingSession.CallbackParams.RedirectURI,
})
+5 -4
View File
@@ -10,6 +10,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/google/go-querystring/query"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
@@ -111,14 +112,14 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
return
}
userContext, err := utils.GetContext(c)
userContext, err := new(model.UserContext).NewFromGin(c)
if err != nil {
controller.authorizeError(c, err, "Failed to get user context", "User is not logged in or the session is invalid", "", "", "")
return
}
if !userContext.IsLoggedIn {
if !userContext.Authenticated {
controller.authorizeError(c, errors.New("err user not logged in"), "User not logged in", "The user is not logged in", "", "", "")
return
}
@@ -151,7 +152,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
}
// WARNING: Since Tinyauth is stateless, we cannot have a sub that never changes. We will just create a uuid out of the username and client name which remains stable, but if username or client name changes then sub changes too.
sub := utils.GenerateUUID(fmt.Sprintf("%s:%s", userContext.Username, client.ID))
sub := utils.GenerateUUID(fmt.Sprintf("%s:%s", userContext.GetUsername(), client.ID))
code := utils.GenerateString(32)
// Before storing the code, delete old session
@@ -170,7 +171,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
// We also need a snapshot of the user that authorized this (skip if no openid scope)
if slices.Contains(strings.Fields(req.Scope), "openid") {
err = controller.oidc.StoreUserinfo(c, sub, userContext, req)
err = controller.oidc.StoreUserinfo(c, sub, *userContext, req)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to insert user info into database")
+15 -11
View File
@@ -12,14 +12,14 @@ import (
"github.com/gin-gonic/gin"
"github.com/google/go-querystring/query"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestOIDCController(t *testing.T) {
@@ -27,7 +27,7 @@ func TestOIDCController(t *testing.T) {
tempDir := t.TempDir()
oidcServiceCfg := service.OIDCServiceConfig{
Clients: map[string]config.OIDCClientConfig{
Clients: map[string]model.OIDCClientConfig{
"test": {
ClientID: "some-client-id",
ClientSecret: "some-client-secret",
@@ -44,12 +44,16 @@ func TestOIDCController(t *testing.T) {
controllerCfg := controller.OIDCControllerConfig{}
simpleCtx := func(c *gin.Context) {
c.Set("context", &config.UserContext{
Username: "test",
Name: "Test User",
Email: "test@example.com",
IsLoggedIn: true,
Provider: "local",
c.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderLocal,
Local: &model.LocalContext{
BaseContext: model.BaseContext{
Username: "test",
Name: "Test User",
Email: "test@example.com",
},
},
})
c.Next()
}
@@ -848,7 +852,7 @@ func TestOIDCController(t *testing.T) {
},
}
app := bootstrap.NewBootstrapApp(config.Config{})
app := bootstrap.NewBootstrapApp(model.Config{})
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
require.NoError(t, err)
+41 -40
View File
@@ -8,7 +8,7 @@ import (
"regexp"
"strings"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
@@ -103,7 +103,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
clientIP := c.ClientIP()
if controller.auth.IsBypassedIP(acls.IP, clientIP) {
if controller.auth.IsBypassedIP(clientIP, acls) {
controller.setHeaders(c, acls)
c.JSON(200, gin.H{
"status": 200,
@@ -112,7 +112,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
authEnabled, err := controller.auth.IsAuthEnabled(proxyCtx.Path, acls.Path)
authEnabled, err := controller.auth.IsAuthEnabled(proxyCtx.Path, acls)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to check if auth is enabled for resource")
@@ -130,8 +130,8 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
if !controller.auth.CheckIP(acls.IP, clientIP) {
queries, err := query.Values(config.UnauthorizedQuery{
if !controller.auth.CheckIP(clientIP, acls) {
queries, err := query.Values(UnauthorizedQuery{
Resource: strings.Split(proxyCtx.Host, ".")[0],
IP: clientIP,
})
@@ -157,28 +157,24 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
var userContext config.UserContext
context, err := utils.GetContext(c)
userContext, err := new(model.UserContext).NewFromGin(c)
if err != nil {
tlog.App.Debug().Msg("No user context found in request, treating as not logged in")
userContext = config.UserContext{
IsLoggedIn: false,
tlog.App.Debug().Err(err).Msg("No user context found in request, treating as unauthenticated")
userContext = &model.UserContext{
Authenticated: false,
}
} else {
userContext = context
}
tlog.App.Trace().Interface("context", userContext).Msg("User context from request")
if userContext.IsLoggedIn {
userAllowed := controller.auth.IsUserAllowed(c, userContext, acls)
if userContext.Authenticated {
userAllowed := controller.auth.IsUserAllowed(c, *userContext, acls)
if !userAllowed {
tlog.App.Warn().Str("user", userContext.Username).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User not allowed to access resource")
tlog.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User not allowed to access resource")
queries, err := query.Values(config.UnauthorizedQuery{
queries, err := query.Values(UnauthorizedQuery{
Resource: strings.Split(proxyCtx.Host, ".")[0],
})
@@ -188,10 +184,10 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
if userContext.OAuth {
queries.Set("username", userContext.Email)
if userContext.IsOAuth() {
queries.Set("username", userContext.GetEmail())
} else {
queries.Set("username", userContext.Username)
queries.Set("username", userContext.GetUsername())
}
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())
@@ -209,19 +205,19 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
if userContext.OAuth || userContext.Provider == "ldap" {
if userContext.IsOAuth() || userContext.IsLDAP() {
var groupOK bool
if userContext.OAuth {
groupOK = controller.auth.IsInOAuthGroup(c, userContext, acls.OAuth.Groups)
if userContext.IsOAuth() {
groupOK = controller.auth.IsInOAuthGroup(c, *userContext, acls)
} else {
groupOK = controller.auth.IsInLdapGroup(c, userContext, acls.LDAP.Groups)
groupOK = controller.auth.IsInLDAPGroup(c, *userContext, acls)
}
if !groupOK {
tlog.App.Warn().Str("user", userContext.Username).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User groups do not match resource requirements")
tlog.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User groups do not match resource requirements")
queries, err := query.Values(config.UnauthorizedQuery{
queries, err := query.Values(UnauthorizedQuery{
Resource: strings.Split(proxyCtx.Host, ".")[0],
GroupErr: true,
})
@@ -232,10 +228,10 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
if userContext.OAuth {
queries.Set("username", userContext.Email)
if userContext.IsOAuth() {
queries.Set("username", userContext.GetEmail())
} else {
queries.Set("username", userContext.Username)
queries.Set("username", userContext.GetUsername())
}
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())
@@ -254,17 +250,18 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
}
}
c.Header("Remote-User", utils.SanitizeHeader(userContext.Username))
c.Header("Remote-Name", utils.SanitizeHeader(userContext.Name))
c.Header("Remote-Email", utils.SanitizeHeader(userContext.Email))
c.Header("Remote-User", utils.SanitizeHeader(userContext.GetUsername()))
c.Header("Remote-Name", utils.SanitizeHeader(userContext.GetName()))
c.Header("Remote-Email", utils.SanitizeHeader(userContext.GetEmail()))
if userContext.Provider == "ldap" {
c.Header("Remote-Groups", utils.SanitizeHeader(userContext.LdapGroups))
} else if userContext.Provider != "local" {
c.Header("Remote-Groups", utils.SanitizeHeader(userContext.OAuthGroups))
if userContext.IsLDAP() {
c.Header("Remote-Groups", utils.SanitizeHeader(strings.Join(userContext.LDAP.Groups, ",")))
}
c.Header("Remote-Sub", utils.SanitizeHeader(userContext.OAuthSub))
if userContext.IsOAuth() {
c.Header("Remote-Groups", utils.SanitizeHeader(strings.Join(userContext.OAuth.Groups, ",")))
c.Header("Remote-Sub", utils.SanitizeHeader(userContext.OAuth.Sub))
}
controller.setHeaders(c, acls)
@@ -275,7 +272,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
queries, err := query.Values(config.RedirectQuery{
queries, err := query.Values(RedirectQuery{
RedirectURI: fmt.Sprintf("%s://%s%s", proxyCtx.Proto, proxyCtx.Host, proxyCtx.Path),
})
@@ -299,9 +296,13 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
}
func (controller *ProxyController) setHeaders(c *gin.Context, acls config.App) {
func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) {
c.Header("Authorization", c.Request.Header.Get("Authorization"))
if acls == nil {
return
}
headers := utils.ParseHeaders(acls.Response.Headers)
for key, value := range headers {
@@ -313,7 +314,7 @@ func (controller *ProxyController) setHeaders(c *gin.Context, acls config.App) {
if acls.Response.BasicAuth.Username != "" && basicPassword != "" {
tlog.App.Debug().Str("username", acls.Response.BasicAuth.Username).Msg("Setting basic auth header")
c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(acls.Response.BasicAuth.Username, basicPassword)))
c.Header("Authorization", fmt.Sprintf("Basic %s", utils.EncodeBasicAuth(acls.Response.BasicAuth.Username, basicPassword)))
}
}
+34 -27
View File
@@ -6,14 +6,14 @@ import (
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestProxyController(t *testing.T) {
@@ -21,7 +21,7 @@ func TestProxyController(t *testing.T) {
tempDir := t.TempDir()
authServiceCfg := service.AuthServiceConfig{
Users: []config.User{
LocalUsers: &[]model.LocalUser{
{
Username: "testuser",
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
@@ -29,7 +29,7 @@ func TestProxyController(t *testing.T) {
{
Username: "totpuser",
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
TotpSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
},
},
SessionExpiry: 10, // 10 seconds, useful for testing
@@ -43,28 +43,28 @@ func TestProxyController(t *testing.T) {
AppURL: "https://tinyauth.example.com",
}
acls := map[string]config.App{
acls := map[string]model.App{
"app_path_allow": {
Config: config.AppConfig{
Config: model.AppConfig{
Domain: "path-allow.example.com",
},
Path: config.AppPath{
Path: model.AppPath{
Allow: "/allowed",
},
},
"app_user_allow": {
Config: config.AppConfig{
Config: model.AppConfig{
Domain: "user-allow.example.com",
},
Users: config.AppUsers{
Users: model.AppUsers{
Allow: "testuser",
},
},
"ip_bypass": {
Config: config.AppConfig{
Config: model.AppConfig{
Domain: "ip-bypass.example.com",
},
IP: config.AppIP{
IP: model.AppIP{
Bypass: []string{"10.10.10.10"},
},
},
@@ -74,24 +74,31 @@ func TestProxyController(t *testing.T) {
Mozilla/5.0 (Linux; Android 8.0.0; SM-G955U Build/R16NW) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/146.0.0.0 Mobile Safari/537.36`
simpleCtx := func(c *gin.Context) {
c.Set("context", &config.UserContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
IsLoggedIn: true,
Provider: "local",
c.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderLocal,
Local: &model.LocalContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Testuser",
Email: "testuser@example.com",
},
},
})
c.Next()
}
simpleCtxTotp := func(c *gin.Context) {
c.Set("context", &config.UserContext{
Username: "totpuser",
Name: "Totpuser",
Email: "totpuser@example.com",
IsLoggedIn: true,
Provider: "local",
TotpEnabled: true,
c.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderLocal,
Local: &model.LocalContext{
BaseContext: model.BaseContext{
Username: "totpuser",
Name: "Totpuser",
Email: "totpuser@example.com",
},
},
})
c.Next()
}
@@ -391,9 +398,9 @@ func TestProxyController(t *testing.T) {
},
}
oauthBrokerCfgs := make(map[string]config.OAuthServiceConfig)
oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig)
app := bootstrap.NewBootstrapApp(config.Config{})
app := bootstrap.NewBootstrapApp(model.Config{})
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
require.NoError(t, err)
+129 -53
View File
@@ -1,10 +1,12 @@
package controller
import (
"errors"
"fmt"
"net/http"
"time"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils"
@@ -24,7 +26,8 @@ type TotpRequest struct {
}
type UserControllerConfig struct {
CookieDomain string
CookieDomain string
SessionCookieName string
}
type UserController struct {
@@ -77,21 +80,29 @@ func (controller *UserController) loginHandler(c *gin.Context) {
return
}
userSearch := controller.auth.SearchUser(req.Username)
search, err := controller.auth.SearchUser(req.Username)
if userSearch.Type == "unknown" {
tlog.App.Warn().Str("username", req.Username).Msg("User not found")
controller.auth.RecordLoginAttempt(req.Username, false)
tlog.AuditLoginFailure(c, req.Username, "username", "user not found")
c.JSON(401, gin.H{
"status": 401,
"message": "Unauthorized",
if err != nil {
if errors.Is(err, service.ErrUserNotFound) {
tlog.App.Warn().Str("username", req.Username).Msg("User not found")
controller.auth.RecordLoginAttempt(req.Username, false)
tlog.AuditLoginFailure(c, req.Username, "username", "user not found")
c.JSON(401, gin.H{
"status": 401,
"message": "Unauthorized",
})
return
}
tlog.App.Error().Err(err).Str("username", req.Username).Msg("Error searching for user")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
})
return
}
if !controller.auth.VerifyUser(userSearch, req.Password) {
tlog.App.Warn().Str("username", req.Username).Msg("Invalid password")
if err := controller.auth.CheckUserPassword(*search, req.Password); err != nil {
tlog.App.Warn().Err(err).Str("username", req.Username).Msg("Failed to verify password")
controller.auth.RecordLoginAttempt(req.Username, false)
tlog.AuditLoginFailure(c, req.Username, "username", "invalid password")
c.JSON(401, gin.H{
@@ -101,35 +112,35 @@ func (controller *UserController) loginHandler(c *gin.Context) {
return
}
tlog.App.Info().Str("username", req.Username).Msg("Login successful")
tlog.AuditLoginSuccess(c, req.Username, "username")
var localUser *model.LocalUser
controller.auth.RecordLoginAttempt(req.Username, true)
if search.Type == model.UserLocal {
localUser = controller.auth.GetLocalUser(req.Username)
var localUser *config.User
if userSearch.Type == "local" {
user := controller.auth.GetLocalUser(userSearch.Username)
localUser = &user
}
if localUser == nil {
tlog.App.Warn().Str("username", req.Username).Msg("User disappeared during login")
c.JSON(401, gin.H{
"status": 401,
"message": "Unauthorized",
})
return
}
if userSearch.Type == "local" && localUser != nil {
user := *localUser
if user.TotpSecret != "" {
if localUser.TOTPSecret != "" {
tlog.App.Debug().Str("username", req.Username).Msg("User has TOTP enabled, requiring TOTP verification")
name := user.Attributes.Name
name := localUser.Attributes.Name
if name == "" {
name = utils.Capitalize(user.Username)
name = utils.Capitalize(localUser.Username)
}
email := user.Attributes.Email
email := localUser.Attributes.Email
if email == "" {
email = utils.CompileUserEmail(user.Username, controller.config.CookieDomain)
email = utils.CompileUserEmail(localUser.Username, controller.config.CookieDomain)
}
err := controller.auth.CreateSessionCookie(c, &repository.Session{
Username: user.Username,
cookie, err := controller.auth.CreateSession(c, repository.Session{
Username: localUser.Username,
Name: name,
Email: email,
Provider: "local",
@@ -145,6 +156,8 @@ func (controller *UserController) loginHandler(c *gin.Context) {
return
}
http.SetCookie(c.Writer, cookie)
c.JSON(200, gin.H{
"status": 200,
"message": "TOTP required",
@@ -161,7 +174,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
Provider: "local",
}
if userSearch.Type == "local" && localUser != nil {
if search.Type == model.UserLocal {
if localUser.Attributes.Name != "" {
sessionCookie.Name = localUser.Attributes.Name
}
@@ -170,13 +183,13 @@ func (controller *UserController) loginHandler(c *gin.Context) {
}
}
if userSearch.Type == "ldap" {
if search.Type == model.UserLDAP {
sessionCookie.Provider = "ldap"
}
tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
err = controller.auth.CreateSessionCookie(c, &sessionCookie)
cookie, err := controller.auth.CreateSession(c, sessionCookie)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to create session cookie")
@@ -187,6 +200,13 @@ func (controller *UserController) loginHandler(c *gin.Context) {
return
}
http.SetCookie(c.Writer, cookie)
tlog.App.Info().Str("username", req.Username).Msg("Login successful")
tlog.AuditLoginSuccess(c, req.Username, "username")
controller.auth.RecordLoginAttempt(req.Username, true)
c.JSON(200, gin.H{
"status": 200,
"message": "Login successful",
@@ -196,13 +216,47 @@ func (controller *UserController) loginHandler(c *gin.Context) {
func (controller *UserController) logoutHandler(c *gin.Context) {
tlog.App.Debug().Msg("Logout request received")
controller.auth.DeleteSessionCookie(c)
uuid, err := c.Cookie(controller.config.SessionCookieName)
context, err := utils.GetContext(c)
if err == nil && context.IsLoggedIn {
tlog.AuditLogout(c, context.Username, context.Provider)
if err != nil {
if errors.Is(err, http.ErrNoCookie) {
tlog.App.Warn().Msg("No session cookie found on logout request")
c.JSON(200, gin.H{
"status": 200,
"message": "Logout successful",
})
return
}
tlog.App.Error().Err(err).Msg("Error retrieving session cookie on logout")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
})
return
}
cookie, err := controller.auth.DeleteSession(c, uuid)
if err != nil {
tlog.App.Error().Err(err).Msg("Error deleting session on logout")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
})
return
}
context, err := new(model.UserContext).NewFromGin(c)
if err == nil {
tlog.AuditLogout(c, context.GetUsername(), context.GetProviderID())
} else {
tlog.App.Warn().Err(err).Msg("Failed to get user context for logout audit, proceeding without username")
tlog.AuditLogout(c, "unknown", "unknown")
}
http.SetCookie(c.Writer, cookie)
c.JSON(200, gin.H{
"status": 200,
"message": "Logout successful",
@@ -222,7 +276,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
return
}
context, err := utils.GetContext(c)
context, err := new(model.UserContext).NewFromGin(c)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to get user context")
@@ -233,7 +287,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
return
}
if !context.TotpPending {
if !context.TOTPPending() {
tlog.App.Warn().Msg("TOTP attempt without a pending TOTP session")
c.JSON(401, gin.H{
"status": 401,
@@ -242,12 +296,12 @@ func (controller *UserController) totpHandler(c *gin.Context) {
return
}
tlog.App.Debug().Str("username", context.Username).Msg("TOTP verification attempt")
tlog.App.Debug().Str("username", context.GetUsername()).Msg("TOTP verification attempt")
isLocked, remaining := controller.auth.IsAccountLocked(context.Username)
isLocked, remaining := controller.auth.IsAccountLocked(context.GetUsername())
if isLocked {
tlog.App.Warn().Str("username", context.Username).Msg("Account is locked due to too many failed TOTP attempts")
tlog.App.Warn().Str("username", context.GetUsername()).Msg("Account is locked due to too many failed TOTP attempts")
c.Writer.Header().Add("x-tinyauth-lock-locked", "true")
c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339))
c.JSON(429, gin.H{
@@ -257,14 +311,10 @@ func (controller *UserController) totpHandler(c *gin.Context) {
return
}
user := controller.auth.GetLocalUser(context.Username)
user := controller.auth.GetLocalUser(context.GetUsername())
ok := totp.Validate(req.Code, user.TotpSecret)
if !ok {
tlog.App.Warn().Str("username", context.Username).Msg("Invalid TOTP code")
controller.auth.RecordLoginAttempt(context.Username, false)
tlog.AuditLoginFailure(c, context.Username, "totp", "invalid totp code")
if user == nil {
tlog.App.Error().Str("username", context.GetUsername()).Msg("User not found in TOTP handler")
c.JSON(401, gin.H{
"status": 401,
"message": "Unauthorized",
@@ -272,10 +322,31 @@ func (controller *UserController) totpHandler(c *gin.Context) {
return
}
tlog.App.Info().Str("username", context.Username).Msg("TOTP verification successful")
tlog.AuditLoginSuccess(c, context.Username, "totp")
ok := totp.Validate(req.Code, user.TOTPSecret)
controller.auth.RecordLoginAttempt(context.Username, true)
if !ok {
tlog.App.Warn().Str("username", context.GetUsername()).Msg("Invalid TOTP code")
controller.auth.RecordLoginAttempt(context.GetUsername(), false)
tlog.AuditLoginFailure(c, context.GetUsername(), "totp", "invalid totp code")
c.JSON(401, gin.H{
"status": 401,
"message": "Unauthorized",
})
return
}
uuid, err := c.Cookie(controller.config.SessionCookieName)
if err == nil {
_, err = controller.auth.DeleteSession(c, uuid)
if err != nil {
tlog.App.Warn().Err(err).Msg("Failed to delete pending TOTP session")
}
} else {
tlog.App.Warn().Err(err).Msg("Failed to retrieve session cookie for pending TOTP session, proceeding without deleting it")
}
controller.auth.RecordLoginAttempt(context.GetUsername(), true)
sessionCookie := repository.Session{
Username: user.Username,
@@ -293,7 +364,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
err = controller.auth.CreateSessionCookie(c, &sessionCookie)
cookie, err := controller.auth.CreateSession(c, sessionCookie)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to create session cookie")
@@ -304,6 +375,11 @@ func (controller *UserController) totpHandler(c *gin.Context) {
return
}
http.SetCookie(c.Writer, cookie)
tlog.App.Info().Str("username", context.GetUsername()).Msg("TOTP verification successful")
tlog.AuditLoginSuccess(c, context.GetUsername(), "totp")
c.JSON(200, gin.H{
"status": 200,
"message": "Login successful",
+131 -69
View File
@@ -1,7 +1,9 @@
package controller_test
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"path"
"strings"
@@ -10,14 +12,14 @@ import (
"github.com/gin-gonic/gin"
"github.com/pquerna/otp/totp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestUserController(t *testing.T) {
@@ -25,7 +27,7 @@ func TestUserController(t *testing.T) {
tempDir := t.TempDir()
authServiceCfg := service.AuthServiceConfig{
Users: []config.User{
LocalUsers: &[]model.LocalUser{
{
Username: "testuser",
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
@@ -33,12 +35,12 @@ func TestUserController(t *testing.T) {
{
Username: "totpuser",
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
TotpSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
},
{
Username: "attruser",
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
Attributes: config.UserAttributes{
Attributes: model.UserAttributes{
Name: "Alice Smith",
Email: "alice@example.com",
},
@@ -46,8 +48,8 @@ func TestUserController(t *testing.T) {
{
Username: "attrtotpuser",
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
TotpSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
Attributes: config.UserAttributes{
TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK",
Attributes: model.UserAttributes{
Name: "Bob Jones",
Email: "bob@example.com",
},
@@ -61,9 +63,63 @@ func TestUserController(t *testing.T) {
}
userControllerCfg := controller.UserControllerConfig{
CookieDomain: "example.com",
CookieDomain: "example.com",
SessionCookieName: "tinyauth-session",
}
totpCtx := func(c *gin.Context) {
c.Set("context", &model.UserContext{
Authenticated: false,
Provider: model.ProviderLocal,
Local: &model.LocalContext{
BaseContext: model.BaseContext{
Username: "totpuser",
Name: "Totpuser",
Email: "totpuser@example.com",
},
TOTPPending: true,
},
})
}
totpAttrCtx := func(c *gin.Context) {
c.Set("context", &model.UserContext{
Authenticated: false,
Provider: model.ProviderLocal,
Local: &model.LocalContext{
BaseContext: model.BaseContext{
Username: "attrtotpuser",
Name: "Bob Jones",
Email: "bob@example.com",
},
TOTPPending: true,
},
})
}
simpleCtx := func(c *gin.Context) {
c.Set("context", &model.UserContext{
Authenticated: true,
Provider: model.ProviderLocal,
Local: &model.LocalContext{
BaseContext: model.BaseContext{
Username: "testuser",
Name: "Test User",
Email: "testuser@example.com",
},
},
})
}
oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig)
app := bootstrap.NewBootstrapApp(model.Config{})
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
require.NoError(t, err)
queries := repository.New(db)
type testCase struct {
description string
middlewares []gin.HandlerFunc
@@ -94,7 +150,9 @@ func TestUserController(t *testing.T) {
assert.Equal(t, "tinyauth-session", cookie.Name)
assert.True(t, cookie.HttpOnly)
assert.Equal(t, "example.com", cookie.Domain)
assert.Equal(t, 10, cookie.MaxAge)
// 3 seconds should be more than enough for even slow test environments
assert.GreaterOrEqual(t, cookie.MaxAge, 7)
assert.LessOrEqual(t, cookie.MaxAge, 10)
},
},
{
@@ -183,12 +241,15 @@ func TestUserController(t *testing.T) {
assert.Equal(t, "tinyauth-session", cookie.Name)
assert.True(t, cookie.HttpOnly)
assert.Equal(t, "example.com", cookie.Domain)
assert.Equal(t, 3600, cookie.MaxAge) // 1 hour, default for totp pending sessions
assert.GreaterOrEqual(t, cookie.MaxAge, 3597)
assert.LessOrEqual(t, cookie.MaxAge, 3600)
},
},
{
description: "Should be able to logout",
middlewares: []gin.HandlerFunc{},
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
// First login to get a session cookie
loginReq := controller.LoginRequest{
@@ -204,9 +265,10 @@ func TestUserController(t *testing.T) {
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Len(t, recorder.Result().Cookies(), 1)
cookies := recorder.Result().Cookies()
assert.Len(t, cookies, 1)
cookie := recorder.Result().Cookies()[0]
cookie := cookies[0]
assert.Equal(t, "tinyauth-session", cookie.Name)
// Now logout using the session cookie
@@ -217,18 +279,33 @@ func TestUserController(t *testing.T) {
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
assert.Len(t, recorder.Result().Cookies(), 1)
cookies = recorder.Result().Cookies()
assert.Len(t, cookies, 1)
logoutCookie := recorder.Result().Cookies()[0]
assert.Equal(t, "tinyauth-session", logoutCookie.Name)
assert.Equal(t, "", logoutCookie.Value)
assert.Equal(t, -1, logoutCookie.MaxAge) // MaxAge -1 means delete cookie
cookie = cookies[0]
assert.Equal(t, "tinyauth-session", cookie.Name)
assert.Equal(t, "", cookie.Value)
assert.Equal(t, -1, cookie.MaxAge) // MaxAge -1 means delete cookie
},
},
{
description: "Should be able to login with totp",
middlewares: []gin.HandlerFunc{},
middlewares: []gin.HandlerFunc{
totpCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
_, err := queries.CreateSession(context.TODO(), repository.CreateSessionParams{
UUID: "test-totp-login-uuid",
Username: "test",
Email: "test@example.com",
Name: "Test",
Provider: "local",
TotpPending: true,
Expiry: time.Now().Add(1 * time.Hour).Unix(),
CreatedAt: time.Now().Unix(),
})
require.NoError(t, err)
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
assert.NoError(t, err)
@@ -242,7 +319,13 @@ func TestUserController(t *testing.T) {
recorder = httptest.NewRecorder()
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{
Name: "tinyauth-session",
Value: "test-totp-login-uuid",
HttpOnly: true,
MaxAge: 3600,
Expires: time.Now().Add(1 * time.Hour),
})
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)
@@ -253,12 +336,15 @@ func TestUserController(t *testing.T) {
assert.Equal(t, "tinyauth-session", totpCookie.Name)
assert.True(t, totpCookie.HttpOnly)
assert.Equal(t, "example.com", totpCookie.Domain)
assert.Equal(t, 10, totpCookie.MaxAge) // should use the regular session expiry time
assert.GreaterOrEqual(t, totpCookie.MaxAge, 7)
assert.LessOrEqual(t, totpCookie.MaxAge, 10)
},
},
{
description: "Totp should rate limit on multiple invalid attempts",
middlewares: []gin.HandlerFunc{},
middlewares: []gin.HandlerFunc{
totpCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
for range 3 {
totpReq := controller.TotpRequest{
@@ -328,8 +414,22 @@ func TestUserController(t *testing.T) {
},
{
description: "TOTP completion uses name and email from user attributes",
middlewares: []gin.HandlerFunc{},
middlewares: []gin.HandlerFunc{
totpAttrCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
_, err := queries.CreateSession(context.TODO(), repository.CreateSessionParams{
UUID: "test-totp-login-attributes-uuid",
Username: "test",
Email: "test@example.com",
Name: "Test",
Provider: "local",
TotpPending: true,
Expiry: time.Now().Add(1 * time.Hour).Unix(),
CreatedAt: time.Now().Unix(),
})
require.NoError(t, err)
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
require.NoError(t, err)
@@ -339,6 +439,13 @@ func TestUserController(t *testing.T) {
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(body)))
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{
Name: "tinyauth-session",
Value: "test-totp-login-attributes-uuid",
HttpOnly: true,
MaxAge: 3600,
Expires: time.Now().Add(1 * time.Hour),
})
router.ServeHTTP(recorder, req)
require.Equal(t, 200, recorder.Code)
@@ -349,15 +456,6 @@ func TestUserController(t *testing.T) {
},
}
oauthBrokerCfgs := make(map[string]config.OAuthServiceConfig)
app := bootstrap.NewBootstrapApp(config.Config{})
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
require.NoError(t, err)
queries := repository.New(db)
docker := service.NewDockerService()
err = docker.Init()
require.NoError(t, err)
@@ -379,33 +477,6 @@ func TestUserController(t *testing.T) {
authService.ClearRateLimitsTestingOnly()
}
setTotpMiddlewareOverrides := map[string]config.UserContext{
"Should be able to login with totp": {
Username: "totpuser",
Name: "Totpuser",
Email: "totpuser@example.com",
Provider: "local",
TotpPending: true,
TotpEnabled: true,
},
"Totp should rate limit on multiple invalid attempts": {
Username: "totpuser",
Name: "Totpuser",
Email: "totpuser@example.com",
Provider: "local",
TotpPending: true,
TotpEnabled: true,
},
"TOTP completion uses name and email from user attributes": {
Username: "attrtotpuser",
Name: "Bob Jones",
Email: "bob@example.com",
Provider: "local",
TotpPending: true,
TotpEnabled: true,
},
}
for _, test := range tests {
beforeEach()
t.Run(test.description, func(t *testing.T) {
@@ -415,15 +486,6 @@ func TestUserController(t *testing.T) {
router.Use(middleware)
}
// Gin is stupid and doesn't allow setting a middleware after the groups
// so we need to do some stupid overrides here
if ctx, ok := setTotpMiddlewareOverrides[test.description]; ok {
ctx := ctx
router.Use(func(c *gin.Context) {
c.Set("context", &ctx)
})
}
group := router.Group("/api")
gin.SetMode(gin.TestMode)
@@ -8,14 +8,14 @@ import (
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestWellKnownController(t *testing.T) {
@@ -23,7 +23,7 @@ func TestWellKnownController(t *testing.T) {
tempDir := t.TempDir()
oidcServiceCfg := service.OIDCServiceConfig{
Clients: map[string]config.OIDCClientConfig{
Clients: map[string]model.OIDCClientConfig{
"test": {
ClientID: "some-client-id",
ClientSecret: "some-client-secret",
@@ -101,7 +101,7 @@ func TestWellKnownController(t *testing.T) {
},
}
app := bootstrap.NewBootstrapApp(config.Config{})
app := bootstrap.NewBootstrapApp(model.Config{})
db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db"))
require.NoError(t, err)