refactor: rework ldap group fetching logic

This commit is contained in:
Stavros
2026-01-16 17:06:22 +02:00
parent 4917202879
commit 4fc18325a0
20 changed files with 129 additions and 76 deletions

13
.zed/debug.json Normal file
View File

@@ -0,0 +1,13 @@
[
{
"label": "Attach to remote Delve",
"adapter": "Delve",
"mode": "remote",
"remotePath": "/tinyauth",
"request": "attach",
"tcp_connection": {
"host": "127.0.0.1",
"port": 4000,
},
},
]

View File

@@ -62,3 +62,8 @@ develop:
# Production
prod:
docker compose -f $(PROD_COMPOSE) up --force-recreate --pull=always --remove-orphans
# SQL
.PHONY: sql
sql:
sqlc generate

View File

@@ -50,10 +50,12 @@ export const LoginPage = () => {
const redirectUri = searchParams.get("redirect_uri");
const oauthProviders = providers.filter(
(provider) => provider.id !== "username",
(provider) => provider.id !== "local" && provider.id != "ldap",
);
const userAuthConfigured =
providers.find((provider) => provider.id === "username") !== undefined;
providers.find(
(provider) => provider.id === "local" || provider.id === "ldap",
) !== undefined;
const oauthMutation = useMutation({
mutationFn: (provider: string) =>

View File

@@ -1 +0,0 @@
ALTER TABLE "sessions" DROP COLUMN "ldap_groups";

View File

@@ -1 +0,0 @@
ALTER TABLE "sessions" ADD COLUMN "ldap_groups" TEXT;

View File

@@ -144,10 +144,18 @@ func (app *BootstrapApp) Setup() error {
return configuredProviders[i].Name < configuredProviders[j].Name
})
if services.authService.UserAuthConfigured() {
if services.authService.LocalAuthConfigured() {
configuredProviders = append(configuredProviders, controller.Provider{
Name: "Username",
ID: "username",
Name: "Local",
ID: "local",
OAuth: false,
})
}
if services.authService.LdapAuthConfigured() {
configuredProviders = append(configuredProviders, controller.Provider{
Name: "LDAP",
ID: "ldap",
OAuth: false,
})
}

View File

@@ -153,6 +153,7 @@ type UserContext struct {
Name string
Email string
IsLoggedIn bool
IsBasicAuth bool
OAuth bool
Provider string
TotpPending bool
@@ -189,6 +190,7 @@ type App struct {
IP AppIP `description:"IP access configuration." yaml:"ip"`
Response AppResponse `description:"Response customization." yaml:"response"`
Path AppPath `description:"Path access configuration." yaml:"path"`
LDAP AppLDAP `description:"LDAP access configuration." yaml:"ldap"`
}
type AppConfig struct {
@@ -205,6 +207,10 @@ type AppOAuth struct {
Groups string `description:"Comma-separated list of required OAuth groups." yaml:"groups"`
}
type AppLDAP struct {
Groups string `description:"Comma-separated list of required LDAP groups." yaml:"groups"`
}
type AppIP struct {
Allow []string `description:"List of allowed IPs or CIDR ranges." yaml:"allow"`
Block []string `description:"List of blocked IPs or CIDR ranges." yaml:"block"`

View File

@@ -21,7 +21,6 @@ type UserContextResponse struct {
OAuth bool `json:"oauth"`
TotpPending bool `json:"totpPending"`
OAuthName string `json:"oauthName"`
OAuthSub string `json:"oauthSub"`
}
type AppContextResponse struct {
@@ -90,7 +89,6 @@ func (controller *ContextController) userContextHandler(c *gin.Context) {
OAuth: context.OAuth,
TotpPending: context.TotpPending,
OAuthName: context.OAuthName,
OAuthSub: context.OAuthSub,
}
if err != nil {

View File

@@ -16,8 +16,8 @@ import (
var controllerCfg = controller.ContextControllerConfig{
Providers: []controller.Provider{
{
Name: "Username",
ID: "username",
Name: "Local",
ID: "local",
OAuth: false,
},
{
@@ -40,6 +40,7 @@ var userContext = config.UserContext{
Name: "testuser",
Email: "test@example.com",
IsLoggedIn: true,
IsBasicAuth: false,
OAuth: false,
Provider: "username",
TotpPending: false,

View File

@@ -189,7 +189,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
username = user.PreferredUsername
} else {
tlog.App.Debug().Msg("No preferred username from OAuth provider, using pseudo username")
username = strings.Replace(user.Email, "@", "_", -1)
username = strings.Replace(user.Email, "@", "_", 1)
}
sessionCookie := repository.Session{

View File

@@ -173,7 +173,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
tlog.App.Trace().Interface("context", userContext).Msg("User context from request")
if userContext.Provider == "basic" && userContext.TotpEnabled {
if userContext.IsBasicAuth && userContext.TotpEnabled {
tlog.App.Debug().Msg("User has TOTP enabled, denying basic auth access")
userContext.IsLoggedIn = false
}
@@ -212,11 +212,17 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return
}
if userContext.OAuth {
groupOK := controller.auth.IsInOAuthGroup(c, userContext, acls.OAuth.Groups)
if userContext.OAuth || userContext.Provider == "ldap" {
var groupOK bool
if userContext.OAuth {
groupOK = controller.auth.IsInOAuthGroup(c, userContext, acls.OAuth.Groups)
} else {
groupOK = controller.auth.IsInLdapGroup(c, userContext, acls.LDAP.Groups)
}
if !groupOK {
tlog.App.Warn().Str("user", userContext.Username).Str("resource", strings.Split(host, ".")[0]).Msg("User OAuth groups do not match resource requirements")
tlog.App.Warn().Str("user", userContext.Username).Str("resource", strings.Split(host, ".")[0]).Msg("User groups do not match resource requirements")
if req.Proxy == "nginx" || !isBrowser {
c.JSON(403, gin.H{
@@ -251,7 +257,13 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
c.Header("Remote-User", utils.SanitizeHeader(userContext.Username))
c.Header("Remote-Name", utils.SanitizeHeader(userContext.Name))
c.Header("Remote-Email", utils.SanitizeHeader(userContext.Email))
c.Header("Remote-Groups", utils.SanitizeHeader(userContext.OAuthGroups))
if userContext.Provider == "ldap" {
c.Header("Remote-Groups", utils.SanitizeHeader(userContext.LdapGroups))
} else if userContext.Provider != "local" {
c.Header("Remote-Groups", utils.SanitizeHeader(userContext.OAuthGroups))
}
c.Header("Remote-Sub", utils.SanitizeHeader(userContext.OAuthSub))
controller.setHeaders(c, acls)

View File

@@ -192,6 +192,7 @@ func TestProxyHandler(t *testing.T) {
Name: "testuser",
Email: "testuser@example.com",
IsLoggedIn: true,
IsBasicAuth: true,
OAuth: false,
Provider: "basic",
TotpPending: false,

View File

@@ -116,7 +116,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
Username: user.Username,
Name: utils.Capitalize(req.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.config.CookieDomain),
Provider: "username",
Provider: "local",
TotpPending: true,
})
@@ -142,22 +142,11 @@ func (controller *UserController) loginHandler(c *gin.Context) {
Username: req.Username,
Name: utils.Capitalize(req.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.config.CookieDomain),
Provider: "username",
Provider: "local",
}
if userSearch.Type == "ldap" {
ldapUser, err := controller.auth.GetLdapUser(userSearch.Username)
if err != nil {
tlog.App.Error().Err(err).Str("username", req.Username).Msg("Failed to get LDAP user details")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
})
return
}
sessionCookie.LdapGroups = strings.Join(ldapUser.Groups, ",")
sessionCookie.Provider = "ldap"
}
tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
@@ -267,7 +256,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
Username: user.Username,
Name: utils.Capitalize(user.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), controller.config.CookieDomain),
Provider: "username",
Provider: "local",
}
tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")

View File

@@ -204,7 +204,7 @@ func TestTotpHandler(t *testing.T) {
Email: "totpuser@example.com",
IsLoggedIn: false,
OAuth: false,
Provider: "username",
Provider: "local",
TotpPending: true,
OAuthGroups: "",
TotpEnabled: true,
@@ -267,7 +267,7 @@ func TestTotpHandler(t *testing.T) {
Email: "totpuser@example.com",
IsLoggedIn: false,
OAuth: false,
Provider: "username",
Provider: "local",
TotpPending: true,
OAuthGroups: "",
TotpEnabled: true,
@@ -290,7 +290,7 @@ func TestTotpHandler(t *testing.T) {
Email: "totpuser@example.com",
IsLoggedIn: false,
OAuth: false,
Provider: "username",
Provider: "local",
TotpPending: false,
OAuthGroups: "",
TotpEnabled: false,

View File

@@ -49,7 +49,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
Username: cookie.Username,
Name: cookie.Name,
Email: cookie.Email,
Provider: "username",
Provider: "local",
TotpPending: true,
TotpEnabled: true,
})
@@ -58,23 +58,37 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
}
switch cookie.Provider {
case "username":
case "local", "ldap":
userSearch := m.auth.SearchUser(cookie.Username)
if userSearch.Type == "unknown" || userSearch.Type == "error" {
if userSearch.Type == "unknown" {
tlog.App.Debug().Msg("User from session cookie not found")
m.auth.DeleteSessionCookie(c)
goto basic
}
var ldapGroups []string
if userSearch.Type == "ldap" {
ldapUser, err := m.auth.GetLdapUser(userSearch.Username)
if err != nil {
tlog.App.Error().Err(err).Msg("Error retrieving LDAP user details")
c.Next()
return
}
ldapGroups = ldapUser.Groups
}
m.auth.RefreshSessionCookie(c)
c.Set("context", &config.UserContext{
Username: cookie.Username,
Name: cookie.Name,
Email: cookie.Email,
Provider: "username",
Provider: userSearch.Type,
IsLoggedIn: true,
LdapGroups: cookie.LdapGroups,
LdapGroups: strings.Join(ldapGroups, ","),
})
c.Next()
return
@@ -156,9 +170,10 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
Username: user.Username,
Name: utils.Capitalize(user.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), m.config.CookieDomain),
Provider: "username",
Provider: "local",
IsLoggedIn: true,
TotpEnabled: user.TotpSecret != "",
IsBasicAuth: true,
})
c.Next()
return
@@ -174,12 +189,13 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
}
c.Set("context", &config.UserContext{
Username: basic.Username,
Name: utils.Capitalize(basic.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(basic.Username), m.config.CookieDomain),
Provider: "ldap",
IsLoggedIn: true,
LdapGroups: strings.Join(ldapUser.Groups, ","),
Username: basic.Username,
Name: utils.Capitalize(basic.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(basic.Username), m.config.CookieDomain),
Provider: "ldap",
IsLoggedIn: true,
LdapGroups: strings.Join(ldapUser.Groups, ","),
IsBasicAuth: true,
})
c.Next()
return

View File

@@ -16,5 +16,4 @@ type Session struct {
CreatedAt int64
OAuthName string
OAuthSub string
LdapGroups string
}

View File

@@ -21,12 +21,11 @@ INSERT INTO sessions (
"expiry",
"created_at",
"oauth_name",
"oauth_sub",
"ldap_groups"
"oauth_sub"
) VALUES (
?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?
?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?
)
RETURNING uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub, ldap_groups
RETURNING uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub
`
type CreateSessionParams struct {
@@ -41,7 +40,6 @@ type CreateSessionParams struct {
CreatedAt int64
OAuthName string
OAuthSub string
LdapGroups string
}
func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error) {
@@ -57,7 +55,6 @@ func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (S
arg.CreatedAt,
arg.OAuthName,
arg.OAuthSub,
arg.LdapGroups,
)
var i Session
err := row.Scan(
@@ -72,7 +69,6 @@ func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (S
&i.CreatedAt,
&i.OAuthName,
&i.OAuthSub,
&i.LdapGroups,
)
return i, err
}
@@ -98,7 +94,7 @@ func (q *Queries) DeleteSession(ctx context.Context, uuid string) error {
}
const getSession = `-- name: GetSession :one
SELECT uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub, ldap_groups FROM "sessions"
SELECT uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub FROM "sessions"
WHERE "uuid" = ?
`
@@ -117,7 +113,6 @@ func (q *Queries) GetSession(ctx context.Context, uuid string) (Session, error)
&i.CreatedAt,
&i.OAuthName,
&i.OAuthSub,
&i.LdapGroups,
)
return i, err
}
@@ -132,10 +127,9 @@ UPDATE "sessions" SET
"oauth_groups" = ?,
"expiry" = ?,
"oauth_name" = ?,
"oauth_sub" = ?,
"ldap_groups" = ?
"oauth_sub" = ?
WHERE "uuid" = ?
RETURNING uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub, ldap_groups
RETURNING uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub
`
type UpdateSessionParams struct {
@@ -148,7 +142,6 @@ type UpdateSessionParams struct {
Expiry int64
OAuthName string
OAuthSub string
LdapGroups string
UUID string
}
@@ -163,7 +156,6 @@ func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (S
arg.Expiry,
arg.OAuthName,
arg.OAuthSub,
arg.LdapGroups,
arg.UUID,
)
var i Session
@@ -179,7 +171,6 @@ func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (S
&i.CreatedAt,
&i.OAuthName,
&i.OAuthSub,
&i.LdapGroups,
)
return i, err
}

View File

@@ -75,7 +75,7 @@ func (auth *AuthService) SearchUser(username string) config.UserSearch {
if err != nil {
tlog.App.Warn().Err(err).Str("username", username).Msg("Failed to search for user in LDAP")
return config.UserSearch{
Type: "error",
Type: "unknown",
}
}
@@ -230,7 +230,6 @@ func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *repository.Se
CreatedAt: time.Now().Unix(),
OAuthName: data.OAuthName,
OAuthSub: data.OAuthSub,
LdapGroups: data.LdapGroups,
}
_, err = auth.queries.CreateSession(c, session)
@@ -284,7 +283,6 @@ func (auth *AuthService) RefreshSessionCookie(c *gin.Context) error {
OAuthName: session.OAuthName,
OAuthSub: session.OAuthSub,
UUID: session.UUID,
LdapGroups: session.LdapGroups,
})
if err != nil {
@@ -361,12 +359,15 @@ func (auth *AuthService) GetSessionCookie(c *gin.Context) (repository.Session, e
OAuthGroups: session.OAuthGroups,
OAuthName: session.OAuthName,
OAuthSub: session.OAuthSub,
LdapGroups: session.LdapGroups,
}, nil
}
func (auth *AuthService) UserAuthConfigured() bool {
return len(auth.config.Users) > 0 || auth.ldap != nil
func (auth *AuthService) LocalAuthConfigured() bool {
return len(auth.config.Users) > 0
}
func (auth *AuthService) LdapAuthConfigured() bool {
return auth.ldap != nil
}
func (auth *AuthService) IsUserAllowed(c *gin.Context, context config.UserContext, acls config.App) bool {
@@ -409,6 +410,22 @@ func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context config.UserConte
return false
}
func (auth *AuthService) IsInLdapGroup(c *gin.Context, context config.UserContext, requiredGroups string) bool {
if requiredGroups == "" {
return true
}
for userGroup := range strings.SplitSeq(context.LdapGroups, ",") {
if utils.CheckFilter(requiredGroups, strings.TrimSpace(userGroup)) {
tlog.App.Trace().Str("group", userGroup).Str("required", requiredGroups).Msg("User group matched")
return true
}
}
tlog.App.Debug().Msg("No groups matched")
return false
}
func (auth *AuthService) IsAuthEnabled(uri string, path config.AppPath) (bool, error) {
// Check for block list
if path.Block != "" {

View File

@@ -10,10 +10,9 @@ INSERT INTO sessions (
"expiry",
"created_at",
"oauth_name",
"oauth_sub",
"ldap_groups"
"oauth_sub"
) VALUES (
?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?
?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?
)
RETURNING *;
@@ -35,8 +34,7 @@ UPDATE "sessions" SET
"oauth_groups" = ?,
"expiry" = ?,
"oauth_name" = ?,
"oauth_sub" = ?,
"ldap_groups" = ?
"oauth_sub" = ?
WHERE "uuid" = ?
RETURNING *;

View File

@@ -9,6 +9,5 @@ CREATE TABLE IF NOT EXISTS "sessions" (
"expiry" INTEGER NOT NULL,
"created_at" INTEGER NOT NULL,
"oauth_name" TEXT NULL,
"oauth_sub" TEXT NULL,
"ldap_groups" TEXT NULL
"oauth_sub" TEXT NULL
);