diff --git a/internal/controller/user_controller.go b/internal/controller/user_controller.go index 187b33b..14648dc 100644 --- a/internal/controller/user_controller.go +++ b/internal/controller/user_controller.go @@ -1,10 +1,12 @@ package controller import ( + "errors" "fmt" + "net/http" "time" - "github.com/tinyauthapp/tinyauth/internal/config" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils" @@ -24,7 +26,8 @@ type TotpRequest struct { } type UserControllerConfig struct { - CookieDomain string + CookieDomain string + SessionCookieName string } type UserController struct { @@ -77,20 +80,28 @@ func (controller *UserController) loginHandler(c *gin.Context) { return } - userSearch := controller.auth.SearchUser(req.Username) + search, err := controller.auth.SearchUser(req.Username) - if userSearch.Type == "unknown" { - tlog.App.Warn().Str("username", req.Username).Msg("User not found") - controller.auth.RecordLoginAttempt(req.Username, false) - tlog.AuditLoginFailure(c, req.Username, "username", "user not found") - c.JSON(401, gin.H{ - "status": 401, - "message": "Unauthorized", + if err != nil { + if errors.Is(err, service.ErrUserNotFound) { + tlog.App.Warn().Str("username", req.Username).Msg("User not found") + controller.auth.RecordLoginAttempt(req.Username, false) + tlog.AuditLoginFailure(c, req.Username, "username", "user not found") + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } + tlog.App.Error().Err(err).Str("username", req.Username).Msg("Error searching for user") + c.JSON(500, gin.H{ + "status": 500, + "message": "Internal Server Error", }) return } - if !controller.auth.VerifyUser(userSearch, req.Password) { + if err := controller.auth.CheckUserPassword(*search, req.Password); err != nil { tlog.App.Warn().Str("username", req.Username).Msg("Invalid password") controller.auth.RecordLoginAttempt(req.Username, false) tlog.AuditLoginFailure(c, req.Username, "username", "invalid password") @@ -106,30 +117,26 @@ func (controller *UserController) loginHandler(c *gin.Context) { controller.auth.RecordLoginAttempt(req.Username, true) - var localUser *config.User - if userSearch.Type == "local" { - user := controller.auth.GetLocalUser(userSearch.Username) - localUser = &user - } + var localUser *model.LocalUser - if userSearch.Type == "local" && localUser != nil { - user := *localUser + if search.Type == model.UserLocal { + localUser = controller.auth.GetLocalUser(req.Username) - if user.TotpSecret != "" { + if localUser.TOTPSecret != "" { tlog.App.Debug().Str("username", req.Username).Msg("User has TOTP enabled, requiring TOTP verification") - name := user.Attributes.Name + name := localUser.Attributes.Name if name == "" { - name = utils.Capitalize(user.Username) + name = utils.Capitalize(localUser.Username) } - email := user.Attributes.Email + email := localUser.Attributes.Email if email == "" { - email = utils.CompileUserEmail(user.Username, controller.config.CookieDomain) + email = utils.CompileUserEmail(localUser.Username, controller.config.CookieDomain) } - err := controller.auth.CreateSessionCookie(c, &repository.Session{ - Username: user.Username, + cookie, err := controller.auth.CreateSession(c, repository.Session{ + Username: localUser.Username, Name: name, Email: email, Provider: "local", @@ -145,6 +152,8 @@ func (controller *UserController) loginHandler(c *gin.Context) { return } + http.SetCookie(c.Writer, cookie) + c.JSON(200, gin.H{ "status": 200, "message": "TOTP required", @@ -161,7 +170,7 @@ func (controller *UserController) loginHandler(c *gin.Context) { Provider: "local", } - if userSearch.Type == "local" && localUser != nil { + if search.Type == model.UserLocal { if localUser.Attributes.Name != "" { sessionCookie.Name = localUser.Attributes.Name } @@ -170,13 +179,13 @@ func (controller *UserController) loginHandler(c *gin.Context) { } } - if userSearch.Type == "ldap" { + if search.Type == model.UserLDAP { sessionCookie.Provider = "ldap" } tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie") - err = controller.auth.CreateSessionCookie(c, &sessionCookie) + cookie, err := controller.auth.CreateSession(c, sessionCookie) if err != nil { tlog.App.Error().Err(err).Msg("Failed to create session cookie") @@ -187,6 +196,8 @@ func (controller *UserController) loginHandler(c *gin.Context) { return } + http.SetCookie(c.Writer, cookie) + c.JSON(200, gin.H{ "status": 200, "message": "Login successful", @@ -196,13 +207,51 @@ func (controller *UserController) loginHandler(c *gin.Context) { func (controller *UserController) logoutHandler(c *gin.Context) { tlog.App.Debug().Msg("Logout request received") - controller.auth.DeleteSessionCookie(c) + uuid, err := c.Cookie(controller.config.SessionCookieName) - context, err := utils.GetContext(c) - if err == nil && context.IsLoggedIn { - tlog.AuditLogout(c, context.Username, context.Provider) + if err != nil { + if errors.Is(err, http.ErrNoCookie) { + tlog.App.Warn().Msg("No session cookie found on logout request") + c.JSON(200, gin.H{ + "status": 200, + "message": "Logout successful", + }) + return + } + tlog.App.Error().Err(err).Msg("Error retrieving session cookie on logout") + c.JSON(500, gin.H{ + "status": 500, + "message": "Internal Server Error", + }) + return } + context, err := new(model.UserContext).NewFromGin(c) + + if err != nil { + tlog.App.Error().Err(err).Msg("Failed to get user context on logout") + c.JSON(500, gin.H{ + "status": 500, + "message": "Internal Server Error", + }) + return + } + + cookie, err := controller.auth.DeleteSession(c, uuid) + + if err != nil { + tlog.App.Error().Err(err).Msg("Error deleting session on logout") + c.JSON(500, gin.H{ + "status": 500, + "message": "Internal Server Error", + }) + return + } + + tlog.AuditLogout(c, context.GetUsername(), context.ProviderName()) + + http.SetCookie(c.Writer, cookie) + c.JSON(200, gin.H{ "status": 200, "message": "Logout successful", @@ -222,7 +271,7 @@ func (controller *UserController) totpHandler(c *gin.Context) { return } - context, err := utils.GetContext(c) + context, err := new(model.UserContext).NewFromGin(c) if err != nil { tlog.App.Error().Err(err).Msg("Failed to get user context") @@ -233,7 +282,7 @@ func (controller *UserController) totpHandler(c *gin.Context) { return } - if !context.TotpPending { + if !context.TOTPPending() { tlog.App.Warn().Msg("TOTP attempt without a pending TOTP session") c.JSON(401, gin.H{ "status": 401, @@ -242,12 +291,12 @@ func (controller *UserController) totpHandler(c *gin.Context) { return } - tlog.App.Debug().Str("username", context.Username).Msg("TOTP verification attempt") + tlog.App.Debug().Str("username", context.GetUsername()).Msg("TOTP verification attempt") - isLocked, remaining := controller.auth.IsAccountLocked(context.Username) + isLocked, remaining := controller.auth.IsAccountLocked(context.GetUsername()) if isLocked { - tlog.App.Warn().Str("username", context.Username).Msg("Account is locked due to too many failed TOTP attempts") + tlog.App.Warn().Str("username", context.GetUsername()).Msg("Account is locked due to too many failed TOTP attempts") c.Writer.Header().Add("x-tinyauth-lock-locked", "true") c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339)) c.JSON(429, gin.H{ @@ -257,14 +306,14 @@ func (controller *UserController) totpHandler(c *gin.Context) { return } - user := controller.auth.GetLocalUser(context.Username) + user := controller.auth.GetLocalUser(context.GetUsername()) - ok := totp.Validate(req.Code, user.TotpSecret) + ok := totp.Validate(req.Code, user.TOTPSecret) if !ok { - tlog.App.Warn().Str("username", context.Username).Msg("Invalid TOTP code") - controller.auth.RecordLoginAttempt(context.Username, false) - tlog.AuditLoginFailure(c, context.Username, "totp", "invalid totp code") + tlog.App.Warn().Str("username", context.GetUsername()).Msg("Invalid TOTP code") + controller.auth.RecordLoginAttempt(context.GetUsername(), false) + tlog.AuditLoginFailure(c, context.GetUsername(), "totp", "invalid totp code") c.JSON(401, gin.H{ "status": 401, "message": "Unauthorized", @@ -272,10 +321,10 @@ func (controller *UserController) totpHandler(c *gin.Context) { return } - tlog.App.Info().Str("username", context.Username).Msg("TOTP verification successful") - tlog.AuditLoginSuccess(c, context.Username, "totp") + tlog.App.Info().Str("username", context.GetUsername()).Msg("TOTP verification successful") + tlog.AuditLoginSuccess(c, context.GetUsername(), "totp") - controller.auth.RecordLoginAttempt(context.Username, true) + controller.auth.RecordLoginAttempt(context.GetUsername(), true) sessionCookie := repository.Session{ Username: user.Username, @@ -293,7 +342,7 @@ func (controller *UserController) totpHandler(c *gin.Context) { tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie") - err = controller.auth.CreateSessionCookie(c, &sessionCookie) + cookie, err := controller.auth.CreateSession(c, sessionCookie) if err != nil { tlog.App.Error().Err(err).Msg("Failed to create session cookie") @@ -304,6 +353,8 @@ func (controller *UserController) totpHandler(c *gin.Context) { return } + http.SetCookie(c.Writer, cookie) + c.JSON(200, gin.H{ "status": 200, "message": "Login successful", diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index 148340f..86743e4 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -30,6 +30,10 @@ const MaxOAuthPendingSessions = 256 const OAuthCleanupCount = 16 const MaxLoginAttemptRecords = 256 +var ( + ErrUserNotFound = errors.New("user not found") +) + // slightly modified version of the AuthorizeRequest from the OIDC service to basically accept all // parameters and pass them to the authorize page if needed type OAuthURLParams struct { @@ -136,7 +140,7 @@ func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error) }, nil } - return nil, fmt.Errorf("user not found") + return nil, ErrUserNotFound } func (auth *AuthService) CheckUserPassword(search model.UserSearch, password string) error {