mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-01-11 09:52:30 +00:00
feat: retrieve and store groups from ldap provider
This commit is contained in:
1
internal/assets/migrations/000005_ldap_groups.down.sql
Normal file
1
internal/assets/migrations/000005_ldap_groups.down.sql
Normal file
@@ -0,0 +1 @@
|
||||
ALTER TABLE "sessions" DROP COLUMN "ldap_groups";
|
||||
1
internal/assets/migrations/000005_ldap_groups.up.sql
Normal file
1
internal/assets/migrations/000005_ldap_groups.up.sql
Normal file
@@ -0,0 +1 @@
|
||||
ALTER TABLE "sessions" ADD COLUMN "ldap_groups" TEXT;
|
||||
@@ -122,6 +122,11 @@ type User struct {
|
||||
TotpSecret string
|
||||
}
|
||||
|
||||
type LdapUser struct {
|
||||
DN string
|
||||
Groups []string
|
||||
}
|
||||
|
||||
type UserSearch struct {
|
||||
Username string
|
||||
Type string // local, ldap or unknown
|
||||
|
||||
@@ -141,6 +141,21 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
||||
Provider: "username",
|
||||
}
|
||||
|
||||
if userSearch.Type == "ldap" {
|
||||
ldapUser, err := controller.auth.GetLdapUser(userSearch.Username)
|
||||
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("username", req.Username).Msg("Failed to get LDAP user details")
|
||||
c.JSON(500, gin.H{
|
||||
"status": 500,
|
||||
"message": "Internal Server Error",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
sessionCookie.LdapGroups = strings.Join(ldapUser.Groups, ",")
|
||||
}
|
||||
|
||||
log.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
|
||||
|
||||
err = controller.auth.CreateSessionCookie(c, &sessionCookie)
|
||||
|
||||
@@ -74,6 +74,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
|
||||
Email: cookie.Email,
|
||||
Provider: "username",
|
||||
IsLoggedIn: true,
|
||||
LdapGroups: cookie.LdapGroups,
|
||||
})
|
||||
c.Next()
|
||||
return
|
||||
@@ -155,7 +156,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
|
||||
Username: user.Username,
|
||||
Name: utils.Capitalize(user.Username),
|
||||
Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), m.config.CookieDomain),
|
||||
Provider: "basic",
|
||||
Provider: "username",
|
||||
IsLoggedIn: true,
|
||||
TotpEnabled: user.TotpSecret != "",
|
||||
})
|
||||
@@ -163,12 +164,22 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
|
||||
return
|
||||
case "ldap":
|
||||
log.Debug().Msg("Basic auth user is LDAP")
|
||||
|
||||
ldapUser, err := m.auth.GetLdapUser(basic.Username)
|
||||
|
||||
if err != nil {
|
||||
log.Debug().Err(err).Msg("Error retrieving LDAP user details")
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
c.Set("context", &config.UserContext{
|
||||
Username: basic.Username,
|
||||
Name: utils.Capitalize(basic.Username),
|
||||
Email: fmt.Sprintf("%s@%s", strings.ToLower(basic.Username), m.config.CookieDomain),
|
||||
Provider: "basic",
|
||||
Provider: "ldap",
|
||||
IsLoggedIn: true,
|
||||
LdapGroups: strings.Join(ldapUser.Groups, ","),
|
||||
})
|
||||
c.Next()
|
||||
return
|
||||
|
||||
@@ -16,4 +16,5 @@ type Session struct {
|
||||
CreatedAt int64
|
||||
OAuthName string
|
||||
OAuthSub string
|
||||
LdapGroups string
|
||||
}
|
||||
|
||||
@@ -21,11 +21,12 @@ INSERT INTO sessions (
|
||||
"expiry",
|
||||
"created_at",
|
||||
"oauth_name",
|
||||
"oauth_sub"
|
||||
"oauth_sub",
|
||||
"ldap_groups"
|
||||
) VALUES (
|
||||
?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?
|
||||
?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?
|
||||
)
|
||||
RETURNING uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub
|
||||
RETURNING uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub, ldap_groups
|
||||
`
|
||||
|
||||
type CreateSessionParams struct {
|
||||
@@ -40,6 +41,7 @@ type CreateSessionParams struct {
|
||||
CreatedAt int64
|
||||
OAuthName string
|
||||
OAuthSub string
|
||||
LdapGroups string
|
||||
}
|
||||
|
||||
func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error) {
|
||||
@@ -55,6 +57,7 @@ func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (S
|
||||
arg.CreatedAt,
|
||||
arg.OAuthName,
|
||||
arg.OAuthSub,
|
||||
arg.LdapGroups,
|
||||
)
|
||||
var i Session
|
||||
err := row.Scan(
|
||||
@@ -69,6 +72,7 @@ func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (S
|
||||
&i.CreatedAt,
|
||||
&i.OAuthName,
|
||||
&i.OAuthSub,
|
||||
&i.LdapGroups,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -94,7 +98,7 @@ func (q *Queries) DeleteSession(ctx context.Context, uuid string) error {
|
||||
}
|
||||
|
||||
const getSession = `-- name: GetSession :one
|
||||
SELECT uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub FROM "sessions"
|
||||
SELECT uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub, ldap_groups FROM "sessions"
|
||||
WHERE "uuid" = ?
|
||||
`
|
||||
|
||||
@@ -113,6 +117,7 @@ func (q *Queries) GetSession(ctx context.Context, uuid string) (Session, error)
|
||||
&i.CreatedAt,
|
||||
&i.OAuthName,
|
||||
&i.OAuthSub,
|
||||
&i.LdapGroups,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -127,9 +132,10 @@ UPDATE "sessions" SET
|
||||
"oauth_groups" = ?,
|
||||
"expiry" = ?,
|
||||
"oauth_name" = ?,
|
||||
"oauth_sub" = ?
|
||||
"oauth_sub" = ?,
|
||||
"ldap_groups" = ?
|
||||
WHERE "uuid" = ?
|
||||
RETURNING uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub
|
||||
RETURNING uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub, ldap_groups
|
||||
`
|
||||
|
||||
type UpdateSessionParams struct {
|
||||
@@ -142,6 +148,7 @@ type UpdateSessionParams struct {
|
||||
Expiry int64
|
||||
OAuthName string
|
||||
OAuthSub string
|
||||
LdapGroups string
|
||||
UUID string
|
||||
}
|
||||
|
||||
@@ -156,6 +163,7 @@ func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (S
|
||||
arg.Expiry,
|
||||
arg.OAuthName,
|
||||
arg.OAuthSub,
|
||||
arg.LdapGroups,
|
||||
arg.UUID,
|
||||
)
|
||||
var i Session
|
||||
@@ -171,6 +179,7 @@ func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (S
|
||||
&i.CreatedAt,
|
||||
&i.OAuthName,
|
||||
&i.OAuthSub,
|
||||
&i.LdapGroups,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
@@ -70,7 +70,7 @@ func (auth *AuthService) SearchUser(username string) config.UserSearch {
|
||||
}
|
||||
|
||||
if auth.ldap != nil {
|
||||
userDN, err := auth.ldap.Search(username)
|
||||
userDN, err := auth.ldap.GetUserDN(username)
|
||||
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Str("username", username).Msg("Failed to search for user in LDAP")
|
||||
@@ -131,6 +131,19 @@ func (auth *AuthService) GetLocalUser(username string) config.User {
|
||||
return config.User{}
|
||||
}
|
||||
|
||||
func (auth *AuthService) GetLdapUser(userDN string) (config.LdapUser, error) {
|
||||
groups, err := auth.ldap.GetUserGroups(userDN)
|
||||
|
||||
if err != nil {
|
||||
return config.LdapUser{}, err
|
||||
}
|
||||
|
||||
return config.LdapUser{
|
||||
DN: userDN,
|
||||
Groups: groups,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (auth *AuthService) CheckPassword(user config.User, password string) bool {
|
||||
return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) == nil
|
||||
}
|
||||
@@ -217,6 +230,7 @@ func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *repository.Se
|
||||
CreatedAt: time.Now().Unix(),
|
||||
OAuthName: data.OAuthName,
|
||||
OAuthSub: data.OAuthSub,
|
||||
LdapGroups: data.LdapGroups,
|
||||
}
|
||||
|
||||
_, err = auth.queries.CreateSession(c, session)
|
||||
@@ -270,6 +284,7 @@ func (auth *AuthService) RefreshSessionCookie(c *gin.Context) error {
|
||||
OAuthName: session.OAuthName,
|
||||
OAuthSub: session.OAuthSub,
|
||||
UUID: session.UUID,
|
||||
LdapGroups: session.LdapGroups,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
@@ -346,6 +361,7 @@ func (auth *AuthService) GetSessionCookie(c *gin.Context) (repository.Session, e
|
||||
OAuthGroups: session.OAuthGroups,
|
||||
OAuthName: session.OAuthName,
|
||||
OAuthSub: session.OAuthSub,
|
||||
LdapGroups: session.LdapGroups,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -117,7 +118,7 @@ func (ldap *LdapService) connect() (*ldapgo.Conn, error) {
|
||||
return ldap.conn, nil
|
||||
}
|
||||
|
||||
func (ldap *LdapService) Search(username string) (string, error) {
|
||||
func (ldap *LdapService) GetUserDN(username string) (string, error) {
|
||||
// Escape the username to prevent LDAP injection
|
||||
escapedUsername := ldapgo.EscapeFilter(username)
|
||||
filter := fmt.Sprintf(ldap.config.SearchFilter, escapedUsername)
|
||||
@@ -146,7 +147,7 @@ func (ldap *LdapService) Search(username string) (string, error) {
|
||||
return userDN, nil
|
||||
}
|
||||
|
||||
func (ldap *LdapService) GetUserGroups(username string) ([]string, error) {
|
||||
func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) {
|
||||
searchRequest := ldapgo.NewSearchRequest(
|
||||
ldap.config.BaseDN,
|
||||
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
|
||||
@@ -163,13 +164,24 @@ func (ldap *LdapService) GetUserGroups(username string) ([]string, error) {
|
||||
return []string{}, err
|
||||
}
|
||||
|
||||
groups := []string{}
|
||||
groupDNs := []string{}
|
||||
|
||||
for _, entry := range searchResult.Entries {
|
||||
memberAttributes := entry.GetAttributeValues("uniquemember")
|
||||
// no need to escape username here, if it's malicious it won't match anything
|
||||
if slices.Contains(memberAttributes, fmt.Sprintf(ldap.config.SearchFilter, username)) {
|
||||
groups = append(groups, entry.DN)
|
||||
if slices.Contains(memberAttributes, userDN) {
|
||||
groupDNs = append(groupDNs, entry.DN)
|
||||
}
|
||||
}
|
||||
|
||||
// Should work for most ldap providers?
|
||||
groups := []string{}
|
||||
|
||||
for _, groupDN := range groupDNs {
|
||||
groupDN = strings.TrimPrefix(groupDN, "cn=")
|
||||
parts := strings.SplitN(groupDN, ",", 2)
|
||||
if len(parts) > 0 {
|
||||
groups = append(groups, parts[0])
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -10,9 +10,10 @@ INSERT INTO sessions (
|
||||
"expiry",
|
||||
"created_at",
|
||||
"oauth_name",
|
||||
"oauth_sub"
|
||||
"oauth_sub",
|
||||
"ldap_groups"
|
||||
) VALUES (
|
||||
?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?
|
||||
?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?
|
||||
)
|
||||
RETURNING *;
|
||||
|
||||
@@ -34,7 +35,8 @@ UPDATE "sessions" SET
|
||||
"oauth_groups" = ?,
|
||||
"expiry" = ?,
|
||||
"oauth_name" = ?,
|
||||
"oauth_sub" = ?
|
||||
"oauth_sub" = ?,
|
||||
"ldap_groups" = ?
|
||||
WHERE "uuid" = ?
|
||||
RETURNING *;
|
||||
|
||||
|
||||
@@ -9,5 +9,6 @@ CREATE TABLE IF NOT EXISTS "sessions" (
|
||||
"expiry" INTEGER NOT NULL,
|
||||
"created_at" INTEGER NOT NULL,
|
||||
"oauth_name" TEXT NULL,
|
||||
"oauth_sub" TEXT NULL
|
||||
"oauth_sub" TEXT NULL,
|
||||
"ldap_groups" TEXT NULL
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user