mirror of
				https://github.com/steveiliop56/tinyauth.git
				synced 2025-11-03 23:55:44 +00:00 
			
		
		
		
	Compare commits
	
		
			5 Commits
		
	
	
		
			v3.6.1
			...
			v3.6.2-bet
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 
						 | 
					2233557990 | ||
| 
						 | 
					d3bec635f8 | ||
| 
						 | 
					6519644fc1 | ||
| 
						 | 
					736f65b7b2 | ||
| 
						 | 
					63d39b5500 | 
							
								
								
									
										1
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								go.mod
									
									
									
									
									
								
							@@ -17,6 +17,7 @@ require (
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
require (
 | 
					require (
 | 
				
			||||||
	github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 // indirect
 | 
						github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 // indirect
 | 
				
			||||||
 | 
						github.com/cenkalti/backoff/v5 v5.0.2 // indirect
 | 
				
			||||||
	github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
 | 
						github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
 | 
				
			||||||
	github.com/charmbracelet/x/cellbuf v0.0.13 // indirect
 | 
						github.com/charmbracelet/x/cellbuf v0.0.13 // indirect
 | 
				
			||||||
	github.com/containerd/errdefs v1.0.0 // indirect
 | 
						github.com/containerd/errdefs v1.0.0 // indirect
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										2
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								go.sum
									
									
									
									
									
								
							@@ -26,6 +26,8 @@ github.com/catppuccin/go v0.3.0 h1:d+0/YicIq+hSTo5oPuRi5kOpqkVA5tAsU6dNhvRu+aY=
 | 
				
			|||||||
github.com/catppuccin/go v0.3.0/go.mod h1:8IHJuMGaUUjQM82qBrGNBv7LFq6JI3NnQCF6MOlZjpc=
 | 
					github.com/catppuccin/go v0.3.0/go.mod h1:8IHJuMGaUUjQM82qBrGNBv7LFq6JI3NnQCF6MOlZjpc=
 | 
				
			||||||
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
 | 
					github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
 | 
				
			||||||
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
 | 
					github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
 | 
				
			||||||
 | 
					github.com/cenkalti/backoff/v5 v5.0.2 h1:rIfFVxEf1QsI7E1ZHfp/B4DF/6QBAUhmgkxc0H7Zss8=
 | 
				
			||||||
 | 
					github.com/cenkalti/backoff/v5 v5.0.2/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
 | 
				
			||||||
github.com/charmbracelet/bubbles v0.21.0 h1:9TdC97SdRVg/1aaXNVWfFH3nnLAwOXr8Fn6u6mfQdFs=
 | 
					github.com/charmbracelet/bubbles v0.21.0 h1:9TdC97SdRVg/1aaXNVWfFH3nnLAwOXr8Fn6u6mfQdFs=
 | 
				
			||||||
github.com/charmbracelet/bubbles v0.21.0/go.mod h1:HF+v6QUR4HkEpz62dx7ym2xc71/KBHg+zKwJtMw+qtg=
 | 
					github.com/charmbracelet/bubbles v0.21.0/go.mod h1:HF+v6QUR4HkEpz62dx7ym2xc71/KBHg+zKwJtMw+qtg=
 | 
				
			||||||
github.com/charmbracelet/bubbletea v1.3.4 h1:kCg7B+jSCFPLYRA52SDZjr51kG/fMUEoPoZrkaDHyoI=
 | 
					github.com/charmbracelet/bubbletea v1.3.4 h1:kCg7B+jSCFPLYRA52SDZjr51kG/fMUEoPoZrkaDHyoI=
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -88,7 +88,9 @@ func (auth *Auth) SearchUser(username string) types.UserSearch {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return types.UserSearch{}
 | 
						return types.UserSearch{
 | 
				
			||||||
 | 
							Type: "unknown",
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (auth *Auth) VerifyUser(search types.UserSearch, password string) bool {
 | 
					func (auth *Auth) VerifyUser(search types.UserSearch, password string) bool {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -5,7 +5,7 @@ type Claims struct {
 | 
				
			|||||||
	Name              string `json:"name"`
 | 
						Name              string `json:"name"`
 | 
				
			||||||
	Email             string `json:"email"`
 | 
						Email             string `json:"email"`
 | 
				
			||||||
	PreferredUsername string `json:"preferred_username"`
 | 
						PreferredUsername string `json:"preferred_username"`
 | 
				
			||||||
	Groups            []string `json:"groups"`
 | 
						Groups            any    `json:"groups"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Version information
 | 
					// Version information
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,4 +1,4 @@
 | 
				
			|||||||
package server_test
 | 
					package handlers_test
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"encoding/json"
 | 
						"encoding/json"
 | 
				
			||||||
@@ -189,7 +189,7 @@ func (h *Handlers) OAuthCallbackHandler(c *gin.Context) {
 | 
				
			|||||||
		Name:        name,
 | 
							Name:        name,
 | 
				
			||||||
		Email:       user.Email,
 | 
							Email:       user.Email,
 | 
				
			||||||
		Provider:    providerName.Provider,
 | 
							Provider:    providerName.Provider,
 | 
				
			||||||
		OAuthGroups: strings.Join(user.Groups, ","),
 | 
							OAuthGroups: utils.CoalesceToString(user.Groups),
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Check if we have a redirect URI
 | 
						// Check if we have a redirect URI
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -40,10 +40,7 @@ func (h *Handlers) ProxyHandler(c *gin.Context) {
 | 
				
			|||||||
	proto := c.Request.Header.Get("X-Forwarded-Proto")
 | 
						proto := c.Request.Header.Get("X-Forwarded-Proto")
 | 
				
			||||||
	host := c.Request.Header.Get("X-Forwarded-Host")
 | 
						host := c.Request.Header.Get("X-Forwarded-Host")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Remove the port from the host if it exists
 | 
					 | 
				
			||||||
	hostPortless := strings.Split(host, ":")[0] // *lol*
 | 
						hostPortless := strings.Split(host, ":")[0] // *lol*
 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Get the id
 | 
					 | 
				
			||||||
	id := strings.Split(hostPortless, ".")[0]
 | 
						id := strings.Split(hostPortless, ".")[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	labels, err := h.Docker.GetLabels(id, hostPortless)
 | 
						labels, err := h.Docker.GetLabels(id, hostPortless)
 | 
				
			||||||
@@ -66,10 +63,10 @@ func (h *Handlers) ProxyHandler(c *gin.Context) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	ip := c.ClientIP()
 | 
						ip := c.ClientIP()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Check if the IP is in bypass list
 | 
					 | 
				
			||||||
	if h.Auth.BypassedIP(labels, ip) {
 | 
						if h.Auth.BypassedIP(labels, ip) {
 | 
				
			||||||
		headersParsed := utils.ParseHeaders(labels.Headers)
 | 
							c.Header("Authorization", c.Request.Header.Get("Authorization"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							headersParsed := utils.ParseHeaders(labels.Headers)
 | 
				
			||||||
		for key, value := range headersParsed {
 | 
							for key, value := range headersParsed {
 | 
				
			||||||
			log.Debug().Str("key", key).Msg("Setting header")
 | 
								log.Debug().Str("key", key).Msg("Setting header")
 | 
				
			||||||
			c.Header(key, value)
 | 
								c.Header(key, value)
 | 
				
			||||||
@@ -87,7 +84,6 @@ func (h *Handlers) ProxyHandler(c *gin.Context) {
 | 
				
			|||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Check if the IP is allowed/blocked
 | 
					 | 
				
			||||||
	if !h.Auth.CheckIP(labels, ip) {
 | 
						if !h.Auth.CheckIP(labels, ip) {
 | 
				
			||||||
		if proxy.Proxy == "nginx" || !isBrowser {
 | 
							if proxy.Proxy == "nginx" || !isBrowser {
 | 
				
			||||||
			c.JSON(403, gin.H{
 | 
								c.JSON(403, gin.H{
 | 
				
			||||||
@@ -113,7 +109,6 @@ func (h *Handlers) ProxyHandler(c *gin.Context) {
 | 
				
			|||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Check if auth is enabled
 | 
					 | 
				
			||||||
	authEnabled, err := h.Auth.AuthEnabled(uri, labels)
 | 
						authEnabled, err := h.Auth.AuthEnabled(uri, labels)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		log.Error().Err(err).Msg("Failed to check if app is allowed")
 | 
							log.Error().Err(err).Msg("Failed to check if app is allowed")
 | 
				
			||||||
@@ -129,8 +124,9 @@ func (h *Handlers) ProxyHandler(c *gin.Context) {
 | 
				
			|||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// If auth is not enabled, return 200
 | 
					 | 
				
			||||||
	if !authEnabled {
 | 
						if !authEnabled {
 | 
				
			||||||
 | 
							c.Header("Authorization", c.Request.Header.Get("Authorization"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		headersParsed := utils.ParseHeaders(labels.Headers)
 | 
							headersParsed := utils.ParseHeaders(labels.Headers)
 | 
				
			||||||
		for key, value := range headersParsed {
 | 
							for key, value := range headersParsed {
 | 
				
			||||||
			log.Debug().Str("key", key).Msg("Setting header")
 | 
								log.Debug().Str("key", key).Msg("Setting header")
 | 
				
			||||||
@@ -150,7 +146,6 @@ func (h *Handlers) ProxyHandler(c *gin.Context) {
 | 
				
			|||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Get user context
 | 
					 | 
				
			||||||
	userContext := h.Hooks.UseUserContext(c)
 | 
						userContext := h.Hooks.UseUserContext(c)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// If we are using basic auth, we need to check if the user has totp and if it does then disable basic auth
 | 
						// If we are using basic auth, we need to check if the user has totp and if it does then disable basic auth
 | 
				
			||||||
@@ -159,7 +154,6 @@ func (h *Handlers) ProxyHandler(c *gin.Context) {
 | 
				
			|||||||
		userContext.IsLoggedIn = false
 | 
							userContext.IsLoggedIn = false
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Check if user is logged in
 | 
					 | 
				
			||||||
	if userContext.IsLoggedIn {
 | 
						if userContext.IsLoggedIn {
 | 
				
			||||||
		log.Debug().Msg("Authenticated")
 | 
							log.Debug().Msg("Authenticated")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -200,7 +194,6 @@ func (h *Handlers) ProxyHandler(c *gin.Context) {
 | 
				
			|||||||
			return
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// Check groups if using OAuth
 | 
					 | 
				
			||||||
		if userContext.OAuth {
 | 
							if userContext.OAuth {
 | 
				
			||||||
			groupOk := h.Auth.OAuthGroup(c, userContext, labels)
 | 
								groupOk := h.Auth.OAuthGroup(c, userContext, labels)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -239,19 +232,18 @@ func (h *Handlers) ProxyHandler(c *gin.Context) {
 | 
				
			|||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							c.Header("Authorization", c.Request.Header.Get("Authorization"))
 | 
				
			||||||
		c.Header("Remote-User", utils.SanitizeHeader(userContext.Username))
 | 
							c.Header("Remote-User", utils.SanitizeHeader(userContext.Username))
 | 
				
			||||||
		c.Header("Remote-Name", utils.SanitizeHeader(userContext.Name))
 | 
							c.Header("Remote-Name", utils.SanitizeHeader(userContext.Name))
 | 
				
			||||||
		c.Header("Remote-Email", utils.SanitizeHeader(userContext.Email))
 | 
							c.Header("Remote-Email", utils.SanitizeHeader(userContext.Email))
 | 
				
			||||||
		c.Header("Remote-Groups", utils.SanitizeHeader(userContext.OAuthGroups))
 | 
							c.Header("Remote-Groups", utils.SanitizeHeader(userContext.OAuthGroups))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// Set the rest of the headers
 | 
					 | 
				
			||||||
		parsedHeaders := utils.ParseHeaders(labels.Headers)
 | 
							parsedHeaders := utils.ParseHeaders(labels.Headers)
 | 
				
			||||||
		for key, value := range parsedHeaders {
 | 
							for key, value := range parsedHeaders {
 | 
				
			||||||
			log.Debug().Str("key", key).Msg("Setting header")
 | 
								log.Debug().Str("key", key).Msg("Setting header")
 | 
				
			||||||
			c.Header(key, value)
 | 
								c.Header(key, value)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// Set basic auth headers if configured
 | 
					 | 
				
			||||||
		if labels.Basic.Username != "" && utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File) != "" {
 | 
							if labels.Basic.Username != "" && utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File) != "" {
 | 
				
			||||||
			log.Debug().Str("username", labels.Basic.Username).Msg("Setting basic auth headers")
 | 
								log.Debug().Str("username", labels.Basic.Username).Msg("Setting basic auth headers")
 | 
				
			||||||
			c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Basic.Username, utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File))))
 | 
								c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Basic.Username, utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File))))
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -37,15 +37,15 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
		userSearch := hooks.Auth.SearchUser(basic.Username)
 | 
							userSearch := hooks.Auth.SearchUser(basic.Username)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if userSearch.Type == "" {
 | 
							if userSearch.Type == "unkown" {
 | 
				
			||||||
			log.Error().Str("username", basic.Username).Msg("User does not exist")
 | 
								log.Warn().Str("username", basic.Username).Msg("Basic auth user does not exist, skipping")
 | 
				
			||||||
			return types.UserContext{}
 | 
								goto session
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// Verify the user
 | 
							// Verify the user
 | 
				
			||||||
		if !hooks.Auth.VerifyUser(userSearch, basic.Password) {
 | 
							if !hooks.Auth.VerifyUser(userSearch, basic.Password) {
 | 
				
			||||||
			log.Error().Str("username", basic.Username).Msg("Password incorrect")
 | 
								log.Error().Str("username", basic.Username).Msg("Basic auth user password incorrect, skipping")
 | 
				
			||||||
			return types.UserContext{}
 | 
								goto session
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// Get the user type
 | 
							// Get the user type
 | 
				
			||||||
@@ -75,6 +75,7 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					session:
 | 
				
			||||||
	// Check cookie error after basic auth
 | 
						// Check cookie error after basic auth
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		log.Error().Err(err).Msg("Failed to get session cookie")
 | 
							log.Error().Err(err).Msg("Failed to get session cookie")
 | 
				
			||||||
@@ -98,7 +99,7 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
		userSearch := hooks.Auth.SearchUser(cookie.Username)
 | 
							userSearch := hooks.Auth.SearchUser(cookie.Username)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if userSearch.Type == "" {
 | 
							if userSearch.Type == "unknown" {
 | 
				
			||||||
			log.Error().Str("username", cookie.Username).Msg("User does not exist")
 | 
								log.Error().Str("username", cookie.Username).Msg("User does not exist")
 | 
				
			||||||
			return types.UserContext{}
 | 
								return types.UserContext{}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,11 +1,13 @@
 | 
				
			|||||||
package ldap
 | 
					package ldap
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
	"crypto/tls"
 | 
						"crypto/tls"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
	"tinyauth/internal/types"
 | 
						"tinyauth/internal/types"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/cenkalti/backoff/v5"
 | 
				
			||||||
	ldapgo "github.com/go-ldap/ldap/v3"
 | 
						ldapgo "github.com/go-ldap/ldap/v3"
 | 
				
			||||||
	"github.com/rs/zerolog/log"
 | 
						"github.com/rs/zerolog/log"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
@@ -30,6 +32,11 @@ func NewLDAP(config types.LdapConfig) (*LDAP, error) {
 | 
				
			|||||||
			err := ldap.heartbeat()
 | 
								err := ldap.heartbeat()
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				log.Error().Err(err).Msg("LDAP connection heartbeat failed")
 | 
									log.Error().Err(err).Msg("LDAP connection heartbeat failed")
 | 
				
			||||||
 | 
									if reconnectErr := ldap.reconnect(); reconnectErr != nil {
 | 
				
			||||||
 | 
										log.Error().Err(reconnectErr).Msg("Failed to reconnect to LDAP server")
 | 
				
			||||||
 | 
										continue
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									log.Info().Msg("Successfully reconnected to LDAP server")
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}()
 | 
						}()
 | 
				
			||||||
@@ -38,6 +45,7 @@ func NewLDAP(config types.LdapConfig) (*LDAP, error) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (l *LDAP) connect() (*ldapgo.Conn, error) {
 | 
					func (l *LDAP) connect() (*ldapgo.Conn, error) {
 | 
				
			||||||
 | 
						log.Debug().Msg("Connecting to LDAP server")
 | 
				
			||||||
	conn, err := ldapgo.DialURL(l.Config.Address, ldapgo.DialWithTLSConfig(&tls.Config{
 | 
						conn, err := ldapgo.DialURL(l.Config.Address, ldapgo.DialWithTLSConfig(&tls.Config{
 | 
				
			||||||
		InsecureSkipVerify: l.Config.Insecure,
 | 
							InsecureSkipVerify: l.Config.Insecure,
 | 
				
			||||||
		MinVersion:         tls.VersionTLS12,
 | 
							MinVersion:         tls.VersionTLS12,
 | 
				
			||||||
@@ -46,6 +54,7 @@ func (l *LDAP) connect() (*ldapgo.Conn, error) {
 | 
				
			|||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						log.Debug().Msg("Binding to LDAP server")
 | 
				
			||||||
	err = conn.Bind(l.Config.BindDN, l.Config.BindPassword)
 | 
						err = conn.Bind(l.Config.BindDN, l.Config.BindPassword)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
@@ -109,3 +118,30 @@ func (l *LDAP) heartbeat() error {
 | 
				
			|||||||
	// No error means the connection is alive
 | 
						// No error means the connection is alive
 | 
				
			||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (l *LDAP) reconnect() error {
 | 
				
			||||||
 | 
						log.Info().Msg("Reconnecting to LDAP server")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						exp := backoff.NewExponentialBackOff()
 | 
				
			||||||
 | 
						exp.InitialInterval = 500 * time.Millisecond
 | 
				
			||||||
 | 
						exp.RandomizationFactor = 0.1
 | 
				
			||||||
 | 
						exp.Multiplier = 1.5
 | 
				
			||||||
 | 
						exp.Reset()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						operation := func() (*ldapgo.Conn, error) {
 | 
				
			||||||
 | 
							l.Conn.Close()
 | 
				
			||||||
 | 
							_, err := l.connect()
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return nil, nil
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return nil, nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						_, err := backoff.Retry(context.TODO(), operation, backoff.WithBackOff(exp), backoff.WithMaxTries(3))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -327,3 +327,15 @@ func DeriveKey(secret string, info string) (string, error) {
 | 
				
			|||||||
	encodedKey := base64.StdEncoding.EncodeToString(key)
 | 
						encodedKey := base64.StdEncoding.EncodeToString(key)
 | 
				
			||||||
	return encodedKey, nil
 | 
						return encodedKey, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func CoalesceToString(value any) string {
 | 
				
			||||||
 | 
						switch v := value.(type) {
 | 
				
			||||||
 | 
						case []string:
 | 
				
			||||||
 | 
							return strings.Join(v, ",")
 | 
				
			||||||
 | 
						case string:
 | 
				
			||||||
 | 
							return v
 | 
				
			||||||
 | 
						default:
 | 
				
			||||||
 | 
							log.Warn().Interface("value", value).Msg("Unsupported type, returning empty string")
 | 
				
			||||||
 | 
							return ""
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -511,3 +511,38 @@ func TestDeriveKey(t *testing.T) {
 | 
				
			|||||||
		t.Fatalf("Expected %v, got %v", expected, result)
 | 
							t.Fatalf("Expected %v, got %v", expected, result)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestCoalesceToString(t *testing.T) {
 | 
				
			||||||
 | 
						t.Log("Testing coalesce to string with a string")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						value := "test"
 | 
				
			||||||
 | 
						expected := "test"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						result := utils.CoalesceToString(value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if result != expected {
 | 
				
			||||||
 | 
							t.Fatalf("Expected %v, got %v", expected, result)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						t.Log("Testing coalesce to string with a slice of strings")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						valueSlice := []string{"test1", "test2"}
 | 
				
			||||||
 | 
						expected = "test1,test2"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						result = utils.CoalesceToString(valueSlice)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if result != expected {
 | 
				
			||||||
 | 
							t.Fatalf("Expected %v, got %v", expected, result)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						t.Log("Testing coalesce to string with an unsupported type")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						valueUnsupported := 12345
 | 
				
			||||||
 | 
						expected = ""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						result = utils.CoalesceToString(valueUnsupported)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if result != expected {
 | 
				
			||||||
 | 
							t.Fatalf("Expected %v, got %v", expected, result)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user