fix: use new context in user controller

This commit is contained in:
Stavros
2026-04-29 19:45:23 +03:00
parent 9a219046ac
commit 2f24f823eb
2 changed files with 102 additions and 47 deletions
+97 -46
View File
@@ -1,10 +1,12 @@
package controller package controller
import ( import (
"errors"
"fmt" "fmt"
"net/http"
"time" "time"
"github.com/tinyauthapp/tinyauth/internal/config" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
@@ -24,7 +26,8 @@ type TotpRequest struct {
} }
type UserControllerConfig struct { type UserControllerConfig struct {
CookieDomain string CookieDomain string
SessionCookieName string
} }
type UserController struct { type UserController struct {
@@ -77,20 +80,28 @@ func (controller *UserController) loginHandler(c *gin.Context) {
return return
} }
userSearch := controller.auth.SearchUser(req.Username) search, err := controller.auth.SearchUser(req.Username)
if userSearch.Type == "unknown" { if err != nil {
tlog.App.Warn().Str("username", req.Username).Msg("User not found") if errors.Is(err, service.ErrUserNotFound) {
controller.auth.RecordLoginAttempt(req.Username, false) tlog.App.Warn().Str("username", req.Username).Msg("User not found")
tlog.AuditLoginFailure(c, req.Username, "username", "user not found") controller.auth.RecordLoginAttempt(req.Username, false)
c.JSON(401, gin.H{ tlog.AuditLoginFailure(c, req.Username, "username", "user not found")
"status": 401, c.JSON(401, gin.H{
"message": "Unauthorized", "status": 401,
"message": "Unauthorized",
})
return
}
tlog.App.Error().Err(err).Str("username", req.Username).Msg("Error searching for user")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
}) })
return return
} }
if !controller.auth.VerifyUser(userSearch, req.Password) { if err := controller.auth.CheckUserPassword(*search, req.Password); err != nil {
tlog.App.Warn().Str("username", req.Username).Msg("Invalid password") tlog.App.Warn().Str("username", req.Username).Msg("Invalid password")
controller.auth.RecordLoginAttempt(req.Username, false) controller.auth.RecordLoginAttempt(req.Username, false)
tlog.AuditLoginFailure(c, req.Username, "username", "invalid password") tlog.AuditLoginFailure(c, req.Username, "username", "invalid password")
@@ -106,30 +117,26 @@ func (controller *UserController) loginHandler(c *gin.Context) {
controller.auth.RecordLoginAttempt(req.Username, true) controller.auth.RecordLoginAttempt(req.Username, true)
var localUser *config.User var localUser *model.LocalUser
if userSearch.Type == "local" {
user := controller.auth.GetLocalUser(userSearch.Username)
localUser = &user
}
if userSearch.Type == "local" && localUser != nil { if search.Type == model.UserLocal {
user := *localUser localUser = controller.auth.GetLocalUser(req.Username)
if user.TotpSecret != "" { if localUser.TOTPSecret != "" {
tlog.App.Debug().Str("username", req.Username).Msg("User has TOTP enabled, requiring TOTP verification") tlog.App.Debug().Str("username", req.Username).Msg("User has TOTP enabled, requiring TOTP verification")
name := user.Attributes.Name name := localUser.Attributes.Name
if name == "" { if name == "" {
name = utils.Capitalize(user.Username) name = utils.Capitalize(localUser.Username)
} }
email := user.Attributes.Email email := localUser.Attributes.Email
if email == "" { if email == "" {
email = utils.CompileUserEmail(user.Username, controller.config.CookieDomain) email = utils.CompileUserEmail(localUser.Username, controller.config.CookieDomain)
} }
err := controller.auth.CreateSessionCookie(c, &repository.Session{ cookie, err := controller.auth.CreateSession(c, repository.Session{
Username: user.Username, Username: localUser.Username,
Name: name, Name: name,
Email: email, Email: email,
Provider: "local", Provider: "local",
@@ -145,6 +152,8 @@ func (controller *UserController) loginHandler(c *gin.Context) {
return return
} }
http.SetCookie(c.Writer, cookie)
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"message": "TOTP required", "message": "TOTP required",
@@ -161,7 +170,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
Provider: "local", Provider: "local",
} }
if userSearch.Type == "local" && localUser != nil { if search.Type == model.UserLocal {
if localUser.Attributes.Name != "" { if localUser.Attributes.Name != "" {
sessionCookie.Name = localUser.Attributes.Name sessionCookie.Name = localUser.Attributes.Name
} }
@@ -170,13 +179,13 @@ func (controller *UserController) loginHandler(c *gin.Context) {
} }
} }
if userSearch.Type == "ldap" { if search.Type == model.UserLDAP {
sessionCookie.Provider = "ldap" sessionCookie.Provider = "ldap"
} }
tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie") tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
err = controller.auth.CreateSessionCookie(c, &sessionCookie) cookie, err := controller.auth.CreateSession(c, sessionCookie)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to create session cookie") tlog.App.Error().Err(err).Msg("Failed to create session cookie")
@@ -187,6 +196,8 @@ func (controller *UserController) loginHandler(c *gin.Context) {
return return
} }
http.SetCookie(c.Writer, cookie)
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"message": "Login successful", "message": "Login successful",
@@ -196,13 +207,51 @@ func (controller *UserController) loginHandler(c *gin.Context) {
func (controller *UserController) logoutHandler(c *gin.Context) { func (controller *UserController) logoutHandler(c *gin.Context) {
tlog.App.Debug().Msg("Logout request received") tlog.App.Debug().Msg("Logout request received")
controller.auth.DeleteSessionCookie(c) uuid, err := c.Cookie(controller.config.SessionCookieName)
context, err := utils.GetContext(c) if err != nil {
if err == nil && context.IsLoggedIn { if errors.Is(err, http.ErrNoCookie) {
tlog.AuditLogout(c, context.Username, context.Provider) tlog.App.Warn().Msg("No session cookie found on logout request")
c.JSON(200, gin.H{
"status": 200,
"message": "Logout successful",
})
return
}
tlog.App.Error().Err(err).Msg("Error retrieving session cookie on logout")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
})
return
} }
context, err := new(model.UserContext).NewFromGin(c)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to get user context on logout")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
})
return
}
cookie, err := controller.auth.DeleteSession(c, uuid)
if err != nil {
tlog.App.Error().Err(err).Msg("Error deleting session on logout")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
})
return
}
tlog.AuditLogout(c, context.GetUsername(), context.ProviderName())
http.SetCookie(c.Writer, cookie)
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"message": "Logout successful", "message": "Logout successful",
@@ -222,7 +271,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
return return
} }
context, err := utils.GetContext(c) context, err := new(model.UserContext).NewFromGin(c)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to get user context") tlog.App.Error().Err(err).Msg("Failed to get user context")
@@ -233,7 +282,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
return return
} }
if !context.TotpPending { if !context.TOTPPending() {
tlog.App.Warn().Msg("TOTP attempt without a pending TOTP session") tlog.App.Warn().Msg("TOTP attempt without a pending TOTP session")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"status": 401, "status": 401,
@@ -242,12 +291,12 @@ func (controller *UserController) totpHandler(c *gin.Context) {
return return
} }
tlog.App.Debug().Str("username", context.Username).Msg("TOTP verification attempt") tlog.App.Debug().Str("username", context.GetUsername()).Msg("TOTP verification attempt")
isLocked, remaining := controller.auth.IsAccountLocked(context.Username) isLocked, remaining := controller.auth.IsAccountLocked(context.GetUsername())
if isLocked { if isLocked {
tlog.App.Warn().Str("username", context.Username).Msg("Account is locked due to too many failed TOTP attempts") tlog.App.Warn().Str("username", context.GetUsername()).Msg("Account is locked due to too many failed TOTP attempts")
c.Writer.Header().Add("x-tinyauth-lock-locked", "true") c.Writer.Header().Add("x-tinyauth-lock-locked", "true")
c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339)) c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339))
c.JSON(429, gin.H{ c.JSON(429, gin.H{
@@ -257,14 +306,14 @@ func (controller *UserController) totpHandler(c *gin.Context) {
return return
} }
user := controller.auth.GetLocalUser(context.Username) user := controller.auth.GetLocalUser(context.GetUsername())
ok := totp.Validate(req.Code, user.TotpSecret) ok := totp.Validate(req.Code, user.TOTPSecret)
if !ok { if !ok {
tlog.App.Warn().Str("username", context.Username).Msg("Invalid TOTP code") tlog.App.Warn().Str("username", context.GetUsername()).Msg("Invalid TOTP code")
controller.auth.RecordLoginAttempt(context.Username, false) controller.auth.RecordLoginAttempt(context.GetUsername(), false)
tlog.AuditLoginFailure(c, context.Username, "totp", "invalid totp code") tlog.AuditLoginFailure(c, context.GetUsername(), "totp", "invalid totp code")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"status": 401, "status": 401,
"message": "Unauthorized", "message": "Unauthorized",
@@ -272,10 +321,10 @@ func (controller *UserController) totpHandler(c *gin.Context) {
return return
} }
tlog.App.Info().Str("username", context.Username).Msg("TOTP verification successful") tlog.App.Info().Str("username", context.GetUsername()).Msg("TOTP verification successful")
tlog.AuditLoginSuccess(c, context.Username, "totp") tlog.AuditLoginSuccess(c, context.GetUsername(), "totp")
controller.auth.RecordLoginAttempt(context.Username, true) controller.auth.RecordLoginAttempt(context.GetUsername(), true)
sessionCookie := repository.Session{ sessionCookie := repository.Session{
Username: user.Username, Username: user.Username,
@@ -293,7 +342,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie") tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
err = controller.auth.CreateSessionCookie(c, &sessionCookie) cookie, err := controller.auth.CreateSession(c, sessionCookie)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to create session cookie") tlog.App.Error().Err(err).Msg("Failed to create session cookie")
@@ -304,6 +353,8 @@ func (controller *UserController) totpHandler(c *gin.Context) {
return return
} }
http.SetCookie(c.Writer, cookie)
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"message": "Login successful", "message": "Login successful",
+5 -1
View File
@@ -30,6 +30,10 @@ const MaxOAuthPendingSessions = 256
const OAuthCleanupCount = 16 const OAuthCleanupCount = 16
const MaxLoginAttemptRecords = 256 const MaxLoginAttemptRecords = 256
var (
ErrUserNotFound = errors.New("user not found")
)
// slightly modified version of the AuthorizeRequest from the OIDC service to basically accept all // slightly modified version of the AuthorizeRequest from the OIDC service to basically accept all
// parameters and pass them to the authorize page if needed // parameters and pass them to the authorize page if needed
type OAuthURLParams struct { type OAuthURLParams struct {
@@ -136,7 +140,7 @@ func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error)
}, nil }, nil
} }
return nil, fmt.Errorf("user not found") return nil, ErrUserNotFound
} }
func (auth *AuthService) CheckUserPassword(search model.UserSearch, password string) error { func (auth *AuthService) CheckUserPassword(search model.UserSearch, password string) error {