diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 93255b0..24fd442 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -125,9 +125,9 @@ jobs: with: images: ghcr.io/${{ github.repository_owner }}/tinyauth tags: | - type=semver,pattern=v{{version}} - type=semver,pattern=v{{major}} - type=semver,pattern=v{{major}}.{{minor}} + type=semver,pattern={{version}},prefix=v + type=semver,pattern={{major}},prefix=v + type=semver,pattern={{major}}.{{minor}},prefix=v - name: Create manifest list and push working-directory: ${{ runner.temp }}/digests diff --git a/cmd/root.go b/cmd/root.go index 453d76f..2676b02 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -2,6 +2,7 @@ package cmd import ( "errors" + "fmt" "os" "strings" "time" @@ -11,6 +12,7 @@ import ( "tinyauth/internal/assets" "tinyauth/internal/auth" "tinyauth/internal/docker" + "tinyauth/internal/handlers" "tinyauth/internal/hooks" "tinyauth/internal/providers" "tinyauth/internal/types" @@ -33,8 +35,8 @@ var rootCmd = &cobra.Command{ // Get config var config types.Config - parseErr := viper.Unmarshal(&config) - HandleError(parseErr, "Failed to parse config") + err := viper.Unmarshal(&config) + HandleError(err, "Failed to parse config") // Secrets config.Secret = utils.GetSecret(config.Secret, config.SecretFile) @@ -45,8 +47,8 @@ var rootCmd = &cobra.Command{ // Validate config validator := validator.New() - validateErr := validator.Struct(config) - HandleError(validateErr, "Failed to validate config") + err = validator.Struct(config) + HandleError(err, "Failed to validate config") // Logger log.Logger = log.Level(zerolog.Level(config.LogLevel)) @@ -54,9 +56,8 @@ var rootCmd = &cobra.Command{ // Users log.Info().Msg("Parsing users") - users, usersErr := utils.GetUsers(config.Users, config.UsersFile) - - HandleError(usersErr, "Failed to parse users") + users, err := utils.GetUsers(config.Users, config.UsersFile) + HandleError(err, "Failed to parse users") if len(users) == 0 && !utils.OAuthConfigured(config) { HandleError(errors.New("no users or OAuth configured"), "No users or OAuth configured") @@ -66,8 +67,15 @@ var rootCmd = &cobra.Command{ oauthWhitelist := utils.Filter(strings.Split(config.OAuthWhitelist, ","), func(val string) bool { return val != "" }) + log.Debug().Msg("Parsed OAuth whitelist") + // Get domain + log.Debug().Msg("Getting domain") + domain, err := utils.GetUpperDomain(config.AppURL) + HandleError(err, "Failed to get upper domain") + log.Info().Str("domain", domain).Msg("Using domain for cookie store") + // Create OAuth config oauthConfig := types.OAuthConfig{ GithubClientId: config.GithubClientId, @@ -85,7 +93,25 @@ var rootCmd = &cobra.Command{ AppURL: config.AppURL, } - log.Debug().Msg("Parsed OAuth config") + // Create handlers config + serverConfig := types.HandlersConfig{ + AppURL: config.AppURL, + Domain: fmt.Sprintf(".%s", domain), + CookieSecure: config.CookieSecure, + DisableContinue: config.DisableContinue, + Title: config.Title, + GenericName: config.GenericName, + } + + // Create api config + apiConfig := types.APIConfig{ + Port: config.Port, + Address: config.Address, + Secret: config.Secret, + CookieSecure: config.CookieSecure, + SessionExpiry: config.SessionExpiry, + Domain: domain, + } // Create docker service docker := docker.NewDocker() @@ -106,18 +132,11 @@ var rootCmd = &cobra.Command{ // Create hooks service hooks := hooks.NewHooks(auth, providers) + // Create handlers + handlers := handlers.NewHandlers(serverConfig, auth, hooks, providers) + // Create API - api := api.NewAPI(types.APIConfig{ - Port: config.Port, - Address: config.Address, - Secret: config.Secret, - AppURL: config.AppURL, - CookieSecure: config.CookieSecure, - DisableContinue: config.DisableContinue, - SessionExpiry: config.SessionExpiry, - Title: config.Title, - GenericName: config.GenericName, - }, hooks, auth, providers) + api := api.NewAPI(apiConfig, handlers) // Setup routes api.Init() @@ -134,7 +153,7 @@ func Execute() { } func HandleError(err error, msg string) { - // If error log it and exit + // If error, log it and exit if err != nil { log.Fatal().Err(err).Msg(msg) } diff --git a/internal/api/api.go b/internal/api/api.go index 7001ea5..7d861f3 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -3,42 +3,30 @@ package api import ( "fmt" "io/fs" - "math/rand/v2" "net/http" - "os" "strings" "time" "tinyauth/internal/assets" - "tinyauth/internal/auth" - "tinyauth/internal/hooks" - "tinyauth/internal/providers" + "tinyauth/internal/handlers" "tinyauth/internal/types" - "tinyauth/internal/utils" "github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions/cookie" "github.com/gin-gonic/gin" - "github.com/google/go-querystring/query" - "github.com/pquerna/otp/totp" "github.com/rs/zerolog/log" ) -func NewAPI(config types.APIConfig, hooks *hooks.Hooks, auth *auth.Auth, providers *providers.Providers) *API { +func NewAPI(config types.APIConfig, handlers *handlers.Handlers) *API { return &API{ - Config: config, - Hooks: hooks, - Auth: auth, - Providers: providers, + Config: config, + Handlers: handlers, } } type API struct { - Config types.APIConfig - Router *gin.Engine - Hooks *hooks.Hooks - Auth *auth.Auth - Providers *providers.Providers - Domain string + Config types.APIConfig + Router *gin.Engine + Handlers *handlers.Handlers } func (api *API) Init() { @@ -52,10 +40,10 @@ func (api *API) Init() { // Read UI assets log.Debug().Msg("Setting up assets") - dist, distErr := fs.Sub(assets.Assets, "dist") + dist, err := fs.Sub(assets.Assets, "dist") - if distErr != nil { - log.Fatal().Err(distErr).Msg("Failed to get UI assets") + if err != nil { + log.Fatal().Err(err).Msg("Failed to get UI assets") } // Create file server @@ -66,22 +54,9 @@ func (api *API) Init() { log.Debug().Msg("Setting up cookie store") store := cookie.NewStore([]byte(api.Config.Secret)) - // Get domain to use for session cookies - log.Debug().Msg("Getting domain") - domain, domainErr := utils.GetRootURL(api.Config.AppURL) - - if domainErr != nil { - log.Fatal().Err(domainErr).Msg("Failed to get domain") - os.Exit(1) - } - - log.Info().Str("domain", domain).Msg("Using domain for cookies") - - api.Domain = fmt.Sprintf(".%s", domain) - // Use session middleware store.Options(sessions.Options{ - Domain: api.Domain, + Domain: api.Config.Domain, Path: "/", HttpOnly: true, Secure: api.Config.CookieSecure, @@ -94,17 +69,7 @@ func (api *API) Init() { router.Use(func(c *gin.Context) { // If not an API request, serve the UI if !strings.HasPrefix(c.Request.URL.Path, "/api") { - _, err := fs.Stat(dist, strings.TrimPrefix(c.Request.URL.Path, "/")) - - // If the file doesn't exist, serve the index.html - if os.IsNotExist(err) { - c.Request.URL.Path = "/" - } - - // Serve the file fileServer.ServeHTTP(c.Writer, c.Request) - - // Stop further processing c.Abort() } }) @@ -114,608 +79,24 @@ func (api *API) Init() { } func (api *API) SetupRoutes() { - api.Router.GET("/api/auth/:proxy", func(c *gin.Context) { - // Create struct for proxy - var proxy types.Proxy + // Proxy + api.Router.GET("/api/auth/:proxy", api.Handlers.AuthHandler) - // Bind URI - bindErr := c.BindUri(&proxy) + // Auth + api.Router.POST("/api/login", api.Handlers.LoginHandler) + api.Router.POST("/api/totp", api.Handlers.TotpHandler) + api.Router.POST("/api/logout", api.Handlers.LogoutHandler) - // Handle error - if bindErr != nil { - log.Error().Err(bindErr).Msg("Failed to bind URI") - c.JSON(400, gin.H{ - "status": 400, - "message": "Bad Request", - }) - return - } + // Context + api.Router.GET("/api/app", api.Handlers.AppHandler) + api.Router.GET("/api/user", api.Handlers.UserHandler) - // Check if the request is coming from a browser (tools like curl/bruno use */* and they don't include the text/html) - isBrowser := strings.Contains(c.Request.Header.Get("Accept"), "text/html") + // OAuth + api.Router.GET("/api/oauth/url/:provider", api.Handlers.OauthUrlHandler) + api.Router.GET("/api/oauth/callback/:provider", api.Handlers.OauthCallbackHandler) - if isBrowser { - log.Debug().Msg("Request is most likely coming from a browser") - } else { - log.Debug().Msg("Request is most likely not coming from a browser") - } - - log.Debug().Interface("proxy", proxy.Proxy).Msg("Got proxy") - - // Check if auth is enabled - authEnabled, authEnabledErr := api.Auth.AuthEnabled(c) - - // Handle error - if authEnabledErr != nil { - // Return 500 if nginx is the proxy or if the request is not coming from a browser - if proxy.Proxy == "nginx" || !isBrowser { - log.Error().Err(authEnabledErr).Msg("Failed to check if auth is enabled") - c.JSON(500, gin.H{ - "status": 500, - "message": "Internal Server Error", - }) - return - } - - // Return the internal server error page - if api.handleError(c, "Failed to check if auth is enabled", authEnabledErr) { - return - } - } - - // If auth is not enabled, return 200 - if !authEnabled { - // The user is allowed to access the app - c.JSON(200, gin.H{ - "status": 200, - "message": "Authenticated", - }) - - // Stop further processing - return - } - - // Get user context - userContext := api.Hooks.UseUserContext(c) - - // Get headers - uri := c.Request.Header.Get("X-Forwarded-Uri") - proto := c.Request.Header.Get("X-Forwarded-Proto") - host := c.Request.Header.Get("X-Forwarded-Host") - - // Check if user is logged in - if userContext.IsLoggedIn { - log.Debug().Msg("Authenticated") - - // Check if user is allowed to access subdomain, if request is nginx.example.com the subdomain (resource) is nginx - appAllowed, appAllowedErr := api.Auth.ResourceAllowed(c, userContext) - - // Check if there was an error - if appAllowedErr != nil { - // Return 500 if nginx is the proxy or if the request is not coming from a browser - if proxy.Proxy == "nginx" || !isBrowser { - log.Error().Err(appAllowedErr).Msg("Failed to check if app is allowed") - c.JSON(500, gin.H{ - "status": 500, - "message": "Internal Server Error", - }) - return - } - - // Return the internal server error page - if api.handleError(c, "Failed to check if app is allowed", appAllowedErr) { - return - } - } - - log.Debug().Bool("appAllowed", appAllowed).Msg("Checking if app is allowed") - - // The user is not allowed to access the app - if !appAllowed { - log.Warn().Str("username", userContext.Username).Str("host", host).Msg("User not allowed") - - // Set WWW-Authenticate header - c.Header("WWW-Authenticate", "Basic realm=\"tinyauth\"") - - // Return 401 if nginx is the proxy or if the request is not coming from a browser - if proxy.Proxy == "nginx" || !isBrowser { - c.JSON(401, gin.H{ - "status": 401, - "message": "Unauthorized", - }) - return - } - - // Build query - queries, queryErr := query.Values(types.UnauthorizedQuery{ - Username: userContext.Username, - Resource: strings.Split(host, ".")[0], - }) - - // Handle error (no need to check for nginx/headers since we are sure we are using caddy/traefik) - if api.handleError(c, "Failed to build query", queryErr) { - return - } - - // We are using caddy/traefik so redirect - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", api.Config.AppURL, queries.Encode())) - - // Stop further processing - return - } - - // Set the user header - c.Header("Remote-User", userContext.Username) - - // The user is allowed to access the app - c.JSON(200, gin.H{ - "status": 200, - "message": "Authenticated", - }) - - // Stop further processing - return - } - - // The user is not logged in - log.Debug().Msg("Unauthorized") - - // Set www-authenticate header - c.Header("WWW-Authenticate", "Basic realm=\"tinyauth\"") - - // Return 401 if nginx is the proxy or if the request is not coming from a browser - if proxy.Proxy == "nginx" || !isBrowser { - c.JSON(401, gin.H{ - "status": 401, - "message": "Unauthorized", - }) - return - } - - // Build query - queries, queryErr := query.Values(types.LoginQuery{ - RedirectURI: fmt.Sprintf("%s://%s%s", proto, host, uri), - }) - - // Handle error (no need to check for nginx/headers since we are sure we are using caddy/traefik) - if api.handleError(c, "Failed to build query", queryErr) { - return - } - - log.Debug().Interface("redirect_uri", fmt.Sprintf("%s://%s%s", proto, host, uri)).Msg("Redirecting to login") - - // Redirect to login - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/?%s", api.Config.AppURL, queries.Encode())) - }) - - api.Router.POST("/api/login", func(c *gin.Context) { - // Create login struct - var login types.LoginRequest - - // Bind JSON - err := c.BindJSON(&login) - - // Handle error - if err != nil { - log.Error().Err(err).Msg("Failed to bind JSON") - c.JSON(400, gin.H{ - "status": 400, - "message": "Bad Request", - }) - return - } - - log.Debug().Msg("Got login request") - - // Get user based on username - user := api.Auth.GetUser(login.Username) - - // User does not exist - if user == nil { - log.Debug().Str("username", login.Username).Msg("User not found") - c.JSON(401, gin.H{ - "status": 401, - "message": "Unauthorized", - }) - return - } - - log.Debug().Msg("Got user") - - // Check if password is correct - if !api.Auth.CheckPassword(*user, login.Password) { - log.Debug().Str("username", login.Username).Msg("Password incorrect") - c.JSON(401, gin.H{ - "status": 401, - "message": "Unauthorized", - }) - return - } - - log.Debug().Msg("Password correct, checking totp") - - // Check if user has totp enabled - if user.TotpSecret != "" { - log.Debug().Msg("Totp enabled") - - // Set totp pending cookie - api.Auth.CreateSessionCookie(c, &types.SessionCookie{ - Username: login.Username, - Provider: "username", - TotpPending: true, - }) - - // Return totp required - c.JSON(200, gin.H{ - "status": 200, - "message": "Waiting for totp", - "totpPending": true, - }) - - // Stop further processing - return - } - - // Create session cookie with username as provider - api.Auth.CreateSessionCookie(c, &types.SessionCookie{ - Username: login.Username, - Provider: "username", - }) - - // Return logged in - c.JSON(200, gin.H{ - "status": 200, - "message": "Logged in", - "totpPending": false, - }) - }) - - api.Router.POST("/api/totp", func(c *gin.Context) { - // Create totp struct - var totpReq types.TotpRequest - - // Bind JSON - err := c.BindJSON(&totpReq) - - // Handle error - if err != nil { - log.Error().Err(err).Msg("Failed to bind JSON") - c.JSON(400, gin.H{ - "status": 400, - "message": "Bad Request", - }) - return - } - - log.Debug().Msg("Checking totp") - - // Get user context - userContext := api.Hooks.UseUserContext(c) - - // Check if we have a user - if userContext.Username == "" { - log.Debug().Msg("No user context") - c.JSON(401, gin.H{ - "status": 401, - "message": "Unauthorized", - }) - return - } - - // Get user - user := api.Auth.GetUser(userContext.Username) - - // Check if user exists - if user == nil { - log.Debug().Msg("User not found") - c.JSON(401, gin.H{ - "status": 401, - "message": "Unauthorized", - }) - return - } - - // Check if totp is correct - totpOk := totp.Validate(totpReq.Code, user.TotpSecret) - - // TOTP is incorrect - if !totpOk { - log.Debug().Msg("Totp incorrect") - c.JSON(401, gin.H{ - "status": 401, - "message": "Unauthorized", - }) - return - } - - log.Debug().Msg("Totp correct") - - // Create session cookie with username as provider - api.Auth.CreateSessionCookie(c, &types.SessionCookie{ - Username: user.Username, - Provider: "username", - }) - - // Return logged in - c.JSON(200, gin.H{ - "status": 200, - "message": "Logged in", - }) - }) - - api.Router.POST("/api/logout", func(c *gin.Context) { - log.Debug().Msg("Logging out") - - // Delete session cookie - api.Auth.DeleteSessionCookie(c) - - log.Debug().Msg("Cleaning up redirect cookie") - - // Clean up redirect cookie if it exists - c.SetCookie("tinyauth_redirect_uri", "", -1, "/", api.Domain, api.Config.CookieSecure, true) - - // Return logged out - c.JSON(200, gin.H{ - "status": 200, - "message": "Logged out", - }) - }) - - api.Router.GET("/api/app", func(c *gin.Context) { - log.Debug().Msg("Getting app context") - - // Get configured providers - configuredProviders := api.Providers.GetConfiguredProviders() - - // We have username/password configured so add it to our providers - if api.Auth.UserAuthConfigured() { - configuredProviders = append(configuredProviders, "username") - } - - // Create app context struct - appContext := types.AppContext{ - Status: 200, - Message: "Ok", - ConfiguredProviders: configuredProviders, - DisableContinue: api.Config.DisableContinue, - Title: api.Config.Title, - GenericName: api.Config.GenericName, - } - - // Return app context - c.JSON(200, appContext) - }) - - api.Router.GET("/api/user", func(c *gin.Context) { - log.Debug().Msg("Getting user context") - - // Get user context - userContext := api.Hooks.UseUserContext(c) - - // Create user context response - userContextResponse := types.UserContextResponse{ - Status: 200, - IsLoggedIn: userContext.IsLoggedIn, - Username: userContext.Username, - Provider: userContext.Provider, - Oauth: userContext.OAuth, - TotpPending: userContext.TotpPending, - } - - // If we are not logged in we set the status to 401 and add the WWW-Authenticate header else we set it to 200 - if !userContext.IsLoggedIn { - log.Debug().Msg("Unauthorized") - c.Header("WWW-Authenticate", "Basic realm=\"tinyauth\"") - userContextResponse.Message = "Unauthorized" - } else { - log.Debug().Interface("userContext", userContext).Msg("Authenticated") - userContextResponse.Message = "Authenticated" - } - - // Return user context - c.JSON(200, userContextResponse) - }) - - api.Router.GET("/api/oauth/url/:provider", func(c *gin.Context) { - // Create struct for OAuth request - var request types.OAuthRequest - - // Bind URI - bindErr := c.BindUri(&request) - - // Handle error - if bindErr != nil { - log.Error().Err(bindErr).Msg("Failed to bind URI") - c.JSON(400, gin.H{ - "status": 400, - "message": "Bad Request", - }) - return - } - - log.Debug().Msg("Got OAuth request") - - // Check if provider exists - provider := api.Providers.GetProvider(request.Provider) - - // Provider does not exist - if provider == nil { - c.JSON(404, gin.H{ - "status": 404, - "message": "Not Found", - }) - return - } - - log.Debug().Str("provider", request.Provider).Msg("Got provider") - - // Get auth URL - authURL := provider.GetAuthURL() - - log.Debug().Msg("Got auth URL") - - // Get redirect URI - redirectURI := c.Query("redirect_uri") - - // Set redirect cookie if redirect URI is provided - if redirectURI != "" { - log.Debug().Str("redirectURI", redirectURI).Msg("Setting redirect cookie") - c.SetCookie("tinyauth_redirect_uri", redirectURI, 3600, "/", api.Domain, api.Config.CookieSecure, true) - } - - // Tailscale does not have an auth url so we create a random code (does not need to be secure) to avoid caching and send it - if request.Provider == "tailscale" { - // Build tailscale query - tailscaleQuery, tailscaleQueryErr := query.Values(types.TailscaleQuery{ - Code: (1000 + rand.IntN(9000)), - }) - - // Handle error - if tailscaleQueryErr != nil { - log.Error().Err(tailscaleQueryErr).Msg("Failed to build query") - c.JSON(500, gin.H{ - "status": 500, - "message": "Internal Server Error", - }) - return - } - - // Return tailscale URL (immidiately redirects to the callback) - c.JSON(200, gin.H{ - "status": 200, - "message": "Ok", - "url": fmt.Sprintf("%s/api/oauth/callback/tailscale?%s", api.Config.AppURL, tailscaleQuery.Encode()), - }) - return - } - - // Return auth URL - c.JSON(200, gin.H{ - "status": 200, - "message": "Ok", - "url": authURL, - }) - }) - - api.Router.GET("/api/oauth/callback/:provider", func(c *gin.Context) { - // Create struct for OAuth request - var providerName types.OAuthRequest - - // Bind URI - bindErr := c.BindUri(&providerName) - - // Handle error - if api.handleError(c, "Failed to bind URI", bindErr) { - return - } - - log.Debug().Interface("provider", providerName.Provider).Msg("Got provider name") - - // Get code - code := c.Query("code") - - // Code empty so redirect to error - if code == "" { - log.Error().Msg("No code provided") - c.Redirect(http.StatusPermanentRedirect, "/error") - return - } - - log.Debug().Msg("Got code") - - // Get provider - provider := api.Providers.GetProvider(providerName.Provider) - - log.Debug().Str("provider", providerName.Provider).Msg("Got provider") - - // Provider does not exist - if provider == nil { - c.Redirect(http.StatusPermanentRedirect, "/not-found") - return - } - - // Exchange token (authenticates user) - _, tokenErr := provider.ExchangeToken(code) - - log.Debug().Msg("Got token") - - // Handle error - if api.handleError(c, "Failed to exchange token", tokenErr) { - return - } - - // Get email - email, emailErr := api.Providers.GetUser(providerName.Provider) - - log.Debug().Str("email", email).Msg("Got email") - - // Handle error - if api.handleError(c, "Failed to get user", emailErr) { - return - } - - // Email is not whitelisted - if !api.Auth.EmailWhitelisted(email) { - log.Warn().Str("email", email).Msg("Email not whitelisted") - - // Build query - unauthorizedQuery, unauthorizedQueryErr := query.Values(types.UnauthorizedQuery{ - Username: email, - }) - - // Handle error - if api.handleError(c, "Failed to build query", unauthorizedQueryErr) { - return - } - - // Redirect to unauthorized - c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/unauthorized?%s", api.Config.AppURL, unauthorizedQuery.Encode())) - } - - log.Debug().Msg("Email whitelisted") - - // Create session cookie - api.Auth.CreateSessionCookie(c, &types.SessionCookie{ - Username: email, - Provider: providerName.Provider, - }) - - // Get redirect URI - redirectURI, redirectURIErr := c.Cookie("tinyauth_redirect_uri") - - // If it is empty it means that no redirect_uri was provided to the login screen so we just log in - if redirectURIErr != nil { - c.Redirect(http.StatusPermanentRedirect, api.Config.AppURL) - } - - log.Debug().Str("redirectURI", redirectURI).Msg("Got redirect URI") - - // Clean up redirect cookie since we already have the value - c.SetCookie("tinyauth_redirect_uri", "", -1, "/", api.Domain, api.Config.CookieSecure, true) - - // Build query - redirectQuery, redirectQueryErr := query.Values(types.LoginQuery{ - RedirectURI: redirectURI, - }) - - log.Debug().Msg("Got redirect query") - - // Handle error - if api.handleError(c, "Failed to build query", redirectQueryErr) { - return - } - - // Redirect to continue with the redirect URI - c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/continue?%s", api.Config.AppURL, redirectQuery.Encode())) - }) - - // Simple healthcheck - api.Router.GET("/api/healthcheck", func(c *gin.Context) { - c.JSON(200, gin.H{ - "status": 200, - "message": "OK", - }) - }) + // App + api.Router.GET("/api/healthcheck", api.Handlers.HealthcheckHandler) } func (api *API) Run() { @@ -724,23 +105,12 @@ func (api *API) Run() { // Run server err := api.Router.Run(fmt.Sprintf("%s:%d", api.Config.Address, api.Config.Port)) - // Check error + // Check for errors if err != nil { log.Fatal().Err(err).Msg("Failed to start server") } } -// handleError logs the error and redirects to the error page (only meant for stuff the user may access does not apply for login paths) -func (api *API) handleError(c *gin.Context, msg string, err error) bool { - // If error is not nil log it and redirect to error page also return true so we can stop further processing - if err != nil { - log.Error().Err(err).Msg(msg) - c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", api.Config.AppURL)) - return true - } - return false -} - // zerolog is a middleware for gin that logs requests using zerolog func zerolog() gin.HandlerFunc { return func(c *gin.Context) { diff --git a/internal/api/api_test.go b/internal/api/api_test.go index 28c01b2..a110519 100644 --- a/internal/api/api_test.go +++ b/internal/api/api_test.go @@ -5,11 +5,13 @@ import ( "io" "net/http" "net/http/httptest" + "reflect" "strings" "testing" "tinyauth/internal/api" "tinyauth/internal/auth" "tinyauth/internal/docker" + "tinyauth/internal/handlers" "tinyauth/internal/hooks" "tinyauth/internal/providers" "tinyauth/internal/types" @@ -19,13 +21,21 @@ import ( // Simple API config for tests var apiConfig = types.APIConfig{ - Port: 8080, - Address: "0.0.0.0", - Secret: "super-secret-api-thing-for-tests", // It is 32 chars long - AppURL: "http://tinyauth.localhost", + Port: 8080, + Address: "0.0.0.0", + Secret: "super-secret-api-thing-for-tests", // It is 32 chars long + CookieSecure: false, + SessionExpiry: 3600, +} + +// Simple handlers config for tests +var handlersConfig = types.HandlersConfig{ + AppURL: "http://localhost:8080", + Domain: ".localhost", CookieSecure: false, - SessionExpiry: 3600, DisableContinue: false, + Title: "Tinyauth", + GenericName: "Generic", } // Cookie @@ -67,8 +77,11 @@ func getAPI(t *testing.T) *api.API { // Create hooks service hooks := hooks.NewHooks(auth, providers) + // Create handlers service + handlers := handlers.NewHandlers(handlersConfig, auth, hooks, providers) + // Create API - api := api.NewAPI(apiConfig, hooks, auth, providers) + api := api.NewAPI(apiConfig, handlers) // Setup routes api.Init() @@ -123,6 +136,70 @@ func TestLogin(t *testing.T) { } } +// Test app context +func TestAppContext(t *testing.T) { + t.Log("Testing app context") + + // Get API + api := getAPI(t) + + // Create recorder + recorder := httptest.NewRecorder() + + // Create request + req, err := http.NewRequest("GET", "/api/app", nil) + + // Check if there was an error + if err != nil { + t.Fatalf("Error creating request: %v", err) + } + + // Set the cookie + req.AddCookie(&http.Cookie{ + Name: "tinyauth", + Value: cookie, + }) + + // Serve the request + api.Router.ServeHTTP(recorder, req) + + // Assert + assert.Equal(t, recorder.Code, http.StatusOK) + + // Read the body of the response + body, bodyErr := io.ReadAll(recorder.Body) + + // Check if there was an error + if bodyErr != nil { + t.Fatalf("Error getting body: %v", bodyErr) + } + + // Unmarshal the body into the user struct + var app types.AppContext + + jsonErr := json.Unmarshal(body, &app) + + // Check if there was an error + if jsonErr != nil { + t.Fatalf("Error unmarshalling body: %v", jsonErr) + } + + // Create tests values + expected := types.AppContext{ + Status: 200, + Message: "OK", + ConfiguredProviders: []string{"username"}, + DisableContinue: false, + Title: "Tinyauth", + GenericName: "Generic", + } + + // We should get the username back + if !reflect.DeepEqual(app, expected) { + t.Fatalf("Expected %v, got %v", expected, app) + } +} + // Test user context func TestUserContext(t *testing.T) { t.Log("Testing user context") diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go new file mode 100644 index 0000000..fc98bd1 --- /dev/null +++ b/internal/handlers/handlers.go @@ -0,0 +1,634 @@ +package handlers + +import ( + "fmt" + "math/rand/v2" + "net/http" + "strings" + "tinyauth/internal/auth" + "tinyauth/internal/hooks" + "tinyauth/internal/providers" + "tinyauth/internal/types" + + "github.com/gin-gonic/gin" + "github.com/google/go-querystring/query" + "github.com/pquerna/otp/totp" + "github.com/rs/zerolog/log" +) + +func NewHandlers(config types.HandlersConfig, auth *auth.Auth, hooks *hooks.Hooks, providers *providers.Providers) *Handlers { + return &Handlers{ + Config: config, + Auth: auth, + Hooks: hooks, + Providers: providers, + } +} + +type Handlers struct { + Config types.HandlersConfig + Auth *auth.Auth + Hooks *hooks.Hooks + Providers *providers.Providers +} + +func (h *Handlers) AuthHandler(c *gin.Context) { + // Create struct for proxy + var proxy types.Proxy + + // Bind URI + err := c.BindUri(&proxy) + + // Handle error + if err != nil { + log.Error().Err(err).Msg("Failed to bind URI") + c.JSON(400, gin.H{ + "status": 400, + "message": "Bad Request", + }) + return + } + + // Check if the request is coming from a browser (tools like curl/bruno use */* and they don't include the text/html) + isBrowser := strings.Contains(c.Request.Header.Get("Accept"), "text/html") + + if isBrowser { + log.Debug().Msg("Request is most likely coming from a browser") + } else { + log.Debug().Msg("Request is most likely not coming from a browser") + } + + log.Debug().Interface("proxy", proxy.Proxy).Msg("Got proxy") + + // Check if auth is enabled + authEnabled, err := h.Auth.AuthEnabled(c) + + // Handle error + if err != nil { + log.Error().Err(err).Msg("Failed to check if auth is enabled") + + if proxy.Proxy == "nginx" || !isBrowser { + c.JSON(500, gin.H{ + "status": 500, + "message": "Internal Server Error", + }) + return + } + + c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) + return + } + + // If auth is not enabled, return 200 + if !authEnabled { + c.JSON(200, gin.H{ + "status": 200, + "message": "Authenticated", + }) + return + } + + // Get user context + userContext := h.Hooks.UseUserContext(c) + + // Get headers + uri := c.Request.Header.Get("X-Forwarded-Uri") + proto := c.Request.Header.Get("X-Forwarded-Proto") + host := c.Request.Header.Get("X-Forwarded-Host") + + // Check if user is logged in + if userContext.IsLoggedIn { + log.Debug().Msg("Authenticated") + + // Check if user is allowed to access subdomain, if request is nginx.example.com the subdomain (resource) is nginx + appAllowed, err := h.Auth.ResourceAllowed(c, userContext) + + // Check if there was an error + if err != nil { + log.Error().Err(err).Msg("Failed to check if app is allowed") + + if proxy.Proxy == "nginx" || !isBrowser { + c.JSON(500, gin.H{ + "status": 500, + "message": "Internal Server Error", + }) + return + } + + c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) + return + } + + log.Debug().Bool("appAllowed", appAllowed).Msg("Checking if app is allowed") + + // The user is not allowed to access the app + if !appAllowed { + log.Warn().Str("username", userContext.Username).Str("host", host).Msg("User not allowed") + + // Set WWW-Authenticate header + c.Header("WWW-Authenticate", "Basic realm=\"tinyauth\"") + + if proxy.Proxy == "nginx" || !isBrowser { + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } + + // Build query + queries, err := query.Values(types.UnauthorizedQuery{ + Username: userContext.Username, + Resource: strings.Split(host, ".")[0], + }) + + // Handle error (no need to check for nginx/headers since we are sure we are using caddy/traefik) + if err != nil { + log.Error().Err(err).Msg("Failed to build query") + c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) + return + } + + // We are using caddy/traefik so redirect + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", h.Config.AppURL, queries.Encode())) + return + } + + // Set the user header + c.Header("Remote-User", userContext.Username) + + // The user is allowed to access the app + c.JSON(200, gin.H{ + "status": 200, + "message": "Authenticated", + }) + return + } + + // The user is not logged in + log.Debug().Msg("Unauthorized") + + // Set www-authenticate header + c.Header("WWW-Authenticate", "Basic realm=\"tinyauth\"") + + if proxy.Proxy == "nginx" || !isBrowser { + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } + + queries, err := query.Values(types.LoginQuery{ + RedirectURI: fmt.Sprintf("%s://%s%s", proto, host, uri), + }) + + if err != nil { + log.Error().Err(err).Msg("Failed to build query") + c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) + return + } + + log.Debug().Interface("redirect_uri", fmt.Sprintf("%s://%s%s", proto, host, uri)).Msg("Redirecting to login") + + // Redirect to login + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/?%s", h.Config.AppURL, queries.Encode())) +} + +func (h *Handlers) LoginHandler(c *gin.Context) { + // Create login struct + var login types.LoginRequest + + // Bind JSON + err := c.BindJSON(&login) + + // Handle error + if err != nil { + log.Error().Err(err).Msg("Failed to bind JSON") + c.JSON(400, gin.H{ + "status": 400, + "message": "Bad Request", + }) + return + } + + log.Debug().Msg("Got login request") + + // Get user based on username + user := h.Auth.GetUser(login.Username) + + // User does not exist + if user == nil { + log.Debug().Str("username", login.Username).Msg("User not found") + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } + + log.Debug().Msg("Got user") + + // Check if password is correct + if !h.Auth.CheckPassword(*user, login.Password) { + log.Debug().Str("username", login.Username).Msg("Password incorrect") + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } + + log.Debug().Msg("Password correct, checking totp") + + // Check if user has totp enabled + if user.TotpSecret != "" { + log.Debug().Msg("Totp enabled") + + // Set totp pending cookie + h.Auth.CreateSessionCookie(c, &types.SessionCookie{ + Username: login.Username, + Provider: "username", + TotpPending: true, + }) + + // Return totp required + c.JSON(200, gin.H{ + "status": 200, + "message": "Waiting for totp", + "totpPending": true, + }) + + // Stop further processing + return + } + + // Create session cookie with username as provider + h.Auth.CreateSessionCookie(c, &types.SessionCookie{ + Username: login.Username, + Provider: "username", + }) + + // Return logged in + c.JSON(200, gin.H{ + "status": 200, + "message": "Logged in", + "totpPending": false, + }) +} + +func (h *Handlers) TotpHandler(c *gin.Context) { + // Create totp struct + var totpReq types.TotpRequest + + // Bind JSON + err := c.BindJSON(&totpReq) + + // Handle error + if err != nil { + log.Error().Err(err).Msg("Failed to bind JSON") + c.JSON(400, gin.H{ + "status": 400, + "message": "Bad Request", + }) + return + } + + log.Debug().Msg("Checking totp") + + // Get user context + userContext := h.Hooks.UseUserContext(c) + + // Check if we have a user + if userContext.Username == "" { + log.Debug().Msg("No user context") + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } + + // Get user + user := h.Auth.GetUser(userContext.Username) + + // Check if user exists + if user == nil { + log.Debug().Msg("User not found") + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } + + // Check if totp is correct + totpOk := totp.Validate(totpReq.Code, user.TotpSecret) + + // TOTP is incorrect + if !totpOk { + log.Debug().Msg("Totp incorrect") + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } + + log.Debug().Msg("Totp correct") + + // Create session cookie with username as provider + h.Auth.CreateSessionCookie(c, &types.SessionCookie{ + Username: user.Username, + Provider: "username", + }) + + // Return logged in + c.JSON(200, gin.H{ + "status": 200, + "message": "Logged in", + }) +} + +func (h *Handlers) LogoutHandler(c *gin.Context) { + log.Debug().Msg("Logging out") + + // Delete session cookie + h.Auth.DeleteSessionCookie(c) + + log.Debug().Msg("Cleaning up redirect cookie") + + // Clean up redirect cookie if it exists + c.SetCookie("tinyauth_redirect_uri", "", -1, "/", h.Config.Domain, h.Config.CookieSecure, true) + + // Return logged out + c.JSON(200, gin.H{ + "status": 200, + "message": "Logged out", + }) +} + +func (h *Handlers) AppHandler(c *gin.Context) { + log.Debug().Msg("Getting app context") + + // Get configured providers + configuredProviders := h.Providers.GetConfiguredProviders() + + // We have username/password configured so add it to our providers + if h.Auth.UserAuthConfigured() { + configuredProviders = append(configuredProviders, "username") + } + + // Create app context struct + appContext := types.AppContext{ + Status: 200, + Message: "OK", + ConfiguredProviders: configuredProviders, + DisableContinue: h.Config.DisableContinue, + Title: h.Config.Title, + GenericName: h.Config.GenericName, + } + + // Return app context + c.JSON(200, appContext) +} + +func (h *Handlers) UserHandler(c *gin.Context) { + log.Debug().Msg("Getting user context") + + // Get user context + userContext := h.Hooks.UseUserContext(c) + + // Create user context response + userContextResponse := types.UserContextResponse{ + Status: 200, + IsLoggedIn: userContext.IsLoggedIn, + Username: userContext.Username, + Provider: userContext.Provider, + Oauth: userContext.OAuth, + TotpPending: userContext.TotpPending, + } + + // If we are not logged in we set the status to 401 and add the WWW-Authenticate header else we set it to 200 + if !userContext.IsLoggedIn { + log.Debug().Msg("Unauthorized") + c.Header("WWW-Authenticate", "Basic realm=\"tinyauth\"") + userContextResponse.Message = "Unauthorized" + } else { + log.Debug().Interface("userContext", userContext).Msg("Authenticated") + userContextResponse.Message = "Authenticated" + } + + // Return user context + c.JSON(200, userContextResponse) +} + +func (h *Handlers) OauthUrlHandler(c *gin.Context) { + // Create struct for OAuth request + var request types.OAuthRequest + + // Bind URI + err := c.BindUri(&request) + + // Handle error + if err != nil { + log.Error().Err(err).Msg("Failed to bind URI") + c.JSON(400, gin.H{ + "status": 400, + "message": "Bad Request", + }) + return + } + + log.Debug().Msg("Got OAuth request") + + // Check if provider exists + provider := h.Providers.GetProvider(request.Provider) + + // Provider does not exist + if provider == nil { + c.JSON(404, gin.H{ + "status": 404, + "message": "Not Found", + }) + return + } + + log.Debug().Str("provider", request.Provider).Msg("Got provider") + + // Get auth URL + authURL := provider.GetAuthURL() + + log.Debug().Msg("Got auth URL") + + // Get redirect URI + redirectURI := c.Query("redirect_uri") + + // Set redirect cookie if redirect URI is provided + if redirectURI != "" { + log.Debug().Str("redirectURI", redirectURI).Msg("Setting redirect cookie") + c.SetCookie("tinyauth_redirect_uri", redirectURI, 3600, "/", h.Config.Domain, h.Config.CookieSecure, true) + } + + // Tailscale does not have an auth url so we create a random code (does not need to be secure) to avoid caching and send it + if request.Provider == "tailscale" { + // Build tailscale query + tailscaleQuery, err := query.Values(types.TailscaleQuery{ + Code: (1000 + rand.IntN(9000)), + }) + + // Handle error + if err != nil { + log.Error().Err(err).Msg("Failed to build query") + c.JSON(500, gin.H{ + "status": 500, + "message": "Internal Server Error", + }) + return + } + + // Return tailscale URL (immidiately redirects to the callback) + c.JSON(200, gin.H{ + "status": 200, + "message": "OK", + "url": fmt.Sprintf("%s/api/oauth/callback/tailscale?%s", h.Config.AppURL, tailscaleQuery.Encode()), + }) + return + } + + // Return auth URL + c.JSON(200, gin.H{ + "status": 200, + "message": "OK", + "url": authURL, + }) +} + +func (h *Handlers) OauthCallbackHandler(c *gin.Context) { + // Create struct for OAuth request + var providerName types.OAuthRequest + + // Bind URI + err := c.BindUri(&providerName) + + // Handle error + if err != nil { + log.Error().Err(err).Msg("Failed to bind URI") + c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) + return + } + + log.Debug().Interface("provider", providerName.Provider).Msg("Got provider name") + + // Get code + code := c.Query("code") + + // Code empty so redirect to error + if code == "" { + log.Error().Msg("No code provided") + c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) + return + } + + log.Debug().Msg("Got code") + + // Get provider + provider := h.Providers.GetProvider(providerName.Provider) + + log.Debug().Str("provider", providerName.Provider).Msg("Got provider") + + // Provider does not exist + if provider == nil { + c.Redirect(http.StatusPermanentRedirect, "/not-found") + return + } + + // Exchange token (authenticates user) + _, err = provider.ExchangeToken(code) + + log.Debug().Msg("Got token") + + // Handle error + if err != nil { + log.Error().Msg("Failed to exchange token") + c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) + return + } + + // Get email + email, err := h.Providers.GetUser(providerName.Provider) + + log.Debug().Str("email", email).Msg("Got email") + + // Handle error + if err != nil { + log.Error().Msg("Failed to get email") + c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) + return + } + + // Email is not whitelisted + if !h.Auth.EmailWhitelisted(email) { + log.Warn().Str("email", email).Msg("Email not whitelisted") + + // Build query + unauthorizedQuery, err := query.Values(types.UnauthorizedQuery{ + Username: email, + }) + + // Handle error + if err != nil { + log.Error().Msg("Failed to build query") + c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) + return + } + + // Redirect to unauthorized + c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/unauthorized?%s", h.Config.AppURL, unauthorizedQuery.Encode())) + } + + log.Debug().Msg("Email whitelisted") + + // Create session cookie + h.Auth.CreateSessionCookie(c, &types.SessionCookie{ + Username: email, + Provider: providerName.Provider, + }) + + // Get redirect URI + redirectURI, redirectURIErr := c.Cookie("tinyauth_redirect_uri") + + // If it is empty it means that no redirect_uri was provided to the login screen so we just log in + if redirectURIErr != nil { + c.Redirect(http.StatusPermanentRedirect, h.Config.AppURL) + } + + log.Debug().Str("redirectURI", redirectURI).Msg("Got redirect URI") + + // Clean up redirect cookie since we already have the value + c.SetCookie("tinyauth_redirect_uri", "", -1, "/", h.Config.Domain, h.Config.CookieSecure, true) + + // Build query + redirectQuery, err := query.Values(types.LoginQuery{ + RedirectURI: redirectURI, + }) + + log.Debug().Msg("Got redirect query") + + // Handle error + if err != nil { + log.Error().Msg("Failed to build query") + c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) + return + } + + // Redirect to continue with the redirect URI + c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/continue?%s", h.Config.AppURL, redirectQuery.Encode())) +} + +func (h *Handlers) HealthcheckHandler(c *gin.Context) { + c.JSON(200, gin.H{ + "status": 200, + "message": "OK", + }) +} diff --git a/internal/types/types.go b/internal/types/types.go index 0f4ead7..16f2482 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -69,15 +69,12 @@ type UserContext struct { // APIConfig is the configuration for the API type APIConfig struct { - Port int - Address string - Secret string - AppURL string - CookieSecure bool - SessionExpiry int - DisableContinue bool - GenericName string - Title string + Port int + Address string + Secret string + CookieSecure bool + SessionExpiry int + Domain string } // OAuthConfig is the configuration for the providers @@ -164,3 +161,13 @@ type AppContext struct { type TotpRequest struct { Code string `json:"code"` } + +// Server configuration +type HandlersConfig struct { + AppURL string + Domain string + CookieSecure bool + DisableContinue bool + GenericName string + Title string +} diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 1e68aee..efda749 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -46,14 +46,14 @@ func ParseUsers(users string) (types.Users, error) { return usersParsed, nil } -// Root url parses parses a hostname and returns the root domain (e.g. sub1.sub2.domain.com -> sub2.domain.com) -func GetRootURL(urlSrc string) (string, error) { +// Get upper domain parses a hostname and returns the upper domain (e.g. sub1.sub2.domain.com -> sub2.domain.com) +func GetUpperDomain(urlSrc string) (string, error) { // Make sure the url is valid - urlParsed, parseErr := url.Parse(urlSrc) + urlParsed, err := url.Parse(urlSrc) // Check if there was an error - if parseErr != nil { - return "", parseErr + if err != nil { + return "", err } // Split the hostname by period diff --git a/internal/utils/utils_test.go b/internal/utils/utils_test.go index b3774ce..52e4dcd 100644 --- a/internal/utils/utils_test.go +++ b/internal/utils/utils_test.go @@ -38,15 +38,15 @@ func TestParseUsers(t *testing.T) { } } -// Test the get root url function -func TestGetRootURL(t *testing.T) { - t.Log("Testing get root url with a valid url") +// Test the get upper domain function +func TestGetUpperDomain(t *testing.T) { + t.Log("Testing get upper domain with a valid url") - // Test the get root url function with a valid url + // Test the get upper domain function with a valid url url := "https://sub1.sub2.domain.com:8080" expected := "sub2.domain.com" - result, err := utils.GetRootURL(url) + result, err := utils.GetUpperDomain(url) // Check if there was an error if err != nil {