From 1b2bf3902cbb2b589b60c3f71548c01e480d99e7 Mon Sep 17 00:00:00 2001 From: Stavros Date: Fri, 9 Jan 2026 23:23:36 +0200 Subject: [PATCH] feat: retrieve and store groups from ldap provider --- .../migrations/000005_ldap_groups.down.sql | 1 + .../migrations/000005_ldap_groups.up.sql | 1 + internal/config/config.go | 5 +++++ internal/controller/user_controller.go | 15 +++++++++++++ internal/middleware/context_middleware.go | 15 +++++++++++-- internal/repository/models.go | 1 + internal/repository/queries.sql.go | 21 +++++++++++++----- internal/service/auth_service.go | 18 ++++++++++++++- internal/service/ldap_service.go | 22 ++++++++++++++----- sql/queries.sql | 8 ++++--- sql/schema.sql | 3 ++- sqlc.yml | 2 ++ 12 files changed, 94 insertions(+), 18 deletions(-) create mode 100644 internal/assets/migrations/000005_ldap_groups.down.sql create mode 100644 internal/assets/migrations/000005_ldap_groups.up.sql diff --git a/internal/assets/migrations/000005_ldap_groups.down.sql b/internal/assets/migrations/000005_ldap_groups.down.sql new file mode 100644 index 0000000..047c05c --- /dev/null +++ b/internal/assets/migrations/000005_ldap_groups.down.sql @@ -0,0 +1 @@ +ALTER TABLE "sessions" DROP COLUMN "ldap_groups"; diff --git a/internal/assets/migrations/000005_ldap_groups.up.sql b/internal/assets/migrations/000005_ldap_groups.up.sql new file mode 100644 index 0000000..b75f36f --- /dev/null +++ b/internal/assets/migrations/000005_ldap_groups.up.sql @@ -0,0 +1 @@ +ALTER TABLE "sessions" ADD COLUMN "ldap_groups" TEXT; diff --git a/internal/config/config.go b/internal/config/config.go index ad6f25f..2e6cc6b 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -122,6 +122,11 @@ type User struct { TotpSecret string } +type LdapUser struct { + DN string + Groups []string +} + type UserSearch struct { Username string Type string // local, ldap or unknown diff --git a/internal/controller/user_controller.go b/internal/controller/user_controller.go index 8d32681..c85c451 100644 --- a/internal/controller/user_controller.go +++ b/internal/controller/user_controller.go @@ -141,6 +141,21 @@ func (controller *UserController) loginHandler(c *gin.Context) { Provider: "username", } + if userSearch.Type == "ldap" { + ldapUser, err := controller.auth.GetLdapUser(userSearch.Username) + + if err != nil { + log.Error().Err(err).Str("username", req.Username).Msg("Failed to get LDAP user details") + c.JSON(500, gin.H{ + "status": 500, + "message": "Internal Server Error", + }) + return + } + + sessionCookie.LdapGroups = strings.Join(ldapUser.Groups, ",") + } + log.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie") err = controller.auth.CreateSessionCookie(c, &sessionCookie) diff --git a/internal/middleware/context_middleware.go b/internal/middleware/context_middleware.go index a6bddc9..4ed1050 100644 --- a/internal/middleware/context_middleware.go +++ b/internal/middleware/context_middleware.go @@ -74,6 +74,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { Email: cookie.Email, Provider: "username", IsLoggedIn: true, + LdapGroups: cookie.LdapGroups, }) c.Next() return @@ -155,7 +156,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { Username: user.Username, Name: utils.Capitalize(user.Username), Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), m.config.CookieDomain), - Provider: "basic", + Provider: "username", IsLoggedIn: true, TotpEnabled: user.TotpSecret != "", }) @@ -163,12 +164,22 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { return case "ldap": log.Debug().Msg("Basic auth user is LDAP") + + ldapUser, err := m.auth.GetLdapUser(basic.Username) + + if err != nil { + log.Debug().Err(err).Msg("Error retrieving LDAP user details") + c.Next() + return + } + c.Set("context", &config.UserContext{ Username: basic.Username, Name: utils.Capitalize(basic.Username), Email: fmt.Sprintf("%s@%s", strings.ToLower(basic.Username), m.config.CookieDomain), - Provider: "basic", + Provider: "ldap", IsLoggedIn: true, + LdapGroups: strings.Join(ldapUser.Groups, ","), }) c.Next() return diff --git a/internal/repository/models.go b/internal/repository/models.go index 61f7f80..b1879c0 100644 --- a/internal/repository/models.go +++ b/internal/repository/models.go @@ -16,4 +16,5 @@ type Session struct { CreatedAt int64 OAuthName string OAuthSub string + LdapGroups string } diff --git a/internal/repository/queries.sql.go b/internal/repository/queries.sql.go index e171b7a..78057e5 100644 --- a/internal/repository/queries.sql.go +++ b/internal/repository/queries.sql.go @@ -21,11 +21,12 @@ INSERT INTO sessions ( "expiry", "created_at", "oauth_name", - "oauth_sub" + "oauth_sub", + "ldap_groups" ) VALUES ( - ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? + ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? ) -RETURNING uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub +RETURNING uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub, ldap_groups ` type CreateSessionParams struct { @@ -40,6 +41,7 @@ type CreateSessionParams struct { CreatedAt int64 OAuthName string OAuthSub string + LdapGroups string } func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error) { @@ -55,6 +57,7 @@ func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (S arg.CreatedAt, arg.OAuthName, arg.OAuthSub, + arg.LdapGroups, ) var i Session err := row.Scan( @@ -69,6 +72,7 @@ func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (S &i.CreatedAt, &i.OAuthName, &i.OAuthSub, + &i.LdapGroups, ) return i, err } @@ -94,7 +98,7 @@ func (q *Queries) DeleteSession(ctx context.Context, uuid string) error { } const getSession = `-- name: GetSession :one -SELECT uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub FROM "sessions" +SELECT uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub, ldap_groups FROM "sessions" WHERE "uuid" = ? ` @@ -113,6 +117,7 @@ func (q *Queries) GetSession(ctx context.Context, uuid string) (Session, error) &i.CreatedAt, &i.OAuthName, &i.OAuthSub, + &i.LdapGroups, ) return i, err } @@ -127,9 +132,10 @@ UPDATE "sessions" SET "oauth_groups" = ?, "expiry" = ?, "oauth_name" = ?, - "oauth_sub" = ? + "oauth_sub" = ?, + "ldap_groups" = ? WHERE "uuid" = ? -RETURNING uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub +RETURNING uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub, ldap_groups ` type UpdateSessionParams struct { @@ -142,6 +148,7 @@ type UpdateSessionParams struct { Expiry int64 OAuthName string OAuthSub string + LdapGroups string UUID string } @@ -156,6 +163,7 @@ func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (S arg.Expiry, arg.OAuthName, arg.OAuthSub, + arg.LdapGroups, arg.UUID, ) var i Session @@ -171,6 +179,7 @@ func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (S &i.CreatedAt, &i.OAuthName, &i.OAuthSub, + &i.LdapGroups, ) return i, err } diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index cb93c92..4e14c59 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -70,7 +70,7 @@ func (auth *AuthService) SearchUser(username string) config.UserSearch { } if auth.ldap != nil { - userDN, err := auth.ldap.Search(username) + userDN, err := auth.ldap.GetUserDN(username) if err != nil { log.Warn().Err(err).Str("username", username).Msg("Failed to search for user in LDAP") @@ -131,6 +131,19 @@ func (auth *AuthService) GetLocalUser(username string) config.User { return config.User{} } +func (auth *AuthService) GetLdapUser(userDN string) (config.LdapUser, error) { + groups, err := auth.ldap.GetUserGroups(userDN) + + if err != nil { + return config.LdapUser{}, err + } + + return config.LdapUser{ + DN: userDN, + Groups: groups, + }, nil +} + func (auth *AuthService) CheckPassword(user config.User, password string) bool { return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) == nil } @@ -217,6 +230,7 @@ func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *repository.Se CreatedAt: time.Now().Unix(), OAuthName: data.OAuthName, OAuthSub: data.OAuthSub, + LdapGroups: data.LdapGroups, } _, err = auth.queries.CreateSession(c, session) @@ -270,6 +284,7 @@ func (auth *AuthService) RefreshSessionCookie(c *gin.Context) error { OAuthName: session.OAuthName, OAuthSub: session.OAuthSub, UUID: session.UUID, + LdapGroups: session.LdapGroups, }) if err != nil { @@ -346,6 +361,7 @@ func (auth *AuthService) GetSessionCookie(c *gin.Context) (repository.Session, e OAuthGroups: session.OAuthGroups, OAuthName: session.OAuthName, OAuthSub: session.OAuthSub, + LdapGroups: session.LdapGroups, }, nil } diff --git a/internal/service/ldap_service.go b/internal/service/ldap_service.go index d88d180..fc0e6c6 100644 --- a/internal/service/ldap_service.go +++ b/internal/service/ldap_service.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "fmt" "slices" + "strings" "sync" "time" @@ -117,7 +118,7 @@ func (ldap *LdapService) connect() (*ldapgo.Conn, error) { return ldap.conn, nil } -func (ldap *LdapService) Search(username string) (string, error) { +func (ldap *LdapService) GetUserDN(username string) (string, error) { // Escape the username to prevent LDAP injection escapedUsername := ldapgo.EscapeFilter(username) filter := fmt.Sprintf(ldap.config.SearchFilter, escapedUsername) @@ -146,7 +147,7 @@ func (ldap *LdapService) Search(username string) (string, error) { return userDN, nil } -func (ldap *LdapService) GetUserGroups(username string) ([]string, error) { +func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) { searchRequest := ldapgo.NewSearchRequest( ldap.config.BaseDN, ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false, @@ -163,13 +164,24 @@ func (ldap *LdapService) GetUserGroups(username string) ([]string, error) { return []string{}, err } - groups := []string{} + groupDNs := []string{} for _, entry := range searchResult.Entries { memberAttributes := entry.GetAttributeValues("uniquemember") // no need to escape username here, if it's malicious it won't match anything - if slices.Contains(memberAttributes, fmt.Sprintf(ldap.config.SearchFilter, username)) { - groups = append(groups, entry.DN) + if slices.Contains(memberAttributes, userDN) { + groupDNs = append(groupDNs, entry.DN) + } + } + + // Should work for most ldap providers? + groups := []string{} + + for _, groupDN := range groupDNs { + groupDN = strings.TrimPrefix(groupDN, "cn=") + parts := strings.SplitN(groupDN, ",", 2) + if len(parts) > 0 { + groups = append(groups, parts[0]) } } diff --git a/sql/queries.sql b/sql/queries.sql index 9fde4e2..9004edd 100644 --- a/sql/queries.sql +++ b/sql/queries.sql @@ -10,9 +10,10 @@ INSERT INTO sessions ( "expiry", "created_at", "oauth_name", - "oauth_sub" + "oauth_sub", + "ldap_groups" ) VALUES ( - ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? + ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? ) RETURNING *; @@ -34,7 +35,8 @@ UPDATE "sessions" SET "oauth_groups" = ?, "expiry" = ?, "oauth_name" = ?, - "oauth_sub" = ? + "oauth_sub" = ?, + "ldap_groups" = ? WHERE "uuid" = ? RETURNING *; diff --git a/sql/schema.sql b/sql/schema.sql index a7f37eb..bd5abdf 100644 --- a/sql/schema.sql +++ b/sql/schema.sql @@ -9,5 +9,6 @@ CREATE TABLE IF NOT EXISTS "sessions" ( "expiry" INTEGER NOT NULL, "created_at" INTEGER NOT NULL, "oauth_name" TEXT NULL, - "oauth_sub" TEXT NULL + "oauth_sub" TEXT NULL, + "ldap_groups" TEXT NULL ); diff --git a/sqlc.yml b/sqlc.yml index 77b3a71..b9cf1ea 100644 --- a/sqlc.yml +++ b/sqlc.yml @@ -19,3 +19,5 @@ sql: go_type: "string" - column: "sessions.oauth_sub" go_type: "string" + - column: "sessions.ldap_groups" + go_type: "string"