mirror of
				https://github.com/steveiliop56/tinyauth.git
				synced 2025-10-30 21:55:43 +00:00 
			
		
		
		
	Compare commits
	
		
			7 Commits
		
	
	
		
			v3.6.1
			...
			v3.6.2-bet
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|   | 5854d973ea | ||
|   | f25ab72747 | ||
|   | 2233557990 | ||
|   | d3bec635f8 | ||
|   | 6519644fc1 | ||
|   | 736f65b7b2 | ||
|   | 63d39b5500 | 
| @@ -12,6 +12,7 @@ import { | |||||||
| } from "../ui/form"; | } from "../ui/form"; | ||||||
| import { Button } from "../ui/button"; | import { Button } from "../ui/button"; | ||||||
| import { loginSchema, LoginSchema } from "@/schemas/login-schema"; | import { loginSchema, LoginSchema } from "@/schemas/login-schema"; | ||||||
|  | import z from "zod"; | ||||||
|  |  | ||||||
| interface Props { | interface Props { | ||||||
|   onSubmit: (data: LoginSchema) => void; |   onSubmit: (data: LoginSchema) => void; | ||||||
| @@ -22,6 +23,11 @@ export const LoginForm = (props: Props) => { | |||||||
|   const { onSubmit, loading } = props; |   const { onSubmit, loading } = props; | ||||||
|   const { t } = useTranslation(); |   const { t } = useTranslation(); | ||||||
|  |  | ||||||
|  |   z.config({ | ||||||
|  |     customError: (iss) => | ||||||
|  |       iss.input === undefined ? t("fieldRequired") : t("invalidInput"), | ||||||
|  |   }); | ||||||
|  |  | ||||||
|   const form = useForm<LoginSchema>({ |   const form = useForm<LoginSchema>({ | ||||||
|     resolver: zodResolver(loginSchema), |     resolver: zodResolver(loginSchema), | ||||||
|   }); |   }); | ||||||
|   | |||||||
| @@ -8,6 +8,8 @@ import { | |||||||
| import { zodResolver } from "@hookform/resolvers/zod"; | import { zodResolver } from "@hookform/resolvers/zod"; | ||||||
| import { useForm } from "react-hook-form"; | import { useForm } from "react-hook-form"; | ||||||
| import { totpSchema, TotpSchema } from "@/schemas/totp-schema"; | import { totpSchema, TotpSchema } from "@/schemas/totp-schema"; | ||||||
|  | import { useTranslation } from "react-i18next"; | ||||||
|  | import z from "zod"; | ||||||
|  |  | ||||||
| interface Props { | interface Props { | ||||||
|   formId: string; |   formId: string; | ||||||
| @@ -17,6 +19,12 @@ interface Props { | |||||||
|  |  | ||||||
| export const TotpForm = (props: Props) => { | export const TotpForm = (props: Props) => { | ||||||
|   const { formId, onSubmit, loading } = props; |   const { formId, onSubmit, loading } = props; | ||||||
|  |   const { t } = useTranslation(); | ||||||
|  |  | ||||||
|  |   z.config({ | ||||||
|  |     customError: (iss) => | ||||||
|  |       iss.input === undefined ? t("fieldRequired") : t("invalidInput"), | ||||||
|  |   }); | ||||||
|  |  | ||||||
|   const form = useForm<TotpSchema>({ |   const form = useForm<TotpSchema>({ | ||||||
|     resolver: zodResolver(totpSchema), |     resolver: zodResolver(totpSchema), | ||||||
|   | |||||||
| @@ -51,5 +51,7 @@ | |||||||
|     "failedToFetchProvidersTitle": "Failed to load authentication providers. Please check your configuration.", |     "failedToFetchProvidersTitle": "Failed to load authentication providers. Please check your configuration.", | ||||||
|     "errorTitle": "An error occurred", |     "errorTitle": "An error occurred", | ||||||
|     "errorSubtitle": "An error occurred while trying to perform this action. Please check the console for more information.", |     "errorSubtitle": "An error occurred while trying to perform this action. Please check the console for more information.", | ||||||
|     "forgotPasswordMessage": "You can reset your password by changing the `USERS` environment variable." |     "forgotPasswordMessage": "You can reset your password by changing the `USERS` environment variable.", | ||||||
|  |     "fieldRequired": "This field is required", | ||||||
|  |     "invalidInput": "Invalid input" | ||||||
| } | } | ||||||
| @@ -51,5 +51,7 @@ | |||||||
|     "failedToFetchProvidersTitle": "Failed to load authentication providers. Please check your configuration.", |     "failedToFetchProvidersTitle": "Failed to load authentication providers. Please check your configuration.", | ||||||
|     "errorTitle": "An error occurred", |     "errorTitle": "An error occurred", | ||||||
|     "errorSubtitle": "An error occurred while trying to perform this action. Please check the console for more information.", |     "errorSubtitle": "An error occurred while trying to perform this action. Please check the console for more information.", | ||||||
|     "forgotPasswordMessage": "You can reset your password by changing the `USERS` environment variable." |     "forgotPasswordMessage": "You can reset your password by changing the `USERS` environment variable.", | ||||||
|  |     "fieldRequired": "This field is required", | ||||||
|  |     "invalidInput": "Invalid input" | ||||||
| } | } | ||||||
							
								
								
									
										1
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								go.mod
									
									
									
									
									
								
							| @@ -17,6 +17,7 @@ require ( | |||||||
|  |  | ||||||
| require ( | require ( | ||||||
| 	github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 // indirect | 	github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 // indirect | ||||||
|  | 	github.com/cenkalti/backoff/v5 v5.0.2 // indirect | ||||||
| 	github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect | 	github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect | ||||||
| 	github.com/charmbracelet/x/cellbuf v0.0.13 // indirect | 	github.com/charmbracelet/x/cellbuf v0.0.13 // indirect | ||||||
| 	github.com/containerd/errdefs v1.0.0 // indirect | 	github.com/containerd/errdefs v1.0.0 // indirect | ||||||
|   | |||||||
							
								
								
									
										2
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								go.sum
									
									
									
									
									
								
							| @@ -26,6 +26,8 @@ github.com/catppuccin/go v0.3.0 h1:d+0/YicIq+hSTo5oPuRi5kOpqkVA5tAsU6dNhvRu+aY= | |||||||
| github.com/catppuccin/go v0.3.0/go.mod h1:8IHJuMGaUUjQM82qBrGNBv7LFq6JI3NnQCF6MOlZjpc= | github.com/catppuccin/go v0.3.0/go.mod h1:8IHJuMGaUUjQM82qBrGNBv7LFq6JI3NnQCF6MOlZjpc= | ||||||
| github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= | github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= | ||||||
| github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= | github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= | ||||||
|  | github.com/cenkalti/backoff/v5 v5.0.2 h1:rIfFVxEf1QsI7E1ZHfp/B4DF/6QBAUhmgkxc0H7Zss8= | ||||||
|  | github.com/cenkalti/backoff/v5 v5.0.2/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw= | ||||||
| github.com/charmbracelet/bubbles v0.21.0 h1:9TdC97SdRVg/1aaXNVWfFH3nnLAwOXr8Fn6u6mfQdFs= | github.com/charmbracelet/bubbles v0.21.0 h1:9TdC97SdRVg/1aaXNVWfFH3nnLAwOXr8Fn6u6mfQdFs= | ||||||
| github.com/charmbracelet/bubbles v0.21.0/go.mod h1:HF+v6QUR4HkEpz62dx7ym2xc71/KBHg+zKwJtMw+qtg= | github.com/charmbracelet/bubbles v0.21.0/go.mod h1:HF+v6QUR4HkEpz62dx7ym2xc71/KBHg+zKwJtMw+qtg= | ||||||
| github.com/charmbracelet/bubbletea v1.3.4 h1:kCg7B+jSCFPLYRA52SDZjr51kG/fMUEoPoZrkaDHyoI= | github.com/charmbracelet/bubbletea v1.3.4 h1:kCg7B+jSCFPLYRA52SDZjr51kG/fMUEoPoZrkaDHyoI= | ||||||
|   | |||||||
| @@ -50,7 +50,7 @@ func (auth *Auth) GetSession(c *gin.Context) (*sessions.Session, error) { | |||||||
|  |  | ||||||
| 	// If there was an error getting the session, it might be invalid so let's clear it and retry | 	// If there was an error getting the session, it might be invalid so let's clear it and retry | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Warn().Err(err).Msg("Invalid session, clearing cookie and retrying") | 		log.Error().Err(err).Msg("Invalid session, clearing cookie and retrying") | ||||||
| 		c.SetCookie(auth.Config.SessionCookieName, "", -1, "/", fmt.Sprintf(".%s", auth.Config.Domain), auth.Config.CookieSecure, true) | 		c.SetCookie(auth.Config.SessionCookieName, "", -1, "/", fmt.Sprintf(".%s", auth.Config.Domain), auth.Config.CookieSecure, true) | ||||||
| 		session, err = auth.Store.Get(c.Request, auth.Config.SessionCookieName) | 		session, err = auth.Store.Get(c.Request, auth.Config.SessionCookieName) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| @@ -79,7 +79,7 @@ func (auth *Auth) SearchUser(username string) types.UserSearch { | |||||||
| 		log.Debug().Str("username", username).Msg("Checking LDAP for user") | 		log.Debug().Str("username", username).Msg("Checking LDAP for user") | ||||||
| 		userDN, err := auth.LDAP.Search(username) | 		userDN, err := auth.LDAP.Search(username) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			log.Warn().Err(err).Str("username", username).Msg("Failed to find user in LDAP") | 			log.Error().Err(err).Str("username", username).Msg("Failed to find user in LDAP") | ||||||
| 			return types.UserSearch{} | 			return types.UserSearch{} | ||||||
| 		} | 		} | ||||||
| 		return types.UserSearch{ | 		return types.UserSearch{ | ||||||
| @@ -88,7 +88,9 @@ func (auth *Auth) SearchUser(username string) types.UserSearch { | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return types.UserSearch{} | 	return types.UserSearch{ | ||||||
|  | 		Type: "unknown", | ||||||
|  | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func (auth *Auth) VerifyUser(search types.UserSearch, password string) bool { | func (auth *Auth) VerifyUser(search types.UserSearch, password string) bool { | ||||||
| @@ -105,7 +107,7 @@ func (auth *Auth) VerifyUser(search types.UserSearch, password string) bool { | |||||||
|  |  | ||||||
| 			err := auth.LDAP.Bind(search.Username, password) | 			err := auth.LDAP.Bind(search.Username, password) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				log.Warn().Err(err).Str("username", search.Username).Msg("Failed to bind to LDAP") | 				log.Error().Err(err).Str("username", search.Username).Msg("Failed to bind to LDAP") | ||||||
| 				return false | 				return false | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| @@ -370,7 +372,7 @@ func (auth *Auth) AuthEnabled(uri string, labels types.Labels) (bool, error) { | |||||||
|  |  | ||||||
| 	// If there is an error, invalid regex, auth enabled | 	// If there is an error, invalid regex, auth enabled | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Warn().Err(err).Msg("Invalid regex") | 		log.Error().Err(err).Msg("Invalid regex") | ||||||
| 		return true, err | 		return true, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -399,7 +401,7 @@ func (auth *Auth) CheckIP(labels types.Labels, ip string) bool { | |||||||
| 	for _, blocked := range labels.IP.Block { | 	for _, blocked := range labels.IP.Block { | ||||||
| 		res, err := utils.FilterIP(blocked, ip) | 		res, err := utils.FilterIP(blocked, ip) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			log.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list") | 			log.Error().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list") | ||||||
| 			continue | 			continue | ||||||
| 		} | 		} | ||||||
| 		if res { | 		if res { | ||||||
| @@ -412,7 +414,7 @@ func (auth *Auth) CheckIP(labels types.Labels, ip string) bool { | |||||||
| 	for _, allowed := range labels.IP.Allow { | 	for _, allowed := range labels.IP.Allow { | ||||||
| 		res, err := utils.FilterIP(allowed, ip) | 		res, err := utils.FilterIP(allowed, ip) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			log.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list") | 			log.Error().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list") | ||||||
| 			continue | 			continue | ||||||
| 		} | 		} | ||||||
| 		if res { | 		if res { | ||||||
| @@ -436,7 +438,7 @@ func (auth *Auth) BypassedIP(labels types.Labels, ip string) bool { | |||||||
| 	for _, bypassed := range labels.IP.Bypass { | 	for _, bypassed := range labels.IP.Bypass { | ||||||
| 		res, err := utils.FilterIP(bypassed, ip) | 		res, err := utils.FilterIP(bypassed, ip) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			log.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list") | 			log.Error().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list") | ||||||
| 			continue | 			continue | ||||||
| 		} | 		} | ||||||
| 		if res { | 		if res { | ||||||
|   | |||||||
| @@ -5,7 +5,7 @@ type Claims struct { | |||||||
| 	Name              string `json:"name"` | 	Name              string `json:"name"` | ||||||
| 	Email             string `json:"email"` | 	Email             string `json:"email"` | ||||||
| 	PreferredUsername string `json:"preferred_username"` | 	PreferredUsername string `json:"preferred_username"` | ||||||
| 	Groups            []string `json:"groups"` | 	Groups            any    `json:"groups"` | ||||||
| } | } | ||||||
|  |  | ||||||
| // Version information | // Version information | ||||||
|   | |||||||
| @@ -1,4 +1,4 @@ | |||||||
| package server_test | package handlers_test | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| @@ -189,7 +189,7 @@ func (h *Handlers) OAuthCallbackHandler(c *gin.Context) { | |||||||
| 		Name:        name, | 		Name:        name, | ||||||
| 		Email:       user.Email, | 		Email:       user.Email, | ||||||
| 		Provider:    providerName.Provider, | 		Provider:    providerName.Provider, | ||||||
| 		OAuthGroups: strings.Join(user.Groups, ","), | 		OAuthGroups: utils.CoalesceToString(user.Groups), | ||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
| 	// Check if we have a redirect URI | 	// Check if we have a redirect URI | ||||||
|   | |||||||
| @@ -40,10 +40,7 @@ func (h *Handlers) ProxyHandler(c *gin.Context) { | |||||||
| 	proto := c.Request.Header.Get("X-Forwarded-Proto") | 	proto := c.Request.Header.Get("X-Forwarded-Proto") | ||||||
| 	host := c.Request.Header.Get("X-Forwarded-Host") | 	host := c.Request.Header.Get("X-Forwarded-Host") | ||||||
|  |  | ||||||
| 	// Remove the port from the host if it exists |  | ||||||
| 	hostPortless := strings.Split(host, ":")[0] // *lol* | 	hostPortless := strings.Split(host, ":")[0] // *lol* | ||||||
|  |  | ||||||
| 	// Get the id |  | ||||||
| 	id := strings.Split(hostPortless, ".")[0] | 	id := strings.Split(hostPortless, ".")[0] | ||||||
|  |  | ||||||
| 	labels, err := h.Docker.GetLabels(id, hostPortless) | 	labels, err := h.Docker.GetLabels(id, hostPortless) | ||||||
| @@ -66,10 +63,10 @@ func (h *Handlers) ProxyHandler(c *gin.Context) { | |||||||
|  |  | ||||||
| 	ip := c.ClientIP() | 	ip := c.ClientIP() | ||||||
|  |  | ||||||
| 	// Check if the IP is in bypass list |  | ||||||
| 	if h.Auth.BypassedIP(labels, ip) { | 	if h.Auth.BypassedIP(labels, ip) { | ||||||
| 		headersParsed := utils.ParseHeaders(labels.Headers) | 		c.Header("Authorization", c.Request.Header.Get("Authorization")) | ||||||
|  |  | ||||||
|  | 		headersParsed := utils.ParseHeaders(labels.Headers) | ||||||
| 		for key, value := range headersParsed { | 		for key, value := range headersParsed { | ||||||
| 			log.Debug().Str("key", key).Msg("Setting header") | 			log.Debug().Str("key", key).Msg("Setting header") | ||||||
| 			c.Header(key, value) | 			c.Header(key, value) | ||||||
| @@ -87,7 +84,6 @@ func (h *Handlers) ProxyHandler(c *gin.Context) { | |||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Check if the IP is allowed/blocked |  | ||||||
| 	if !h.Auth.CheckIP(labels, ip) { | 	if !h.Auth.CheckIP(labels, ip) { | ||||||
| 		if proxy.Proxy == "nginx" || !isBrowser { | 		if proxy.Proxy == "nginx" || !isBrowser { | ||||||
| 			c.JSON(403, gin.H{ | 			c.JSON(403, gin.H{ | ||||||
| @@ -113,7 +109,6 @@ func (h *Handlers) ProxyHandler(c *gin.Context) { | |||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Check if auth is enabled |  | ||||||
| 	authEnabled, err := h.Auth.AuthEnabled(uri, labels) | 	authEnabled, err := h.Auth.AuthEnabled(uri, labels) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Error().Err(err).Msg("Failed to check if app is allowed") | 		log.Error().Err(err).Msg("Failed to check if app is allowed") | ||||||
| @@ -129,8 +124,9 @@ func (h *Handlers) ProxyHandler(c *gin.Context) { | |||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// If auth is not enabled, return 200 |  | ||||||
| 	if !authEnabled { | 	if !authEnabled { | ||||||
|  | 		c.Header("Authorization", c.Request.Header.Get("Authorization")) | ||||||
|  |  | ||||||
| 		headersParsed := utils.ParseHeaders(labels.Headers) | 		headersParsed := utils.ParseHeaders(labels.Headers) | ||||||
| 		for key, value := range headersParsed { | 		for key, value := range headersParsed { | ||||||
| 			log.Debug().Str("key", key).Msg("Setting header") | 			log.Debug().Str("key", key).Msg("Setting header") | ||||||
| @@ -150,7 +146,6 @@ func (h *Handlers) ProxyHandler(c *gin.Context) { | |||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Get user context |  | ||||||
| 	userContext := h.Hooks.UseUserContext(c) | 	userContext := h.Hooks.UseUserContext(c) | ||||||
|  |  | ||||||
| 	// If we are using basic auth, we need to check if the user has totp and if it does then disable basic auth | 	// If we are using basic auth, we need to check if the user has totp and if it does then disable basic auth | ||||||
| @@ -159,7 +154,6 @@ func (h *Handlers) ProxyHandler(c *gin.Context) { | |||||||
| 		userContext.IsLoggedIn = false | 		userContext.IsLoggedIn = false | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Check if user is logged in |  | ||||||
| 	if userContext.IsLoggedIn { | 	if userContext.IsLoggedIn { | ||||||
| 		log.Debug().Msg("Authenticated") | 		log.Debug().Msg("Authenticated") | ||||||
|  |  | ||||||
| @@ -200,7 +194,6 @@ func (h *Handlers) ProxyHandler(c *gin.Context) { | |||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// Check groups if using OAuth |  | ||||||
| 		if userContext.OAuth { | 		if userContext.OAuth { | ||||||
| 			groupOk := h.Auth.OAuthGroup(c, userContext, labels) | 			groupOk := h.Auth.OAuthGroup(c, userContext, labels) | ||||||
|  |  | ||||||
| @@ -239,19 +232,18 @@ func (h *Handlers) ProxyHandler(c *gin.Context) { | |||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | 		c.Header("Authorization", c.Request.Header.Get("Authorization")) | ||||||
| 		c.Header("Remote-User", utils.SanitizeHeader(userContext.Username)) | 		c.Header("Remote-User", utils.SanitizeHeader(userContext.Username)) | ||||||
| 		c.Header("Remote-Name", utils.SanitizeHeader(userContext.Name)) | 		c.Header("Remote-Name", utils.SanitizeHeader(userContext.Name)) | ||||||
| 		c.Header("Remote-Email", utils.SanitizeHeader(userContext.Email)) | 		c.Header("Remote-Email", utils.SanitizeHeader(userContext.Email)) | ||||||
| 		c.Header("Remote-Groups", utils.SanitizeHeader(userContext.OAuthGroups)) | 		c.Header("Remote-Groups", utils.SanitizeHeader(userContext.OAuthGroups)) | ||||||
|  |  | ||||||
| 		// Set the rest of the headers |  | ||||||
| 		parsedHeaders := utils.ParseHeaders(labels.Headers) | 		parsedHeaders := utils.ParseHeaders(labels.Headers) | ||||||
| 		for key, value := range parsedHeaders { | 		for key, value := range parsedHeaders { | ||||||
| 			log.Debug().Str("key", key).Msg("Setting header") | 			log.Debug().Str("key", key).Msg("Setting header") | ||||||
| 			c.Header(key, value) | 			c.Header(key, value) | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// Set basic auth headers if configured |  | ||||||
| 		if labels.Basic.Username != "" && utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File) != "" { | 		if labels.Basic.Username != "" && utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File) != "" { | ||||||
| 			log.Debug().Str("username", labels.Basic.Username).Msg("Setting basic auth headers") | 			log.Debug().Str("username", labels.Basic.Username).Msg("Setting basic auth headers") | ||||||
| 			c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Basic.Username, utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File)))) | 			c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Basic.Username, utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File)))) | ||||||
|   | |||||||
| @@ -4,6 +4,7 @@ import ( | |||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"tinyauth/internal/auth" | 	"tinyauth/internal/auth" | ||||||
|  | 	"tinyauth/internal/oauth" | ||||||
| 	"tinyauth/internal/providers" | 	"tinyauth/internal/providers" | ||||||
| 	"tinyauth/internal/types" | 	"tinyauth/internal/types" | ||||||
| 	"tinyauth/internal/utils" | 	"tinyauth/internal/utils" | ||||||
| @@ -27,28 +28,92 @@ func NewHooks(config types.HooksConfig, auth *auth.Auth, providers *providers.Pr | |||||||
| } | } | ||||||
|  |  | ||||||
| func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { | func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { | ||||||
| 	// Get session cookie and basic auth |  | ||||||
| 	cookie, err := hooks.Auth.GetSessionCookie(c) | 	cookie, err := hooks.Auth.GetSessionCookie(c) | ||||||
|  | 	var provider *oauth.OAuth | ||||||
|  |  | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Error().Err(err).Msg("Failed to get session cookie") | ||||||
|  | 		goto basic | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if cookie.TotpPending { | ||||||
|  | 		log.Debug().Msg("Totp pending") | ||||||
|  | 		return types.UserContext{ | ||||||
|  | 			Username:    cookie.Username, | ||||||
|  | 			Name:        cookie.Name, | ||||||
|  | 			Email:       cookie.Email, | ||||||
|  | 			Provider:    cookie.Provider, | ||||||
|  | 			TotpPending: true, | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if cookie.Provider == "username" { | ||||||
|  | 		log.Debug().Msg("Provider is username") | ||||||
|  |  | ||||||
|  | 		userSearch := hooks.Auth.SearchUser(cookie.Username) | ||||||
|  |  | ||||||
|  | 		if userSearch.Type == "unknown" { | ||||||
|  | 			log.Warn().Str("username", cookie.Username).Msg("User does not exist") | ||||||
|  | 			goto basic | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		log.Debug().Str("type", userSearch.Type).Msg("User exists") | ||||||
|  |  | ||||||
|  | 		return types.UserContext{ | ||||||
|  | 			Username:   cookie.Username, | ||||||
|  | 			Name:       cookie.Name, | ||||||
|  | 			Email:      cookie.Email, | ||||||
|  | 			IsLoggedIn: true, | ||||||
|  | 			Provider:   "username", | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	log.Debug().Msg("Provider is not username") | ||||||
|  |  | ||||||
|  | 	provider = hooks.Providers.GetProvider(cookie.Provider) | ||||||
|  |  | ||||||
|  | 	if provider != nil { | ||||||
|  | 		log.Debug().Msg("Provider exists") | ||||||
|  |  | ||||||
|  | 		if !hooks.Auth.EmailWhitelisted(cookie.Email) { | ||||||
|  | 			log.Warn().Str("email", cookie.Email).Msg("Email is not whitelisted") | ||||||
|  | 			hooks.Auth.DeleteSessionCookie(c) | ||||||
|  | 			goto basic | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		log.Debug().Msg("Email is whitelisted") | ||||||
|  |  | ||||||
|  | 		return types.UserContext{ | ||||||
|  | 			Username:    cookie.Username, | ||||||
|  | 			Name:        cookie.Name, | ||||||
|  | 			Email:       cookie.Email, | ||||||
|  | 			IsLoggedIn:  true, | ||||||
|  | 			OAuth:       true, | ||||||
|  | 			Provider:    cookie.Provider, | ||||||
|  | 			OAuthGroups: cookie.OAuthGroups, | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | basic: | ||||||
|  | 	log.Debug().Msg("Trying basic auth") | ||||||
|  |  | ||||||
| 	basic := hooks.Auth.GetBasicAuth(c) | 	basic := hooks.Auth.GetBasicAuth(c) | ||||||
|  |  | ||||||
| 	// Check if basic auth is set |  | ||||||
| 	if basic != nil { | 	if basic != nil { | ||||||
| 		log.Debug().Msg("Got basic auth") | 		log.Debug().Msg("Got basic auth") | ||||||
|  |  | ||||||
| 		userSearch := hooks.Auth.SearchUser(basic.Username) | 		userSearch := hooks.Auth.SearchUser(basic.Username) | ||||||
|  |  | ||||||
| 		if userSearch.Type == "" { | 		if userSearch.Type == "unkown" { | ||||||
| 			log.Error().Str("username", basic.Username).Msg("User does not exist") | 			log.Error().Str("username", basic.Username).Msg("Basic auth user does not exist") | ||||||
| 			return types.UserContext{} | 			return types.UserContext{} | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// Verify the user |  | ||||||
| 		if !hooks.Auth.VerifyUser(userSearch, basic.Password) { | 		if !hooks.Auth.VerifyUser(userSearch, basic.Password) { | ||||||
| 			log.Error().Str("username", basic.Username).Msg("Password incorrect") | 			log.Error().Str("username", basic.Username).Msg("Basic auth user password incorrect") | ||||||
| 			return types.UserContext{} | 			return types.UserContext{} | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// Get the user type |  | ||||||
| 		if userSearch.Type == "ldap" { | 		if userSearch.Type == "ldap" { | ||||||
| 			log.Debug().Msg("User is LDAP") | 			log.Debug().Msg("User is LDAP") | ||||||
|  |  | ||||||
| @@ -75,73 +140,5 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { | |||||||
|  |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Check cookie error after basic auth |  | ||||||
| 	if err != nil { |  | ||||||
| 		log.Error().Err(err).Msg("Failed to get session cookie") |  | ||||||
| 		return types.UserContext{} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if cookie.TotpPending { |  | ||||||
| 		log.Debug().Msg("Totp pending") |  | ||||||
| 		return types.UserContext{ |  | ||||||
| 			Username:    cookie.Username, |  | ||||||
| 			Name:        cookie.Name, |  | ||||||
| 			Email:       cookie.Email, |  | ||||||
| 			Provider:    cookie.Provider, |  | ||||||
| 			TotpPending: true, |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Check if session cookie is username/password auth |  | ||||||
| 	if cookie.Provider == "username" { |  | ||||||
| 		log.Debug().Msg("Provider is username") |  | ||||||
|  |  | ||||||
| 		userSearch := hooks.Auth.SearchUser(cookie.Username) |  | ||||||
|  |  | ||||||
| 		if userSearch.Type == "" { |  | ||||||
| 			log.Error().Str("username", cookie.Username).Msg("User does not exist") |  | ||||||
| 			return types.UserContext{} |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		log.Debug().Str("type", userSearch.Type).Msg("User exists") |  | ||||||
|  |  | ||||||
| 		return types.UserContext{ |  | ||||||
| 			Username:   cookie.Username, |  | ||||||
| 			Name:       cookie.Name, |  | ||||||
| 			Email:      cookie.Email, |  | ||||||
| 			IsLoggedIn: true, |  | ||||||
| 			Provider:   "username", |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Provider is not username") |  | ||||||
|  |  | ||||||
| 	// The provider is not username so we need to check if it is an oauth provider |  | ||||||
| 	provider := hooks.Providers.GetProvider(cookie.Provider) |  | ||||||
|  |  | ||||||
| 	// If we have a provider with this name |  | ||||||
| 	if provider != nil { |  | ||||||
| 		log.Debug().Msg("Provider exists") |  | ||||||
|  |  | ||||||
| 		// If the email is not whitelisted we delete the cookie and return an empty context |  | ||||||
| 		if !hooks.Auth.EmailWhitelisted(cookie.Email) { |  | ||||||
| 			log.Error().Str("email", cookie.Email).Msg("Email is not whitelisted") |  | ||||||
| 			hooks.Auth.DeleteSessionCookie(c) |  | ||||||
| 			return types.UserContext{} |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Email is whitelisted") |  | ||||||
|  |  | ||||||
| 		return types.UserContext{ |  | ||||||
| 			Username:    cookie.Username, |  | ||||||
| 			Name:        cookie.Name, |  | ||||||
| 			Email:       cookie.Email, |  | ||||||
| 			IsLoggedIn:  true, |  | ||||||
| 			OAuth:       true, |  | ||||||
| 			Provider:    cookie.Provider, |  | ||||||
| 			OAuthGroups: cookie.OAuthGroups, |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	return types.UserContext{} | 	return types.UserContext{} | ||||||
| } | } | ||||||
|   | |||||||
| @@ -1,11 +1,13 @@ | |||||||
| package ldap | package ldap | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"crypto/tls" | 	"crypto/tls" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"time" | 	"time" | ||||||
| 	"tinyauth/internal/types" | 	"tinyauth/internal/types" | ||||||
|  |  | ||||||
|  | 	"github.com/cenkalti/backoff/v5" | ||||||
| 	ldapgo "github.com/go-ldap/ldap/v3" | 	ldapgo "github.com/go-ldap/ldap/v3" | ||||||
| 	"github.com/rs/zerolog/log" | 	"github.com/rs/zerolog/log" | ||||||
| ) | ) | ||||||
| @@ -30,6 +32,11 @@ func NewLDAP(config types.LdapConfig) (*LDAP, error) { | |||||||
| 			err := ldap.heartbeat() | 			err := ldap.heartbeat() | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				log.Error().Err(err).Msg("LDAP connection heartbeat failed") | 				log.Error().Err(err).Msg("LDAP connection heartbeat failed") | ||||||
|  | 				if reconnectErr := ldap.reconnect(); reconnectErr != nil { | ||||||
|  | 					log.Error().Err(reconnectErr).Msg("Failed to reconnect to LDAP server") | ||||||
|  | 					continue | ||||||
|  | 				} | ||||||
|  | 				log.Info().Msg("Successfully reconnected to LDAP server") | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	}() | 	}() | ||||||
| @@ -38,6 +45,7 @@ func NewLDAP(config types.LdapConfig) (*LDAP, error) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (l *LDAP) connect() (*ldapgo.Conn, error) { | func (l *LDAP) connect() (*ldapgo.Conn, error) { | ||||||
|  | 	log.Debug().Msg("Connecting to LDAP server") | ||||||
| 	conn, err := ldapgo.DialURL(l.Config.Address, ldapgo.DialWithTLSConfig(&tls.Config{ | 	conn, err := ldapgo.DialURL(l.Config.Address, ldapgo.DialWithTLSConfig(&tls.Config{ | ||||||
| 		InsecureSkipVerify: l.Config.Insecure, | 		InsecureSkipVerify: l.Config.Insecure, | ||||||
| 		MinVersion:         tls.VersionTLS12, | 		MinVersion:         tls.VersionTLS12, | ||||||
| @@ -46,6 +54,7 @@ func (l *LDAP) connect() (*ldapgo.Conn, error) { | |||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	log.Debug().Msg("Binding to LDAP server") | ||||||
| 	err = conn.Bind(l.Config.BindDN, l.Config.BindPassword) | 	err = conn.Bind(l.Config.BindDN, l.Config.BindPassword) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @@ -109,3 +118,30 @@ func (l *LDAP) heartbeat() error { | |||||||
| 	// No error means the connection is alive | 	// No error means the connection is alive | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (l *LDAP) reconnect() error { | ||||||
|  | 	log.Info().Msg("Reconnecting to LDAP server") | ||||||
|  |  | ||||||
|  | 	exp := backoff.NewExponentialBackOff() | ||||||
|  | 	exp.InitialInterval = 500 * time.Millisecond | ||||||
|  | 	exp.RandomizationFactor = 0.1 | ||||||
|  | 	exp.Multiplier = 1.5 | ||||||
|  | 	exp.Reset() | ||||||
|  |  | ||||||
|  | 	operation := func() (*ldapgo.Conn, error) { | ||||||
|  | 		l.Conn.Close() | ||||||
|  | 		_, err := l.connect() | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, nil | ||||||
|  | 		} | ||||||
|  | 		return nil, nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	_, err := backoff.Retry(context.TODO(), operation, backoff.WithBackOff(exp), backoff.WithMaxTries(3)) | ||||||
|  |  | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|   | |||||||
| @@ -327,3 +327,15 @@ func DeriveKey(secret string, info string) (string, error) { | |||||||
| 	encodedKey := base64.StdEncoding.EncodeToString(key) | 	encodedKey := base64.StdEncoding.EncodeToString(key) | ||||||
| 	return encodedKey, nil | 	return encodedKey, nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func CoalesceToString(value any) string { | ||||||
|  | 	switch v := value.(type) { | ||||||
|  | 	case []string: | ||||||
|  | 		return strings.Join(v, ",") | ||||||
|  | 	case string: | ||||||
|  | 		return v | ||||||
|  | 	default: | ||||||
|  | 		log.Warn().Interface("value", value).Msg("Unsupported type, returning empty string") | ||||||
|  | 		return "" | ||||||
|  | 	} | ||||||
|  | } | ||||||
|   | |||||||
| @@ -511,3 +511,38 @@ func TestDeriveKey(t *testing.T) { | |||||||
| 		t.Fatalf("Expected %v, got %v", expected, result) | 		t.Fatalf("Expected %v, got %v", expected, result) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestCoalesceToString(t *testing.T) { | ||||||
|  | 	t.Log("Testing coalesce to string with a string") | ||||||
|  |  | ||||||
|  | 	value := "test" | ||||||
|  | 	expected := "test" | ||||||
|  |  | ||||||
|  | 	result := utils.CoalesceToString(value) | ||||||
|  |  | ||||||
|  | 	if result != expected { | ||||||
|  | 		t.Fatalf("Expected %v, got %v", expected, result) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	t.Log("Testing coalesce to string with a slice of strings") | ||||||
|  |  | ||||||
|  | 	valueSlice := []string{"test1", "test2"} | ||||||
|  | 	expected = "test1,test2" | ||||||
|  |  | ||||||
|  | 	result = utils.CoalesceToString(valueSlice) | ||||||
|  |  | ||||||
|  | 	if result != expected { | ||||||
|  | 		t.Fatalf("Expected %v, got %v", expected, result) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	t.Log("Testing coalesce to string with an unsupported type") | ||||||
|  |  | ||||||
|  | 	valueUnsupported := 12345 | ||||||
|  | 	expected = "" | ||||||
|  |  | ||||||
|  | 	result = utils.CoalesceToString(valueUnsupported) | ||||||
|  |  | ||||||
|  | 	if result != expected { | ||||||
|  | 		t.Fatalf("Expected %v, got %v", expected, result) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user