mirror of
				https://github.com/steveiliop56/tinyauth.git
				synced 2025-10-31 06:05:43 +00:00 
			
		
		
		
	Compare commits
	
		
			1 Commits
		
	
	
		
			97639ae903
			...
			refactor/e
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|   | 3745394c8c | 
| @@ -117,8 +117,8 @@ var rootCmd = &cobra.Command{ | |||||||
| 		docker := docker.NewDocker() | 		docker := docker.NewDocker() | ||||||
|  |  | ||||||
| 		// Initialize docker | 		// Initialize docker | ||||||
| 		dockerErr := docker.Init() | 		err = docker.Init() | ||||||
| 		HandleError(dockerErr, "Failed to initialize docker") | 		HandleError(err, "Failed to initialize docker") | ||||||
|  |  | ||||||
| 		// Create auth service | 		// Create auth service | ||||||
| 		auth := auth.NewAuth(docker, users, oauthWhitelist, config.SessionExpiry) | 		auth := auth.NewAuth(docker, users, oauthWhitelist, config.SessionExpiry) | ||||||
|   | |||||||
| @@ -18,7 +18,7 @@ import ( | |||||||
| // Interactive flag | // Interactive flag | ||||||
| var interactive bool | var interactive bool | ||||||
|  |  | ||||||
| // i stands for input | // Input user | ||||||
| var iUser string | var iUser string | ||||||
|  |  | ||||||
| var GenerateCmd = &cobra.Command{ | var GenerateCmd = &cobra.Command{ | ||||||
| @@ -46,18 +46,18 @@ var GenerateCmd = &cobra.Command{ | |||||||
| 			) | 			) | ||||||
|  |  | ||||||
| 			// Run form | 			// Run form | ||||||
| 			formErr := form.WithTheme(baseTheme).Run() | 			err := form.WithTheme(baseTheme).Run() | ||||||
|  |  | ||||||
| 			if formErr != nil { | 			if err != nil { | ||||||
| 				log.Fatal().Err(formErr).Msg("Form failed") | 				log.Fatal().Err(err).Msg("Form failed") | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// Parse user | 		// Parse user | ||||||
| 		user, parseErr := utils.ParseUser(iUser) | 		user, err := utils.ParseUser(iUser) | ||||||
|  |  | ||||||
| 		if parseErr != nil { | 		if err != nil { | ||||||
| 			log.Fatal().Err(parseErr).Msg("Failed to parse user") | 			log.Fatal().Err(err).Msg("Failed to parse user") | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// Check if user was using docker escape | 		// Check if user was using docker escape | ||||||
| @@ -73,13 +73,13 @@ var GenerateCmd = &cobra.Command{ | |||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// Generate totp secret | 		// Generate totp secret | ||||||
| 		key, keyErr := totp.Generate(totp.GenerateOpts{ | 		key, err := totp.Generate(totp.GenerateOpts{ | ||||||
| 			Issuer:      "Tinyauth", | 			Issuer:      "Tinyauth", | ||||||
| 			AccountName: user.Username, | 			AccountName: user.Username, | ||||||
| 		}) | 		}) | ||||||
|  |  | ||||||
| 		if keyErr != nil { | 		if err != nil { | ||||||
| 			log.Fatal().Err(keyErr).Msg("Failed to generate totp secret") | 			log.Fatal().Err(err).Msg("Failed to generate totp secret") | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// Create secret | 		// Create secret | ||||||
|   | |||||||
| @@ -12,7 +12,10 @@ import ( | |||||||
| 	"golang.org/x/crypto/bcrypt" | 	"golang.org/x/crypto/bcrypt" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | // Interactive flag | ||||||
| var interactive bool | var interactive bool | ||||||
|  |  | ||||||
|  | // Docker flag | ||||||
| var docker bool | var docker bool | ||||||
|  |  | ||||||
| // i stands for input | // i stands for input | ||||||
| @@ -51,10 +54,10 @@ var CreateCmd = &cobra.Command{ | |||||||
| 			// Use simple theme | 			// Use simple theme | ||||||
| 			var baseTheme *huh.Theme = huh.ThemeBase() | 			var baseTheme *huh.Theme = huh.ThemeBase() | ||||||
|  |  | ||||||
| 			formErr := form.WithTheme(baseTheme).Run() | 			err := form.WithTheme(baseTheme).Run() | ||||||
|  |  | ||||||
| 			if formErr != nil { | 			if err != nil { | ||||||
| 				log.Fatal().Err(formErr).Msg("Form failed") | 				log.Fatal().Err(err).Msg("Form failed") | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| @@ -66,10 +69,10 @@ var CreateCmd = &cobra.Command{ | |||||||
| 		log.Info().Str("username", iUsername).Str("password", iPassword).Bool("docker", docker).Msg("Creating user") | 		log.Info().Str("username", iUsername).Str("password", iPassword).Bool("docker", docker).Msg("Creating user") | ||||||
|  |  | ||||||
| 		// Hash password | 		// Hash password | ||||||
| 		password, passwordErr := bcrypt.GenerateFromPassword([]byte(iPassword), bcrypt.DefaultCost) | 		password, err := bcrypt.GenerateFromPassword([]byte(iPassword), bcrypt.DefaultCost) | ||||||
|  |  | ||||||
| 		if passwordErr != nil { | 		if err != nil { | ||||||
| 			log.Fatal().Err(passwordErr).Msg("Failed to hash password") | 			log.Fatal().Err(err).Msg("Failed to hash password") | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// Convert password to string | 		// Convert password to string | ||||||
|   | |||||||
| @@ -12,7 +12,10 @@ import ( | |||||||
| 	"golang.org/x/crypto/bcrypt" | 	"golang.org/x/crypto/bcrypt" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | // Interactive flag | ||||||
| var interactive bool | var interactive bool | ||||||
|  |  | ||||||
|  | // Docker flag | ||||||
| var docker bool | var docker bool | ||||||
|  |  | ||||||
| // i stands for input | // i stands for input | ||||||
| @@ -60,18 +63,18 @@ var VerifyCmd = &cobra.Command{ | |||||||
| 			) | 			) | ||||||
|  |  | ||||||
| 			// Run form | 			// Run form | ||||||
| 			formErr := form.WithTheme(baseTheme).Run() | 			err := form.WithTheme(baseTheme).Run() | ||||||
|  |  | ||||||
| 			if formErr != nil { | 			if err != nil { | ||||||
| 				log.Fatal().Err(formErr).Msg("Form failed") | 				log.Fatal().Err(err).Msg("Form failed") | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// Parse user | 		// Parse user | ||||||
| 		user, userErr := utils.ParseUser(iUser) | 		user, err := utils.ParseUser(iUser) | ||||||
|  |  | ||||||
| 		if userErr != nil { | 		if err != nil { | ||||||
| 			log.Fatal().Err(userErr).Msg("Failed to parse user") | 			log.Fatal().Err(err).Msg("Failed to parse user") | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// Compare username | 		// Compare username | ||||||
| @@ -80,9 +83,9 @@ var VerifyCmd = &cobra.Command{ | |||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// Compare password | 		// Compare password | ||||||
| 		verifyErr := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(iPassword)) | 		err = bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(iPassword)) | ||||||
|  |  | ||||||
| 		if verifyErr != nil { | 		if err != nil { | ||||||
| 			log.Fatal().Msg("Ppassword is incorrect") | 			log.Fatal().Msg("Ppassword is incorrect") | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| @@ -96,9 +99,9 @@ var VerifyCmd = &cobra.Command{ | |||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// Check totp code | 		// Check totp code | ||||||
| 		totpOk := totp.Validate(iTotp, user.TotpSecret) | 		ok := totp.Validate(iTotp, user.TotpSecret) | ||||||
|  |  | ||||||
| 		if !totpOk { | 		if !ok { | ||||||
| 			log.Fatal().Msg("Totp code incorrect") | 			log.Fatal().Msg("Totp code incorrect") | ||||||
|  |  | ||||||
| 		} | 		} | ||||||
|   | |||||||
| @@ -53,11 +53,11 @@ func getAPI(t *testing.T) *api.API { | |||||||
| 	docker := docker.NewDocker() | 	docker := docker.NewDocker() | ||||||
|  |  | ||||||
| 	// Initialize docker | 	// Initialize docker | ||||||
| 	dockerErr := docker.Init() | 	err := docker.Init() | ||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if dockerErr != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("Failed to initialize docker: %v", dockerErr) | 		t.Fatalf("Failed to initialize docker: %v", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Create auth service | 	// Create auth service | ||||||
| @@ -167,21 +167,21 @@ func TestAppContext(t *testing.T) { | |||||||
| 	assert.Equal(t, recorder.Code, http.StatusOK) | 	assert.Equal(t, recorder.Code, http.StatusOK) | ||||||
|  |  | ||||||
| 	// Read the body of the response | 	// Read the body of the response | ||||||
| 	body, bodyErr := io.ReadAll(recorder.Body) | 	body, err := io.ReadAll(recorder.Body) | ||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if bodyErr != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("Error getting body: %v", bodyErr) | 		t.Fatalf("Error getting body: %v", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Unmarshal the body into the user struct | 	// Unmarshal the body into the user struct | ||||||
| 	var app types.AppContext | 	var app types.AppContext | ||||||
|  |  | ||||||
| 	jsonErr := json.Unmarshal(body, &app) | 	err = json.Unmarshal(body, &app) | ||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if jsonErr != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("Error unmarshalling body: %v", jsonErr) | 		t.Fatalf("Error unmarshalling body: %v", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Create tests values | 	// Create tests values | ||||||
| @@ -231,11 +231,11 @@ func TestUserContext(t *testing.T) { | |||||||
| 	assert.Equal(t, recorder.Code, http.StatusOK) | 	assert.Equal(t, recorder.Code, http.StatusOK) | ||||||
|  |  | ||||||
| 	// Read the body of the response | 	// Read the body of the response | ||||||
| 	body, bodyErr := io.ReadAll(recorder.Body) | 	body, err := io.ReadAll(recorder.Body) | ||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if bodyErr != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("Error getting body: %v", bodyErr) | 		t.Fatalf("Error getting body: %v", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Unmarshal the body into the user struct | 	// Unmarshal the body into the user struct | ||||||
| @@ -245,11 +245,11 @@ func TestUserContext(t *testing.T) { | |||||||
|  |  | ||||||
| 	var user User | 	var user User | ||||||
|  |  | ||||||
| 	jsonErr := json.Unmarshal(body, &user) | 	err = json.Unmarshal(body, &user) | ||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if jsonErr != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("Error unmarshalling body: %v", jsonErr) | 		t.Fatalf("Error unmarshalling body: %v", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// We should get the username back | 	// We should get the username back | ||||||
|   | |||||||
| @@ -160,7 +160,7 @@ func (auth *Auth) ResourceAllowed(c *gin.Context, context types.UserContext) (bo | |||||||
| 	appId := strings.Split(host, ".")[0] | 	appId := strings.Split(host, ".")[0] | ||||||
|  |  | ||||||
| 	// Check if resource is allowed | 	// Check if resource is allowed | ||||||
| 	allowed, allowedErr := auth.Docker.ContainerAction(appId, func(labels types.TinyauthLabels) (bool, error) { | 	allowed, err := auth.Docker.ContainerAction(appId, func(labels types.TinyauthLabels) (bool, error) { | ||||||
| 		// If the container has an oauth whitelist, check if the user is in it | 		// If the container has an oauth whitelist, check if the user is in it | ||||||
| 		if context.OAuth { | 		if context.OAuth { | ||||||
| 			if len(labels.OAuthWhitelist) == 0 { | 			if len(labels.OAuthWhitelist) == 0 { | ||||||
| @@ -187,9 +187,9 @@ func (auth *Auth) ResourceAllowed(c *gin.Context, context types.UserContext) (bo | |||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
| 	// If there is an error, return false | 	// If there is an error, return false | ||||||
| 	if allowedErr != nil { | 	if err != nil { | ||||||
| 		log.Error().Err(allowedErr).Msg("Error checking if resource is allowed") | 		log.Error().Err(err).Msg("Error checking if resource is allowed") | ||||||
| 		return false, allowedErr | 		return false, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Return if the resource is allowed | 	// Return if the resource is allowed | ||||||
| @@ -205,7 +205,7 @@ func (auth *Auth) AuthEnabled(c *gin.Context) (bool, error) { | |||||||
| 	appId := strings.Split(host, ".")[0] | 	appId := strings.Split(host, ".")[0] | ||||||
|  |  | ||||||
| 	// Check if auth is enabled | 	// Check if auth is enabled | ||||||
| 	enabled, enabledErr := auth.Docker.ContainerAction(appId, func(labels types.TinyauthLabels) (bool, error) { | 	enabled, err := auth.Docker.ContainerAction(appId, func(labels types.TinyauthLabels) (bool, error) { | ||||||
| 		// Check if the allowed label is empty | 		// Check if the allowed label is empty | ||||||
| 		if labels.Allowed == "" { | 		if labels.Allowed == "" { | ||||||
| 			// Auth enabled | 			// Auth enabled | ||||||
| @@ -213,12 +213,12 @@ func (auth *Auth) AuthEnabled(c *gin.Context) (bool, error) { | |||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// Compile regex | 		// Compile regex | ||||||
| 		regex, regexErr := regexp.Compile(labels.Allowed) | 		regex, err := regexp.Compile(labels.Allowed) | ||||||
|  |  | ||||||
| 		// If there is an error, invalid regex, auth enabled | 		// If there is an error, invalid regex, auth enabled | ||||||
| 		if regexErr != nil { | 		if err != nil { | ||||||
| 			log.Warn().Err(regexErr).Msg("Invalid regex") | 			log.Warn().Err(err).Msg("Invalid regex") | ||||||
| 			return true, regexErr | 			return true, err | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// Check if the uri matches the regex | 		// Check if the uri matches the regex | ||||||
| @@ -232,9 +232,9 @@ func (auth *Auth) AuthEnabled(c *gin.Context) (bool, error) { | |||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
| 	// If there is an error, auth enabled | 	// If there is an error, auth enabled | ||||||
| 	if enabledErr != nil { | 	if err != nil { | ||||||
| 		log.Error().Err(enabledErr).Msg("Error checking if auth is enabled") | 		log.Error().Err(err).Msg("Error checking if auth is enabled") | ||||||
| 		return true, enabledErr | 		return true, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return enabled, nil | 	return enabled, nil | ||||||
|   | |||||||
| @@ -23,7 +23,7 @@ type Docker struct { | |||||||
|  |  | ||||||
| func (docker *Docker) Init() error { | func (docker *Docker) Init() error { | ||||||
| 	// Create a new docker client | 	// Create a new docker client | ||||||
| 	apiClient, err := client.NewClientWithOpts(client.FromEnv) | 	client, err := client.NewClientWithOpts(client.FromEnv) | ||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @@ -32,7 +32,7 @@ func (docker *Docker) Init() error { | |||||||
|  |  | ||||||
| 	// Set the context and api client | 	// Set the context and api client | ||||||
| 	docker.Context = context.Background() | 	docker.Context = context.Background() | ||||||
| 	docker.Client = apiClient | 	docker.Client = client | ||||||
|  |  | ||||||
| 	// Done | 	// Done | ||||||
| 	return nil | 	return nil | ||||||
| @@ -81,11 +81,11 @@ func (docker *Docker) ContainerAction(appId string, runCheck func(labels appType | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Get the containers | 	// Get the containers | ||||||
| 	containers, containersErr := docker.GetContainers() | 	containers, err := docker.GetContainers() | ||||||
|  |  | ||||||
| 	// If there is an error, return false | 	// If there is an error, return false | ||||||
| 	if containersErr != nil { | 	if err != nil { | ||||||
| 		return false, containersErr | 		return false, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Got containers") | 	log.Debug().Msg("Got containers") | ||||||
| @@ -93,11 +93,11 @@ func (docker *Docker) ContainerAction(appId string, runCheck func(labels appType | |||||||
| 	// Loop through the containers | 	// Loop through the containers | ||||||
| 	for _, container := range containers { | 	for _, container := range containers { | ||||||
| 		// Inspect the container | 		// Inspect the container | ||||||
| 		inspect, inspectErr := docker.InspectContainer(container.ID) | 		inspect, err := docker.InspectContainer(container.ID) | ||||||
|  |  | ||||||
| 		// If there is an error, return false | 		// If there is an error, return false | ||||||
| 		if inspectErr != nil { | 		if err != nil { | ||||||
| 			return false, inspectErr | 			return false, err | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// Get the container name (for some reason it is /name) | 		// Get the container name (for some reason it is /name) | ||||||
|   | |||||||
| @@ -144,7 +144,7 @@ func (h *Handlers) AuthHandler(c *gin.Context) { | |||||||
|  |  | ||||||
| 			// Handle error (no need to check for nginx/headers since we are sure we are using caddy/traefik) | 			// Handle error (no need to check for nginx/headers since we are sure we are using caddy/traefik) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				log.Error().Err(err).Msg("Failed to build query") | 				log.Error().Err(err).Msg("Failed to build queries") | ||||||
| 				c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) | 				c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| @@ -184,7 +184,7 @@ func (h *Handlers) AuthHandler(c *gin.Context) { | |||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Error().Err(err).Msg("Failed to build query") | 		log.Error().Err(err).Msg("Failed to build queries") | ||||||
| 		c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) | 		c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| @@ -323,10 +323,10 @@ func (h *Handlers) TotpHandler(c *gin.Context) { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Check if totp is correct | 	// Check if totp is correct | ||||||
| 	totpOk := totp.Validate(totpReq.Code, user.TotpSecret) | 	ok := totp.Validate(totpReq.Code, user.TotpSecret) | ||||||
|  |  | ||||||
| 	// TOTP is incorrect | 	// TOTP is incorrect | ||||||
| 	if !totpOk { | 	if !ok { | ||||||
| 		log.Debug().Msg("Totp incorrect") | 		log.Debug().Msg("Totp incorrect") | ||||||
| 		c.JSON(401, gin.H{ | 		c.JSON(401, gin.H{ | ||||||
| 			"status":  401, | 			"status":  401, | ||||||
| @@ -473,13 +473,13 @@ func (h *Handlers) OauthUrlHandler(c *gin.Context) { | |||||||
| 	// Tailscale does not have an auth url so we create a random code (does not need to be secure) to avoid caching and send it | 	// Tailscale does not have an auth url so we create a random code (does not need to be secure) to avoid caching and send it | ||||||
| 	if request.Provider == "tailscale" { | 	if request.Provider == "tailscale" { | ||||||
| 		// Build tailscale query | 		// Build tailscale query | ||||||
| 		tailscaleQuery, err := query.Values(types.TailscaleQuery{ | 		queries, err := query.Values(types.TailscaleQuery{ | ||||||
| 			Code: (1000 + rand.IntN(9000)), | 			Code: (1000 + rand.IntN(9000)), | ||||||
| 		}) | 		}) | ||||||
|  |  | ||||||
| 		// Handle error | 		// Handle error | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			log.Error().Err(err).Msg("Failed to build query") | 			log.Error().Err(err).Msg("Failed to build queries") | ||||||
| 			c.JSON(500, gin.H{ | 			c.JSON(500, gin.H{ | ||||||
| 				"status":  500, | 				"status":  500, | ||||||
| 				"message": "Internal Server Error", | 				"message": "Internal Server Error", | ||||||
| @@ -491,7 +491,7 @@ func (h *Handlers) OauthUrlHandler(c *gin.Context) { | |||||||
| 		c.JSON(200, gin.H{ | 		c.JSON(200, gin.H{ | ||||||
| 			"status":  200, | 			"status":  200, | ||||||
| 			"message": "OK", | 			"message": "OK", | ||||||
| 			"url":     fmt.Sprintf("%s/api/oauth/callback/tailscale?%s", h.Config.AppURL, tailscaleQuery.Encode()), | 			"url":     fmt.Sprintf("%s/api/oauth/callback/tailscale?%s", h.Config.AppURL, queries.Encode()), | ||||||
| 		}) | 		}) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| @@ -572,19 +572,19 @@ func (h *Handlers) OauthCallbackHandler(c *gin.Context) { | |||||||
| 		log.Warn().Str("email", email).Msg("Email not whitelisted") | 		log.Warn().Str("email", email).Msg("Email not whitelisted") | ||||||
|  |  | ||||||
| 		// Build query | 		// Build query | ||||||
| 		unauthorizedQuery, err := query.Values(types.UnauthorizedQuery{ | 		queries, err := query.Values(types.UnauthorizedQuery{ | ||||||
| 			Username: email, | 			Username: email, | ||||||
| 		}) | 		}) | ||||||
|  |  | ||||||
| 		// Handle error | 		// Handle error | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			log.Error().Msg("Failed to build query") | 			log.Error().Msg("Failed to build queries") | ||||||
| 			c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) | 			c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// Redirect to unauthorized | 		// Redirect to unauthorized | ||||||
| 		c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/unauthorized?%s", h.Config.AppURL, unauthorizedQuery.Encode())) | 		c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/unauthorized?%s", h.Config.AppURL, queries.Encode())) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Email whitelisted") | 	log.Debug().Msg("Email whitelisted") | ||||||
| @@ -596,10 +596,10 @@ func (h *Handlers) OauthCallbackHandler(c *gin.Context) { | |||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
| 	// Get redirect URI | 	// Get redirect URI | ||||||
| 	redirectURI, redirectURIErr := c.Cookie("tinyauth_redirect_uri") | 	redirectURI, err := c.Cookie("tinyauth_redirect_uri") | ||||||
|  |  | ||||||
| 	// If it is empty it means that no redirect_uri was provided to the login screen so we just log in | 	// If it is empty it means that no redirect_uri was provided to the login screen so we just log in | ||||||
| 	if redirectURIErr != nil { | 	if err != nil { | ||||||
| 		c.Redirect(http.StatusPermanentRedirect, h.Config.AppURL) | 		c.Redirect(http.StatusPermanentRedirect, h.Config.AppURL) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -609,7 +609,7 @@ func (h *Handlers) OauthCallbackHandler(c *gin.Context) { | |||||||
| 	c.SetCookie("tinyauth_redirect_uri", "", -1, "/", h.Config.Domain, h.Config.CookieSecure, true) | 	c.SetCookie("tinyauth_redirect_uri", "", -1, "/", h.Config.Domain, h.Config.CookieSecure, true) | ||||||
|  |  | ||||||
| 	// Build query | 	// Build query | ||||||
| 	redirectQuery, err := query.Values(types.LoginQuery{ | 	queries, err := query.Values(types.LoginQuery{ | ||||||
| 		RedirectURI: redirectURI, | 		RedirectURI: redirectURI, | ||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
| @@ -617,13 +617,13 @@ func (h *Handlers) OauthCallbackHandler(c *gin.Context) { | |||||||
|  |  | ||||||
| 	// Handle error | 	// Handle error | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Error().Msg("Failed to build query") | 		log.Error().Msg("Failed to build queries") | ||||||
| 		c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) | 		c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL)) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Redirect to continue with the redirect URI | 	// Redirect to continue with the redirect URI | ||||||
| 	c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/continue?%s", h.Config.AppURL, redirectQuery.Encode())) | 	c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/continue?%s", h.Config.AppURL, queries.Encode())) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (h *Handlers) HealthcheckHandler(c *gin.Context) { | func (h *Handlers) HealthcheckHandler(c *gin.Context) { | ||||||
|   | |||||||
| @@ -15,21 +15,21 @@ type GenericUserInfoResponse struct { | |||||||
|  |  | ||||||
| func GetGenericEmail(client *http.Client, url string) (string, error) { | func GetGenericEmail(client *http.Client, url string) (string, error) { | ||||||
| 	// Using the oauth client get the user info url | 	// Using the oauth client get the user info url | ||||||
| 	res, resErr := client.Get(url) | 	res, err := client.Get(url) | ||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if resErr != nil { | 	if err != nil { | ||||||
| 		return "", resErr | 		return "", err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Got response from generic provider") | 	log.Debug().Msg("Got response from generic provider") | ||||||
|  |  | ||||||
| 	// Read the body of the response | 	// Read the body of the response | ||||||
| 	body, bodyErr := io.ReadAll(res.Body) | 	body, err := io.ReadAll(res.Body) | ||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if bodyErr != nil { | 	if err != nil { | ||||||
| 		return "", bodyErr | 		return "", err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Read body from generic provider") | 	log.Debug().Msg("Read body from generic provider") | ||||||
| @@ -38,11 +38,11 @@ func GetGenericEmail(client *http.Client, url string) (string, error) { | |||||||
| 	var user GenericUserInfoResponse | 	var user GenericUserInfoResponse | ||||||
|  |  | ||||||
| 	// Unmarshal the body into the user struct | 	// Unmarshal the body into the user struct | ||||||
| 	jsonErr := json.Unmarshal(body, &user) | 	err = json.Unmarshal(body, &user) | ||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if jsonErr != nil { | 	if err != nil { | ||||||
| 		return "", jsonErr | 		return "", err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Parsed user from generic provider") | 	log.Debug().Msg("Parsed user from generic provider") | ||||||
|   | |||||||
| @@ -22,21 +22,21 @@ func GithubScopes() []string { | |||||||
|  |  | ||||||
| func GetGithubEmail(client *http.Client) (string, error) { | func GetGithubEmail(client *http.Client) (string, error) { | ||||||
| 	// Get the user emails from github using the oauth http client | 	// Get the user emails from github using the oauth http client | ||||||
| 	res, resErr := client.Get("https://api.github.com/user/emails") | 	res, err := client.Get("https://api.github.com/user/emails") | ||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if resErr != nil { | 	if err != nil { | ||||||
| 		return "", resErr | 		return "", err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Got response from github") | 	log.Debug().Msg("Got response from github") | ||||||
|  |  | ||||||
| 	// Read the body of the response | 	// Read the body of the response | ||||||
| 	body, bodyErr := io.ReadAll(res.Body) | 	body, err := io.ReadAll(res.Body) | ||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if bodyErr != nil { | 	if err != nil { | ||||||
| 		return "", bodyErr | 		return "", err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Read body from github") | 	log.Debug().Msg("Read body from github") | ||||||
| @@ -45,11 +45,11 @@ func GetGithubEmail(client *http.Client) (string, error) { | |||||||
| 	var emails GithubUserInfoResponse | 	var emails GithubUserInfoResponse | ||||||
|  |  | ||||||
| 	// Unmarshal the body into the user struct | 	// Unmarshal the body into the user struct | ||||||
| 	jsonErr := json.Unmarshal(body, &emails) | 	err = json.Unmarshal(body, &emails) | ||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if jsonErr != nil { | 	if err != nil { | ||||||
| 		return "", jsonErr | 		return "", err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Parsed emails from github") | 	log.Debug().Msg("Parsed emails from github") | ||||||
|   | |||||||
| @@ -20,21 +20,21 @@ func GoogleScopes() []string { | |||||||
|  |  | ||||||
| func GetGoogleEmail(client *http.Client) (string, error) { | func GetGoogleEmail(client *http.Client) (string, error) { | ||||||
| 	// Get the user info from google using the oauth http client | 	// Get the user info from google using the oauth http client | ||||||
| 	res, resErr := client.Get("https://www.googleapis.com/userinfo/v2/me") | 	res, err := client.Get("https://www.googleapis.com/userinfo/v2/me") | ||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if resErr != nil { | 	if err != nil { | ||||||
| 		return "", resErr | 		return "", err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Got response from google") | 	log.Debug().Msg("Got response from google") | ||||||
|  |  | ||||||
| 	// Read the body of the response | 	// Read the body of the response | ||||||
| 	body, bodyErr := io.ReadAll(res.Body) | 	body, err := io.ReadAll(res.Body) | ||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if bodyErr != nil { | 	if err != nil { | ||||||
| 		return "", bodyErr | 		return "", err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Read body from google") | 	log.Debug().Msg("Read body from google") | ||||||
| @@ -43,11 +43,11 @@ func GetGoogleEmail(client *http.Client) (string, error) { | |||||||
| 	var user GoogleUserInfoResponse | 	var user GoogleUserInfoResponse | ||||||
|  |  | ||||||
| 	// Unmarshal the body into the user struct | 	// Unmarshal the body into the user struct | ||||||
| 	jsonErr := json.Unmarshal(body, &user) | 	err = json.Unmarshal(body, &user) | ||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if jsonErr != nil { | 	if err != nil { | ||||||
| 		return "", jsonErr | 		return "", err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Parsed user from google") | 	log.Debug().Msg("Parsed user from google") | ||||||
|   | |||||||
| @@ -128,11 +128,11 @@ func (providers *Providers) GetUser(provider string) (string, error) { | |||||||
| 		log.Debug().Msg("Got client from github") | 		log.Debug().Msg("Got client from github") | ||||||
|  |  | ||||||
| 		// Get the email from the github provider | 		// Get the email from the github provider | ||||||
| 		email, emailErr := GetGithubEmail(client) | 		email, err := GetGithubEmail(client) | ||||||
|  |  | ||||||
| 		// Check if there was an error | 		// Check if there was an error | ||||||
| 		if emailErr != nil { | 		if err != nil { | ||||||
| 			return "", emailErr | 			return "", err | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Got email from github") | 		log.Debug().Msg("Got email from github") | ||||||
| @@ -152,11 +152,11 @@ func (providers *Providers) GetUser(provider string) (string, error) { | |||||||
| 		log.Debug().Msg("Got client from google") | 		log.Debug().Msg("Got client from google") | ||||||
|  |  | ||||||
| 		// Get the email from the google provider | 		// Get the email from the google provider | ||||||
| 		email, emailErr := GetGoogleEmail(client) | 		email, err := GetGoogleEmail(client) | ||||||
|  |  | ||||||
| 		// Check if there was an error | 		// Check if there was an error | ||||||
| 		if emailErr != nil { | 		if err != nil { | ||||||
| 			return "", emailErr | 			return "", err | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Got email from google") | 		log.Debug().Msg("Got email from google") | ||||||
| @@ -176,11 +176,11 @@ func (providers *Providers) GetUser(provider string) (string, error) { | |||||||
| 		log.Debug().Msg("Got client from tailscale") | 		log.Debug().Msg("Got client from tailscale") | ||||||
|  |  | ||||||
| 		// Get the email from the tailscale provider | 		// Get the email from the tailscale provider | ||||||
| 		email, emailErr := GetTailscaleEmail(client) | 		email, err := GetTailscaleEmail(client) | ||||||
|  |  | ||||||
| 		// Check if there was an error | 		// Check if there was an error | ||||||
| 		if emailErr != nil { | 		if err != nil { | ||||||
| 			return "", emailErr | 			return "", err | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Got email from tailscale") | 		log.Debug().Msg("Got email from tailscale") | ||||||
| @@ -200,11 +200,11 @@ func (providers *Providers) GetUser(provider string) (string, error) { | |||||||
| 		log.Debug().Msg("Got client from generic") | 		log.Debug().Msg("Got client from generic") | ||||||
|  |  | ||||||
| 		// Get the email from the generic provider | 		// Get the email from the generic provider | ||||||
| 		email, emailErr := GetGenericEmail(client, providers.Config.GenericUserURL) | 		email, err := GetGenericEmail(client, providers.Config.GenericUserURL) | ||||||
|  |  | ||||||
| 		// Check if there was an error | 		// Check if there was an error | ||||||
| 		if emailErr != nil { | 		if err != nil { | ||||||
| 			return "", emailErr | 			return "", err | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Got email from generic") | 		log.Debug().Msg("Got email from generic") | ||||||
|   | |||||||
| @@ -31,21 +31,21 @@ var TailscaleEndpoint = oauth2.Endpoint{ | |||||||
|  |  | ||||||
| func GetTailscaleEmail(client *http.Client) (string, error) { | func GetTailscaleEmail(client *http.Client) (string, error) { | ||||||
| 	// Get the user info from tailscale using the oauth http client | 	// Get the user info from tailscale using the oauth http client | ||||||
| 	res, resErr := client.Get("https://api.tailscale.com/api/v2/tailnet/-/users") | 	res, err := client.Get("https://api.tailscale.com/api/v2/tailnet/-/users") | ||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if resErr != nil { | 	if err != nil { | ||||||
| 		return "", resErr | 		return "", err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Got response from tailscale") | 	log.Debug().Msg("Got response from tailscale") | ||||||
|  |  | ||||||
| 	// Read the body of the response | 	// Read the body of the response | ||||||
| 	body, bodyErr := io.ReadAll(res.Body) | 	body, err := io.ReadAll(res.Body) | ||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if bodyErr != nil { | 	if err != nil { | ||||||
| 		return "", bodyErr | 		return "", err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Read body from tailscale") | 	log.Debug().Msg("Read body from tailscale") | ||||||
| @@ -54,11 +54,11 @@ func GetTailscaleEmail(client *http.Client) (string, error) { | |||||||
| 	var users TailscaleUserInfoResponse | 	var users TailscaleUserInfoResponse | ||||||
|  |  | ||||||
| 	// Unmarshal the body into the user struct | 	// Unmarshal the body into the user struct | ||||||
| 	jsonErr := json.Unmarshal(body, &users) | 	err = json.Unmarshal(body, &users) | ||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if jsonErr != nil { | 	if err != nil { | ||||||
| 		return "", jsonErr | 		return "", err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Parsed users from tailscale") | 	log.Debug().Msg("Parsed users from tailscale") | ||||||
|   | |||||||
| @@ -29,11 +29,11 @@ func ParseUsers(users string) (types.Users, error) { | |||||||
|  |  | ||||||
| 	// Loop through the users and split them by colon | 	// Loop through the users and split them by colon | ||||||
| 	for _, user := range userList { | 	for _, user := range userList { | ||||||
| 		parsed, parseErr := ParseUser(user) | 		parsed, err := ParseUser(user) | ||||||
|  |  | ||||||
| 		// Check if there was an error | 		// Check if there was an error | ||||||
| 		if parseErr != nil { | 		if err != nil { | ||||||
| 			return types.Users{}, parseErr | 			return types.Users{}, err | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// Append the user to the users struct | 		// Append the user to the users struct | ||||||
| @@ -69,19 +69,19 @@ func GetUpperDomain(urlSrc string) (string, error) { | |||||||
| // Reads a file and returns the contents | // Reads a file and returns the contents | ||||||
| func ReadFile(file string) (string, error) { | func ReadFile(file string) (string, error) { | ||||||
| 	// Check if the file exists | 	// Check if the file exists | ||||||
| 	_, statErr := os.Stat(file) | 	_, err := os.Stat(file) | ||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if statErr != nil { | 	if err != nil { | ||||||
| 		return "", statErr | 		return "", err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Read the file | 	// Read the file | ||||||
| 	data, readErr := os.ReadFile(file) | 	data, err := os.ReadFile(file) | ||||||
|  |  | ||||||
| 	// Check if there was an error | 	// Check if there was an error | ||||||
| 	if readErr != nil { | 	if err != nil { | ||||||
| 		return "", readErr | 		return "", err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Return the file contents | 	// Return the file contents | ||||||
| @@ -152,10 +152,10 @@ func GetUsers(conf string, file string) (types.Users, error) { | |||||||
| 	// If the file is set, read the file and append the users to the users string | 	// If the file is set, read the file and append the users to the users string | ||||||
| 	if file != "" { | 	if file != "" { | ||||||
| 		// Read the file | 		// Read the file | ||||||
| 		fileContents, fileErr := ReadFile(file) | 		contents, err := ReadFile(file) | ||||||
|  |  | ||||||
| 		// If there isn't an error we can append the users to the users string | 		// If there isn't an error we can append the users to the users string | ||||||
| 		if fileErr == nil { | 		if err == nil { | ||||||
| 			log.Debug().Msg("Using users from file") | 			log.Debug().Msg("Using users from file") | ||||||
|  |  | ||||||
| 			// Append the users to the users string | 			// Append the users to the users string | ||||||
| @@ -164,7 +164,7 @@ func GetUsers(conf string, file string) (types.Users, error) { | |||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			// Parse the file contents into a comma separated list of users | 			// Parse the file contents into a comma separated list of users | ||||||
| 			users += ParseFileToLine(fileContents) | 			users += ParseFileToLine(contents) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|   | |||||||
| @@ -102,7 +102,7 @@ func TestParseFileToLine(t *testing.T) { | |||||||
| 	t.Log("Testing parse file to line with a valid string") | 	t.Log("Testing parse file to line with a valid string") | ||||||
|  |  | ||||||
| 	// Test the parse file to line function with a valid string | 	// Test the parse file to line function with a valid string | ||||||
| 	content := "user1:pass1\nuser2:pass2" | 	content := "\nuser1:pass1\nuser2:pass2\n" | ||||||
| 	expected := "user1:pass1,user2:pass2" | 	expected := "user1:pass1,user2:pass2" | ||||||
|  |  | ||||||
| 	result := utils.ParseFileToLine(content) | 	result := utils.ParseFileToLine(content) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user