mirror of
				https://github.com/steveiliop56/tinyauth.git
				synced 2025-10-30 21:55:43 +00:00 
			
		
		
		
	Compare commits
	
		
			12 Commits
		
	
	
		
			v2.1.0
			...
			chore/comm
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|   | 1b145fd531 | ||
|   | 7a3a463489 | ||
|   | e09f241364 | ||
|   | d2ee382f92 | ||
|   | 4e8a2443a6 | ||
|   | 22777a16a1 | ||
|   | 0872556c1a | ||
|   | daad2abc33 | ||
|   | ce567ae3de | ||
|   | 87393d3c64 | ||
|   | 97830a309b | ||
|   | fe594d2755 | 
							
								
								
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -5,7 +5,7 @@ internal/assets/dist | |||||||
| tinyauth | tinyauth | ||||||
|  |  | ||||||
| # test docker compose | # test docker compose | ||||||
| docker-compose.test.yml | docker-compose.test* | ||||||
|  |  | ||||||
| # users file | # users file | ||||||
| users.txt | users.txt | ||||||
|   | |||||||
| @@ -22,9 +22,13 @@ Tinyauth is a simple authentication middleware that adds simple username/passwor | |||||||
| > [!NOTE] | > [!NOTE] | ||||||
| > Tinyauth is intended for homelab use and it is not made for production use cases. If you are looking for something production ready please use [authentik](https://goauthentik.io). | > Tinyauth is intended for homelab use and it is not made for production use cases. If you are looking for something production ready please use [authentik](https://goauthentik.io). | ||||||
|  |  | ||||||
|  | ## Discord | ||||||
|  |  | ||||||
|  | I just made a Discord server for Tinyauth! It is not only for Tinyauth but general self-hosting because I just like chatting with people! The link is [here](https://discord.gg/gWpzrksk), see you there! | ||||||
|  |  | ||||||
| ## Getting Started | ## Getting Started | ||||||
|  |  | ||||||
| You can easily get started with tinyauth by following the guide on the documentation [here](https://tinyauth.doesmycode.work/docs/getting-started.html). There is also an available docker compose file [here](./docker-compose.example.yml) that has traefik, nginx and tinyauth to demonstrate its capabilities. | You can easily get started with tinyauth by following the guide on the [documentation](https://tinyauth.doesmycode.work/docs/getting-started.html). There is also an available [docker compose file](./docker-compose.example.yml) that has traefik, nginx and tinyauth to demonstrate its capabilities. | ||||||
|  |  | ||||||
| ## Documentation | ## Documentation | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										22
									
								
								assets/discohook.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								assets/discohook.json
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,22 @@ | |||||||
|  | { | ||||||
|  |   "content": null, | ||||||
|  |   "embeds": [ | ||||||
|  |     { | ||||||
|  |       "title": "Welcome to Tinyauth Discord!", | ||||||
|  |       "description": "Tinyauth is a simple authentication middleware that adds simple username/password login or OAuth with Google, Github and any generic OAuth provider to all of your docker apps.\n\n**Information**\n\n• Github: <https://github.com/steveiliop56/tinyauth>\n• Website: <https://tinyauth.doesmycode.work>", | ||||||
|  |       "url": "https://tinyauth.doesmycode.work", | ||||||
|  |       "color": 7002085, | ||||||
|  |       "author": { | ||||||
|  |         "name": "Tinyauth" | ||||||
|  |       }, | ||||||
|  |       "footer": { | ||||||
|  |         "text": "Updated at" | ||||||
|  |       }, | ||||||
|  |       "timestamp": "2025-02-06T22:00:00.000Z", | ||||||
|  |       "thumbnail": { | ||||||
|  |         "url": "https://github.com/steveiliop56/tinyauth/blob/main/site/public/logo.png?raw=true" | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   ], | ||||||
|  |   "attachments": [] | ||||||
|  | } | ||||||
							
								
								
									
										29
									
								
								cmd/root.go
									
									
									
									
									
								
							
							
						
						
									
										29
									
								
								cmd/root.go
									
									
									
									
									
								
							| @@ -1,6 +1,7 @@ | |||||||
| package cmd | package cmd | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"errors" | ||||||
| 	"os" | 	"os" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
| @@ -54,8 +55,10 @@ var rootCmd = &cobra.Command{ | |||||||
| 		log.Info().Msg("Parsing users") | 		log.Info().Msg("Parsing users") | ||||||
| 		users, usersErr := utils.GetUsers(config.Users, config.UsersFile) | 		users, usersErr := utils.GetUsers(config.Users, config.UsersFile) | ||||||
|  |  | ||||||
| 		if (len(users) == 0 || usersErr != nil) && !utils.OAuthConfigured(config) { | 		HandleError(usersErr, "Failed to parse users") | ||||||
| 			log.Fatal().Err(usersErr).Msg("Failed to parse users") |  | ||||||
|  | 		if len(users) == 0 && !utils.OAuthConfigured(config) { | ||||||
|  | 			HandleError(errors.New("no users or OAuth configured"), "No users or OAuth configured") | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// Create oauth whitelist | 		// Create oauth whitelist | ||||||
| @@ -89,7 +92,7 @@ var rootCmd = &cobra.Command{ | |||||||
| 		HandleError(dockerErr, "Failed to initialize docker") | 		HandleError(dockerErr, "Failed to initialize docker") | ||||||
|  |  | ||||||
| 		// Create auth service | 		// Create auth service | ||||||
| 		auth := auth.NewAuth(docker, users, oauthWhitelist) | 		auth := auth.NewAuth(docker, users, oauthWhitelist, config.SessionExpiry) | ||||||
|  |  | ||||||
| 		// Create OAuth providers service | 		// Create OAuth providers service | ||||||
| 		providers := providers.NewProviders(oauthConfig) | 		providers := providers.NewProviders(oauthConfig) | ||||||
| @@ -108,7 +111,7 @@ var rootCmd = &cobra.Command{ | |||||||
| 			AppURL:          config.AppURL, | 			AppURL:          config.AppURL, | ||||||
| 			CookieSecure:    config.CookieSecure, | 			CookieSecure:    config.CookieSecure, | ||||||
| 			DisableContinue: config.DisableContinue, | 			DisableContinue: config.DisableContinue, | ||||||
| 			CookieExpiry:    config.CookieExpiry, | 			CookieExpiry:    config.SessionExpiry, | ||||||
| 		}, hooks, auth, providers) | 		}, hooks, auth, providers) | ||||||
|  |  | ||||||
| 		// Setup routes | 		// Setup routes | ||||||
| @@ -122,20 +125,24 @@ var rootCmd = &cobra.Command{ | |||||||
|  |  | ||||||
| func Execute() { | func Execute() { | ||||||
| 	err := rootCmd.Execute() | 	err := rootCmd.Execute() | ||||||
| 	if err != nil { | 	HandleError(err, "Failed to execute root command") | ||||||
| 		log.Fatal().Err(err).Msg("Failed to execute command") |  | ||||||
| 	} |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func HandleError(err error, msg string) { | func HandleError(err error, msg string) { | ||||||
|  | 	// If error log it and exit | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Fatal().Err(err).Msg(msg) | 		log.Fatal().Err(err).Msg(msg) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func init() { | func init() { | ||||||
|  | 	// Add user command | ||||||
| 	rootCmd.AddCommand(cmd.UserCmd()) | 	rootCmd.AddCommand(cmd.UserCmd()) | ||||||
|  |  | ||||||
|  | 	// Read environment variables | ||||||
| 	viper.AutomaticEnv() | 	viper.AutomaticEnv() | ||||||
|  |  | ||||||
|  | 	// Flags | ||||||
| 	rootCmd.Flags().Int("port", 3000, "Port to run the server on.") | 	rootCmd.Flags().Int("port", 3000, "Port to run the server on.") | ||||||
| 	rootCmd.Flags().String("address", "0.0.0.0", "Address to bind the server to.") | 	rootCmd.Flags().String("address", "0.0.0.0", "Address to bind the server to.") | ||||||
| 	rootCmd.Flags().String("secret", "", "Secret to use for the cookie.") | 	rootCmd.Flags().String("secret", "", "Secret to use for the cookie.") | ||||||
| @@ -162,8 +169,10 @@ func init() { | |||||||
| 	rootCmd.Flags().String("generic-user-url", "", "Generic OAuth user info URL.") | 	rootCmd.Flags().String("generic-user-url", "", "Generic OAuth user info URL.") | ||||||
| 	rootCmd.Flags().Bool("disable-continue", false, "Disable continue screen and redirect to app directly.") | 	rootCmd.Flags().Bool("disable-continue", false, "Disable continue screen and redirect to app directly.") | ||||||
| 	rootCmd.Flags().String("oauth-whitelist", "", "Comma separated list of email addresses to whitelist when using OAuth.") | 	rootCmd.Flags().String("oauth-whitelist", "", "Comma separated list of email addresses to whitelist when using OAuth.") | ||||||
| 	rootCmd.Flags().Int("cookie-expiry", 86400, "Cookie expiration time in seconds.") | 	rootCmd.Flags().Int("session-expiry", 86400, "Session (cookie) expiration time in seconds.") | ||||||
| 	rootCmd.Flags().Int("log-level", 1, "Log level.") | 	rootCmd.Flags().Int("log-level", 1, "Log level.") | ||||||
|  |  | ||||||
|  | 	// Bind flags to environment | ||||||
| 	viper.BindEnv("port", "PORT") | 	viper.BindEnv("port", "PORT") | ||||||
| 	viper.BindEnv("address", "ADDRESS") | 	viper.BindEnv("address", "ADDRESS") | ||||||
| 	viper.BindEnv("secret", "SECRET") | 	viper.BindEnv("secret", "SECRET") | ||||||
| @@ -190,7 +199,9 @@ func init() { | |||||||
| 	viper.BindEnv("generic-user-url", "GENERIC_USER_URL") | 	viper.BindEnv("generic-user-url", "GENERIC_USER_URL") | ||||||
| 	viper.BindEnv("disable-continue", "DISABLE_CONTINUE") | 	viper.BindEnv("disable-continue", "DISABLE_CONTINUE") | ||||||
| 	viper.BindEnv("oauth-whitelist", "OAUTH_WHITELIST") | 	viper.BindEnv("oauth-whitelist", "OAUTH_WHITELIST") | ||||||
| 	viper.BindEnv("cookie-expiry", "COOKIE_EXPIRY") | 	viper.BindEnv("session-expiry", "SESSION_EXPIRY") | ||||||
| 	viper.BindEnv("log-level", "LOG_LEVEL") | 	viper.BindEnv("log-level", "LOG_LEVEL") | ||||||
|  |  | ||||||
|  | 	// Bind flags to viper | ||||||
| 	viper.BindPFlags(rootCmd.Flags()) | 	viper.BindPFlags(rootCmd.Flags()) | ||||||
| } | } | ||||||
|   | |||||||
| @@ -22,9 +22,12 @@ var CreateCmd = &cobra.Command{ | |||||||
| 	Short: "Create a user", | 	Short: "Create a user", | ||||||
| 	Long:  `Create a user either interactively or by passing flags.`, | 	Long:  `Create a user either interactively or by passing flags.`, | ||||||
| 	Run: func(cmd *cobra.Command, args []string) { | 	Run: func(cmd *cobra.Command, args []string) { | ||||||
|  | 		// Setup logger | ||||||
| 		log.Logger = log.Level(zerolog.InfoLevel) | 		log.Logger = log.Level(zerolog.InfoLevel) | ||||||
|  |  | ||||||
|  | 		// Check if interactive | ||||||
| 		if interactive { | 		if interactive { | ||||||
|  | 			// Create huh form | ||||||
| 			form := huh.NewForm( | 			form := huh.NewForm( | ||||||
| 				huh.NewGroup( | 				huh.NewGroup( | ||||||
| 					huh.NewInput().Title("Username").Value(&username).Validate((func(s string) error { | 					huh.NewInput().Title("Username").Value(&username).Validate((func(s string) error { | ||||||
| @@ -43,6 +46,7 @@ var CreateCmd = &cobra.Command{ | |||||||
| 				), | 				), | ||||||
| 			) | 			) | ||||||
|  |  | ||||||
|  | 			// Use simple theme | ||||||
| 			var baseTheme *huh.Theme = huh.ThemeBase() | 			var baseTheme *huh.Theme = huh.ThemeBase() | ||||||
|  |  | ||||||
| 			formErr := form.WithTheme(baseTheme).Run() | 			formErr := form.WithTheme(baseTheme).Run() | ||||||
| @@ -52,12 +56,14 @@ var CreateCmd = &cobra.Command{ | |||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | 		// Do we have username and password? | ||||||
| 		if username == "" || password == "" { | 		if username == "" || password == "" { | ||||||
| 			log.Error().Msg("Username and password cannot be empty") | 			log.Error().Msg("Username and password cannot be empty") | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		log.Info().Str("username", username).Str("password", password).Bool("docker", docker).Msg("Creating user") | 		log.Info().Str("username", username).Str("password", password).Bool("docker", docker).Msg("Creating user") | ||||||
|  |  | ||||||
|  | 		// Hash password | ||||||
| 		passwordByte, passwordErr := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) | 		passwordByte, passwordErr := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) | ||||||
|  |  | ||||||
| 		if passwordErr != nil { | 		if passwordErr != nil { | ||||||
| @@ -66,15 +72,18 @@ var CreateCmd = &cobra.Command{ | |||||||
|  |  | ||||||
| 		passwordString := string(passwordByte) | 		passwordString := string(passwordByte) | ||||||
|  |  | ||||||
|  | 		// Escape $ for docker | ||||||
| 		if docker { | 		if docker { | ||||||
| 			passwordString = strings.ReplaceAll(passwordString, "$", "$$") | 			passwordString = strings.ReplaceAll(passwordString, "$", "$$") | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | 		// Log user created | ||||||
| 		log.Info().Str("user", fmt.Sprintf("%s:%s", username, passwordString)).Msg("User created") | 		log.Info().Str("user", fmt.Sprintf("%s:%s", username, passwordString)).Msg("User created") | ||||||
| 	}, | 	}, | ||||||
| } | } | ||||||
|  |  | ||||||
| func init() { | func init() { | ||||||
|  | 	// Flags | ||||||
| 	CreateCmd.Flags().BoolVar(&interactive, "interactive", false, "Create a user interactively") | 	CreateCmd.Flags().BoolVar(&interactive, "interactive", false, "Create a user interactively") | ||||||
| 	CreateCmd.Flags().BoolVar(&docker, "docker", false, "Format output for docker") | 	CreateCmd.Flags().BoolVar(&docker, "docker", false, "Format output for docker") | ||||||
| 	CreateCmd.Flags().StringVar(&username, "username", "", "Username") | 	CreateCmd.Flags().StringVar(&username, "username", "", "Username") | ||||||
|   | |||||||
| @@ -22,9 +22,12 @@ var VerifyCmd = &cobra.Command{ | |||||||
| 	Short: "Verify a user is set up correctly", | 	Short: "Verify a user is set up correctly", | ||||||
| 	Long:  `Verify a user is set up correctly meaning that it has a correct username and password.`, | 	Long:  `Verify a user is set up correctly meaning that it has a correct username and password.`, | ||||||
| 	Run: func(cmd *cobra.Command, args []string) { | 	Run: func(cmd *cobra.Command, args []string) { | ||||||
|  | 		// Setup logger | ||||||
| 		log.Logger = log.Level(zerolog.InfoLevel) | 		log.Logger = log.Level(zerolog.InfoLevel) | ||||||
|  |  | ||||||
|  | 		// Check if interactive | ||||||
| 		if interactive { | 		if interactive { | ||||||
|  | 			// Create huh form | ||||||
| 			form := huh.NewForm( | 			form := huh.NewForm( | ||||||
| 				huh.NewGroup( | 				huh.NewGroup( | ||||||
| 					huh.NewInput().Title("User (username:hash)").Value(&user).Validate((func(s string) error { | 					huh.NewInput().Title("User (username:hash)").Value(&user).Validate((func(s string) error { | ||||||
| @@ -49,6 +52,7 @@ var VerifyCmd = &cobra.Command{ | |||||||
| 				), | 				), | ||||||
| 			) | 			) | ||||||
|  |  | ||||||
|  | 			// Use simple theme | ||||||
| 			var baseTheme *huh.Theme = huh.ThemeBase() | 			var baseTheme *huh.Theme = huh.ThemeBase() | ||||||
|  |  | ||||||
| 			formErr := form.WithTheme(baseTheme).Run() | 			formErr := form.WithTheme(baseTheme).Run() | ||||||
| @@ -58,22 +62,26 @@ var VerifyCmd = &cobra.Command{ | |||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | 		// Do we have username, password and user? | ||||||
| 		if username == "" || password == "" || user == "" { | 		if username == "" || password == "" || user == "" { | ||||||
| 			log.Fatal().Msg("Username, password and user cannot be empty") | 			log.Fatal().Msg("Username, password and user cannot be empty") | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		log.Info().Str("user", user).Str("username", username).Str("password", password).Bool("docker", docker).Msg("Verifying user") | 		log.Info().Str("user", user).Str("username", username).Str("password", password).Bool("docker", docker).Msg("Verifying user") | ||||||
|  |  | ||||||
|  | 		// Split username and password | ||||||
| 		userSplit := strings.Split(user, ":") | 		userSplit := strings.Split(user, ":") | ||||||
|  |  | ||||||
| 		if userSplit[1] == "" { | 		if userSplit[1] == "" { | ||||||
| 			log.Fatal().Msg("User is not formatted correctly") | 			log.Fatal().Msg("User is not formatted correctly") | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | 		// Replace $$ with $ if formatted for docker | ||||||
| 		if docker { | 		if docker { | ||||||
| 			userSplit[1] = strings.ReplaceAll(userSplit[1], "$$", "$") | 			userSplit[1] = strings.ReplaceAll(userSplit[1], "$$", "$") | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | 		// Compare username and password | ||||||
| 		verifyErr := bcrypt.CompareHashAndPassword([]byte(userSplit[1]), []byte(password)) | 		verifyErr := bcrypt.CompareHashAndPassword([]byte(userSplit[1]), []byte(password)) | ||||||
|  |  | ||||||
| 		if verifyErr != nil || username != userSplit[0] { | 		if verifyErr != nil || username != userSplit[0] { | ||||||
| @@ -85,6 +93,7 @@ var VerifyCmd = &cobra.Command{ | |||||||
| } | } | ||||||
|  |  | ||||||
| func init() { | func init() { | ||||||
|  | 	// Flags | ||||||
| 	VerifyCmd.Flags().BoolVarP(&interactive, "interactive", "i", false, "Create a user interactively") | 	VerifyCmd.Flags().BoolVarP(&interactive, "interactive", "i", false, "Create a user interactively") | ||||||
| 	VerifyCmd.Flags().BoolVar(&docker, "docker", false, "Is the user formatted for docker?") | 	VerifyCmd.Flags().BoolVar(&docker, "docker", false, "Is the user formatted for docker?") | ||||||
| 	VerifyCmd.Flags().StringVar(&username, "username", "", "Username") | 	VerifyCmd.Flags().StringVar(&username, "username", "", "Username") | ||||||
|   | |||||||
| @@ -41,11 +41,15 @@ type API struct { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (api *API) Init() { | func (api *API) Init() { | ||||||
|  | 	// Disable gin logs | ||||||
| 	gin.SetMode(gin.ReleaseMode) | 	gin.SetMode(gin.ReleaseMode) | ||||||
|  |  | ||||||
|  | 	// Create router and use zerolog for logs | ||||||
| 	log.Debug().Msg("Setting up router") | 	log.Debug().Msg("Setting up router") | ||||||
| 	router := gin.New() | 	router := gin.New() | ||||||
| 	router.Use(zerolog()) | 	router.Use(zerolog()) | ||||||
|  |  | ||||||
|  | 	// Read UI assets | ||||||
| 	log.Debug().Msg("Setting up assets") | 	log.Debug().Msg("Setting up assets") | ||||||
| 	dist, distErr := fs.Sub(assets.Assets, "dist") | 	dist, distErr := fs.Sub(assets.Assets, "dist") | ||||||
|  |  | ||||||
| @@ -53,11 +57,15 @@ func (api *API) Init() { | |||||||
| 		log.Fatal().Err(distErr).Msg("Failed to get UI assets") | 		log.Fatal().Err(distErr).Msg("Failed to get UI assets") | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// Create file server | ||||||
| 	log.Debug().Msg("Setting up file server") | 	log.Debug().Msg("Setting up file server") | ||||||
| 	fileServer := http.FileServer(http.FS(dist)) | 	fileServer := http.FileServer(http.FS(dist)) | ||||||
|  |  | ||||||
|  | 	// Setup cookie store | ||||||
| 	log.Debug().Msg("Setting up cookie store") | 	log.Debug().Msg("Setting up cookie store") | ||||||
| 	store := cookie.NewStore([]byte(api.Config.Secret)) | 	store := cookie.NewStore([]byte(api.Config.Secret)) | ||||||
|  |  | ||||||
|  | 	// Get domain to use for session cookies | ||||||
| 	log.Debug().Msg("Getting domain") | 	log.Debug().Msg("Getting domain") | ||||||
| 	domain, domainErr := utils.GetRootURL(api.Config.AppURL) | 	domain, domainErr := utils.GetRootURL(api.Config.AppURL) | ||||||
|  |  | ||||||
| @@ -70,6 +78,7 @@ func (api *API) Init() { | |||||||
|  |  | ||||||
| 	api.Domain = fmt.Sprintf(".%s", domain) | 	api.Domain = fmt.Sprintf(".%s", domain) | ||||||
|  |  | ||||||
|  | 	// Use session middleware | ||||||
| 	store.Options(sessions.Options{ | 	store.Options(sessions.Options{ | ||||||
| 		Domain:   api.Domain, | 		Domain:   api.Domain, | ||||||
| 		Path:     "/", | 		Path:     "/", | ||||||
| @@ -80,63 +89,93 @@ func (api *API) Init() { | |||||||
|  |  | ||||||
| 	router.Use(sessions.Sessions("tinyauth", store)) | 	router.Use(sessions.Sessions("tinyauth", store)) | ||||||
|  |  | ||||||
|  | 	// UI middleware | ||||||
| 	router.Use(func(c *gin.Context) { | 	router.Use(func(c *gin.Context) { | ||||||
|  | 		// If not an API request, serve the UI | ||||||
| 		if !strings.HasPrefix(c.Request.URL.Path, "/api") { | 		if !strings.HasPrefix(c.Request.URL.Path, "/api") { | ||||||
| 			_, err := fs.Stat(dist, strings.TrimPrefix(c.Request.URL.Path, "/")) | 			_, err := fs.Stat(dist, strings.TrimPrefix(c.Request.URL.Path, "/")) | ||||||
|  |  | ||||||
|  | 			// If the file doesn't exist, serve the index.html | ||||||
| 			if os.IsNotExist(err) { | 			if os.IsNotExist(err) { | ||||||
| 				c.Request.URL.Path = "/" | 				c.Request.URL.Path = "/" | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
|  | 			// Serve the file | ||||||
| 			fileServer.ServeHTTP(c.Writer, c.Request) | 			fileServer.ServeHTTP(c.Writer, c.Request) | ||||||
|  |  | ||||||
|  | 			// Stop further processing | ||||||
| 			c.Abort() | 			c.Abort() | ||||||
| 		} | 		} | ||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
|  | 	// Set router | ||||||
| 	api.Router = router | 	api.Router = router | ||||||
| } | } | ||||||
|  |  | ||||||
| func (api *API) SetupRoutes() { | func (api *API) SetupRoutes() { | ||||||
| 	api.Router.GET("/api/auth", func(c *gin.Context) { | 	api.Router.GET("/api/auth/:proxy", func(c *gin.Context) { | ||||||
| 		log.Debug().Msg("Checking auth") | 		// Create struct for proxy | ||||||
|  | 		var proxy types.Proxy | ||||||
|  |  | ||||||
|  | 		// Bind URI | ||||||
|  | 		bindErr := c.BindUri(&proxy) | ||||||
|  |  | ||||||
|  | 		// Handle error | ||||||
|  | 		if api.handleError(c, "Failed to bind URI", bindErr) { | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		log.Debug().Interface("proxy", proxy.Proxy).Msg("Got proxy") | ||||||
|  |  | ||||||
|  | 		// Get user context | ||||||
| 		userContext := api.Hooks.UseUserContext(c) | 		userContext := api.Hooks.UseUserContext(c) | ||||||
|  |  | ||||||
|  | 		// Get headers | ||||||
| 		uri := c.Request.Header.Get("X-Forwarded-Uri") | 		uri := c.Request.Header.Get("X-Forwarded-Uri") | ||||||
| 		proto := c.Request.Header.Get("X-Forwarded-Proto") | 		proto := c.Request.Header.Get("X-Forwarded-Proto") | ||||||
| 		host := c.Request.Header.Get("X-Forwarded-Host") | 		host := c.Request.Header.Get("X-Forwarded-Host") | ||||||
|  |  | ||||||
|  | 		// Check if user is logged in | ||||||
| 		if userContext.IsLoggedIn { | 		if userContext.IsLoggedIn { | ||||||
| 			log.Debug().Msg("Authenticated") | 			log.Debug().Msg("Authenticated") | ||||||
|  |  | ||||||
|  | 			// Check if user is allowed to access subdomain, if request is nginx.example.com the subdomain (resource) is nginx | ||||||
| 			appAllowed, appAllowedErr := api.Auth.ResourceAllowed(userContext, host) | 			appAllowed, appAllowedErr := api.Auth.ResourceAllowed(userContext, host) | ||||||
| 			if handleApiError(c, "Failed to check if resource is allowed", appAllowedErr) { |  | ||||||
|  | 			// Check if there was an error | ||||||
|  | 			if appAllowedErr != nil { | ||||||
|  | 				// Return 501 if nginx is the proxy or if the request is using an Authorization header | ||||||
|  | 				if proxy.Proxy == "nginx" || c.GetHeader("Authorization") != "" { | ||||||
|  | 					log.Error().Err(appAllowedErr).Msg("Failed to check if app is allowed") | ||||||
|  | 					c.JSON(501, gin.H{ | ||||||
|  | 						"status":  501, | ||||||
|  | 						"message": "Internal Server Error", | ||||||
|  | 					}) | ||||||
| 					return | 					return | ||||||
| 				} | 				} | ||||||
|  |  | ||||||
|  | 				// Return the internal server error page | ||||||
|  | 				if api.handleError(c, "Failed to check if app is allowed", appAllowedErr) { | ||||||
|  | 					return | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			log.Debug().Bool("appAllowed", appAllowed).Msg("Checking if app is allowed") | ||||||
|  |  | ||||||
|  | 			// The user is not allowed to access the app | ||||||
| 			if !appAllowed { | 			if !appAllowed { | ||||||
| 				log.Warn().Str("username", userContext.Username).Str("host", host).Msg("User not allowed") | 				log.Warn().Str("username", userContext.Username).Str("host", host).Msg("User not allowed") | ||||||
|  |  | ||||||
|  | 				// Build query | ||||||
| 				queries, queryErr := query.Values(types.UnauthorizedQuery{ | 				queries, queryErr := query.Values(types.UnauthorizedQuery{ | ||||||
| 					Username: userContext.Username, | 					Username: userContext.Username, | ||||||
| 					Resource: strings.Split(host, ".")[0], | 					Resource: strings.Split(host, ".")[0], | ||||||
| 				}) | 				}) | ||||||
| 				if handleApiError(c, "Failed to build query", queryErr) { |  | ||||||
| 					return |  | ||||||
| 				} |  | ||||||
| 				c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", api.Config.AppURL, queries.Encode())) |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			c.JSON(200, gin.H{ |  | ||||||
| 				"status":  200, |  | ||||||
| 				"message": "Authenticated", |  | ||||||
| 			}) |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		queries, queryErr := query.Values(types.LoginQuery{ |  | ||||||
| 			RedirectURI: fmt.Sprintf("%s://%s%s", proto, host, uri), |  | ||||||
| 		}) |  | ||||||
|  |  | ||||||
| 		log.Debug().Interface("redirect_uri", fmt.Sprintf("%s://%s%s", proto, host, uri)).Msg("Redirecting to login") |  | ||||||
|  |  | ||||||
|  | 				// Check if there was an error | ||||||
| 				if queryErr != nil { | 				if queryErr != nil { | ||||||
|  | 					// Return 501 if nginx is the proxy or if the request is using an Authorization header | ||||||
|  | 					if proxy.Proxy == "nginx" || c.GetHeader("Authorization") != "" { | ||||||
| 						log.Error().Err(queryErr).Msg("Failed to build query") | 						log.Error().Err(queryErr).Msg("Failed to build query") | ||||||
| 						c.JSON(501, gin.H{ | 						c.JSON(501, gin.H{ | ||||||
| 							"status":  501, | 							"status":  501, | ||||||
| @@ -145,14 +184,74 @@ func (api *API) SetupRoutes() { | |||||||
| 						return | 						return | ||||||
| 					} | 					} | ||||||
|  |  | ||||||
|  | 					// Return the internal server error page | ||||||
|  | 					if api.handleError(c, "Failed to build query", queryErr) { | ||||||
|  | 						return | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
|  |  | ||||||
|  | 				// Return 401 if nginx is the proxy or if the request is using an Authorization header | ||||||
|  | 				if proxy.Proxy == "nginx" || c.GetHeader("Authorization") != "" { | ||||||
|  | 					c.JSON(401, gin.H{ | ||||||
|  | 						"status":  401, | ||||||
|  | 						"message": "Unauthorized", | ||||||
|  | 					}) | ||||||
|  | 					return | ||||||
|  | 				} | ||||||
|  |  | ||||||
|  | 				// We are using caddy/traefik so redirect | ||||||
|  | 				c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", api.Config.AppURL, queries.Encode())) | ||||||
|  |  | ||||||
|  | 				// Stop further processing | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			// The user is allowed to access the app | ||||||
|  | 			c.JSON(200, gin.H{ | ||||||
|  | 				"status":  200, | ||||||
|  | 				"message": "Authenticated", | ||||||
|  | 			}) | ||||||
|  |  | ||||||
|  | 			// Stop further processing | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		// The user is not logged in | ||||||
|  | 		log.Debug().Msg("Unauthorized") | ||||||
|  |  | ||||||
|  | 		// Return 401 if nginx is the proxy or if the request is using an Authorization header | ||||||
|  | 		if proxy.Proxy == "nginx" || c.GetHeader("Authorization") != "" { | ||||||
|  | 			c.JSON(401, gin.H{ | ||||||
|  | 				"status":  401, | ||||||
|  | 				"message": "Unauthorized", | ||||||
|  | 			}) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		// Build query | ||||||
|  | 		queries, queryErr := query.Values(types.LoginQuery{ | ||||||
|  | 			RedirectURI: fmt.Sprintf("%s://%s%s", proto, host, uri), | ||||||
|  | 		}) | ||||||
|  |  | ||||||
|  | 		log.Debug().Interface("redirect_uri", fmt.Sprintf("%s://%s%s", proto, host, uri)).Msg("Redirecting to login") | ||||||
|  |  | ||||||
|  | 		// Handle error (no need to check for nginx/headers since we are sure we are using caddy/traefik) | ||||||
|  | 		if api.handleError(c, "Failed to build query", queryErr) { | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		// Redirect to login | ||||||
| 		c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/?%s", api.Config.AppURL, queries.Encode())) | 		c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/?%s", api.Config.AppURL, queries.Encode())) | ||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
| 	api.Router.POST("/api/login", func(c *gin.Context) { | 	api.Router.POST("/api/login", func(c *gin.Context) { | ||||||
|  | 		// Create login struct | ||||||
| 		var login types.LoginRequest | 		var login types.LoginRequest | ||||||
|  |  | ||||||
|  | 		// Bind JSON | ||||||
| 		err := c.BindJSON(&login) | 		err := c.BindJSON(&login) | ||||||
|  |  | ||||||
|  | 		// Handle error | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			log.Error().Err(err).Msg("Failed to bind JSON") | 			log.Error().Err(err).Msg("Failed to bind JSON") | ||||||
| 			c.JSON(400, gin.H{ | 			c.JSON(400, gin.H{ | ||||||
| @@ -164,8 +263,10 @@ func (api *API) SetupRoutes() { | |||||||
|  |  | ||||||
| 		log.Debug().Msg("Got login request") | 		log.Debug().Msg("Got login request") | ||||||
|  |  | ||||||
|  | 		// Get user based on username | ||||||
| 		user := api.Auth.GetUser(login.Username) | 		user := api.Auth.GetUser(login.Username) | ||||||
|  |  | ||||||
|  | 		// User does not exist | ||||||
| 		if user == nil { | 		if user == nil { | ||||||
| 			log.Debug().Str("username", login.Username).Msg("User not found") | 			log.Debug().Str("username", login.Username).Msg("User not found") | ||||||
| 			c.JSON(401, gin.H{ | 			c.JSON(401, gin.H{ | ||||||
| @@ -175,6 +276,9 @@ func (api *API) SetupRoutes() { | |||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | 		log.Debug().Msg("Got user") | ||||||
|  |  | ||||||
|  | 		// Check if password is correct | ||||||
| 		if !api.Auth.CheckPassword(*user, login.Password) { | 		if !api.Auth.CheckPassword(*user, login.Password) { | ||||||
| 			log.Debug().Str("username", login.Username).Msg("Password incorrect") | 			log.Debug().Str("username", login.Username).Msg("Password incorrect") | ||||||
| 			c.JSON(401, gin.H{ | 			c.JSON(401, gin.H{ | ||||||
| @@ -186,11 +290,13 @@ func (api *API) SetupRoutes() { | |||||||
|  |  | ||||||
| 		log.Debug().Msg("Password correct, logging in") | 		log.Debug().Msg("Password correct, logging in") | ||||||
|  |  | ||||||
|  | 		// Create session cookie with username as provider | ||||||
| 		api.Auth.CreateSessionCookie(c, &types.SessionCookie{ | 		api.Auth.CreateSessionCookie(c, &types.SessionCookie{ | ||||||
| 			Username: login.Username, | 			Username: login.Username, | ||||||
| 			Provider: "username", | 			Provider: "username", | ||||||
| 		}) | 		}) | ||||||
|  |  | ||||||
|  | 		// Return logged in | ||||||
| 		c.JSON(200, gin.H{ | 		c.JSON(200, gin.H{ | ||||||
| 			"status":  200, | 			"status":  200, | ||||||
| 			"message": "Logged in", | 			"message": "Logged in", | ||||||
| @@ -198,12 +304,17 @@ func (api *API) SetupRoutes() { | |||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
| 	api.Router.POST("/api/logout", func(c *gin.Context) { | 	api.Router.POST("/api/logout", func(c *gin.Context) { | ||||||
|  | 		log.Debug().Msg("Logging out") | ||||||
|  |  | ||||||
|  | 		// Delete session cookie | ||||||
| 		api.Auth.DeleteSessionCookie(c) | 		api.Auth.DeleteSessionCookie(c) | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Cleaning up redirect cookie") | 		log.Debug().Msg("Cleaning up redirect cookie") | ||||||
|  |  | ||||||
|  | 		// Clean up redirect cookie if it exists | ||||||
| 		c.SetCookie("tinyauth_redirect_uri", "", -1, "/", api.Domain, api.Config.CookieSecure, true) | 		c.SetCookie("tinyauth_redirect_uri", "", -1, "/", api.Domain, api.Config.CookieSecure, true) | ||||||
|  |  | ||||||
|  | 		// Return logged out | ||||||
| 		c.JSON(200, gin.H{ | 		c.JSON(200, gin.H{ | ||||||
| 			"status":  200, | 			"status":  200, | ||||||
| 			"message": "Logged out", | 			"message": "Logged out", | ||||||
| @@ -212,19 +323,24 @@ func (api *API) SetupRoutes() { | |||||||
|  |  | ||||||
| 	api.Router.GET("/api/status", func(c *gin.Context) { | 	api.Router.GET("/api/status", func(c *gin.Context) { | ||||||
| 		log.Debug().Msg("Checking status") | 		log.Debug().Msg("Checking status") | ||||||
|  |  | ||||||
|  | 		// Get user context | ||||||
| 		userContext := api.Hooks.UseUserContext(c) | 		userContext := api.Hooks.UseUserContext(c) | ||||||
|  |  | ||||||
|  | 		// Get configured providers | ||||||
| 		configuredProviders := api.Providers.GetConfiguredProviders() | 		configuredProviders := api.Providers.GetConfiguredProviders() | ||||||
|  |  | ||||||
|  | 		// We have username/password configured so add it to our providers | ||||||
| 		if api.Auth.UserAuthConfigured() { | 		if api.Auth.UserAuthConfigured() { | ||||||
| 			configuredProviders = append(configuredProviders, "username") | 			configuredProviders = append(configuredProviders, "username") | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | 		// We are not logged in so return unauthorized | ||||||
| 		if !userContext.IsLoggedIn { | 		if !userContext.IsLoggedIn { | ||||||
| 			log.Debug().Msg("Unauthenticated") | 			log.Debug().Msg("Unauthorized") | ||||||
| 			c.JSON(200, gin.H{ | 			c.JSON(200, gin.H{ | ||||||
| 				"status":              200, | 				"status":              200, | ||||||
| 				"message":             "Unauthenticated", | 				"message":             "Unauthorized", | ||||||
| 				"username":            "", | 				"username":            "", | ||||||
| 				"isLoggedIn":          false, | 				"isLoggedIn":          false, | ||||||
| 				"oauth":               false, | 				"oauth":               false, | ||||||
| @@ -237,6 +353,7 @@ func (api *API) SetupRoutes() { | |||||||
|  |  | ||||||
| 		log.Debug().Interface("userContext", userContext).Strs("configuredProviders", configuredProviders).Bool("disableContinue", api.Config.DisableContinue).Msg("Authenticated") | 		log.Debug().Interface("userContext", userContext).Strs("configuredProviders", configuredProviders).Bool("disableContinue", api.Config.DisableContinue).Msg("Authenticated") | ||||||
|  |  | ||||||
|  | 		// We are logged in so return our user context | ||||||
| 		c.JSON(200, gin.H{ | 		c.JSON(200, gin.H{ | ||||||
| 			"status":              200, | 			"status":              200, | ||||||
| 			"message":             "Authenticated", | 			"message":             "Authenticated", | ||||||
| @@ -249,18 +366,14 @@ func (api *API) SetupRoutes() { | |||||||
| 		}) | 		}) | ||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
| 	api.Router.GET("/api/healthcheck", func(c *gin.Context) { |  | ||||||
| 		c.JSON(200, gin.H{ |  | ||||||
| 			"status":  200, |  | ||||||
| 			"message": "OK", |  | ||||||
| 		}) |  | ||||||
| 	}) |  | ||||||
|  |  | ||||||
| 	api.Router.GET("/api/oauth/url/:provider", func(c *gin.Context) { | 	api.Router.GET("/api/oauth/url/:provider", func(c *gin.Context) { | ||||||
|  | 		// Create struct for OAuth request | ||||||
| 		var request types.OAuthRequest | 		var request types.OAuthRequest | ||||||
|  |  | ||||||
|  | 		// Bind URI | ||||||
| 		bindErr := c.BindUri(&request) | 		bindErr := c.BindUri(&request) | ||||||
|  |  | ||||||
|  | 		// Handle error | ||||||
| 		if bindErr != nil { | 		if bindErr != nil { | ||||||
| 			log.Error().Err(bindErr).Msg("Failed to bind URI") | 			log.Error().Err(bindErr).Msg("Failed to bind URI") | ||||||
| 			c.JSON(400, gin.H{ | 			c.JSON(400, gin.H{ | ||||||
| @@ -272,8 +385,10 @@ func (api *API) SetupRoutes() { | |||||||
|  |  | ||||||
| 		log.Debug().Msg("Got OAuth request") | 		log.Debug().Msg("Got OAuth request") | ||||||
|  |  | ||||||
|  | 		// Check if provider exists | ||||||
| 		provider := api.Providers.GetProvider(request.Provider) | 		provider := api.Providers.GetProvider(request.Provider) | ||||||
|  |  | ||||||
|  | 		// Provider does not exist | ||||||
| 		if provider == nil { | 		if provider == nil { | ||||||
| 			c.JSON(404, gin.H{ | 			c.JSON(404, gin.H{ | ||||||
| 				"status":  404, | 				"status":  404, | ||||||
| @@ -284,24 +399,38 @@ func (api *API) SetupRoutes() { | |||||||
|  |  | ||||||
| 		log.Debug().Str("provider", request.Provider).Msg("Got provider") | 		log.Debug().Str("provider", request.Provider).Msg("Got provider") | ||||||
|  |  | ||||||
|  | 		// Get auth URL | ||||||
| 		authURL := provider.GetAuthURL() | 		authURL := provider.GetAuthURL() | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Got auth URL") | 		log.Debug().Msg("Got auth URL") | ||||||
|  |  | ||||||
|  | 		// Get redirect URI | ||||||
| 		redirectURI := c.Query("redirect_uri") | 		redirectURI := c.Query("redirect_uri") | ||||||
|  |  | ||||||
|  | 		// Set redirect cookie if redirect URI is provided | ||||||
| 		if redirectURI != "" { | 		if redirectURI != "" { | ||||||
| 			log.Debug().Str("redirectURI", redirectURI).Msg("Setting redirect cookie") | 			log.Debug().Str("redirectURI", redirectURI).Msg("Setting redirect cookie") | ||||||
| 			c.SetCookie("tinyauth_redirect_uri", redirectURI, 3600, "/", api.Domain, api.Config.CookieSecure, true) | 			c.SetCookie("tinyauth_redirect_uri", redirectURI, 3600, "/", api.Domain, api.Config.CookieSecure, true) | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | 		// Tailscale does not have an auth url so we create a random code (does not need to be secure) to avoid caching and send it | ||||||
| 		if request.Provider == "tailscale" { | 		if request.Provider == "tailscale" { | ||||||
|  | 			// Build tailscale query | ||||||
| 			tailscaleQuery, tailscaleQueryErr := query.Values(types.TailscaleQuery{ | 			tailscaleQuery, tailscaleQueryErr := query.Values(types.TailscaleQuery{ | ||||||
| 				Code: (1000 + rand.IntN(9000)), // doesn't need to be secure, just there to avoid caching | 				Code: (1000 + rand.IntN(9000)), | ||||||
|  | 			}) | ||||||
|  |  | ||||||
|  | 			// Handle error | ||||||
|  | 			if tailscaleQueryErr != nil { | ||||||
|  | 				log.Error().Err(tailscaleQueryErr).Msg("Failed to build query") | ||||||
|  | 				c.JSON(500, gin.H{ | ||||||
|  | 					"status":  500, | ||||||
|  | 					"message": "Internal Server Error", | ||||||
| 				}) | 				}) | ||||||
| 			if handleApiError(c, "Failed to build query", tailscaleQueryErr) { |  | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
|  | 			// Return tailscale URL (immidiately redirects to the callback) | ||||||
| 			c.JSON(200, gin.H{ | 			c.JSON(200, gin.H{ | ||||||
| 				"status":  200, | 				"status":  200, | ||||||
| 				"message": "Ok", | 				"message": "Ok", | ||||||
| @@ -310,6 +439,7 @@ func (api *API) SetupRoutes() { | |||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | 		// Return auth URL | ||||||
| 		c.JSON(200, gin.H{ | 		c.JSON(200, gin.H{ | ||||||
| 			"status":  200, | 			"status":  200, | ||||||
| 			"message": "Ok", | 			"message": "Ok", | ||||||
| @@ -318,18 +448,23 @@ func (api *API) SetupRoutes() { | |||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
| 	api.Router.GET("/api/oauth/callback/:provider", func(c *gin.Context) { | 	api.Router.GET("/api/oauth/callback/:provider", func(c *gin.Context) { | ||||||
|  | 		// Create struct for OAuth request | ||||||
| 		var providerName types.OAuthRequest | 		var providerName types.OAuthRequest | ||||||
|  |  | ||||||
|  | 		// Bind URI | ||||||
| 		bindErr := c.BindUri(&providerName) | 		bindErr := c.BindUri(&providerName) | ||||||
|  |  | ||||||
| 		if handleApiError(c, "Failed to bind URI", bindErr) { | 		// Handle error | ||||||
|  | 		if api.handleError(c, "Failed to bind URI", bindErr) { | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		log.Debug().Interface("provider", providerName.Provider).Msg("Got provider name") | 		log.Debug().Interface("provider", providerName.Provider).Msg("Got provider name") | ||||||
|  |  | ||||||
|  | 		// Get code | ||||||
| 		code := c.Query("code") | 		code := c.Query("code") | ||||||
|  |  | ||||||
|  | 		// Code empty so redirect to error | ||||||
| 		if code == "" { | 		if code == "" { | ||||||
| 			log.Error().Msg("No code provided") | 			log.Error().Msg("No code provided") | ||||||
| 			c.Redirect(http.StatusPermanentRedirect, "/error") | 			c.Redirect(http.StatusPermanentRedirect, "/error") | ||||||
| @@ -338,51 +473,67 @@ func (api *API) SetupRoutes() { | |||||||
|  |  | ||||||
| 		log.Debug().Msg("Got code") | 		log.Debug().Msg("Got code") | ||||||
|  |  | ||||||
|  | 		// Get provider | ||||||
| 		provider := api.Providers.GetProvider(providerName.Provider) | 		provider := api.Providers.GetProvider(providerName.Provider) | ||||||
|  |  | ||||||
| 		log.Debug().Str("provider", providerName.Provider).Msg("Got provider") | 		log.Debug().Str("provider", providerName.Provider).Msg("Got provider") | ||||||
|  |  | ||||||
|  | 		// Provider does not exist | ||||||
| 		if provider == nil { | 		if provider == nil { | ||||||
| 			c.Redirect(http.StatusPermanentRedirect, "/not-found") | 			c.Redirect(http.StatusPermanentRedirect, "/not-found") | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | 		// Exchange token (authenticates user) | ||||||
| 		_, tokenErr := provider.ExchangeToken(code) | 		_, tokenErr := provider.ExchangeToken(code) | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Got token") | 		log.Debug().Msg("Got token") | ||||||
|  |  | ||||||
| 		if handleApiError(c, "Failed to exchange token", tokenErr) { | 		// Handle error | ||||||
|  | 		if api.handleError(c, "Failed to exchange token", tokenErr) { | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | 		// Get email | ||||||
| 		email, emailErr := api.Providers.GetUser(providerName.Provider) | 		email, emailErr := api.Providers.GetUser(providerName.Provider) | ||||||
|  |  | ||||||
| 		log.Debug().Str("email", email).Msg("Got email") | 		log.Debug().Str("email", email).Msg("Got email") | ||||||
|  |  | ||||||
| 		if handleApiError(c, "Failed to get user", emailErr) { | 		// Handle error | ||||||
|  | 		if api.handleError(c, "Failed to get user", emailErr) { | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | 		// Email is not whitelisted | ||||||
| 		if !api.Auth.EmailWhitelisted(email) { | 		if !api.Auth.EmailWhitelisted(email) { | ||||||
| 			log.Warn().Str("email", email).Msg("Email not whitelisted") | 			log.Warn().Str("email", email).Msg("Email not whitelisted") | ||||||
|  |  | ||||||
|  | 			// Build query | ||||||
| 			unauthorizedQuery, unauthorizedQueryErr := query.Values(types.UnauthorizedQuery{ | 			unauthorizedQuery, unauthorizedQueryErr := query.Values(types.UnauthorizedQuery{ | ||||||
| 				Username: email, | 				Username: email, | ||||||
| 			}) | 			}) | ||||||
| 			if handleApiError(c, "Failed to build query", unauthorizedQueryErr) { |  | ||||||
|  | 			// Handle error | ||||||
|  | 			if api.handleError(c, "Failed to build query", unauthorizedQueryErr) { | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
|  | 			// Redirect to unauthorized | ||||||
| 			c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/unauthorized?%s", api.Config.AppURL, unauthorizedQuery.Encode())) | 			c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/unauthorized?%s", api.Config.AppURL, unauthorizedQuery.Encode())) | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Email whitelisted") | 		log.Debug().Msg("Email whitelisted") | ||||||
|  |  | ||||||
|  | 		// Create session cookie | ||||||
| 		api.Auth.CreateSessionCookie(c, &types.SessionCookie{ | 		api.Auth.CreateSessionCookie(c, &types.SessionCookie{ | ||||||
| 			Username: email, | 			Username: email, | ||||||
| 			Provider: providerName.Provider, | 			Provider: providerName.Provider, | ||||||
| 		}) | 		}) | ||||||
|  |  | ||||||
|  | 		// Get redirect URI | ||||||
| 		redirectURI, redirectURIErr := c.Cookie("tinyauth_redirect_uri") | 		redirectURI, redirectURIErr := c.Cookie("tinyauth_redirect_uri") | ||||||
|  |  | ||||||
|  | 		// If it is empty it means that no redirect_uri was provided to the login screen so we just log in | ||||||
| 		if redirectURIErr != nil { | 		if redirectURIErr != nil { | ||||||
| 			c.JSON(200, gin.H{ | 			c.JSON(200, gin.H{ | ||||||
| 				"status":  200, | 				"status":  200, | ||||||
| @@ -392,40 +543,71 @@ func (api *API) SetupRoutes() { | |||||||
|  |  | ||||||
| 		log.Debug().Str("redirectURI", redirectURI).Msg("Got redirect URI") | 		log.Debug().Str("redirectURI", redirectURI).Msg("Got redirect URI") | ||||||
|  |  | ||||||
|  | 		// Clean up redirect cookie since we already have the value | ||||||
| 		c.SetCookie("tinyauth_redirect_uri", "", -1, "/", api.Domain, api.Config.CookieSecure, true) | 		c.SetCookie("tinyauth_redirect_uri", "", -1, "/", api.Domain, api.Config.CookieSecure, true) | ||||||
|  |  | ||||||
|  | 		// Build query | ||||||
| 		redirectQuery, redirectQueryErr := query.Values(types.LoginQuery{ | 		redirectQuery, redirectQueryErr := query.Values(types.LoginQuery{ | ||||||
| 			RedirectURI: redirectURI, | 			RedirectURI: redirectURI, | ||||||
| 		}) | 		}) | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Got redirect query") | 		log.Debug().Msg("Got redirect query") | ||||||
|  |  | ||||||
| 		if handleApiError(c, "Failed to build query", redirectQueryErr) { | 		// Handle error | ||||||
|  | 		if api.handleError(c, "Failed to build query", redirectQueryErr) { | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | 		// Redirect to continue with the redirect URI | ||||||
| 		c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/continue?%s", api.Config.AppURL, redirectQuery.Encode())) | 		c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/continue?%s", api.Config.AppURL, redirectQuery.Encode())) | ||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
|  | 	// Simple healthcheck | ||||||
|  | 	api.Router.GET("/api/healthcheck", func(c *gin.Context) { | ||||||
|  | 		c.JSON(200, gin.H{ | ||||||
|  | 			"status":  200, | ||||||
|  | 			"message": "OK", | ||||||
|  | 		}) | ||||||
|  | 	}) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (api *API) Run() { | func (api *API) Run() { | ||||||
| 	log.Info().Str("address", api.Config.Address).Int("port", api.Config.Port).Msg("Starting server") | 	log.Info().Str("address", api.Config.Address).Int("port", api.Config.Port).Msg("Starting server") | ||||||
|  |  | ||||||
|  | 	// Run server | ||||||
| 	api.Router.Run(fmt.Sprintf("%s:%d", api.Config.Address, api.Config.Port)) | 	api.Router.Run(fmt.Sprintf("%s:%d", api.Config.Address, api.Config.Port)) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // handleError logs the error and redirects to the error page (only meant for stuff the user may access does not apply for login paths) | ||||||
|  | func (api *API) handleError(c *gin.Context, msg string, err error) bool { | ||||||
|  | 	// If error is not nil log it and redirect to error page also return true so we can stop further processing | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Error().Err(err).Msg(msg) | ||||||
|  | 		c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", api.Config.AppURL)) | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  | 	return false | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // zerolog is a middleware for gin that logs requests using zerolog | ||||||
| func zerolog() gin.HandlerFunc { | func zerolog() gin.HandlerFunc { | ||||||
| 	return func(c *gin.Context) { | 	return func(c *gin.Context) { | ||||||
|  | 		// Get initial time | ||||||
| 		tStart := time.Now() | 		tStart := time.Now() | ||||||
|  |  | ||||||
|  | 		// Process request | ||||||
| 		c.Next() | 		c.Next() | ||||||
|  |  | ||||||
|  | 		// Get status code, address, method and path | ||||||
| 		code := c.Writer.Status() | 		code := c.Writer.Status() | ||||||
| 		address := c.Request.RemoteAddr | 		address := c.Request.RemoteAddr | ||||||
| 		method := c.Request.Method | 		method := c.Request.Method | ||||||
| 		path := c.Request.URL.Path | 		path := c.Request.URL.Path | ||||||
|  |  | ||||||
|  | 		// Get latency | ||||||
| 		latency := time.Since(tStart).String() | 		latency := time.Since(tStart).String() | ||||||
|  |  | ||||||
|  | 		// Log request | ||||||
| 		switch { | 		switch { | ||||||
| 		case code >= 200 && code < 300: | 		case code >= 200 && code < 300: | ||||||
| 			log.Info().Str("method", method).Str("path", path).Str("address", address).Int("status", code).Str("latency", latency).Msg("Request") | 			log.Info().Str("method", method).Str("path", path).Str("address", address).Int("status", code).Str("latency", latency).Msg("Request") | ||||||
| @@ -436,12 +618,3 @@ func zerolog() gin.HandlerFunc { | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func handleApiError(c *gin.Context, msg string, err error) bool { |  | ||||||
| 	if err != nil { |  | ||||||
| 		log.Error().Err(err).Msg(msg) |  | ||||||
| 		c.Redirect(http.StatusPermanentRedirect, "/error") |  | ||||||
| 		return true |  | ||||||
| 	} |  | ||||||
| 	return false |  | ||||||
| } |  | ||||||
|   | |||||||
| @@ -4,8 +4,12 @@ import ( | |||||||
| 	"embed" | 	"embed" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | // UI assets | ||||||
|  | // | ||||||
| //go:embed dist | //go:embed dist | ||||||
| var Assets embed.FS | var Assets embed.FS | ||||||
|  |  | ||||||
|  | // Version file | ||||||
|  | // | ||||||
| //go:embed version | //go:embed version | ||||||
| var Version string | var Version string | ||||||
| @@ -1 +1 @@ | |||||||
| v2.1.0 | v3.0.0 | ||||||
| @@ -3,6 +3,7 @@ package auth | |||||||
| import ( | import ( | ||||||
| 	"slices" | 	"slices" | ||||||
| 	"strings" | 	"strings" | ||||||
|  | 	"time" | ||||||
| 	"tinyauth/internal/docker" | 	"tinyauth/internal/docker" | ||||||
| 	"tinyauth/internal/types" | 	"tinyauth/internal/types" | ||||||
| 	"tinyauth/internal/utils" | 	"tinyauth/internal/utils" | ||||||
| @@ -13,11 +14,12 @@ import ( | |||||||
| 	"golang.org/x/crypto/bcrypt" | 	"golang.org/x/crypto/bcrypt" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func NewAuth(docker *docker.Docker, userList types.Users, oauthWhitelist []string) *Auth { | func NewAuth(docker *docker.Docker, userList types.Users, oauthWhitelist []string, sessionExpiry int) *Auth { | ||||||
| 	return &Auth{ | 	return &Auth{ | ||||||
| 		Docker:         docker, | 		Docker:         docker, | ||||||
| 		Users:          userList, | 		Users:          userList, | ||||||
| 		OAuthWhitelist: oauthWhitelist, | 		OAuthWhitelist: oauthWhitelist, | ||||||
|  | 		SessionExpiry:  sessionExpiry, | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -25,9 +27,11 @@ type Auth struct { | |||||||
| 	Users          types.Users | 	Users          types.Users | ||||||
| 	Docker         *docker.Docker | 	Docker         *docker.Docker | ||||||
| 	OAuthWhitelist []string | 	OAuthWhitelist []string | ||||||
|  | 	SessionExpiry  int | ||||||
| } | } | ||||||
|  |  | ||||||
| func (auth *Auth) GetUser(username string) *types.User { | func (auth *Auth) GetUser(username string) *types.User { | ||||||
|  | 	// Loop through users and return the user if the username matches | ||||||
| 	for _, user := range auth.Users { | 	for _, user := range auth.Users { | ||||||
| 		if user.Username == username { | 		if user.Username == username { | ||||||
| 			return &user | 			return &user | ||||||
| @@ -37,91 +41,150 @@ func (auth *Auth) GetUser(username string) *types.User { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (auth *Auth) CheckPassword(user types.User, password string) bool { | func (auth *Auth) CheckPassword(user types.User, password string) bool { | ||||||
| 	hashedPasswordErr := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) | 	// Compare the hashed password with the password provided | ||||||
| 	return hashedPasswordErr == nil | 	return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) == nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (auth *Auth) EmailWhitelisted(emailSrc string) bool { | func (auth *Auth) EmailWhitelisted(emailSrc string) bool { | ||||||
|  | 	// If the whitelist is empty, allow all emails | ||||||
| 	if len(auth.OAuthWhitelist) == 0 { | 	if len(auth.OAuthWhitelist) == 0 { | ||||||
| 		return true | 		return true | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// Loop through the whitelist and return true if the email matches | ||||||
| 	for _, email := range auth.OAuthWhitelist { | 	for _, email := range auth.OAuthWhitelist { | ||||||
| 		if email == emailSrc { | 		if email == emailSrc { | ||||||
| 			return true | 			return true | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// If no emails match, return false | ||||||
| 	return false | 	return false | ||||||
| } | } | ||||||
|  |  | ||||||
| func (auth *Auth) CreateSessionCookie(c *gin.Context, data *types.SessionCookie) { | func (auth *Auth) CreateSessionCookie(c *gin.Context, data *types.SessionCookie) { | ||||||
| 	log.Debug().Msg("Creating session cookie") | 	log.Debug().Msg("Creating session cookie") | ||||||
|  |  | ||||||
|  | 	// Get session | ||||||
| 	sessions := sessions.Default(c) | 	sessions := sessions.Default(c) | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Setting session cookie") | 	log.Debug().Msg("Setting session cookie") | ||||||
|  |  | ||||||
|  | 	// Set data | ||||||
| 	sessions.Set("username", data.Username) | 	sessions.Set("username", data.Username) | ||||||
| 	sessions.Set("provider", data.Provider) | 	sessions.Set("provider", data.Provider) | ||||||
|  | 	sessions.Set("expiry", time.Now().Add(time.Duration(auth.SessionExpiry)*time.Second).Unix()) | ||||||
|  |  | ||||||
|  | 	// Save session | ||||||
| 	sessions.Save() | 	sessions.Save() | ||||||
| } | } | ||||||
|  |  | ||||||
| func (auth *Auth) DeleteSessionCookie(c *gin.Context) { | func (auth *Auth) DeleteSessionCookie(c *gin.Context) { | ||||||
| 	log.Debug().Msg("Deleting session cookie") | 	log.Debug().Msg("Deleting session cookie") | ||||||
|  |  | ||||||
|  | 	// Get session | ||||||
| 	sessions := sessions.Default(c) | 	sessions := sessions.Default(c) | ||||||
|  |  | ||||||
|  | 	// Clear session | ||||||
| 	sessions.Clear() | 	sessions.Clear() | ||||||
|  |  | ||||||
|  | 	// Save session | ||||||
| 	sessions.Save() | 	sessions.Save() | ||||||
| } | } | ||||||
|  |  | ||||||
| func (auth *Auth) GetSessionCookie(c *gin.Context) (types.SessionCookie, error) { | func (auth *Auth) GetSessionCookie(c *gin.Context) types.SessionCookie { | ||||||
| 	log.Debug().Msg("Getting session cookie") | 	log.Debug().Msg("Getting session cookie") | ||||||
|  |  | ||||||
|  | 	// Get session | ||||||
| 	sessions := sessions.Default(c) | 	sessions := sessions.Default(c) | ||||||
|  |  | ||||||
|  | 	// Get data | ||||||
| 	cookieUsername := sessions.Get("username") | 	cookieUsername := sessions.Get("username") | ||||||
| 	cookieProvider := sessions.Get("provider") | 	cookieProvider := sessions.Get("provider") | ||||||
|  | 	cookieExpiry := sessions.Get("expiry") | ||||||
|  |  | ||||||
|  | 	// Convert interfaces to correct types | ||||||
| 	username, usernameOk := cookieUsername.(string) | 	username, usernameOk := cookieUsername.(string) | ||||||
| 	provider, providerOk := cookieProvider.(string) | 	provider, providerOk := cookieProvider.(string) | ||||||
|  | 	expiry, expiryOk := cookieExpiry.(int64) | ||||||
|  |  | ||||||
| 	log.Debug().Str("username", username).Str("provider", provider).Msg("Parsed cookie") | 	// Check if the cookie is invalid | ||||||
|  | 	if !usernameOk || !providerOk || !expiryOk { | ||||||
| 	if !usernameOk || !providerOk { |  | ||||||
| 		log.Warn().Msg("Session cookie invalid") | 		log.Warn().Msg("Session cookie invalid") | ||||||
| 		return types.SessionCookie{}, nil | 		return types.SessionCookie{} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// Check if the cookie has expired | ||||||
|  | 	if time.Now().Unix() > expiry { | ||||||
|  | 		log.Warn().Msg("Session cookie expired") | ||||||
|  |  | ||||||
|  | 		// If it has, delete it | ||||||
|  | 		auth.DeleteSessionCookie(c) | ||||||
|  |  | ||||||
|  | 		// Return empty cookie | ||||||
|  | 		return types.SessionCookie{} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	log.Debug().Str("username", username).Str("provider", provider).Int64("expiry", expiry).Msg("Parsed cookie") | ||||||
|  |  | ||||||
|  | 	// Return the cookie | ||||||
| 	return types.SessionCookie{ | 	return types.SessionCookie{ | ||||||
| 		Username: username, | 		Username: username, | ||||||
| 		Provider: provider, | 		Provider: provider, | ||||||
| 	}, nil | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func (auth *Auth) UserAuthConfigured() bool { | func (auth *Auth) UserAuthConfigured() bool { | ||||||
|  | 	// If there are users, return true | ||||||
| 	return len(auth.Users) > 0 | 	return len(auth.Users) > 0 | ||||||
| } | } | ||||||
|  |  | ||||||
| func (auth *Auth) ResourceAllowed(context types.UserContext, host string) (bool, error) { | func (auth *Auth) ResourceAllowed(context types.UserContext, host string) (bool, error) { | ||||||
|  | 	// Check if we have access to the Docker API | ||||||
|  | 	isConnected := auth.Docker.DockerConnected() | ||||||
|  |  | ||||||
|  | 	// If we don't have access, it is assumed that the user has access | ||||||
|  | 	if !isConnected { | ||||||
|  | 		log.Debug().Msg("Docker not connected, allowing access") | ||||||
|  | 		return true, nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Get the app ID from the host | ||||||
| 	appId := strings.Split(host, ".")[0] | 	appId := strings.Split(host, ".")[0] | ||||||
|  |  | ||||||
|  | 	// Get the containers | ||||||
| 	containers, containersErr := auth.Docker.GetContainers() | 	containers, containersErr := auth.Docker.GetContainers() | ||||||
|  |  | ||||||
|  | 	// If there is an error, return false | ||||||
| 	if containersErr != nil { | 	if containersErr != nil { | ||||||
| 		return false, containersErr | 		return false, containersErr | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Got containers") | 	log.Debug().Msg("Got containers") | ||||||
|  |  | ||||||
|  | 	// Loop through the containers | ||||||
| 	for _, container := range containers { | 	for _, container := range containers { | ||||||
|  | 		// Inspect the container | ||||||
| 		inspect, inspectErr := auth.Docker.InspectContainer(container.ID) | 		inspect, inspectErr := auth.Docker.InspectContainer(container.ID) | ||||||
|  |  | ||||||
|  | 		// If there is an error, return false | ||||||
| 		if inspectErr != nil { | 		if inspectErr != nil { | ||||||
| 			return false, inspectErr | 			return false, inspectErr | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | 		// Get the container name (for some reason it is /name) | ||||||
| 		containerName := strings.Split(inspect.Name, "/")[1] | 		containerName := strings.Split(inspect.Name, "/")[1] | ||||||
|  |  | ||||||
|  | 		// There is a container with the same name as the app ID | ||||||
| 		if containerName == appId { | 		if containerName == appId { | ||||||
| 			log.Debug().Str("container", containerName).Msg("Found container") | 			log.Debug().Str("container", containerName).Msg("Found container") | ||||||
|  |  | ||||||
|  | 			// Get only the tinyauth labels in a struct | ||||||
| 			labels := utils.GetTinyauthLabels(inspect.Config.Labels) | 			labels := utils.GetTinyauthLabels(inspect.Config.Labels) | ||||||
|  |  | ||||||
| 			log.Debug().Msg("Got labels") | 			log.Debug().Msg("Got labels") | ||||||
|  |  | ||||||
|  | 			// If the container has an oauth whitelist, check if the user is in it | ||||||
| 			if context.OAuth && len(labels.OAuthWhitelist) != 0 { | 			if context.OAuth && len(labels.OAuthWhitelist) != 0 { | ||||||
| 				log.Debug().Msg("Checking OAuth whitelist") | 				log.Debug().Msg("Checking OAuth whitelist") | ||||||
| 				if slices.Contains(labels.OAuthWhitelist, context.Username) { | 				if slices.Contains(labels.OAuthWhitelist, context.Username) { | ||||||
| @@ -130,6 +193,7 @@ func (auth *Auth) ResourceAllowed(context types.UserContext, host string) (bool, | |||||||
| 				return false, nil | 				return false, nil | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
|  | 			// If the container has users, check if the user is in it | ||||||
| 			if len(labels.Users) != 0 { | 			if len(labels.Users) != 0 { | ||||||
| 				log.Debug().Msg("Checking users") | 				log.Debug().Msg("Checking users") | ||||||
| 				if slices.Contains(labels.Users, context.Username) { | 				if slices.Contains(labels.Users, context.Username) { | ||||||
| @@ -143,5 +207,42 @@ func (auth *Auth) ResourceAllowed(context types.UserContext, host string) (bool, | |||||||
|  |  | ||||||
| 	log.Debug().Msg("No matching container found, allowing access") | 	log.Debug().Msg("No matching container found, allowing access") | ||||||
|  |  | ||||||
|  | 	// If no matching container is found, allow access | ||||||
| 	return true, nil | 	return true, nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (auth *Auth) GetBasicAuth(c *gin.Context) types.User { | ||||||
|  | 	// Get the Authorization header | ||||||
|  | 	header := c.GetHeader("Authorization") | ||||||
|  |  | ||||||
|  | 	// If the header is empty, return an empty user | ||||||
|  | 	if header == "" { | ||||||
|  | 		return types.User{} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Split the header | ||||||
|  | 	headerSplit := strings.Split(header, " ") | ||||||
|  |  | ||||||
|  | 	if len(headerSplit) != 2 { | ||||||
|  | 		return types.User{} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Check if the header is Basic | ||||||
|  | 	if headerSplit[0] != "Basic" { | ||||||
|  | 		return types.User{} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Split the credentials | ||||||
|  | 	credentials := strings.Split(headerSplit[1], ":") | ||||||
|  |  | ||||||
|  | 	// If the credentials are not in the correct format, return an empty user | ||||||
|  | 	if len(credentials) != 2 { | ||||||
|  | 		return types.User{} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Return the user | ||||||
|  | 	return types.User{ | ||||||
|  | 		Username: credentials[0], | ||||||
|  | 		Password: credentials[1], | ||||||
|  | 	} | ||||||
|  | } | ||||||
|   | |||||||
| @@ -1,5 +1,6 @@ | |||||||
| package constants | package constants | ||||||
|  |  | ||||||
|  | // TinyauthLabels is a list of labels that can be used in a tinyauth protected container | ||||||
| var TinyauthLabels = []string{ | var TinyauthLabels = []string{ | ||||||
| 	"tinyauth.oauth.whitelist", | 	"tinyauth.oauth.whitelist", | ||||||
| 	"tinyauth.users", | 	"tinyauth.users", | ||||||
|   | |||||||
| @@ -18,34 +18,50 @@ type Docker struct { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (docker *Docker) Init() error { | func (docker *Docker) Init() error { | ||||||
|  | 	// Create a new docker client | ||||||
| 	apiClient, err := client.NewClientWithOpts(client.FromEnv) | 	apiClient, err := client.NewClientWithOpts(client.FromEnv) | ||||||
|  |  | ||||||
|  | 	// Check if there was an error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// Set the context and api client | ||||||
| 	docker.Context = context.Background() | 	docker.Context = context.Background() | ||||||
| 	docker.Client = apiClient | 	docker.Client = apiClient | ||||||
|  |  | ||||||
|  | 	// Done | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (docker *Docker) GetContainers() ([]types.Container, error) { | func (docker *Docker) GetContainers() ([]types.Container, error) { | ||||||
|  | 	// Get the list of containers | ||||||
| 	containers, err := docker.Client.ContainerList(docker.Context, container.ListOptions{}) | 	containers, err := docker.Client.ContainerList(docker.Context, container.ListOptions{}) | ||||||
|  |  | ||||||
|  | 	// Check if there was an error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// Return the containers | ||||||
| 	return containers, nil | 	return containers, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (docker *Docker) InspectContainer(containerId string) (types.ContainerJSON, error) { | func (docker *Docker) InspectContainer(containerId string) (types.ContainerJSON, error) { | ||||||
|  | 	// Inspect the container | ||||||
| 	inspect, err := docker.Client.ContainerInspect(docker.Context, containerId) | 	inspect, err := docker.Client.ContainerInspect(docker.Context, containerId) | ||||||
|  |  | ||||||
|  | 	// Check if there was an error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return types.ContainerJSON{}, err | 		return types.ContainerJSON{}, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// Return the inspect | ||||||
| 	return inspect, nil | 	return inspect, nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (docker *Docker) DockerConnected() bool { | ||||||
|  | 	// Ping the docker client if there is an error it is not connected | ||||||
|  | 	_, err := docker.Client.Ping(docker.Context) | ||||||
|  | 	return err == nil | ||||||
|  | } | ||||||
|   | |||||||
| @@ -22,39 +22,64 @@ type Hooks struct { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { | func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { | ||||||
| 	cookie, cookiErr := hooks.Auth.GetSessionCookie(c) | 	// Get session cookie and basic auth | ||||||
|  | 	cookie := hooks.Auth.GetSessionCookie(c) | ||||||
|  | 	basic := hooks.Auth.GetBasicAuth(c) | ||||||
|  |  | ||||||
| 	if cookiErr != nil { | 	// Check if basic auth is set | ||||||
| 		log.Error().Err(cookiErr).Msg("Failed to get session cookie") | 	if basic.Username != "" { | ||||||
|  | 		log.Debug().Msg("Got basic auth") | ||||||
|  |  | ||||||
|  | 		// Check if user exists and password is correct | ||||||
|  | 		user := hooks.Auth.GetUser(basic.Username) | ||||||
|  |  | ||||||
|  | 		if user != nil && hooks.Auth.CheckPassword(*user, basic.Password) { | ||||||
|  | 			// Return user context since we are logged in with basic auth | ||||||
| 			return types.UserContext{ | 			return types.UserContext{ | ||||||
| 			Username:   "", | 				Username:   basic.Username, | ||||||
| 			IsLoggedIn: false, | 				IsLoggedIn: true, | ||||||
| 				OAuth:      false, | 				OAuth:      false, | ||||||
| 			Provider:   "", | 				Provider:   "basic", | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Check if session cookie is username/password auth | ||||||
| 	if cookie.Provider == "username" { | 	if cookie.Provider == "username" { | ||||||
| 		log.Debug().Msg("Provider is username") | 		log.Debug().Msg("Provider is username") | ||||||
|  |  | ||||||
|  | 		// Check if user exists | ||||||
| 		if hooks.Auth.GetUser(cookie.Username) != nil { | 		if hooks.Auth.GetUser(cookie.Username) != nil { | ||||||
| 			log.Debug().Msg("User exists") | 			log.Debug().Msg("User exists") | ||||||
|  |  | ||||||
|  | 			// It exists so we are logged in | ||||||
| 			return types.UserContext{ | 			return types.UserContext{ | ||||||
| 				Username:   cookie.Username, | 				Username:   cookie.Username, | ||||||
| 				IsLoggedIn: true, | 				IsLoggedIn: true, | ||||||
| 				OAuth:      false, | 				OAuth:      false, | ||||||
| 				Provider:   "", | 				Provider:   "username", | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Provider is not username") | 	log.Debug().Msg("Provider is not username") | ||||||
|  |  | ||||||
|  | 	// The provider is not username so we need to check if it is an oauth provider | ||||||
| 	provider := hooks.Providers.GetProvider(cookie.Provider) | 	provider := hooks.Providers.GetProvider(cookie.Provider) | ||||||
|  |  | ||||||
|  | 	// If we have a provider with this name | ||||||
| 	if provider != nil { | 	if provider != nil { | ||||||
| 		log.Debug().Msg("Provider exists") | 		log.Debug().Msg("Provider exists") | ||||||
|  |  | ||||||
|  | 		// Check if the oauth email is whitelisted | ||||||
| 		if !hooks.Auth.EmailWhitelisted(cookie.Username) { | 		if !hooks.Auth.EmailWhitelisted(cookie.Username) { | ||||||
| 			log.Error().Str("email", cookie.Username).Msg("Email is not whitelisted") | 			log.Error().Str("email", cookie.Username).Msg("Email is not whitelisted") | ||||||
|  |  | ||||||
|  | 			// It isn't so we delete the cookie and return an empty context | ||||||
| 			hooks.Auth.DeleteSessionCookie(c) | 			hooks.Auth.DeleteSessionCookie(c) | ||||||
|  |  | ||||||
|  | 			// Return empty context | ||||||
| 			return types.UserContext{ | 			return types.UserContext{ | ||||||
| 				Username:   "", | 				Username:   "", | ||||||
| 				IsLoggedIn: false, | 				IsLoggedIn: false, | ||||||
| @@ -62,7 +87,10 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { | |||||||
| 				Provider:   "", | 				Provider:   "", | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Email is whitelisted") | 		log.Debug().Msg("Email is whitelisted") | ||||||
|  |  | ||||||
|  | 		// Return user context since we are logged in with oauth | ||||||
| 		return types.UserContext{ | 		return types.UserContext{ | ||||||
| 			Username:   cookie.Username, | 			Username:   cookie.Username, | ||||||
| 			IsLoggedIn: true, | 			IsLoggedIn: true, | ||||||
| @@ -71,6 +99,7 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// Neither basic auth or oauth is set so we return an empty context | ||||||
| 	return types.UserContext{ | 	return types.UserContext{ | ||||||
| 		Username:   "", | 		Username:   "", | ||||||
| 		IsLoggedIn: false, | 		IsLoggedIn: false, | ||||||
|   | |||||||
| @@ -21,23 +21,33 @@ type OAuth struct { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (oauth *OAuth) Init() { | func (oauth *OAuth) Init() { | ||||||
|  | 	// Create a new context and verifier | ||||||
| 	oauth.Context = context.Background() | 	oauth.Context = context.Background() | ||||||
| 	oauth.Verifier = oauth2.GenerateVerifier() | 	oauth.Verifier = oauth2.GenerateVerifier() | ||||||
| } | } | ||||||
|  |  | ||||||
| func (oauth *OAuth) GetAuthURL() string { | func (oauth *OAuth) GetAuthURL() string { | ||||||
|  | 	// Return the auth url | ||||||
| 	return oauth.Config.AuthCodeURL("state", oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(oauth.Verifier)) | 	return oauth.Config.AuthCodeURL("state", oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(oauth.Verifier)) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (oauth *OAuth) ExchangeToken(code string) (string, error) { | func (oauth *OAuth) ExchangeToken(code string) (string, error) { | ||||||
|  | 	// Exchange the code for a token | ||||||
| 	token, err := oauth.Config.Exchange(oauth.Context, code, oauth2.VerifierOption(oauth.Verifier)) | 	token, err := oauth.Config.Exchange(oauth.Context, code, oauth2.VerifierOption(oauth.Verifier)) | ||||||
|  |  | ||||||
|  | 	// Check if there was an error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", err | 		return "", err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// Set the token | ||||||
| 	oauth.Token = token | 	oauth.Token = token | ||||||
|  |  | ||||||
|  | 	// Return the access token | ||||||
| 	return oauth.Token.AccessToken, nil | 	return oauth.Token.AccessToken, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (oauth *OAuth) GetClient() *http.Client { | func (oauth *OAuth) GetClient() *http.Client { | ||||||
|  | 	// Return the http client with the token set | ||||||
| 	return oauth.Config.Client(oauth.Context, oauth.Token) | 	return oauth.Config.Client(oauth.Context, oauth.Token) | ||||||
| } | } | ||||||
|   | |||||||
| @@ -8,36 +8,45 @@ import ( | |||||||
| 	"github.com/rs/zerolog/log" | 	"github.com/rs/zerolog/log" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | // We are assuming that the generic provider will return a JSON object with an email field | ||||||
| type GenericUserInfoResponse struct { | type GenericUserInfoResponse struct { | ||||||
| 	Email string `json:"email"` | 	Email string `json:"email"` | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetGenericEmail(client *http.Client, url string) (string, error) { | func GetGenericEmail(client *http.Client, url string) (string, error) { | ||||||
|  | 	// Using the oauth client get the user info url | ||||||
| 	res, resErr := client.Get(url) | 	res, resErr := client.Get(url) | ||||||
|  |  | ||||||
|  | 	// Check if there was an error | ||||||
| 	if resErr != nil { | 	if resErr != nil { | ||||||
| 		return "", resErr | 		return "", resErr | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Got response from generic provider") | 	log.Debug().Msg("Got response from generic provider") | ||||||
|  |  | ||||||
|  | 	// Read the body of the response | ||||||
| 	body, bodyErr := io.ReadAll(res.Body) | 	body, bodyErr := io.ReadAll(res.Body) | ||||||
|  |  | ||||||
|  | 	// Check if there was an error | ||||||
| 	if bodyErr != nil { | 	if bodyErr != nil { | ||||||
| 		return "", bodyErr | 		return "", bodyErr | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Read body from generic provider") | 	log.Debug().Msg("Read body from generic provider") | ||||||
|  |  | ||||||
|  | 	// Parse the body into a user struct | ||||||
| 	var user GenericUserInfoResponse | 	var user GenericUserInfoResponse | ||||||
|  |  | ||||||
|  | 	// Unmarshal the body into the user struct | ||||||
| 	jsonErr := json.Unmarshal(body, &user) | 	jsonErr := json.Unmarshal(body, &user) | ||||||
|  |  | ||||||
|  | 	// Check if there was an error | ||||||
| 	if jsonErr != nil { | 	if jsonErr != nil { | ||||||
| 		return "", jsonErr | 		return "", jsonErr | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Parsed user from generic provider") | 	log.Debug().Msg("Parsed user from generic provider") | ||||||
|  |  | ||||||
|  | 	// Return the email | ||||||
| 	return user.Email, nil | 	return user.Email, nil | ||||||
| } | } | ||||||
|   | |||||||
| @@ -9,47 +9,58 @@ import ( | |||||||
| 	"github.com/rs/zerolog/log" | 	"github.com/rs/zerolog/log" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | // Github has a different response than the generic provider | ||||||
| type GithubUserInfoResponse []struct { | type GithubUserInfoResponse []struct { | ||||||
| 	Email   string `json:"email"` | 	Email   string `json:"email"` | ||||||
| 	Primary bool   `json:"primary"` | 	Primary bool   `json:"primary"` | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // The scopes required for the github provider | ||||||
| func GithubScopes() []string { | func GithubScopes() []string { | ||||||
| 	return []string{"user:email"} | 	return []string{"user:email"} | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetGithubEmail(client *http.Client) (string, error) { | func GetGithubEmail(client *http.Client) (string, error) { | ||||||
|  | 	// Get the user emails from github using the oauth http client | ||||||
| 	res, resErr := client.Get("https://api.github.com/user/emails") | 	res, resErr := client.Get("https://api.github.com/user/emails") | ||||||
|  |  | ||||||
|  | 	// Check if there was an error | ||||||
| 	if resErr != nil { | 	if resErr != nil { | ||||||
| 		return "", resErr | 		return "", resErr | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Got response from github") | 	log.Debug().Msg("Got response from github") | ||||||
|  |  | ||||||
|  | 	// Read the body of the response | ||||||
| 	body, bodyErr := io.ReadAll(res.Body) | 	body, bodyErr := io.ReadAll(res.Body) | ||||||
|  |  | ||||||
|  | 	// Check if there was an error | ||||||
| 	if bodyErr != nil { | 	if bodyErr != nil { | ||||||
| 		return "", bodyErr | 		return "", bodyErr | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Read body from github") | 	log.Debug().Msg("Read body from github") | ||||||
|  |  | ||||||
|  | 	// Parse the body into a user struct | ||||||
| 	var emails GithubUserInfoResponse | 	var emails GithubUserInfoResponse | ||||||
|  |  | ||||||
|  | 	// Unmarshal the body into the user struct | ||||||
| 	jsonErr := json.Unmarshal(body, &emails) | 	jsonErr := json.Unmarshal(body, &emails) | ||||||
|  |  | ||||||
|  | 	// Check if there was an error | ||||||
| 	if jsonErr != nil { | 	if jsonErr != nil { | ||||||
| 		return "", jsonErr | 		return "", jsonErr | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Parsed emails from github") | 	log.Debug().Msg("Parsed emails from github") | ||||||
|  |  | ||||||
|  | 	// Find and return the primary email | ||||||
| 	for _, email := range emails { | 	for _, email := range emails { | ||||||
| 		if email.Primary { | 		if email.Primary { | ||||||
| 			return email.Email, nil | 			return email.Email, nil | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// User does not have a primary email? | ||||||
| 	return "", errors.New("no primary email found") | 	return "", errors.New("no primary email found") | ||||||
| } | } | ||||||
|   | |||||||
| @@ -8,40 +8,50 @@ import ( | |||||||
| 	"github.com/rs/zerolog/log" | 	"github.com/rs/zerolog/log" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | // Google works the same as the generic provider | ||||||
| type GoogleUserInfoResponse struct { | type GoogleUserInfoResponse struct { | ||||||
| 	Email string `json:"email"` | 	Email string `json:"email"` | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // The scopes required for the google provider | ||||||
| func GoogleScopes() []string { | func GoogleScopes() []string { | ||||||
| 	return []string{"https://www.googleapis.com/auth/userinfo.email"} | 	return []string{"https://www.googleapis.com/auth/userinfo.email"} | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetGoogleEmail(client *http.Client) (string, error) { | func GetGoogleEmail(client *http.Client) (string, error) { | ||||||
|  | 	// Get the user info from google using the oauth http client | ||||||
| 	res, resErr := client.Get("https://www.googleapis.com/userinfo/v2/me") | 	res, resErr := client.Get("https://www.googleapis.com/userinfo/v2/me") | ||||||
|  |  | ||||||
|  | 	// Check if there was an error | ||||||
| 	if resErr != nil { | 	if resErr != nil { | ||||||
| 		return "", resErr | 		return "", resErr | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Got response from google") | 	log.Debug().Msg("Got response from google") | ||||||
|  |  | ||||||
|  | 	// Read the body of the response | ||||||
| 	body, bodyErr := io.ReadAll(res.Body) | 	body, bodyErr := io.ReadAll(res.Body) | ||||||
|  |  | ||||||
|  | 	// Check if there was an error | ||||||
| 	if bodyErr != nil { | 	if bodyErr != nil { | ||||||
| 		return "", bodyErr | 		return "", bodyErr | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Read body from google") | 	log.Debug().Msg("Read body from google") | ||||||
|  |  | ||||||
|  | 	// Parse the body into a user struct | ||||||
| 	var user GoogleUserInfoResponse | 	var user GoogleUserInfoResponse | ||||||
|  |  | ||||||
|  | 	// Unmarshal the body into the user struct | ||||||
| 	jsonErr := json.Unmarshal(body, &user) | 	jsonErr := json.Unmarshal(body, &user) | ||||||
|  |  | ||||||
|  | 	// Check if there was an error | ||||||
| 	if jsonErr != nil { | 	if jsonErr != nil { | ||||||
| 		return "", jsonErr | 		return "", jsonErr | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Parsed user from google") | 	log.Debug().Msg("Parsed user from google") | ||||||
|  |  | ||||||
|  | 	// Return the email | ||||||
| 	return user.Email, nil | 	return user.Email, nil | ||||||
| } | } | ||||||
|   | |||||||
| @@ -25,8 +25,11 @@ type Providers struct { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (providers *Providers) Init() { | func (providers *Providers) Init() { | ||||||
|  | 	// If we have a client id and secret for github, initialize the oauth provider | ||||||
| 	if providers.Config.GithubClientId != "" && providers.Config.GithubClientSecret != "" { | 	if providers.Config.GithubClientId != "" && providers.Config.GithubClientSecret != "" { | ||||||
| 		log.Info().Msg("Initializing Github OAuth") | 		log.Info().Msg("Initializing Github OAuth") | ||||||
|  |  | ||||||
|  | 		// Create a new oauth provider with the github config | ||||||
| 		providers.Github = oauth.NewOAuth(oauth2.Config{ | 		providers.Github = oauth.NewOAuth(oauth2.Config{ | ||||||
| 			ClientID:     providers.Config.GithubClientId, | 			ClientID:     providers.Config.GithubClientId, | ||||||
| 			ClientSecret: providers.Config.GithubClientSecret, | 			ClientSecret: providers.Config.GithubClientSecret, | ||||||
| @@ -34,10 +37,16 @@ func (providers *Providers) Init() { | |||||||
| 			Scopes:       GithubScopes(), | 			Scopes:       GithubScopes(), | ||||||
| 			Endpoint:     endpoints.GitHub, | 			Endpoint:     endpoints.GitHub, | ||||||
| 		}) | 		}) | ||||||
|  |  | ||||||
|  | 		// Initialize the oauth provider | ||||||
| 		providers.Github.Init() | 		providers.Github.Init() | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// If we have a client id and secret for google, initialize the oauth provider | ||||||
| 	if providers.Config.GoogleClientId != "" && providers.Config.GoogleClientSecret != "" { | 	if providers.Config.GoogleClientId != "" && providers.Config.GoogleClientSecret != "" { | ||||||
| 		log.Info().Msg("Initializing Google OAuth") | 		log.Info().Msg("Initializing Google OAuth") | ||||||
|  |  | ||||||
|  | 		// Create a new oauth provider with the google config | ||||||
| 		providers.Google = oauth.NewOAuth(oauth2.Config{ | 		providers.Google = oauth.NewOAuth(oauth2.Config{ | ||||||
| 			ClientID:     providers.Config.GoogleClientId, | 			ClientID:     providers.Config.GoogleClientId, | ||||||
| 			ClientSecret: providers.Config.GoogleClientSecret, | 			ClientSecret: providers.Config.GoogleClientSecret, | ||||||
| @@ -45,10 +54,15 @@ func (providers *Providers) Init() { | |||||||
| 			Scopes:       GoogleScopes(), | 			Scopes:       GoogleScopes(), | ||||||
| 			Endpoint:     endpoints.Google, | 			Endpoint:     endpoints.Google, | ||||||
| 		}) | 		}) | ||||||
|  |  | ||||||
|  | 		// Initialize the oauth provider | ||||||
| 		providers.Google.Init() | 		providers.Google.Init() | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if providers.Config.TailscaleClientId != "" && providers.Config.TailscaleClientSecret != "" { | 	if providers.Config.TailscaleClientId != "" && providers.Config.TailscaleClientSecret != "" { | ||||||
| 		log.Info().Msg("Initializing Tailscale OAuth") | 		log.Info().Msg("Initializing Tailscale OAuth") | ||||||
|  |  | ||||||
|  | 		// Create a new oauth provider with the tailscale config | ||||||
| 		providers.Tailscale = oauth.NewOAuth(oauth2.Config{ | 		providers.Tailscale = oauth.NewOAuth(oauth2.Config{ | ||||||
| 			ClientID:     providers.Config.TailscaleClientId, | 			ClientID:     providers.Config.TailscaleClientId, | ||||||
| 			ClientSecret: providers.Config.TailscaleClientSecret, | 			ClientSecret: providers.Config.TailscaleClientSecret, | ||||||
| @@ -56,10 +70,16 @@ func (providers *Providers) Init() { | |||||||
| 			Scopes:       TailscaleScopes(), | 			Scopes:       TailscaleScopes(), | ||||||
| 			Endpoint:     TailscaleEndpoint, | 			Endpoint:     TailscaleEndpoint, | ||||||
| 		}) | 		}) | ||||||
|  |  | ||||||
|  | 		// Initialize the oauth provider | ||||||
| 		providers.Tailscale.Init() | 		providers.Tailscale.Init() | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// If we have a client id and secret for generic oauth, initialize the oauth provider | ||||||
| 	if providers.Config.GenericClientId != "" && providers.Config.GenericClientSecret != "" { | 	if providers.Config.GenericClientId != "" && providers.Config.GenericClientSecret != "" { | ||||||
| 		log.Info().Msg("Initializing Generic OAuth") | 		log.Info().Msg("Initializing Generic OAuth") | ||||||
|  |  | ||||||
|  | 		// Create a new oauth provider with the generic config | ||||||
| 		providers.Generic = oauth.NewOAuth(oauth2.Config{ | 		providers.Generic = oauth.NewOAuth(oauth2.Config{ | ||||||
| 			ClientID:     providers.Config.GenericClientId, | 			ClientID:     providers.Config.GenericClientId, | ||||||
| 			ClientSecret: providers.Config.GenericClientSecret, | 			ClientSecret: providers.Config.GenericClientSecret, | ||||||
| @@ -70,11 +90,14 @@ func (providers *Providers) Init() { | |||||||
| 				TokenURL: providers.Config.GenericTokenURL, | 				TokenURL: providers.Config.GenericTokenURL, | ||||||
| 			}, | 			}, | ||||||
| 		}) | 		}) | ||||||
|  |  | ||||||
|  | 		// Initialize the oauth provider | ||||||
| 		providers.Generic.Init() | 		providers.Generic.Init() | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func (providers *Providers) GetProvider(provider string) *oauth.OAuth { | func (providers *Providers) GetProvider(provider string) *oauth.OAuth { | ||||||
|  | 	// Return the provider based on the provider string | ||||||
| 	switch provider { | 	switch provider { | ||||||
| 	case "github": | 	case "github": | ||||||
| 		return providers.Github | 		return providers.Github | ||||||
| @@ -90,58 +113,103 @@ func (providers *Providers) GetProvider(provider string) *oauth.OAuth { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (providers *Providers) GetUser(provider string) (string, error) { | func (providers *Providers) GetUser(provider string) (string, error) { | ||||||
|  | 	// Get the email from the provider | ||||||
| 	switch provider { | 	switch provider { | ||||||
| 	case "github": | 	case "github": | ||||||
|  | 		// If the github provider is not configured, return an error | ||||||
| 		if providers.Github == nil { | 		if providers.Github == nil { | ||||||
| 			log.Debug().Msg("Github provider not configured") | 			log.Debug().Msg("Github provider not configured") | ||||||
| 			return "", nil | 			return "", nil | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | 		// Get the client from the github provider | ||||||
| 		client := providers.Github.GetClient() | 		client := providers.Github.GetClient() | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Got client from github") | 		log.Debug().Msg("Got client from github") | ||||||
|  |  | ||||||
|  | 		// Get the email from the github provider | ||||||
| 		email, emailErr := GetGithubEmail(client) | 		email, emailErr := GetGithubEmail(client) | ||||||
|  |  | ||||||
|  | 		// Check if there was an error | ||||||
| 		if emailErr != nil { | 		if emailErr != nil { | ||||||
| 			return "", emailErr | 			return "", emailErr | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Got email from github") | 		log.Debug().Msg("Got email from github") | ||||||
|  |  | ||||||
|  | 		// Return the email | ||||||
| 		return email, nil | 		return email, nil | ||||||
| 	case "google": | 	case "google": | ||||||
|  | 		// If the google provider is not configured, return an error | ||||||
| 		if providers.Google == nil { | 		if providers.Google == nil { | ||||||
| 			log.Debug().Msg("Google provider not configured") | 			log.Debug().Msg("Google provider not configured") | ||||||
| 			return "", nil | 			return "", nil | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | 		// Get the client from the google provider | ||||||
| 		client := providers.Google.GetClient() | 		client := providers.Google.GetClient() | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Got client from google") | 		log.Debug().Msg("Got client from google") | ||||||
|  |  | ||||||
|  | 		// Get the email from the google provider | ||||||
| 		email, emailErr := GetGoogleEmail(client) | 		email, emailErr := GetGoogleEmail(client) | ||||||
|  |  | ||||||
|  | 		// Check if there was an error | ||||||
| 		if emailErr != nil { | 		if emailErr != nil { | ||||||
| 			return "", emailErr | 			return "", emailErr | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Got email from google") | 		log.Debug().Msg("Got email from google") | ||||||
|  |  | ||||||
|  | 		// Return the email | ||||||
| 		return email, nil | 		return email, nil | ||||||
| 	case "tailscale": | 	case "tailscale": | ||||||
|  | 		// If the tailscale provider is not configured, return an error | ||||||
| 		if providers.Tailscale == nil { | 		if providers.Tailscale == nil { | ||||||
| 			log.Debug().Msg("Tailscale provider not configured") | 			log.Debug().Msg("Tailscale provider not configured") | ||||||
| 			return "", nil | 			return "", nil | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | 		// Get the client from the tailscale provider | ||||||
| 		client := providers.Tailscale.GetClient() | 		client := providers.Tailscale.GetClient() | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Got client from tailscale") | 		log.Debug().Msg("Got client from tailscale") | ||||||
|  |  | ||||||
|  | 		// Get the email from the tailscale provider | ||||||
| 		email, emailErr := GetTailscaleEmail(client) | 		email, emailErr := GetTailscaleEmail(client) | ||||||
|  |  | ||||||
|  | 		// Check if there was an error | ||||||
| 		if emailErr != nil { | 		if emailErr != nil { | ||||||
| 			return "", emailErr | 			return "", emailErr | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Got email from tailscale") | 		log.Debug().Msg("Got email from tailscale") | ||||||
|  |  | ||||||
|  | 		// Return the email | ||||||
| 		return email, nil | 		return email, nil | ||||||
| 	case "generic": | 	case "generic": | ||||||
|  | 		// If the generic provider is not configured, return an error | ||||||
| 		if providers.Generic == nil { | 		if providers.Generic == nil { | ||||||
| 			log.Debug().Msg("Generic provider not configured") | 			log.Debug().Msg("Generic provider not configured") | ||||||
| 			return "", nil | 			return "", nil | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | 		// Get the client from the generic provider | ||||||
| 		client := providers.Generic.GetClient() | 		client := providers.Generic.GetClient() | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Got client from generic") | 		log.Debug().Msg("Got client from generic") | ||||||
|  |  | ||||||
|  | 		// Get the email from the generic provider | ||||||
| 		email, emailErr := GetGenericEmail(client, providers.Config.GenericUserURL) | 		email, emailErr := GetGenericEmail(client, providers.Config.GenericUserURL) | ||||||
|  |  | ||||||
|  | 		// Check if there was an error | ||||||
| 		if emailErr != nil { | 		if emailErr != nil { | ||||||
| 			return "", emailErr | 			return "", emailErr | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Got email from generic") | 		log.Debug().Msg("Got email from generic") | ||||||
|  |  | ||||||
|  | 		// Return the email | ||||||
| 		return email, nil | 		return email, nil | ||||||
| 	default: | 	default: | ||||||
| 		return "", nil | 		return "", nil | ||||||
| @@ -149,6 +217,7 @@ func (providers *Providers) GetUser(provider string) (string, error) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (provider *Providers) GetConfiguredProviders() []string { | func (provider *Providers) GetConfiguredProviders() []string { | ||||||
|  | 	// Create a list of the configured providers | ||||||
| 	providers := []string{} | 	providers := []string{} | ||||||
| 	if provider.Github != nil { | 	if provider.Github != nil { | ||||||
| 		providers = append(providers, "github") | 		providers = append(providers, "github") | ||||||
|   | |||||||
| @@ -9,48 +9,60 @@ import ( | |||||||
| 	"golang.org/x/oauth2" | 	"golang.org/x/oauth2" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | // The tailscale email is the loginName | ||||||
| type TailscaleUser struct { | type TailscaleUser struct { | ||||||
| 	LoginName string `json:"loginName"` | 	LoginName string `json:"loginName"` | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // The response from the tailscale user info endpoint | ||||||
| type TailscaleUserInfoResponse struct { | type TailscaleUserInfoResponse struct { | ||||||
| 	Users []TailscaleUser `json:"users"` | 	Users []TailscaleUser `json:"users"` | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // The scopes required for the tailscale provider | ||||||
| func TailscaleScopes() []string { | func TailscaleScopes() []string { | ||||||
| 	return []string{"users:read"} | 	return []string{"users:read"} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // The tailscale endpoint | ||||||
| var TailscaleEndpoint = oauth2.Endpoint{ | var TailscaleEndpoint = oauth2.Endpoint{ | ||||||
| 	TokenURL: "https://api.tailscale.com/api/v2/oauth/token", | 	TokenURL: "https://api.tailscale.com/api/v2/oauth/token", | ||||||
| } | } | ||||||
|  |  | ||||||
| func GetTailscaleEmail(client *http.Client) (string, error) { | func GetTailscaleEmail(client *http.Client) (string, error) { | ||||||
|  | 	// Get the user info from tailscale using the oauth http client | ||||||
| 	res, resErr := client.Get("https://api.tailscale.com/api/v2/tailnet/-/users") | 	res, resErr := client.Get("https://api.tailscale.com/api/v2/tailnet/-/users") | ||||||
|  |  | ||||||
|  | 	// Check if there was an error | ||||||
| 	if resErr != nil { | 	if resErr != nil { | ||||||
| 		return "", resErr | 		return "", resErr | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Got response from tailscale") | 	log.Debug().Msg("Got response from tailscale") | ||||||
|  |  | ||||||
|  | 	// Read the body of the response | ||||||
| 	body, bodyErr := io.ReadAll(res.Body) | 	body, bodyErr := io.ReadAll(res.Body) | ||||||
|  |  | ||||||
|  | 	// Check if there was an error | ||||||
| 	if bodyErr != nil { | 	if bodyErr != nil { | ||||||
| 		return "", bodyErr | 		return "", bodyErr | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Read body from tailscale") | 	log.Debug().Msg("Read body from tailscale") | ||||||
|  |  | ||||||
|  | 	// Parse the body into a user struct | ||||||
| 	var users TailscaleUserInfoResponse | 	var users TailscaleUserInfoResponse | ||||||
|  |  | ||||||
|  | 	// Unmarshal the body into the user struct | ||||||
| 	jsonErr := json.Unmarshal(body, &users) | 	jsonErr := json.Unmarshal(body, &users) | ||||||
|  |  | ||||||
|  | 	// Check if there was an error | ||||||
| 	if jsonErr != nil { | 	if jsonErr != nil { | ||||||
| 		return "", jsonErr | 		return "", jsonErr | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Parsed users from tailscale") | 	log.Debug().Msg("Parsed users from tailscale") | ||||||
|  |  | ||||||
|  | 	// Return the email of the first user | ||||||
| 	return users.Users[0].LoginName, nil | 	return users.Users[0].LoginName, nil | ||||||
| } | } | ||||||
|   | |||||||
| @@ -2,22 +2,27 @@ package types | |||||||
|  |  | ||||||
| import "tinyauth/internal/oauth" | import "tinyauth/internal/oauth" | ||||||
|  |  | ||||||
|  | // LoginQuery is the query parameters for the login endpoint | ||||||
| type LoginQuery struct { | type LoginQuery struct { | ||||||
| 	RedirectURI string `url:"redirect_uri"` | 	RedirectURI string `url:"redirect_uri"` | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // LoginRequest is the request body for the login endpoint | ||||||
| type LoginRequest struct { | type LoginRequest struct { | ||||||
| 	Username string `json:"username"` | 	Username string `json:"username"` | ||||||
| 	Password string `json:"password"` | 	Password string `json:"password"` | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // User is the struct for a user | ||||||
| type User struct { | type User struct { | ||||||
| 	Username string | 	Username string | ||||||
| 	Password string | 	Password string | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // Users is a list of users | ||||||
| type Users []User | type Users []User | ||||||
|  |  | ||||||
|  | // Config is the configuration for the tinyauth server | ||||||
| type Config struct { | type Config struct { | ||||||
| 	Port                      int    `mapstructure:"port" validate:"required"` | 	Port                      int    `mapstructure:"port" validate:"required"` | ||||||
| 	Address                   string `validate:"required,ip4_addr" mapstructure:"address"` | 	Address                   string `validate:"required,ip4_addr" mapstructure:"address"` | ||||||
| @@ -45,10 +50,11 @@ type Config struct { | |||||||
| 	GenericUserURL            string `mapstructure:"generic-user-url"` | 	GenericUserURL            string `mapstructure:"generic-user-url"` | ||||||
| 	DisableContinue           bool   `mapstructure:"disable-continue"` | 	DisableContinue           bool   `mapstructure:"disable-continue"` | ||||||
| 	OAuthWhitelist            string `mapstructure:"oauth-whitelist"` | 	OAuthWhitelist            string `mapstructure:"oauth-whitelist"` | ||||||
| 	CookieExpiry              int    `mapstructure:"cookie-expiry"` | 	SessionExpiry             int    `mapstructure:"session-expiry"` | ||||||
| 	LogLevel                  int8   `mapstructure:"log-level" validate:"min=-1,max=5"` | 	LogLevel                  int8   `mapstructure:"log-level" validate:"min=-1,max=5"` | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // UserContext is the context for the user | ||||||
| type UserContext struct { | type UserContext struct { | ||||||
| 	Username   string | 	Username   string | ||||||
| 	IsLoggedIn bool | 	IsLoggedIn bool | ||||||
| @@ -56,6 +62,7 @@ type UserContext struct { | |||||||
| 	Provider   string | 	Provider   string | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // APIConfig is the configuration for the API | ||||||
| type APIConfig struct { | type APIConfig struct { | ||||||
| 	Port            int | 	Port            int | ||||||
| 	Address         string | 	Address         string | ||||||
| @@ -66,6 +73,7 @@ type APIConfig struct { | |||||||
| 	DisableContinue bool | 	DisableContinue bool | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // OAuthConfig is the configuration for the providers | ||||||
| type OAuthConfig struct { | type OAuthConfig struct { | ||||||
| 	GithubClientId        string | 	GithubClientId        string | ||||||
| 	GithubClientSecret    string | 	GithubClientSecret    string | ||||||
| @@ -82,31 +90,42 @@ type OAuthConfig struct { | |||||||
| 	AppURL                string | 	AppURL                string | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // OAuthRequest is the request for the OAuth endpoint | ||||||
| type OAuthRequest struct { | type OAuthRequest struct { | ||||||
| 	Provider string `uri:"provider" binding:"required"` | 	Provider string `uri:"provider" binding:"required"` | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // OAuthProviders is the struct for the OAuth providers | ||||||
| type OAuthProviders struct { | type OAuthProviders struct { | ||||||
| 	Github    *oauth.OAuth | 	Github    *oauth.OAuth | ||||||
| 	Google    *oauth.OAuth | 	Google    *oauth.OAuth | ||||||
| 	Microsoft *oauth.OAuth | 	Microsoft *oauth.OAuth | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // UnauthorizedQuery is the query parameters for the unauthorized endpoint | ||||||
| type UnauthorizedQuery struct { | type UnauthorizedQuery struct { | ||||||
| 	Username string `url:"username"` | 	Username string `url:"username"` | ||||||
| 	Resource string `url:"resource"` | 	Resource string `url:"resource"` | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // SessionCookie is the cookie for the session (exculding the expiry) | ||||||
| type SessionCookie struct { | type SessionCookie struct { | ||||||
| 	Username string | 	Username string | ||||||
| 	Provider string | 	Provider string | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // TinyauthLabels is the labels for the tinyauth container | ||||||
| type TinyauthLabels struct { | type TinyauthLabels struct { | ||||||
| 	OAuthWhitelist []string | 	OAuthWhitelist []string | ||||||
| 	Users          []string | 	Users          []string | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // TailscaleQuery is the query parameters for the tailscale endpoint | ||||||
| type TailscaleQuery struct { | type TailscaleQuery struct { | ||||||
| 	Code int `url:"code"` | 	Code int `url:"code"` | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // Proxy is the uri parameters for the proxy endpoint | ||||||
|  | type Proxy struct { | ||||||
|  | 	Proxy string `uri:"proxy" binding:"required"` | ||||||
|  | } | ||||||
|   | |||||||
| @@ -12,20 +12,32 @@ import ( | |||||||
| 	"github.com/rs/zerolog/log" | 	"github.com/rs/zerolog/log" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | // Parses a list of comma separated users in a struct | ||||||
| func ParseUsers(users string) (types.Users, error) { | func ParseUsers(users string) (types.Users, error) { | ||||||
| 	log.Debug().Msg("Parsing users") | 	log.Debug().Msg("Parsing users") | ||||||
|  |  | ||||||
|  | 	// Create a new users struct | ||||||
| 	var usersParsed types.Users | 	var usersParsed types.Users | ||||||
|  |  | ||||||
|  | 	// Split the users by comma | ||||||
| 	userList := strings.Split(users, ",") | 	userList := strings.Split(users, ",") | ||||||
|  |  | ||||||
|  | 	// Check if there are any users | ||||||
| 	if len(userList) == 0 { | 	if len(userList) == 0 { | ||||||
| 		return types.Users{}, errors.New("invalid user format") | 		return types.Users{}, errors.New("invalid user format") | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// Loop through the users and split them by colon | ||||||
| 	for _, user := range userList { | 	for _, user := range userList { | ||||||
|  | 		// Split the user by colon | ||||||
| 		userSplit := strings.Split(user, ":") | 		userSplit := strings.Split(user, ":") | ||||||
|  |  | ||||||
|  | 		// Check if the user is in the correct format | ||||||
| 		if len(userSplit) != 2 { | 		if len(userSplit) != 2 { | ||||||
| 			return types.Users{}, errors.New("invalid user format") | 			return types.Users{}, errors.New("invalid user format") | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | 		// Append the user to the users struct | ||||||
| 		usersParsed = append(usersParsed, types.User{ | 		usersParsed = append(usersParsed, types.User{ | ||||||
| 			Username: userSplit[0], | 			Username: userSplit[0], | ||||||
| 			Password: userSplit[1], | 			Password: userSplit[1], | ||||||
| @@ -34,43 +46,61 @@ func ParseUsers(users string) (types.Users, error) { | |||||||
|  |  | ||||||
| 	log.Debug().Msg("Parsed users") | 	log.Debug().Msg("Parsed users") | ||||||
|  |  | ||||||
|  | 	// Return the users struct | ||||||
| 	return usersParsed, nil | 	return usersParsed, nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // Root url parses parses a hostname and returns the root domain (e.g. sub1.sub2.domain.com -> domain.com) | ||||||
| func GetRootURL(urlSrc string) (string, error) { | func GetRootURL(urlSrc string) (string, error) { | ||||||
|  | 	// Make sure the url is valid | ||||||
| 	urlParsed, parseErr := url.Parse(urlSrc) | 	urlParsed, parseErr := url.Parse(urlSrc) | ||||||
|  |  | ||||||
|  | 	// Check if there was an error | ||||||
| 	if parseErr != nil { | 	if parseErr != nil { | ||||||
| 		return "", parseErr | 		return "", parseErr | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// Split the hostname by period | ||||||
| 	urlSplitted := strings.Split(urlParsed.Hostname(), ".") | 	urlSplitted := strings.Split(urlParsed.Hostname(), ".") | ||||||
|  |  | ||||||
|  | 	// Get the last part of the url | ||||||
| 	urlFinal := strings.Join(urlSplitted[1:], ".") | 	urlFinal := strings.Join(urlSplitted[1:], ".") | ||||||
|  |  | ||||||
|  | 	// Return the root domain | ||||||
| 	return urlFinal, nil | 	return urlFinal, nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // Reads a file and returns the contents | ||||||
| func ReadFile(file string) (string, error) { | func ReadFile(file string) (string, error) { | ||||||
|  | 	// Check if the file exists | ||||||
| 	_, statErr := os.Stat(file) | 	_, statErr := os.Stat(file) | ||||||
|  |  | ||||||
|  | 	// Check if there was an error | ||||||
| 	if statErr != nil { | 	if statErr != nil { | ||||||
| 		return "", statErr | 		return "", statErr | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// Read the file | ||||||
| 	data, readErr := os.ReadFile(file) | 	data, readErr := os.ReadFile(file) | ||||||
|  |  | ||||||
|  | 	// Check if there was an error | ||||||
| 	if readErr != nil { | 	if readErr != nil { | ||||||
| 		return "", readErr | 		return "", readErr | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// Return the file contents | ||||||
| 	return string(data), nil | 	return string(data), nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // Parses a file into a comma separated list of users | ||||||
| func ParseFileToLine(content string) string { | func ParseFileToLine(content string) string { | ||||||
|  | 	// Split the content by newline | ||||||
| 	lines := strings.Split(content, "\n") | 	lines := strings.Split(content, "\n") | ||||||
|  |  | ||||||
|  | 	// Create a list of users | ||||||
| 	users := make([]string, 0) | 	users := make([]string, 0) | ||||||
|  |  | ||||||
|  | 	// Loop through the lines, trimming the whitespace and appending to the users list | ||||||
| 	for _, line := range lines { | 	for _, line := range lines { | ||||||
| 		if strings.TrimSpace(line) == "" { | 		if strings.TrimSpace(line) == "" { | ||||||
| 			continue | 			continue | ||||||
| @@ -79,63 +109,92 @@ func ParseFileToLine(content string) string { | |||||||
| 		users = append(users, strings.TrimSpace(line)) | 		users = append(users, strings.TrimSpace(line)) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// Return the users as a comma separated string | ||||||
| 	return strings.Join(users, ",") | 	return strings.Join(users, ",") | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // Get the secret from the config or file | ||||||
| func GetSecret(conf string, file string) string { | func GetSecret(conf string, file string) string { | ||||||
|  | 	// If neither the config or file is set, return an empty string | ||||||
| 	if conf == "" && file == "" { | 	if conf == "" && file == "" { | ||||||
| 		return "" | 		return "" | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// If the config is set, return the config (environment variable) | ||||||
| 	if conf != "" { | 	if conf != "" { | ||||||
| 		return conf | 		return conf | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// If the file is set, read the file | ||||||
| 	contents, err := ReadFile(file) | 	contents, err := ReadFile(file) | ||||||
|  |  | ||||||
|  | 	// Check if there was an error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "" | 		return "" | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// Return the contents of the file | ||||||
| 	return contents | 	return contents | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // Get the users from the config or file | ||||||
| func GetUsers(conf string, file string) (types.Users, error) { | func GetUsers(conf string, file string) (types.Users, error) { | ||||||
|  | 	// Create a string to store the users | ||||||
| 	var users string | 	var users string | ||||||
|  |  | ||||||
|  | 	// If neither the config or file is set, return an empty users struct | ||||||
| 	if conf == "" && file == "" { | 	if conf == "" && file == "" { | ||||||
| 		return types.Users{}, errors.New("no users provided") | 		return types.Users{}, nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// If the config (environment) is set, append the users to the users string | ||||||
| 	if conf != "" { | 	if conf != "" { | ||||||
| 		log.Debug().Msg("Using users from config") | 		log.Debug().Msg("Using users from config") | ||||||
| 		users += conf | 		users += conf | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// If the file is set, read the file and append the users to the users string | ||||||
| 	if file != "" { | 	if file != "" { | ||||||
|  | 		// Read the file | ||||||
| 		fileContents, fileErr := ReadFile(file) | 		fileContents, fileErr := ReadFile(file) | ||||||
|  |  | ||||||
|  | 		// If there isn't an error we can append the users to the users string | ||||||
| 		if fileErr == nil { | 		if fileErr == nil { | ||||||
| 			log.Debug().Msg("Using users from file") | 			log.Debug().Msg("Using users from file") | ||||||
|  |  | ||||||
|  | 			// Append the users to the users string | ||||||
| 			if users != "" { | 			if users != "" { | ||||||
| 				users += "," | 				users += "," | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
|  | 			// Parse the file contents into a comma separated list of users | ||||||
| 			users += ParseFileToLine(fileContents) | 			users += ParseFileToLine(fileContents) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// Return the parsed users | ||||||
| 	return ParseUsers(users) | 	return ParseUsers(users) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // Check if any of the OAuth providers are configured based on the client id and secret | ||||||
| func OAuthConfigured(config types.Config) bool { | func OAuthConfigured(config types.Config) bool { | ||||||
| 	return (config.GithubClientId != "" && config.GithubClientSecret != "") || (config.GoogleClientId != "" && config.GoogleClientSecret != "") || (config.GenericClientId != "" && config.GenericClientSecret != "") | 	return (config.GithubClientId != "" && config.GithubClientSecret != "") || (config.GoogleClientId != "" && config.GoogleClientSecret != "") || (config.GenericClientId != "" && config.GenericClientSecret != "") || (config.TailscaleClientId != "" && config.TailscaleClientSecret != "") | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // Parse the docker labels to the tinyauth labels struct | ||||||
| func GetTinyauthLabels(labels map[string]string) types.TinyauthLabels { | func GetTinyauthLabels(labels map[string]string) types.TinyauthLabels { | ||||||
|  | 	// Create a new tinyauth labels struct | ||||||
| 	var tinyauthLabels types.TinyauthLabels | 	var tinyauthLabels types.TinyauthLabels | ||||||
|  |  | ||||||
|  | 	// Loop through the labels | ||||||
| 	for label, value := range labels { | 	for label, value := range labels { | ||||||
|  |  | ||||||
|  | 		// Check if the label is in the tinyauth labels | ||||||
| 		if slices.Contains(constants.TinyauthLabels, label) { | 		if slices.Contains(constants.TinyauthLabels, label) { | ||||||
|  |  | ||||||
| 			log.Debug().Str("label", label).Msg("Found label") | 			log.Debug().Str("label", label).Msg("Found label") | ||||||
|  |  | ||||||
|  | 			// Add the label value to the tinyauth labels struct | ||||||
| 			switch label { | 			switch label { | ||||||
| 			case "tinyauth.oauth.whitelist": | 			case "tinyauth.oauth.whitelist": | ||||||
| 				tinyauthLabels.OAuthWhitelist = strings.Split(value, ",") | 				tinyauthLabels.OAuthWhitelist = strings.Split(value, ",") | ||||||
| @@ -144,5 +203,7 @@ func GetTinyauthLabels(labels map[string]string) types.TinyauthLabels { | |||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// Return the tinyauth labels | ||||||
| 	return tinyauthLabels | 	return tinyauthLabels | ||||||
| } | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user