Compare commits

...

5 Commits

Author SHA1 Message Date
Stavros
2233557990 tests: move handlers test to handlers package 2025-07-15 01:38:01 +03:00
Stavros
d3bec635f8 fix: make tinyauth not "eat" the authorization header 2025-07-15 01:34:25 +03:00
Stavros
6519644fc1 fix: handle type string for oauth groups 2025-07-15 00:17:41 +03:00
Stavros
736f65b7b2 refactor: close connection before trying to reconnect 2025-07-14 20:10:15 +03:00
Stavros
63d39b5500 feat: try to reconnect to ldap server if heartbeat fails 2025-07-14 20:02:16 +03:00
11 changed files with 107 additions and 26 deletions

1
go.mod
View File

@@ -17,6 +17,7 @@ require (
require (
github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 // indirect
github.com/cenkalti/backoff/v5 v5.0.2 // indirect
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
github.com/charmbracelet/x/cellbuf v0.0.13 // indirect
github.com/containerd/errdefs v1.0.0 // indirect

2
go.sum
View File

@@ -26,6 +26,8 @@ github.com/catppuccin/go v0.3.0 h1:d+0/YicIq+hSTo5oPuRi5kOpqkVA5tAsU6dNhvRu+aY=
github.com/catppuccin/go v0.3.0/go.mod h1:8IHJuMGaUUjQM82qBrGNBv7LFq6JI3NnQCF6MOlZjpc=
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
github.com/cenkalti/backoff/v5 v5.0.2 h1:rIfFVxEf1QsI7E1ZHfp/B4DF/6QBAUhmgkxc0H7Zss8=
github.com/cenkalti/backoff/v5 v5.0.2/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
github.com/charmbracelet/bubbles v0.21.0 h1:9TdC97SdRVg/1aaXNVWfFH3nnLAwOXr8Fn6u6mfQdFs=
github.com/charmbracelet/bubbles v0.21.0/go.mod h1:HF+v6QUR4HkEpz62dx7ym2xc71/KBHg+zKwJtMw+qtg=
github.com/charmbracelet/bubbletea v1.3.4 h1:kCg7B+jSCFPLYRA52SDZjr51kG/fMUEoPoZrkaDHyoI=

View File

@@ -88,7 +88,9 @@ func (auth *Auth) SearchUser(username string) types.UserSearch {
}
}
return types.UserSearch{}
return types.UserSearch{
Type: "unknown",
}
}
func (auth *Auth) VerifyUser(search types.UserSearch, password string) bool {

View File

@@ -2,10 +2,10 @@ package constants
// Claims are the OIDC supported claims (prefered username is included for convinience)
type Claims struct {
Name string `json:"name"`
Email string `json:"email"`
PreferredUsername string `json:"preferred_username"`
Groups []string `json:"groups"`
Name string `json:"name"`
Email string `json:"email"`
PreferredUsername string `json:"preferred_username"`
Groups any `json:"groups"`
}
// Version information

View File

@@ -1,4 +1,4 @@
package server_test
package handlers_test
import (
"encoding/json"

View File

@@ -189,7 +189,7 @@ func (h *Handlers) OAuthCallbackHandler(c *gin.Context) {
Name: name,
Email: user.Email,
Provider: providerName.Provider,
OAuthGroups: strings.Join(user.Groups, ","),
OAuthGroups: utils.CoalesceToString(user.Groups),
})
// Check if we have a redirect URI

View File

@@ -40,10 +40,7 @@ func (h *Handlers) ProxyHandler(c *gin.Context) {
proto := c.Request.Header.Get("X-Forwarded-Proto")
host := c.Request.Header.Get("X-Forwarded-Host")
// Remove the port from the host if it exists
hostPortless := strings.Split(host, ":")[0] // *lol*
// Get the id
id := strings.Split(hostPortless, ".")[0]
labels, err := h.Docker.GetLabels(id, hostPortless)
@@ -66,10 +63,10 @@ func (h *Handlers) ProxyHandler(c *gin.Context) {
ip := c.ClientIP()
// Check if the IP is in bypass list
if h.Auth.BypassedIP(labels, ip) {
headersParsed := utils.ParseHeaders(labels.Headers)
c.Header("Authorization", c.Request.Header.Get("Authorization"))
headersParsed := utils.ParseHeaders(labels.Headers)
for key, value := range headersParsed {
log.Debug().Str("key", key).Msg("Setting header")
c.Header(key, value)
@@ -87,7 +84,6 @@ func (h *Handlers) ProxyHandler(c *gin.Context) {
return
}
// Check if the IP is allowed/blocked
if !h.Auth.CheckIP(labels, ip) {
if proxy.Proxy == "nginx" || !isBrowser {
c.JSON(403, gin.H{
@@ -113,7 +109,6 @@ func (h *Handlers) ProxyHandler(c *gin.Context) {
return
}
// Check if auth is enabled
authEnabled, err := h.Auth.AuthEnabled(uri, labels)
if err != nil {
log.Error().Err(err).Msg("Failed to check if app is allowed")
@@ -129,8 +124,9 @@ func (h *Handlers) ProxyHandler(c *gin.Context) {
return
}
// If auth is not enabled, return 200
if !authEnabled {
c.Header("Authorization", c.Request.Header.Get("Authorization"))
headersParsed := utils.ParseHeaders(labels.Headers)
for key, value := range headersParsed {
log.Debug().Str("key", key).Msg("Setting header")
@@ -150,7 +146,6 @@ func (h *Handlers) ProxyHandler(c *gin.Context) {
return
}
// Get user context
userContext := h.Hooks.UseUserContext(c)
// If we are using basic auth, we need to check if the user has totp and if it does then disable basic auth
@@ -159,7 +154,6 @@ func (h *Handlers) ProxyHandler(c *gin.Context) {
userContext.IsLoggedIn = false
}
// Check if user is logged in
if userContext.IsLoggedIn {
log.Debug().Msg("Authenticated")
@@ -200,7 +194,6 @@ func (h *Handlers) ProxyHandler(c *gin.Context) {
return
}
// Check groups if using OAuth
if userContext.OAuth {
groupOk := h.Auth.OAuthGroup(c, userContext, labels)
@@ -239,19 +232,18 @@ func (h *Handlers) ProxyHandler(c *gin.Context) {
}
}
c.Header("Authorization", c.Request.Header.Get("Authorization"))
c.Header("Remote-User", utils.SanitizeHeader(userContext.Username))
c.Header("Remote-Name", utils.SanitizeHeader(userContext.Name))
c.Header("Remote-Email", utils.SanitizeHeader(userContext.Email))
c.Header("Remote-Groups", utils.SanitizeHeader(userContext.OAuthGroups))
// Set the rest of the headers
parsedHeaders := utils.ParseHeaders(labels.Headers)
for key, value := range parsedHeaders {
log.Debug().Str("key", key).Msg("Setting header")
c.Header(key, value)
}
// Set basic auth headers if configured
if labels.Basic.Username != "" && utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File) != "" {
log.Debug().Str("username", labels.Basic.Username).Msg("Setting basic auth headers")
c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Basic.Username, utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File))))

View File

@@ -37,15 +37,15 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
userSearch := hooks.Auth.SearchUser(basic.Username)
if userSearch.Type == "" {
log.Error().Str("username", basic.Username).Msg("User does not exist")
return types.UserContext{}
if userSearch.Type == "unkown" {
log.Warn().Str("username", basic.Username).Msg("Basic auth user does not exist, skipping")
goto session
}
// Verify the user
if !hooks.Auth.VerifyUser(userSearch, basic.Password) {
log.Error().Str("username", basic.Username).Msg("Password incorrect")
return types.UserContext{}
log.Error().Str("username", basic.Username).Msg("Basic auth user password incorrect, skipping")
goto session
}
// Get the user type
@@ -75,6 +75,7 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
}
session:
// Check cookie error after basic auth
if err != nil {
log.Error().Err(err).Msg("Failed to get session cookie")
@@ -98,7 +99,7 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
userSearch := hooks.Auth.SearchUser(cookie.Username)
if userSearch.Type == "" {
if userSearch.Type == "unknown" {
log.Error().Str("username", cookie.Username).Msg("User does not exist")
return types.UserContext{}
}

View File

@@ -1,11 +1,13 @@
package ldap
import (
"context"
"crypto/tls"
"fmt"
"time"
"tinyauth/internal/types"
"github.com/cenkalti/backoff/v5"
ldapgo "github.com/go-ldap/ldap/v3"
"github.com/rs/zerolog/log"
)
@@ -30,6 +32,11 @@ func NewLDAP(config types.LdapConfig) (*LDAP, error) {
err := ldap.heartbeat()
if err != nil {
log.Error().Err(err).Msg("LDAP connection heartbeat failed")
if reconnectErr := ldap.reconnect(); reconnectErr != nil {
log.Error().Err(reconnectErr).Msg("Failed to reconnect to LDAP server")
continue
}
log.Info().Msg("Successfully reconnected to LDAP server")
}
}
}()
@@ -38,6 +45,7 @@ func NewLDAP(config types.LdapConfig) (*LDAP, error) {
}
func (l *LDAP) connect() (*ldapgo.Conn, error) {
log.Debug().Msg("Connecting to LDAP server")
conn, err := ldapgo.DialURL(l.Config.Address, ldapgo.DialWithTLSConfig(&tls.Config{
InsecureSkipVerify: l.Config.Insecure,
MinVersion: tls.VersionTLS12,
@@ -46,6 +54,7 @@ func (l *LDAP) connect() (*ldapgo.Conn, error) {
return nil, err
}
log.Debug().Msg("Binding to LDAP server")
err = conn.Bind(l.Config.BindDN, l.Config.BindPassword)
if err != nil {
return nil, err
@@ -109,3 +118,30 @@ func (l *LDAP) heartbeat() error {
// No error means the connection is alive
return nil
}
func (l *LDAP) reconnect() error {
log.Info().Msg("Reconnecting to LDAP server")
exp := backoff.NewExponentialBackOff()
exp.InitialInterval = 500 * time.Millisecond
exp.RandomizationFactor = 0.1
exp.Multiplier = 1.5
exp.Reset()
operation := func() (*ldapgo.Conn, error) {
l.Conn.Close()
_, err := l.connect()
if err != nil {
return nil, nil
}
return nil, nil
}
_, err := backoff.Retry(context.TODO(), operation, backoff.WithBackOff(exp), backoff.WithMaxTries(3))
if err != nil {
return err
}
return nil
}

View File

@@ -327,3 +327,15 @@ func DeriveKey(secret string, info string) (string, error) {
encodedKey := base64.StdEncoding.EncodeToString(key)
return encodedKey, nil
}
func CoalesceToString(value any) string {
switch v := value.(type) {
case []string:
return strings.Join(v, ",")
case string:
return v
default:
log.Warn().Interface("value", value).Msg("Unsupported type, returning empty string")
return ""
}
}

View File

@@ -511,3 +511,38 @@ func TestDeriveKey(t *testing.T) {
t.Fatalf("Expected %v, got %v", expected, result)
}
}
func TestCoalesceToString(t *testing.T) {
t.Log("Testing coalesce to string with a string")
value := "test"
expected := "test"
result := utils.CoalesceToString(value)
if result != expected {
t.Fatalf("Expected %v, got %v", expected, result)
}
t.Log("Testing coalesce to string with a slice of strings")
valueSlice := []string{"test1", "test2"}
expected = "test1,test2"
result = utils.CoalesceToString(valueSlice)
if result != expected {
t.Fatalf("Expected %v, got %v", expected, result)
}
t.Log("Testing coalesce to string with an unsupported type")
valueUnsupported := 12345
expected = ""
result = utils.CoalesceToString(valueUnsupported)
if result != expected {
t.Fatalf("Expected %v, got %v", expected, result)
}
}