diff --git a/internal/controller/context_controller.go b/internal/controller/context_controller.go index da53303..3362d0d 100644 --- a/internal/controller/context_controller.go +++ b/internal/controller/context_controller.go @@ -4,7 +4,7 @@ import ( "fmt" "net/url" - "github.com/tinyauthapp/tinyauth/internal/utils" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/gin-gonic/gin" @@ -19,7 +19,7 @@ type UserContextResponse struct { Email string `json:"email"` Provider string `json:"provider"` OAuth bool `json:"oauth"` - TotpPending bool `json:"totpPending"` + TOTPPending bool `json:"totpPending"` OAuthName string `json:"oauthName"` } @@ -76,28 +76,29 @@ func (controller *ContextController) SetupRoutes() { } func (controller *ContextController) userContextHandler(c *gin.Context) { - context, err := utils.GetContext(c) + context, err := new(model.UserContext).NewFromGin(c) + + if err != nil { + tlog.App.Debug().Err(err).Msg("No user context found in request") + c.JSON(200, UserContextResponse{ + Status: 401, + Message: "Unauthorized", + IsLoggedIn: false, + }) + return + } userContext := UserContextResponse{ Status: 200, Message: "Success", - IsLoggedIn: context.IsLoggedIn, - Username: context.Username, - Name: context.Name, - Email: context.Email, - Provider: context.Provider, - OAuth: context.OAuth, - TotpPending: context.TotpPending, - OAuthName: context.OAuthName, - } - - if err != nil { - tlog.App.Debug().Err(err).Msg("No user context found in request") - userContext.Status = 401 - userContext.Message = "Unauthorized" - userContext.IsLoggedIn = false - c.JSON(200, userContext) - return + IsLoggedIn: context.Authenticated, + Username: context.GetUsername(), + Name: context.GetName(), + Email: context.GetEmail(), + Provider: context.ProviderName(), + OAuth: context.IsOAuth(), + TOTPPending: context.TOTPPending(), + OAuthName: context.OAuthName(), } c.JSON(200, userContext) diff --git a/internal/model/context.go b/internal/model/context.go index ad75b4f..64202cf 100644 --- a/internal/model/context.go +++ b/internal/model/context.go @@ -177,3 +177,30 @@ func (c *UserContext) GetName() string { return "" } } + +func (c *UserContext) ProviderName() string { + switch c.Provider { + case ProviderBasicAuth, ProviderLocal: + return "local" + case ProviderLDAP: + return "ldap" + case ProviderOAuth: + return c.OAuth.DisplayName // compatability + default: + return "unknown" + } +} + +func (c *UserContext) TOTPPending() bool { + if c.Provider == ProviderLocal { + return c.Local.TOTPPending + } + return false +} + +func (c *UserContext) OAuthName() string { + if c.Provider == ProviderOAuth { + return c.OAuth.DisplayName + } + return "" +}