feat: retrieve and store groups from ldap provider

This commit is contained in:
Stavros
2026-01-09 23:23:36 +02:00
parent 467c580ec4
commit 1b2bf3902c
12 changed files with 94 additions and 18 deletions

View File

@@ -0,0 +1 @@
ALTER TABLE "sessions" DROP COLUMN "ldap_groups";

View File

@@ -0,0 +1 @@
ALTER TABLE "sessions" ADD COLUMN "ldap_groups" TEXT;

View File

@@ -122,6 +122,11 @@ type User struct {
TotpSecret string TotpSecret string
} }
type LdapUser struct {
DN string
Groups []string
}
type UserSearch struct { type UserSearch struct {
Username string Username string
Type string // local, ldap or unknown Type string // local, ldap or unknown

View File

@@ -141,6 +141,21 @@ func (controller *UserController) loginHandler(c *gin.Context) {
Provider: "username", Provider: "username",
} }
if userSearch.Type == "ldap" {
ldapUser, err := controller.auth.GetLdapUser(userSearch.Username)
if err != nil {
log.Error().Err(err).Str("username", req.Username).Msg("Failed to get LDAP user details")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
})
return
}
sessionCookie.LdapGroups = strings.Join(ldapUser.Groups, ",")
}
log.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie") log.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
err = controller.auth.CreateSessionCookie(c, &sessionCookie) err = controller.auth.CreateSessionCookie(c, &sessionCookie)

View File

@@ -74,6 +74,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
Email: cookie.Email, Email: cookie.Email,
Provider: "username", Provider: "username",
IsLoggedIn: true, IsLoggedIn: true,
LdapGroups: cookie.LdapGroups,
}) })
c.Next() c.Next()
return return
@@ -155,7 +156,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
Username: user.Username, Username: user.Username,
Name: utils.Capitalize(user.Username), Name: utils.Capitalize(user.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), m.config.CookieDomain), Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), m.config.CookieDomain),
Provider: "basic", Provider: "username",
IsLoggedIn: true, IsLoggedIn: true,
TotpEnabled: user.TotpSecret != "", TotpEnabled: user.TotpSecret != "",
}) })
@@ -163,12 +164,22 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
return return
case "ldap": case "ldap":
log.Debug().Msg("Basic auth user is LDAP") log.Debug().Msg("Basic auth user is LDAP")
ldapUser, err := m.auth.GetLdapUser(basic.Username)
if err != nil {
log.Debug().Err(err).Msg("Error retrieving LDAP user details")
c.Next()
return
}
c.Set("context", &config.UserContext{ c.Set("context", &config.UserContext{
Username: basic.Username, Username: basic.Username,
Name: utils.Capitalize(basic.Username), Name: utils.Capitalize(basic.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(basic.Username), m.config.CookieDomain), Email: fmt.Sprintf("%s@%s", strings.ToLower(basic.Username), m.config.CookieDomain),
Provider: "basic", Provider: "ldap",
IsLoggedIn: true, IsLoggedIn: true,
LdapGroups: strings.Join(ldapUser.Groups, ","),
}) })
c.Next() c.Next()
return return

View File

@@ -16,4 +16,5 @@ type Session struct {
CreatedAt int64 CreatedAt int64
OAuthName string OAuthName string
OAuthSub string OAuthSub string
LdapGroups string
} }

View File

@@ -21,11 +21,12 @@ INSERT INTO sessions (
"expiry", "expiry",
"created_at", "created_at",
"oauth_name", "oauth_name",
"oauth_sub" "oauth_sub",
"ldap_groups"
) VALUES ( ) VALUES (
?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?
) )
RETURNING uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub RETURNING uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub, ldap_groups
` `
type CreateSessionParams struct { type CreateSessionParams struct {
@@ -40,6 +41,7 @@ type CreateSessionParams struct {
CreatedAt int64 CreatedAt int64
OAuthName string OAuthName string
OAuthSub string OAuthSub string
LdapGroups string
} }
func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error) { func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error) {
@@ -55,6 +57,7 @@ func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (S
arg.CreatedAt, arg.CreatedAt,
arg.OAuthName, arg.OAuthName,
arg.OAuthSub, arg.OAuthSub,
arg.LdapGroups,
) )
var i Session var i Session
err := row.Scan( err := row.Scan(
@@ -69,6 +72,7 @@ func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (S
&i.CreatedAt, &i.CreatedAt,
&i.OAuthName, &i.OAuthName,
&i.OAuthSub, &i.OAuthSub,
&i.LdapGroups,
) )
return i, err return i, err
} }
@@ -94,7 +98,7 @@ func (q *Queries) DeleteSession(ctx context.Context, uuid string) error {
} }
const getSession = `-- name: GetSession :one const getSession = `-- name: GetSession :one
SELECT uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub FROM "sessions" SELECT uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub, ldap_groups FROM "sessions"
WHERE "uuid" = ? WHERE "uuid" = ?
` `
@@ -113,6 +117,7 @@ func (q *Queries) GetSession(ctx context.Context, uuid string) (Session, error)
&i.CreatedAt, &i.CreatedAt,
&i.OAuthName, &i.OAuthName,
&i.OAuthSub, &i.OAuthSub,
&i.LdapGroups,
) )
return i, err return i, err
} }
@@ -127,9 +132,10 @@ UPDATE "sessions" SET
"oauth_groups" = ?, "oauth_groups" = ?,
"expiry" = ?, "expiry" = ?,
"oauth_name" = ?, "oauth_name" = ?,
"oauth_sub" = ? "oauth_sub" = ?,
"ldap_groups" = ?
WHERE "uuid" = ? WHERE "uuid" = ?
RETURNING uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub RETURNING uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub, ldap_groups
` `
type UpdateSessionParams struct { type UpdateSessionParams struct {
@@ -142,6 +148,7 @@ type UpdateSessionParams struct {
Expiry int64 Expiry int64
OAuthName string OAuthName string
OAuthSub string OAuthSub string
LdapGroups string
UUID string UUID string
} }
@@ -156,6 +163,7 @@ func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (S
arg.Expiry, arg.Expiry,
arg.OAuthName, arg.OAuthName,
arg.OAuthSub, arg.OAuthSub,
arg.LdapGroups,
arg.UUID, arg.UUID,
) )
var i Session var i Session
@@ -171,6 +179,7 @@ func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (S
&i.CreatedAt, &i.CreatedAt,
&i.OAuthName, &i.OAuthName,
&i.OAuthSub, &i.OAuthSub,
&i.LdapGroups,
) )
return i, err return i, err
} }

View File

@@ -70,7 +70,7 @@ func (auth *AuthService) SearchUser(username string) config.UserSearch {
} }
if auth.ldap != nil { if auth.ldap != nil {
userDN, err := auth.ldap.Search(username) userDN, err := auth.ldap.GetUserDN(username)
if err != nil { if err != nil {
log.Warn().Err(err).Str("username", username).Msg("Failed to search for user in LDAP") log.Warn().Err(err).Str("username", username).Msg("Failed to search for user in LDAP")
@@ -131,6 +131,19 @@ func (auth *AuthService) GetLocalUser(username string) config.User {
return config.User{} return config.User{}
} }
func (auth *AuthService) GetLdapUser(userDN string) (config.LdapUser, error) {
groups, err := auth.ldap.GetUserGroups(userDN)
if err != nil {
return config.LdapUser{}, err
}
return config.LdapUser{
DN: userDN,
Groups: groups,
}, nil
}
func (auth *AuthService) CheckPassword(user config.User, password string) bool { func (auth *AuthService) CheckPassword(user config.User, password string) bool {
return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) == nil return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) == nil
} }
@@ -217,6 +230,7 @@ func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *repository.Se
CreatedAt: time.Now().Unix(), CreatedAt: time.Now().Unix(),
OAuthName: data.OAuthName, OAuthName: data.OAuthName,
OAuthSub: data.OAuthSub, OAuthSub: data.OAuthSub,
LdapGroups: data.LdapGroups,
} }
_, err = auth.queries.CreateSession(c, session) _, err = auth.queries.CreateSession(c, session)
@@ -270,6 +284,7 @@ func (auth *AuthService) RefreshSessionCookie(c *gin.Context) error {
OAuthName: session.OAuthName, OAuthName: session.OAuthName,
OAuthSub: session.OAuthSub, OAuthSub: session.OAuthSub,
UUID: session.UUID, UUID: session.UUID,
LdapGroups: session.LdapGroups,
}) })
if err != nil { if err != nil {
@@ -346,6 +361,7 @@ func (auth *AuthService) GetSessionCookie(c *gin.Context) (repository.Session, e
OAuthGroups: session.OAuthGroups, OAuthGroups: session.OAuthGroups,
OAuthName: session.OAuthName, OAuthName: session.OAuthName,
OAuthSub: session.OAuthSub, OAuthSub: session.OAuthSub,
LdapGroups: session.LdapGroups,
}, nil }, nil
} }

View File

@@ -5,6 +5,7 @@ import (
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"slices" "slices"
"strings"
"sync" "sync"
"time" "time"
@@ -117,7 +118,7 @@ func (ldap *LdapService) connect() (*ldapgo.Conn, error) {
return ldap.conn, nil return ldap.conn, nil
} }
func (ldap *LdapService) Search(username string) (string, error) { func (ldap *LdapService) GetUserDN(username string) (string, error) {
// Escape the username to prevent LDAP injection // Escape the username to prevent LDAP injection
escapedUsername := ldapgo.EscapeFilter(username) escapedUsername := ldapgo.EscapeFilter(username)
filter := fmt.Sprintf(ldap.config.SearchFilter, escapedUsername) filter := fmt.Sprintf(ldap.config.SearchFilter, escapedUsername)
@@ -146,7 +147,7 @@ func (ldap *LdapService) Search(username string) (string, error) {
return userDN, nil return userDN, nil
} }
func (ldap *LdapService) GetUserGroups(username string) ([]string, error) { func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) {
searchRequest := ldapgo.NewSearchRequest( searchRequest := ldapgo.NewSearchRequest(
ldap.config.BaseDN, ldap.config.BaseDN,
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false, ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
@@ -163,13 +164,24 @@ func (ldap *LdapService) GetUserGroups(username string) ([]string, error) {
return []string{}, err return []string{}, err
} }
groups := []string{} groupDNs := []string{}
for _, entry := range searchResult.Entries { for _, entry := range searchResult.Entries {
memberAttributes := entry.GetAttributeValues("uniquemember") memberAttributes := entry.GetAttributeValues("uniquemember")
// no need to escape username here, if it's malicious it won't match anything // no need to escape username here, if it's malicious it won't match anything
if slices.Contains(memberAttributes, fmt.Sprintf(ldap.config.SearchFilter, username)) { if slices.Contains(memberAttributes, userDN) {
groups = append(groups, entry.DN) groupDNs = append(groupDNs, entry.DN)
}
}
// Should work for most ldap providers?
groups := []string{}
for _, groupDN := range groupDNs {
groupDN = strings.TrimPrefix(groupDN, "cn=")
parts := strings.SplitN(groupDN, ",", 2)
if len(parts) > 0 {
groups = append(groups, parts[0])
} }
} }

View File

@@ -10,9 +10,10 @@ INSERT INTO sessions (
"expiry", "expiry",
"created_at", "created_at",
"oauth_name", "oauth_name",
"oauth_sub" "oauth_sub",
"ldap_groups"
) VALUES ( ) VALUES (
?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?
) )
RETURNING *; RETURNING *;
@@ -34,7 +35,8 @@ UPDATE "sessions" SET
"oauth_groups" = ?, "oauth_groups" = ?,
"expiry" = ?, "expiry" = ?,
"oauth_name" = ?, "oauth_name" = ?,
"oauth_sub" = ? "oauth_sub" = ?,
"ldap_groups" = ?
WHERE "uuid" = ? WHERE "uuid" = ?
RETURNING *; RETURNING *;

View File

@@ -9,5 +9,6 @@ CREATE TABLE IF NOT EXISTS "sessions" (
"expiry" INTEGER NOT NULL, "expiry" INTEGER NOT NULL,
"created_at" INTEGER NOT NULL, "created_at" INTEGER NOT NULL,
"oauth_name" TEXT NULL, "oauth_name" TEXT NULL,
"oauth_sub" TEXT NULL "oauth_sub" TEXT NULL,
"ldap_groups" TEXT NULL
); );

View File

@@ -19,3 +19,5 @@ sql:
go_type: "string" go_type: "string"
- column: "sessions.oauth_sub" - column: "sessions.oauth_sub"
go_type: "string" go_type: "string"
- column: "sessions.ldap_groups"
go_type: "string"