Compare commits

...

3 Commits

Author SHA1 Message Date
Stavros 5093200023 fix: review nitpicks 2026-01-16 17:53:08 +02:00
Stavros 95ed36db8b feat: store ldap group results in cache 2026-01-16 17:50:04 +02:00
Stavros 4fc18325a0 refactor: rework ldap group fetching logic 2026-01-16 17:06:22 +02:00
23 changed files with 208 additions and 119 deletions
+13
View File
@@ -0,0 +1,13 @@
[
{
"label": "Attach to remote Delve",
"adapter": "Delve",
"mode": "remote",
"remotePath": "/tinyauth",
"request": "attach",
"tcp_connection": {
"host": "127.0.0.1",
"port": 4000,
},
},
]
+5
View File
@@ -62,3 +62,8 @@ develop:
# Production
prod:
docker compose -f $(PROD_COMPOSE) up --force-recreate --pull=always --remove-orphans
# SQL
.PHONY: sql
sql:
sqlc generate
+6 -5
View File
@@ -21,9 +21,9 @@ func NewTinyauthCmdConfiguration() *config.Config {
Address: "0.0.0.0",
},
Auth: config.AuthConfig{
SessionExpiry: 3600,
SessionMaxLifetime: 0,
LoginTimeout: 300,
SessionExpiry: 86400, // 1 day
SessionMaxLifetime: 0, // disabled
LoginTimeout: 300, // 5 minutes
LoginMaxRetries: 3,
},
UI: config.UIConfig{
@@ -32,8 +32,9 @@ func NewTinyauthCmdConfiguration() *config.Config {
BackgroundImage: "/background.jpg",
},
Ldap: config.LdapConfig{
Insecure: false,
SearchFilter: "(uid=%s)",
Insecure: false,
SearchFilter: "(uid=%s)",
GroupCacheTTL: 900, // 15 minutes
},
Log: config.LogConfig{
Level: "info",
+4 -2
View File
@@ -50,10 +50,12 @@ export const LoginPage = () => {
const redirectUri = searchParams.get("redirect_uri");
const oauthProviders = providers.filter(
(provider) => provider.id !== "username",
(provider) => provider.id !== "local" && provider.id !== "ldap",
);
const userAuthConfigured =
providers.find((provider) => provider.id === "username") !== undefined;
providers.find(
(provider) => provider.id === "local" || provider.id === "ldap",
) !== undefined;
const oauthMutation = useMutation({
mutationFn: (provider: string) =>
@@ -1 +0,0 @@
ALTER TABLE "sessions" DROP COLUMN "ldap_groups";
@@ -1 +0,0 @@
ALTER TABLE "sessions" ADD COLUMN "ldap_groups" TEXT;
+11 -3
View File
@@ -144,10 +144,18 @@ func (app *BootstrapApp) Setup() error {
return configuredProviders[i].Name < configuredProviders[j].Name
})
if services.authService.UserAuthConfigured() {
if services.authService.LocalAuthConfigured() {
configuredProviders = append(configuredProviders, controller.Provider{
Name: "Username",
ID: "username",
Name: "Local",
ID: "local",
OAuth: false,
})
}
if services.authService.LdapAuthConfigured() {
configuredProviders = append(configuredProviders, controller.Provider{
Name: "LDAP",
ID: "ldap",
OAuth: false,
})
}
+1
View File
@@ -67,6 +67,7 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er
LoginMaxRetries: app.config.Auth.LoginMaxRetries,
SessionCookieName: app.context.sessionCookieName,
IP: app.config.Auth.IP,
LDAPGroupsCacheTTL: app.config.Ldap.GroupCacheTTL,
}, dockerService, services.ldapService, queries)
err = authService.Init()
+15 -8
View File
@@ -67,14 +67,15 @@ type UIConfig struct {
}
type LdapConfig struct {
Address string `description:"LDAP server address." yaml:"address"`
BindDN string `description:"Bind DN for LDAP authentication." yaml:"bindDn"`
BindPassword string `description:"Bind password for LDAP authentication." yaml:"bindPassword"`
BaseDN string `description:"Base DN for LDAP searches." yaml:"baseDn"`
Insecure bool `description:"Allow insecure LDAP connections." yaml:"insecure"`
SearchFilter string `description:"LDAP search filter." yaml:"searchFilter"`
AuthCert string `description:"Certificate for mTLS authentication." yaml:"authCert"`
AuthKey string `description:"Certificate key for mTLS authentication." yaml:"authKey"`
Address string `description:"LDAP server address." yaml:"address"`
BindDN string `description:"Bind DN for LDAP authentication." yaml:"bindDn"`
BindPassword string `description:"Bind password for LDAP authentication." yaml:"bindPassword"`
BaseDN string `description:"Base DN for LDAP searches." yaml:"baseDn"`
Insecure bool `description:"Allow insecure LDAP connections." yaml:"insecure"`
SearchFilter string `description:"LDAP search filter." yaml:"searchFilter"`
AuthCert string `description:"Certificate for mTLS authentication." yaml:"authCert"`
AuthKey string `description:"Certificate key for mTLS authentication." yaml:"authKey"`
GroupCacheTTL int `description:"Cache duration for LDAP group membership in seconds." yaml:"groupCacheTTL"`
}
type LogConfig struct {
@@ -153,6 +154,7 @@ type UserContext struct {
Name string
Email string
IsLoggedIn bool
IsBasicAuth bool
OAuth bool
Provider string
TotpPending bool
@@ -189,6 +191,7 @@ type App struct {
IP AppIP `description:"IP access configuration." yaml:"ip"`
Response AppResponse `description:"Response customization." yaml:"response"`
Path AppPath `description:"Path access configuration." yaml:"path"`
LDAP AppLDAP `description:"LDAP access configuration." yaml:"ldap"`
}
type AppConfig struct {
@@ -205,6 +208,10 @@ type AppOAuth struct {
Groups string `description:"Comma-separated list of required OAuth groups." yaml:"groups"`
}
type AppLDAP struct {
Groups string `description:"Comma-separated list of required LDAP groups." yaml:"groups"`
}
type AppIP struct {
Allow []string `description:"List of allowed IPs or CIDR ranges." yaml:"allow"`
Block []string `description:"List of blocked IPs or CIDR ranges." yaml:"block"`
@@ -21,7 +21,6 @@ type UserContextResponse struct {
OAuth bool `json:"oauth"`
TotpPending bool `json:"totpPending"`
OAuthName string `json:"oauthName"`
OAuthSub string `json:"oauthSub"`
}
type AppContextResponse struct {
@@ -90,7 +89,6 @@ func (controller *ContextController) userContextHandler(c *gin.Context) {
OAuth: context.OAuth,
TotpPending: context.TotpPending,
OAuthName: context.OAuthName,
OAuthSub: context.OAuthSub,
}
if err != nil {
@@ -16,8 +16,8 @@ import (
var controllerCfg = controller.ContextControllerConfig{
Providers: []controller.Provider{
{
Name: "Username",
ID: "username",
Name: "Local",
ID: "local",
OAuth: false,
},
{
@@ -40,8 +40,9 @@ var userContext = config.UserContext{
Name: "testuser",
Email: "test@example.com",
IsLoggedIn: true,
IsBasicAuth: false,
OAuth: false,
Provider: "username",
Provider: "local",
TotpPending: false,
OAuthGroups: "",
TotpEnabled: false,
+1 -1
View File
@@ -189,7 +189,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
username = user.PreferredUsername
} else {
tlog.App.Debug().Msg("No preferred username from OAuth provider, using pseudo username")
username = strings.Replace(user.Email, "@", "_", -1)
username = strings.Replace(user.Email, "@", "_", 1)
}
sessionCookie := repository.Session{
+17 -5
View File
@@ -173,7 +173,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
tlog.App.Trace().Interface("context", userContext).Msg("User context from request")
if userContext.Provider == "basic" && userContext.TotpEnabled {
if userContext.IsBasicAuth && userContext.TotpEnabled {
tlog.App.Debug().Msg("User has TOTP enabled, denying basic auth access")
userContext.IsLoggedIn = false
}
@@ -212,11 +212,17 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
if userContext.OAuth {
groupOK := controller.auth.IsInOAuthGroup(c, userContext, acls.OAuth.Groups)
if userContext.OAuth || userContext.Provider == "ldap" {
var groupOK bool
if userContext.OAuth {
groupOK = controller.auth.IsInOAuthGroup(c, userContext, acls.OAuth.Groups)
} else {
groupOK = controller.auth.IsInLdapGroup(c, userContext, acls.LDAP.Groups)
}
if !groupOK {
tlog.App.Warn().Str("user", userContext.Username).Str("resource", strings.Split(host, ".")[0]).Msg("User OAuth groups do not match resource requirements")
tlog.App.Warn().Str("user", userContext.Username).Str("resource", strings.Split(host, ".")[0]).Msg("User groups do not match resource requirements")
if req.Proxy == "nginx" || !isBrowser {
c.JSON(403, gin.H{
@@ -251,7 +257,13 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
c.Header("Remote-User", utils.SanitizeHeader(userContext.Username))
c.Header("Remote-Name", utils.SanitizeHeader(userContext.Name))
c.Header("Remote-Email", utils.SanitizeHeader(userContext.Email))
c.Header("Remote-Groups", utils.SanitizeHeader(userContext.OAuthGroups))
if userContext.Provider == "ldap" {
c.Header("Remote-Groups", utils.SanitizeHeader(userContext.LdapGroups))
} else if userContext.Provider != "local" {
c.Header("Remote-Groups", utils.SanitizeHeader(userContext.OAuthGroups))
}
c.Header("Remote-Sub", utils.SanitizeHeader(userContext.OAuthSub))
controller.setHeaders(c, acls)
+4 -3
View File
@@ -147,7 +147,7 @@ func TestProxyHandler(t *testing.T) {
Username: "testuser",
Name: "testuser",
Email: "testuser@example.com",
Provider: "username",
Provider: "local",
TotpPending: false,
OAuthGroups: "",
})
@@ -164,7 +164,7 @@ func TestProxyHandler(t *testing.T) {
Email: "testuser@example.com",
IsLoggedIn: true,
OAuth: false,
Provider: "username",
Provider: "local",
TotpPending: false,
OAuthGroups: "",
TotpEnabled: false,
@@ -192,8 +192,9 @@ func TestProxyHandler(t *testing.T) {
Name: "testuser",
Email: "testuser@example.com",
IsLoggedIn: true,
IsBasicAuth: true,
OAuth: false,
Provider: "basic",
Provider: "local",
TotpPending: false,
OAuthGroups: "",
TotpEnabled: true,
+4 -15
View File
@@ -116,7 +116,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
Username: user.Username,
Name: utils.Capitalize(req.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.config.CookieDomain),
Provider: "username",
Provider: "local",
TotpPending: true,
})
@@ -142,22 +142,11 @@ func (controller *UserController) loginHandler(c *gin.Context) {
Username: req.Username,
Name: utils.Capitalize(req.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.config.CookieDomain),
Provider: "username",
Provider: "local",
}
if userSearch.Type == "ldap" {
ldapUser, err := controller.auth.GetLdapUser(userSearch.Username)
if err != nil {
tlog.App.Error().Err(err).Str("username", req.Username).Msg("Failed to get LDAP user details")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
})
return
}
sessionCookie.LdapGroups = strings.Join(ldapUser.Groups, ",")
sessionCookie.Provider = "ldap"
}
tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
@@ -267,7 +256,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
Username: user.Username,
Name: utils.Capitalize(user.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), controller.config.CookieDomain),
Provider: "username",
Provider: "local",
}
tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
+3 -3
View File
@@ -204,7 +204,7 @@ func TestTotpHandler(t *testing.T) {
Email: "totpuser@example.com",
IsLoggedIn: false,
OAuth: false,
Provider: "username",
Provider: "local",
TotpPending: true,
OAuthGroups: "",
TotpEnabled: true,
@@ -267,7 +267,7 @@ func TestTotpHandler(t *testing.T) {
Email: "totpuser@example.com",
IsLoggedIn: false,
OAuth: false,
Provider: "username",
Provider: "local",
TotpPending: true,
OAuthGroups: "",
TotpEnabled: true,
@@ -290,7 +290,7 @@ func TestTotpHandler(t *testing.T) {
Email: "totpuser@example.com",
IsLoggedIn: false,
OAuth: false,
Provider: "username",
Provider: "local",
TotpPending: false,
OAuthGroups: "",
TotpEnabled: false,
+35 -12
View File
@@ -49,7 +49,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
Username: cookie.Username,
Name: cookie.Name,
Email: cookie.Email,
Provider: "username",
Provider: "local",
TotpPending: true,
TotpEnabled: true,
})
@@ -58,23 +58,44 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
}
switch cookie.Provider {
case "username":
case "local", "ldap":
userSearch := m.auth.SearchUser(cookie.Username)
if userSearch.Type == "unknown" || userSearch.Type == "error" {
if userSearch.Type == "unknown" {
tlog.App.Debug().Msg("User from session cookie not found")
m.auth.DeleteSessionCookie(c)
goto basic
}
if userSearch.Type != cookie.Provider {
tlog.App.Warn().Msg("User type from session cookie does not match user search type")
m.auth.DeleteSessionCookie(c)
c.Next()
return
}
var ldapGroups []string
if cookie.Provider == "ldap" {
ldapUser, err := m.auth.GetLdapUser(userSearch.Username)
if err != nil {
tlog.App.Error().Err(err).Msg("Error retrieving LDAP user details")
c.Next()
return
}
ldapGroups = ldapUser.Groups
}
m.auth.RefreshSessionCookie(c)
c.Set("context", &config.UserContext{
Username: cookie.Username,
Name: cookie.Name,
Email: cookie.Email,
Provider: "username",
Provider: cookie.Provider,
IsLoggedIn: true,
LdapGroups: cookie.LdapGroups,
LdapGroups: strings.Join(ldapGroups, ","),
})
c.Next()
return
@@ -156,9 +177,10 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
Username: user.Username,
Name: utils.Capitalize(user.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), m.config.CookieDomain),
Provider: "username",
Provider: "local",
IsLoggedIn: true,
TotpEnabled: user.TotpSecret != "",
IsBasicAuth: true,
})
c.Next()
return
@@ -174,12 +196,13 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
}
c.Set("context", &config.UserContext{
Username: basic.Username,
Name: utils.Capitalize(basic.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(basic.Username), m.config.CookieDomain),
Provider: "ldap",
IsLoggedIn: true,
LdapGroups: strings.Join(ldapUser.Groups, ","),
Username: basic.Username,
Name: utils.Capitalize(basic.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(basic.Username), m.config.CookieDomain),
Provider: "ldap",
IsLoggedIn: true,
LdapGroups: strings.Join(ldapUser.Groups, ","),
IsBasicAuth: true,
})
c.Next()
return
-1
View File
@@ -16,5 +16,4 @@ type Session struct {
CreatedAt int64
OAuthName string
OAuthSub string
LdapGroups string
}
+6 -15
View File
@@ -21,12 +21,11 @@ INSERT INTO sessions (
"expiry",
"created_at",
"oauth_name",
"oauth_sub",
"ldap_groups"
"oauth_sub"
) VALUES (
?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?
?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?
)
RETURNING uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub, ldap_groups
RETURNING uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub
`
type CreateSessionParams struct {
@@ -41,7 +40,6 @@ type CreateSessionParams struct {
CreatedAt int64
OAuthName string
OAuthSub string
LdapGroups string
}
func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error) {
@@ -57,7 +55,6 @@ func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (S
arg.CreatedAt,
arg.OAuthName,
arg.OAuthSub,
arg.LdapGroups,
)
var i Session
err := row.Scan(
@@ -72,7 +69,6 @@ func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (S
&i.CreatedAt,
&i.OAuthName,
&i.OAuthSub,
&i.LdapGroups,
)
return i, err
}
@@ -98,7 +94,7 @@ func (q *Queries) DeleteSession(ctx context.Context, uuid string) error {
}
const getSession = `-- name: GetSession :one
SELECT uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub, ldap_groups FROM "sessions"
SELECT uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub FROM "sessions"
WHERE "uuid" = ?
`
@@ -117,7 +113,6 @@ func (q *Queries) GetSession(ctx context.Context, uuid string) (Session, error)
&i.CreatedAt,
&i.OAuthName,
&i.OAuthSub,
&i.LdapGroups,
)
return i, err
}
@@ -132,10 +127,9 @@ UPDATE "sessions" SET
"oauth_groups" = ?,
"expiry" = ?,
"oauth_name" = ?,
"oauth_sub" = ?,
"ldap_groups" = ?
"oauth_sub" = ?
WHERE "uuid" = ?
RETURNING uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub, ldap_groups
RETURNING uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub
`
type UpdateSessionParams struct {
@@ -148,7 +142,6 @@ type UpdateSessionParams struct {
Expiry int64
OAuthName string
OAuthSub string
LdapGroups string
UUID string
}
@@ -163,7 +156,6 @@ func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (S
arg.Expiry,
arg.OAuthName,
arg.OAuthSub,
arg.LdapGroups,
arg.UUID,
)
var i Session
@@ -179,7 +171,6 @@ func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (S
&i.CreatedAt,
&i.OAuthName,
&i.OAuthSub,
&i.LdapGroups,
)
return i, err
}
+61 -17
View File
@@ -19,6 +19,11 @@ import (
"golang.org/x/crypto/bcrypt"
)
type LdapGroupsCache struct {
Groups []string
Expires time.Time
}
type LoginAttempt struct {
FailedAttempts int
LastAttempt time.Time
@@ -36,24 +41,28 @@ type AuthServiceConfig struct {
LoginMaxRetries int
SessionCookieName string
IP config.IPConfig
LDAPGroupsCacheTTL int
}
type AuthService struct {
config AuthServiceConfig
docker *DockerService
loginAttempts map[string]*LoginAttempt
loginMutex sync.RWMutex
ldap *LdapService
queries *repository.Queries
config AuthServiceConfig
docker *DockerService
loginAttempts map[string]*LoginAttempt
ldapGroupsCache map[string]*LdapGroupsCache
loginMutex sync.RWMutex
ldapGroupsMutex sync.RWMutex
ldap *LdapService
queries *repository.Queries
}
func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapService, queries *repository.Queries) *AuthService {
return &AuthService{
config: config,
docker: docker,
loginAttempts: make(map[string]*LoginAttempt),
ldap: ldap,
queries: queries,
config: config,
docker: docker,
loginAttempts: make(map[string]*LoginAttempt),
ldapGroupsCache: make(map[string]*LdapGroupsCache),
ldap: ldap,
queries: queries,
}
}
@@ -75,7 +84,7 @@ func (auth *AuthService) SearchUser(username string) config.UserSearch {
if err != nil {
tlog.App.Warn().Err(err).Str("username", username).Msg("Failed to search for user in LDAP")
return config.UserSearch{
Type: "error",
Type: "unknown",
}
}
@@ -132,12 +141,30 @@ func (auth *AuthService) GetLocalUser(username string) config.User {
}
func (auth *AuthService) GetLdapUser(userDN string) (config.LdapUser, error) {
auth.ldapGroupsMutex.Lock()
entry, exists := auth.ldapGroupsCache[userDN]
auth.ldapGroupsMutex.Unlock()
if exists && time.Now().Before(entry.Expires) {
return config.LdapUser{
DN: userDN,
Groups: entry.Groups,
}, nil
}
groups, err := auth.ldap.GetUserGroups(userDN)
if err != nil {
return config.LdapUser{}, err
}
auth.ldapGroupsMutex.Lock()
auth.ldapGroupsCache[userDN] = &LdapGroupsCache{
Groups: groups,
Expires: time.Now().Add(time.Duration(auth.config.LDAPGroupsCacheTTL) * time.Second),
}
auth.ldapGroupsMutex.Unlock()
return config.LdapUser{
DN: userDN,
Groups: groups,
@@ -230,7 +257,6 @@ func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *repository.Se
CreatedAt: time.Now().Unix(),
OAuthName: data.OAuthName,
OAuthSub: data.OAuthSub,
LdapGroups: data.LdapGroups,
}
_, err = auth.queries.CreateSession(c, session)
@@ -284,7 +310,6 @@ func (auth *AuthService) RefreshSessionCookie(c *gin.Context) error {
OAuthName: session.OAuthName,
OAuthSub: session.OAuthSub,
UUID: session.UUID,
LdapGroups: session.LdapGroups,
})
if err != nil {
@@ -361,12 +386,15 @@ func (auth *AuthService) GetSessionCookie(c *gin.Context) (repository.Session, e
OAuthGroups: session.OAuthGroups,
OAuthName: session.OAuthName,
OAuthSub: session.OAuthSub,
LdapGroups: session.LdapGroups,
}, nil
}
func (auth *AuthService) UserAuthConfigured() bool {
return len(auth.config.Users) > 0 || auth.ldap != nil
func (auth *AuthService) LocalAuthConfigured() bool {
return len(auth.config.Users) > 0
}
func (auth *AuthService) LdapAuthConfigured() bool {
return auth.ldap != nil
}
func (auth *AuthService) IsUserAllowed(c *gin.Context, context config.UserContext, acls config.App) bool {
@@ -409,6 +437,22 @@ func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context config.UserConte
return false
}
func (auth *AuthService) IsInLdapGroup(c *gin.Context, context config.UserContext, requiredGroups string) bool {
if requiredGroups == "" {
return true
}
for userGroup := range strings.SplitSeq(context.LdapGroups, ",") {
if utils.CheckFilter(requiredGroups, strings.TrimSpace(userGroup)) {
tlog.App.Trace().Str("group", userGroup).Str("required", requiredGroups).Msg("User group matched")
return true
}
}
tlog.App.Debug().Msg("No groups matched")
return false
}
func (auth *AuthService) IsAuthEnabled(uri string, path config.AppPath) (bool, error) {
// Check for block list
if path.Block != "" {
+14 -15
View File
@@ -4,8 +4,6 @@ import (
"context"
"crypto/tls"
"fmt"
"slices"
"strings"
"sync"
"time"
@@ -148,11 +146,13 @@ func (ldap *LdapService) GetUserDN(username string) (string, error) {
}
func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) {
escapedUserDN := ldapgo.EscapeFilter(userDN)
searchRequest := ldapgo.NewSearchRequest(
ldap.config.BaseDN,
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
"(objectclass=groupOfUniqueNames)",
[]string{"uniquemember"},
fmt.Sprintf("(&(objectclass=groupOfUniqueNames)(uniquemember=%s))", escapedUserDN),
[]string{"dn"},
nil,
)
@@ -167,22 +167,21 @@ func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) {
groupDNs := []string{}
for _, entry := range searchResult.Entries {
memberAttributes := entry.GetAttributeValues("uniquemember")
// no need to escape username here, if it's malicious it won't match anything
if slices.Contains(memberAttributes, userDN) {
groupDNs = append(groupDNs, entry.DN)
}
groupDNs = append(groupDNs, entry.DN)
}
// Should work for most ldap providers?
groups := []string{}
for _, groupDN := range groupDNs {
groupDN = strings.TrimPrefix(groupDN, "cn=")
parts := strings.SplitN(groupDN, ",", 2)
if len(parts) > 0 {
groups = append(groups, parts[0])
// I guess it should work for most ldap providers
for _, dn := range groupDNs {
rdnParts, err := ldapgo.ParseDN(dn)
if err != nil {
return []string{}, err
}
if len(rdnParts.RDNs) == 0 || len(rdnParts.RDNs[0].Attributes) == 0 {
return []string{}, fmt.Errorf("invalid DN format: %s", dn)
}
groups = append(groups, rdnParts.RDNs[0].Attributes[0].Value)
}
return groups, nil
+3 -5
View File
@@ -10,10 +10,9 @@ INSERT INTO sessions (
"expiry",
"created_at",
"oauth_name",
"oauth_sub",
"ldap_groups"
"oauth_sub"
) VALUES (
?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?
?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?
)
RETURNING *;
@@ -35,8 +34,7 @@ UPDATE "sessions" SET
"oauth_groups" = ?,
"expiry" = ?,
"oauth_name" = ?,
"oauth_sub" = ?,
"ldap_groups" = ?
"oauth_sub" = ?
WHERE "uuid" = ?
RETURNING *;
+1 -2
View File
@@ -9,6 +9,5 @@ CREATE TABLE IF NOT EXISTS "sessions" (
"expiry" INTEGER NOT NULL,
"created_at" INTEGER NOT NULL,
"oauth_name" TEXT NULL,
"oauth_sub" TEXT NULL,
"ldap_groups" TEXT NULL
"oauth_sub" TEXT NULL
);