From dfdc656145a12556c9ad7ea1b9100c237706115f Mon Sep 17 00:00:00 2001 From: Stavros Date: Mon, 25 Aug 2025 16:40:06 +0300 Subject: [PATCH] refactor: use controller approach in handlers --- cmd/root.go | 96 ++++-- frontend/src/context/app-context.tsx | 2 +- frontend/src/context/user-context.tsx | 2 +- frontend/src/pages/logout-page.tsx | 2 +- frontend/src/pages/totp-page.tsx | 2 +- go.mod | 1 + go.sum | 2 + internal/controller/context_controller.go | 102 ++++++ internal/controller/health_controller.go | 24 ++ internal/controller/oauth_controller.go | 185 ++++++++++ internal/controller/proxy_controller.go | 281 +++++++++++++++ internal/controller/user_controller.go | 216 ++++++++++++ internal/handlers/context.go | 84 ----- internal/handlers/handlers.go | 46 --- internal/handlers/handlers_test.go | 394 ---------------------- internal/handlers/oauth.go | 223 ------------ internal/handlers/proxy.go | 299 ---------------- internal/handlers/user.go | 215 ------------ internal/server/server.go | 66 ---- internal/types/api.go | 62 ---- internal/types/config.go | 6 - internal/types/types.go | 11 + internal/utils/utils.go | 17 + 23 files changed, 910 insertions(+), 1428 deletions(-) create mode 100644 internal/controller/context_controller.go create mode 100644 internal/controller/health_controller.go create mode 100644 internal/controller/oauth_controller.go create mode 100644 internal/controller/proxy_controller.go create mode 100644 internal/controller/user_controller.go delete mode 100644 internal/handlers/context.go delete mode 100644 internal/handlers/handlers.go delete mode 100644 internal/handlers/handlers_test.go delete mode 100644 internal/handlers/oauth.go delete mode 100644 internal/handlers/proxy.go delete mode 100644 internal/handlers/user.go delete mode 100644 internal/server/server.go delete mode 100644 internal/types/api.go diff --git a/cmd/root.go b/cmd/root.go index 927b375..8dadd5d 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -8,22 +8,28 @@ import ( userCmd "tinyauth/cmd/user" "tinyauth/internal/auth" "tinyauth/internal/constants" + "tinyauth/internal/controller" "tinyauth/internal/docker" - "tinyauth/internal/handlers" "tinyauth/internal/ldap" "tinyauth/internal/middleware" "tinyauth/internal/providers" - "tinyauth/internal/server" "tinyauth/internal/types" "tinyauth/internal/utils" - "github.com/go-playground/validator/v10" + "github.com/gin-gonic/gin" + "github.com/go-playground/validator" "github.com/rs/zerolog" "github.com/rs/zerolog/log" "github.com/spf13/cobra" "github.com/spf13/viper" ) +type Middleware interface { + Middleware() gin.HandlerFunc + Init() error + Name() string +} + var rootCmd = &cobra.Command{ Use: "tinyauth", Short: "The simplest way to protect your apps with a login screen.", @@ -84,25 +90,6 @@ var rootCmd = &cobra.Command{ AppURL: config.AppURL, } - handlersConfig := handlers.HandlersConfig{ - AppURL: config.AppURL, - DisableContinue: config.DisableContinue, - Title: config.Title, - GenericName: config.GenericName, - CookieSecure: config.CookieSecure, - Domain: domain, - ForgotPasswordMessage: config.FogotPasswordMessage, - BackgroundImage: config.BackgroundImage, - OAuthAutoRedirect: config.OAuthAutoRedirect, - CsrfCookieName: csrfCookieName, - RedirectCookieName: redirectCookieName, - } - - serverConfig := types.ServerConfig{ - Port: config.Port, - Address: config.Address, - } - authConfig := types.AuthConfig{ Users: users, OauthWhitelist: config.OAuthWhitelist, @@ -147,10 +134,15 @@ var rootCmd = &cobra.Command{ HandleError(err, "Failed to initialize docker") auth := auth.NewAuth(authConfig, docker, ldapService) providers := providers.NewProviders(oauthConfig) - handlers := handlers.NewHandlers(handlersConfig, auth, providers, docker) + + // Create the engine + engine := gin.New() + + // Create the group + router := engine.Group("/api") // Setup the middlewares - var middlewares []server.Middleware + var middlewares []Middleware contextMiddleware := middleware.NewContextMiddleware(middleware.ContextMiddlewareConfig{ Domain: domain, @@ -160,12 +152,58 @@ var rootCmd = &cobra.Command{ middlewares = append(middlewares, contextMiddleware, uiMiddleware, zerologMiddleware) - srv, err := server.NewServer(serverConfig, handlers, middlewares) - HandleError(err, "Failed to create server") + for _, middleware := range middlewares { + log.Debug().Str("middleware", middleware.Name()).Msg("Initializing middleware") + err := middleware.Init() + HandleError(err, fmt.Sprintf("Failed to initialize middleware %s", middleware.Name())) + router.Use(middleware.Middleware()) + } - // Start up - err = srv.Start() - HandleError(err, "Failed to start server") + // Create configured providers + var configuredProviders []string + + configuredProviders = append(configuredProviders, providers.GetConfiguredProviders()...) + + if auth.UserAuthConfigured() { + configuredProviders = append(configuredProviders, "username") + } + + // Create controllers + contextController := controller.NewContextController(controller.ContextControllerConfig{ + ConfiguredProviders: configuredProviders, + DisableContinue: config.DisableContinue, + Title: config.Title, + GenericName: config.GenericName, + Domain: domain, + ForgotPasswordMessage: config.FogotPasswordMessage, + BackgroundImage: config.BackgroundImage, + OAuthAutoRedirect: config.OAuthAutoRedirect, + }, router) + contextController.SetupRoutes() + + healthController := controller.NewHealthController(router) + healthController.SetupRoutes() + + oauthController := controller.NewOAuthController(controller.OAuthControllerConfig{ + AppURL: config.AppURL, + SecureCookie: config.CookieSecure, + CSRFCookieName: csrfCookieName, + RedirectCookieName: redirectCookieName, + }, router, auth, providers) + oauthController.SetupRoutes() + + proxyController := controller.NewProxyController(controller.ProxyControllerConfig{ + AppURL: config.AppURL, + }, router, docker, auth) + proxyController.SetupRoutes() + + userController := controller.NewUserController(controller.UserControllerConfig{ + Domain: domain, + }, router, auth) + userController.SetupRoutes() + + // Run server + engine.Run(fmt.Sprintf("%s:%d", config.Address, config.Port)) }, } diff --git a/frontend/src/context/app-context.tsx b/frontend/src/context/app-context.tsx index 13abf50..8f76c11 100644 --- a/frontend/src/context/app-context.tsx +++ b/frontend/src/context/app-context.tsx @@ -15,7 +15,7 @@ export const AppContextProvider = ({ }) => { const { isFetching, data, error } = useSuspenseQuery({ queryKey: ["app"], - queryFn: () => axios.get("/api/app").then((res) => res.data), + queryFn: () => axios.get("/api/context/app").then((res) => res.data), }); if (error && !isFetching) { diff --git a/frontend/src/context/user-context.tsx b/frontend/src/context/user-context.tsx index 43b3c00..a3cfeaa 100644 --- a/frontend/src/context/user-context.tsx +++ b/frontend/src/context/user-context.tsx @@ -15,7 +15,7 @@ export const UserContextProvider = ({ }) => { const { isFetching, data, error } = useSuspenseQuery({ queryKey: ["user"], - queryFn: () => axios.get("/api/user").then((res) => res.data), + queryFn: () => axios.get("/api/context/user").then((res) => res.data), }); if (error && !isFetching) { diff --git a/frontend/src/pages/logout-page.tsx b/frontend/src/pages/logout-page.tsx index 8c28500..30b2af8 100644 --- a/frontend/src/pages/logout-page.tsx +++ b/frontend/src/pages/logout-page.tsx @@ -26,7 +26,7 @@ export const LogoutPage = () => { const { t } = useTranslation(); const logoutMutation = useMutation({ - mutationFn: () => axios.post("/api/logout"), + mutationFn: () => axios.post("/api/user/logout"), mutationKey: ["logout"], onSuccess: () => { toast.success(t("logoutSuccessTitle"), { diff --git a/frontend/src/pages/totp-page.tsx b/frontend/src/pages/totp-page.tsx index e04fb2f..7d4ebad 100644 --- a/frontend/src/pages/totp-page.tsx +++ b/frontend/src/pages/totp-page.tsx @@ -32,7 +32,7 @@ export const TotpPage = () => { const redirectUri = searchParams.get("redirect_uri"); const totpMutation = useMutation({ - mutationFn: (values: TotpSchema) => axios.post("/api/totp", values), + mutationFn: (values: TotpSchema) => axios.post("/api/user/totp", values), mutationKey: ["totp"], onSuccess: () => { toast.success(t("totpSuccessTitle"), { diff --git a/go.mod b/go.mod index 0a6f885..8388b2a 100644 --- a/go.mod +++ b/go.mod @@ -68,6 +68,7 @@ require ( github.com/go-logr/stdr v1.2.2 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/go-playground/validator v9.31.0+incompatible github.com/goccy/go-json v0.10.4 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/gorilla/securecookie v1.1.2 // indirect diff --git a/go.sum b/go.sum index dabff47..b43990c 100644 --- a/go.sum +++ b/go.sum @@ -111,6 +111,8 @@ github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/o github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator v9.31.0+incompatible h1:UA72EPEogEnq76ehGdEDp4Mit+3FDh548oRqwVgNsHA= +github.com/go-playground/validator v9.31.0+incompatible/go.mod h1:yrEkQXlcI+PugkyDjY2bRrL/UBU4f3rvrgkN3V8JEig= github.com/go-playground/validator/v10 v10.27.0 h1:w8+XrWVMhGkxOaaowyKH35gFydVHOvC0/uWoy2Fzwn4= github.com/go-playground/validator/v10 v10.27.0/go.mod h1:I5QpIEbmr8On7W0TktmJAumgzX4CA1XNl4ZmDuVHKKo= github.com/go-viper/mapstructure/v2 v2.3.0 h1:27XbWsHIqhbdR5TIC911OfYvgSaW93HM+dX7970Q7jk= diff --git a/internal/controller/context_controller.go b/internal/controller/context_controller.go new file mode 100644 index 0000000..c7dfccf --- /dev/null +++ b/internal/controller/context_controller.go @@ -0,0 +1,102 @@ +package controller + +import ( + "tinyauth/internal/utils" + + "github.com/gin-gonic/gin" +) + +type UserContextResponse struct { + Status int `json:"status"` + Message string `json:"message"` + IsLoggedIn bool `json:"isLoggedIn"` + Username string `json:"username"` + Name string `json:"name"` + Email string `json:"email"` + Provider string `json:"provider"` + Oauth bool `json:"oauth"` + TotpPending bool `json:"totpPending"` +} + +type AppContextResponse struct { + Status int `json:"status"` + Message string `json:"message"` + ConfiguredProviders []string `json:"configuredProviders"` + DisableContinue bool `json:"disableContinue"` + Title string `json:"title"` + GenericName string `json:"genericName"` + Domain string `json:"domain"` + ForgotPasswordMessage string `json:"forgotPasswordMessage"` + BackgroundImage string `json:"backgroundImage"` + OAuthAutoRedirect string `json:"oauthAutoRedirect"` +} + +type ContextControllerConfig struct { + ConfiguredProviders []string + DisableContinue bool + Title string + GenericName string + Domain string + ForgotPasswordMessage string + BackgroundImage string + OAuthAutoRedirect string +} + +type ContextController struct { + Config ContextControllerConfig + Router *gin.RouterGroup +} + +func NewContextController(config ContextControllerConfig, router *gin.RouterGroup) *ContextController { + return &ContextController{ + Config: config, + Router: router, + } +} + +func (controller *ContextController) SetupRoutes() { + contextGroup := controller.Router.Group("/context") + contextGroup.GET("/user", controller.userContextHandler) + contextGroup.GET("/app", controller.appContextHandler) +} + +func (controller *ContextController) userContextHandler(c *gin.Context) { + context, err := utils.GetContext(c) + + userContext := UserContextResponse{ + Status: 200, + Message: "Success", + IsLoggedIn: context.IsLoggedIn, + Username: context.Username, + Name: context.Name, + Email: context.Email, + Provider: context.Provider, + Oauth: context.OAuth, + TotpPending: context.TotpPending, + } + + if err != nil { + userContext.Status = 401 + userContext.Message = "Unauthorized" + userContext.IsLoggedIn = false + c.JSON(200, userContext) + return + } + + c.JSON(200, userContext) +} + +func (controller *ContextController) appContextHandler(c *gin.Context) { + c.JSON(200, AppContextResponse{ + Status: 200, + Message: "Success", + ConfiguredProviders: controller.Config.ConfiguredProviders, + DisableContinue: controller.Config.DisableContinue, + Title: controller.Config.Title, + GenericName: controller.Config.GenericName, + Domain: controller.Config.Domain, + ForgotPasswordMessage: controller.Config.ForgotPasswordMessage, + BackgroundImage: controller.Config.BackgroundImage, + OAuthAutoRedirect: controller.Config.OAuthAutoRedirect, + }) +} diff --git a/internal/controller/health_controller.go b/internal/controller/health_controller.go new file mode 100644 index 0000000..2330fb1 --- /dev/null +++ b/internal/controller/health_controller.go @@ -0,0 +1,24 @@ +package controller + +import "github.com/gin-gonic/gin" + +type HealthController struct { + Router *gin.RouterGroup +} + +func NewHealthController(router *gin.RouterGroup) *HealthController { + return &HealthController{ + Router: router, + } +} + +func (controller *HealthController) SetupRoutes() { + controller.Router.GET("/health", controller.healthHandler) +} + +func (controller *HealthController) healthHandler(c *gin.Context) { + c.JSON(200, gin.H{ + "status": "ok", + "message": "Healthy", + }) +} diff --git a/internal/controller/oauth_controller.go b/internal/controller/oauth_controller.go new file mode 100644 index 0000000..63b6322 --- /dev/null +++ b/internal/controller/oauth_controller.go @@ -0,0 +1,185 @@ +package controller + +import ( + "fmt" + "net/http" + "strings" + "time" + "tinyauth/internal/auth" + "tinyauth/internal/providers" + "tinyauth/internal/types" + "tinyauth/internal/utils" + + "github.com/gin-gonic/gin" + "github.com/google/go-querystring/query" +) + +type OAuthRequest struct { + Provider string `uri:"provider" binding:"required"` +} + +type OAuthControllerConfig struct { + CSRFCookieName string + RedirectCookieName string + SecureCookie bool + AppURL string +} + +type OAuthController struct { + Config OAuthControllerConfig + Router *gin.RouterGroup + Auth *auth.Auth + Providers *providers.Providers +} + +func NewOAuthController(config OAuthControllerConfig, router *gin.RouterGroup, auth *auth.Auth, providers *providers.Providers) *OAuthController { + return &OAuthController{ + Config: config, + Router: router, + Auth: auth, + Providers: providers, + } +} + +func (controller *OAuthController) SetupRoutes() { + oauthGroup := controller.Router.Group("/oauth") + oauthGroup.GET("/url/:provider", controller.oauthURLHandler) + oauthGroup.GET("/callback/:provider", controller.oauthCallbackHandler) +} + +func (controller *OAuthController) oauthURLHandler(c *gin.Context) { + var req OAuthRequest + + err := c.BindUri(&req) + if err != nil { + c.JSON(400, gin.H{ + "status": 400, + "message": "Bad Request", + }) + return + } + + provider := controller.Providers.GetProvider(req.Provider) + + if provider == nil { + c.JSON(404, gin.H{ + "status": 404, + "message": "Not Found", + }) + return + } + + state := provider.GenerateState() + authURL := provider.GetAuthURL(state) + c.SetCookie(controller.Config.CSRFCookieName, state, int(time.Hour.Seconds()), "/", "", controller.Config.SecureCookie, true) + + redirectURI := c.Query("redirect_uri") + + if redirectURI != "" { + c.SetCookie(controller.Config.RedirectCookieName, redirectURI, int(time.Hour.Seconds()), "/", "", controller.Config.SecureCookie, true) + } + + c.JSON(200, gin.H{ + "status": 200, + "message": "OK", + "url": authURL, + }) +} + +func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { + var req OAuthRequest + + err := c.BindUri(&req) + if err != nil { + c.JSON(400, gin.H{ + "status": 400, + "message": "Bad Request", + }) + return + } + + state := c.Query("state") + csrfCookie, err := c.Cookie(controller.Config.CSRFCookieName) + + if err != nil || state != csrfCookie { + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + return + } + + c.SetCookie(controller.Config.CSRFCookieName, "", -1, "/", "", controller.Config.SecureCookie, true) + + code := c.Query("code") + provider := controller.Providers.GetProvider(req.Provider) + + if provider == nil { + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + return + } + + _, err = provider.ExchangeToken(code) + if err != nil { + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + return + } + + user, err := controller.Providers.GetUser(req.Provider) + + if err != nil { + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + return + } + + if user.Email == "" { + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + return + } + + if !controller.Auth.EmailWhitelisted(user.Email) { + queries, err := query.Values(types.UnauthorizedQuery{ + Username: user.Email, + }) + + if err != nil { + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + return + } + + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.Config.AppURL, queries.Encode())) + return + } + + var name string + + if user.Name != "" { + name = user.Name + } else { + name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1]) + } + + controller.Auth.CreateSessionCookie(c, &types.SessionCookie{ + Username: user.Email, + Name: name, + Email: user.Email, + Provider: req.Provider, + OAuthGroups: utils.CoalesceToString(user.Groups), + }) + + redirectURI, err := c.Cookie(controller.Config.RedirectCookieName) + + if err != nil { + c.Redirect(http.StatusTemporaryRedirect, controller.Config.AppURL) + return + } + + queries, err := query.Values(types.RedirectQuery{ + RedirectURI: redirectURI, + }) + + if err != nil { + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + return + } + + c.SetCookie(controller.Config.RedirectCookieName, "", -1, "/", "", controller.Config.SecureCookie, true) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/login?%s", controller.Config.AppURL, queries.Encode())) +} diff --git a/internal/controller/proxy_controller.go b/internal/controller/proxy_controller.go new file mode 100644 index 0000000..f8476f0 --- /dev/null +++ b/internal/controller/proxy_controller.go @@ -0,0 +1,281 @@ +package controller + +import ( + "fmt" + "net/http" + "strings" + "tinyauth/internal/auth" + "tinyauth/internal/docker" + "tinyauth/internal/types" + "tinyauth/internal/utils" + + "github.com/gin-gonic/gin" + "github.com/google/go-querystring/query" +) + +type Proxy struct { + Proxy string `uri:"proxy" binding:"required"` +} + +type ProxyControllerConfig struct { + AppURL string +} + +type ProxyController struct { + Config ProxyControllerConfig + Router *gin.RouterGroup + Docker *docker.Docker + Auth *auth.Auth +} + +func NewProxyController(config ProxyControllerConfig, router *gin.RouterGroup, docker *docker.Docker, auth *auth.Auth) *ProxyController { + return &ProxyController{ + Config: config, + Router: router, + Docker: docker, + Auth: auth, + } +} + +func (controller *ProxyController) SetupRoutes() { + proxyGroup := controller.Router.Group("/api/auth") + proxyGroup.GET("/:proxy", controller.proxyHandler) +} + +func (controller *ProxyController) proxyHandler(c *gin.Context) { + var req Proxy + + err := c.BindUri(&req) + if err != nil { + c.JSON(400, gin.H{ + "status": 400, + "message": "Bad Request", + }) + return + } + + isBrowser := strings.Contains(c.Request.Header.Get("Accept"), "text/html") + + uri := c.Request.Header.Get("X-Forwarded-Uri") + proto := c.Request.Header.Get("X-Forwarded-Proto") + host := c.Request.Header.Get("X-Forwarded-Host") + + hostWithoutPort := strings.Split(host, ":")[0] + id := strings.Split(hostWithoutPort, ".")[0] + + labels, err := controller.Docker.GetLabels(id, hostWithoutPort) + + if err != nil { + if req.Proxy == "nginx" || !isBrowser { + c.JSON(500, gin.H{ + "status": 500, + "message": "Internal Server Error", + }) + return + } + + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + return + } + + clientIP := c.ClientIP() + + if controller.Auth.BypassedIP(labels, clientIP) { + c.Header("Authorization", c.Request.Header.Get("Authorization")) + + headers := utils.ParseHeaders(labels.Headers) + + for key, value := range headers { + c.Header(key, value) + } + + if labels.Basic.Username != "" && utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File) != "" { + c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Basic.Username, utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File)))) + } + + c.JSON(200, gin.H{ + "status": 200, + "message": "Authenticated", + }) + return + } + + if !controller.Auth.CheckIP(labels, clientIP) { + if req.Proxy == "nginx" || !isBrowser { + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } + + queries, err := query.Values(types.UnauthorizedQuery{ + Resource: strings.Split(host, ".")[0], + IP: clientIP, + }) + + if err != nil { + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + } + + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.Config.AppURL, queries.Encode())) + return + } + + authEnabled, err := controller.Auth.AuthEnabled(uri, labels) + + if err != nil { + if req.Proxy == "nginx" || !isBrowser { + c.JSON(500, gin.H{ + "status": 500, + "message": "Internal Server Error", + }) + return + } + + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + return + } + + if !authEnabled { + c.Header("Authorization", c.Request.Header.Get("Authorization")) + + headers := utils.ParseHeaders(labels.Headers) + + for key, value := range headers { + c.Header(key, value) + } + + if labels.Basic.Username != "" && utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File) != "" { + c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Basic.Username, utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File)))) + } + + c.JSON(200, gin.H{ + "status": 200, + "message": "Authenticated", + }) + return + } + + var userContext types.UserContext + + context, err := utils.GetContext(c) + + if err != nil { + userContext = types.UserContext{ + IsLoggedIn: false, + } + } else { + userContext = context + } + + if userContext.Provider == "basic" && userContext.TotpEnabled { + userContext.IsLoggedIn = false + } + + if userContext.IsLoggedIn { + appAllowed := controller.Auth.ResourceAllowed(c, userContext, labels) + + if !appAllowed { + if req.Proxy == "nginx" || !isBrowser { + c.JSON(403, gin.H{ + "status": 403, + "message": "Forbidden", + }) + return + } + + queries, err := query.Values(types.UnauthorizedQuery{ + Resource: strings.Split(host, ".")[0], + }) + + if userContext.OAuth { + queries.Set("username", userContext.Username) + } else { + queries.Set("username", userContext.Email) + } + + if err != nil { + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + return + } + + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.Config.AppURL, queries.Encode())) + return + } + + if userContext.OAuth { + groupOK := controller.Auth.OAuthGroup(c, userContext, labels) + + if !groupOK { + if req.Proxy == "nginx" || !isBrowser { + c.JSON(403, gin.H{ + "status": 403, + "message": "Forbidden", + }) + return + } + + queries, err := query.Values(types.UnauthorizedQuery{ + Resource: strings.Split(host, ".")[0], + GroupErr: true, + }) + + if userContext.OAuth { + queries.Set("username", userContext.Username) + } else { + queries.Set("username", userContext.Email) + } + + if err != nil { + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + return + } + + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.Config.AppURL, queries.Encode())) + return + } + } + + c.Header("Authorization", c.Request.Header.Get("Authorization")) + c.Header("Remote-User", utils.SanitizeHeader(userContext.Username)) + c.Header("Remote-Name", utils.SanitizeHeader(userContext.Name)) + c.Header("Remote-Email", utils.SanitizeHeader(userContext.Email)) + c.Header("Remote-Groups", utils.SanitizeHeader(userContext.OAuthGroups)) + + headers := utils.ParseHeaders(labels.Headers) + + for key, value := range headers { + c.Header(key, value) + } + + if labels.Basic.Username != "" && utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File) != "" { + c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Basic.Username, utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File)))) + } + + c.JSON(200, gin.H{ + "status": 200, + "message": "Authenticated", + }) + return + } + + if req.Proxy == "nginx" || !isBrowser { + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } + + queries, err := query.Values(types.RedirectQuery{ + RedirectURI: fmt.Sprintf("%s://%s%s", proto, host, uri), + }) + + if err != nil { + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.Config.AppURL)) + return + } + + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/login?%s", controller.Config.AppURL, queries.Encode())) +} diff --git a/internal/controller/user_controller.go b/internal/controller/user_controller.go new file mode 100644 index 0000000..e017826 --- /dev/null +++ b/internal/controller/user_controller.go @@ -0,0 +1,216 @@ +package controller + +import ( + "fmt" + "strings" + "tinyauth/internal/auth" + "tinyauth/internal/types" + "tinyauth/internal/utils" + + "github.com/gin-gonic/gin" + "github.com/pquerna/otp/totp" +) + +type LoginRequest struct { + Username string `json:"username"` + Password string `json:"password"` +} + +type TotpRequest struct { + Code string `json:"code"` +} + +type UserControllerConfig struct { + Domain string +} + +type UserController struct { + Config UserControllerConfig + Router *gin.RouterGroup + Auth *auth.Auth +} + +func NewUserController(config UserControllerConfig, router *gin.RouterGroup, auth *auth.Auth) *UserController { + return &UserController{ + Config: config, + Router: router, + Auth: auth, + } +} + +func (controller *UserController) SetupRoutes() { + userGroup := controller.Router.Group("/user") + userGroup.POST("/login", controller.loginHandler) + userGroup.POST("/logout", controller.logoutHandler) + userGroup.POST("/totp", controller.totpHandler) +} + +func (controller *UserController) loginHandler(c *gin.Context) { + var req LoginRequest + + err := c.BindJSON(&req) + if err != nil { + c.JSON(400, gin.H{ + "status": 400, + "message": "Bad Request", + }) + return + } + + clientIP := c.ClientIP() + + rateIdentifier := req.Username + + if rateIdentifier == "" { + rateIdentifier = clientIP + } + + isLocked, remainingTime := controller.Auth.IsAccountLocked(rateIdentifier) + + if isLocked { + c.JSON(429, gin.H{ + "status": 429, + "message": fmt.Sprintf("Too many failed login attempts. Try again in %d seconds", remainingTime), + }) + return + } + + userSearch := controller.Auth.SearchUser(req.Username) + + if userSearch.Type == "" { + controller.Auth.RecordLoginAttempt(rateIdentifier, false) + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } + + if !controller.Auth.VerifyUser(userSearch, req.Password) { + controller.Auth.RecordLoginAttempt(rateIdentifier, false) + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } + + controller.Auth.RecordLoginAttempt(rateIdentifier, true) + + if userSearch.Type == "local" { + user := controller.Auth.GetLocalUser(userSearch.Username) + + if user.TotpSecret != "" { + controller.Auth.CreateSessionCookie(c, &types.SessionCookie{ + Username: user.Username, + Name: utils.Capitalize(req.Username), + Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.Config.Domain), + Provider: "username", + TotpPending: true, + }) + + c.JSON(200, gin.H{ + "status": 200, + "message": "TOTP required", + "totpPending": true, + }) + return + } + } + + controller.Auth.CreateSessionCookie(c, &types.SessionCookie{ + Username: req.Username, + Name: utils.Capitalize(req.Username), + Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.Config.Domain), + Provider: "username", + }) + + c.JSON(200, gin.H{ + "status": 200, + "message": "Login successful", + }) +} + +func (controller *UserController) logoutHandler(c *gin.Context) { + controller.Auth.DeleteSessionCookie(c) + c.JSON(200, gin.H{ + "status": 200, + "message": "Logout successful", + }) +} + +func (controller *UserController) totpHandler(c *gin.Context) { + var req TotpRequest + + err := c.BindJSON(&req) + if err != nil { + c.JSON(400, gin.H{ + "status": 400, + "message": "Bad Request", + }) + return + } + + context, err := utils.GetContext(c) + + if err != nil { + c.JSON(500, gin.H{ + "status": 500, + "message": "Internal Server Error", + }) + return + } + + if !context.IsLoggedIn { + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } + + clientIP := c.ClientIP() + + rateIdentifier := context.Username + + if rateIdentifier == "" { + rateIdentifier = clientIP + } + + isLocked, remainingTime := controller.Auth.IsAccountLocked(rateIdentifier) + + if isLocked { + c.JSON(429, gin.H{ + "status": 429, + "message": fmt.Sprintf("Too many failed login attempts. Try again in %d seconds", remainingTime), + }) + return + } + + user := controller.Auth.GetLocalUser(context.Username) + + ok := totp.Validate(req.Code, user.TotpSecret) + + if !ok { + controller.Auth.RecordLoginAttempt(rateIdentifier, false) + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } + + controller.Auth.RecordLoginAttempt(rateIdentifier, true) + + controller.Auth.CreateSessionCookie(c, &types.SessionCookie{ + Username: user.Username, + Name: utils.Capitalize(user.Username), + Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), controller.Config.Domain), + Provider: "username", + }) + + c.JSON(200, gin.H{ + "status": 200, + "message": "Login successful", + }) +} diff --git a/internal/handlers/context.go b/internal/handlers/context.go deleted file mode 100644 index 0bbe392..0000000 --- a/internal/handlers/context.go +++ /dev/null @@ -1,84 +0,0 @@ -package handlers - -import ( - "tinyauth/internal/types" - - "github.com/gin-gonic/gin" - "github.com/rs/zerolog/log" -) - -func (h *Handlers) AppContextHandler(c *gin.Context) { - log.Debug().Msg("Getting app context") - - // Get configured providers - configuredProviders := h.Providers.GetConfiguredProviders() - - // We have username/password configured so add it to our providers - if h.Auth.UserAuthConfigured() { - configuredProviders = append(configuredProviders, "username") - } - - // Return app context - appContext := types.AppContext{ - Status: 200, - Message: "OK", - ConfiguredProviders: configuredProviders, - DisableContinue: h.Config.DisableContinue, - Title: h.Config.Title, - GenericName: h.Config.GenericName, - Domain: h.Config.Domain, - ForgotPasswordMessage: h.Config.ForgotPasswordMessage, - BackgroundImage: h.Config.BackgroundImage, - OAuthAutoRedirect: h.Config.OAuthAutoRedirect, - } - c.JSON(200, appContext) -} - -func (h *Handlers) UserContextHandler(c *gin.Context) { - log.Debug().Msg("Getting user context") - - // Get user context from middleware - userContextValue, exists := c.Get("context") - - if !exists { - c.JSON(200, types.UserContextResponse{ - Status: 200, - Message: "Unauthorized", - IsLoggedIn: false, - }) - return - } - - userContext, ok := userContextValue.(*types.UserContext) - - if !ok { - c.JSON(200, types.UserContextResponse{ - Status: 200, - Message: "Unauthorized", - IsLoggedIn: false, - }) - return - } - - userContextResponse := types.UserContextResponse{ - Status: 200, - IsLoggedIn: userContext.IsLoggedIn, - Username: userContext.Username, - Name: userContext.Name, - Email: userContext.Email, - Provider: userContext.Provider, - Oauth: userContext.OAuth, - TotpPending: userContext.TotpPending, - } - - // If we are not logged in we set the status to 401 else we set it to 200 - if !userContext.IsLoggedIn { - log.Debug().Msg("Unauthorized") - userContextResponse.Message = "Unauthorized" - } else { - log.Debug().Interface("userContext", userContext).Msg("Authenticated") - userContextResponse.Message = "Authenticated" - } - - c.JSON(200, userContextResponse) -} diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go deleted file mode 100644 index e24f7fa..0000000 --- a/internal/handlers/handlers.go +++ /dev/null @@ -1,46 +0,0 @@ -package handlers - -import ( - "tinyauth/internal/auth" - "tinyauth/internal/docker" - "tinyauth/internal/providers" - - "github.com/gin-gonic/gin" -) - -type HandlersConfig struct { - AppURL string - Domain string - CookieSecure bool - DisableContinue bool - GenericName string - Title string - ForgotPasswordMessage string - BackgroundImage string - OAuthAutoRedirect string - CsrfCookieName string - RedirectCookieName string -} - -type Handlers struct { - Config HandlersConfig - Auth *auth.Auth - Providers *providers.Providers - Docker *docker.Docker -} - -func NewHandlers(config HandlersConfig, auth *auth.Auth, providers *providers.Providers, docker *docker.Docker) *Handlers { - return &Handlers{ - Config: config, - Auth: auth, - Providers: providers, - Docker: docker, - } -} - -func (h *Handlers) HealthcheckHandler(c *gin.Context) { - c.JSON(200, gin.H{ - "status": 200, - "message": "OK", - }) -} diff --git a/internal/handlers/handlers_test.go b/internal/handlers/handlers_test.go deleted file mode 100644 index 279534d..0000000 --- a/internal/handlers/handlers_test.go +++ /dev/null @@ -1,394 +0,0 @@ -package handlers_test - -import ( - "encoding/json" - "io" - "net/http" - "net/http/httptest" - "reflect" - "strings" - "testing" - "time" - "tinyauth/internal/auth" - "tinyauth/internal/docker" - "tinyauth/internal/handlers" - "tinyauth/internal/hooks" - "tinyauth/internal/providers" - "tinyauth/internal/server" - "tinyauth/internal/types" - - "github.com/magiconair/properties/assert" - "github.com/pquerna/otp/totp" -) - -// Simple server config -var serverConfig = types.ServerConfig{ - Port: 8080, - Address: "0.0.0.0", -} - -// Simple handlers config -var handlersConfig = types.HandlersConfig{ - AppURL: "http://localhost:8080", - Domain: "localhost", - DisableContinue: false, - CookieSecure: false, - Title: "Tinyauth", - GenericName: "Generic", - ForgotPasswordMessage: "Message", - CsrfCookieName: "tinyauth-csrf", - RedirectCookieName: "tinyauth-redirect", - BackgroundImage: "https://example.com/image.png", - OAuthAutoRedirect: "none", -} - -// Simple auth config -var authConfig = types.AuthConfig{ - Users: types.Users{}, - OauthWhitelist: "", - HMACSecret: "4bZ9K.*:;zH=,9zG!meUxu.B5-S[7.V.", // Complex on purpose - EncryptionSecret: "\\:!R(u[Sbv6ZLm.7es)H|OqH4y}0u\\rj", - CookieSecure: false, - SessionExpiry: 3600, - LoginTimeout: 0, - LoginMaxRetries: 0, - SessionCookieName: "tinyauth-session", - Domain: "localhost", -} - -// Simple hooks config -var hooksConfig = types.HooksConfig{ - Domain: "localhost", -} - -// Cookie -var cookie string - -// User -var user = types.User{ - Username: "user", - Password: "$2a$10$AvGHLTYv3xiRJ0xV9xs3XeVIlkGTygI9nqIamFYB5Xu.5.0UWF7B6", // pass -} - -// Initialize the server for tests -func getServer(t *testing.T) *server.Server { - // Create services - authConfig.Users = types.Users{ - { - Username: user.Username, - Password: user.Password, - TotpSecret: user.TotpSecret, - }, - } - docker, err := docker.NewDocker() - if err != nil { - t.Fatalf("Failed to create docker client: %v", err) - } - auth := auth.NewAuth(authConfig, nil, nil) - providers := providers.NewProviders(types.OAuthConfig{}) - hooks := hooks.NewHooks(hooksConfig, auth, providers) - handlers := handlers.NewHandlers(handlersConfig, auth, hooks, providers, docker) - - // Create server - srv, err := server.NewServer(serverConfig, handlers) - if err != nil { - t.Fatalf("Failed to create server: %v", err) - } - - return srv -} - -func TestLogin(t *testing.T) { - t.Log("Testing login") - - srv := getServer(t) - - recorder := httptest.NewRecorder() - - user := types.LoginRequest{ - Username: "user", - Password: "pass", - } - - json, err := json.Marshal(user) - if err != nil { - t.Fatalf("Error marshalling json: %v", err) - } - - req, err := http.NewRequest("POST", "/api/login", strings.NewReader(string(json))) - if err != nil { - t.Fatalf("Error creating request: %v", err) - } - - srv.Router.ServeHTTP(recorder, req) - assert.Equal(t, recorder.Code, http.StatusOK) - - cookies := recorder.Result().Cookies() - - if len(cookies) == 0 { - t.Fatalf("Cookie not set") - } - - // Set the cookie for further tests - cookie = cookies[0].Value -} - -func TestAppContext(t *testing.T) { - // Refresh the cookie - TestLogin(t) - - t.Log("Testing app context") - - srv := getServer(t) - - recorder := httptest.NewRecorder() - - req, err := http.NewRequest("GET", "/api/app", nil) - if err != nil { - t.Fatalf("Error creating request: %v", err) - } - - // Set the cookie from the previous test - req.AddCookie(&http.Cookie{ - Name: "tinyauth", - Value: cookie, - }) - - srv.Router.ServeHTTP(recorder, req) - assert.Equal(t, recorder.Code, http.StatusOK) - - body, err := io.ReadAll(recorder.Body) - if err != nil { - t.Fatalf("Error getting body: %v", err) - } - - var app types.AppContext - - err = json.Unmarshal(body, &app) - if err != nil { - t.Fatalf("Error unmarshalling body: %v", err) - } - - expected := types.AppContext{ - Status: 200, - Message: "OK", - ConfiguredProviders: []string{"username"}, - DisableContinue: false, - Title: "Tinyauth", - GenericName: "Generic", - ForgotPasswordMessage: "Message", - BackgroundImage: "https://example.com/image.png", - OAuthAutoRedirect: "none", - Domain: "localhost", - } - - // We should get the username back - if !reflect.DeepEqual(app, expected) { - t.Fatalf("Expected %v, got %v", expected, app) - } -} - -func TestUserContext(t *testing.T) { - // Refresh the cookie - TestLogin(t) - - t.Log("Testing user context") - - srv := getServer(t) - - recorder := httptest.NewRecorder() - - req, err := http.NewRequest("GET", "/api/user", nil) - if err != nil { - t.Fatalf("Error creating request: %v", err) - } - - req.AddCookie(&http.Cookie{ - Name: "tinyauth-session", - Value: cookie, - }) - - srv.Router.ServeHTTP(recorder, req) - assert.Equal(t, recorder.Code, http.StatusOK) - - body, err := io.ReadAll(recorder.Body) - if err != nil { - t.Fatalf("Error getting body: %v", err) - } - - type User struct { - Username string `json:"username"` - } - - var user User - - err = json.Unmarshal(body, &user) - if err != nil { - t.Fatalf("Error unmarshalling body: %v", err) - } - - // We should get the user back - if user.Username != "user" { - t.Fatalf("Expected user, got %s", user.Username) - } -} - -func TestLogout(t *testing.T) { - // Refresh the cookie - TestLogin(t) - - t.Log("Testing logout") - - srv := getServer(t) - - recorder := httptest.NewRecorder() - - req, err := http.NewRequest("POST", "/api/logout", nil) - if err != nil { - t.Fatalf("Error creating request: %v", err) - } - - req.AddCookie(&http.Cookie{ - Name: "tinyauth-session", - Value: cookie, - }) - - srv.Router.ServeHTTP(recorder, req) - assert.Equal(t, recorder.Code, http.StatusOK) - - // Check if the cookie is different (means the cookie is gone) - if recorder.Result().Cookies()[0].Value == cookie { - t.Fatalf("Cookie not flushed") - } -} - -func TestAuth(t *testing.T) { - // Refresh the cookie - TestLogin(t) - - t.Log("Testing auth endpoint") - - srv := getServer(t) - - recorder := httptest.NewRecorder() - - req, err := http.NewRequest("GET", "/api/auth/traefik", nil) - if err != nil { - t.Fatalf("Error creating request: %v", err) - } - - req.Header.Set("Accept", "text/html") - - srv.Router.ServeHTTP(recorder, req) - assert.Equal(t, recorder.Code, http.StatusTemporaryRedirect) - - recorder = httptest.NewRecorder() - - req, err = http.NewRequest("GET", "/api/auth/traefik", nil) - if err != nil { - t.Fatalf("Error creating request: %v", err) - } - - req.AddCookie(&http.Cookie{ - Name: "tinyauth-session", - Value: cookie, - }) - - srv.Router.ServeHTTP(recorder, req) - assert.Equal(t, recorder.Code, http.StatusOK) - - recorder = httptest.NewRecorder() - - req, err = http.NewRequest("GET", "/api/auth/nginx", nil) - if err != nil { - t.Fatalf("Error creating request: %v", err) - } - - srv.Router.ServeHTTP(recorder, req) - assert.Equal(t, recorder.Code, http.StatusUnauthorized) - - recorder = httptest.NewRecorder() - - req, err = http.NewRequest("GET", "/api/auth/nginx", nil) - if err != nil { - t.Fatalf("Error creating request: %v", err) - } - - req.AddCookie(&http.Cookie{ - Name: "tinyauth-session", - Value: cookie, - }) - - srv.Router.ServeHTTP(recorder, req) - assert.Equal(t, recorder.Code, http.StatusOK) -} - -func TestTOTP(t *testing.T) { - t.Log("Testing TOTP") - - key, err := totp.Generate(totp.GenerateOpts{ - Issuer: "Tinyauth", - AccountName: user.Username, - }) - if err != nil { - t.Fatalf("Failed to generate TOTP secret: %v", err) - } - - secret := key.Secret() - - user.TotpSecret = secret - - srv := getServer(t) - - user := types.LoginRequest{ - Username: "user", - Password: "pass", - } - - loginJson, err := json.Marshal(user) - if err != nil { - t.Fatalf("Error marshalling json: %v", err) - } - - recorder := httptest.NewRecorder() - - req, err := http.NewRequest("POST", "/api/login", strings.NewReader(string(loginJson))) - if err != nil { - t.Fatalf("Error creating request: %v", err) - } - - srv.Router.ServeHTTP(recorder, req) - assert.Equal(t, recorder.Code, http.StatusOK) - - // Set the cookie for next test - cookie = recorder.Result().Cookies()[0].Value - - code, err := totp.GenerateCode(secret, time.Now()) - if err != nil { - t.Fatalf("Failed to generate TOTP code: %v", err) - } - - totpRequest := types.TotpRequest{ - Code: code, - } - - totpJson, err := json.Marshal(totpRequest) - if err != nil { - t.Fatalf("Error marshalling TOTP request: %v", err) - } - - recorder = httptest.NewRecorder() - - req, err = http.NewRequest("POST", "/api/totp", strings.NewReader(string(totpJson))) - if err != nil { - t.Fatalf("Error creating request: %v", err) - } - - req.AddCookie(&http.Cookie{ - Name: "tinyauth-session", - Value: cookie, - }) - - srv.Router.ServeHTTP(recorder, req) - assert.Equal(t, recorder.Code, http.StatusOK) -} diff --git a/internal/handlers/oauth.go b/internal/handlers/oauth.go deleted file mode 100644 index 13c3a47..0000000 --- a/internal/handlers/oauth.go +++ /dev/null @@ -1,223 +0,0 @@ -package handlers - -import ( - "fmt" - "net/http" - "strings" - "time" - "tinyauth/internal/types" - "tinyauth/internal/utils" - - "github.com/gin-gonic/gin" - "github.com/google/go-querystring/query" - "github.com/rs/zerolog/log" -) - -func (h *Handlers) OAuthURLHandler(c *gin.Context) { - var request types.OAuthRequest - - err := c.BindUri(&request) - if err != nil { - log.Error().Err(err).Msg("Failed to bind URI") - c.JSON(400, gin.H{ - "status": 400, - "message": "Bad Request", - }) - return - } - - log.Debug().Msg("Got OAuth request") - - // Check if provider exists - provider := h.Providers.GetProvider(request.Provider) - - if provider == nil { - c.JSON(404, gin.H{ - "status": 404, - "message": "Not Found", - }) - return - } - - log.Debug().Str("provider", request.Provider).Msg("Got provider") - - // Create state - state := provider.GenerateState() - - // Get auth URL - authURL := provider.GetAuthURL(state) - - log.Debug().Msg("Got auth URL") - - // Set CSRF cookie - c.SetCookie(h.Config.CsrfCookieName, state, int(time.Hour.Seconds()), "/", "", h.Config.CookieSecure, true) - - // Get redirect URI - redirectURI := c.Query("redirect_uri") - - // Set redirect cookie if redirect URI is provided - if redirectURI != "" { - log.Debug().Str("redirectURI", redirectURI).Msg("Setting redirect cookie") - c.SetCookie(h.Config.RedirectCookieName, redirectURI, int(time.Hour.Seconds()), "/", "", h.Config.CookieSecure, true) - } - - // Return auth URL - c.JSON(200, gin.H{ - "status": 200, - "message": "OK", - "url": authURL, - }) -} - -func (h *Handlers) OAuthCallbackHandler(c *gin.Context) { - var providerName types.OAuthRequest - - err := c.BindUri(&providerName) - if err != nil { - log.Error().Err(err).Msg("Failed to bind URI") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - log.Debug().Interface("provider", providerName.Provider).Msg("Got provider name") - - // Get state - state := c.Query("state") - - // Get CSRF cookie - csrfCookie, err := c.Cookie(h.Config.CsrfCookieName) - - if err != nil { - log.Debug().Msg("No CSRF cookie") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - log.Debug().Str("csrfCookie", csrfCookie).Msg("Got CSRF cookie") - - // Check if CSRF cookie is valid - if csrfCookie != state { - log.Warn().Msg("Invalid CSRF cookie or CSRF cookie does not match with the state") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - // Clean up CSRF cookie - c.SetCookie(h.Config.CsrfCookieName, "", -1, "/", "", h.Config.CookieSecure, true) - - // Get code - code := c.Query("code") - - log.Debug().Msg("Got code") - - // Get provider - provider := h.Providers.GetProvider(providerName.Provider) - - if provider == nil { - c.Redirect(http.StatusTemporaryRedirect, "/not-found") - return - } - - log.Debug().Str("provider", providerName.Provider).Msg("Got provider") - - // Exchange token (authenticates user) - _, err = provider.ExchangeToken(code) - if err != nil { - log.Error().Err(err).Msg("Failed to exchange token") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - log.Debug().Msg("Got token") - - // Get user - user, err := h.Providers.GetUser(providerName.Provider) - if err != nil { - log.Error().Err(err).Msg("Failed to get user") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - log.Debug().Interface("user", user).Msg("Got user") - - // Check that email is not empty - if user.Email == "" { - log.Error().Msg("Email is empty") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - // Email is not whitelisted - if !h.Auth.EmailWhitelisted(user.Email) { - log.Warn().Str("email", user.Email).Msg("Email not whitelisted") - queries, err := query.Values(types.UnauthorizedQuery{ - Username: user.Email, - }) - - if err != nil { - log.Error().Err(err).Msg("Failed to build queries") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", h.Config.AppURL, queries.Encode())) - } - - log.Debug().Msg("Email whitelisted") - - // Get username - var username string - - if user.PreferredUsername != "" { - username = user.PreferredUsername - } else { - username = fmt.Sprintf("%s_%s", strings.Split(user.Email, "@")[0], strings.Split(user.Email, "@")[1]) - } - - // Get name - var name string - - if user.Name != "" { - name = user.Name - } else { - name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1]) - } - - // Create session cookie - h.Auth.CreateSessionCookie(c, &types.SessionCookie{ - Username: username, - Name: name, - Email: user.Email, - Provider: providerName.Provider, - OAuthGroups: utils.CoalesceToString(user.Groups), - }) - - // Check if we have a redirect URI - redirectCookie, err := c.Cookie(h.Config.RedirectCookieName) - - if err != nil { - log.Debug().Msg("No redirect cookie") - c.Redirect(http.StatusTemporaryRedirect, h.Config.AppURL) - return - } - - log.Debug().Str("redirectURI", redirectCookie).Msg("Got redirect URI") - - queries, err := query.Values(types.LoginQuery{ - RedirectURI: redirectCookie, - }) - - if err != nil { - log.Error().Err(err).Msg("Failed to build queries") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - log.Debug().Msg("Got redirect query") - - // Clean up redirect cookie - c.SetCookie(h.Config.RedirectCookieName, "", -1, "/", "", h.Config.CookieSecure, true) - - // Redirect to continue with the redirect URI - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", h.Config.AppURL, queries.Encode())) -} diff --git a/internal/handlers/proxy.go b/internal/handlers/proxy.go deleted file mode 100644 index c9d234e..0000000 --- a/internal/handlers/proxy.go +++ /dev/null @@ -1,299 +0,0 @@ -package handlers - -import ( - "fmt" - "net/http" - "strings" - "tinyauth/internal/types" - "tinyauth/internal/utils" - - "github.com/gin-gonic/gin" - "github.com/google/go-querystring/query" - "github.com/rs/zerolog/log" -) - -func (h *Handlers) ProxyHandler(c *gin.Context) { - var proxy types.Proxy - - err := c.BindUri(&proxy) - if err != nil { - log.Error().Err(err).Msg("Failed to bind URI") - c.JSON(400, gin.H{ - "status": 400, - "message": "Bad Request", - }) - return - } - - // Check if the request is coming from a browser (tools like curl/bruno use */* and they don't include the text/html) - isBrowser := strings.Contains(c.Request.Header.Get("Accept"), "text/html") - - if isBrowser { - log.Debug().Msg("Request is most likely coming from a browser") - } else { - log.Debug().Msg("Request is most likely not coming from a browser") - } - - log.Debug().Interface("proxy", proxy.Proxy).Msg("Got proxy") - - uri := c.Request.Header.Get("X-Forwarded-Uri") - proto := c.Request.Header.Get("X-Forwarded-Proto") - host := c.Request.Header.Get("X-Forwarded-Host") - - hostPortless := strings.Split(host, ":")[0] // *lol* - id := strings.Split(hostPortless, ".")[0] - - labels, err := h.Docker.GetLabels(id, hostPortless) - if err != nil { - log.Error().Err(err).Msg("Failed to get container labels") - - if proxy.Proxy == "nginx" || !isBrowser { - c.JSON(500, gin.H{ - "status": 500, - "message": "Internal Server Error", - }) - return - } - - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - log.Debug().Interface("labels", labels).Msg("Got labels") - - ip := c.ClientIP() - - if h.Auth.BypassedIP(labels, ip) { - c.Header("Authorization", c.Request.Header.Get("Authorization")) - - headersParsed := utils.ParseHeaders(labels.Headers) - for key, value := range headersParsed { - log.Debug().Str("key", key).Msg("Setting header") - c.Header(key, value) - } - - if labels.Basic.Username != "" && utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File) != "" { - log.Debug().Str("username", labels.Basic.Username).Msg("Setting basic auth headers") - c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Basic.Username, utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File)))) - } - - c.JSON(200, gin.H{ - "status": 200, - "message": "Authenticated", - }) - return - } - - if !h.Auth.CheckIP(labels, ip) { - if proxy.Proxy == "nginx" || !isBrowser { - c.JSON(403, gin.H{ - "status": 403, - "message": "Forbidden", - }) - return - } - - values := types.UnauthorizedQuery{ - Resource: strings.Split(host, ".")[0], - IP: ip, - } - - queries, err := query.Values(values) - if err != nil { - log.Error().Err(err).Msg("Failed to build queries") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", h.Config.AppURL, queries.Encode())) - return - } - - authEnabled, err := h.Auth.AuthEnabled(uri, labels) - if err != nil { - log.Error().Err(err).Msg("Failed to check if app is allowed") - if proxy.Proxy == "nginx" || !isBrowser { - c.JSON(500, gin.H{ - "status": 500, - "message": "Internal Server Error", - }) - return - } - - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - if !authEnabled { - c.Header("Authorization", c.Request.Header.Get("Authorization")) - - headersParsed := utils.ParseHeaders(labels.Headers) - for key, value := range headersParsed { - log.Debug().Str("key", key).Msg("Setting header") - c.Header(key, value) - } - - if labels.Basic.Username != "" && utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File) != "" { - log.Debug().Str("username", labels.Basic.Username).Msg("Setting basic auth headers") - c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Basic.Username, utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File)))) - } - - c.JSON(200, gin.H{ - "status": 200, - "message": "Authenticated", - }) - - return - } - - var userContext *types.UserContext - - userContextValue, exists := c.Get("context") - - if !exists { - userContext = &types.UserContext{ - IsLoggedIn: false, - } - } else { - var ok bool - userContext, ok = userContextValue.(*types.UserContext) - - if !ok { - userContext = &types.UserContext{ - IsLoggedIn: false, - } - } - } - - // If we are using basic auth, we need to check if the user has totp and if it does then disable basic auth - if userContext.Provider == "basic" && userContext.TotpEnabled { - log.Warn().Str("username", userContext.Username).Msg("User has totp enabled, disabling basic auth") - userContext.IsLoggedIn = false - } - - if userContext.IsLoggedIn { - log.Debug().Msg("Authenticated") - - // Check if user is allowed to access subdomain, if request is nginx.example.com the subdomain (resource) is nginx - appAllowed := h.Auth.ResourceAllowed(c, *userContext, labels) - - log.Debug().Bool("appAllowed", appAllowed).Msg("Checking if app is allowed") - - if !appAllowed { - log.Warn().Str("username", userContext.Username).Str("host", host).Msg("User not allowed") - - if proxy.Proxy == "nginx" || !isBrowser { - c.JSON(401, gin.H{ - "status": 401, - "message": "Unauthorized", - }) - return - } - - values := types.UnauthorizedQuery{ - Resource: strings.Split(host, ".")[0], - } - - if userContext.OAuth { - values.Username = userContext.Email - } else { - values.Username = userContext.Username - } - - queries, err := query.Values(values) - if err != nil { - log.Error().Err(err).Msg("Failed to build queries") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", h.Config.AppURL, queries.Encode())) - return - } - - if userContext.OAuth { - groupOk := h.Auth.OAuthGroup(c, *userContext, labels) - - log.Debug().Bool("groupOk", groupOk).Msg("Checking if user is in required groups") - - if !groupOk { - log.Warn().Str("username", userContext.Username).Str("host", host).Msg("User is not in required groups") - if proxy.Proxy == "nginx" || !isBrowser { - c.JSON(401, gin.H{ - "status": 401, - "message": "Unauthorized", - }) - return - } - - values := types.UnauthorizedQuery{ - Resource: strings.Split(host, ".")[0], - GroupErr: true, - } - - if userContext.OAuth { - values.Username = userContext.Email - } else { - values.Username = userContext.Username - } - - queries, err := query.Values(values) - if err != nil { - log.Error().Err(err).Msg("Failed to build queries") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", h.Config.AppURL, queries.Encode())) - return - } - } - - c.Header("Authorization", c.Request.Header.Get("Authorization")) - c.Header("Remote-User", utils.SanitizeHeader(userContext.Username)) - c.Header("Remote-Name", utils.SanitizeHeader(userContext.Name)) - c.Header("Remote-Email", utils.SanitizeHeader(userContext.Email)) - c.Header("Remote-Groups", utils.SanitizeHeader(userContext.OAuthGroups)) - - parsedHeaders := utils.ParseHeaders(labels.Headers) - for key, value := range parsedHeaders { - log.Debug().Str("key", key).Msg("Setting header") - c.Header(key, value) - } - - if labels.Basic.Username != "" && utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File) != "" { - log.Debug().Str("username", labels.Basic.Username).Msg("Setting basic auth headers") - c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Basic.Username, utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File)))) - } - - c.JSON(200, gin.H{ - "status": 200, - "message": "Authenticated", - }) - return - } - - // The user is not logged in - log.Debug().Msg("Unauthorized") - - if proxy.Proxy == "nginx" || !isBrowser { - c.JSON(401, gin.H{ - "status": 401, - "message": "Unauthorized", - }) - return - } - - queries, err := query.Values(types.LoginQuery{ - RedirectURI: fmt.Sprintf("%s://%s%s", proto, host, uri), - }) - - if err != nil { - log.Error().Err(err).Msg("Failed to build queries") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) - return - } - - log.Debug().Interface("redirect_uri", fmt.Sprintf("%s://%s%s", proto, host, uri)).Msg("Redirecting to login") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/login?%s", h.Config.AppURL, queries.Encode())) -} diff --git a/internal/handlers/user.go b/internal/handlers/user.go deleted file mode 100644 index 86a18ee..0000000 --- a/internal/handlers/user.go +++ /dev/null @@ -1,215 +0,0 @@ -package handlers - -import ( - "fmt" - "strings" - "tinyauth/internal/types" - "tinyauth/internal/utils" - - "github.com/gin-gonic/gin" - "github.com/pquerna/otp/totp" - "github.com/rs/zerolog/log" -) - -func (h *Handlers) LoginHandler(c *gin.Context) { - var login types.LoginRequest - - err := c.BindJSON(&login) - if err != nil { - log.Error().Err(err).Msg("Failed to bind JSON") - c.JSON(400, gin.H{ - "status": 400, - "message": "Bad Request", - }) - return - } - - log.Debug().Msg("Got login request") - - clientIP := c.ClientIP() - - // Create an identifier for rate limiting (username or IP if username doesn't exist yet) - rateIdentifier := login.Username - if rateIdentifier == "" { - rateIdentifier = clientIP - } - - // Check if the account is locked due to too many failed attempts - locked, remainingTime := h.Auth.IsAccountLocked(rateIdentifier) - if locked { - log.Warn().Str("identifier", rateIdentifier).Int("remaining_seconds", remainingTime).Msg("Account is locked due to too many failed login attempts") - c.JSON(429, gin.H{ - "status": 429, - "message": fmt.Sprintf("Too many failed login attempts. Try again in %d seconds", remainingTime), - }) - return - } - - // Search for a user based on username - log.Debug().Interface("username", login.Username).Msg("Searching for user") - - userSearch := h.Auth.SearchUser(login.Username) - - // User does not exist - if userSearch.Type == "" { - log.Debug().Str("username", login.Username).Msg("User not found") - // Record failed login attempt - h.Auth.RecordLoginAttempt(rateIdentifier, false) - c.JSON(401, gin.H{ - "status": 401, - "message": "Unauthorized", - }) - return - } - - log.Debug().Msg("Got user") - - // Check if password is correct - if !h.Auth.VerifyUser(userSearch, login.Password) { - log.Debug().Str("username", login.Username).Msg("Password incorrect") - // Record failed login attempt - h.Auth.RecordLoginAttempt(rateIdentifier, false) - c.JSON(401, gin.H{ - "status": 401, - "message": "Unauthorized", - }) - return - } - - log.Debug().Msg("Password correct, checking totp") - - // Record successful login attempt (will reset failed attempt counter) - h.Auth.RecordLoginAttempt(rateIdentifier, true) - - // Check if user is using TOTP - if userSearch.Type == "local" { - // Get local user - localUser := h.Auth.GetLocalUser(login.Username) - - // Check if TOTP is enabled - if localUser.TotpSecret != "" { - log.Debug().Msg("Totp enabled") - - // Set totp pending cookie - h.Auth.CreateSessionCookie(c, &types.SessionCookie{ - Username: login.Username, - Name: utils.Capitalize(login.Username), - Email: fmt.Sprintf("%s@%s", strings.ToLower(login.Username), h.Config.Domain), - Provider: "username", - TotpPending: true, - }) - - // Return totp required - c.JSON(200, gin.H{ - "status": 200, - "message": "Waiting for totp", - "totpPending": true, - }) - return - } - } - - // Create session cookie with username as provider - h.Auth.CreateSessionCookie(c, &types.SessionCookie{ - Username: login.Username, - Name: utils.Capitalize(login.Username), - Email: fmt.Sprintf("%s@%s", strings.ToLower(login.Username), h.Config.Domain), - Provider: "username", - }) - - // Return logged in - c.JSON(200, gin.H{ - "status": 200, - "message": "Logged in", - "totpPending": false, - }) -} - -func (h *Handlers) TOTPHandler(c *gin.Context) { - var totpReq types.TotpRequest - - err := c.BindJSON(&totpReq) - if err != nil { - log.Error().Err(err).Msg("Failed to bind JSON") - c.JSON(400, gin.H{ - "status": 400, - "message": "Bad Request", - }) - return - } - - log.Debug().Msg("Checking totp") - - // Get user context - userContextValue, exists := c.Get("context") - - if !exists { - c.JSON(401, gin.H{ - "status": 401, - "message": "Unauthorized", - }) - return - } - - userContext, ok := userContextValue.(*types.UserContext) - - if !ok { - c.JSON(401, gin.H{ - "status": 401, - "message": "Unauthorized", - }) - return - } - - // Check if we have a user - if userContext.Username == "" { - log.Debug().Msg("No user context") - c.JSON(401, gin.H{ - "status": 401, - "message": "Unauthorized", - }) - return - } - - // Get user - user := h.Auth.GetLocalUser(userContext.Username) - - // Check if totp is correct - ok = totp.Validate(totpReq.Code, user.TotpSecret) - - if !ok { - log.Debug().Msg("Totp incorrect") - c.JSON(401, gin.H{ - "status": 401, - "message": "Unauthorized", - }) - return - } - - log.Debug().Msg("Totp correct") - - // Create session cookie with username as provider - h.Auth.CreateSessionCookie(c, &types.SessionCookie{ - Username: user.Username, - Name: utils.Capitalize(user.Username), - Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), h.Config.Domain), - Provider: "username", - }) - - // Return logged in - c.JSON(200, gin.H{ - "status": 200, - "message": "Logged in", - }) -} - -func (h *Handlers) LogoutHandler(c *gin.Context) { - log.Debug().Msg("Cleaning up redirect cookie") - - h.Auth.DeleteSessionCookie(c) - - c.JSON(200, gin.H{ - "status": 200, - "message": "Logged out", - }) -} diff --git a/internal/server/server.go b/internal/server/server.go deleted file mode 100644 index 78150fe..0000000 --- a/internal/server/server.go +++ /dev/null @@ -1,66 +0,0 @@ -package server - -import ( - "fmt" - "tinyauth/internal/handlers" - "tinyauth/internal/types" - - "github.com/gin-gonic/gin" - "github.com/rs/zerolog/log" -) - -type Server struct { - Config types.ServerConfig - Handlers *handlers.Handlers - Router *gin.Engine -} - -type Middleware interface { - Middleware() gin.HandlerFunc - Init() error - Name() string -} - -func NewServer(config types.ServerConfig, handlers *handlers.Handlers, middlewares []Middleware) (*Server, error) { - router := gin.New() - - for _, middleware := range middlewares { - log.Debug().Str("middleware", middleware.Name()).Msg("Initializing middleware") - err := middleware.Init() - if err != nil { - return nil, fmt.Errorf("failed to initialize middleware %s: %w", middleware.Name(), err) - } - router.Use(middleware.Middleware()) - } - - // Proxy routes - router.GET("/api/auth/:proxy", handlers.ProxyHandler) - - // Auth routes - router.POST("/api/login", handlers.LoginHandler) - router.POST("/api/totp", handlers.TOTPHandler) - router.POST("/api/logout", handlers.LogoutHandler) - - // Context routes - router.GET("/api/app", handlers.AppContextHandler) - router.GET("/api/user", handlers.UserContextHandler) - - // OAuth routes - router.GET("/api/oauth/url/:provider", handlers.OAuthURLHandler) - router.GET("/api/oauth/callback/:provider", handlers.OAuthCallbackHandler) - - // App routes - router.GET("/api/healthcheck", handlers.HealthcheckHandler) - router.HEAD("/api/healthcheck", handlers.HealthcheckHandler) - - return &Server{ - Config: config, - Handlers: handlers, - Router: router, - }, nil -} - -func (s *Server) Start() error { - log.Info().Str("address", s.Config.Address).Int("port", s.Config.Port).Msg("Starting server") - return s.Router.Run(fmt.Sprintf("%s:%d", s.Config.Address, s.Config.Port)) -} diff --git a/internal/types/api.go b/internal/types/api.go deleted file mode 100644 index fbf8bf7..0000000 --- a/internal/types/api.go +++ /dev/null @@ -1,62 +0,0 @@ -package types - -// LoginQuery is the query parameters for the login endpoint -type LoginQuery struct { - RedirectURI string `url:"redirect_uri"` -} - -// LoginRequest is the request body for the login endpoint -type LoginRequest struct { - Username string `json:"username"` - Password string `json:"password"` -} - -// OAuthRequest is the request for the OAuth endpoint -type OAuthRequest struct { - Provider string `uri:"provider" binding:"required"` -} - -// UnauthorizedQuery is the query parameters for the unauthorized endpoint -type UnauthorizedQuery struct { - Username string `url:"username"` - Resource string `url:"resource"` - GroupErr bool `url:"groupErr"` - IP string `url:"ip"` -} - -// Proxy is the uri parameters for the proxy endpoint -type Proxy struct { - Proxy string `uri:"proxy" binding:"required"` -} - -// User Context response is the response for the user context endpoint -type UserContextResponse struct { - Status int `json:"status"` - Message string `json:"message"` - IsLoggedIn bool `json:"isLoggedIn"` - Username string `json:"username"` - Name string `json:"name"` - Email string `json:"email"` - Provider string `json:"provider"` - Oauth bool `json:"oauth"` - TotpPending bool `json:"totpPending"` -} - -// App Context is the response for the app context endpoint -type AppContext struct { - Status int `json:"status"` - Message string `json:"message"` - ConfiguredProviders []string `json:"configuredProviders"` - DisableContinue bool `json:"disableContinue"` - Title string `json:"title"` - GenericName string `json:"genericName"` - Domain string `json:"domain"` - ForgotPasswordMessage string `json:"forgotPasswordMessage"` - BackgroundImage string `json:"backgroundImage"` - OAuthAutoRedirect string `json:"oauthAutoRedirect"` -} - -// Totp request is the request for the totp endpoint -type TotpRequest struct { - Code string `json:"code"` -} diff --git a/internal/types/config.go b/internal/types/config.go index 4b32ad9..dfb9e98 100644 --- a/internal/types/config.go +++ b/internal/types/config.go @@ -60,12 +60,6 @@ type OAuthConfig struct { AppURL string } -// ServerConfig is the configuration for the server -type ServerConfig struct { - Port int - Address string -} - // AuthConfig is the configuration for the auth service type AuthConfig struct { Users Users diff --git a/internal/types/types.go b/internal/types/types.go index 2c40ae5..1cb6bed 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -57,3 +57,14 @@ type LoginAttempt struct { LastAttempt time.Time LockedUntil time.Time } + +type UnauthorizedQuery struct { + Username string `url:"username"` + Resource string `url:"resource"` + GroupErr bool `url:"groupErr"` + IP string `url:"ip"` +} + +type RedirectQuery struct { + RedirectURI string `url:"redirect_uri"` +} diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 39b1518..8c2f4ea 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -13,6 +13,7 @@ import ( "strings" "tinyauth/internal/types" + "github.com/gin-gonic/gin" "github.com/traefik/paerser/parser" "golang.org/x/crypto/hkdf" @@ -348,3 +349,19 @@ func CoalesceToString(value any) string { return "" } } + +func GetContext(c *gin.Context) (types.UserContext, error) { + userContextValue, exists := c.Get("context") + + if !exists { + return types.UserContext{}, errors.New("no user context in request") + } + + userContext, ok := userContextValue.(*types.UserContext) + + if !ok { + return types.UserContext{}, errors.New("invalid user context in request") + } + + return *userContext, nil +}