mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-01-12 02:12:29 +00:00
Compare commits
1 Commits
feat/ldap-
...
refactor/u
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c453b57440 |
@@ -1 +0,0 @@
|
|||||||
ALTER TABLE "sessions" DROP COLUMN "ldap_groups";
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
ALTER TABLE "sessions" ADD COLUMN "ldap_groups" TEXT;
|
|
||||||
@@ -122,16 +122,23 @@ type User struct {
|
|||||||
TotpSecret string
|
TotpSecret string
|
||||||
}
|
}
|
||||||
|
|
||||||
type LdapUser struct {
|
|
||||||
DN string
|
|
||||||
Groups []string
|
|
||||||
}
|
|
||||||
|
|
||||||
type UserSearch struct {
|
type UserSearch struct {
|
||||||
Username string
|
Username string
|
||||||
Type string // local, ldap or unknown
|
Type string // local, ldap or unknown
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type SessionCookie struct {
|
||||||
|
UUID string
|
||||||
|
Username string
|
||||||
|
Name string
|
||||||
|
Email string
|
||||||
|
Provider string
|
||||||
|
TotpPending bool
|
||||||
|
OAuthGroups string
|
||||||
|
OAuthName string
|
||||||
|
OAuthSub string
|
||||||
|
}
|
||||||
|
|
||||||
type UserContext struct {
|
type UserContext struct {
|
||||||
Username string
|
Username string
|
||||||
Name string
|
Name string
|
||||||
@@ -144,7 +151,6 @@ type UserContext struct {
|
|||||||
TotpEnabled bool
|
TotpEnabled bool
|
||||||
OAuthName string
|
OAuthName string
|
||||||
OAuthSub string
|
OAuthSub string
|
||||||
LdapGroups string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// API responses and queries
|
// API responses and queries
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/steveiliop56/tinyauth/internal/config"
|
"github.com/steveiliop56/tinyauth/internal/config"
|
||||||
"github.com/steveiliop56/tinyauth/internal/repository"
|
|
||||||
"github.com/steveiliop56/tinyauth/internal/service"
|
"github.com/steveiliop56/tinyauth/internal/service"
|
||||||
"github.com/steveiliop56/tinyauth/internal/utils"
|
"github.com/steveiliop56/tinyauth/internal/utils"
|
||||||
|
|
||||||
@@ -191,7 +190,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
username = strings.Replace(user.Email, "@", "_", -1)
|
username = strings.Replace(user.Email, "@", "_", -1)
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionCookie := repository.Session{
|
sessionCookie := config.SessionCookie{
|
||||||
Username: username,
|
Username: username,
|
||||||
Name: name,
|
Name: name,
|
||||||
Email: user.Email,
|
Email: user.Email,
|
||||||
|
|||||||
@@ -140,7 +140,7 @@ func TestProxyHandler(t *testing.T) {
|
|||||||
// Test logged in user
|
// Test logged in user
|
||||||
c := gin.CreateTestContextOnly(recorder, router)
|
c := gin.CreateTestContextOnly(recorder, router)
|
||||||
|
|
||||||
err := authService.CreateSessionCookie(c, &repository.Session{
|
err := authService.CreateSessionCookie(c, &config.SessionCookie{
|
||||||
Username: "testuser",
|
Username: "testuser",
|
||||||
Name: "testuser",
|
Name: "testuser",
|
||||||
Email: "testuser@example.com",
|
Email: "testuser@example.com",
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/steveiliop56/tinyauth/internal/repository"
|
"github.com/steveiliop56/tinyauth/internal/config"
|
||||||
"github.com/steveiliop56/tinyauth/internal/service"
|
"github.com/steveiliop56/tinyauth/internal/service"
|
||||||
"github.com/steveiliop56/tinyauth/internal/utils"
|
"github.com/steveiliop56/tinyauth/internal/utils"
|
||||||
|
|
||||||
@@ -108,7 +108,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
if user.TotpSecret != "" {
|
if user.TotpSecret != "" {
|
||||||
log.Debug().Str("username", req.Username).Msg("User has TOTP enabled, requiring TOTP verification")
|
log.Debug().Str("username", req.Username).Msg("User has TOTP enabled, requiring TOTP verification")
|
||||||
|
|
||||||
err := controller.auth.CreateSessionCookie(c, &repository.Session{
|
err := controller.auth.CreateSessionCookie(c, &config.SessionCookie{
|
||||||
Username: user.Username,
|
Username: user.Username,
|
||||||
Name: utils.Capitalize(req.Username),
|
Name: utils.Capitalize(req.Username),
|
||||||
Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.config.CookieDomain),
|
Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.config.CookieDomain),
|
||||||
@@ -134,28 +134,13 @@ func (controller *UserController) loginHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionCookie := repository.Session{
|
sessionCookie := config.SessionCookie{
|
||||||
Username: req.Username,
|
Username: req.Username,
|
||||||
Name: utils.Capitalize(req.Username),
|
Name: utils.Capitalize(req.Username),
|
||||||
Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.config.CookieDomain),
|
Email: fmt.Sprintf("%s@%s", strings.ToLower(req.Username), controller.config.CookieDomain),
|
||||||
Provider: "username",
|
Provider: "username",
|
||||||
}
|
}
|
||||||
|
|
||||||
if userSearch.Type == "ldap" {
|
|
||||||
ldapUser, err := controller.auth.GetLdapUser(userSearch.Username)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Str("username", req.Username).Msg("Failed to get LDAP user details")
|
|
||||||
c.JSON(500, gin.H{
|
|
||||||
"status": 500,
|
|
||||||
"message": "Internal Server Error",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
sessionCookie.LdapGroups = strings.Join(ldapUser.Groups, ",")
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
|
log.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
|
||||||
|
|
||||||
err = controller.auth.CreateSessionCookie(c, &sessionCookie)
|
err = controller.auth.CreateSessionCookie(c, &sessionCookie)
|
||||||
@@ -252,7 +237,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
|
|||||||
|
|
||||||
controller.auth.RecordLoginAttempt(context.Username, true)
|
controller.auth.RecordLoginAttempt(context.Username, true)
|
||||||
|
|
||||||
sessionCookie := repository.Session{
|
sessionCookie := config.SessionCookie{
|
||||||
Username: user.Username,
|
Username: user.Username,
|
||||||
Name: utils.Capitalize(user.Username),
|
Name: utils.Capitalize(user.Username),
|
||||||
Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), controller.config.CookieDomain),
|
Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), controller.config.CookieDomain),
|
||||||
|
|||||||
@@ -74,7 +74,6 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
|
|||||||
Email: cookie.Email,
|
Email: cookie.Email,
|
||||||
Provider: "username",
|
Provider: "username",
|
||||||
IsLoggedIn: true,
|
IsLoggedIn: true,
|
||||||
LdapGroups: cookie.LdapGroups,
|
|
||||||
})
|
})
|
||||||
c.Next()
|
c.Next()
|
||||||
return
|
return
|
||||||
@@ -156,7 +155,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
|
|||||||
Username: user.Username,
|
Username: user.Username,
|
||||||
Name: utils.Capitalize(user.Username),
|
Name: utils.Capitalize(user.Username),
|
||||||
Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), m.config.CookieDomain),
|
Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), m.config.CookieDomain),
|
||||||
Provider: "username",
|
Provider: "basic",
|
||||||
IsLoggedIn: true,
|
IsLoggedIn: true,
|
||||||
TotpEnabled: user.TotpSecret != "",
|
TotpEnabled: user.TotpSecret != "",
|
||||||
})
|
})
|
||||||
@@ -164,22 +163,12 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
|
|||||||
return
|
return
|
||||||
case "ldap":
|
case "ldap":
|
||||||
log.Debug().Msg("Basic auth user is LDAP")
|
log.Debug().Msg("Basic auth user is LDAP")
|
||||||
|
|
||||||
ldapUser, err := m.auth.GetLdapUser(basic.Username)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
log.Debug().Err(err).Msg("Error retrieving LDAP user details")
|
|
||||||
c.Next()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Set("context", &config.UserContext{
|
c.Set("context", &config.UserContext{
|
||||||
Username: basic.Username,
|
Username: basic.Username,
|
||||||
Name: utils.Capitalize(basic.Username),
|
Name: utils.Capitalize(basic.Username),
|
||||||
Email: fmt.Sprintf("%s@%s", strings.ToLower(basic.Username), m.config.CookieDomain),
|
Email: fmt.Sprintf("%s@%s", strings.ToLower(basic.Username), m.config.CookieDomain),
|
||||||
Provider: "ldap",
|
Provider: "basic",
|
||||||
IsLoggedIn: true,
|
IsLoggedIn: true,
|
||||||
LdapGroups: strings.Join(ldapUser.Groups, ","),
|
|
||||||
})
|
})
|
||||||
c.Next()
|
c.Next()
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -16,5 +16,4 @@ type Session struct {
|
|||||||
CreatedAt int64
|
CreatedAt int64
|
||||||
OAuthName string
|
OAuthName string
|
||||||
OAuthSub string
|
OAuthSub string
|
||||||
LdapGroups string
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,12 +21,11 @@ INSERT INTO sessions (
|
|||||||
"expiry",
|
"expiry",
|
||||||
"created_at",
|
"created_at",
|
||||||
"oauth_name",
|
"oauth_name",
|
||||||
"oauth_sub",
|
"oauth_sub"
|
||||||
"ldap_groups"
|
|
||||||
) VALUES (
|
) VALUES (
|
||||||
?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?
|
?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?
|
||||||
)
|
)
|
||||||
RETURNING uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub, ldap_groups
|
RETURNING uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub
|
||||||
`
|
`
|
||||||
|
|
||||||
type CreateSessionParams struct {
|
type CreateSessionParams struct {
|
||||||
@@ -41,7 +40,6 @@ type CreateSessionParams struct {
|
|||||||
CreatedAt int64
|
CreatedAt int64
|
||||||
OAuthName string
|
OAuthName string
|
||||||
OAuthSub string
|
OAuthSub string
|
||||||
LdapGroups string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error) {
|
func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error) {
|
||||||
@@ -57,7 +55,6 @@ func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (S
|
|||||||
arg.CreatedAt,
|
arg.CreatedAt,
|
||||||
arg.OAuthName,
|
arg.OAuthName,
|
||||||
arg.OAuthSub,
|
arg.OAuthSub,
|
||||||
arg.LdapGroups,
|
|
||||||
)
|
)
|
||||||
var i Session
|
var i Session
|
||||||
err := row.Scan(
|
err := row.Scan(
|
||||||
@@ -72,7 +69,6 @@ func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (S
|
|||||||
&i.CreatedAt,
|
&i.CreatedAt,
|
||||||
&i.OAuthName,
|
&i.OAuthName,
|
||||||
&i.OAuthSub,
|
&i.OAuthSub,
|
||||||
&i.LdapGroups,
|
|
||||||
)
|
)
|
||||||
return i, err
|
return i, err
|
||||||
}
|
}
|
||||||
@@ -98,7 +94,7 @@ func (q *Queries) DeleteSession(ctx context.Context, uuid string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const getSession = `-- name: GetSession :one
|
const getSession = `-- name: GetSession :one
|
||||||
SELECT uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub, ldap_groups FROM "sessions"
|
SELECT uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub FROM "sessions"
|
||||||
WHERE "uuid" = ?
|
WHERE "uuid" = ?
|
||||||
`
|
`
|
||||||
|
|
||||||
@@ -117,7 +113,6 @@ func (q *Queries) GetSession(ctx context.Context, uuid string) (Session, error)
|
|||||||
&i.CreatedAt,
|
&i.CreatedAt,
|
||||||
&i.OAuthName,
|
&i.OAuthName,
|
||||||
&i.OAuthSub,
|
&i.OAuthSub,
|
||||||
&i.LdapGroups,
|
|
||||||
)
|
)
|
||||||
return i, err
|
return i, err
|
||||||
}
|
}
|
||||||
@@ -132,10 +127,9 @@ UPDATE "sessions" SET
|
|||||||
"oauth_groups" = ?,
|
"oauth_groups" = ?,
|
||||||
"expiry" = ?,
|
"expiry" = ?,
|
||||||
"oauth_name" = ?,
|
"oauth_name" = ?,
|
||||||
"oauth_sub" = ?,
|
"oauth_sub" = ?
|
||||||
"ldap_groups" = ?
|
|
||||||
WHERE "uuid" = ?
|
WHERE "uuid" = ?
|
||||||
RETURNING uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub, ldap_groups
|
RETURNING uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, created_at, oauth_name, oauth_sub
|
||||||
`
|
`
|
||||||
|
|
||||||
type UpdateSessionParams struct {
|
type UpdateSessionParams struct {
|
||||||
@@ -148,7 +142,6 @@ type UpdateSessionParams struct {
|
|||||||
Expiry int64
|
Expiry int64
|
||||||
OAuthName string
|
OAuthName string
|
||||||
OAuthSub string
|
OAuthSub string
|
||||||
LdapGroups string
|
|
||||||
UUID string
|
UUID string
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -163,7 +156,6 @@ func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (S
|
|||||||
arg.Expiry,
|
arg.Expiry,
|
||||||
arg.OAuthName,
|
arg.OAuthName,
|
||||||
arg.OAuthSub,
|
arg.OAuthSub,
|
||||||
arg.LdapGroups,
|
|
||||||
arg.UUID,
|
arg.UUID,
|
||||||
)
|
)
|
||||||
var i Session
|
var i Session
|
||||||
@@ -179,7 +171,6 @@ func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (S
|
|||||||
&i.CreatedAt,
|
&i.CreatedAt,
|
||||||
&i.OAuthName,
|
&i.OAuthName,
|
||||||
&i.OAuthSub,
|
&i.OAuthSub,
|
||||||
&i.LdapGroups,
|
|
||||||
)
|
)
|
||||||
return i, err
|
return i, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ func (auth *AuthService) SearchUser(username string) config.UserSearch {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if auth.ldap != nil {
|
if auth.ldap != nil {
|
||||||
userDN, err := auth.ldap.GetUserDN(username)
|
userDN, err := auth.ldap.Search(username)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn().Err(err).Str("username", username).Msg("Failed to search for user in LDAP")
|
log.Warn().Err(err).Str("username", username).Msg("Failed to search for user in LDAP")
|
||||||
@@ -131,19 +131,6 @@ func (auth *AuthService) GetLocalUser(username string) config.User {
|
|||||||
return config.User{}
|
return config.User{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) GetLdapUser(userDN string) (config.LdapUser, error) {
|
|
||||||
groups, err := auth.ldap.GetUserGroups(userDN)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return config.LdapUser{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return config.LdapUser{
|
|
||||||
DN: userDN,
|
|
||||||
Groups: groups,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (auth *AuthService) CheckPassword(user config.User, password string) bool {
|
func (auth *AuthService) CheckPassword(user config.User, password string) bool {
|
||||||
return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) == nil
|
return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) == nil
|
||||||
}
|
}
|
||||||
@@ -203,7 +190,7 @@ func (auth *AuthService) IsEmailWhitelisted(email string) bool {
|
|||||||
return utils.CheckFilter(strings.Join(auth.config.OauthWhitelist, ","), email)
|
return utils.CheckFilter(strings.Join(auth.config.OauthWhitelist, ","), email)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *repository.Session) error {
|
func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *config.SessionCookie) error {
|
||||||
uuid, err := uuid.NewRandom()
|
uuid, err := uuid.NewRandom()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -230,7 +217,6 @@ func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *repository.Se
|
|||||||
CreatedAt: time.Now().Unix(),
|
CreatedAt: time.Now().Unix(),
|
||||||
OAuthName: data.OAuthName,
|
OAuthName: data.OAuthName,
|
||||||
OAuthSub: data.OAuthSub,
|
OAuthSub: data.OAuthSub,
|
||||||
LdapGroups: data.LdapGroups,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = auth.queries.CreateSession(c, session)
|
_, err = auth.queries.CreateSession(c, session)
|
||||||
@@ -284,7 +270,6 @@ func (auth *AuthService) RefreshSessionCookie(c *gin.Context) error {
|
|||||||
OAuthName: session.OAuthName,
|
OAuthName: session.OAuthName,
|
||||||
OAuthSub: session.OAuthSub,
|
OAuthSub: session.OAuthSub,
|
||||||
UUID: session.UUID,
|
UUID: session.UUID,
|
||||||
LdapGroups: session.LdapGroups,
|
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -315,20 +300,20 @@ func (auth *AuthService) DeleteSessionCookie(c *gin.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) GetSessionCookie(c *gin.Context) (repository.Session, error) {
|
func (auth *AuthService) GetSessionCookie(c *gin.Context) (config.SessionCookie, error) {
|
||||||
cookie, err := c.Cookie(auth.config.SessionCookieName)
|
cookie, err := c.Cookie(auth.config.SessionCookieName)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return repository.Session{}, err
|
return config.SessionCookie{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
session, err := auth.queries.GetSession(c, cookie)
|
session, err := auth.queries.GetSession(c, cookie)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
return repository.Session{}, fmt.Errorf("session not found")
|
return config.SessionCookie{}, fmt.Errorf("session not found")
|
||||||
}
|
}
|
||||||
return repository.Session{}, err
|
return config.SessionCookie{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
currentTime := time.Now().Unix()
|
currentTime := time.Now().Unix()
|
||||||
@@ -339,7 +324,7 @@ func (auth *AuthService) GetSessionCookie(c *gin.Context) (repository.Session, e
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("Failed to delete session exceeding max lifetime")
|
log.Error().Err(err).Msg("Failed to delete session exceeding max lifetime")
|
||||||
}
|
}
|
||||||
return repository.Session{}, fmt.Errorf("session expired due to max lifetime exceeded")
|
return config.SessionCookie{}, fmt.Errorf("session expired due to max lifetime exceeded")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -348,10 +333,10 @@ func (auth *AuthService) GetSessionCookie(c *gin.Context) (repository.Session, e
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("Failed to delete expired session")
|
log.Error().Err(err).Msg("Failed to delete expired session")
|
||||||
}
|
}
|
||||||
return repository.Session{}, fmt.Errorf("session expired")
|
return config.SessionCookie{}, fmt.Errorf("session expired")
|
||||||
}
|
}
|
||||||
|
|
||||||
return repository.Session{
|
return config.SessionCookie{
|
||||||
UUID: session.UUID,
|
UUID: session.UUID,
|
||||||
Username: session.Username,
|
Username: session.Username,
|
||||||
Email: session.Email,
|
Email: session.Email,
|
||||||
@@ -361,7 +346,6 @@ func (auth *AuthService) GetSessionCookie(c *gin.Context) (repository.Session, e
|
|||||||
OAuthGroups: session.OAuthGroups,
|
OAuthGroups: session.OAuthGroups,
|
||||||
OAuthName: session.OAuthName,
|
OAuthName: session.OAuthName,
|
||||||
OAuthSub: session.OAuthSub,
|
OAuthSub: session.OAuthSub,
|
||||||
LdapGroups: session.LdapGroups,
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,8 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
"slices"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -118,7 +116,7 @@ func (ldap *LdapService) connect() (*ldapgo.Conn, error) {
|
|||||||
return ldap.conn, nil
|
return ldap.conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ldap *LdapService) GetUserDN(username string) (string, error) {
|
func (ldap *LdapService) Search(username string) (string, error) {
|
||||||
// Escape the username to prevent LDAP injection
|
// Escape the username to prevent LDAP injection
|
||||||
escapedUsername := ldapgo.EscapeFilter(username)
|
escapedUsername := ldapgo.EscapeFilter(username)
|
||||||
filter := fmt.Sprintf(ldap.config.SearchFilter, escapedUsername)
|
filter := fmt.Sprintf(ldap.config.SearchFilter, escapedUsername)
|
||||||
@@ -147,47 +145,6 @@ func (ldap *LdapService) GetUserDN(username string) (string, error) {
|
|||||||
return userDN, nil
|
return userDN, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) {
|
|
||||||
searchRequest := ldapgo.NewSearchRequest(
|
|
||||||
ldap.config.BaseDN,
|
|
||||||
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
|
|
||||||
"(objectclass=groupOfUniqueNames)",
|
|
||||||
[]string{"uniquemember"},
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
|
|
||||||
ldap.mutex.Lock()
|
|
||||||
defer ldap.mutex.Unlock()
|
|
||||||
|
|
||||||
searchResult, err := ldap.conn.Search(searchRequest)
|
|
||||||
if err != nil {
|
|
||||||
return []string{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
groupDNs := []string{}
|
|
||||||
|
|
||||||
for _, entry := range searchResult.Entries {
|
|
||||||
memberAttributes := entry.GetAttributeValues("uniquemember")
|
|
||||||
// no need to escape username here, if it's malicious it won't match anything
|
|
||||||
if slices.Contains(memberAttributes, userDN) {
|
|
||||||
groupDNs = append(groupDNs, entry.DN)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Should work for most ldap providers?
|
|
||||||
groups := []string{}
|
|
||||||
|
|
||||||
for _, groupDN := range groupDNs {
|
|
||||||
groupDN = strings.TrimPrefix(groupDN, "cn=")
|
|
||||||
parts := strings.SplitN(groupDN, ",", 2)
|
|
||||||
if len(parts) > 0 {
|
|
||||||
groups = append(groups, parts[0])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return groups, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ldap *LdapService) BindService(rebind bool) error {
|
func (ldap *LdapService) BindService(rebind bool) error {
|
||||||
// Locks must not be used for initial binding attempt
|
// Locks must not be used for initial binding attempt
|
||||||
if rebind {
|
if rebind {
|
||||||
|
|||||||
@@ -10,10 +10,9 @@ INSERT INTO sessions (
|
|||||||
"expiry",
|
"expiry",
|
||||||
"created_at",
|
"created_at",
|
||||||
"oauth_name",
|
"oauth_name",
|
||||||
"oauth_sub",
|
"oauth_sub"
|
||||||
"ldap_groups"
|
|
||||||
) VALUES (
|
) VALUES (
|
||||||
?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?
|
?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?
|
||||||
)
|
)
|
||||||
RETURNING *;
|
RETURNING *;
|
||||||
|
|
||||||
@@ -35,8 +34,7 @@ UPDATE "sessions" SET
|
|||||||
"oauth_groups" = ?,
|
"oauth_groups" = ?,
|
||||||
"expiry" = ?,
|
"expiry" = ?,
|
||||||
"oauth_name" = ?,
|
"oauth_name" = ?,
|
||||||
"oauth_sub" = ?,
|
"oauth_sub" = ?
|
||||||
"ldap_groups" = ?
|
|
||||||
WHERE "uuid" = ?
|
WHERE "uuid" = ?
|
||||||
RETURNING *;
|
RETURNING *;
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,5 @@ CREATE TABLE IF NOT EXISTS "sessions" (
|
|||||||
"expiry" INTEGER NOT NULL,
|
"expiry" INTEGER NOT NULL,
|
||||||
"created_at" INTEGER NOT NULL,
|
"created_at" INTEGER NOT NULL,
|
||||||
"oauth_name" TEXT NULL,
|
"oauth_name" TEXT NULL,
|
||||||
"oauth_sub" TEXT NULL,
|
"oauth_sub" TEXT NULL
|
||||||
"ldap_groups" TEXT NULL
|
|
||||||
);
|
);
|
||||||
|
|||||||
Reference in New Issue
Block a user