mirror of
				https://github.com/steveiliop56/tinyauth.git
				synced 2025-10-31 14:15:50 +00:00 
			
		
		
		
	Compare commits
	
		
			3 Commits
		
	
	
		
			81944e770e
			...
			refactor/h
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|   | 156bba6141 | ||
|   | 7d1252f3c7 | ||
|   | 0c91465c63 | 
							
								
								
									
										6
									
								
								.github/workflows/release.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										6
									
								
								.github/workflows/release.yml
									
									
									
									
										vendored
									
									
								
							| @@ -125,9 +125,9 @@ jobs: | |||||||
|         with: |         with: | ||||||
|           images: ghcr.io/${{ github.repository_owner }}/tinyauth |           images: ghcr.io/${{ github.repository_owner }}/tinyauth | ||||||
|           tags: | |           tags: | | ||||||
|             type=semver,pattern=v{{version}} |             type=semver,pattern={{version}},prefix=v | ||||||
|             type=semver,pattern=v{{major}} |             type=semver,pattern={{major}},prefix=v | ||||||
|             type=semver,pattern=v{{major}}.{{minor}} |             type=semver,pattern={{major}}.{{minor}},prefix=v | ||||||
|  |  | ||||||
|       - name: Create manifest list and push |       - name: Create manifest list and push | ||||||
|         working-directory: ${{ runner.temp }}/digests |         working-directory: ${{ runner.temp }}/digests | ||||||
|   | |||||||
							
								
								
									
										59
									
								
								cmd/root.go
									
									
									
									
									
								
							
							
						
						
									
										59
									
								
								cmd/root.go
									
									
									
									
									
								
							| @@ -2,6 +2,7 @@ package cmd | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"errors" | 	"errors" | ||||||
|  | 	"fmt" | ||||||
| 	"os" | 	"os" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
| @@ -11,6 +12,7 @@ import ( | |||||||
| 	"tinyauth/internal/assets" | 	"tinyauth/internal/assets" | ||||||
| 	"tinyauth/internal/auth" | 	"tinyauth/internal/auth" | ||||||
| 	"tinyauth/internal/docker" | 	"tinyauth/internal/docker" | ||||||
|  | 	"tinyauth/internal/handlers" | ||||||
| 	"tinyauth/internal/hooks" | 	"tinyauth/internal/hooks" | ||||||
| 	"tinyauth/internal/providers" | 	"tinyauth/internal/providers" | ||||||
| 	"tinyauth/internal/types" | 	"tinyauth/internal/types" | ||||||
| @@ -33,8 +35,8 @@ var rootCmd = &cobra.Command{ | |||||||
|  |  | ||||||
| 		// Get config | 		// Get config | ||||||
| 		var config types.Config | 		var config types.Config | ||||||
| 		parseErr := viper.Unmarshal(&config) | 		err := viper.Unmarshal(&config) | ||||||
| 		HandleError(parseErr, "Failed to parse config") | 		HandleError(err, "Failed to parse config") | ||||||
|  |  | ||||||
| 		// Secrets | 		// Secrets | ||||||
| 		config.Secret = utils.GetSecret(config.Secret, config.SecretFile) | 		config.Secret = utils.GetSecret(config.Secret, config.SecretFile) | ||||||
| @@ -45,8 +47,8 @@ var rootCmd = &cobra.Command{ | |||||||
|  |  | ||||||
| 		// Validate config | 		// Validate config | ||||||
| 		validator := validator.New() | 		validator := validator.New() | ||||||
| 		validateErr := validator.Struct(config) | 		err = validator.Struct(config) | ||||||
| 		HandleError(validateErr, "Failed to validate config") | 		HandleError(err, "Failed to validate config") | ||||||
|  |  | ||||||
| 		// Logger | 		// Logger | ||||||
| 		log.Logger = log.Level(zerolog.Level(config.LogLevel)) | 		log.Logger = log.Level(zerolog.Level(config.LogLevel)) | ||||||
| @@ -54,9 +56,8 @@ var rootCmd = &cobra.Command{ | |||||||
|  |  | ||||||
| 		// Users | 		// Users | ||||||
| 		log.Info().Msg("Parsing users") | 		log.Info().Msg("Parsing users") | ||||||
| 		users, usersErr := utils.GetUsers(config.Users, config.UsersFile) | 		users, err := utils.GetUsers(config.Users, config.UsersFile) | ||||||
|  | 		HandleError(err, "Failed to parse users") | ||||||
| 		HandleError(usersErr, "Failed to parse users") |  | ||||||
|  |  | ||||||
| 		if len(users) == 0 && !utils.OAuthConfigured(config) { | 		if len(users) == 0 && !utils.OAuthConfigured(config) { | ||||||
| 			HandleError(errors.New("no users or OAuth configured"), "No users or OAuth configured") | 			HandleError(errors.New("no users or OAuth configured"), "No users or OAuth configured") | ||||||
| @@ -66,8 +67,15 @@ var rootCmd = &cobra.Command{ | |||||||
| 		oauthWhitelist := utils.Filter(strings.Split(config.OAuthWhitelist, ","), func(val string) bool { | 		oauthWhitelist := utils.Filter(strings.Split(config.OAuthWhitelist, ","), func(val string) bool { | ||||||
| 			return val != "" | 			return val != "" | ||||||
| 		}) | 		}) | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Parsed OAuth whitelist") | 		log.Debug().Msg("Parsed OAuth whitelist") | ||||||
|  |  | ||||||
|  | 		// Get domain | ||||||
|  | 		log.Debug().Msg("Getting domain") | ||||||
|  | 		domain, err := utils.GetUpperDomain(config.AppURL) | ||||||
|  | 		HandleError(err, "Failed to get upper domain") | ||||||
|  | 		log.Info().Str("domain", domain).Msg("Using domain for cookie store") | ||||||
|  |  | ||||||
| 		// Create OAuth config | 		// Create OAuth config | ||||||
| 		oauthConfig := types.OAuthConfig{ | 		oauthConfig := types.OAuthConfig{ | ||||||
| 			GithubClientId:        config.GithubClientId, | 			GithubClientId:        config.GithubClientId, | ||||||
| @@ -85,7 +93,25 @@ var rootCmd = &cobra.Command{ | |||||||
| 			AppURL:                config.AppURL, | 			AppURL:                config.AppURL, | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Parsed OAuth config") | 		// Create handlers config | ||||||
|  | 		serverConfig := types.HandlersConfig{ | ||||||
|  | 			AppURL:          config.AppURL, | ||||||
|  | 			Domain:          fmt.Sprintf(".%s", domain), | ||||||
|  | 			CookieSecure:    config.CookieSecure, | ||||||
|  | 			DisableContinue: config.DisableContinue, | ||||||
|  | 			Title:           config.Title, | ||||||
|  | 			GenericName:     config.GenericName, | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		// Create api config | ||||||
|  | 		apiConfig := types.APIConfig{ | ||||||
|  | 			Port:          config.Port, | ||||||
|  | 			Address:       config.Address, | ||||||
|  | 			Secret:        config.Secret, | ||||||
|  | 			CookieSecure:  config.CookieSecure, | ||||||
|  | 			SessionExpiry: config.SessionExpiry, | ||||||
|  | 			Domain:        domain, | ||||||
|  | 		} | ||||||
|  |  | ||||||
| 		// Create docker service | 		// Create docker service | ||||||
| 		docker := docker.NewDocker() | 		docker := docker.NewDocker() | ||||||
| @@ -106,18 +132,11 @@ var rootCmd = &cobra.Command{ | |||||||
| 		// Create hooks service | 		// Create hooks service | ||||||
| 		hooks := hooks.NewHooks(auth, providers) | 		hooks := hooks.NewHooks(auth, providers) | ||||||
|  |  | ||||||
|  | 		// Create handlers | ||||||
|  | 		handlers := handlers.NewHandlers(serverConfig, auth, hooks, providers) | ||||||
|  |  | ||||||
| 		// Create API | 		// Create API | ||||||
| 		api := api.NewAPI(types.APIConfig{ | 		api := api.NewAPI(apiConfig, handlers) | ||||||
| 			Port:            config.Port, |  | ||||||
| 			Address:         config.Address, |  | ||||||
| 			Secret:          config.Secret, |  | ||||||
| 			AppURL:          config.AppURL, |  | ||||||
| 			CookieSecure:    config.CookieSecure, |  | ||||||
| 			DisableContinue: config.DisableContinue, |  | ||||||
| 			SessionExpiry:   config.SessionExpiry, |  | ||||||
| 			Title:           config.Title, |  | ||||||
| 			GenericName:     config.GenericName, |  | ||||||
| 		}, hooks, auth, providers) |  | ||||||
|  |  | ||||||
| 		// Setup routes | 		// Setup routes | ||||||
| 		api.Init() | 		api.Init() | ||||||
| @@ -134,7 +153,7 @@ func Execute() { | |||||||
| } | } | ||||||
|  |  | ||||||
| func HandleError(err error, msg string) { | func HandleError(err error, msg string) { | ||||||
| 	// If error log it and exit | 	// If error, log it and exit | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Fatal().Err(err).Msg(msg) | 		log.Fatal().Err(err).Msg(msg) | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -3,42 +3,30 @@ package api | |||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"io/fs" | 	"io/fs" | ||||||
| 	"math/rand/v2" |  | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"os" |  | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
| 	"tinyauth/internal/assets" | 	"tinyauth/internal/assets" | ||||||
| 	"tinyauth/internal/auth" | 	"tinyauth/internal/handlers" | ||||||
| 	"tinyauth/internal/hooks" |  | ||||||
| 	"tinyauth/internal/providers" |  | ||||||
| 	"tinyauth/internal/types" | 	"tinyauth/internal/types" | ||||||
| 	"tinyauth/internal/utils" |  | ||||||
|  |  | ||||||
| 	"github.com/gin-contrib/sessions" | 	"github.com/gin-contrib/sessions" | ||||||
| 	"github.com/gin-contrib/sessions/cookie" | 	"github.com/gin-contrib/sessions/cookie" | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"github.com/google/go-querystring/query" |  | ||||||
| 	"github.com/pquerna/otp/totp" |  | ||||||
| 	"github.com/rs/zerolog/log" | 	"github.com/rs/zerolog/log" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func NewAPI(config types.APIConfig, hooks *hooks.Hooks, auth *auth.Auth, providers *providers.Providers) *API { | func NewAPI(config types.APIConfig, handlers *handlers.Handlers) *API { | ||||||
| 	return &API{ | 	return &API{ | ||||||
| 		Config:    config, | 		Config:   config, | ||||||
| 		Hooks:     hooks, | 		Handlers: handlers, | ||||||
| 		Auth:      auth, |  | ||||||
| 		Providers: providers, |  | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| type API struct { | type API struct { | ||||||
| 	Config    types.APIConfig | 	Config   types.APIConfig | ||||||
| 	Router    *gin.Engine | 	Router   *gin.Engine | ||||||
| 	Hooks     *hooks.Hooks | 	Handlers *handlers.Handlers | ||||||
| 	Auth      *auth.Auth |  | ||||||
| 	Providers *providers.Providers |  | ||||||
| 	Domain    string |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func (api *API) Init() { | func (api *API) Init() { | ||||||
| @@ -52,10 +40,10 @@ func (api *API) Init() { | |||||||
|  |  | ||||||
| 	// Read UI assets | 	// Read UI assets | ||||||
| 	log.Debug().Msg("Setting up assets") | 	log.Debug().Msg("Setting up assets") | ||||||
| 	dist, distErr := fs.Sub(assets.Assets, "dist") | 	dist, err := fs.Sub(assets.Assets, "dist") | ||||||
|  |  | ||||||
| 	if distErr != nil { | 	if err != nil { | ||||||
| 		log.Fatal().Err(distErr).Msg("Failed to get UI assets") | 		log.Fatal().Err(err).Msg("Failed to get UI assets") | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Create file server | 	// Create file server | ||||||
| @@ -66,22 +54,9 @@ func (api *API) Init() { | |||||||
| 	log.Debug().Msg("Setting up cookie store") | 	log.Debug().Msg("Setting up cookie store") | ||||||
| 	store := cookie.NewStore([]byte(api.Config.Secret)) | 	store := cookie.NewStore([]byte(api.Config.Secret)) | ||||||
|  |  | ||||||
| 	// Get domain to use for session cookies |  | ||||||
| 	log.Debug().Msg("Getting domain") |  | ||||||
| 	domain, domainErr := utils.GetRootURL(api.Config.AppURL) |  | ||||||
|  |  | ||||||
| 	if domainErr != nil { |  | ||||||
| 		log.Fatal().Err(domainErr).Msg("Failed to get domain") |  | ||||||
| 		os.Exit(1) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	log.Info().Str("domain", domain).Msg("Using domain for cookies") |  | ||||||
|  |  | ||||||
| 	api.Domain = fmt.Sprintf(".%s", domain) |  | ||||||
|  |  | ||||||
| 	// Use session middleware | 	// Use session middleware | ||||||
| 	store.Options(sessions.Options{ | 	store.Options(sessions.Options{ | ||||||
| 		Domain:   api.Domain, | 		Domain:   api.Config.Domain, | ||||||
| 		Path:     "/", | 		Path:     "/", | ||||||
| 		HttpOnly: true, | 		HttpOnly: true, | ||||||
| 		Secure:   api.Config.CookieSecure, | 		Secure:   api.Config.CookieSecure, | ||||||
| @@ -94,17 +69,7 @@ func (api *API) Init() { | |||||||
| 	router.Use(func(c *gin.Context) { | 	router.Use(func(c *gin.Context) { | ||||||
| 		// If not an API request, serve the UI | 		// If not an API request, serve the UI | ||||||
| 		if !strings.HasPrefix(c.Request.URL.Path, "/api") { | 		if !strings.HasPrefix(c.Request.URL.Path, "/api") { | ||||||
| 			_, err := fs.Stat(dist, strings.TrimPrefix(c.Request.URL.Path, "/")) |  | ||||||
|  |  | ||||||
| 			// If the file doesn't exist, serve the index.html |  | ||||||
| 			if os.IsNotExist(err) { |  | ||||||
| 				c.Request.URL.Path = "/" |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			// Serve the file |  | ||||||
| 			fileServer.ServeHTTP(c.Writer, c.Request) | 			fileServer.ServeHTTP(c.Writer, c.Request) | ||||||
|  |  | ||||||
| 			// Stop further processing |  | ||||||
| 			c.Abort() | 			c.Abort() | ||||||
| 		} | 		} | ||||||
| 	}) | 	}) | ||||||
| @@ -114,608 +79,24 @@ func (api *API) Init() { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (api *API) SetupRoutes() { | func (api *API) SetupRoutes() { | ||||||
| 	api.Router.GET("/api/auth/:proxy", func(c *gin.Context) { | 	// Proxy | ||||||
| 		// Create struct for proxy | 	api.Router.GET("/api/auth/:proxy", api.Handlers.AuthHandler) | ||||||
| 		var proxy types.Proxy |  | ||||||
|  |  | ||||||
| 		// Bind URI | 	// Auth | ||||||
| 		bindErr := c.BindUri(&proxy) | 	api.Router.POST("/api/login", api.Handlers.LoginHandler) | ||||||
|  | 	api.Router.POST("/api/totp", api.Handlers.TotpHandler) | ||||||
|  | 	api.Router.POST("/api/logout", api.Handlers.LogoutHandler) | ||||||
|  |  | ||||||
| 		// Handle error | 	// Context | ||||||
| 		if bindErr != nil { | 	api.Router.GET("/api/app", api.Handlers.AppHandler) | ||||||
| 			log.Error().Err(bindErr).Msg("Failed to bind URI") | 	api.Router.GET("/api/user", api.Handlers.UserHandler) | ||||||
| 			c.JSON(400, gin.H{ |  | ||||||
| 				"status":  400, |  | ||||||
| 				"message": "Bad Request", |  | ||||||
| 			}) |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		// Check if the request is coming from a browser (tools like curl/bruno use */* and they don't include the text/html) | 	// OAuth | ||||||
| 		isBrowser := strings.Contains(c.Request.Header.Get("Accept"), "text/html") | 	api.Router.GET("/api/oauth/url/:provider", api.Handlers.OauthUrlHandler) | ||||||
|  | 	api.Router.GET("/api/oauth/callback/:provider", api.Handlers.OauthCallbackHandler) | ||||||
|  |  | ||||||
| 		if isBrowser { | 	// App | ||||||
| 			log.Debug().Msg("Request is most likely coming from a browser") | 	api.Router.GET("/api/healthcheck", api.Handlers.HealthcheckHandler) | ||||||
| 		} else { |  | ||||||
| 			log.Debug().Msg("Request is most likely not coming from a browser") |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		log.Debug().Interface("proxy", proxy.Proxy).Msg("Got proxy") |  | ||||||
|  |  | ||||||
| 		// Check if auth is enabled |  | ||||||
| 		authEnabled, authEnabledErr := api.Auth.AuthEnabled(c) |  | ||||||
|  |  | ||||||
| 		// Handle error |  | ||||||
| 		if authEnabledErr != nil { |  | ||||||
| 			// Return 500 if nginx is the proxy or if the request is not coming from a browser |  | ||||||
| 			if proxy.Proxy == "nginx" || !isBrowser { |  | ||||||
| 				log.Error().Err(authEnabledErr).Msg("Failed to check if auth is enabled") |  | ||||||
| 				c.JSON(500, gin.H{ |  | ||||||
| 					"status":  500, |  | ||||||
| 					"message": "Internal Server Error", |  | ||||||
| 				}) |  | ||||||
| 				return |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			// Return the internal server error page |  | ||||||
| 			if api.handleError(c, "Failed to check if auth is enabled", authEnabledErr) { |  | ||||||
| 				return |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		// If auth is not enabled, return 200 |  | ||||||
| 		if !authEnabled { |  | ||||||
| 			// The user is allowed to access the app |  | ||||||
| 			c.JSON(200, gin.H{ |  | ||||||
| 				"status":  200, |  | ||||||
| 				"message": "Authenticated", |  | ||||||
| 			}) |  | ||||||
|  |  | ||||||
| 			// Stop further processing |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		// Get user context |  | ||||||
| 		userContext := api.Hooks.UseUserContext(c) |  | ||||||
|  |  | ||||||
| 		// Get headers |  | ||||||
| 		uri := c.Request.Header.Get("X-Forwarded-Uri") |  | ||||||
| 		proto := c.Request.Header.Get("X-Forwarded-Proto") |  | ||||||
| 		host := c.Request.Header.Get("X-Forwarded-Host") |  | ||||||
|  |  | ||||||
| 		// Check if user is logged in |  | ||||||
| 		if userContext.IsLoggedIn { |  | ||||||
| 			log.Debug().Msg("Authenticated") |  | ||||||
|  |  | ||||||
| 			// Check if user is allowed to access subdomain, if request is nginx.example.com the subdomain (resource) is nginx |  | ||||||
| 			appAllowed, appAllowedErr := api.Auth.ResourceAllowed(c, userContext) |  | ||||||
|  |  | ||||||
| 			// Check if there was an error |  | ||||||
| 			if appAllowedErr != nil { |  | ||||||
| 				// Return 500 if nginx is the proxy or if the request is not coming from a browser |  | ||||||
| 				if proxy.Proxy == "nginx" || !isBrowser { |  | ||||||
| 					log.Error().Err(appAllowedErr).Msg("Failed to check if app is allowed") |  | ||||||
| 					c.JSON(500, gin.H{ |  | ||||||
| 						"status":  500, |  | ||||||
| 						"message": "Internal Server Error", |  | ||||||
| 					}) |  | ||||||
| 					return |  | ||||||
| 				} |  | ||||||
|  |  | ||||||
| 				// Return the internal server error page |  | ||||||
| 				if api.handleError(c, "Failed to check if app is allowed", appAllowedErr) { |  | ||||||
| 					return |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			log.Debug().Bool("appAllowed", appAllowed).Msg("Checking if app is allowed") |  | ||||||
|  |  | ||||||
| 			// The user is not allowed to access the app |  | ||||||
| 			if !appAllowed { |  | ||||||
| 				log.Warn().Str("username", userContext.Username).Str("host", host).Msg("User not allowed") |  | ||||||
|  |  | ||||||
| 				// Set WWW-Authenticate header |  | ||||||
| 				c.Header("WWW-Authenticate", "Basic realm=\"tinyauth\"") |  | ||||||
|  |  | ||||||
| 				// Return 401 if nginx is the proxy or if the request is not coming from a browser |  | ||||||
| 				if proxy.Proxy == "nginx" || !isBrowser { |  | ||||||
| 					c.JSON(401, gin.H{ |  | ||||||
| 						"status":  401, |  | ||||||
| 						"message": "Unauthorized", |  | ||||||
| 					}) |  | ||||||
| 					return |  | ||||||
| 				} |  | ||||||
|  |  | ||||||
| 				// Build query |  | ||||||
| 				queries, queryErr := query.Values(types.UnauthorizedQuery{ |  | ||||||
| 					Username: userContext.Username, |  | ||||||
| 					Resource: strings.Split(host, ".")[0], |  | ||||||
| 				}) |  | ||||||
|  |  | ||||||
| 				// Handle error (no need to check for nginx/headers since we are sure we are using caddy/traefik) |  | ||||||
| 				if api.handleError(c, "Failed to build query", queryErr) { |  | ||||||
| 					return |  | ||||||
| 				} |  | ||||||
|  |  | ||||||
| 				// We are using caddy/traefik so redirect |  | ||||||
| 				c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", api.Config.AppURL, queries.Encode())) |  | ||||||
|  |  | ||||||
| 				// Stop further processing |  | ||||||
| 				return |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			// Set the user header |  | ||||||
| 			c.Header("Remote-User", userContext.Username) |  | ||||||
|  |  | ||||||
| 			// The user is allowed to access the app |  | ||||||
| 			c.JSON(200, gin.H{ |  | ||||||
| 				"status":  200, |  | ||||||
| 				"message": "Authenticated", |  | ||||||
| 			}) |  | ||||||
|  |  | ||||||
| 			// Stop further processing |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		// The user is not logged in |  | ||||||
| 		log.Debug().Msg("Unauthorized") |  | ||||||
|  |  | ||||||
| 		// Set www-authenticate header |  | ||||||
| 		c.Header("WWW-Authenticate", "Basic realm=\"tinyauth\"") |  | ||||||
|  |  | ||||||
| 		// Return 401 if nginx is the proxy or if the request is not coming from a browser |  | ||||||
| 		if proxy.Proxy == "nginx" || !isBrowser { |  | ||||||
| 			c.JSON(401, gin.H{ |  | ||||||
| 				"status":  401, |  | ||||||
| 				"message": "Unauthorized", |  | ||||||
| 			}) |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		// Build query |  | ||||||
| 		queries, queryErr := query.Values(types.LoginQuery{ |  | ||||||
| 			RedirectURI: fmt.Sprintf("%s://%s%s", proto, host, uri), |  | ||||||
| 		}) |  | ||||||
|  |  | ||||||
| 		// Handle error (no need to check for nginx/headers since we are sure we are using caddy/traefik) |  | ||||||
| 		if api.handleError(c, "Failed to build query", queryErr) { |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		log.Debug().Interface("redirect_uri", fmt.Sprintf("%s://%s%s", proto, host, uri)).Msg("Redirecting to login") |  | ||||||
|  |  | ||||||
| 		// Redirect to login |  | ||||||
| 		c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/?%s", api.Config.AppURL, queries.Encode())) |  | ||||||
| 	}) |  | ||||||
|  |  | ||||||
| 	api.Router.POST("/api/login", func(c *gin.Context) { |  | ||||||
| 		// Create login struct |  | ||||||
| 		var login types.LoginRequest |  | ||||||
|  |  | ||||||
| 		// Bind JSON |  | ||||||
| 		err := c.BindJSON(&login) |  | ||||||
|  |  | ||||||
| 		// Handle error |  | ||||||
| 		if err != nil { |  | ||||||
| 			log.Error().Err(err).Msg("Failed to bind JSON") |  | ||||||
| 			c.JSON(400, gin.H{ |  | ||||||
| 				"status":  400, |  | ||||||
| 				"message": "Bad Request", |  | ||||||
| 			}) |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Got login request") |  | ||||||
|  |  | ||||||
| 		// Get user based on username |  | ||||||
| 		user := api.Auth.GetUser(login.Username) |  | ||||||
|  |  | ||||||
| 		// User does not exist |  | ||||||
| 		if user == nil { |  | ||||||
| 			log.Debug().Str("username", login.Username).Msg("User not found") |  | ||||||
| 			c.JSON(401, gin.H{ |  | ||||||
| 				"status":  401, |  | ||||||
| 				"message": "Unauthorized", |  | ||||||
| 			}) |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Got user") |  | ||||||
|  |  | ||||||
| 		// Check if password is correct |  | ||||||
| 		if !api.Auth.CheckPassword(*user, login.Password) { |  | ||||||
| 			log.Debug().Str("username", login.Username).Msg("Password incorrect") |  | ||||||
| 			c.JSON(401, gin.H{ |  | ||||||
| 				"status":  401, |  | ||||||
| 				"message": "Unauthorized", |  | ||||||
| 			}) |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Password correct, checking totp") |  | ||||||
|  |  | ||||||
| 		// Check if user has totp enabled |  | ||||||
| 		if user.TotpSecret != "" { |  | ||||||
| 			log.Debug().Msg("Totp enabled") |  | ||||||
|  |  | ||||||
| 			// Set totp pending cookie |  | ||||||
| 			api.Auth.CreateSessionCookie(c, &types.SessionCookie{ |  | ||||||
| 				Username:    login.Username, |  | ||||||
| 				Provider:    "username", |  | ||||||
| 				TotpPending: true, |  | ||||||
| 			}) |  | ||||||
|  |  | ||||||
| 			// Return totp required |  | ||||||
| 			c.JSON(200, gin.H{ |  | ||||||
| 				"status":      200, |  | ||||||
| 				"message":     "Waiting for totp", |  | ||||||
| 				"totpPending": true, |  | ||||||
| 			}) |  | ||||||
|  |  | ||||||
| 			// Stop further processing |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		// Create session cookie with username as provider |  | ||||||
| 		api.Auth.CreateSessionCookie(c, &types.SessionCookie{ |  | ||||||
| 			Username: login.Username, |  | ||||||
| 			Provider: "username", |  | ||||||
| 		}) |  | ||||||
|  |  | ||||||
| 		// Return logged in |  | ||||||
| 		c.JSON(200, gin.H{ |  | ||||||
| 			"status":      200, |  | ||||||
| 			"message":     "Logged in", |  | ||||||
| 			"totpPending": false, |  | ||||||
| 		}) |  | ||||||
| 	}) |  | ||||||
|  |  | ||||||
| 	api.Router.POST("/api/totp", func(c *gin.Context) { |  | ||||||
| 		// Create totp struct |  | ||||||
| 		var totpReq types.TotpRequest |  | ||||||
|  |  | ||||||
| 		// Bind JSON |  | ||||||
| 		err := c.BindJSON(&totpReq) |  | ||||||
|  |  | ||||||
| 		// Handle error |  | ||||||
| 		if err != nil { |  | ||||||
| 			log.Error().Err(err).Msg("Failed to bind JSON") |  | ||||||
| 			c.JSON(400, gin.H{ |  | ||||||
| 				"status":  400, |  | ||||||
| 				"message": "Bad Request", |  | ||||||
| 			}) |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Checking totp") |  | ||||||
|  |  | ||||||
| 		// Get user context |  | ||||||
| 		userContext := api.Hooks.UseUserContext(c) |  | ||||||
|  |  | ||||||
| 		// Check if we have a user |  | ||||||
| 		if userContext.Username == "" { |  | ||||||
| 			log.Debug().Msg("No user context") |  | ||||||
| 			c.JSON(401, gin.H{ |  | ||||||
| 				"status":  401, |  | ||||||
| 				"message": "Unauthorized", |  | ||||||
| 			}) |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		// Get user |  | ||||||
| 		user := api.Auth.GetUser(userContext.Username) |  | ||||||
|  |  | ||||||
| 		// Check if user exists |  | ||||||
| 		if user == nil { |  | ||||||
| 			log.Debug().Msg("User not found") |  | ||||||
| 			c.JSON(401, gin.H{ |  | ||||||
| 				"status":  401, |  | ||||||
| 				"message": "Unauthorized", |  | ||||||
| 			}) |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		// Check if totp is correct |  | ||||||
| 		totpOk := totp.Validate(totpReq.Code, user.TotpSecret) |  | ||||||
|  |  | ||||||
| 		// TOTP is incorrect |  | ||||||
| 		if !totpOk { |  | ||||||
| 			log.Debug().Msg("Totp incorrect") |  | ||||||
| 			c.JSON(401, gin.H{ |  | ||||||
| 				"status":  401, |  | ||||||
| 				"message": "Unauthorized", |  | ||||||
| 			}) |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Totp correct") |  | ||||||
|  |  | ||||||
| 		// Create session cookie with username as provider |  | ||||||
| 		api.Auth.CreateSessionCookie(c, &types.SessionCookie{ |  | ||||||
| 			Username: user.Username, |  | ||||||
| 			Provider: "username", |  | ||||||
| 		}) |  | ||||||
|  |  | ||||||
| 		// Return logged in |  | ||||||
| 		c.JSON(200, gin.H{ |  | ||||||
| 			"status":  200, |  | ||||||
| 			"message": "Logged in", |  | ||||||
| 		}) |  | ||||||
| 	}) |  | ||||||
|  |  | ||||||
| 	api.Router.POST("/api/logout", func(c *gin.Context) { |  | ||||||
| 		log.Debug().Msg("Logging out") |  | ||||||
|  |  | ||||||
| 		// Delete session cookie |  | ||||||
| 		api.Auth.DeleteSessionCookie(c) |  | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Cleaning up redirect cookie") |  | ||||||
|  |  | ||||||
| 		// Clean up redirect cookie if it exists |  | ||||||
| 		c.SetCookie("tinyauth_redirect_uri", "", -1, "/", api.Domain, api.Config.CookieSecure, true) |  | ||||||
|  |  | ||||||
| 		// Return logged out |  | ||||||
| 		c.JSON(200, gin.H{ |  | ||||||
| 			"status":  200, |  | ||||||
| 			"message": "Logged out", |  | ||||||
| 		}) |  | ||||||
| 	}) |  | ||||||
|  |  | ||||||
| 	api.Router.GET("/api/app", func(c *gin.Context) { |  | ||||||
| 		log.Debug().Msg("Getting app context") |  | ||||||
|  |  | ||||||
| 		// Get configured providers |  | ||||||
| 		configuredProviders := api.Providers.GetConfiguredProviders() |  | ||||||
|  |  | ||||||
| 		// We have username/password configured so add it to our providers |  | ||||||
| 		if api.Auth.UserAuthConfigured() { |  | ||||||
| 			configuredProviders = append(configuredProviders, "username") |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		// Create app context struct |  | ||||||
| 		appContext := types.AppContext{ |  | ||||||
| 			Status:              200, |  | ||||||
| 			Message:             "Ok", |  | ||||||
| 			ConfiguredProviders: configuredProviders, |  | ||||||
| 			DisableContinue:     api.Config.DisableContinue, |  | ||||||
| 			Title:               api.Config.Title, |  | ||||||
| 			GenericName:         api.Config.GenericName, |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		// Return app context |  | ||||||
| 		c.JSON(200, appContext) |  | ||||||
| 	}) |  | ||||||
|  |  | ||||||
| 	api.Router.GET("/api/user", func(c *gin.Context) { |  | ||||||
| 		log.Debug().Msg("Getting user context") |  | ||||||
|  |  | ||||||
| 		// Get user context |  | ||||||
| 		userContext := api.Hooks.UseUserContext(c) |  | ||||||
|  |  | ||||||
| 		// Create user context response |  | ||||||
| 		userContextResponse := types.UserContextResponse{ |  | ||||||
| 			Status:      200, |  | ||||||
| 			IsLoggedIn:  userContext.IsLoggedIn, |  | ||||||
| 			Username:    userContext.Username, |  | ||||||
| 			Provider:    userContext.Provider, |  | ||||||
| 			Oauth:       userContext.OAuth, |  | ||||||
| 			TotpPending: userContext.TotpPending, |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		// If we are not logged in we set the status to 401 and add the WWW-Authenticate header else we set it to 200 |  | ||||||
| 		if !userContext.IsLoggedIn { |  | ||||||
| 			log.Debug().Msg("Unauthorized") |  | ||||||
| 			c.Header("WWW-Authenticate", "Basic realm=\"tinyauth\"") |  | ||||||
| 			userContextResponse.Message = "Unauthorized" |  | ||||||
| 		} else { |  | ||||||
| 			log.Debug().Interface("userContext", userContext).Msg("Authenticated") |  | ||||||
| 			userContextResponse.Message = "Authenticated" |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		// Return user context |  | ||||||
| 		c.JSON(200, userContextResponse) |  | ||||||
| 	}) |  | ||||||
|  |  | ||||||
| 	api.Router.GET("/api/oauth/url/:provider", func(c *gin.Context) { |  | ||||||
| 		// Create struct for OAuth request |  | ||||||
| 		var request types.OAuthRequest |  | ||||||
|  |  | ||||||
| 		// Bind URI |  | ||||||
| 		bindErr := c.BindUri(&request) |  | ||||||
|  |  | ||||||
| 		// Handle error |  | ||||||
| 		if bindErr != nil { |  | ||||||
| 			log.Error().Err(bindErr).Msg("Failed to bind URI") |  | ||||||
| 			c.JSON(400, gin.H{ |  | ||||||
| 				"status":  400, |  | ||||||
| 				"message": "Bad Request", |  | ||||||
| 			}) |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Got OAuth request") |  | ||||||
|  |  | ||||||
| 		// Check if provider exists |  | ||||||
| 		provider := api.Providers.GetProvider(request.Provider) |  | ||||||
|  |  | ||||||
| 		// Provider does not exist |  | ||||||
| 		if provider == nil { |  | ||||||
| 			c.JSON(404, gin.H{ |  | ||||||
| 				"status":  404, |  | ||||||
| 				"message": "Not Found", |  | ||||||
| 			}) |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		log.Debug().Str("provider", request.Provider).Msg("Got provider") |  | ||||||
|  |  | ||||||
| 		// Get auth URL |  | ||||||
| 		authURL := provider.GetAuthURL() |  | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Got auth URL") |  | ||||||
|  |  | ||||||
| 		// Get redirect URI |  | ||||||
| 		redirectURI := c.Query("redirect_uri") |  | ||||||
|  |  | ||||||
| 		// Set redirect cookie if redirect URI is provided |  | ||||||
| 		if redirectURI != "" { |  | ||||||
| 			log.Debug().Str("redirectURI", redirectURI).Msg("Setting redirect cookie") |  | ||||||
| 			c.SetCookie("tinyauth_redirect_uri", redirectURI, 3600, "/", api.Domain, api.Config.CookieSecure, true) |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		// Tailscale does not have an auth url so we create a random code (does not need to be secure) to avoid caching and send it |  | ||||||
| 		if request.Provider == "tailscale" { |  | ||||||
| 			// Build tailscale query |  | ||||||
| 			tailscaleQuery, tailscaleQueryErr := query.Values(types.TailscaleQuery{ |  | ||||||
| 				Code: (1000 + rand.IntN(9000)), |  | ||||||
| 			}) |  | ||||||
|  |  | ||||||
| 			// Handle error |  | ||||||
| 			if tailscaleQueryErr != nil { |  | ||||||
| 				log.Error().Err(tailscaleQueryErr).Msg("Failed to build query") |  | ||||||
| 				c.JSON(500, gin.H{ |  | ||||||
| 					"status":  500, |  | ||||||
| 					"message": "Internal Server Error", |  | ||||||
| 				}) |  | ||||||
| 				return |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			// Return tailscale URL (immidiately redirects to the callback) |  | ||||||
| 			c.JSON(200, gin.H{ |  | ||||||
| 				"status":  200, |  | ||||||
| 				"message": "Ok", |  | ||||||
| 				"url":     fmt.Sprintf("%s/api/oauth/callback/tailscale?%s", api.Config.AppURL, tailscaleQuery.Encode()), |  | ||||||
| 			}) |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		// Return auth URL |  | ||||||
| 		c.JSON(200, gin.H{ |  | ||||||
| 			"status":  200, |  | ||||||
| 			"message": "Ok", |  | ||||||
| 			"url":     authURL, |  | ||||||
| 		}) |  | ||||||
| 	}) |  | ||||||
|  |  | ||||||
| 	api.Router.GET("/api/oauth/callback/:provider", func(c *gin.Context) { |  | ||||||
| 		// Create struct for OAuth request |  | ||||||
| 		var providerName types.OAuthRequest |  | ||||||
|  |  | ||||||
| 		// Bind URI |  | ||||||
| 		bindErr := c.BindUri(&providerName) |  | ||||||
|  |  | ||||||
| 		// Handle error |  | ||||||
| 		if api.handleError(c, "Failed to bind URI", bindErr) { |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		log.Debug().Interface("provider", providerName.Provider).Msg("Got provider name") |  | ||||||
|  |  | ||||||
| 		// Get code |  | ||||||
| 		code := c.Query("code") |  | ||||||
|  |  | ||||||
| 		// Code empty so redirect to error |  | ||||||
| 		if code == "" { |  | ||||||
| 			log.Error().Msg("No code provided") |  | ||||||
| 			c.Redirect(http.StatusPermanentRedirect, "/error") |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Got code") |  | ||||||
|  |  | ||||||
| 		// Get provider |  | ||||||
| 		provider := api.Providers.GetProvider(providerName.Provider) |  | ||||||
|  |  | ||||||
| 		log.Debug().Str("provider", providerName.Provider).Msg("Got provider") |  | ||||||
|  |  | ||||||
| 		// Provider does not exist |  | ||||||
| 		if provider == nil { |  | ||||||
| 			c.Redirect(http.StatusPermanentRedirect, "/not-found") |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		// Exchange token (authenticates user) |  | ||||||
| 		_, tokenErr := provider.ExchangeToken(code) |  | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Got token") |  | ||||||
|  |  | ||||||
| 		// Handle error |  | ||||||
| 		if api.handleError(c, "Failed to exchange token", tokenErr) { |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		// Get email |  | ||||||
| 		email, emailErr := api.Providers.GetUser(providerName.Provider) |  | ||||||
|  |  | ||||||
| 		log.Debug().Str("email", email).Msg("Got email") |  | ||||||
|  |  | ||||||
| 		// Handle error |  | ||||||
| 		if api.handleError(c, "Failed to get user", emailErr) { |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		// Email is not whitelisted |  | ||||||
| 		if !api.Auth.EmailWhitelisted(email) { |  | ||||||
| 			log.Warn().Str("email", email).Msg("Email not whitelisted") |  | ||||||
|  |  | ||||||
| 			// Build query |  | ||||||
| 			unauthorizedQuery, unauthorizedQueryErr := query.Values(types.UnauthorizedQuery{ |  | ||||||
| 				Username: email, |  | ||||||
| 			}) |  | ||||||
|  |  | ||||||
| 			// Handle error |  | ||||||
| 			if api.handleError(c, "Failed to build query", unauthorizedQueryErr) { |  | ||||||
| 				return |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			// Redirect to unauthorized |  | ||||||
| 			c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/unauthorized?%s", api.Config.AppURL, unauthorizedQuery.Encode())) |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Email whitelisted") |  | ||||||
|  |  | ||||||
| 		// Create session cookie |  | ||||||
| 		api.Auth.CreateSessionCookie(c, &types.SessionCookie{ |  | ||||||
| 			Username: email, |  | ||||||
| 			Provider: providerName.Provider, |  | ||||||
| 		}) |  | ||||||
|  |  | ||||||
| 		// Get redirect URI |  | ||||||
| 		redirectURI, redirectURIErr := c.Cookie("tinyauth_redirect_uri") |  | ||||||
|  |  | ||||||
| 		// If it is empty it means that no redirect_uri was provided to the login screen so we just log in |  | ||||||
| 		if redirectURIErr != nil { |  | ||||||
| 			c.Redirect(http.StatusPermanentRedirect, api.Config.AppURL) |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		log.Debug().Str("redirectURI", redirectURI).Msg("Got redirect URI") |  | ||||||
|  |  | ||||||
| 		// Clean up redirect cookie since we already have the value |  | ||||||
| 		c.SetCookie("tinyauth_redirect_uri", "", -1, "/", api.Domain, api.Config.CookieSecure, true) |  | ||||||
|  |  | ||||||
| 		// Build query |  | ||||||
| 		redirectQuery, redirectQueryErr := query.Values(types.LoginQuery{ |  | ||||||
| 			RedirectURI: redirectURI, |  | ||||||
| 		}) |  | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Got redirect query") |  | ||||||
|  |  | ||||||
| 		// Handle error |  | ||||||
| 		if api.handleError(c, "Failed to build query", redirectQueryErr) { |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		// Redirect to continue with the redirect URI |  | ||||||
| 		c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/continue?%s", api.Config.AppURL, redirectQuery.Encode())) |  | ||||||
| 	}) |  | ||||||
|  |  | ||||||
| 	// Simple healthcheck |  | ||||||
| 	api.Router.GET("/api/healthcheck", func(c *gin.Context) { |  | ||||||
| 		c.JSON(200, gin.H{ |  | ||||||
| 			"status":  200, |  | ||||||
| 			"message": "OK", |  | ||||||
| 		}) |  | ||||||
| 	}) |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func (api *API) Run() { | func (api *API) Run() { | ||||||
| @@ -724,23 +105,12 @@ func (api *API) Run() { | |||||||
| 	// Run server | 	// Run server | ||||||
| 	err := api.Router.Run(fmt.Sprintf("%s:%d", api.Config.Address, api.Config.Port)) | 	err := api.Router.Run(fmt.Sprintf("%s:%d", api.Config.Address, api.Config.Port)) | ||||||
|  |  | ||||||
| 	// Check error | 	// Check for errors | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Fatal().Err(err).Msg("Failed to start server") | 		log.Fatal().Err(err).Msg("Failed to start server") | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| // handleError logs the error and redirects to the error page (only meant for stuff the user may access does not apply for login paths) |  | ||||||
| func (api *API) handleError(c *gin.Context, msg string, err error) bool { |  | ||||||
| 	// If error is not nil log it and redirect to error page also return true so we can stop further processing |  | ||||||
| 	if err != nil { |  | ||||||
| 		log.Error().Err(err).Msg(msg) |  | ||||||
| 		c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", api.Config.AppURL)) |  | ||||||
| 		return true |  | ||||||
| 	} |  | ||||||
| 	return false |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // zerolog is a middleware for gin that logs requests using zerolog | // zerolog is a middleware for gin that logs requests using zerolog | ||||||
| func zerolog() gin.HandlerFunc { | func zerolog() gin.HandlerFunc { | ||||||
| 	return func(c *gin.Context) { | 	return func(c *gin.Context) { | ||||||
|   | |||||||
| @@ -5,11 +5,13 @@ import ( | |||||||
| 	"io" | 	"io" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/http/httptest" | 	"net/http/httptest" | ||||||
|  | 	"reflect" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"tinyauth/internal/api" | 	"tinyauth/internal/api" | ||||||
| 	"tinyauth/internal/auth" | 	"tinyauth/internal/auth" | ||||||
| 	"tinyauth/internal/docker" | 	"tinyauth/internal/docker" | ||||||
|  | 	"tinyauth/internal/handlers" | ||||||
| 	"tinyauth/internal/hooks" | 	"tinyauth/internal/hooks" | ||||||
| 	"tinyauth/internal/providers" | 	"tinyauth/internal/providers" | ||||||
| 	"tinyauth/internal/types" | 	"tinyauth/internal/types" | ||||||
| @@ -19,13 +21,21 @@ import ( | |||||||
|  |  | ||||||
| // Simple API config for tests | // Simple API config for tests | ||||||
| var apiConfig = types.APIConfig{ | var apiConfig = types.APIConfig{ | ||||||
| 	Port:            8080, | 	Port:          8080, | ||||||
| 	Address:         "0.0.0.0", | 	Address:       "0.0.0.0", | ||||||
| 	Secret:          "super-secret-api-thing-for-tests", // It is 32 chars long | 	Secret:        "super-secret-api-thing-for-tests", // It is 32 chars long | ||||||
| 	AppURL:          "http://tinyauth.localhost", | 	CookieSecure:  false, | ||||||
|  | 	SessionExpiry: 3600, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Simple handlers config for tests | ||||||
|  | var handlersConfig = types.HandlersConfig{ | ||||||
|  | 	AppURL:          "http://localhost:8080", | ||||||
|  | 	Domain:          ".localhost", | ||||||
| 	CookieSecure:    false, | 	CookieSecure:    false, | ||||||
| 	SessionExpiry:   3600, |  | ||||||
| 	DisableContinue: false, | 	DisableContinue: false, | ||||||
|  | 	Title:           "Tinyauth", | ||||||
|  | 	GenericName:     "Generic", | ||||||
| } | } | ||||||
|  |  | ||||||
| // Cookie | // Cookie | ||||||
| @@ -67,8 +77,11 @@ func getAPI(t *testing.T) *api.API { | |||||||
| 	// Create hooks service | 	// Create hooks service | ||||||
| 	hooks := hooks.NewHooks(auth, providers) | 	hooks := hooks.NewHooks(auth, providers) | ||||||
|  |  | ||||||
|  | 	// Create handlers service | ||||||
|  | 	handlers := handlers.NewHandlers(handlersConfig, auth, hooks, providers) | ||||||
|  |  | ||||||
| 	// Create API | 	// Create API | ||||||
| 	api := api.NewAPI(apiConfig, hooks, auth, providers) | 	api := api.NewAPI(apiConfig, handlers) | ||||||
|  |  | ||||||
| 	// Setup routes | 	// Setup routes | ||||||
| 	api.Init() | 	api.Init() | ||||||
| @@ -123,6 +136,70 @@ func TestLogin(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // Test app context | ||||||
|  | func TestAppContext(t *testing.T) { | ||||||
|  | 	t.Log("Testing app context") | ||||||
|  |  | ||||||
|  | 	// Get API | ||||||
|  | 	api := getAPI(t) | ||||||
|  |  | ||||||
|  | 	// Create recorder | ||||||
|  | 	recorder := httptest.NewRecorder() | ||||||
|  |  | ||||||
|  | 	// Create request | ||||||
|  | 	req, err := http.NewRequest("GET", "/api/app", nil) | ||||||
|  |  | ||||||
|  | 	// Check if there was an error | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("Error creating request: %v", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Set the cookie | ||||||
|  | 	req.AddCookie(&http.Cookie{ | ||||||
|  | 		Name:  "tinyauth", | ||||||
|  | 		Value: cookie, | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	// Serve the request | ||||||
|  | 	api.Router.ServeHTTP(recorder, req) | ||||||
|  |  | ||||||
|  | 	// Assert | ||||||
|  | 	assert.Equal(t, recorder.Code, http.StatusOK) | ||||||
|  |  | ||||||
|  | 	// Read the body of the response | ||||||
|  | 	body, bodyErr := io.ReadAll(recorder.Body) | ||||||
|  |  | ||||||
|  | 	// Check if there was an error | ||||||
|  | 	if bodyErr != nil { | ||||||
|  | 		t.Fatalf("Error getting body: %v", bodyErr) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Unmarshal the body into the user struct | ||||||
|  | 	var app types.AppContext | ||||||
|  |  | ||||||
|  | 	jsonErr := json.Unmarshal(body, &app) | ||||||
|  |  | ||||||
|  | 	// Check if there was an error | ||||||
|  | 	if jsonErr != nil { | ||||||
|  | 		t.Fatalf("Error unmarshalling body: %v", jsonErr) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Create tests values | ||||||
|  | 	expected := types.AppContext{ | ||||||
|  | 		Status:              200, | ||||||
|  | 		Message:             "OK", | ||||||
|  | 		ConfiguredProviders: []string{"username"}, | ||||||
|  | 		DisableContinue:     false, | ||||||
|  | 		Title:               "Tinyauth", | ||||||
|  | 		GenericName:         "Generic", | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// We should get the username back | ||||||
|  | 	if !reflect.DeepEqual(app, expected) { | ||||||
|  | 		t.Fatalf("Expected %v, got %v", expected, app) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| // Test user context | // Test user context | ||||||
| func TestUserContext(t *testing.T) { | func TestUserContext(t *testing.T) { | ||||||
| 	t.Log("Testing user context") | 	t.Log("Testing user context") | ||||||
|   | |||||||
							
								
								
									
										634
									
								
								internal/handlers/handlers.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										634
									
								
								internal/handlers/handlers.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,634 @@ | |||||||
|  | package handlers | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"math/rand/v2" | ||||||
|  | 	"net/http" | ||||||
|  | 	"strings" | ||||||
|  | 	"tinyauth/internal/auth" | ||||||
|  | 	"tinyauth/internal/hooks" | ||||||
|  | 	"tinyauth/internal/providers" | ||||||
|  | 	"tinyauth/internal/types" | ||||||
|  |  | ||||||
|  | 	"github.com/gin-gonic/gin" | ||||||
|  | 	"github.com/google/go-querystring/query" | ||||||
|  | 	"github.com/pquerna/otp/totp" | ||||||
|  | 	"github.com/rs/zerolog/log" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func NewHandlers(config types.HandlersConfig, auth *auth.Auth, hooks *hooks.Hooks, providers *providers.Providers) *Handlers { | ||||||
|  | 	return &Handlers{ | ||||||
|  | 		Config:    config, | ||||||
|  | 		Auth:      auth, | ||||||
|  | 		Hooks:     hooks, | ||||||
|  | 		Providers: providers, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type Handlers struct { | ||||||
|  | 	Config    types.HandlersConfig | ||||||
|  | 	Auth      *auth.Auth | ||||||
|  | 	Hooks     *hooks.Hooks | ||||||
|  | 	Providers *providers.Providers | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (h *Handlers) AuthHandler(c *gin.Context) { | ||||||
|  | 	// Create struct for proxy | ||||||
|  | 	var proxy types.Proxy | ||||||
|  |  | ||||||
|  | 	// Bind URI | ||||||
|  | 	err := c.BindUri(&proxy) | ||||||
|  |  | ||||||
|  | 	// Handle error | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Error().Err(err).Msg("Failed to bind URI") | ||||||
|  | 		c.JSON(400, gin.H{ | ||||||
|  | 			"status":  400, | ||||||
|  | 			"message": "Bad Request", | ||||||
|  | 		}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Check if the request is coming from a browser (tools like curl/bruno use */* and they don't include the text/html) | ||||||
|  | 	isBrowser := strings.Contains(c.Request.Header.Get("Accept"), "text/html") | ||||||
|  |  | ||||||
|  | 	if isBrowser { | ||||||
|  | 		log.Debug().Msg("Request is most likely coming from a browser") | ||||||
|  | 	} else { | ||||||
|  | 		log.Debug().Msg("Request is most likely not coming from a browser") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	log.Debug().Interface("proxy", proxy.Proxy).Msg("Got proxy") | ||||||
|  |  | ||||||
|  | 	// Check if auth is enabled | ||||||
|  | 	authEnabled, err := h.Auth.AuthEnabled(c) | ||||||
|  |  | ||||||
|  | 	// Handle error | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Error().Err(err).Msg("Failed to check if auth is enabled") | ||||||
|  |  | ||||||
|  | 		if proxy.Proxy == "nginx" || !isBrowser { | ||||||
|  | 			c.JSON(500, gin.H{ | ||||||
|  | 				"status":  500, | ||||||
|  | 				"message": "Internal Server Error", | ||||||
|  | 			}) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// If auth is not enabled, return 200 | ||||||
|  | 	if !authEnabled { | ||||||
|  | 		c.JSON(200, gin.H{ | ||||||
|  | 			"status":  200, | ||||||
|  | 			"message": "Authenticated", | ||||||
|  | 		}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Get user context | ||||||
|  | 	userContext := h.Hooks.UseUserContext(c) | ||||||
|  |  | ||||||
|  | 	// Get headers | ||||||
|  | 	uri := c.Request.Header.Get("X-Forwarded-Uri") | ||||||
|  | 	proto := c.Request.Header.Get("X-Forwarded-Proto") | ||||||
|  | 	host := c.Request.Header.Get("X-Forwarded-Host") | ||||||
|  |  | ||||||
|  | 	// Check if user is logged in | ||||||
|  | 	if userContext.IsLoggedIn { | ||||||
|  | 		log.Debug().Msg("Authenticated") | ||||||
|  |  | ||||||
|  | 		// Check if user is allowed to access subdomain, if request is nginx.example.com the subdomain (resource) is nginx | ||||||
|  | 		appAllowed, err := h.Auth.ResourceAllowed(c, userContext) | ||||||
|  |  | ||||||
|  | 		// Check if there was an error | ||||||
|  | 		if err != nil { | ||||||
|  | 			log.Error().Err(err).Msg("Failed to check if app is allowed") | ||||||
|  |  | ||||||
|  | 			if proxy.Proxy == "nginx" || !isBrowser { | ||||||
|  | 				c.JSON(500, gin.H{ | ||||||
|  | 					"status":  500, | ||||||
|  | 					"message": "Internal Server Error", | ||||||
|  | 				}) | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		log.Debug().Bool("appAllowed", appAllowed).Msg("Checking if app is allowed") | ||||||
|  |  | ||||||
|  | 		// The user is not allowed to access the app | ||||||
|  | 		if !appAllowed { | ||||||
|  | 			log.Warn().Str("username", userContext.Username).Str("host", host).Msg("User not allowed") | ||||||
|  |  | ||||||
|  | 			// Set WWW-Authenticate header | ||||||
|  | 			c.Header("WWW-Authenticate", "Basic realm=\"tinyauth\"") | ||||||
|  |  | ||||||
|  | 			if proxy.Proxy == "nginx" || !isBrowser { | ||||||
|  | 				c.JSON(401, gin.H{ | ||||||
|  | 					"status":  401, | ||||||
|  | 					"message": "Unauthorized", | ||||||
|  | 				}) | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			// Build query | ||||||
|  | 			queries, err := query.Values(types.UnauthorizedQuery{ | ||||||
|  | 				Username: userContext.Username, | ||||||
|  | 				Resource: strings.Split(host, ".")[0], | ||||||
|  | 			}) | ||||||
|  |  | ||||||
|  | 			// Handle error (no need to check for nginx/headers since we are sure we are using caddy/traefik) | ||||||
|  | 			if err != nil { | ||||||
|  | 				log.Error().Err(err).Msg("Failed to build query") | ||||||
|  | 				c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			// We are using caddy/traefik so redirect | ||||||
|  | 			c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", h.Config.AppURL, queries.Encode())) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		// Set the user header | ||||||
|  | 		c.Header("Remote-User", userContext.Username) | ||||||
|  |  | ||||||
|  | 		// The user is allowed to access the app | ||||||
|  | 		c.JSON(200, gin.H{ | ||||||
|  | 			"status":  200, | ||||||
|  | 			"message": "Authenticated", | ||||||
|  | 		}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// The user is not logged in | ||||||
|  | 	log.Debug().Msg("Unauthorized") | ||||||
|  |  | ||||||
|  | 	// Set www-authenticate header | ||||||
|  | 	c.Header("WWW-Authenticate", "Basic realm=\"tinyauth\"") | ||||||
|  |  | ||||||
|  | 	if proxy.Proxy == "nginx" || !isBrowser { | ||||||
|  | 		c.JSON(401, gin.H{ | ||||||
|  | 			"status":  401, | ||||||
|  | 			"message": "Unauthorized", | ||||||
|  | 		}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	queries, err := query.Values(types.LoginQuery{ | ||||||
|  | 		RedirectURI: fmt.Sprintf("%s://%s%s", proto, host, uri), | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Error().Err(err).Msg("Failed to build query") | ||||||
|  | 		c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	log.Debug().Interface("redirect_uri", fmt.Sprintf("%s://%s%s", proto, host, uri)).Msg("Redirecting to login") | ||||||
|  |  | ||||||
|  | 	// Redirect to login | ||||||
|  | 	c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/?%s", h.Config.AppURL, queries.Encode())) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (h *Handlers) LoginHandler(c *gin.Context) { | ||||||
|  | 	// Create login struct | ||||||
|  | 	var login types.LoginRequest | ||||||
|  |  | ||||||
|  | 	// Bind JSON | ||||||
|  | 	err := c.BindJSON(&login) | ||||||
|  |  | ||||||
|  | 	// Handle error | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Error().Err(err).Msg("Failed to bind JSON") | ||||||
|  | 		c.JSON(400, gin.H{ | ||||||
|  | 			"status":  400, | ||||||
|  | 			"message": "Bad Request", | ||||||
|  | 		}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	log.Debug().Msg("Got login request") | ||||||
|  |  | ||||||
|  | 	// Get user based on username | ||||||
|  | 	user := h.Auth.GetUser(login.Username) | ||||||
|  |  | ||||||
|  | 	// User does not exist | ||||||
|  | 	if user == nil { | ||||||
|  | 		log.Debug().Str("username", login.Username).Msg("User not found") | ||||||
|  | 		c.JSON(401, gin.H{ | ||||||
|  | 			"status":  401, | ||||||
|  | 			"message": "Unauthorized", | ||||||
|  | 		}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	log.Debug().Msg("Got user") | ||||||
|  |  | ||||||
|  | 	// Check if password is correct | ||||||
|  | 	if !h.Auth.CheckPassword(*user, login.Password) { | ||||||
|  | 		log.Debug().Str("username", login.Username).Msg("Password incorrect") | ||||||
|  | 		c.JSON(401, gin.H{ | ||||||
|  | 			"status":  401, | ||||||
|  | 			"message": "Unauthorized", | ||||||
|  | 		}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	log.Debug().Msg("Password correct, checking totp") | ||||||
|  |  | ||||||
|  | 	// Check if user has totp enabled | ||||||
|  | 	if user.TotpSecret != "" { | ||||||
|  | 		log.Debug().Msg("Totp enabled") | ||||||
|  |  | ||||||
|  | 		// Set totp pending cookie | ||||||
|  | 		h.Auth.CreateSessionCookie(c, &types.SessionCookie{ | ||||||
|  | 			Username:    login.Username, | ||||||
|  | 			Provider:    "username", | ||||||
|  | 			TotpPending: true, | ||||||
|  | 		}) | ||||||
|  |  | ||||||
|  | 		// Return totp required | ||||||
|  | 		c.JSON(200, gin.H{ | ||||||
|  | 			"status":      200, | ||||||
|  | 			"message":     "Waiting for totp", | ||||||
|  | 			"totpPending": true, | ||||||
|  | 		}) | ||||||
|  |  | ||||||
|  | 		// Stop further processing | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Create session cookie with username as provider | ||||||
|  | 	h.Auth.CreateSessionCookie(c, &types.SessionCookie{ | ||||||
|  | 		Username: login.Username, | ||||||
|  | 		Provider: "username", | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	// Return logged in | ||||||
|  | 	c.JSON(200, gin.H{ | ||||||
|  | 		"status":      200, | ||||||
|  | 		"message":     "Logged in", | ||||||
|  | 		"totpPending": false, | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (h *Handlers) TotpHandler(c *gin.Context) { | ||||||
|  | 	// Create totp struct | ||||||
|  | 	var totpReq types.TotpRequest | ||||||
|  |  | ||||||
|  | 	// Bind JSON | ||||||
|  | 	err := c.BindJSON(&totpReq) | ||||||
|  |  | ||||||
|  | 	// Handle error | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Error().Err(err).Msg("Failed to bind JSON") | ||||||
|  | 		c.JSON(400, gin.H{ | ||||||
|  | 			"status":  400, | ||||||
|  | 			"message": "Bad Request", | ||||||
|  | 		}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	log.Debug().Msg("Checking totp") | ||||||
|  |  | ||||||
|  | 	// Get user context | ||||||
|  | 	userContext := h.Hooks.UseUserContext(c) | ||||||
|  |  | ||||||
|  | 	// Check if we have a user | ||||||
|  | 	if userContext.Username == "" { | ||||||
|  | 		log.Debug().Msg("No user context") | ||||||
|  | 		c.JSON(401, gin.H{ | ||||||
|  | 			"status":  401, | ||||||
|  | 			"message": "Unauthorized", | ||||||
|  | 		}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Get user | ||||||
|  | 	user := h.Auth.GetUser(userContext.Username) | ||||||
|  |  | ||||||
|  | 	// Check if user exists | ||||||
|  | 	if user == nil { | ||||||
|  | 		log.Debug().Msg("User not found") | ||||||
|  | 		c.JSON(401, gin.H{ | ||||||
|  | 			"status":  401, | ||||||
|  | 			"message": "Unauthorized", | ||||||
|  | 		}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Check if totp is correct | ||||||
|  | 	totpOk := totp.Validate(totpReq.Code, user.TotpSecret) | ||||||
|  |  | ||||||
|  | 	// TOTP is incorrect | ||||||
|  | 	if !totpOk { | ||||||
|  | 		log.Debug().Msg("Totp incorrect") | ||||||
|  | 		c.JSON(401, gin.H{ | ||||||
|  | 			"status":  401, | ||||||
|  | 			"message": "Unauthorized", | ||||||
|  | 		}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	log.Debug().Msg("Totp correct") | ||||||
|  |  | ||||||
|  | 	// Create session cookie with username as provider | ||||||
|  | 	h.Auth.CreateSessionCookie(c, &types.SessionCookie{ | ||||||
|  | 		Username: user.Username, | ||||||
|  | 		Provider: "username", | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	// Return logged in | ||||||
|  | 	c.JSON(200, gin.H{ | ||||||
|  | 		"status":  200, | ||||||
|  | 		"message": "Logged in", | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (h *Handlers) LogoutHandler(c *gin.Context) { | ||||||
|  | 	log.Debug().Msg("Logging out") | ||||||
|  |  | ||||||
|  | 	// Delete session cookie | ||||||
|  | 	h.Auth.DeleteSessionCookie(c) | ||||||
|  |  | ||||||
|  | 	log.Debug().Msg("Cleaning up redirect cookie") | ||||||
|  |  | ||||||
|  | 	// Clean up redirect cookie if it exists | ||||||
|  | 	c.SetCookie("tinyauth_redirect_uri", "", -1, "/", h.Config.Domain, h.Config.CookieSecure, true) | ||||||
|  |  | ||||||
|  | 	// Return logged out | ||||||
|  | 	c.JSON(200, gin.H{ | ||||||
|  | 		"status":  200, | ||||||
|  | 		"message": "Logged out", | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (h *Handlers) AppHandler(c *gin.Context) { | ||||||
|  | 	log.Debug().Msg("Getting app context") | ||||||
|  |  | ||||||
|  | 	// Get configured providers | ||||||
|  | 	configuredProviders := h.Providers.GetConfiguredProviders() | ||||||
|  |  | ||||||
|  | 	// We have username/password configured so add it to our providers | ||||||
|  | 	if h.Auth.UserAuthConfigured() { | ||||||
|  | 		configuredProviders = append(configuredProviders, "username") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Create app context struct | ||||||
|  | 	appContext := types.AppContext{ | ||||||
|  | 		Status:              200, | ||||||
|  | 		Message:             "OK", | ||||||
|  | 		ConfiguredProviders: configuredProviders, | ||||||
|  | 		DisableContinue:     h.Config.DisableContinue, | ||||||
|  | 		Title:               h.Config.Title, | ||||||
|  | 		GenericName:         h.Config.GenericName, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Return app context | ||||||
|  | 	c.JSON(200, appContext) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (h *Handlers) UserHandler(c *gin.Context) { | ||||||
|  | 	log.Debug().Msg("Getting user context") | ||||||
|  |  | ||||||
|  | 	// Get user context | ||||||
|  | 	userContext := h.Hooks.UseUserContext(c) | ||||||
|  |  | ||||||
|  | 	// Create user context response | ||||||
|  | 	userContextResponse := types.UserContextResponse{ | ||||||
|  | 		Status:      200, | ||||||
|  | 		IsLoggedIn:  userContext.IsLoggedIn, | ||||||
|  | 		Username:    userContext.Username, | ||||||
|  | 		Provider:    userContext.Provider, | ||||||
|  | 		Oauth:       userContext.OAuth, | ||||||
|  | 		TotpPending: userContext.TotpPending, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// If we are not logged in we set the status to 401 and add the WWW-Authenticate header else we set it to 200 | ||||||
|  | 	if !userContext.IsLoggedIn { | ||||||
|  | 		log.Debug().Msg("Unauthorized") | ||||||
|  | 		c.Header("WWW-Authenticate", "Basic realm=\"tinyauth\"") | ||||||
|  | 		userContextResponse.Message = "Unauthorized" | ||||||
|  | 	} else { | ||||||
|  | 		log.Debug().Interface("userContext", userContext).Msg("Authenticated") | ||||||
|  | 		userContextResponse.Message = "Authenticated" | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Return user context | ||||||
|  | 	c.JSON(200, userContextResponse) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (h *Handlers) OauthUrlHandler(c *gin.Context) { | ||||||
|  | 	// Create struct for OAuth request | ||||||
|  | 	var request types.OAuthRequest | ||||||
|  |  | ||||||
|  | 	// Bind URI | ||||||
|  | 	err := c.BindUri(&request) | ||||||
|  |  | ||||||
|  | 	// Handle error | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Error().Err(err).Msg("Failed to bind URI") | ||||||
|  | 		c.JSON(400, gin.H{ | ||||||
|  | 			"status":  400, | ||||||
|  | 			"message": "Bad Request", | ||||||
|  | 		}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	log.Debug().Msg("Got OAuth request") | ||||||
|  |  | ||||||
|  | 	// Check if provider exists | ||||||
|  | 	provider := h.Providers.GetProvider(request.Provider) | ||||||
|  |  | ||||||
|  | 	// Provider does not exist | ||||||
|  | 	if provider == nil { | ||||||
|  | 		c.JSON(404, gin.H{ | ||||||
|  | 			"status":  404, | ||||||
|  | 			"message": "Not Found", | ||||||
|  | 		}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	log.Debug().Str("provider", request.Provider).Msg("Got provider") | ||||||
|  |  | ||||||
|  | 	// Get auth URL | ||||||
|  | 	authURL := provider.GetAuthURL() | ||||||
|  |  | ||||||
|  | 	log.Debug().Msg("Got auth URL") | ||||||
|  |  | ||||||
|  | 	// Get redirect URI | ||||||
|  | 	redirectURI := c.Query("redirect_uri") | ||||||
|  |  | ||||||
|  | 	// Set redirect cookie if redirect URI is provided | ||||||
|  | 	if redirectURI != "" { | ||||||
|  | 		log.Debug().Str("redirectURI", redirectURI).Msg("Setting redirect cookie") | ||||||
|  | 		c.SetCookie("tinyauth_redirect_uri", redirectURI, 3600, "/", h.Config.Domain, h.Config.CookieSecure, true) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Tailscale does not have an auth url so we create a random code (does not need to be secure) to avoid caching and send it | ||||||
|  | 	if request.Provider == "tailscale" { | ||||||
|  | 		// Build tailscale query | ||||||
|  | 		tailscaleQuery, err := query.Values(types.TailscaleQuery{ | ||||||
|  | 			Code: (1000 + rand.IntN(9000)), | ||||||
|  | 		}) | ||||||
|  |  | ||||||
|  | 		// Handle error | ||||||
|  | 		if err != nil { | ||||||
|  | 			log.Error().Err(err).Msg("Failed to build query") | ||||||
|  | 			c.JSON(500, gin.H{ | ||||||
|  | 				"status":  500, | ||||||
|  | 				"message": "Internal Server Error", | ||||||
|  | 			}) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		// Return tailscale URL (immidiately redirects to the callback) | ||||||
|  | 		c.JSON(200, gin.H{ | ||||||
|  | 			"status":  200, | ||||||
|  | 			"message": "OK", | ||||||
|  | 			"url":     fmt.Sprintf("%s/api/oauth/callback/tailscale?%s", h.Config.AppURL, tailscaleQuery.Encode()), | ||||||
|  | 		}) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Return auth URL | ||||||
|  | 	c.JSON(200, gin.H{ | ||||||
|  | 		"status":  200, | ||||||
|  | 		"message": "OK", | ||||||
|  | 		"url":     authURL, | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (h *Handlers) OauthCallbackHandler(c *gin.Context) { | ||||||
|  | 	// Create struct for OAuth request | ||||||
|  | 	var providerName types.OAuthRequest | ||||||
|  |  | ||||||
|  | 	// Bind URI | ||||||
|  | 	err := c.BindUri(&providerName) | ||||||
|  |  | ||||||
|  | 	// Handle error | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Error().Err(err).Msg("Failed to bind URI") | ||||||
|  | 		c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	log.Debug().Interface("provider", providerName.Provider).Msg("Got provider name") | ||||||
|  |  | ||||||
|  | 	// Get code | ||||||
|  | 	code := c.Query("code") | ||||||
|  |  | ||||||
|  | 	// Code empty so redirect to error | ||||||
|  | 	if code == "" { | ||||||
|  | 		log.Error().Msg("No code provided") | ||||||
|  | 		c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	log.Debug().Msg("Got code") | ||||||
|  |  | ||||||
|  | 	// Get provider | ||||||
|  | 	provider := h.Providers.GetProvider(providerName.Provider) | ||||||
|  |  | ||||||
|  | 	log.Debug().Str("provider", providerName.Provider).Msg("Got provider") | ||||||
|  |  | ||||||
|  | 	// Provider does not exist | ||||||
|  | 	if provider == nil { | ||||||
|  | 		c.Redirect(http.StatusPermanentRedirect, "/not-found") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Exchange token (authenticates user) | ||||||
|  | 	_, err = provider.ExchangeToken(code) | ||||||
|  |  | ||||||
|  | 	log.Debug().Msg("Got token") | ||||||
|  |  | ||||||
|  | 	// Handle error | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Error().Msg("Failed to exchange token") | ||||||
|  | 		c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Get email | ||||||
|  | 	email, err := h.Providers.GetUser(providerName.Provider) | ||||||
|  |  | ||||||
|  | 	log.Debug().Str("email", email).Msg("Got email") | ||||||
|  |  | ||||||
|  | 	// Handle error | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Error().Msg("Failed to get email") | ||||||
|  | 		c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Email is not whitelisted | ||||||
|  | 	if !h.Auth.EmailWhitelisted(email) { | ||||||
|  | 		log.Warn().Str("email", email).Msg("Email not whitelisted") | ||||||
|  |  | ||||||
|  | 		// Build query | ||||||
|  | 		unauthorizedQuery, err := query.Values(types.UnauthorizedQuery{ | ||||||
|  | 			Username: email, | ||||||
|  | 		}) | ||||||
|  |  | ||||||
|  | 		// Handle error | ||||||
|  | 		if err != nil { | ||||||
|  | 			log.Error().Msg("Failed to build query") | ||||||
|  | 			c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		// Redirect to unauthorized | ||||||
|  | 		c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/unauthorized?%s", h.Config.AppURL, unauthorizedQuery.Encode())) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	log.Debug().Msg("Email whitelisted") | ||||||
|  |  | ||||||
|  | 	// Create session cookie | ||||||
|  | 	h.Auth.CreateSessionCookie(c, &types.SessionCookie{ | ||||||
|  | 		Username: email, | ||||||
|  | 		Provider: providerName.Provider, | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	// Get redirect URI | ||||||
|  | 	redirectURI, redirectURIErr := c.Cookie("tinyauth_redirect_uri") | ||||||
|  |  | ||||||
|  | 	// If it is empty it means that no redirect_uri was provided to the login screen so we just log in | ||||||
|  | 	if redirectURIErr != nil { | ||||||
|  | 		c.Redirect(http.StatusPermanentRedirect, h.Config.AppURL) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	log.Debug().Str("redirectURI", redirectURI).Msg("Got redirect URI") | ||||||
|  |  | ||||||
|  | 	// Clean up redirect cookie since we already have the value | ||||||
|  | 	c.SetCookie("tinyauth_redirect_uri", "", -1, "/", h.Config.Domain, h.Config.CookieSecure, true) | ||||||
|  |  | ||||||
|  | 	// Build query | ||||||
|  | 	redirectQuery, err := query.Values(types.LoginQuery{ | ||||||
|  | 		RedirectURI: redirectURI, | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	log.Debug().Msg("Got redirect query") | ||||||
|  |  | ||||||
|  | 	// Handle error | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Error().Msg("Failed to build query") | ||||||
|  | 		c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Redirect to continue with the redirect URI | ||||||
|  | 	c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/continue?%s", h.Config.AppURL, redirectQuery.Encode())) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (h *Handlers) HealthcheckHandler(c *gin.Context) { | ||||||
|  | 	c.JSON(200, gin.H{ | ||||||
|  | 		"status":  200, | ||||||
|  | 		"message": "OK", | ||||||
|  | 	}) | ||||||
|  | } | ||||||
| @@ -69,15 +69,12 @@ type UserContext struct { | |||||||
|  |  | ||||||
| // APIConfig is the configuration for the API | // APIConfig is the configuration for the API | ||||||
| type APIConfig struct { | type APIConfig struct { | ||||||
| 	Port            int | 	Port          int | ||||||
| 	Address         string | 	Address       string | ||||||
| 	Secret          string | 	Secret        string | ||||||
| 	AppURL          string | 	CookieSecure  bool | ||||||
| 	CookieSecure    bool | 	SessionExpiry int | ||||||
| 	SessionExpiry   int | 	Domain        string | ||||||
| 	DisableContinue bool |  | ||||||
| 	GenericName     string |  | ||||||
| 	Title           string |  | ||||||
| } | } | ||||||
|  |  | ||||||
| // OAuthConfig is the configuration for the providers | // OAuthConfig is the configuration for the providers | ||||||
| @@ -164,3 +161,13 @@ type AppContext struct { | |||||||
| type TotpRequest struct { | type TotpRequest struct { | ||||||
| 	Code string `json:"code"` | 	Code string `json:"code"` | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // Server configuration | ||||||
|  | type HandlersConfig struct { | ||||||
|  | 	AppURL          string | ||||||
|  | 	Domain          string | ||||||
|  | 	CookieSecure    bool | ||||||
|  | 	DisableContinue bool | ||||||
|  | 	GenericName     string | ||||||
|  | 	Title           string | ||||||
|  | } | ||||||
|   | |||||||
| @@ -46,14 +46,14 @@ func ParseUsers(users string) (types.Users, error) { | |||||||
| 	return usersParsed, nil | 	return usersParsed, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| // Root url parses parses a hostname and returns the root domain (e.g. sub1.sub2.domain.com -> sub2.domain.com) | // Get upper domain parses a hostname and returns the upper domain (e.g. sub1.sub2.domain.com -> sub2.domain.com) | ||||||
| func GetRootURL(urlSrc string) (string, error) { | func GetUpperDomain(urlSrc string) (string, error) { | ||||||
| 	// Make sure the url is valid | 	// Make sure the url is valid | ||||||
| 	urlParsed, parseErr := url.Parse(urlSrc) | 	urlParsed, err := url.Parse(urlSrc) | ||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if parseErr != nil { | 	if err != nil { | ||||||
| 		return "", parseErr | 		return "", err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Split the hostname by period | 	// Split the hostname by period | ||||||
|   | |||||||
| @@ -38,15 +38,15 @@ func TestParseUsers(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| // Test the get root url function | // Test the get upper domain function | ||||||
| func TestGetRootURL(t *testing.T) { | func TestGetUpperDomain(t *testing.T) { | ||||||
| 	t.Log("Testing get root url with a valid url") | 	t.Log("Testing get upper domain with a valid url") | ||||||
|  |  | ||||||
| 	// Test the get root url function with a valid url | 	// Test the get upper domain function with a valid url | ||||||
| 	url := "https://sub1.sub2.domain.com:8080" | 	url := "https://sub1.sub2.domain.com:8080" | ||||||
| 	expected := "sub2.domain.com" | 	expected := "sub2.domain.com" | ||||||
|  |  | ||||||
| 	result, err := utils.GetRootURL(url) | 	result, err := utils.GetUpperDomain(url) | ||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user