diff --git a/internal/handlers/context.go b/internal/handlers/context.go new file mode 100644 index 0000000..d0fff5e --- /dev/null +++ b/internal/handlers/context.go @@ -0,0 +1,64 @@ +package handlers + +import ( + "tinyauth/internal/types" + + "github.com/gin-gonic/gin" + "github.com/rs/zerolog/log" +) + +func (h *Handlers) AppContextHandler(c *gin.Context) { + log.Debug().Msg("Getting app context") + + // Get configured providers + configuredProviders := h.Providers.GetConfiguredProviders() + + // We have username/password configured so add it to our providers + if h.Auth.UserAuthConfigured() { + configuredProviders = append(configuredProviders, "username") + } + + // Return app context + appContext := types.AppContext{ + Status: 200, + Message: "OK", + ConfiguredProviders: configuredProviders, + DisableContinue: h.Config.DisableContinue, + Title: h.Config.Title, + GenericName: h.Config.GenericName, + Domain: h.Config.Domain, + ForgotPasswordMessage: h.Config.ForgotPasswordMessage, + BackgroundImage: h.Config.BackgroundImage, + OAuthAutoRedirect: h.Config.OAuthAutoRedirect, + } + c.JSON(200, appContext) +} + +func (h *Handlers) UserContextHandler(c *gin.Context) { + log.Debug().Msg("Getting user context") + + // Create user context using hooks + userContext := h.Hooks.UseUserContext(c) + + userContextResponse := types.UserContextResponse{ + Status: 200, + IsLoggedIn: userContext.IsLoggedIn, + Username: userContext.Username, + Name: userContext.Name, + Email: userContext.Email, + Provider: userContext.Provider, + Oauth: userContext.OAuth, + TotpPending: userContext.TotpPending, + } + + // If we are not logged in we set the status to 401 else we set it to 200 + if !userContext.IsLoggedIn { + log.Debug().Msg("Unauthorized") + userContextResponse.Message = "Unauthorized" + } else { + log.Debug().Interface("userContext", userContext).Msg("Authenticated") + userContextResponse.Message = "Authenticated" + } + + c.JSON(200, userContextResponse) +} diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go index b23eda1..0e8ebe2 100644 --- a/internal/handlers/handlers.go +++ b/internal/handlers/handlers.go @@ -1,21 +1,13 @@ package handlers import ( - "fmt" - "net/http" - "strings" - "time" "tinyauth/internal/auth" "tinyauth/internal/docker" "tinyauth/internal/hooks" "tinyauth/internal/providers" "tinyauth/internal/types" - "tinyauth/internal/utils" "github.com/gin-gonic/gin" - "github.com/google/go-querystring/query" - "github.com/pquerna/otp/totp" - "github.com/rs/zerolog/log" ) type Handlers struct { @@ -36,733 +28,6 @@ func NewHandlers(config types.HandlersConfig, auth *auth.Auth, hooks *hooks.Hook } } -func (h *Handlers) AuthHandler(c *gin.Context) { - var proxy types.Proxy - - err := c.BindUri(&proxy) - if err != nil { - log.Error().Err(err).Msg("Failed to bind URI") - c.JSON(400, gin.H{ - "status": 400, - "message": "Bad Request", - }) - return - } - - // Check if the request is coming from a browser (tools like curl/bruno use */* and they don't include the text/html) - isBrowser := strings.Contains(c.Request.Header.Get("Accept"), "text/html") - - if isBrowser { - log.Debug().Msg("Request is most likely coming from a browser") - } else { - log.Debug().Msg("Request is most likely not coming from a browser") - } - - log.Debug().Interface("proxy", proxy.Proxy).Msg("Got proxy") - - uri := c.Request.Header.Get("X-Forwarded-Uri") - proto := c.Request.Header.Get("X-Forwarded-Proto") - host := c.Request.Header.Get("X-Forwarded-Host") - - // Remove the port from the host if it exists - hostPortless := strings.Split(host, ":")[0] // *lol* - - // Get the id - id := strings.Split(hostPortless, ".")[0] - - labels, err := h.Docker.GetLabels(id, hostPortless) - if err != nil { - log.Error().Err(err).Msg("Failed to get container labels") - - if proxy.Proxy == "nginx" || !isBrowser { - c.JSON(500, gin.H{ - "status": 500, - "message": "Internal Server Error", - }) - return - } - - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - log.Debug().Interface("labels", labels).Msg("Got labels") - - ip := c.ClientIP() - - // Check if the IP is in bypass list - if h.Auth.BypassedIP(labels, ip) { - headersParsed := utils.ParseHeaders(labels.Headers) - - for key, value := range headersParsed { - log.Debug().Str("key", key).Msg("Setting header") - c.Header(key, value) - } - - if labels.Basic.Username != "" && utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File) != "" { - log.Debug().Str("username", labels.Basic.Username).Msg("Setting basic auth headers") - c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Basic.Username, utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File)))) - } - - c.JSON(200, gin.H{ - "status": 200, - "message": "Authenticated", - }) - return - } - - // Check if the IP is allowed/blocked - if !h.Auth.CheckIP(labels, ip) { - if proxy.Proxy == "nginx" || !isBrowser { - c.JSON(403, gin.H{ - "status": 403, - "message": "Forbidden", - }) - return - } - - values := types.UnauthorizedQuery{ - Resource: strings.Split(host, ".")[0], - IP: ip, - } - - queries, err := query.Values(values) - if err != nil { - log.Error().Err(err).Msg("Failed to build queries") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", h.Config.AppURL, queries.Encode())) - return - } - - // Check if auth is enabled - authEnabled, err := h.Auth.AuthEnabled(uri, labels) - if err != nil { - log.Error().Err(err).Msg("Failed to check if app is allowed") - if proxy.Proxy == "nginx" || !isBrowser { - c.JSON(500, gin.H{ - "status": 500, - "message": "Internal Server Error", - }) - return - } - - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - // If auth is not enabled, return 200 - if !authEnabled { - headersParsed := utils.ParseHeaders(labels.Headers) - for key, value := range headersParsed { - log.Debug().Str("key", key).Msg("Setting header") - c.Header(key, value) - } - - if labels.Basic.Username != "" && utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File) != "" { - log.Debug().Str("username", labels.Basic.Username).Msg("Setting basic auth headers") - c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Basic.Username, utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File)))) - } - - c.JSON(200, gin.H{ - "status": 200, - "message": "Authenticated", - }) - - return - } - - // Get user context - userContext := h.Hooks.UseUserContext(c) - - // If we are using basic auth, we need to check if the user has totp and if it does then disable basic auth - if userContext.Provider == "basic" && userContext.TotpEnabled { - log.Warn().Str("username", userContext.Username).Msg("User has totp enabled, disabling basic auth") - userContext.IsLoggedIn = false - } - - // Check if user is logged in - if userContext.IsLoggedIn { - log.Debug().Msg("Authenticated") - - // Check if user is allowed to access subdomain, if request is nginx.example.com the subdomain (resource) is nginx - appAllowed := h.Auth.ResourceAllowed(c, userContext, labels) - - log.Debug().Bool("appAllowed", appAllowed).Msg("Checking if app is allowed") - - if !appAllowed { - log.Warn().Str("username", userContext.Username).Str("host", host).Msg("User not allowed") - - if proxy.Proxy == "nginx" || !isBrowser { - c.JSON(401, gin.H{ - "status": 401, - "message": "Unauthorized", - }) - return - } - - values := types.UnauthorizedQuery{ - Resource: strings.Split(host, ".")[0], - } - - if userContext.OAuth { - values.Username = userContext.Email - } else { - values.Username = userContext.Username - } - - queries, err := query.Values(values) - if err != nil { - log.Error().Err(err).Msg("Failed to build queries") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", h.Config.AppURL, queries.Encode())) - return - } - - // Check groups if using OAuth - if userContext.OAuth { - groupOk := h.Auth.OAuthGroup(c, userContext, labels) - - log.Debug().Bool("groupOk", groupOk).Msg("Checking if user is in required groups") - - if !groupOk { - log.Warn().Str("username", userContext.Username).Str("host", host).Msg("User is not in required groups") - if proxy.Proxy == "nginx" || !isBrowser { - c.JSON(401, gin.H{ - "status": 401, - "message": "Unauthorized", - }) - return - } - - values := types.UnauthorizedQuery{ - Resource: strings.Split(host, ".")[0], - GroupErr: true, - } - - if userContext.OAuth { - values.Username = userContext.Email - } else { - values.Username = userContext.Username - } - - queries, err := query.Values(values) - if err != nil { - log.Error().Err(err).Msg("Failed to build queries") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", h.Config.AppURL, queries.Encode())) - return - } - } - - c.Header("Remote-User", utils.SanitizeHeader(userContext.Username)) - c.Header("Remote-Name", utils.SanitizeHeader(userContext.Name)) - c.Header("Remote-Email", utils.SanitizeHeader(userContext.Email)) - c.Header("Remote-Groups", utils.SanitizeHeader(userContext.OAuthGroups)) - - // Set the rest of the headers - parsedHeaders := utils.ParseHeaders(labels.Headers) - for key, value := range parsedHeaders { - log.Debug().Str("key", key).Msg("Setting header") - c.Header(key, value) - } - - // Set basic auth headers if configured - if labels.Basic.Username != "" && utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File) != "" { - log.Debug().Str("username", labels.Basic.Username).Msg("Setting basic auth headers") - c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Basic.Username, utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File)))) - } - - c.JSON(200, gin.H{ - "status": 200, - "message": "Authenticated", - }) - return - } - - // The user is not logged in - log.Debug().Msg("Unauthorized") - - if proxy.Proxy == "nginx" || !isBrowser { - c.JSON(401, gin.H{ - "status": 401, - "message": "Unauthorized", - }) - return - } - - queries, err := query.Values(types.LoginQuery{ - RedirectURI: fmt.Sprintf("%s://%s%s", proto, host, uri), - }) - - if err != nil { - log.Error().Err(err).Msg("Failed to build queries") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - log.Debug().Interface("redirect_uri", fmt.Sprintf("%s://%s%s", proto, host, uri)).Msg("Redirecting to login") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/login?%s", h.Config.AppURL, queries.Encode())) -} - -func (h *Handlers) LoginHandler(c *gin.Context) { - var login types.LoginRequest - - err := c.BindJSON(&login) - if err != nil { - log.Error().Err(err).Msg("Failed to bind JSON") - c.JSON(400, gin.H{ - "status": 400, - "message": "Bad Request", - }) - return - } - - log.Debug().Msg("Got login request") - - clientIP := c.ClientIP() - - // Create an identifier for rate limiting (username or IP if username doesn't exist yet) - rateIdentifier := login.Username - if rateIdentifier == "" { - rateIdentifier = clientIP - } - - // Check if the account is locked due to too many failed attempts - locked, remainingTime := h.Auth.IsAccountLocked(rateIdentifier) - if locked { - log.Warn().Str("identifier", rateIdentifier).Int("remaining_seconds", remainingTime).Msg("Account is locked due to too many failed login attempts") - c.JSON(429, gin.H{ - "status": 429, - "message": fmt.Sprintf("Too many failed login attempts. Try again in %d seconds", remainingTime), - }) - return - } - - // Search for a user based on username - log.Debug().Interface("username", login.Username).Msg("Searching for user") - - userSearch := h.Auth.SearchUser(login.Username) - - // User does not exist - if userSearch.Type == "" { - log.Debug().Str("username", login.Username).Msg("User not found") - // Record failed login attempt - h.Auth.RecordLoginAttempt(rateIdentifier, false) - c.JSON(401, gin.H{ - "status": 401, - "message": "Unauthorized", - }) - return - } - - log.Debug().Msg("Got user") - - // Check if password is correct - if !h.Auth.VerifyUser(userSearch, login.Password) { - log.Debug().Str("username", login.Username).Msg("Password incorrect") - // Record failed login attempt - h.Auth.RecordLoginAttempt(rateIdentifier, false) - c.JSON(401, gin.H{ - "status": 401, - "message": "Unauthorized", - }) - return - } - - log.Debug().Msg("Password correct, checking totp") - - // Record successful login attempt (will reset failed attempt counter) - h.Auth.RecordLoginAttempt(rateIdentifier, true) - - // Check if user is using TOTP - if userSearch.Type == "local" { - // Get local user - localUser := h.Auth.GetLocalUser(login.Username) - - // Check if TOTP is enabled - if localUser.TotpSecret != "" { - log.Debug().Msg("Totp enabled") - - // Set totp pending cookie - h.Auth.CreateSessionCookie(c, &types.SessionCookie{ - Username: login.Username, - Name: utils.Capitalize(login.Username), - Email: fmt.Sprintf("%s@%s", strings.ToLower(login.Username), h.Config.Domain), - Provider: "username", - TotpPending: true, - }) - - // Return totp required - c.JSON(200, gin.H{ - "status": 200, - "message": "Waiting for totp", - "totpPending": true, - }) - return - } - } - - // Create session cookie with username as provider - h.Auth.CreateSessionCookie(c, &types.SessionCookie{ - Username: login.Username, - Name: utils.Capitalize(login.Username), - Email: fmt.Sprintf("%s@%s", strings.ToLower(login.Username), h.Config.Domain), - Provider: "username", - }) - - // Return logged in - c.JSON(200, gin.H{ - "status": 200, - "message": "Logged in", - "totpPending": false, - }) -} - -func (h *Handlers) TotpHandler(c *gin.Context) { - var totpReq types.TotpRequest - - err := c.BindJSON(&totpReq) - if err != nil { - log.Error().Err(err).Msg("Failed to bind JSON") - c.JSON(400, gin.H{ - "status": 400, - "message": "Bad Request", - }) - return - } - - log.Debug().Msg("Checking totp") - - // Get user context - userContext := h.Hooks.UseUserContext(c) - - // Check if we have a user - if userContext.Username == "" { - log.Debug().Msg("No user context") - c.JSON(401, gin.H{ - "status": 401, - "message": "Unauthorized", - }) - return - } - - // Get user - user := h.Auth.GetLocalUser(userContext.Username) - - // Check if totp is correct - ok := totp.Validate(totpReq.Code, user.TotpSecret) - - if !ok { - log.Debug().Msg("Totp incorrect") - c.JSON(401, gin.H{ - "status": 401, - "message": "Unauthorized", - }) - return - } - - log.Debug().Msg("Totp correct") - - // Create session cookie with username as provider - h.Auth.CreateSessionCookie(c, &types.SessionCookie{ - Username: user.Username, - Name: utils.Capitalize(user.Username), - Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), h.Config.Domain), - Provider: "username", - }) - - // Return logged in - c.JSON(200, gin.H{ - "status": 200, - "message": "Logged in", - }) -} - -func (h *Handlers) LogoutHandler(c *gin.Context) { - log.Debug().Msg("Cleaning up redirect cookie") - - h.Auth.DeleteSessionCookie(c) - - c.JSON(200, gin.H{ - "status": 200, - "message": "Logged out", - }) -} - -func (h *Handlers) AppHandler(c *gin.Context) { - log.Debug().Msg("Getting app context") - - // Get configured providers - configuredProviders := h.Providers.GetConfiguredProviders() - - // We have username/password configured so add it to our providers - if h.Auth.UserAuthConfigured() { - configuredProviders = append(configuredProviders, "username") - } - - // Return app context - appContext := types.AppContext{ - Status: 200, - Message: "OK", - ConfiguredProviders: configuredProviders, - DisableContinue: h.Config.DisableContinue, - Title: h.Config.Title, - GenericName: h.Config.GenericName, - Domain: h.Config.Domain, - ForgotPasswordMessage: h.Config.ForgotPasswordMessage, - BackgroundImage: h.Config.BackgroundImage, - OAuthAutoRedirect: h.Config.OAuthAutoRedirect, - } - c.JSON(200, appContext) -} - -func (h *Handlers) UserHandler(c *gin.Context) { - log.Debug().Msg("Getting user context") - - // Create user context using hooks - userContext := h.Hooks.UseUserContext(c) - - userContextResponse := types.UserContextResponse{ - Status: 200, - IsLoggedIn: userContext.IsLoggedIn, - Username: userContext.Username, - Name: userContext.Name, - Email: userContext.Email, - Provider: userContext.Provider, - Oauth: userContext.OAuth, - TotpPending: userContext.TotpPending, - } - - // If we are not logged in we set the status to 401 else we set it to 200 - if !userContext.IsLoggedIn { - log.Debug().Msg("Unauthorized") - userContextResponse.Message = "Unauthorized" - } else { - log.Debug().Interface("userContext", userContext).Msg("Authenticated") - userContextResponse.Message = "Authenticated" - } - - c.JSON(200, userContextResponse) -} - -func (h *Handlers) OauthUrlHandler(c *gin.Context) { - var request types.OAuthRequest - - err := c.BindUri(&request) - if err != nil { - log.Error().Err(err).Msg("Failed to bind URI") - c.JSON(400, gin.H{ - "status": 400, - "message": "Bad Request", - }) - return - } - - log.Debug().Msg("Got OAuth request") - - // Check if provider exists - provider := h.Providers.GetProvider(request.Provider) - - if provider == nil { - c.JSON(404, gin.H{ - "status": 404, - "message": "Not Found", - }) - return - } - - log.Debug().Str("provider", request.Provider).Msg("Got provider") - - // Create state - state := provider.GenerateState() - - // Get auth URL - authURL := provider.GetAuthURL(state) - - log.Debug().Msg("Got auth URL") - - // Set CSRF cookie - c.SetCookie(h.Config.CsrfCookieName, state, int(time.Hour.Seconds()), "/", "", h.Config.CookieSecure, true) - - // Get redirect URI - redirectURI := c.Query("redirect_uri") - - // Set redirect cookie if redirect URI is provided - if redirectURI != "" { - log.Debug().Str("redirectURI", redirectURI).Msg("Setting redirect cookie") - c.SetCookie(h.Config.RedirectCookieName, redirectURI, int(time.Hour.Seconds()), "/", "", h.Config.CookieSecure, true) - } - - // Return auth URL - c.JSON(200, gin.H{ - "status": 200, - "message": "OK", - "url": authURL, - }) -} - -func (h *Handlers) OauthCallbackHandler(c *gin.Context) { - var providerName types.OAuthRequest - - err := c.BindUri(&providerName) - if err != nil { - log.Error().Err(err).Msg("Failed to bind URI") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - log.Debug().Interface("provider", providerName.Provider).Msg("Got provider name") - - // Get state - state := c.Query("state") - - // Get CSRF cookie - csrfCookie, err := c.Cookie(h.Config.CsrfCookieName) - - if err != nil { - log.Debug().Msg("No CSRF cookie") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - log.Debug().Str("csrfCookie", csrfCookie).Msg("Got CSRF cookie") - - // Check if CSRF cookie is valid - if csrfCookie != state { - log.Warn().Msg("Invalid CSRF cookie or CSRF cookie does not match with the state") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - // Clean up CSRF cookie - c.SetCookie(h.Config.CsrfCookieName, "", -1, "/", "", h.Config.CookieSecure, true) - - // Get code - code := c.Query("code") - - log.Debug().Msg("Got code") - - // Get provider - provider := h.Providers.GetProvider(providerName.Provider) - - if provider == nil { - c.Redirect(http.StatusTemporaryRedirect, "/not-found") - return - } - - log.Debug().Str("provider", providerName.Provider).Msg("Got provider") - - // Exchange token (authenticates user) - _, err = provider.ExchangeToken(code) - if err != nil { - log.Error().Err(err).Msg("Failed to exchange token") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - log.Debug().Msg("Got token") - - // Get user - user, err := h.Providers.GetUser(providerName.Provider) - if err != nil { - log.Error().Err(err).Msg("Failed to get user") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - log.Debug().Msg("Got user") - - // Check that email is not empty - if user.Email == "" { - log.Error().Msg("Email is empty") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - // Email is not whitelisted - if !h.Auth.EmailWhitelisted(user.Email) { - log.Warn().Str("email", user.Email).Msg("Email not whitelisted") - queries, err := query.Values(types.UnauthorizedQuery{ - Username: user.Email, - }) - - if err != nil { - log.Error().Err(err).Msg("Failed to build queries") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", h.Config.AppURL, queries.Encode())) - } - - log.Debug().Msg("Email whitelisted") - - // Get username - var username string - - if user.PreferredUsername != "" { - username = user.PreferredUsername - } else { - username = fmt.Sprintf("%s_%s", strings.Split(user.Email, "@")[0], strings.Split(user.Email, "@")[1]) - } - - // Get name - var name string - - if user.Name != "" { - name = user.Name - } else { - name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1]) - } - - // Create session cookie - h.Auth.CreateSessionCookie(c, &types.SessionCookie{ - Username: username, - Name: name, - Email: user.Email, - Provider: providerName.Provider, - OAuthGroups: strings.Join(user.Groups, ","), - }) - - // Check if we have a redirect URI - redirectCookie, err := c.Cookie(h.Config.RedirectCookieName) - - if err != nil { - log.Debug().Msg("No redirect cookie") - c.Redirect(http.StatusTemporaryRedirect, h.Config.AppURL) - return - } - - log.Debug().Str("redirectURI", redirectCookie).Msg("Got redirect URI") - - queries, err := query.Values(types.LoginQuery{ - RedirectURI: redirectCookie, - }) - - if err != nil { - log.Error().Err(err).Msg("Failed to build queries") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - log.Debug().Msg("Got redirect query") - - // Clean up redirect cookie - c.SetCookie(h.Config.RedirectCookieName, "", -1, "/", "", h.Config.CookieSecure, true) - - // Redirect to continue with the redirect URI - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", h.Config.AppURL, queries.Encode())) -} - func (h *Handlers) HealthcheckHandler(c *gin.Context) { c.JSON(200, gin.H{ "status": 200, diff --git a/internal/handlers/oauth.go b/internal/handlers/oauth.go new file mode 100644 index 0000000..6e1528f --- /dev/null +++ b/internal/handlers/oauth.go @@ -0,0 +1,223 @@ +package handlers + +import ( + "fmt" + "net/http" + "strings" + "time" + "tinyauth/internal/types" + "tinyauth/internal/utils" + + "github.com/gin-gonic/gin" + "github.com/google/go-querystring/query" + "github.com/rs/zerolog/log" +) + +func (h *Handlers) OAuthURLHandler(c *gin.Context) { + var request types.OAuthRequest + + err := c.BindUri(&request) + if err != nil { + log.Error().Err(err).Msg("Failed to bind URI") + c.JSON(400, gin.H{ + "status": 400, + "message": "Bad Request", + }) + return + } + + log.Debug().Msg("Got OAuth request") + + // Check if provider exists + provider := h.Providers.GetProvider(request.Provider) + + if provider == nil { + c.JSON(404, gin.H{ + "status": 404, + "message": "Not Found", + }) + return + } + + log.Debug().Str("provider", request.Provider).Msg("Got provider") + + // Create state + state := provider.GenerateState() + + // Get auth URL + authURL := provider.GetAuthURL(state) + + log.Debug().Msg("Got auth URL") + + // Set CSRF cookie + c.SetCookie(h.Config.CsrfCookieName, state, int(time.Hour.Seconds()), "/", "", h.Config.CookieSecure, true) + + // Get redirect URI + redirectURI := c.Query("redirect_uri") + + // Set redirect cookie if redirect URI is provided + if redirectURI != "" { + log.Debug().Str("redirectURI", redirectURI).Msg("Setting redirect cookie") + c.SetCookie(h.Config.RedirectCookieName, redirectURI, int(time.Hour.Seconds()), "/", "", h.Config.CookieSecure, true) + } + + // Return auth URL + c.JSON(200, gin.H{ + "status": 200, + "message": "OK", + "url": authURL, + }) +} + +func (h *Handlers) OAuthCallbackHandler(c *gin.Context) { + var providerName types.OAuthRequest + + err := c.BindUri(&providerName) + if err != nil { + log.Error().Err(err).Msg("Failed to bind URI") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) + return + } + + log.Debug().Interface("provider", providerName.Provider).Msg("Got provider name") + + // Get state + state := c.Query("state") + + // Get CSRF cookie + csrfCookie, err := c.Cookie(h.Config.CsrfCookieName) + + if err != nil { + log.Debug().Msg("No CSRF cookie") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) + return + } + + log.Debug().Str("csrfCookie", csrfCookie).Msg("Got CSRF cookie") + + // Check if CSRF cookie is valid + if csrfCookie != state { + log.Warn().Msg("Invalid CSRF cookie or CSRF cookie does not match with the state") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) + return + } + + // Clean up CSRF cookie + c.SetCookie(h.Config.CsrfCookieName, "", -1, "/", "", h.Config.CookieSecure, true) + + // Get code + code := c.Query("code") + + log.Debug().Msg("Got code") + + // Get provider + provider := h.Providers.GetProvider(providerName.Provider) + + if provider == nil { + c.Redirect(http.StatusTemporaryRedirect, "/not-found") + return + } + + log.Debug().Str("provider", providerName.Provider).Msg("Got provider") + + // Exchange token (authenticates user) + _, err = provider.ExchangeToken(code) + if err != nil { + log.Error().Err(err).Msg("Failed to exchange token") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) + return + } + + log.Debug().Msg("Got token") + + // Get user + user, err := h.Providers.GetUser(providerName.Provider) + if err != nil { + log.Error().Err(err).Msg("Failed to get user") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) + return + } + + log.Debug().Msg("Got user") + + // Check that email is not empty + if user.Email == "" { + log.Error().Msg("Email is empty") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) + return + } + + // Email is not whitelisted + if !h.Auth.EmailWhitelisted(user.Email) { + log.Warn().Str("email", user.Email).Msg("Email not whitelisted") + queries, err := query.Values(types.UnauthorizedQuery{ + Username: user.Email, + }) + + if err != nil { + log.Error().Err(err).Msg("Failed to build queries") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) + return + } + + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", h.Config.AppURL, queries.Encode())) + } + + log.Debug().Msg("Email whitelisted") + + // Get username + var username string + + if user.PreferredUsername != "" { + username = user.PreferredUsername + } else { + username = fmt.Sprintf("%s_%s", strings.Split(user.Email, "@")[0], strings.Split(user.Email, "@")[1]) + } + + // Get name + var name string + + if user.Name != "" { + name = user.Name + } else { + name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1]) + } + + // Create session cookie + h.Auth.CreateSessionCookie(c, &types.SessionCookie{ + Username: username, + Name: name, + Email: user.Email, + Provider: providerName.Provider, + OAuthGroups: strings.Join(user.Groups, ","), + }) + + // Check if we have a redirect URI + redirectCookie, err := c.Cookie(h.Config.RedirectCookieName) + + if err != nil { + log.Debug().Msg("No redirect cookie") + c.Redirect(http.StatusTemporaryRedirect, h.Config.AppURL) + return + } + + log.Debug().Str("redirectURI", redirectCookie).Msg("Got redirect URI") + + queries, err := query.Values(types.LoginQuery{ + RedirectURI: redirectCookie, + }) + + if err != nil { + log.Error().Err(err).Msg("Failed to build queries") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) + return + } + + log.Debug().Msg("Got redirect query") + + // Clean up redirect cookie + c.SetCookie(h.Config.RedirectCookieName, "", -1, "/", "", h.Config.CookieSecure, true) + + // Redirect to continue with the redirect URI + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", h.Config.AppURL, queries.Encode())) +} diff --git a/internal/handlers/proxy.go b/internal/handlers/proxy.go new file mode 100644 index 0000000..8e15c68 --- /dev/null +++ b/internal/handlers/proxy.go @@ -0,0 +1,290 @@ +package handlers + +import ( + "fmt" + "net/http" + "strings" + "tinyauth/internal/types" + "tinyauth/internal/utils" + + "github.com/gin-gonic/gin" + "github.com/google/go-querystring/query" + "github.com/rs/zerolog/log" +) + +func (h *Handlers) ProxyHandler(c *gin.Context) { + var proxy types.Proxy + + err := c.BindUri(&proxy) + if err != nil { + log.Error().Err(err).Msg("Failed to bind URI") + c.JSON(400, gin.H{ + "status": 400, + "message": "Bad Request", + }) + return + } + + // Check if the request is coming from a browser (tools like curl/bruno use */* and they don't include the text/html) + isBrowser := strings.Contains(c.Request.Header.Get("Accept"), "text/html") + + if isBrowser { + log.Debug().Msg("Request is most likely coming from a browser") + } else { + log.Debug().Msg("Request is most likely not coming from a browser") + } + + log.Debug().Interface("proxy", proxy.Proxy).Msg("Got proxy") + + uri := c.Request.Header.Get("X-Forwarded-Uri") + proto := c.Request.Header.Get("X-Forwarded-Proto") + host := c.Request.Header.Get("X-Forwarded-Host") + + // Remove the port from the host if it exists + hostPortless := strings.Split(host, ":")[0] // *lol* + + // Get the id + id := strings.Split(hostPortless, ".")[0] + + labels, err := h.Docker.GetLabels(id, hostPortless) + if err != nil { + log.Error().Err(err).Msg("Failed to get container labels") + + if proxy.Proxy == "nginx" || !isBrowser { + c.JSON(500, gin.H{ + "status": 500, + "message": "Internal Server Error", + }) + return + } + + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) + return + } + + log.Debug().Interface("labels", labels).Msg("Got labels") + + ip := c.ClientIP() + + // Check if the IP is in bypass list + if h.Auth.BypassedIP(labels, ip) { + headersParsed := utils.ParseHeaders(labels.Headers) + + for key, value := range headersParsed { + log.Debug().Str("key", key).Msg("Setting header") + c.Header(key, value) + } + + if labels.Basic.Username != "" && utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File) != "" { + log.Debug().Str("username", labels.Basic.Username).Msg("Setting basic auth headers") + c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Basic.Username, utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File)))) + } + + c.JSON(200, gin.H{ + "status": 200, + "message": "Authenticated", + }) + return + } + + // Check if the IP is allowed/blocked + if !h.Auth.CheckIP(labels, ip) { + if proxy.Proxy == "nginx" || !isBrowser { + c.JSON(403, gin.H{ + "status": 403, + "message": "Forbidden", + }) + return + } + + values := types.UnauthorizedQuery{ + Resource: strings.Split(host, ".")[0], + IP: ip, + } + + queries, err := query.Values(values) + if err != nil { + log.Error().Err(err).Msg("Failed to build queries") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) + return + } + + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", h.Config.AppURL, queries.Encode())) + return + } + + // Check if auth is enabled + authEnabled, err := h.Auth.AuthEnabled(uri, labels) + if err != nil { + log.Error().Err(err).Msg("Failed to check if app is allowed") + if proxy.Proxy == "nginx" || !isBrowser { + c.JSON(500, gin.H{ + "status": 500, + "message": "Internal Server Error", + }) + return + } + + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) + return + } + + // If auth is not enabled, return 200 + if !authEnabled { + headersParsed := utils.ParseHeaders(labels.Headers) + for key, value := range headersParsed { + log.Debug().Str("key", key).Msg("Setting header") + c.Header(key, value) + } + + if labels.Basic.Username != "" && utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File) != "" { + log.Debug().Str("username", labels.Basic.Username).Msg("Setting basic auth headers") + c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Basic.Username, utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File)))) + } + + c.JSON(200, gin.H{ + "status": 200, + "message": "Authenticated", + }) + + return + } + + // Get user context + userContext := h.Hooks.UseUserContext(c) + + // If we are using basic auth, we need to check if the user has totp and if it does then disable basic auth + if userContext.Provider == "basic" && userContext.TotpEnabled { + log.Warn().Str("username", userContext.Username).Msg("User has totp enabled, disabling basic auth") + userContext.IsLoggedIn = false + } + + // Check if user is logged in + if userContext.IsLoggedIn { + log.Debug().Msg("Authenticated") + + // Check if user is allowed to access subdomain, if request is nginx.example.com the subdomain (resource) is nginx + appAllowed := h.Auth.ResourceAllowed(c, userContext, labels) + + log.Debug().Bool("appAllowed", appAllowed).Msg("Checking if app is allowed") + + if !appAllowed { + log.Warn().Str("username", userContext.Username).Str("host", host).Msg("User not allowed") + + if proxy.Proxy == "nginx" || !isBrowser { + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } + + values := types.UnauthorizedQuery{ + Resource: strings.Split(host, ".")[0], + } + + if userContext.OAuth { + values.Username = userContext.Email + } else { + values.Username = userContext.Username + } + + queries, err := query.Values(values) + if err != nil { + log.Error().Err(err).Msg("Failed to build queries") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) + return + } + + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", h.Config.AppURL, queries.Encode())) + return + } + + // Check groups if using OAuth + if userContext.OAuth { + groupOk := h.Auth.OAuthGroup(c, userContext, labels) + + log.Debug().Bool("groupOk", groupOk).Msg("Checking if user is in required groups") + + if !groupOk { + log.Warn().Str("username", userContext.Username).Str("host", host).Msg("User is not in required groups") + if proxy.Proxy == "nginx" || !isBrowser { + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } + + values := types.UnauthorizedQuery{ + Resource: strings.Split(host, ".")[0], + GroupErr: true, + } + + if userContext.OAuth { + values.Username = userContext.Email + } else { + values.Username = userContext.Username + } + + queries, err := query.Values(values) + if err != nil { + log.Error().Err(err).Msg("Failed to build queries") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) + return + } + + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", h.Config.AppURL, queries.Encode())) + return + } + } + + c.Header("Remote-User", utils.SanitizeHeader(userContext.Username)) + c.Header("Remote-Name", utils.SanitizeHeader(userContext.Name)) + c.Header("Remote-Email", utils.SanitizeHeader(userContext.Email)) + c.Header("Remote-Groups", utils.SanitizeHeader(userContext.OAuthGroups)) + + // Set the rest of the headers + parsedHeaders := utils.ParseHeaders(labels.Headers) + for key, value := range parsedHeaders { + log.Debug().Str("key", key).Msg("Setting header") + c.Header(key, value) + } + + // Set basic auth headers if configured + if labels.Basic.Username != "" && utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File) != "" { + log.Debug().Str("username", labels.Basic.Username).Msg("Setting basic auth headers") + c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Basic.Username, utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File)))) + } + + c.JSON(200, gin.H{ + "status": 200, + "message": "Authenticated", + }) + return + } + + // The user is not logged in + log.Debug().Msg("Unauthorized") + + if proxy.Proxy == "nginx" || !isBrowser { + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } + + queries, err := query.Values(types.LoginQuery{ + RedirectURI: fmt.Sprintf("%s://%s%s", proto, host, uri), + }) + + if err != nil { + log.Error().Err(err).Msg("Failed to build queries") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) + return + } + + log.Debug().Interface("redirect_uri", fmt.Sprintf("%s://%s%s", proto, host, uri)).Msg("Redirecting to login") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/login?%s", h.Config.AppURL, queries.Encode())) +} diff --git a/internal/handlers/user.go b/internal/handlers/user.go new file mode 100644 index 0000000..91d0fef --- /dev/null +++ b/internal/handlers/user.go @@ -0,0 +1,197 @@ +package handlers + +import ( + "fmt" + "strings" + "tinyauth/internal/types" + "tinyauth/internal/utils" + + "github.com/gin-gonic/gin" + "github.com/pquerna/otp/totp" + "github.com/rs/zerolog/log" +) + +func (h *Handlers) LoginHandler(c *gin.Context) { + var login types.LoginRequest + + err := c.BindJSON(&login) + if err != nil { + log.Error().Err(err).Msg("Failed to bind JSON") + c.JSON(400, gin.H{ + "status": 400, + "message": "Bad Request", + }) + return + } + + log.Debug().Msg("Got login request") + + clientIP := c.ClientIP() + + // Create an identifier for rate limiting (username or IP if username doesn't exist yet) + rateIdentifier := login.Username + if rateIdentifier == "" { + rateIdentifier = clientIP + } + + // Check if the account is locked due to too many failed attempts + locked, remainingTime := h.Auth.IsAccountLocked(rateIdentifier) + if locked { + log.Warn().Str("identifier", rateIdentifier).Int("remaining_seconds", remainingTime).Msg("Account is locked due to too many failed login attempts") + c.JSON(429, gin.H{ + "status": 429, + "message": fmt.Sprintf("Too many failed login attempts. Try again in %d seconds", remainingTime), + }) + return + } + + // Search for a user based on username + log.Debug().Interface("username", login.Username).Msg("Searching for user") + + userSearch := h.Auth.SearchUser(login.Username) + + // User does not exist + if userSearch.Type == "" { + log.Debug().Str("username", login.Username).Msg("User not found") + // Record failed login attempt + h.Auth.RecordLoginAttempt(rateIdentifier, false) + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } + + log.Debug().Msg("Got user") + + // Check if password is correct + if !h.Auth.VerifyUser(userSearch, login.Password) { + log.Debug().Str("username", login.Username).Msg("Password incorrect") + // Record failed login attempt + h.Auth.RecordLoginAttempt(rateIdentifier, false) + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } + + log.Debug().Msg("Password correct, checking totp") + + // Record successful login attempt (will reset failed attempt counter) + h.Auth.RecordLoginAttempt(rateIdentifier, true) + + // Check if user is using TOTP + if userSearch.Type == "local" { + // Get local user + localUser := h.Auth.GetLocalUser(login.Username) + + // Check if TOTP is enabled + if localUser.TotpSecret != "" { + log.Debug().Msg("Totp enabled") + + // Set totp pending cookie + h.Auth.CreateSessionCookie(c, &types.SessionCookie{ + Username: login.Username, + Name: utils.Capitalize(login.Username), + Email: fmt.Sprintf("%s@%s", strings.ToLower(login.Username), h.Config.Domain), + Provider: "username", + TotpPending: true, + }) + + // Return totp required + c.JSON(200, gin.H{ + "status": 200, + "message": "Waiting for totp", + "totpPending": true, + }) + return + } + } + + // Create session cookie with username as provider + h.Auth.CreateSessionCookie(c, &types.SessionCookie{ + Username: login.Username, + Name: utils.Capitalize(login.Username), + Email: fmt.Sprintf("%s@%s", strings.ToLower(login.Username), h.Config.Domain), + Provider: "username", + }) + + // Return logged in + c.JSON(200, gin.H{ + "status": 200, + "message": "Logged in", + "totpPending": false, + }) +} + +func (h *Handlers) TOTPHandler(c *gin.Context) { + var totpReq types.TotpRequest + + err := c.BindJSON(&totpReq) + if err != nil { + log.Error().Err(err).Msg("Failed to bind JSON") + c.JSON(400, gin.H{ + "status": 400, + "message": "Bad Request", + }) + return + } + + log.Debug().Msg("Checking totp") + + // Get user context + userContext := h.Hooks.UseUserContext(c) + + // Check if we have a user + if userContext.Username == "" { + log.Debug().Msg("No user context") + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } + + // Get user + user := h.Auth.GetLocalUser(userContext.Username) + + // Check if totp is correct + ok := totp.Validate(totpReq.Code, user.TotpSecret) + + if !ok { + log.Debug().Msg("Totp incorrect") + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } + + log.Debug().Msg("Totp correct") + + // Create session cookie with username as provider + h.Auth.CreateSessionCookie(c, &types.SessionCookie{ + Username: user.Username, + Name: utils.Capitalize(user.Username), + Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), h.Config.Domain), + Provider: "username", + }) + + // Return logged in + c.JSON(200, gin.H{ + "status": 200, + "message": "Logged in", + }) +} + +func (h *Handlers) LogoutHandler(c *gin.Context) { + log.Debug().Msg("Cleaning up redirect cookie") + + h.Auth.DeleteSessionCookie(c) + + c.JSON(200, gin.H{ + "status": 200, + "message": "Logged out", + }) +} diff --git a/internal/server/server.go b/internal/server/server.go index 97cf2e2..cacc10c 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -51,20 +51,20 @@ func NewServer(config types.ServerConfig, handlers *handlers.Handlers) (*Server, }) // Proxy routes - router.GET("/api/auth/:proxy", handlers.AuthHandler) + router.GET("/api/auth/:proxy", handlers.ProxyHandler) // Auth routes router.POST("/api/login", handlers.LoginHandler) - router.POST("/api/totp", handlers.TotpHandler) + router.POST("/api/totp", handlers.TOTPHandler) router.POST("/api/logout", handlers.LogoutHandler) // Context routes - router.GET("/api/app", handlers.AppHandler) - router.GET("/api/user", handlers.UserHandler) + router.GET("/api/app", handlers.AppContextHandler) + router.GET("/api/user", handlers.UserContextHandler) // OAuth routes - router.GET("/api/oauth/url/:provider", handlers.OauthUrlHandler) - router.GET("/api/oauth/callback/:provider", handlers.OauthCallbackHandler) + router.GET("/api/oauth/url/:provider", handlers.OAuthURLHandler) + router.GET("/api/oauth/callback/:provider", handlers.OAuthCallbackHandler) // App routes router.GET("/api/healthcheck", handlers.HealthcheckHandler)