mirror of
				https://github.com/steveiliop56/tinyauth.git
				synced 2025-11-03 07:35:44 +00:00 
			
		
		
		
	Compare commits
	
		
			13 Commits
		
	
	
		
			v2.1.0-alp
			...
			chore/comm
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 
						 | 
					1b145fd531 | ||
| 
						 | 
					7a3a463489 | ||
| 
						 | 
					e09f241364 | ||
| 
						 | 
					d2ee382f92 | ||
| 
						 | 
					4e8a2443a6 | ||
| 
						 | 
					22777a16a1 | ||
| 
						 | 
					0872556c1a | ||
| 
						 | 
					daad2abc33 | ||
| 
						 | 
					ce567ae3de | ||
| 
						 | 
					87393d3c64 | ||
| 
						 | 
					97830a309b | ||
| 
						 | 
					fe594d2755 | ||
| 
						 | 
					b3aac26644 | 
							
								
								
									
										8
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@@ -5,7 +5,7 @@ internal/assets/dist
 | 
				
			|||||||
tinyauth
 | 
					tinyauth
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# test docker compose
 | 
					# test docker compose
 | 
				
			||||||
docker-compose.test.yml
 | 
					docker-compose.test*
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# users file
 | 
					# users file
 | 
				
			||||||
users.txt
 | 
					users.txt
 | 
				
			||||||
@@ -13,3 +13,9 @@ users.txt
 | 
				
			|||||||
# secret test file
 | 
					# secret test file
 | 
				
			||||||
secret.txt
 | 
					secret.txt
 | 
				
			||||||
secret_oauth.txt
 | 
					secret_oauth.txt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# vscode
 | 
				
			||||||
 | 
					.vscode
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# apple stuff
 | 
				
			||||||
 | 
					.DS_Store
 | 
				
			||||||
@@ -22,9 +22,13 @@ Tinyauth is a simple authentication middleware that adds simple username/passwor
 | 
				
			|||||||
> [!NOTE]
 | 
					> [!NOTE]
 | 
				
			||||||
> Tinyauth is intended for homelab use and it is not made for production use cases. If you are looking for something production ready please use [authentik](https://goauthentik.io).
 | 
					> Tinyauth is intended for homelab use and it is not made for production use cases. If you are looking for something production ready please use [authentik](https://goauthentik.io).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Discord
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					I just made a Discord server for Tinyauth! It is not only for Tinyauth but general self-hosting because I just like chatting with people! The link is [here](https://discord.gg/gWpzrksk), see you there!
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## Getting Started
 | 
					## Getting Started
 | 
				
			||||||
 | 
					
 | 
				
			||||||
You can easily get started with tinyauth by following the guide on the documentation [here](https://tinyauth.doesmycode.work/docs/getting-started.html). There is also an available docker compose file [here](./docker-compose.example.yml) that has traefik, nginx and tinyauth to demonstrate its capabilities.
 | 
					You can easily get started with tinyauth by following the guide on the [documentation](https://tinyauth.doesmycode.work/docs/getting-started.html). There is also an available [docker compose file](./docker-compose.example.yml) that has traefik, nginx and tinyauth to demonstrate its capabilities.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## Documentation
 | 
					## Documentation
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										22
									
								
								assets/discohook.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								assets/discohook.json
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,22 @@
 | 
				
			|||||||
 | 
					{
 | 
				
			||||||
 | 
					  "content": null,
 | 
				
			||||||
 | 
					  "embeds": [
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					      "title": "Welcome to Tinyauth Discord!",
 | 
				
			||||||
 | 
					      "description": "Tinyauth is a simple authentication middleware that adds simple username/password login or OAuth with Google, Github and any generic OAuth provider to all of your docker apps.\n\n**Information**\n\n• Github: <https://github.com/steveiliop56/tinyauth>\n• Website: <https://tinyauth.doesmycode.work>",
 | 
				
			||||||
 | 
					      "url": "https://tinyauth.doesmycode.work",
 | 
				
			||||||
 | 
					      "color": 7002085,
 | 
				
			||||||
 | 
					      "author": {
 | 
				
			||||||
 | 
					        "name": "Tinyauth"
 | 
				
			||||||
 | 
					      },
 | 
				
			||||||
 | 
					      "footer": {
 | 
				
			||||||
 | 
					        "text": "Updated at"
 | 
				
			||||||
 | 
					      },
 | 
				
			||||||
 | 
					      "timestamp": "2025-02-06T22:00:00.000Z",
 | 
				
			||||||
 | 
					      "thumbnail": {
 | 
				
			||||||
 | 
					        "url": "https://github.com/steveiliop56/tinyauth/blob/main/site/public/logo.png?raw=true"
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  ],
 | 
				
			||||||
 | 
					  "attachments": []
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										29
									
								
								cmd/root.go
									
									
									
									
									
								
							
							
						
						
									
										29
									
								
								cmd/root.go
									
									
									
									
									
								
							@@ -1,6 +1,7 @@
 | 
				
			|||||||
package cmd
 | 
					package cmd
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"errors"
 | 
				
			||||||
	"os"
 | 
						"os"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
@@ -54,8 +55,10 @@ var rootCmd = &cobra.Command{
 | 
				
			|||||||
		log.Info().Msg("Parsing users")
 | 
							log.Info().Msg("Parsing users")
 | 
				
			||||||
		users, usersErr := utils.GetUsers(config.Users, config.UsersFile)
 | 
							users, usersErr := utils.GetUsers(config.Users, config.UsersFile)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if (len(users) == 0 || usersErr != nil) && !utils.OAuthConfigured(config) {
 | 
							HandleError(usersErr, "Failed to parse users")
 | 
				
			||||||
			log.Fatal().Err(usersErr).Msg("Failed to parse users")
 | 
					
 | 
				
			||||||
 | 
							if len(users) == 0 && !utils.OAuthConfigured(config) {
 | 
				
			||||||
 | 
								HandleError(errors.New("no users or OAuth configured"), "No users or OAuth configured")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// Create oauth whitelist
 | 
							// Create oauth whitelist
 | 
				
			||||||
@@ -89,7 +92,7 @@ var rootCmd = &cobra.Command{
 | 
				
			|||||||
		HandleError(dockerErr, "Failed to initialize docker")
 | 
							HandleError(dockerErr, "Failed to initialize docker")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// Create auth service
 | 
							// Create auth service
 | 
				
			||||||
		auth := auth.NewAuth(docker, users, oauthWhitelist)
 | 
							auth := auth.NewAuth(docker, users, oauthWhitelist, config.SessionExpiry)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// Create OAuth providers service
 | 
							// Create OAuth providers service
 | 
				
			||||||
		providers := providers.NewProviders(oauthConfig)
 | 
							providers := providers.NewProviders(oauthConfig)
 | 
				
			||||||
@@ -108,7 +111,7 @@ var rootCmd = &cobra.Command{
 | 
				
			|||||||
			AppURL:          config.AppURL,
 | 
								AppURL:          config.AppURL,
 | 
				
			||||||
			CookieSecure:    config.CookieSecure,
 | 
								CookieSecure:    config.CookieSecure,
 | 
				
			||||||
			DisableContinue: config.DisableContinue,
 | 
								DisableContinue: config.DisableContinue,
 | 
				
			||||||
			CookieExpiry:    config.CookieExpiry,
 | 
								CookieExpiry:    config.SessionExpiry,
 | 
				
			||||||
		}, hooks, auth, providers)
 | 
							}, hooks, auth, providers)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// Setup routes
 | 
							// Setup routes
 | 
				
			||||||
@@ -122,20 +125,24 @@ var rootCmd = &cobra.Command{
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
func Execute() {
 | 
					func Execute() {
 | 
				
			||||||
	err := rootCmd.Execute()
 | 
						err := rootCmd.Execute()
 | 
				
			||||||
	if err != nil {
 | 
						HandleError(err, "Failed to execute root command")
 | 
				
			||||||
		log.Fatal().Err(err).Msg("Failed to execute command")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func HandleError(err error, msg string) {
 | 
					func HandleError(err error, msg string) {
 | 
				
			||||||
 | 
						// If error log it and exit
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		log.Fatal().Err(err).Msg(msg)
 | 
							log.Fatal().Err(err).Msg(msg)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func init() {
 | 
					func init() {
 | 
				
			||||||
 | 
						// Add user command
 | 
				
			||||||
	rootCmd.AddCommand(cmd.UserCmd())
 | 
						rootCmd.AddCommand(cmd.UserCmd())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Read environment variables
 | 
				
			||||||
	viper.AutomaticEnv()
 | 
						viper.AutomaticEnv()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Flags
 | 
				
			||||||
	rootCmd.Flags().Int("port", 3000, "Port to run the server on.")
 | 
						rootCmd.Flags().Int("port", 3000, "Port to run the server on.")
 | 
				
			||||||
	rootCmd.Flags().String("address", "0.0.0.0", "Address to bind the server to.")
 | 
						rootCmd.Flags().String("address", "0.0.0.0", "Address to bind the server to.")
 | 
				
			||||||
	rootCmd.Flags().String("secret", "", "Secret to use for the cookie.")
 | 
						rootCmd.Flags().String("secret", "", "Secret to use for the cookie.")
 | 
				
			||||||
@@ -162,8 +169,10 @@ func init() {
 | 
				
			|||||||
	rootCmd.Flags().String("generic-user-url", "", "Generic OAuth user info URL.")
 | 
						rootCmd.Flags().String("generic-user-url", "", "Generic OAuth user info URL.")
 | 
				
			||||||
	rootCmd.Flags().Bool("disable-continue", false, "Disable continue screen and redirect to app directly.")
 | 
						rootCmd.Flags().Bool("disable-continue", false, "Disable continue screen and redirect to app directly.")
 | 
				
			||||||
	rootCmd.Flags().String("oauth-whitelist", "", "Comma separated list of email addresses to whitelist when using OAuth.")
 | 
						rootCmd.Flags().String("oauth-whitelist", "", "Comma separated list of email addresses to whitelist when using OAuth.")
 | 
				
			||||||
	rootCmd.Flags().Int("cookie-expiry", 86400, "Cookie expiration time in seconds.")
 | 
						rootCmd.Flags().Int("session-expiry", 86400, "Session (cookie) expiration time in seconds.")
 | 
				
			||||||
	rootCmd.Flags().Int("log-level", 1, "Log level.")
 | 
						rootCmd.Flags().Int("log-level", 1, "Log level.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Bind flags to environment
 | 
				
			||||||
	viper.BindEnv("port", "PORT")
 | 
						viper.BindEnv("port", "PORT")
 | 
				
			||||||
	viper.BindEnv("address", "ADDRESS")
 | 
						viper.BindEnv("address", "ADDRESS")
 | 
				
			||||||
	viper.BindEnv("secret", "SECRET")
 | 
						viper.BindEnv("secret", "SECRET")
 | 
				
			||||||
@@ -190,7 +199,9 @@ func init() {
 | 
				
			|||||||
	viper.BindEnv("generic-user-url", "GENERIC_USER_URL")
 | 
						viper.BindEnv("generic-user-url", "GENERIC_USER_URL")
 | 
				
			||||||
	viper.BindEnv("disable-continue", "DISABLE_CONTINUE")
 | 
						viper.BindEnv("disable-continue", "DISABLE_CONTINUE")
 | 
				
			||||||
	viper.BindEnv("oauth-whitelist", "OAUTH_WHITELIST")
 | 
						viper.BindEnv("oauth-whitelist", "OAUTH_WHITELIST")
 | 
				
			||||||
	viper.BindEnv("cookie-expiry", "COOKIE_EXPIRY")
 | 
						viper.BindEnv("session-expiry", "SESSION_EXPIRY")
 | 
				
			||||||
	viper.BindEnv("log-level", "LOG_LEVEL")
 | 
						viper.BindEnv("log-level", "LOG_LEVEL")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Bind flags to viper
 | 
				
			||||||
	viper.BindPFlags(rootCmd.Flags())
 | 
						viper.BindPFlags(rootCmd.Flags())
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -22,9 +22,12 @@ var CreateCmd = &cobra.Command{
 | 
				
			|||||||
	Short: "Create a user",
 | 
						Short: "Create a user",
 | 
				
			||||||
	Long:  `Create a user either interactively or by passing flags.`,
 | 
						Long:  `Create a user either interactively or by passing flags.`,
 | 
				
			||||||
	Run: func(cmd *cobra.Command, args []string) {
 | 
						Run: func(cmd *cobra.Command, args []string) {
 | 
				
			||||||
 | 
							// Setup logger
 | 
				
			||||||
		log.Logger = log.Level(zerolog.InfoLevel)
 | 
							log.Logger = log.Level(zerolog.InfoLevel)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Check if interactive
 | 
				
			||||||
		if interactive {
 | 
							if interactive {
 | 
				
			||||||
 | 
								// Create huh form
 | 
				
			||||||
			form := huh.NewForm(
 | 
								form := huh.NewForm(
 | 
				
			||||||
				huh.NewGroup(
 | 
									huh.NewGroup(
 | 
				
			||||||
					huh.NewInput().Title("Username").Value(&username).Validate((func(s string) error {
 | 
										huh.NewInput().Title("Username").Value(&username).Validate((func(s string) error {
 | 
				
			||||||
@@ -43,6 +46,7 @@ var CreateCmd = &cobra.Command{
 | 
				
			|||||||
				),
 | 
									),
 | 
				
			||||||
			)
 | 
								)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// Use simple theme
 | 
				
			||||||
			var baseTheme *huh.Theme = huh.ThemeBase()
 | 
								var baseTheme *huh.Theme = huh.ThemeBase()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			formErr := form.WithTheme(baseTheme).Run()
 | 
								formErr := form.WithTheme(baseTheme).Run()
 | 
				
			||||||
@@ -52,12 +56,14 @@ var CreateCmd = &cobra.Command{
 | 
				
			|||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Do we have username and password?
 | 
				
			||||||
		if username == "" || password == "" {
 | 
							if username == "" || password == "" {
 | 
				
			||||||
			log.Error().Msg("Username and password cannot be empty")
 | 
								log.Error().Msg("Username and password cannot be empty")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		log.Info().Str("username", username).Str("password", password).Bool("docker", docker).Msg("Creating user")
 | 
							log.Info().Str("username", username).Str("password", password).Bool("docker", docker).Msg("Creating user")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Hash password
 | 
				
			||||||
		passwordByte, passwordErr := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
 | 
							passwordByte, passwordErr := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if passwordErr != nil {
 | 
							if passwordErr != nil {
 | 
				
			||||||
@@ -66,15 +72,18 @@ var CreateCmd = &cobra.Command{
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
		passwordString := string(passwordByte)
 | 
							passwordString := string(passwordByte)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Escape $ for docker
 | 
				
			||||||
		if docker {
 | 
							if docker {
 | 
				
			||||||
			passwordString = strings.ReplaceAll(passwordString, "$", "$$")
 | 
								passwordString = strings.ReplaceAll(passwordString, "$", "$$")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Log user created
 | 
				
			||||||
		log.Info().Str("user", fmt.Sprintf("%s:%s", username, passwordString)).Msg("User created")
 | 
							log.Info().Str("user", fmt.Sprintf("%s:%s", username, passwordString)).Msg("User created")
 | 
				
			||||||
	},
 | 
						},
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func init() {
 | 
					func init() {
 | 
				
			||||||
 | 
						// Flags
 | 
				
			||||||
	CreateCmd.Flags().BoolVar(&interactive, "interactive", false, "Create a user interactively")
 | 
						CreateCmd.Flags().BoolVar(&interactive, "interactive", false, "Create a user interactively")
 | 
				
			||||||
	CreateCmd.Flags().BoolVar(&docker, "docker", false, "Format output for docker")
 | 
						CreateCmd.Flags().BoolVar(&docker, "docker", false, "Format output for docker")
 | 
				
			||||||
	CreateCmd.Flags().StringVar(&username, "username", "", "Username")
 | 
						CreateCmd.Flags().StringVar(&username, "username", "", "Username")
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -22,9 +22,12 @@ var VerifyCmd = &cobra.Command{
 | 
				
			|||||||
	Short: "Verify a user is set up correctly",
 | 
						Short: "Verify a user is set up correctly",
 | 
				
			||||||
	Long:  `Verify a user is set up correctly meaning that it has a correct username and password.`,
 | 
						Long:  `Verify a user is set up correctly meaning that it has a correct username and password.`,
 | 
				
			||||||
	Run: func(cmd *cobra.Command, args []string) {
 | 
						Run: func(cmd *cobra.Command, args []string) {
 | 
				
			||||||
 | 
							// Setup logger
 | 
				
			||||||
		log.Logger = log.Level(zerolog.InfoLevel)
 | 
							log.Logger = log.Level(zerolog.InfoLevel)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Check if interactive
 | 
				
			||||||
		if interactive {
 | 
							if interactive {
 | 
				
			||||||
 | 
								// Create huh form
 | 
				
			||||||
			form := huh.NewForm(
 | 
								form := huh.NewForm(
 | 
				
			||||||
				huh.NewGroup(
 | 
									huh.NewGroup(
 | 
				
			||||||
					huh.NewInput().Title("User (username:hash)").Value(&user).Validate((func(s string) error {
 | 
										huh.NewInput().Title("User (username:hash)").Value(&user).Validate((func(s string) error {
 | 
				
			||||||
@@ -49,6 +52,7 @@ var VerifyCmd = &cobra.Command{
 | 
				
			|||||||
				),
 | 
									),
 | 
				
			||||||
			)
 | 
								)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// Use simple theme
 | 
				
			||||||
			var baseTheme *huh.Theme = huh.ThemeBase()
 | 
								var baseTheme *huh.Theme = huh.ThemeBase()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			formErr := form.WithTheme(baseTheme).Run()
 | 
								formErr := form.WithTheme(baseTheme).Run()
 | 
				
			||||||
@@ -58,22 +62,26 @@ var VerifyCmd = &cobra.Command{
 | 
				
			|||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Do we have username, password and user?
 | 
				
			||||||
		if username == "" || password == "" || user == "" {
 | 
							if username == "" || password == "" || user == "" {
 | 
				
			||||||
			log.Fatal().Msg("Username, password and user cannot be empty")
 | 
								log.Fatal().Msg("Username, password and user cannot be empty")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		log.Info().Str("user", user).Str("username", username).Str("password", password).Bool("docker", docker).Msg("Verifying user")
 | 
							log.Info().Str("user", user).Str("username", username).Str("password", password).Bool("docker", docker).Msg("Verifying user")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Split username and password
 | 
				
			||||||
		userSplit := strings.Split(user, ":")
 | 
							userSplit := strings.Split(user, ":")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if userSplit[1] == "" {
 | 
							if userSplit[1] == "" {
 | 
				
			||||||
			log.Fatal().Msg("User is not formatted correctly")
 | 
								log.Fatal().Msg("User is not formatted correctly")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Replace $$ with $ if formatted for docker
 | 
				
			||||||
		if docker {
 | 
							if docker {
 | 
				
			||||||
			userSplit[1] = strings.ReplaceAll(userSplit[1], "$$", "$")
 | 
								userSplit[1] = strings.ReplaceAll(userSplit[1], "$$", "$")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Compare username and password
 | 
				
			||||||
		verifyErr := bcrypt.CompareHashAndPassword([]byte(userSplit[1]), []byte(password))
 | 
							verifyErr := bcrypt.CompareHashAndPassword([]byte(userSplit[1]), []byte(password))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if verifyErr != nil || username != userSplit[0] {
 | 
							if verifyErr != nil || username != userSplit[0] {
 | 
				
			||||||
@@ -85,6 +93,7 @@ var VerifyCmd = &cobra.Command{
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func init() {
 | 
					func init() {
 | 
				
			||||||
 | 
						// Flags
 | 
				
			||||||
	VerifyCmd.Flags().BoolVarP(&interactive, "interactive", "i", false, "Create a user interactively")
 | 
						VerifyCmd.Flags().BoolVarP(&interactive, "interactive", "i", false, "Create a user interactively")
 | 
				
			||||||
	VerifyCmd.Flags().BoolVar(&docker, "docker", false, "Is the user formatted for docker?")
 | 
						VerifyCmd.Flags().BoolVar(&docker, "docker", false, "Is the user formatted for docker?")
 | 
				
			||||||
	VerifyCmd.Flags().StringVar(&username, "username", "", "Username")
 | 
						VerifyCmd.Flags().StringVar(&username, "username", "", "Username")
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -41,11 +41,15 @@ type API struct {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (api *API) Init() {
 | 
					func (api *API) Init() {
 | 
				
			||||||
 | 
						// Disable gin logs
 | 
				
			||||||
	gin.SetMode(gin.ReleaseMode)
 | 
						gin.SetMode(gin.ReleaseMode)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Create router and use zerolog for logs
 | 
				
			||||||
	log.Debug().Msg("Setting up router")
 | 
						log.Debug().Msg("Setting up router")
 | 
				
			||||||
	router := gin.New()
 | 
						router := gin.New()
 | 
				
			||||||
	router.Use(zerolog())
 | 
						router.Use(zerolog())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Read UI assets
 | 
				
			||||||
	log.Debug().Msg("Setting up assets")
 | 
						log.Debug().Msg("Setting up assets")
 | 
				
			||||||
	dist, distErr := fs.Sub(assets.Assets, "dist")
 | 
						dist, distErr := fs.Sub(assets.Assets, "dist")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -53,11 +57,15 @@ func (api *API) Init() {
 | 
				
			|||||||
		log.Fatal().Err(distErr).Msg("Failed to get UI assets")
 | 
							log.Fatal().Err(distErr).Msg("Failed to get UI assets")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Create file server
 | 
				
			||||||
	log.Debug().Msg("Setting up file server")
 | 
						log.Debug().Msg("Setting up file server")
 | 
				
			||||||
	fileServer := http.FileServer(http.FS(dist))
 | 
						fileServer := http.FileServer(http.FS(dist))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Setup cookie store
 | 
				
			||||||
	log.Debug().Msg("Setting up cookie store")
 | 
						log.Debug().Msg("Setting up cookie store")
 | 
				
			||||||
	store := cookie.NewStore([]byte(api.Config.Secret))
 | 
						store := cookie.NewStore([]byte(api.Config.Secret))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Get domain to use for session cookies
 | 
				
			||||||
	log.Debug().Msg("Getting domain")
 | 
						log.Debug().Msg("Getting domain")
 | 
				
			||||||
	domain, domainErr := utils.GetRootURL(api.Config.AppURL)
 | 
						domain, domainErr := utils.GetRootURL(api.Config.AppURL)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -70,6 +78,7 @@ func (api *API) Init() {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	api.Domain = fmt.Sprintf(".%s", domain)
 | 
						api.Domain = fmt.Sprintf(".%s", domain)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Use session middleware
 | 
				
			||||||
	store.Options(sessions.Options{
 | 
						store.Options(sessions.Options{
 | 
				
			||||||
		Domain:   api.Domain,
 | 
							Domain:   api.Domain,
 | 
				
			||||||
		Path:     "/",
 | 
							Path:     "/",
 | 
				
			||||||
@@ -80,79 +89,169 @@ func (api *API) Init() {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	router.Use(sessions.Sessions("tinyauth", store))
 | 
						router.Use(sessions.Sessions("tinyauth", store))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// UI middleware
 | 
				
			||||||
	router.Use(func(c *gin.Context) {
 | 
						router.Use(func(c *gin.Context) {
 | 
				
			||||||
 | 
							// If not an API request, serve the UI
 | 
				
			||||||
		if !strings.HasPrefix(c.Request.URL.Path, "/api") {
 | 
							if !strings.HasPrefix(c.Request.URL.Path, "/api") {
 | 
				
			||||||
			_, err := fs.Stat(dist, strings.TrimPrefix(c.Request.URL.Path, "/"))
 | 
								_, err := fs.Stat(dist, strings.TrimPrefix(c.Request.URL.Path, "/"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// If the file doesn't exist, serve the index.html
 | 
				
			||||||
			if os.IsNotExist(err) {
 | 
								if os.IsNotExist(err) {
 | 
				
			||||||
				c.Request.URL.Path = "/"
 | 
									c.Request.URL.Path = "/"
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// Serve the file
 | 
				
			||||||
			fileServer.ServeHTTP(c.Writer, c.Request)
 | 
								fileServer.ServeHTTP(c.Writer, c.Request)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// Stop further processing
 | 
				
			||||||
			c.Abort()
 | 
								c.Abort()
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Set router
 | 
				
			||||||
	api.Router = router
 | 
						api.Router = router
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (api *API) SetupRoutes() {
 | 
					func (api *API) SetupRoutes() {
 | 
				
			||||||
	api.Router.GET("/api/auth", func(c *gin.Context) {
 | 
						api.Router.GET("/api/auth/:proxy", func(c *gin.Context) {
 | 
				
			||||||
		log.Debug().Msg("Checking auth")
 | 
							// Create struct for proxy
 | 
				
			||||||
 | 
							var proxy types.Proxy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Bind URI
 | 
				
			||||||
 | 
							bindErr := c.BindUri(&proxy)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Handle error
 | 
				
			||||||
 | 
							if api.handleError(c, "Failed to bind URI", bindErr) {
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							log.Debug().Interface("proxy", proxy.Proxy).Msg("Got proxy")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Get user context
 | 
				
			||||||
		userContext := api.Hooks.UseUserContext(c)
 | 
							userContext := api.Hooks.UseUserContext(c)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Get headers
 | 
				
			||||||
		uri := c.Request.Header.Get("X-Forwarded-Uri")
 | 
							uri := c.Request.Header.Get("X-Forwarded-Uri")
 | 
				
			||||||
		proto := c.Request.Header.Get("X-Forwarded-Proto")
 | 
							proto := c.Request.Header.Get("X-Forwarded-Proto")
 | 
				
			||||||
		host := c.Request.Header.Get("X-Forwarded-Host")
 | 
							host := c.Request.Header.Get("X-Forwarded-Host")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Check if user is logged in
 | 
				
			||||||
		if userContext.IsLoggedIn {
 | 
							if userContext.IsLoggedIn {
 | 
				
			||||||
			log.Debug().Msg("Authenticated")
 | 
								log.Debug().Msg("Authenticated")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// Check if user is allowed to access subdomain, if request is nginx.example.com the subdomain (resource) is nginx
 | 
				
			||||||
			appAllowed, appAllowedErr := api.Auth.ResourceAllowed(userContext, host)
 | 
								appAllowed, appAllowedErr := api.Auth.ResourceAllowed(userContext, host)
 | 
				
			||||||
			if handleApiError(c, "Failed to check if resource is allowed", appAllowedErr) {
 | 
					
 | 
				
			||||||
				return
 | 
								// Check if there was an error
 | 
				
			||||||
 | 
								if appAllowedErr != nil {
 | 
				
			||||||
 | 
									// Return 501 if nginx is the proxy or if the request is using an Authorization header
 | 
				
			||||||
 | 
									if proxy.Proxy == "nginx" || c.GetHeader("Authorization") != "" {
 | 
				
			||||||
 | 
										log.Error().Err(appAllowedErr).Msg("Failed to check if app is allowed")
 | 
				
			||||||
 | 
										c.JSON(501, gin.H{
 | 
				
			||||||
 | 
											"status":  501,
 | 
				
			||||||
 | 
											"message": "Internal Server Error",
 | 
				
			||||||
 | 
										})
 | 
				
			||||||
 | 
										return
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									// Return the internal server error page
 | 
				
			||||||
 | 
									if api.handleError(c, "Failed to check if app is allowed", appAllowedErr) {
 | 
				
			||||||
 | 
										return
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								log.Debug().Bool("appAllowed", appAllowed).Msg("Checking if app is allowed")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// The user is not allowed to access the app
 | 
				
			||||||
			if !appAllowed {
 | 
								if !appAllowed {
 | 
				
			||||||
				log.Warn().Str("username", userContext.Username).Str("host", host).Msg("User not allowed")
 | 
									log.Warn().Str("username", userContext.Username).Str("host", host).Msg("User not allowed")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									// Build query
 | 
				
			||||||
				queries, queryErr := query.Values(types.UnauthorizedQuery{
 | 
									queries, queryErr := query.Values(types.UnauthorizedQuery{
 | 
				
			||||||
					Username: userContext.Username,
 | 
										Username: userContext.Username,
 | 
				
			||||||
					Resource: strings.Split(host, ".")[0],
 | 
										Resource: strings.Split(host, ".")[0],
 | 
				
			||||||
				})
 | 
									})
 | 
				
			||||||
				if handleApiError(c, "Failed to build query", queryErr) {
 | 
					
 | 
				
			||||||
 | 
									// Check if there was an error
 | 
				
			||||||
 | 
									if queryErr != nil {
 | 
				
			||||||
 | 
										// Return 501 if nginx is the proxy or if the request is using an Authorization header
 | 
				
			||||||
 | 
										if proxy.Proxy == "nginx" || c.GetHeader("Authorization") != "" {
 | 
				
			||||||
 | 
											log.Error().Err(queryErr).Msg("Failed to build query")
 | 
				
			||||||
 | 
											c.JSON(501, gin.H{
 | 
				
			||||||
 | 
												"status":  501,
 | 
				
			||||||
 | 
												"message": "Internal Server Error",
 | 
				
			||||||
 | 
											})
 | 
				
			||||||
 | 
											return
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
										// Return the internal server error page
 | 
				
			||||||
 | 
										if api.handleError(c, "Failed to build query", queryErr) {
 | 
				
			||||||
 | 
											return
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									// Return 401 if nginx is the proxy or if the request is using an Authorization header
 | 
				
			||||||
 | 
									if proxy.Proxy == "nginx" || c.GetHeader("Authorization") != "" {
 | 
				
			||||||
 | 
										c.JSON(401, gin.H{
 | 
				
			||||||
 | 
											"status":  401,
 | 
				
			||||||
 | 
											"message": "Unauthorized",
 | 
				
			||||||
 | 
										})
 | 
				
			||||||
					return
 | 
										return
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									// We are using caddy/traefik so redirect
 | 
				
			||||||
				c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", api.Config.AppURL, queries.Encode()))
 | 
									c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", api.Config.AppURL, queries.Encode()))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									// Stop further processing
 | 
				
			||||||
 | 
									return
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// The user is allowed to access the app
 | 
				
			||||||
			c.JSON(200, gin.H{
 | 
								c.JSON(200, gin.H{
 | 
				
			||||||
				"status":  200,
 | 
									"status":  200,
 | 
				
			||||||
				"message": "Authenticated",
 | 
									"message": "Authenticated",
 | 
				
			||||||
			})
 | 
								})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// Stop further processing
 | 
				
			||||||
			return
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// The user is not logged in
 | 
				
			||||||
 | 
							log.Debug().Msg("Unauthorized")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Return 401 if nginx is the proxy or if the request is using an Authorization header
 | 
				
			||||||
 | 
							if proxy.Proxy == "nginx" || c.GetHeader("Authorization") != "" {
 | 
				
			||||||
 | 
								c.JSON(401, gin.H{
 | 
				
			||||||
 | 
									"status":  401,
 | 
				
			||||||
 | 
									"message": "Unauthorized",
 | 
				
			||||||
 | 
								})
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Build query
 | 
				
			||||||
		queries, queryErr := query.Values(types.LoginQuery{
 | 
							queries, queryErr := query.Values(types.LoginQuery{
 | 
				
			||||||
			RedirectURI: fmt.Sprintf("%s://%s%s", proto, host, uri),
 | 
								RedirectURI: fmt.Sprintf("%s://%s%s", proto, host, uri),
 | 
				
			||||||
		})
 | 
							})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		log.Debug().Interface("redirect_uri", fmt.Sprintf("%s://%s%s", proto, host, uri)).Msg("Redirecting to login")
 | 
							log.Debug().Interface("redirect_uri", fmt.Sprintf("%s://%s%s", proto, host, uri)).Msg("Redirecting to login")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if queryErr != nil {
 | 
							// Handle error (no need to check for nginx/headers since we are sure we are using caddy/traefik)
 | 
				
			||||||
			log.Error().Err(queryErr).Msg("Failed to build query")
 | 
							if api.handleError(c, "Failed to build query", queryErr) {
 | 
				
			||||||
			c.JSON(501, gin.H{
 | 
					 | 
				
			||||||
				"status":  501,
 | 
					 | 
				
			||||||
				"message": "Internal Server Error",
 | 
					 | 
				
			||||||
			})
 | 
					 | 
				
			||||||
			return
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Redirect to login
 | 
				
			||||||
		c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/?%s", api.Config.AppURL, queries.Encode()))
 | 
							c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/?%s", api.Config.AppURL, queries.Encode()))
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	api.Router.POST("/api/login", func(c *gin.Context) {
 | 
						api.Router.POST("/api/login", func(c *gin.Context) {
 | 
				
			||||||
 | 
							// Create login struct
 | 
				
			||||||
		var login types.LoginRequest
 | 
							var login types.LoginRequest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Bind JSON
 | 
				
			||||||
		err := c.BindJSON(&login)
 | 
							err := c.BindJSON(&login)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Handle error
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			log.Error().Err(err).Msg("Failed to bind JSON")
 | 
								log.Error().Err(err).Msg("Failed to bind JSON")
 | 
				
			||||||
			c.JSON(400, gin.H{
 | 
								c.JSON(400, gin.H{
 | 
				
			||||||
@@ -164,8 +263,10 @@ func (api *API) SetupRoutes() {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
		log.Debug().Msg("Got login request")
 | 
							log.Debug().Msg("Got login request")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Get user based on username
 | 
				
			||||||
		user := api.Auth.GetUser(login.Username)
 | 
							user := api.Auth.GetUser(login.Username)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// User does not exist
 | 
				
			||||||
		if user == nil {
 | 
							if user == nil {
 | 
				
			||||||
			log.Debug().Str("username", login.Username).Msg("User not found")
 | 
								log.Debug().Str("username", login.Username).Msg("User not found")
 | 
				
			||||||
			c.JSON(401, gin.H{
 | 
								c.JSON(401, gin.H{
 | 
				
			||||||
@@ -175,6 +276,9 @@ func (api *API) SetupRoutes() {
 | 
				
			|||||||
			return
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							log.Debug().Msg("Got user")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Check if password is correct
 | 
				
			||||||
		if !api.Auth.CheckPassword(*user, login.Password) {
 | 
							if !api.Auth.CheckPassword(*user, login.Password) {
 | 
				
			||||||
			log.Debug().Str("username", login.Username).Msg("Password incorrect")
 | 
								log.Debug().Str("username", login.Username).Msg("Password incorrect")
 | 
				
			||||||
			c.JSON(401, gin.H{
 | 
								c.JSON(401, gin.H{
 | 
				
			||||||
@@ -186,11 +290,13 @@ func (api *API) SetupRoutes() {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
		log.Debug().Msg("Password correct, logging in")
 | 
							log.Debug().Msg("Password correct, logging in")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Create session cookie with username as provider
 | 
				
			||||||
		api.Auth.CreateSessionCookie(c, &types.SessionCookie{
 | 
							api.Auth.CreateSessionCookie(c, &types.SessionCookie{
 | 
				
			||||||
			Username: login.Username,
 | 
								Username: login.Username,
 | 
				
			||||||
			Provider: "username",
 | 
								Provider: "username",
 | 
				
			||||||
		})
 | 
							})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Return logged in
 | 
				
			||||||
		c.JSON(200, gin.H{
 | 
							c.JSON(200, gin.H{
 | 
				
			||||||
			"status":  200,
 | 
								"status":  200,
 | 
				
			||||||
			"message": "Logged in",
 | 
								"message": "Logged in",
 | 
				
			||||||
@@ -198,12 +304,17 @@ func (api *API) SetupRoutes() {
 | 
				
			|||||||
	})
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	api.Router.POST("/api/logout", func(c *gin.Context) {
 | 
						api.Router.POST("/api/logout", func(c *gin.Context) {
 | 
				
			||||||
 | 
							log.Debug().Msg("Logging out")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Delete session cookie
 | 
				
			||||||
		api.Auth.DeleteSessionCookie(c)
 | 
							api.Auth.DeleteSessionCookie(c)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		log.Debug().Msg("Cleaning up redirect cookie")
 | 
							log.Debug().Msg("Cleaning up redirect cookie")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Clean up redirect cookie if it exists
 | 
				
			||||||
		c.SetCookie("tinyauth_redirect_uri", "", -1, "/", api.Domain, api.Config.CookieSecure, true)
 | 
							c.SetCookie("tinyauth_redirect_uri", "", -1, "/", api.Domain, api.Config.CookieSecure, true)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Return logged out
 | 
				
			||||||
		c.JSON(200, gin.H{
 | 
							c.JSON(200, gin.H{
 | 
				
			||||||
			"status":  200,
 | 
								"status":  200,
 | 
				
			||||||
			"message": "Logged out",
 | 
								"message": "Logged out",
 | 
				
			||||||
@@ -212,19 +323,24 @@ func (api *API) SetupRoutes() {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	api.Router.GET("/api/status", func(c *gin.Context) {
 | 
						api.Router.GET("/api/status", func(c *gin.Context) {
 | 
				
			||||||
		log.Debug().Msg("Checking status")
 | 
							log.Debug().Msg("Checking status")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Get user context
 | 
				
			||||||
		userContext := api.Hooks.UseUserContext(c)
 | 
							userContext := api.Hooks.UseUserContext(c)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Get configured providers
 | 
				
			||||||
		configuredProviders := api.Providers.GetConfiguredProviders()
 | 
							configuredProviders := api.Providers.GetConfiguredProviders()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// We have username/password configured so add it to our providers
 | 
				
			||||||
		if api.Auth.UserAuthConfigured() {
 | 
							if api.Auth.UserAuthConfigured() {
 | 
				
			||||||
			configuredProviders = append(configuredProviders, "username")
 | 
								configuredProviders = append(configuredProviders, "username")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// We are not logged in so return unauthorized
 | 
				
			||||||
		if !userContext.IsLoggedIn {
 | 
							if !userContext.IsLoggedIn {
 | 
				
			||||||
			log.Debug().Msg("Unauthenticated")
 | 
								log.Debug().Msg("Unauthorized")
 | 
				
			||||||
			c.JSON(200, gin.H{
 | 
								c.JSON(200, gin.H{
 | 
				
			||||||
				"status":              200,
 | 
									"status":              200,
 | 
				
			||||||
				"message":             "Unauthenticated",
 | 
									"message":             "Unauthorized",
 | 
				
			||||||
				"username":            "",
 | 
									"username":            "",
 | 
				
			||||||
				"isLoggedIn":          false,
 | 
									"isLoggedIn":          false,
 | 
				
			||||||
				"oauth":               false,
 | 
									"oauth":               false,
 | 
				
			||||||
@@ -237,6 +353,7 @@ func (api *API) SetupRoutes() {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
		log.Debug().Interface("userContext", userContext).Strs("configuredProviders", configuredProviders).Bool("disableContinue", api.Config.DisableContinue).Msg("Authenticated")
 | 
							log.Debug().Interface("userContext", userContext).Strs("configuredProviders", configuredProviders).Bool("disableContinue", api.Config.DisableContinue).Msg("Authenticated")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// We are logged in so return our user context
 | 
				
			||||||
		c.JSON(200, gin.H{
 | 
							c.JSON(200, gin.H{
 | 
				
			||||||
			"status":              200,
 | 
								"status":              200,
 | 
				
			||||||
			"message":             "Authenticated",
 | 
								"message":             "Authenticated",
 | 
				
			||||||
@@ -249,18 +366,14 @@ func (api *API) SetupRoutes() {
 | 
				
			|||||||
		})
 | 
							})
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	api.Router.GET("/api/healthcheck", func(c *gin.Context) {
 | 
					 | 
				
			||||||
		c.JSON(200, gin.H{
 | 
					 | 
				
			||||||
			"status":  200,
 | 
					 | 
				
			||||||
			"message": "OK",
 | 
					 | 
				
			||||||
		})
 | 
					 | 
				
			||||||
	})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	api.Router.GET("/api/oauth/url/:provider", func(c *gin.Context) {
 | 
						api.Router.GET("/api/oauth/url/:provider", func(c *gin.Context) {
 | 
				
			||||||
 | 
							// Create struct for OAuth request
 | 
				
			||||||
		var request types.OAuthRequest
 | 
							var request types.OAuthRequest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Bind URI
 | 
				
			||||||
		bindErr := c.BindUri(&request)
 | 
							bindErr := c.BindUri(&request)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Handle error
 | 
				
			||||||
		if bindErr != nil {
 | 
							if bindErr != nil {
 | 
				
			||||||
			log.Error().Err(bindErr).Msg("Failed to bind URI")
 | 
								log.Error().Err(bindErr).Msg("Failed to bind URI")
 | 
				
			||||||
			c.JSON(400, gin.H{
 | 
								c.JSON(400, gin.H{
 | 
				
			||||||
@@ -272,8 +385,10 @@ func (api *API) SetupRoutes() {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
		log.Debug().Msg("Got OAuth request")
 | 
							log.Debug().Msg("Got OAuth request")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Check if provider exists
 | 
				
			||||||
		provider := api.Providers.GetProvider(request.Provider)
 | 
							provider := api.Providers.GetProvider(request.Provider)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Provider does not exist
 | 
				
			||||||
		if provider == nil {
 | 
							if provider == nil {
 | 
				
			||||||
			c.JSON(404, gin.H{
 | 
								c.JSON(404, gin.H{
 | 
				
			||||||
				"status":  404,
 | 
									"status":  404,
 | 
				
			||||||
@@ -284,24 +399,38 @@ func (api *API) SetupRoutes() {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
		log.Debug().Str("provider", request.Provider).Msg("Got provider")
 | 
							log.Debug().Str("provider", request.Provider).Msg("Got provider")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Get auth URL
 | 
				
			||||||
		authURL := provider.GetAuthURL()
 | 
							authURL := provider.GetAuthURL()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		log.Debug().Msg("Got auth URL")
 | 
							log.Debug().Msg("Got auth URL")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Get redirect URI
 | 
				
			||||||
		redirectURI := c.Query("redirect_uri")
 | 
							redirectURI := c.Query("redirect_uri")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Set redirect cookie if redirect URI is provided
 | 
				
			||||||
		if redirectURI != "" {
 | 
							if redirectURI != "" {
 | 
				
			||||||
			log.Debug().Str("redirectURI", redirectURI).Msg("Setting redirect cookie")
 | 
								log.Debug().Str("redirectURI", redirectURI).Msg("Setting redirect cookie")
 | 
				
			||||||
			c.SetCookie("tinyauth_redirect_uri", redirectURI, 3600, "/", api.Domain, api.Config.CookieSecure, true)
 | 
								c.SetCookie("tinyauth_redirect_uri", redirectURI, 3600, "/", api.Domain, api.Config.CookieSecure, true)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Tailscale does not have an auth url so we create a random code (does not need to be secure) to avoid caching and send it
 | 
				
			||||||
		if request.Provider == "tailscale" {
 | 
							if request.Provider == "tailscale" {
 | 
				
			||||||
 | 
								// Build tailscale query
 | 
				
			||||||
			tailscaleQuery, tailscaleQueryErr := query.Values(types.TailscaleQuery{
 | 
								tailscaleQuery, tailscaleQueryErr := query.Values(types.TailscaleQuery{
 | 
				
			||||||
				Code: (1000 + rand.IntN(9000)), // doesn't need to be secure, just there to avoid caching
 | 
									Code: (1000 + rand.IntN(9000)),
 | 
				
			||||||
			})
 | 
								})
 | 
				
			||||||
			if handleApiError(c, "Failed to build query", tailscaleQueryErr) {
 | 
					
 | 
				
			||||||
 | 
								// Handle error
 | 
				
			||||||
 | 
								if tailscaleQueryErr != nil {
 | 
				
			||||||
 | 
									log.Error().Err(tailscaleQueryErr).Msg("Failed to build query")
 | 
				
			||||||
 | 
									c.JSON(500, gin.H{
 | 
				
			||||||
 | 
										"status":  500,
 | 
				
			||||||
 | 
										"message": "Internal Server Error",
 | 
				
			||||||
 | 
									})
 | 
				
			||||||
				return
 | 
									return
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// Return tailscale URL (immidiately redirects to the callback)
 | 
				
			||||||
			c.JSON(200, gin.H{
 | 
								c.JSON(200, gin.H{
 | 
				
			||||||
				"status":  200,
 | 
									"status":  200,
 | 
				
			||||||
				"message": "Ok",
 | 
									"message": "Ok",
 | 
				
			||||||
@@ -310,6 +439,7 @@ func (api *API) SetupRoutes() {
 | 
				
			|||||||
			return
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Return auth URL
 | 
				
			||||||
		c.JSON(200, gin.H{
 | 
							c.JSON(200, gin.H{
 | 
				
			||||||
			"status":  200,
 | 
								"status":  200,
 | 
				
			||||||
			"message": "Ok",
 | 
								"message": "Ok",
 | 
				
			||||||
@@ -318,18 +448,23 @@ func (api *API) SetupRoutes() {
 | 
				
			|||||||
	})
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	api.Router.GET("/api/oauth/callback/:provider", func(c *gin.Context) {
 | 
						api.Router.GET("/api/oauth/callback/:provider", func(c *gin.Context) {
 | 
				
			||||||
 | 
							// Create struct for OAuth request
 | 
				
			||||||
		var providerName types.OAuthRequest
 | 
							var providerName types.OAuthRequest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Bind URI
 | 
				
			||||||
		bindErr := c.BindUri(&providerName)
 | 
							bindErr := c.BindUri(&providerName)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if handleApiError(c, "Failed to bind URI", bindErr) {
 | 
							// Handle error
 | 
				
			||||||
 | 
							if api.handleError(c, "Failed to bind URI", bindErr) {
 | 
				
			||||||
			return
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		log.Debug().Interface("provider", providerName.Provider).Msg("Got provider name")
 | 
							log.Debug().Interface("provider", providerName.Provider).Msg("Got provider name")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Get code
 | 
				
			||||||
		code := c.Query("code")
 | 
							code := c.Query("code")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Code empty so redirect to error
 | 
				
			||||||
		if code == "" {
 | 
							if code == "" {
 | 
				
			||||||
			log.Error().Msg("No code provided")
 | 
								log.Error().Msg("No code provided")
 | 
				
			||||||
			c.Redirect(http.StatusPermanentRedirect, "/error")
 | 
								c.Redirect(http.StatusPermanentRedirect, "/error")
 | 
				
			||||||
@@ -338,51 +473,67 @@ func (api *API) SetupRoutes() {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
		log.Debug().Msg("Got code")
 | 
							log.Debug().Msg("Got code")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Get provider
 | 
				
			||||||
		provider := api.Providers.GetProvider(providerName.Provider)
 | 
							provider := api.Providers.GetProvider(providerName.Provider)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		log.Debug().Str("provider", providerName.Provider).Msg("Got provider")
 | 
							log.Debug().Str("provider", providerName.Provider).Msg("Got provider")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Provider does not exist
 | 
				
			||||||
		if provider == nil {
 | 
							if provider == nil {
 | 
				
			||||||
			c.Redirect(http.StatusPermanentRedirect, "/not-found")
 | 
								c.Redirect(http.StatusPermanentRedirect, "/not-found")
 | 
				
			||||||
			return
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Exchange token (authenticates user)
 | 
				
			||||||
		_, tokenErr := provider.ExchangeToken(code)
 | 
							_, tokenErr := provider.ExchangeToken(code)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		log.Debug().Msg("Got token")
 | 
							log.Debug().Msg("Got token")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if handleApiError(c, "Failed to exchange token", tokenErr) {
 | 
							// Handle error
 | 
				
			||||||
 | 
							if api.handleError(c, "Failed to exchange token", tokenErr) {
 | 
				
			||||||
			return
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Get email
 | 
				
			||||||
		email, emailErr := api.Providers.GetUser(providerName.Provider)
 | 
							email, emailErr := api.Providers.GetUser(providerName.Provider)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		log.Debug().Str("email", email).Msg("Got email")
 | 
							log.Debug().Str("email", email).Msg("Got email")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if handleApiError(c, "Failed to get user", emailErr) {
 | 
							// Handle error
 | 
				
			||||||
 | 
							if api.handleError(c, "Failed to get user", emailErr) {
 | 
				
			||||||
			return
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Email is not whitelisted
 | 
				
			||||||
		if !api.Auth.EmailWhitelisted(email) {
 | 
							if !api.Auth.EmailWhitelisted(email) {
 | 
				
			||||||
			log.Warn().Str("email", email).Msg("Email not whitelisted")
 | 
								log.Warn().Str("email", email).Msg("Email not whitelisted")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// Build query
 | 
				
			||||||
			unauthorizedQuery, unauthorizedQueryErr := query.Values(types.UnauthorizedQuery{
 | 
								unauthorizedQuery, unauthorizedQueryErr := query.Values(types.UnauthorizedQuery{
 | 
				
			||||||
				Username: email,
 | 
									Username: email,
 | 
				
			||||||
			})
 | 
								})
 | 
				
			||||||
			if handleApiError(c, "Failed to build query", unauthorizedQueryErr) {
 | 
					
 | 
				
			||||||
 | 
								// Handle error
 | 
				
			||||||
 | 
								if api.handleError(c, "Failed to build query", unauthorizedQueryErr) {
 | 
				
			||||||
				return
 | 
									return
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// Redirect to unauthorized
 | 
				
			||||||
			c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/unauthorized?%s", api.Config.AppURL, unauthorizedQuery.Encode()))
 | 
								c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/unauthorized?%s", api.Config.AppURL, unauthorizedQuery.Encode()))
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		log.Debug().Msg("Email whitelisted")
 | 
							log.Debug().Msg("Email whitelisted")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Create session cookie
 | 
				
			||||||
		api.Auth.CreateSessionCookie(c, &types.SessionCookie{
 | 
							api.Auth.CreateSessionCookie(c, &types.SessionCookie{
 | 
				
			||||||
			Username: email,
 | 
								Username: email,
 | 
				
			||||||
			Provider: providerName.Provider,
 | 
								Provider: providerName.Provider,
 | 
				
			||||||
		})
 | 
							})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Get redirect URI
 | 
				
			||||||
		redirectURI, redirectURIErr := c.Cookie("tinyauth_redirect_uri")
 | 
							redirectURI, redirectURIErr := c.Cookie("tinyauth_redirect_uri")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// If it is empty it means that no redirect_uri was provided to the login screen so we just log in
 | 
				
			||||||
		if redirectURIErr != nil {
 | 
							if redirectURIErr != nil {
 | 
				
			||||||
			c.JSON(200, gin.H{
 | 
								c.JSON(200, gin.H{
 | 
				
			||||||
				"status":  200,
 | 
									"status":  200,
 | 
				
			||||||
@@ -392,40 +543,71 @@ func (api *API) SetupRoutes() {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
		log.Debug().Str("redirectURI", redirectURI).Msg("Got redirect URI")
 | 
							log.Debug().Str("redirectURI", redirectURI).Msg("Got redirect URI")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Clean up redirect cookie since we already have the value
 | 
				
			||||||
		c.SetCookie("tinyauth_redirect_uri", "", -1, "/", api.Domain, api.Config.CookieSecure, true)
 | 
							c.SetCookie("tinyauth_redirect_uri", "", -1, "/", api.Domain, api.Config.CookieSecure, true)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Build query
 | 
				
			||||||
		redirectQuery, redirectQueryErr := query.Values(types.LoginQuery{
 | 
							redirectQuery, redirectQueryErr := query.Values(types.LoginQuery{
 | 
				
			||||||
			RedirectURI: redirectURI,
 | 
								RedirectURI: redirectURI,
 | 
				
			||||||
		})
 | 
							})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		log.Debug().Msg("Got redirect query")
 | 
							log.Debug().Msg("Got redirect query")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if handleApiError(c, "Failed to build query", redirectQueryErr) {
 | 
							// Handle error
 | 
				
			||||||
 | 
							if api.handleError(c, "Failed to build query", redirectQueryErr) {
 | 
				
			||||||
			return
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Redirect to continue with the redirect URI
 | 
				
			||||||
		c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/continue?%s", api.Config.AppURL, redirectQuery.Encode()))
 | 
							c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/continue?%s", api.Config.AppURL, redirectQuery.Encode()))
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Simple healthcheck
 | 
				
			||||||
 | 
						api.Router.GET("/api/healthcheck", func(c *gin.Context) {
 | 
				
			||||||
 | 
							c.JSON(200, gin.H{
 | 
				
			||||||
 | 
								"status":  200,
 | 
				
			||||||
 | 
								"message": "OK",
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (api *API) Run() {
 | 
					func (api *API) Run() {
 | 
				
			||||||
	log.Info().Str("address", api.Config.Address).Int("port", api.Config.Port).Msg("Starting server")
 | 
						log.Info().Str("address", api.Config.Address).Int("port", api.Config.Port).Msg("Starting server")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Run server
 | 
				
			||||||
	api.Router.Run(fmt.Sprintf("%s:%d", api.Config.Address, api.Config.Port))
 | 
						api.Router.Run(fmt.Sprintf("%s:%d", api.Config.Address, api.Config.Port))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// handleError logs the error and redirects to the error page (only meant for stuff the user may access does not apply for login paths)
 | 
				
			||||||
 | 
					func (api *API) handleError(c *gin.Context, msg string, err error) bool {
 | 
				
			||||||
 | 
						// If error is not nil log it and redirect to error page also return true so we can stop further processing
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							log.Error().Err(err).Msg(msg)
 | 
				
			||||||
 | 
							c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", api.Config.AppURL))
 | 
				
			||||||
 | 
							return true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// zerolog is a middleware for gin that logs requests using zerolog
 | 
				
			||||||
func zerolog() gin.HandlerFunc {
 | 
					func zerolog() gin.HandlerFunc {
 | 
				
			||||||
	return func(c *gin.Context) {
 | 
						return func(c *gin.Context) {
 | 
				
			||||||
 | 
							// Get initial time
 | 
				
			||||||
		tStart := time.Now()
 | 
							tStart := time.Now()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Process request
 | 
				
			||||||
		c.Next()
 | 
							c.Next()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Get status code, address, method and path
 | 
				
			||||||
		code := c.Writer.Status()
 | 
							code := c.Writer.Status()
 | 
				
			||||||
		address := c.Request.RemoteAddr
 | 
							address := c.Request.RemoteAddr
 | 
				
			||||||
		method := c.Request.Method
 | 
							method := c.Request.Method
 | 
				
			||||||
		path := c.Request.URL.Path
 | 
							path := c.Request.URL.Path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Get latency
 | 
				
			||||||
		latency := time.Since(tStart).String()
 | 
							latency := time.Since(tStart).String()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Log request
 | 
				
			||||||
		switch {
 | 
							switch {
 | 
				
			||||||
		case code >= 200 && code < 300:
 | 
							case code >= 200 && code < 300:
 | 
				
			||||||
			log.Info().Str("method", method).Str("path", path).Str("address", address).Int("status", code).Str("latency", latency).Msg("Request")
 | 
								log.Info().Str("method", method).Str("path", path).Str("address", address).Int("status", code).Str("latency", latency).Msg("Request")
 | 
				
			||||||
@@ -436,12 +618,3 @@ func zerolog() gin.HandlerFunc {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					 | 
				
			||||||
func handleApiError(c *gin.Context, msg string, err error) bool {
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		log.Error().Err(err).Msg(msg)
 | 
					 | 
				
			||||||
		c.Redirect(http.StatusPermanentRedirect, "/error")
 | 
					 | 
				
			||||||
		return true
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return false
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 
 | 
				
			|||||||
@@ -4,8 +4,12 @@ import (
 | 
				
			|||||||
	"embed"
 | 
						"embed"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// UI assets
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
//go:embed dist
 | 
					//go:embed dist
 | 
				
			||||||
var Assets embed.FS
 | 
					var Assets embed.FS
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Version file
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
//go:embed version
 | 
					//go:embed version
 | 
				
			||||||
var Version string
 | 
					var Version string
 | 
				
			||||||
@@ -1 +1 @@
 | 
				
			|||||||
v2.1.0
 | 
					v3.0.0
 | 
				
			||||||
@@ -3,6 +3,7 @@ package auth
 | 
				
			|||||||
import (
 | 
					import (
 | 
				
			||||||
	"slices"
 | 
						"slices"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
	"tinyauth/internal/docker"
 | 
						"tinyauth/internal/docker"
 | 
				
			||||||
	"tinyauth/internal/types"
 | 
						"tinyauth/internal/types"
 | 
				
			||||||
	"tinyauth/internal/utils"
 | 
						"tinyauth/internal/utils"
 | 
				
			||||||
@@ -13,11 +14,12 @@ import (
 | 
				
			|||||||
	"golang.org/x/crypto/bcrypt"
 | 
						"golang.org/x/crypto/bcrypt"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func NewAuth(docker *docker.Docker, userList types.Users, oauthWhitelist []string) *Auth {
 | 
					func NewAuth(docker *docker.Docker, userList types.Users, oauthWhitelist []string, sessionExpiry int) *Auth {
 | 
				
			||||||
	return &Auth{
 | 
						return &Auth{
 | 
				
			||||||
		Docker:         docker,
 | 
							Docker:         docker,
 | 
				
			||||||
		Users:          userList,
 | 
							Users:          userList,
 | 
				
			||||||
		OAuthWhitelist: oauthWhitelist,
 | 
							OAuthWhitelist: oauthWhitelist,
 | 
				
			||||||
 | 
							SessionExpiry:  sessionExpiry,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -25,9 +27,11 @@ type Auth struct {
 | 
				
			|||||||
	Users          types.Users
 | 
						Users          types.Users
 | 
				
			||||||
	Docker         *docker.Docker
 | 
						Docker         *docker.Docker
 | 
				
			||||||
	OAuthWhitelist []string
 | 
						OAuthWhitelist []string
 | 
				
			||||||
 | 
						SessionExpiry  int
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (auth *Auth) GetUser(username string) *types.User {
 | 
					func (auth *Auth) GetUser(username string) *types.User {
 | 
				
			||||||
 | 
						// Loop through users and return the user if the username matches
 | 
				
			||||||
	for _, user := range auth.Users {
 | 
						for _, user := range auth.Users {
 | 
				
			||||||
		if user.Username == username {
 | 
							if user.Username == username {
 | 
				
			||||||
			return &user
 | 
								return &user
 | 
				
			||||||
@@ -37,91 +41,150 @@ func (auth *Auth) GetUser(username string) *types.User {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (auth *Auth) CheckPassword(user types.User, password string) bool {
 | 
					func (auth *Auth) CheckPassword(user types.User, password string) bool {
 | 
				
			||||||
	hashedPasswordErr := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
 | 
						// Compare the hashed password with the password provided
 | 
				
			||||||
	return hashedPasswordErr == nil
 | 
						return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) == nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (auth *Auth) EmailWhitelisted(emailSrc string) bool {
 | 
					func (auth *Auth) EmailWhitelisted(emailSrc string) bool {
 | 
				
			||||||
 | 
						// If the whitelist is empty, allow all emails
 | 
				
			||||||
	if len(auth.OAuthWhitelist) == 0 {
 | 
						if len(auth.OAuthWhitelist) == 0 {
 | 
				
			||||||
		return true
 | 
							return true
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Loop through the whitelist and return true if the email matches
 | 
				
			||||||
	for _, email := range auth.OAuthWhitelist {
 | 
						for _, email := range auth.OAuthWhitelist {
 | 
				
			||||||
		if email == emailSrc {
 | 
							if email == emailSrc {
 | 
				
			||||||
			return true
 | 
								return true
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// If no emails match, return false
 | 
				
			||||||
	return false
 | 
						return false
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (auth *Auth) CreateSessionCookie(c *gin.Context, data *types.SessionCookie) {
 | 
					func (auth *Auth) CreateSessionCookie(c *gin.Context, data *types.SessionCookie) {
 | 
				
			||||||
	log.Debug().Msg("Creating session cookie")
 | 
						log.Debug().Msg("Creating session cookie")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Get session
 | 
				
			||||||
	sessions := sessions.Default(c)
 | 
						sessions := sessions.Default(c)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log.Debug().Msg("Setting session cookie")
 | 
						log.Debug().Msg("Setting session cookie")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Set data
 | 
				
			||||||
	sessions.Set("username", data.Username)
 | 
						sessions.Set("username", data.Username)
 | 
				
			||||||
	sessions.Set("provider", data.Provider)
 | 
						sessions.Set("provider", data.Provider)
 | 
				
			||||||
 | 
						sessions.Set("expiry", time.Now().Add(time.Duration(auth.SessionExpiry)*time.Second).Unix())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Save session
 | 
				
			||||||
	sessions.Save()
 | 
						sessions.Save()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (auth *Auth) DeleteSessionCookie(c *gin.Context) {
 | 
					func (auth *Auth) DeleteSessionCookie(c *gin.Context) {
 | 
				
			||||||
	log.Debug().Msg("Deleting session cookie")
 | 
						log.Debug().Msg("Deleting session cookie")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Get session
 | 
				
			||||||
	sessions := sessions.Default(c)
 | 
						sessions := sessions.Default(c)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Clear session
 | 
				
			||||||
	sessions.Clear()
 | 
						sessions.Clear()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Save session
 | 
				
			||||||
	sessions.Save()
 | 
						sessions.Save()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (auth *Auth) GetSessionCookie(c *gin.Context) (types.SessionCookie, error) {
 | 
					func (auth *Auth) GetSessionCookie(c *gin.Context) types.SessionCookie {
 | 
				
			||||||
	log.Debug().Msg("Getting session cookie")
 | 
						log.Debug().Msg("Getting session cookie")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Get session
 | 
				
			||||||
	sessions := sessions.Default(c)
 | 
						sessions := sessions.Default(c)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Get data
 | 
				
			||||||
	cookieUsername := sessions.Get("username")
 | 
						cookieUsername := sessions.Get("username")
 | 
				
			||||||
	cookieProvider := sessions.Get("provider")
 | 
						cookieProvider := sessions.Get("provider")
 | 
				
			||||||
 | 
						cookieExpiry := sessions.Get("expiry")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Convert interfaces to correct types
 | 
				
			||||||
	username, usernameOk := cookieUsername.(string)
 | 
						username, usernameOk := cookieUsername.(string)
 | 
				
			||||||
	provider, providerOk := cookieProvider.(string)
 | 
						provider, providerOk := cookieProvider.(string)
 | 
				
			||||||
 | 
						expiry, expiryOk := cookieExpiry.(int64)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log.Debug().Str("username", username).Str("provider", provider).Msg("Parsed cookie")
 | 
						// Check if the cookie is invalid
 | 
				
			||||||
 | 
						if !usernameOk || !providerOk || !expiryOk {
 | 
				
			||||||
	if !usernameOk || !providerOk {
 | 
					 | 
				
			||||||
		log.Warn().Msg("Session cookie invalid")
 | 
							log.Warn().Msg("Session cookie invalid")
 | 
				
			||||||
		return types.SessionCookie{}, nil
 | 
							return types.SessionCookie{}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check if the cookie has expired
 | 
				
			||||||
 | 
						if time.Now().Unix() > expiry {
 | 
				
			||||||
 | 
							log.Warn().Msg("Session cookie expired")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// If it has, delete it
 | 
				
			||||||
 | 
							auth.DeleteSessionCookie(c)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Return empty cookie
 | 
				
			||||||
 | 
							return types.SessionCookie{}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						log.Debug().Str("username", username).Str("provider", provider).Int64("expiry", expiry).Msg("Parsed cookie")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Return the cookie
 | 
				
			||||||
	return types.SessionCookie{
 | 
						return types.SessionCookie{
 | 
				
			||||||
		Username: username,
 | 
							Username: username,
 | 
				
			||||||
		Provider: provider,
 | 
							Provider: provider,
 | 
				
			||||||
	}, nil
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (auth *Auth) UserAuthConfigured() bool {
 | 
					func (auth *Auth) UserAuthConfigured() bool {
 | 
				
			||||||
 | 
						// If there are users, return true
 | 
				
			||||||
	return len(auth.Users) > 0
 | 
						return len(auth.Users) > 0
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (auth *Auth) ResourceAllowed(context types.UserContext, host string) (bool, error) {
 | 
					func (auth *Auth) ResourceAllowed(context types.UserContext, host string) (bool, error) {
 | 
				
			||||||
 | 
						// Check if we have access to the Docker API
 | 
				
			||||||
 | 
						isConnected := auth.Docker.DockerConnected()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// If we don't have access, it is assumed that the user has access
 | 
				
			||||||
 | 
						if !isConnected {
 | 
				
			||||||
 | 
							log.Debug().Msg("Docker not connected, allowing access")
 | 
				
			||||||
 | 
							return true, nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Get the app ID from the host
 | 
				
			||||||
	appId := strings.Split(host, ".")[0]
 | 
						appId := strings.Split(host, ".")[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Get the containers
 | 
				
			||||||
	containers, containersErr := auth.Docker.GetContainers()
 | 
						containers, containersErr := auth.Docker.GetContainers()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// If there is an error, return false
 | 
				
			||||||
	if containersErr != nil {
 | 
						if containersErr != nil {
 | 
				
			||||||
		return false, containersErr
 | 
							return false, containersErr
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log.Debug().Msg("Got containers")
 | 
						log.Debug().Msg("Got containers")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Loop through the containers
 | 
				
			||||||
	for _, container := range containers {
 | 
						for _, container := range containers {
 | 
				
			||||||
 | 
							// Inspect the container
 | 
				
			||||||
		inspect, inspectErr := auth.Docker.InspectContainer(container.ID)
 | 
							inspect, inspectErr := auth.Docker.InspectContainer(container.ID)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// If there is an error, return false
 | 
				
			||||||
		if inspectErr != nil {
 | 
							if inspectErr != nil {
 | 
				
			||||||
			return false, inspectErr
 | 
								return false, inspectErr
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Get the container name (for some reason it is /name)
 | 
				
			||||||
		containerName := strings.Split(inspect.Name, "/")[1]
 | 
							containerName := strings.Split(inspect.Name, "/")[1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// There is a container with the same name as the app ID
 | 
				
			||||||
		if containerName == appId {
 | 
							if containerName == appId {
 | 
				
			||||||
			log.Debug().Str("container", containerName).Msg("Found container")
 | 
								log.Debug().Str("container", containerName).Msg("Found container")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// Get only the tinyauth labels in a struct
 | 
				
			||||||
			labels := utils.GetTinyauthLabels(inspect.Config.Labels)
 | 
								labels := utils.GetTinyauthLabels(inspect.Config.Labels)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			log.Debug().Msg("Got labels")
 | 
								log.Debug().Msg("Got labels")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// If the container has an oauth whitelist, check if the user is in it
 | 
				
			||||||
			if context.OAuth && len(labels.OAuthWhitelist) != 0 {
 | 
								if context.OAuth && len(labels.OAuthWhitelist) != 0 {
 | 
				
			||||||
				log.Debug().Msg("Checking OAuth whitelist")
 | 
									log.Debug().Msg("Checking OAuth whitelist")
 | 
				
			||||||
				if slices.Contains(labels.OAuthWhitelist, context.Username) {
 | 
									if slices.Contains(labels.OAuthWhitelist, context.Username) {
 | 
				
			||||||
@@ -130,6 +193,7 @@ func (auth *Auth) ResourceAllowed(context types.UserContext, host string) (bool,
 | 
				
			|||||||
				return false, nil
 | 
									return false, nil
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// If the container has users, check if the user is in it
 | 
				
			||||||
			if len(labels.Users) != 0 {
 | 
								if len(labels.Users) != 0 {
 | 
				
			||||||
				log.Debug().Msg("Checking users")
 | 
									log.Debug().Msg("Checking users")
 | 
				
			||||||
				if slices.Contains(labels.Users, context.Username) {
 | 
									if slices.Contains(labels.Users, context.Username) {
 | 
				
			||||||
@@ -143,5 +207,42 @@ func (auth *Auth) ResourceAllowed(context types.UserContext, host string) (bool,
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	log.Debug().Msg("No matching container found, allowing access")
 | 
						log.Debug().Msg("No matching container found, allowing access")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// If no matching container is found, allow access
 | 
				
			||||||
	return true, nil
 | 
						return true, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (auth *Auth) GetBasicAuth(c *gin.Context) types.User {
 | 
				
			||||||
 | 
						// Get the Authorization header
 | 
				
			||||||
 | 
						header := c.GetHeader("Authorization")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// If the header is empty, return an empty user
 | 
				
			||||||
 | 
						if header == "" {
 | 
				
			||||||
 | 
							return types.User{}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Split the header
 | 
				
			||||||
 | 
						headerSplit := strings.Split(header, " ")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if len(headerSplit) != 2 {
 | 
				
			||||||
 | 
							return types.User{}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check if the header is Basic
 | 
				
			||||||
 | 
						if headerSplit[0] != "Basic" {
 | 
				
			||||||
 | 
							return types.User{}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Split the credentials
 | 
				
			||||||
 | 
						credentials := strings.Split(headerSplit[1], ":")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// If the credentials are not in the correct format, return an empty user
 | 
				
			||||||
 | 
						if len(credentials) != 2 {
 | 
				
			||||||
 | 
							return types.User{}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Return the user
 | 
				
			||||||
 | 
						return types.User{
 | 
				
			||||||
 | 
							Username: credentials[0],
 | 
				
			||||||
 | 
							Password: credentials[1],
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,5 +1,6 @@
 | 
				
			|||||||
package constants
 | 
					package constants
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// TinyauthLabels is a list of labels that can be used in a tinyauth protected container
 | 
				
			||||||
var TinyauthLabels = []string{
 | 
					var TinyauthLabels = []string{
 | 
				
			||||||
	"tinyauth.oauth.whitelist",
 | 
						"tinyauth.oauth.whitelist",
 | 
				
			||||||
	"tinyauth.users",
 | 
						"tinyauth.users",
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -18,34 +18,50 @@ type Docker struct {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (docker *Docker) Init() error {
 | 
					func (docker *Docker) Init() error {
 | 
				
			||||||
 | 
						// Create a new docker client
 | 
				
			||||||
	apiClient, err := client.NewClientWithOpts(client.FromEnv)
 | 
						apiClient, err := client.NewClientWithOpts(client.FromEnv)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check if there was an error
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Set the context and api client
 | 
				
			||||||
	docker.Context = context.Background()
 | 
						docker.Context = context.Background()
 | 
				
			||||||
	docker.Client = apiClient
 | 
						docker.Client = apiClient
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Done
 | 
				
			||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (docker *Docker) GetContainers() ([]types.Container, error) {
 | 
					func (docker *Docker) GetContainers() ([]types.Container, error) {
 | 
				
			||||||
 | 
						// Get the list of containers
 | 
				
			||||||
	containers, err := docker.Client.ContainerList(docker.Context, container.ListOptions{})
 | 
						containers, err := docker.Client.ContainerList(docker.Context, container.ListOptions{})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check if there was an error
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Return the containers
 | 
				
			||||||
	return containers, nil
 | 
						return containers, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (docker *Docker) InspectContainer(containerId string) (types.ContainerJSON, error) {
 | 
					func (docker *Docker) InspectContainer(containerId string) (types.ContainerJSON, error) {
 | 
				
			||||||
 | 
						// Inspect the container
 | 
				
			||||||
	inspect, err := docker.Client.ContainerInspect(docker.Context, containerId)
 | 
						inspect, err := docker.Client.ContainerInspect(docker.Context, containerId)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check if there was an error
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return types.ContainerJSON{}, err
 | 
							return types.ContainerJSON{}, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Return the inspect
 | 
				
			||||||
	return inspect, nil
 | 
						return inspect, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (docker *Docker) DockerConnected() bool {
 | 
				
			||||||
 | 
						// Ping the docker client if there is an error it is not connected
 | 
				
			||||||
 | 
						_, err := docker.Client.Ping(docker.Context)
 | 
				
			||||||
 | 
						return err == nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -22,39 +22,64 @@ type Hooks struct {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
 | 
					func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
 | 
				
			||||||
	cookie, cookiErr := hooks.Auth.GetSessionCookie(c)
 | 
						// Get session cookie and basic auth
 | 
				
			||||||
 | 
						cookie := hooks.Auth.GetSessionCookie(c)
 | 
				
			||||||
 | 
						basic := hooks.Auth.GetBasicAuth(c)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if cookiErr != nil {
 | 
						// Check if basic auth is set
 | 
				
			||||||
		log.Error().Err(cookiErr).Msg("Failed to get session cookie")
 | 
						if basic.Username != "" {
 | 
				
			||||||
		return types.UserContext{
 | 
							log.Debug().Msg("Got basic auth")
 | 
				
			||||||
			Username:   "",
 | 
					
 | 
				
			||||||
			IsLoggedIn: false,
 | 
							// Check if user exists and password is correct
 | 
				
			||||||
			OAuth:      false,
 | 
							user := hooks.Auth.GetUser(basic.Username)
 | 
				
			||||||
			Provider:   "",
 | 
					
 | 
				
			||||||
 | 
							if user != nil && hooks.Auth.CheckPassword(*user, basic.Password) {
 | 
				
			||||||
 | 
								// Return user context since we are logged in with basic auth
 | 
				
			||||||
 | 
								return types.UserContext{
 | 
				
			||||||
 | 
									Username:   basic.Username,
 | 
				
			||||||
 | 
									IsLoggedIn: true,
 | 
				
			||||||
 | 
									OAuth:      false,
 | 
				
			||||||
 | 
									Provider:   "basic",
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check if session cookie is username/password auth
 | 
				
			||||||
	if cookie.Provider == "username" {
 | 
						if cookie.Provider == "username" {
 | 
				
			||||||
		log.Debug().Msg("Provider is username")
 | 
							log.Debug().Msg("Provider is username")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Check if user exists
 | 
				
			||||||
		if hooks.Auth.GetUser(cookie.Username) != nil {
 | 
							if hooks.Auth.GetUser(cookie.Username) != nil {
 | 
				
			||||||
			log.Debug().Msg("User exists")
 | 
								log.Debug().Msg("User exists")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// It exists so we are logged in
 | 
				
			||||||
			return types.UserContext{
 | 
								return types.UserContext{
 | 
				
			||||||
				Username:   cookie.Username,
 | 
									Username:   cookie.Username,
 | 
				
			||||||
				IsLoggedIn: true,
 | 
									IsLoggedIn: true,
 | 
				
			||||||
				OAuth:      false,
 | 
									OAuth:      false,
 | 
				
			||||||
				Provider:   "",
 | 
									Provider:   "username",
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log.Debug().Msg("Provider is not username")
 | 
						log.Debug().Msg("Provider is not username")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// The provider is not username so we need to check if it is an oauth provider
 | 
				
			||||||
	provider := hooks.Providers.GetProvider(cookie.Provider)
 | 
						provider := hooks.Providers.GetProvider(cookie.Provider)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// If we have a provider with this name
 | 
				
			||||||
	if provider != nil {
 | 
						if provider != nil {
 | 
				
			||||||
		log.Debug().Msg("Provider exists")
 | 
							log.Debug().Msg("Provider exists")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Check if the oauth email is whitelisted
 | 
				
			||||||
		if !hooks.Auth.EmailWhitelisted(cookie.Username) {
 | 
							if !hooks.Auth.EmailWhitelisted(cookie.Username) {
 | 
				
			||||||
			log.Error().Str("email", cookie.Username).Msg("Email is not whitelisted")
 | 
								log.Error().Str("email", cookie.Username).Msg("Email is not whitelisted")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// It isn't so we delete the cookie and return an empty context
 | 
				
			||||||
			hooks.Auth.DeleteSessionCookie(c)
 | 
								hooks.Auth.DeleteSessionCookie(c)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// Return empty context
 | 
				
			||||||
			return types.UserContext{
 | 
								return types.UserContext{
 | 
				
			||||||
				Username:   "",
 | 
									Username:   "",
 | 
				
			||||||
				IsLoggedIn: false,
 | 
									IsLoggedIn: false,
 | 
				
			||||||
@@ -62,7 +87,10 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
 | 
				
			|||||||
				Provider:   "",
 | 
									Provider:   "",
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		log.Debug().Msg("Email is whitelisted")
 | 
							log.Debug().Msg("Email is whitelisted")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Return user context since we are logged in with oauth
 | 
				
			||||||
		return types.UserContext{
 | 
							return types.UserContext{
 | 
				
			||||||
			Username:   cookie.Username,
 | 
								Username:   cookie.Username,
 | 
				
			||||||
			IsLoggedIn: true,
 | 
								IsLoggedIn: true,
 | 
				
			||||||
@@ -71,6 +99,7 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Neither basic auth or oauth is set so we return an empty context
 | 
				
			||||||
	return types.UserContext{
 | 
						return types.UserContext{
 | 
				
			||||||
		Username:   "",
 | 
							Username:   "",
 | 
				
			||||||
		IsLoggedIn: false,
 | 
							IsLoggedIn: false,
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -21,23 +21,33 @@ type OAuth struct {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (oauth *OAuth) Init() {
 | 
					func (oauth *OAuth) Init() {
 | 
				
			||||||
 | 
						// Create a new context and verifier
 | 
				
			||||||
	oauth.Context = context.Background()
 | 
						oauth.Context = context.Background()
 | 
				
			||||||
	oauth.Verifier = oauth2.GenerateVerifier()
 | 
						oauth.Verifier = oauth2.GenerateVerifier()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (oauth *OAuth) GetAuthURL() string {
 | 
					func (oauth *OAuth) GetAuthURL() string {
 | 
				
			||||||
 | 
						// Return the auth url
 | 
				
			||||||
	return oauth.Config.AuthCodeURL("state", oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(oauth.Verifier))
 | 
						return oauth.Config.AuthCodeURL("state", oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(oauth.Verifier))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (oauth *OAuth) ExchangeToken(code string) (string, error) {
 | 
					func (oauth *OAuth) ExchangeToken(code string) (string, error) {
 | 
				
			||||||
 | 
						// Exchange the code for a token
 | 
				
			||||||
	token, err := oauth.Config.Exchange(oauth.Context, code, oauth2.VerifierOption(oauth.Verifier))
 | 
						token, err := oauth.Config.Exchange(oauth.Context, code, oauth2.VerifierOption(oauth.Verifier))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check if there was an error
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return "", err
 | 
							return "", err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Set the token
 | 
				
			||||||
	oauth.Token = token
 | 
						oauth.Token = token
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Return the access token
 | 
				
			||||||
	return oauth.Token.AccessToken, nil
 | 
						return oauth.Token.AccessToken, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (oauth *OAuth) GetClient() *http.Client {
 | 
					func (oauth *OAuth) GetClient() *http.Client {
 | 
				
			||||||
 | 
						// Return the http client with the token set
 | 
				
			||||||
	return oauth.Config.Client(oauth.Context, oauth.Token)
 | 
						return oauth.Config.Client(oauth.Context, oauth.Token)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -8,36 +8,45 @@ import (
 | 
				
			|||||||
	"github.com/rs/zerolog/log"
 | 
						"github.com/rs/zerolog/log"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// We are assuming that the generic provider will return a JSON object with an email field
 | 
				
			||||||
type GenericUserInfoResponse struct {
 | 
					type GenericUserInfoResponse struct {
 | 
				
			||||||
	Email string `json:"email"`
 | 
						Email string `json:"email"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func GetGenericEmail(client *http.Client, url string) (string, error) {
 | 
					func GetGenericEmail(client *http.Client, url string) (string, error) {
 | 
				
			||||||
 | 
						// Using the oauth client get the user info url
 | 
				
			||||||
	res, resErr := client.Get(url)
 | 
						res, resErr := client.Get(url)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check if there was an error
 | 
				
			||||||
	if resErr != nil {
 | 
						if resErr != nil {
 | 
				
			||||||
		return "", resErr
 | 
							return "", resErr
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log.Debug().Msg("Got response from generic provider")
 | 
						log.Debug().Msg("Got response from generic provider")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Read the body of the response
 | 
				
			||||||
	body, bodyErr := io.ReadAll(res.Body)
 | 
						body, bodyErr := io.ReadAll(res.Body)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check if there was an error
 | 
				
			||||||
	if bodyErr != nil {
 | 
						if bodyErr != nil {
 | 
				
			||||||
		return "", bodyErr
 | 
							return "", bodyErr
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log.Debug().Msg("Read body from generic provider")
 | 
						log.Debug().Msg("Read body from generic provider")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Parse the body into a user struct
 | 
				
			||||||
	var user GenericUserInfoResponse
 | 
						var user GenericUserInfoResponse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Unmarshal the body into the user struct
 | 
				
			||||||
	jsonErr := json.Unmarshal(body, &user)
 | 
						jsonErr := json.Unmarshal(body, &user)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check if there was an error
 | 
				
			||||||
	if jsonErr != nil {
 | 
						if jsonErr != nil {
 | 
				
			||||||
		return "", jsonErr
 | 
							return "", jsonErr
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log.Debug().Msg("Parsed user from generic provider")
 | 
						log.Debug().Msg("Parsed user from generic provider")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Return the email
 | 
				
			||||||
	return user.Email, nil
 | 
						return user.Email, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -9,47 +9,58 @@ import (
 | 
				
			|||||||
	"github.com/rs/zerolog/log"
 | 
						"github.com/rs/zerolog/log"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Github has a different response than the generic provider
 | 
				
			||||||
type GithubUserInfoResponse []struct {
 | 
					type GithubUserInfoResponse []struct {
 | 
				
			||||||
	Email   string `json:"email"`
 | 
						Email   string `json:"email"`
 | 
				
			||||||
	Primary bool   `json:"primary"`
 | 
						Primary bool   `json:"primary"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// The scopes required for the github provider
 | 
				
			||||||
func GithubScopes() []string {
 | 
					func GithubScopes() []string {
 | 
				
			||||||
	return []string{"user:email"}
 | 
						return []string{"user:email"}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func GetGithubEmail(client *http.Client) (string, error) {
 | 
					func GetGithubEmail(client *http.Client) (string, error) {
 | 
				
			||||||
 | 
						// Get the user emails from github using the oauth http client
 | 
				
			||||||
	res, resErr := client.Get("https://api.github.com/user/emails")
 | 
						res, resErr := client.Get("https://api.github.com/user/emails")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check if there was an error
 | 
				
			||||||
	if resErr != nil {
 | 
						if resErr != nil {
 | 
				
			||||||
		return "", resErr
 | 
							return "", resErr
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log.Debug().Msg("Got response from github")
 | 
						log.Debug().Msg("Got response from github")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Read the body of the response
 | 
				
			||||||
	body, bodyErr := io.ReadAll(res.Body)
 | 
						body, bodyErr := io.ReadAll(res.Body)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check if there was an error
 | 
				
			||||||
	if bodyErr != nil {
 | 
						if bodyErr != nil {
 | 
				
			||||||
		return "", bodyErr
 | 
							return "", bodyErr
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log.Debug().Msg("Read body from github")
 | 
						log.Debug().Msg("Read body from github")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Parse the body into a user struct
 | 
				
			||||||
	var emails GithubUserInfoResponse
 | 
						var emails GithubUserInfoResponse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Unmarshal the body into the user struct
 | 
				
			||||||
	jsonErr := json.Unmarshal(body, &emails)
 | 
						jsonErr := json.Unmarshal(body, &emails)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check if there was an error
 | 
				
			||||||
	if jsonErr != nil {
 | 
						if jsonErr != nil {
 | 
				
			||||||
		return "", jsonErr
 | 
							return "", jsonErr
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log.Debug().Msg("Parsed emails from github")
 | 
						log.Debug().Msg("Parsed emails from github")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Find and return the primary email
 | 
				
			||||||
	for _, email := range emails {
 | 
						for _, email := range emails {
 | 
				
			||||||
		if email.Primary {
 | 
							if email.Primary {
 | 
				
			||||||
			return email.Email, nil
 | 
								return email.Email, nil
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// User does not have a primary email?
 | 
				
			||||||
	return "", errors.New("no primary email found")
 | 
						return "", errors.New("no primary email found")
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -8,40 +8,50 @@ import (
 | 
				
			|||||||
	"github.com/rs/zerolog/log"
 | 
						"github.com/rs/zerolog/log"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Google works the same as the generic provider
 | 
				
			||||||
type GoogleUserInfoResponse struct {
 | 
					type GoogleUserInfoResponse struct {
 | 
				
			||||||
	Email string `json:"email"`
 | 
						Email string `json:"email"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// The scopes required for the google provider
 | 
				
			||||||
func GoogleScopes() []string {
 | 
					func GoogleScopes() []string {
 | 
				
			||||||
	return []string{"https://www.googleapis.com/auth/userinfo.email"}
 | 
						return []string{"https://www.googleapis.com/auth/userinfo.email"}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func GetGoogleEmail(client *http.Client) (string, error) {
 | 
					func GetGoogleEmail(client *http.Client) (string, error) {
 | 
				
			||||||
 | 
						// Get the user info from google using the oauth http client
 | 
				
			||||||
	res, resErr := client.Get("https://www.googleapis.com/userinfo/v2/me")
 | 
						res, resErr := client.Get("https://www.googleapis.com/userinfo/v2/me")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check if there was an error
 | 
				
			||||||
	if resErr != nil {
 | 
						if resErr != nil {
 | 
				
			||||||
		return "", resErr
 | 
							return "", resErr
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log.Debug().Msg("Got response from google")
 | 
						log.Debug().Msg("Got response from google")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Read the body of the response
 | 
				
			||||||
	body, bodyErr := io.ReadAll(res.Body)
 | 
						body, bodyErr := io.ReadAll(res.Body)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check if there was an error
 | 
				
			||||||
	if bodyErr != nil {
 | 
						if bodyErr != nil {
 | 
				
			||||||
		return "", bodyErr
 | 
							return "", bodyErr
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log.Debug().Msg("Read body from google")
 | 
						log.Debug().Msg("Read body from google")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Parse the body into a user struct
 | 
				
			||||||
	var user GoogleUserInfoResponse
 | 
						var user GoogleUserInfoResponse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Unmarshal the body into the user struct
 | 
				
			||||||
	jsonErr := json.Unmarshal(body, &user)
 | 
						jsonErr := json.Unmarshal(body, &user)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check if there was an error
 | 
				
			||||||
	if jsonErr != nil {
 | 
						if jsonErr != nil {
 | 
				
			||||||
		return "", jsonErr
 | 
							return "", jsonErr
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log.Debug().Msg("Parsed user from google")
 | 
						log.Debug().Msg("Parsed user from google")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Return the email
 | 
				
			||||||
	return user.Email, nil
 | 
						return user.Email, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -25,8 +25,11 @@ type Providers struct {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (providers *Providers) Init() {
 | 
					func (providers *Providers) Init() {
 | 
				
			||||||
 | 
						// If we have a client id and secret for github, initialize the oauth provider
 | 
				
			||||||
	if providers.Config.GithubClientId != "" && providers.Config.GithubClientSecret != "" {
 | 
						if providers.Config.GithubClientId != "" && providers.Config.GithubClientSecret != "" {
 | 
				
			||||||
		log.Info().Msg("Initializing Github OAuth")
 | 
							log.Info().Msg("Initializing Github OAuth")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Create a new oauth provider with the github config
 | 
				
			||||||
		providers.Github = oauth.NewOAuth(oauth2.Config{
 | 
							providers.Github = oauth.NewOAuth(oauth2.Config{
 | 
				
			||||||
			ClientID:     providers.Config.GithubClientId,
 | 
								ClientID:     providers.Config.GithubClientId,
 | 
				
			||||||
			ClientSecret: providers.Config.GithubClientSecret,
 | 
								ClientSecret: providers.Config.GithubClientSecret,
 | 
				
			||||||
@@ -34,10 +37,16 @@ func (providers *Providers) Init() {
 | 
				
			|||||||
			Scopes:       GithubScopes(),
 | 
								Scopes:       GithubScopes(),
 | 
				
			||||||
			Endpoint:     endpoints.GitHub,
 | 
								Endpoint:     endpoints.GitHub,
 | 
				
			||||||
		})
 | 
							})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Initialize the oauth provider
 | 
				
			||||||
		providers.Github.Init()
 | 
							providers.Github.Init()
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// If we have a client id and secret for google, initialize the oauth provider
 | 
				
			||||||
	if providers.Config.GoogleClientId != "" && providers.Config.GoogleClientSecret != "" {
 | 
						if providers.Config.GoogleClientId != "" && providers.Config.GoogleClientSecret != "" {
 | 
				
			||||||
		log.Info().Msg("Initializing Google OAuth")
 | 
							log.Info().Msg("Initializing Google OAuth")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Create a new oauth provider with the google config
 | 
				
			||||||
		providers.Google = oauth.NewOAuth(oauth2.Config{
 | 
							providers.Google = oauth.NewOAuth(oauth2.Config{
 | 
				
			||||||
			ClientID:     providers.Config.GoogleClientId,
 | 
								ClientID:     providers.Config.GoogleClientId,
 | 
				
			||||||
			ClientSecret: providers.Config.GoogleClientSecret,
 | 
								ClientSecret: providers.Config.GoogleClientSecret,
 | 
				
			||||||
@@ -45,10 +54,15 @@ func (providers *Providers) Init() {
 | 
				
			|||||||
			Scopes:       GoogleScopes(),
 | 
								Scopes:       GoogleScopes(),
 | 
				
			||||||
			Endpoint:     endpoints.Google,
 | 
								Endpoint:     endpoints.Google,
 | 
				
			||||||
		})
 | 
							})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Initialize the oauth provider
 | 
				
			||||||
		providers.Google.Init()
 | 
							providers.Google.Init()
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if providers.Config.TailscaleClientId != "" && providers.Config.TailscaleClientSecret != "" {
 | 
						if providers.Config.TailscaleClientId != "" && providers.Config.TailscaleClientSecret != "" {
 | 
				
			||||||
		log.Info().Msg("Initializing Tailscale OAuth")
 | 
							log.Info().Msg("Initializing Tailscale OAuth")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Create a new oauth provider with the tailscale config
 | 
				
			||||||
		providers.Tailscale = oauth.NewOAuth(oauth2.Config{
 | 
							providers.Tailscale = oauth.NewOAuth(oauth2.Config{
 | 
				
			||||||
			ClientID:     providers.Config.TailscaleClientId,
 | 
								ClientID:     providers.Config.TailscaleClientId,
 | 
				
			||||||
			ClientSecret: providers.Config.TailscaleClientSecret,
 | 
								ClientSecret: providers.Config.TailscaleClientSecret,
 | 
				
			||||||
@@ -56,10 +70,16 @@ func (providers *Providers) Init() {
 | 
				
			|||||||
			Scopes:       TailscaleScopes(),
 | 
								Scopes:       TailscaleScopes(),
 | 
				
			||||||
			Endpoint:     TailscaleEndpoint,
 | 
								Endpoint:     TailscaleEndpoint,
 | 
				
			||||||
		})
 | 
							})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Initialize the oauth provider
 | 
				
			||||||
		providers.Tailscale.Init()
 | 
							providers.Tailscale.Init()
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// If we have a client id and secret for generic oauth, initialize the oauth provider
 | 
				
			||||||
	if providers.Config.GenericClientId != "" && providers.Config.GenericClientSecret != "" {
 | 
						if providers.Config.GenericClientId != "" && providers.Config.GenericClientSecret != "" {
 | 
				
			||||||
		log.Info().Msg("Initializing Generic OAuth")
 | 
							log.Info().Msg("Initializing Generic OAuth")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Create a new oauth provider with the generic config
 | 
				
			||||||
		providers.Generic = oauth.NewOAuth(oauth2.Config{
 | 
							providers.Generic = oauth.NewOAuth(oauth2.Config{
 | 
				
			||||||
			ClientID:     providers.Config.GenericClientId,
 | 
								ClientID:     providers.Config.GenericClientId,
 | 
				
			||||||
			ClientSecret: providers.Config.GenericClientSecret,
 | 
								ClientSecret: providers.Config.GenericClientSecret,
 | 
				
			||||||
@@ -70,11 +90,14 @@ func (providers *Providers) Init() {
 | 
				
			|||||||
				TokenURL: providers.Config.GenericTokenURL,
 | 
									TokenURL: providers.Config.GenericTokenURL,
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
		})
 | 
							})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Initialize the oauth provider
 | 
				
			||||||
		providers.Generic.Init()
 | 
							providers.Generic.Init()
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (providers *Providers) GetProvider(provider string) *oauth.OAuth {
 | 
					func (providers *Providers) GetProvider(provider string) *oauth.OAuth {
 | 
				
			||||||
 | 
						// Return the provider based on the provider string
 | 
				
			||||||
	switch provider {
 | 
						switch provider {
 | 
				
			||||||
	case "github":
 | 
						case "github":
 | 
				
			||||||
		return providers.Github
 | 
							return providers.Github
 | 
				
			||||||
@@ -90,58 +113,103 @@ func (providers *Providers) GetProvider(provider string) *oauth.OAuth {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (providers *Providers) GetUser(provider string) (string, error) {
 | 
					func (providers *Providers) GetUser(provider string) (string, error) {
 | 
				
			||||||
 | 
						// Get the email from the provider
 | 
				
			||||||
	switch provider {
 | 
						switch provider {
 | 
				
			||||||
	case "github":
 | 
						case "github":
 | 
				
			||||||
 | 
							// If the github provider is not configured, return an error
 | 
				
			||||||
		if providers.Github == nil {
 | 
							if providers.Github == nil {
 | 
				
			||||||
			log.Debug().Msg("Github provider not configured")
 | 
								log.Debug().Msg("Github provider not configured")
 | 
				
			||||||
			return "", nil
 | 
								return "", nil
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Get the client from the github provider
 | 
				
			||||||
		client := providers.Github.GetClient()
 | 
							client := providers.Github.GetClient()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		log.Debug().Msg("Got client from github")
 | 
							log.Debug().Msg("Got client from github")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Get the email from the github provider
 | 
				
			||||||
		email, emailErr := GetGithubEmail(client)
 | 
							email, emailErr := GetGithubEmail(client)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Check if there was an error
 | 
				
			||||||
		if emailErr != nil {
 | 
							if emailErr != nil {
 | 
				
			||||||
			return "", emailErr
 | 
								return "", emailErr
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		log.Debug().Msg("Got email from github")
 | 
							log.Debug().Msg("Got email from github")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Return the email
 | 
				
			||||||
		return email, nil
 | 
							return email, nil
 | 
				
			||||||
	case "google":
 | 
						case "google":
 | 
				
			||||||
 | 
							// If the google provider is not configured, return an error
 | 
				
			||||||
		if providers.Google == nil {
 | 
							if providers.Google == nil {
 | 
				
			||||||
			log.Debug().Msg("Google provider not configured")
 | 
								log.Debug().Msg("Google provider not configured")
 | 
				
			||||||
			return "", nil
 | 
								return "", nil
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Get the client from the google provider
 | 
				
			||||||
		client := providers.Google.GetClient()
 | 
							client := providers.Google.GetClient()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		log.Debug().Msg("Got client from google")
 | 
							log.Debug().Msg("Got client from google")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Get the email from the google provider
 | 
				
			||||||
		email, emailErr := GetGoogleEmail(client)
 | 
							email, emailErr := GetGoogleEmail(client)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Check if there was an error
 | 
				
			||||||
		if emailErr != nil {
 | 
							if emailErr != nil {
 | 
				
			||||||
			return "", emailErr
 | 
								return "", emailErr
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		log.Debug().Msg("Got email from google")
 | 
							log.Debug().Msg("Got email from google")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Return the email
 | 
				
			||||||
		return email, nil
 | 
							return email, nil
 | 
				
			||||||
	case "tailscale":
 | 
						case "tailscale":
 | 
				
			||||||
 | 
							// If the tailscale provider is not configured, return an error
 | 
				
			||||||
		if providers.Tailscale == nil {
 | 
							if providers.Tailscale == nil {
 | 
				
			||||||
			log.Debug().Msg("Tailscale provider not configured")
 | 
								log.Debug().Msg("Tailscale provider not configured")
 | 
				
			||||||
			return "", nil
 | 
								return "", nil
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Get the client from the tailscale provider
 | 
				
			||||||
		client := providers.Tailscale.GetClient()
 | 
							client := providers.Tailscale.GetClient()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		log.Debug().Msg("Got client from tailscale")
 | 
							log.Debug().Msg("Got client from tailscale")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Get the email from the tailscale provider
 | 
				
			||||||
		email, emailErr := GetTailscaleEmail(client)
 | 
							email, emailErr := GetTailscaleEmail(client)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Check if there was an error
 | 
				
			||||||
		if emailErr != nil {
 | 
							if emailErr != nil {
 | 
				
			||||||
			return "", emailErr
 | 
								return "", emailErr
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		log.Debug().Msg("Got email from tailscale")
 | 
							log.Debug().Msg("Got email from tailscale")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Return the email
 | 
				
			||||||
		return email, nil
 | 
							return email, nil
 | 
				
			||||||
	case "generic":
 | 
						case "generic":
 | 
				
			||||||
 | 
							// If the generic provider is not configured, return an error
 | 
				
			||||||
		if providers.Generic == nil {
 | 
							if providers.Generic == nil {
 | 
				
			||||||
			log.Debug().Msg("Generic provider not configured")
 | 
								log.Debug().Msg("Generic provider not configured")
 | 
				
			||||||
			return "", nil
 | 
								return "", nil
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Get the client from the generic provider
 | 
				
			||||||
		client := providers.Generic.GetClient()
 | 
							client := providers.Generic.GetClient()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		log.Debug().Msg("Got client from generic")
 | 
							log.Debug().Msg("Got client from generic")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Get the email from the generic provider
 | 
				
			||||||
		email, emailErr := GetGenericEmail(client, providers.Config.GenericUserURL)
 | 
							email, emailErr := GetGenericEmail(client, providers.Config.GenericUserURL)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Check if there was an error
 | 
				
			||||||
		if emailErr != nil {
 | 
							if emailErr != nil {
 | 
				
			||||||
			return "", emailErr
 | 
								return "", emailErr
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		log.Debug().Msg("Got email from generic")
 | 
							log.Debug().Msg("Got email from generic")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Return the email
 | 
				
			||||||
		return email, nil
 | 
							return email, nil
 | 
				
			||||||
	default:
 | 
						default:
 | 
				
			||||||
		return "", nil
 | 
							return "", nil
 | 
				
			||||||
@@ -149,6 +217,7 @@ func (providers *Providers) GetUser(provider string) (string, error) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (provider *Providers) GetConfiguredProviders() []string {
 | 
					func (provider *Providers) GetConfiguredProviders() []string {
 | 
				
			||||||
 | 
						// Create a list of the configured providers
 | 
				
			||||||
	providers := []string{}
 | 
						providers := []string{}
 | 
				
			||||||
	if provider.Github != nil {
 | 
						if provider.Github != nil {
 | 
				
			||||||
		providers = append(providers, "github")
 | 
							providers = append(providers, "github")
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -9,48 +9,60 @@ import (
 | 
				
			|||||||
	"golang.org/x/oauth2"
 | 
						"golang.org/x/oauth2"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// The tailscale email is the loginName
 | 
				
			||||||
type TailscaleUser struct {
 | 
					type TailscaleUser struct {
 | 
				
			||||||
	LoginName string `json:"loginName"`
 | 
						LoginName string `json:"loginName"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// The response from the tailscale user info endpoint
 | 
				
			||||||
type TailscaleUserInfoResponse struct {
 | 
					type TailscaleUserInfoResponse struct {
 | 
				
			||||||
	Users []TailscaleUser `json:"users"`
 | 
						Users []TailscaleUser `json:"users"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// The scopes required for the tailscale provider
 | 
				
			||||||
func TailscaleScopes() []string {
 | 
					func TailscaleScopes() []string {
 | 
				
			||||||
	return []string{"users:read"}
 | 
						return []string{"users:read"}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// The tailscale endpoint
 | 
				
			||||||
var TailscaleEndpoint = oauth2.Endpoint{
 | 
					var TailscaleEndpoint = oauth2.Endpoint{
 | 
				
			||||||
	TokenURL: "https://api.tailscale.com/api/v2/oauth/token",
 | 
						TokenURL: "https://api.tailscale.com/api/v2/oauth/token",
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func GetTailscaleEmail(client *http.Client) (string, error) {
 | 
					func GetTailscaleEmail(client *http.Client) (string, error) {
 | 
				
			||||||
 | 
						// Get the user info from tailscale using the oauth http client
 | 
				
			||||||
	res, resErr := client.Get("https://api.tailscale.com/api/v2/tailnet/-/users")
 | 
						res, resErr := client.Get("https://api.tailscale.com/api/v2/tailnet/-/users")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check if there was an error
 | 
				
			||||||
	if resErr != nil {
 | 
						if resErr != nil {
 | 
				
			||||||
		return "", resErr
 | 
							return "", resErr
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log.Debug().Msg("Got response from tailscale")
 | 
						log.Debug().Msg("Got response from tailscale")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Read the body of the response
 | 
				
			||||||
	body, bodyErr := io.ReadAll(res.Body)
 | 
						body, bodyErr := io.ReadAll(res.Body)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check if there was an error
 | 
				
			||||||
	if bodyErr != nil {
 | 
						if bodyErr != nil {
 | 
				
			||||||
		return "", bodyErr
 | 
							return "", bodyErr
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log.Debug().Msg("Read body from tailscale")
 | 
						log.Debug().Msg("Read body from tailscale")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Parse the body into a user struct
 | 
				
			||||||
	var users TailscaleUserInfoResponse
 | 
						var users TailscaleUserInfoResponse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Unmarshal the body into the user struct
 | 
				
			||||||
	jsonErr := json.Unmarshal(body, &users)
 | 
						jsonErr := json.Unmarshal(body, &users)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check if there was an error
 | 
				
			||||||
	if jsonErr != nil {
 | 
						if jsonErr != nil {
 | 
				
			||||||
		return "", jsonErr
 | 
							return "", jsonErr
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log.Debug().Msg("Parsed users from tailscale")
 | 
						log.Debug().Msg("Parsed users from tailscale")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Return the email of the first user
 | 
				
			||||||
	return users.Users[0].LoginName, nil
 | 
						return users.Users[0].LoginName, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,22 +2,27 @@ package types
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import "tinyauth/internal/oauth"
 | 
					import "tinyauth/internal/oauth"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// LoginQuery is the query parameters for the login endpoint
 | 
				
			||||||
type LoginQuery struct {
 | 
					type LoginQuery struct {
 | 
				
			||||||
	RedirectURI string `url:"redirect_uri"`
 | 
						RedirectURI string `url:"redirect_uri"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// LoginRequest is the request body for the login endpoint
 | 
				
			||||||
type LoginRequest struct {
 | 
					type LoginRequest struct {
 | 
				
			||||||
	Username string `json:"username"`
 | 
						Username string `json:"username"`
 | 
				
			||||||
	Password string `json:"password"`
 | 
						Password string `json:"password"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// User is the struct for a user
 | 
				
			||||||
type User struct {
 | 
					type User struct {
 | 
				
			||||||
	Username string
 | 
						Username string
 | 
				
			||||||
	Password string
 | 
						Password string
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Users is a list of users
 | 
				
			||||||
type Users []User
 | 
					type Users []User
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Config is the configuration for the tinyauth server
 | 
				
			||||||
type Config struct {
 | 
					type Config struct {
 | 
				
			||||||
	Port                      int    `mapstructure:"port" validate:"required"`
 | 
						Port                      int    `mapstructure:"port" validate:"required"`
 | 
				
			||||||
	Address                   string `validate:"required,ip4_addr" mapstructure:"address"`
 | 
						Address                   string `validate:"required,ip4_addr" mapstructure:"address"`
 | 
				
			||||||
@@ -45,10 +50,11 @@ type Config struct {
 | 
				
			|||||||
	GenericUserURL            string `mapstructure:"generic-user-url"`
 | 
						GenericUserURL            string `mapstructure:"generic-user-url"`
 | 
				
			||||||
	DisableContinue           bool   `mapstructure:"disable-continue"`
 | 
						DisableContinue           bool   `mapstructure:"disable-continue"`
 | 
				
			||||||
	OAuthWhitelist            string `mapstructure:"oauth-whitelist"`
 | 
						OAuthWhitelist            string `mapstructure:"oauth-whitelist"`
 | 
				
			||||||
	CookieExpiry              int    `mapstructure:"cookie-expiry"`
 | 
						SessionExpiry             int    `mapstructure:"session-expiry"`
 | 
				
			||||||
	LogLevel                  int8   `mapstructure:"log-level" validate:"min=-1,max=5"`
 | 
						LogLevel                  int8   `mapstructure:"log-level" validate:"min=-1,max=5"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// UserContext is the context for the user
 | 
				
			||||||
type UserContext struct {
 | 
					type UserContext struct {
 | 
				
			||||||
	Username   string
 | 
						Username   string
 | 
				
			||||||
	IsLoggedIn bool
 | 
						IsLoggedIn bool
 | 
				
			||||||
@@ -56,6 +62,7 @@ type UserContext struct {
 | 
				
			|||||||
	Provider   string
 | 
						Provider   string
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// APIConfig is the configuration for the API
 | 
				
			||||||
type APIConfig struct {
 | 
					type APIConfig struct {
 | 
				
			||||||
	Port            int
 | 
						Port            int
 | 
				
			||||||
	Address         string
 | 
						Address         string
 | 
				
			||||||
@@ -66,6 +73,7 @@ type APIConfig struct {
 | 
				
			|||||||
	DisableContinue bool
 | 
						DisableContinue bool
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// OAuthConfig is the configuration for the providers
 | 
				
			||||||
type OAuthConfig struct {
 | 
					type OAuthConfig struct {
 | 
				
			||||||
	GithubClientId        string
 | 
						GithubClientId        string
 | 
				
			||||||
	GithubClientSecret    string
 | 
						GithubClientSecret    string
 | 
				
			||||||
@@ -82,31 +90,42 @@ type OAuthConfig struct {
 | 
				
			|||||||
	AppURL                string
 | 
						AppURL                string
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// OAuthRequest is the request for the OAuth endpoint
 | 
				
			||||||
type OAuthRequest struct {
 | 
					type OAuthRequest struct {
 | 
				
			||||||
	Provider string `uri:"provider" binding:"required"`
 | 
						Provider string `uri:"provider" binding:"required"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// OAuthProviders is the struct for the OAuth providers
 | 
				
			||||||
type OAuthProviders struct {
 | 
					type OAuthProviders struct {
 | 
				
			||||||
	Github    *oauth.OAuth
 | 
						Github    *oauth.OAuth
 | 
				
			||||||
	Google    *oauth.OAuth
 | 
						Google    *oauth.OAuth
 | 
				
			||||||
	Microsoft *oauth.OAuth
 | 
						Microsoft *oauth.OAuth
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// UnauthorizedQuery is the query parameters for the unauthorized endpoint
 | 
				
			||||||
type UnauthorizedQuery struct {
 | 
					type UnauthorizedQuery struct {
 | 
				
			||||||
	Username string `url:"username"`
 | 
						Username string `url:"username"`
 | 
				
			||||||
	Resource string `url:"resource"`
 | 
						Resource string `url:"resource"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// SessionCookie is the cookie for the session (exculding the expiry)
 | 
				
			||||||
type SessionCookie struct {
 | 
					type SessionCookie struct {
 | 
				
			||||||
	Username string
 | 
						Username string
 | 
				
			||||||
	Provider string
 | 
						Provider string
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// TinyauthLabels is the labels for the tinyauth container
 | 
				
			||||||
type TinyauthLabels struct {
 | 
					type TinyauthLabels struct {
 | 
				
			||||||
	OAuthWhitelist []string
 | 
						OAuthWhitelist []string
 | 
				
			||||||
	Users          []string
 | 
						Users          []string
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// TailscaleQuery is the query parameters for the tailscale endpoint
 | 
				
			||||||
type TailscaleQuery struct {
 | 
					type TailscaleQuery struct {
 | 
				
			||||||
	Code int `url:"code"`
 | 
						Code int `url:"code"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Proxy is the uri parameters for the proxy endpoint
 | 
				
			||||||
 | 
					type Proxy struct {
 | 
				
			||||||
 | 
						Proxy string `uri:"proxy" binding:"required"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -12,20 +12,32 @@ import (
 | 
				
			|||||||
	"github.com/rs/zerolog/log"
 | 
						"github.com/rs/zerolog/log"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Parses a list of comma separated users in a struct
 | 
				
			||||||
func ParseUsers(users string) (types.Users, error) {
 | 
					func ParseUsers(users string) (types.Users, error) {
 | 
				
			||||||
	log.Debug().Msg("Parsing users")
 | 
						log.Debug().Msg("Parsing users")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Create a new users struct
 | 
				
			||||||
	var usersParsed types.Users
 | 
						var usersParsed types.Users
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Split the users by comma
 | 
				
			||||||
	userList := strings.Split(users, ",")
 | 
						userList := strings.Split(users, ",")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check if there are any users
 | 
				
			||||||
	if len(userList) == 0 {
 | 
						if len(userList) == 0 {
 | 
				
			||||||
		return types.Users{}, errors.New("invalid user format")
 | 
							return types.Users{}, errors.New("invalid user format")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Loop through the users and split them by colon
 | 
				
			||||||
	for _, user := range userList {
 | 
						for _, user := range userList {
 | 
				
			||||||
 | 
							// Split the user by colon
 | 
				
			||||||
		userSplit := strings.Split(user, ":")
 | 
							userSplit := strings.Split(user, ":")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Check if the user is in the correct format
 | 
				
			||||||
		if len(userSplit) != 2 {
 | 
							if len(userSplit) != 2 {
 | 
				
			||||||
			return types.Users{}, errors.New("invalid user format")
 | 
								return types.Users{}, errors.New("invalid user format")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Append the user to the users struct
 | 
				
			||||||
		usersParsed = append(usersParsed, types.User{
 | 
							usersParsed = append(usersParsed, types.User{
 | 
				
			||||||
			Username: userSplit[0],
 | 
								Username: userSplit[0],
 | 
				
			||||||
			Password: userSplit[1],
 | 
								Password: userSplit[1],
 | 
				
			||||||
@@ -34,43 +46,61 @@ func ParseUsers(users string) (types.Users, error) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	log.Debug().Msg("Parsed users")
 | 
						log.Debug().Msg("Parsed users")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Return the users struct
 | 
				
			||||||
	return usersParsed, nil
 | 
						return usersParsed, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Root url parses parses a hostname and returns the root domain (e.g. sub1.sub2.domain.com -> domain.com)
 | 
				
			||||||
func GetRootURL(urlSrc string) (string, error) {
 | 
					func GetRootURL(urlSrc string) (string, error) {
 | 
				
			||||||
 | 
						// Make sure the url is valid
 | 
				
			||||||
	urlParsed, parseErr := url.Parse(urlSrc)
 | 
						urlParsed, parseErr := url.Parse(urlSrc)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check if there was an error
 | 
				
			||||||
	if parseErr != nil {
 | 
						if parseErr != nil {
 | 
				
			||||||
		return "", parseErr
 | 
							return "", parseErr
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Split the hostname by period
 | 
				
			||||||
	urlSplitted := strings.Split(urlParsed.Hostname(), ".")
 | 
						urlSplitted := strings.Split(urlParsed.Hostname(), ".")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Get the last part of the url
 | 
				
			||||||
	urlFinal := strings.Join(urlSplitted[1:], ".")
 | 
						urlFinal := strings.Join(urlSplitted[1:], ".")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Return the root domain
 | 
				
			||||||
	return urlFinal, nil
 | 
						return urlFinal, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Reads a file and returns the contents
 | 
				
			||||||
func ReadFile(file string) (string, error) {
 | 
					func ReadFile(file string) (string, error) {
 | 
				
			||||||
 | 
						// Check if the file exists
 | 
				
			||||||
	_, statErr := os.Stat(file)
 | 
						_, statErr := os.Stat(file)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check if there was an error
 | 
				
			||||||
	if statErr != nil {
 | 
						if statErr != nil {
 | 
				
			||||||
		return "", statErr
 | 
							return "", statErr
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Read the file
 | 
				
			||||||
	data, readErr := os.ReadFile(file)
 | 
						data, readErr := os.ReadFile(file)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check if there was an error
 | 
				
			||||||
	if readErr != nil {
 | 
						if readErr != nil {
 | 
				
			||||||
		return "", readErr
 | 
							return "", readErr
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Return the file contents
 | 
				
			||||||
	return string(data), nil
 | 
						return string(data), nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Parses a file into a comma separated list of users
 | 
				
			||||||
func ParseFileToLine(content string) string {
 | 
					func ParseFileToLine(content string) string {
 | 
				
			||||||
 | 
						// Split the content by newline
 | 
				
			||||||
	lines := strings.Split(content, "\n")
 | 
						lines := strings.Split(content, "\n")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Create a list of users
 | 
				
			||||||
	users := make([]string, 0)
 | 
						users := make([]string, 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Loop through the lines, trimming the whitespace and appending to the users list
 | 
				
			||||||
	for _, line := range lines {
 | 
						for _, line := range lines {
 | 
				
			||||||
		if strings.TrimSpace(line) == "" {
 | 
							if strings.TrimSpace(line) == "" {
 | 
				
			||||||
			continue
 | 
								continue
 | 
				
			||||||
@@ -79,63 +109,92 @@ func ParseFileToLine(content string) string {
 | 
				
			|||||||
		users = append(users, strings.TrimSpace(line))
 | 
							users = append(users, strings.TrimSpace(line))
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Return the users as a comma separated string
 | 
				
			||||||
	return strings.Join(users, ",")
 | 
						return strings.Join(users, ",")
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Get the secret from the config or file
 | 
				
			||||||
func GetSecret(conf string, file string) string {
 | 
					func GetSecret(conf string, file string) string {
 | 
				
			||||||
 | 
						// If neither the config or file is set, return an empty string
 | 
				
			||||||
	if conf == "" && file == "" {
 | 
						if conf == "" && file == "" {
 | 
				
			||||||
		return ""
 | 
							return ""
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// If the config is set, return the config (environment variable)
 | 
				
			||||||
	if conf != "" {
 | 
						if conf != "" {
 | 
				
			||||||
		return conf
 | 
							return conf
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// If the file is set, read the file
 | 
				
			||||||
	contents, err := ReadFile(file)
 | 
						contents, err := ReadFile(file)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check if there was an error
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return ""
 | 
							return ""
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Return the contents of the file
 | 
				
			||||||
	return contents
 | 
						return contents
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Get the users from the config or file
 | 
				
			||||||
func GetUsers(conf string, file string) (types.Users, error) {
 | 
					func GetUsers(conf string, file string) (types.Users, error) {
 | 
				
			||||||
 | 
						// Create a string to store the users
 | 
				
			||||||
	var users string
 | 
						var users string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// If neither the config or file is set, return an empty users struct
 | 
				
			||||||
	if conf == "" && file == "" {
 | 
						if conf == "" && file == "" {
 | 
				
			||||||
		return types.Users{}, errors.New("no users provided")
 | 
							return types.Users{}, nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// If the config (environment) is set, append the users to the users string
 | 
				
			||||||
	if conf != "" {
 | 
						if conf != "" {
 | 
				
			||||||
		log.Debug().Msg("Using users from config")
 | 
							log.Debug().Msg("Using users from config")
 | 
				
			||||||
		users += conf
 | 
							users += conf
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// If the file is set, read the file and append the users to the users string
 | 
				
			||||||
	if file != "" {
 | 
						if file != "" {
 | 
				
			||||||
 | 
							// Read the file
 | 
				
			||||||
		fileContents, fileErr := ReadFile(file)
 | 
							fileContents, fileErr := ReadFile(file)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// If there isn't an error we can append the users to the users string
 | 
				
			||||||
		if fileErr == nil {
 | 
							if fileErr == nil {
 | 
				
			||||||
			log.Debug().Msg("Using users from file")
 | 
								log.Debug().Msg("Using users from file")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// Append the users to the users string
 | 
				
			||||||
			if users != "" {
 | 
								if users != "" {
 | 
				
			||||||
				users += ","
 | 
									users += ","
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// Parse the file contents into a comma separated list of users
 | 
				
			||||||
			users += ParseFileToLine(fileContents)
 | 
								users += ParseFileToLine(fileContents)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Return the parsed users
 | 
				
			||||||
	return ParseUsers(users)
 | 
						return ParseUsers(users)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Check if any of the OAuth providers are configured based on the client id and secret
 | 
				
			||||||
func OAuthConfigured(config types.Config) bool {
 | 
					func OAuthConfigured(config types.Config) bool {
 | 
				
			||||||
	return (config.GithubClientId != "" && config.GithubClientSecret != "") || (config.GoogleClientId != "" && config.GoogleClientSecret != "") || (config.GenericClientId != "" && config.GenericClientSecret != "")
 | 
						return (config.GithubClientId != "" && config.GithubClientSecret != "") || (config.GoogleClientId != "" && config.GoogleClientSecret != "") || (config.GenericClientId != "" && config.GenericClientSecret != "") || (config.TailscaleClientId != "" && config.TailscaleClientSecret != "")
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Parse the docker labels to the tinyauth labels struct
 | 
				
			||||||
func GetTinyauthLabels(labels map[string]string) types.TinyauthLabels {
 | 
					func GetTinyauthLabels(labels map[string]string) types.TinyauthLabels {
 | 
				
			||||||
 | 
						// Create a new tinyauth labels struct
 | 
				
			||||||
	var tinyauthLabels types.TinyauthLabels
 | 
						var tinyauthLabels types.TinyauthLabels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Loop through the labels
 | 
				
			||||||
	for label, value := range labels {
 | 
						for label, value := range labels {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Check if the label is in the tinyauth labels
 | 
				
			||||||
		if slices.Contains(constants.TinyauthLabels, label) {
 | 
							if slices.Contains(constants.TinyauthLabels, label) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			log.Debug().Str("label", label).Msg("Found label")
 | 
								log.Debug().Str("label", label).Msg("Found label")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// Add the label value to the tinyauth labels struct
 | 
				
			||||||
			switch label {
 | 
								switch label {
 | 
				
			||||||
			case "tinyauth.oauth.whitelist":
 | 
								case "tinyauth.oauth.whitelist":
 | 
				
			||||||
				tinyauthLabels.OAuthWhitelist = strings.Split(value, ",")
 | 
									tinyauthLabels.OAuthWhitelist = strings.Split(value, ",")
 | 
				
			||||||
@@ -144,5 +203,7 @@ func GetTinyauthLabels(labels map[string]string) types.TinyauthLabels {
 | 
				
			|||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Return the tinyauth labels
 | 
				
			||||||
	return tinyauthLabels
 | 
						return tinyauthLabels
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user