mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-05-11 14:58:10 +00:00
Merge branch 'main' into feat/tailscale
This commit is contained in:
+262
-233
@@ -5,20 +5,22 @@ import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
|
||||
"slices"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"golang.org/x/exp/slices"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
@@ -28,6 +30,10 @@ const MaxOAuthPendingSessions = 256
|
||||
const OAuthCleanupCount = 16
|
||||
const MaxLoginAttemptRecords = 256
|
||||
|
||||
var (
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
)
|
||||
|
||||
// slightly modified version of the AuthorizeRequest from the OIDC service to basically accept all
|
||||
// parameters and pass them to the authorize page if needed
|
||||
type OAuthURLParams struct {
|
||||
@@ -66,41 +72,42 @@ type Lockdown struct {
|
||||
ActiveUntil time.Time
|
||||
}
|
||||
|
||||
type AuthServiceConfig struct {
|
||||
Users []config.User
|
||||
OauthWhitelist []string
|
||||
SessionExpiry int
|
||||
SessionMaxLifetime int
|
||||
SecureCookie bool
|
||||
CookieDomain string
|
||||
LoginTimeout int
|
||||
LoginMaxRetries int
|
||||
SessionCookieName string
|
||||
IP config.IPConfig
|
||||
LDAPGroupsCacheTTL int
|
||||
}
|
||||
|
||||
type AuthService struct {
|
||||
config AuthServiceConfig
|
||||
docker *DockerService
|
||||
log *logger.Logger
|
||||
config model.Config
|
||||
runtime model.RuntimeConfig
|
||||
context context.Context
|
||||
|
||||
ldap *LdapService
|
||||
queries *repository.Queries
|
||||
oauthBroker *OAuthBrokerService
|
||||
|
||||
loginAttempts map[string]*LoginAttempt
|
||||
ldapGroupsCache map[string]*LdapGroupsCache
|
||||
oauthPendingSessions map[string]*OAuthPendingSession
|
||||
oauthMutex sync.RWMutex
|
||||
loginMutex sync.RWMutex
|
||||
ldapGroupsMutex sync.RWMutex
|
||||
ldap *LdapService
|
||||
queries *repository.Queries
|
||||
oauthBroker *OAuthBrokerService
|
||||
lockdown *Lockdown
|
||||
lockdownCtx context.Context
|
||||
lockdownCancelFunc context.CancelFunc
|
||||
}
|
||||
|
||||
func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapService, queries *repository.Queries, oauthBroker *OAuthBrokerService) *AuthService {
|
||||
return &AuthService{
|
||||
func NewAuthService(
|
||||
log *logger.Logger,
|
||||
config model.Config,
|
||||
runtime model.RuntimeConfig,
|
||||
ctx context.Context,
|
||||
wg *sync.WaitGroup,
|
||||
ldap *LdapService,
|
||||
queries *repository.Queries,
|
||||
oauthBroker *OAuthBrokerService,
|
||||
) *AuthService {
|
||||
service := &AuthService{
|
||||
log: log,
|
||||
runtime: runtime,
|
||||
context: ctx,
|
||||
config: config,
|
||||
docker: docker,
|
||||
loginAttempts: make(map[string]*LoginAttempt),
|
||||
ldapGroupsCache: make(map[string]*LdapGroupsCache),
|
||||
oauthPendingSessions: make(map[string]*OAuthPendingSession),
|
||||
@@ -108,86 +115,79 @@ func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapS
|
||||
queries: queries,
|
||||
oauthBroker: oauthBroker,
|
||||
}
|
||||
|
||||
wg.Go(service.CleanupOAuthSessionsRoutine)
|
||||
|
||||
return service
|
||||
}
|
||||
|
||||
func (auth *AuthService) Init() error {
|
||||
go auth.CleanupOAuthSessionsRoutine()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (auth *AuthService) SearchUser(username string) config.UserSearch {
|
||||
if auth.GetLocalUser(username).Username != "" {
|
||||
return config.UserSearch{
|
||||
func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error) {
|
||||
if auth.GetLocalUser(username) != nil {
|
||||
return &model.UserSearch{
|
||||
Username: username,
|
||||
Type: "local",
|
||||
}
|
||||
Type: model.UserLocal,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if auth.ldap.IsConfigured() {
|
||||
if auth.ldap != nil {
|
||||
userDN, err := auth.ldap.GetUserDN(username)
|
||||
|
||||
if err != nil {
|
||||
tlog.App.Warn().Err(err).Str("username", username).Msg("Failed to search for user in LDAP")
|
||||
return config.UserSearch{
|
||||
Type: "unknown",
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get ldap user: %w", err)
|
||||
}
|
||||
|
||||
return config.UserSearch{
|
||||
return &model.UserSearch{
|
||||
Username: userDN,
|
||||
Type: "ldap",
|
||||
}
|
||||
Type: model.UserLDAP,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return config.UserSearch{
|
||||
Type: "unknown",
|
||||
}
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
|
||||
func (auth *AuthService) VerifyUser(search config.UserSearch, password string) bool {
|
||||
func (auth *AuthService) CheckUserPassword(search model.UserSearch, password string) error {
|
||||
switch search.Type {
|
||||
case "local":
|
||||
case model.UserLocal:
|
||||
user := auth.GetLocalUser(search.Username)
|
||||
return auth.CheckPassword(user, password)
|
||||
case "ldap":
|
||||
if auth.ldap.IsConfigured() {
|
||||
if user == nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
|
||||
case model.UserLDAP:
|
||||
if auth.ldap != nil {
|
||||
err := auth.ldap.Bind(search.Username, password)
|
||||
if err != nil {
|
||||
tlog.App.Warn().Err(err).Str("username", search.Username).Msg("Failed to bind to LDAP")
|
||||
return false
|
||||
return fmt.Errorf("failed to bind to ldap user: %w", err)
|
||||
}
|
||||
|
||||
err = auth.ldap.BindService(true)
|
||||
if err != nil {
|
||||
tlog.App.Error().Err(err).Msg("Failed to rebind with service account after user authentication")
|
||||
return false
|
||||
return fmt.Errorf("failed to bind to ldap service account: %w", err)
|
||||
}
|
||||
|
||||
return true
|
||||
return nil
|
||||
}
|
||||
default:
|
||||
tlog.App.Debug().Str("type", search.Type).Msg("Unknown user type for authentication")
|
||||
return false
|
||||
return errors.New("unknown user search type")
|
||||
}
|
||||
|
||||
tlog.App.Warn().Str("username", search.Username).Msg("User authentication failed")
|
||||
return false
|
||||
return errors.New("user authentication failed")
|
||||
}
|
||||
|
||||
func (auth *AuthService) GetLocalUser(username string) config.User {
|
||||
for _, user := range auth.config.Users {
|
||||
func (auth *AuthService) GetLocalUser(username string) *model.LocalUser {
|
||||
if auth.runtime.LocalUsers == nil {
|
||||
return nil
|
||||
}
|
||||
for _, user := range auth.runtime.LocalUsers {
|
||||
if user.Username == username {
|
||||
return user
|
||||
return &user
|
||||
}
|
||||
}
|
||||
|
||||
tlog.App.Warn().Str("username", username).Msg("Local user not found")
|
||||
return config.User{}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (auth *AuthService) GetLdapUser(userDN string) (config.LdapUser, error) {
|
||||
if !auth.ldap.IsConfigured() {
|
||||
return config.LdapUser{}, errors.New("LDAP service not initialized")
|
||||
func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
|
||||
if auth.ldap == nil {
|
||||
return nil, errors.New("ldap service not configured")
|
||||
}
|
||||
|
||||
auth.ldapGroupsMutex.RLock()
|
||||
@@ -195,7 +195,7 @@ func (auth *AuthService) GetLdapUser(userDN string) (config.LdapUser, error) {
|
||||
auth.ldapGroupsMutex.RUnlock()
|
||||
|
||||
if exists && time.Now().Before(entry.Expires) {
|
||||
return config.LdapUser{
|
||||
return &model.LDAPUser{
|
||||
DN: userDN,
|
||||
Groups: entry.Groups,
|
||||
}, nil
|
||||
@@ -204,26 +204,22 @@ func (auth *AuthService) GetLdapUser(userDN string) (config.LdapUser, error) {
|
||||
groups, err := auth.ldap.GetUserGroups(userDN)
|
||||
|
||||
if err != nil {
|
||||
return config.LdapUser{}, err
|
||||
return nil, fmt.Errorf("failed to get ldap groups: %w", err)
|
||||
}
|
||||
|
||||
auth.ldapGroupsMutex.Lock()
|
||||
auth.ldapGroupsCache[userDN] = &LdapGroupsCache{
|
||||
Groups: groups,
|
||||
Expires: time.Now().Add(time.Duration(auth.config.LDAPGroupsCacheTTL) * time.Second),
|
||||
Expires: time.Now().Add(time.Duration(auth.config.LDAP.GroupCacheTTL) * time.Second),
|
||||
}
|
||||
auth.ldapGroupsMutex.Unlock()
|
||||
|
||||
return config.LdapUser{
|
||||
return &model.LDAPUser{
|
||||
DN: userDN,
|
||||
Groups: groups,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (auth *AuthService) CheckPassword(user config.User, password string) bool {
|
||||
return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) == nil
|
||||
}
|
||||
|
||||
func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
|
||||
auth.loginMutex.RLock()
|
||||
defer auth.loginMutex.RUnlock()
|
||||
@@ -233,7 +229,7 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
|
||||
return true, remaining
|
||||
}
|
||||
|
||||
if auth.config.LoginMaxRetries <= 0 || auth.config.LoginTimeout <= 0 {
|
||||
if auth.config.Auth.LoginMaxRetries <= 0 || auth.config.Auth.LoginTimeout <= 0 {
|
||||
return false, 0
|
||||
}
|
||||
|
||||
@@ -251,7 +247,7 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
|
||||
}
|
||||
|
||||
func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
|
||||
if auth.config.LoginMaxRetries <= 0 || auth.config.LoginTimeout <= 0 {
|
||||
if auth.config.Auth.LoginMaxRetries <= 0 || auth.config.Auth.LoginTimeout <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -282,21 +278,21 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
|
||||
|
||||
attempt.FailedAttempts++
|
||||
|
||||
if attempt.FailedAttempts >= auth.config.LoginMaxRetries {
|
||||
attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.LoginTimeout) * time.Second)
|
||||
tlog.App.Warn().Str("identifier", identifier).Int("timeout", auth.config.LoginTimeout).Msg("Account locked due to too many failed login attempts")
|
||||
if attempt.FailedAttempts >= auth.config.Auth.LoginMaxRetries {
|
||||
attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second)
|
||||
auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", attempt.FailedAttempts).Msg("Account locked due to too many failed login attempts")
|
||||
}
|
||||
}
|
||||
|
||||
func (auth *AuthService) IsEmailWhitelisted(email string) bool {
|
||||
return utils.CheckFilter(strings.Join(auth.config.OauthWhitelist, ","), email)
|
||||
return utils.CheckFilter(strings.Join(auth.runtime.OAuthWhitelist, ","), email)
|
||||
}
|
||||
|
||||
func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *repository.Session) error {
|
||||
func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) {
|
||||
uuid, err := uuid.NewRandom()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, fmt.Errorf("failed to generate session uuid: %w", err)
|
||||
}
|
||||
|
||||
var expiry int
|
||||
@@ -304,9 +300,11 @@ func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *repository.Se
|
||||
if data.TotpPending {
|
||||
expiry = 3600
|
||||
} else {
|
||||
expiry = auth.config.SessionExpiry
|
||||
expiry = auth.config.Auth.SessionExpiry
|
||||
}
|
||||
|
||||
expiresAt := time.Now().Add(time.Duration(expiry) * time.Second)
|
||||
|
||||
session := repository.CreateSessionParams{
|
||||
UUID: uuid.String(),
|
||||
Username: data.Username,
|
||||
@@ -315,63 +313,74 @@ func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *repository.Se
|
||||
Provider: data.Provider,
|
||||
TotpPending: data.TotpPending,
|
||||
OAuthGroups: data.OAuthGroups,
|
||||
Expiry: time.Now().Add(time.Duration(expiry) * time.Second).Unix(),
|
||||
Expiry: expiresAt.Unix(),
|
||||
CreatedAt: time.Now().Unix(),
|
||||
OAuthName: data.OAuthName,
|
||||
OAuthSub: data.OAuthSub,
|
||||
}
|
||||
|
||||
_, err = auth.queries.CreateSession(c, session)
|
||||
_, err = auth.queries.CreateSession(ctx, session)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, fmt.Errorf("failed to create session entry: %w", err)
|
||||
}
|
||||
|
||||
if data.Provider == "tailscale" {
|
||||
// TODO: use domain from tailscale to set cookie, this is mostly a hack for now
|
||||
tsCookieDomain, err := utils.GetCookieDomain(fmt.Sprintf("https://%s", c.Request.Host))
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, fmt.Errorf("failed to get cookie domain for tailscale user: %w", err)
|
||||
}
|
||||
c.SetCookie(auth.config.SessionCookieName, session.UUID, expiry, "/", fmt.Sprintf(".%s", tsCookieDomain), auth.config.SecureCookie, true)
|
||||
return nil
|
||||
return &http.Cookie{
|
||||
Name: auth.runtime.SessionCookieName,
|
||||
Value: session.UUID,
|
||||
Path: "/",
|
||||
Domain: fmt.Sprintf(".%s", tsCookieDomain),
|
||||
Expires: expiresAt,
|
||||
MaxAge: int(time.Until(expiresAt).Seconds()),
|
||||
Secure: auth.config.Auth.SecureCookie,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
}, nil
|
||||
}
|
||||
|
||||
c.SetCookie(auth.config.SessionCookieName, session.UUID, expiry, "/", fmt.Sprintf(".%s", auth.config.CookieDomain), auth.config.SecureCookie, true)
|
||||
|
||||
return nil
|
||||
return &http.Cookie{
|
||||
Name: auth.runtime.SessionCookieName,
|
||||
Value: session.UUID,
|
||||
Path: "/",
|
||||
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
|
||||
Expires: expiresAt,
|
||||
MaxAge: int(time.Until(expiresAt).Seconds()),
|
||||
Secure: auth.config.Auth.SecureCookie,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (auth *AuthService) RefreshSessionCookie(c *gin.Context) error {
|
||||
cookie, err := c.Cookie(auth.config.SessionCookieName)
|
||||
func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http.Cookie, error) {
|
||||
session, err := auth.queries.GetSession(ctx, uuid)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
session, err := auth.queries.GetSession(c, cookie)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, fmt.Errorf("failed to retrieve session: %w", err)
|
||||
}
|
||||
|
||||
currentTime := time.Now().Unix()
|
||||
|
||||
var refreshThreshold int64
|
||||
|
||||
if auth.config.SessionExpiry <= int(time.Hour.Seconds()) {
|
||||
refreshThreshold = int64(auth.config.SessionExpiry / 2)
|
||||
if auth.config.Auth.SessionExpiry <= int(time.Hour.Seconds()) {
|
||||
refreshThreshold = int64(auth.config.Auth.SessionExpiry / 2)
|
||||
} else {
|
||||
refreshThreshold = int64(time.Hour.Seconds())
|
||||
}
|
||||
|
||||
if session.Expiry-currentTime > refreshThreshold {
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
newExpiry := session.Expiry + refreshThreshold
|
||||
|
||||
_, err = auth.queries.UpdateSession(c, repository.UpdateSessionParams{
|
||||
_, err = auth.queries.UpdateSession(ctx, repository.UpdateSessionParams{
|
||||
Username: session.Username,
|
||||
Email: session.Email,
|
||||
Name: session.Name,
|
||||
@@ -385,150 +394,160 @@ func (auth *AuthService) RefreshSessionCookie(c *gin.Context) error {
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, fmt.Errorf("failed to update session expiry: %w", err)
|
||||
}
|
||||
|
||||
c.SetCookie(auth.config.SessionCookieName, cookie, int(newExpiry-currentTime), "/", fmt.Sprintf(".%s", auth.config.CookieDomain), auth.config.SecureCookie, true)
|
||||
tlog.App.Trace().Str("username", session.Username).Msg("Session cookie refreshed")
|
||||
return &http.Cookie{
|
||||
Name: auth.runtime.SessionCookieName,
|
||||
Value: session.UUID,
|
||||
Path: "/",
|
||||
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
|
||||
Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second),
|
||||
MaxAge: int(newExpiry - currentTime),
|
||||
Secure: auth.config.Auth.SecureCookie,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
}, nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (auth *AuthService) DeleteSessionCookie(c *gin.Context) error {
|
||||
cookie, err := c.Cookie(auth.config.SessionCookieName)
|
||||
func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.Cookie, error) {
|
||||
err := auth.queries.DeleteSession(ctx, uuid)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
auth.log.App.Error().Err(err).Str("uuid", uuid).Msg("Failed to delete session from database")
|
||||
}
|
||||
|
||||
err = auth.queries.DeleteSession(c, cookie)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.SetCookie(auth.config.SessionCookieName, "", -1, "/", fmt.Sprintf(".%s", auth.config.CookieDomain), auth.config.SecureCookie, true)
|
||||
|
||||
return nil
|
||||
return &http.Cookie{
|
||||
Name: auth.runtime.SessionCookieName,
|
||||
Value: "",
|
||||
Path: "/",
|
||||
Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain),
|
||||
Expires: time.Now(),
|
||||
MaxAge: -1,
|
||||
Secure: auth.config.Auth.SecureCookie,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (auth *AuthService) GetSessionCookie(c *gin.Context) (repository.Session, error) {
|
||||
cookie, err := c.Cookie(auth.config.SessionCookieName)
|
||||
|
||||
if err != nil {
|
||||
return repository.Session{}, err
|
||||
}
|
||||
|
||||
session, err := auth.queries.GetSession(c, cookie)
|
||||
func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*repository.Session, error) {
|
||||
session, err := auth.queries.GetSession(ctx, uuid)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return repository.Session{}, fmt.Errorf("session not found")
|
||||
return nil, errors.New("session not found")
|
||||
}
|
||||
return repository.Session{}, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
currentTime := time.Now().Unix()
|
||||
|
||||
if auth.config.SessionMaxLifetime != 0 && session.CreatedAt != 0 {
|
||||
if currentTime-session.CreatedAt > int64(auth.config.SessionMaxLifetime) {
|
||||
err = auth.queries.DeleteSession(c, cookie)
|
||||
if auth.config.Auth.SessionMaxLifetime != 0 && session.CreatedAt != 0 {
|
||||
if currentTime-session.CreatedAt > int64(auth.config.Auth.SessionMaxLifetime) {
|
||||
err = auth.queries.DeleteSession(ctx, uuid)
|
||||
if err != nil {
|
||||
tlog.App.Error().Err(err).Msg("Failed to delete session exceeding max lifetime")
|
||||
return nil, fmt.Errorf("failed to delete expired session: %w", err)
|
||||
}
|
||||
return repository.Session{}, fmt.Errorf("session expired due to max lifetime exceeded")
|
||||
return nil, fmt.Errorf("session max lifetime exceeded")
|
||||
}
|
||||
}
|
||||
|
||||
if currentTime > session.Expiry {
|
||||
err = auth.queries.DeleteSession(c, cookie)
|
||||
err = auth.queries.DeleteSession(ctx, uuid)
|
||||
if err != nil {
|
||||
tlog.App.Error().Err(err).Msg("Failed to delete expired session")
|
||||
return nil, fmt.Errorf("failed to delete expired session: %w", err)
|
||||
}
|
||||
return repository.Session{}, fmt.Errorf("session expired")
|
||||
return nil, fmt.Errorf("session expired")
|
||||
}
|
||||
|
||||
return repository.Session{
|
||||
UUID: session.UUID,
|
||||
Username: session.Username,
|
||||
Email: session.Email,
|
||||
Name: session.Name,
|
||||
Provider: session.Provider,
|
||||
TotpPending: session.TotpPending,
|
||||
OAuthGroups: session.OAuthGroups,
|
||||
OAuthName: session.OAuthName,
|
||||
OAuthSub: session.OAuthSub,
|
||||
}, nil
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
func (auth *AuthService) LocalAuthConfigured() bool {
|
||||
return len(auth.config.Users) > 0
|
||||
return len(auth.runtime.LocalUsers) > 0
|
||||
}
|
||||
|
||||
func (auth *AuthService) LdapAuthConfigured() bool {
|
||||
return auth.ldap.IsConfigured()
|
||||
func (auth *AuthService) LDAPAuthConfigured() bool {
|
||||
return auth.ldap != nil
|
||||
}
|
||||
|
||||
func (auth *AuthService) IsUserAllowed(c *gin.Context, context config.UserContext, acls config.App) bool {
|
||||
if context.OAuth {
|
||||
tlog.App.Debug().Msg("Checking OAuth whitelist")
|
||||
return utils.CheckFilter(acls.OAuth.Whitelist, context.Email)
|
||||
func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext, acls *model.App) bool {
|
||||
if acls == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
if context.Provider == model.ProviderOAuth {
|
||||
auth.log.App.Debug().Msg("User is an OAuth user, checking OAuth whitelist")
|
||||
return utils.CheckFilter(acls.OAuth.Whitelist, context.OAuth.Email)
|
||||
}
|
||||
|
||||
if acls.Users.Block != "" {
|
||||
tlog.App.Debug().Msg("Checking blocked users")
|
||||
if utils.CheckFilter(acls.Users.Block, context.Username) {
|
||||
auth.log.App.Debug().Msg("Checking users block list")
|
||||
if utils.CheckFilter(acls.Users.Block, context.GetUsername()) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
tlog.App.Debug().Msg("Checking users")
|
||||
return utils.CheckFilter(acls.Users.Allow, context.Username)
|
||||
auth.log.App.Debug().Msg("Checking users allow list")
|
||||
return utils.CheckFilter(acls.Users.Allow, context.GetUsername())
|
||||
}
|
||||
|
||||
func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context config.UserContext, requiredGroups string) bool {
|
||||
if requiredGroups == "" {
|
||||
func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContext, acls *model.App) bool {
|
||||
if acls == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
for id := range config.OverrideProviders {
|
||||
if context.Provider == id {
|
||||
tlog.App.Info().Str("provider", id).Msg("OAuth groups not supported for this provider")
|
||||
return true
|
||||
}
|
||||
if !context.IsOAuth() {
|
||||
auth.log.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check")
|
||||
return false
|
||||
}
|
||||
|
||||
for userGroup := range strings.SplitSeq(context.OAuthGroups, ",") {
|
||||
if utils.CheckFilter(requiredGroups, strings.TrimSpace(userGroup)) {
|
||||
tlog.App.Trace().Str("group", userGroup).Str("required", requiredGroups).Msg("User group matched")
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
tlog.App.Debug().Msg("No groups matched")
|
||||
return false
|
||||
}
|
||||
|
||||
func (auth *AuthService) IsInLdapGroup(c *gin.Context, context config.UserContext, requiredGroups string) bool {
|
||||
if requiredGroups == "" {
|
||||
if _, ok := model.OverrideProviders[context.OAuth.ID]; ok {
|
||||
auth.log.App.Debug().Str("provider", context.OAuth.ID).Msg("Provider override detected, skipping group check")
|
||||
return true
|
||||
}
|
||||
|
||||
for userGroup := range strings.SplitSeq(context.LdapGroups, ",") {
|
||||
if utils.CheckFilter(requiredGroups, strings.TrimSpace(userGroup)) {
|
||||
tlog.App.Trace().Str("group", userGroup).Str("required", requiredGroups).Msg("User group matched")
|
||||
for _, userGroup := range context.OAuth.Groups {
|
||||
if utils.CheckFilter(acls.OAuth.Groups, strings.TrimSpace(userGroup)) {
|
||||
auth.log.App.Trace().Str("group", userGroup).Str("required", acls.OAuth.Groups).Msg("User group matched")
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
tlog.App.Debug().Msg("No groups matched")
|
||||
auth.log.App.Debug().Msg("No groups matched")
|
||||
return false
|
||||
}
|
||||
|
||||
func (auth *AuthService) IsAuthEnabled(uri string, path config.AppPath) (bool, error) {
|
||||
func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext, acls *model.App) bool {
|
||||
if acls == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
if !context.IsLDAP() {
|
||||
auth.log.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check")
|
||||
return false
|
||||
}
|
||||
|
||||
for _, userGroup := range context.LDAP.Groups {
|
||||
if utils.CheckFilter(acls.LDAP.Groups, strings.TrimSpace(userGroup)) {
|
||||
auth.log.App.Trace().Str("group", userGroup).Str("required", acls.LDAP.Groups).Msg("User group matched")
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
auth.log.App.Debug().Msg("No groups matched")
|
||||
return false
|
||||
}
|
||||
|
||||
func (auth *AuthService) IsAuthEnabled(uri string, acls *model.App) (bool, error) {
|
||||
if acls == nil {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Check for block list
|
||||
if path.Block != "" {
|
||||
regex, err := regexp.Compile(path.Block)
|
||||
if acls.Path.Block != "" {
|
||||
regex, err := regexp.Compile(acls.Path.Block)
|
||||
|
||||
if err != nil {
|
||||
return true, err
|
||||
@@ -540,8 +559,8 @@ func (auth *AuthService) IsAuthEnabled(uri string, path config.AppPath) (bool, e
|
||||
}
|
||||
|
||||
// Check for allow list
|
||||
if path.Allow != "" {
|
||||
regex, err := regexp.Compile(path.Allow)
|
||||
if acls.Path.Allow != "" {
|
||||
regex, err := regexp.Compile(acls.Path.Allow)
|
||||
|
||||
if err != nil {
|
||||
return true, err
|
||||
@@ -555,31 +574,23 @@ func (auth *AuthService) IsAuthEnabled(uri string, path config.AppPath) (bool, e
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (auth *AuthService) GetBasicAuth(c *gin.Context) *config.User {
|
||||
username, password, ok := c.Request.BasicAuth()
|
||||
if !ok {
|
||||
tlog.App.Debug().Msg("No basic auth provided")
|
||||
return nil
|
||||
func (auth *AuthService) CheckIP(ip string, acls *model.App) bool {
|
||||
if acls == nil {
|
||||
return true
|
||||
}
|
||||
return &config.User{
|
||||
Username: username,
|
||||
Password: password,
|
||||
}
|
||||
}
|
||||
|
||||
func (auth *AuthService) CheckIP(acls config.AppIP, ip string) bool {
|
||||
// Merge the global and app IP filter
|
||||
blockedIps := append(auth.config.IP.Block, acls.Block...)
|
||||
allowedIPs := append(auth.config.IP.Allow, acls.Allow...)
|
||||
blockedIps := append(auth.config.Auth.IP.Block, acls.IP.Block...)
|
||||
allowedIPs := append(auth.config.Auth.IP.Allow, acls.IP.Allow...)
|
||||
|
||||
for _, blocked := range blockedIps {
|
||||
res, err := utils.FilterIP(blocked, ip)
|
||||
if err != nil {
|
||||
tlog.App.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list")
|
||||
auth.log.App.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list")
|
||||
continue
|
||||
}
|
||||
if res {
|
||||
tlog.App.Debug().Str("ip", ip).Str("item", blocked).Msg("IP is in blocked list, denying access")
|
||||
auth.log.App.Debug().Str("ip", ip).Str("item", blocked).Msg("IP is in block list, denying access")
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -587,38 +598,42 @@ func (auth *AuthService) CheckIP(acls config.AppIP, ip string) bool {
|
||||
for _, allowed := range allowedIPs {
|
||||
res, err := utils.FilterIP(allowed, ip)
|
||||
if err != nil {
|
||||
tlog.App.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list")
|
||||
auth.log.App.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list")
|
||||
continue
|
||||
}
|
||||
if res {
|
||||
tlog.App.Debug().Str("ip", ip).Str("item", allowed).Msg("IP is in allowed list, allowing access")
|
||||
auth.log.App.Debug().Str("ip", ip).Str("item", allowed).Msg("IP is in allow list, allowing access")
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
if len(allowedIPs) > 0 {
|
||||
tlog.App.Debug().Str("ip", ip).Msg("IP not in allow list, denying access")
|
||||
auth.log.App.Debug().Str("ip", ip).Msg("IP not in allow list, denying access")
|
||||
return false
|
||||
}
|
||||
|
||||
tlog.App.Debug().Str("ip", ip).Msg("IP not in allow or block list, allowing by default")
|
||||
auth.log.App.Debug().Str("ip", ip).Msg("IP not in any block or allow list, allowing access by default")
|
||||
return true
|
||||
}
|
||||
|
||||
func (auth *AuthService) IsBypassedIP(acls config.AppIP, ip string) bool {
|
||||
for _, bypassed := range acls.Bypass {
|
||||
func (auth *AuthService) IsBypassedIP(ip string, acls *model.App) bool {
|
||||
if acls == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, bypassed := range acls.IP.Bypass {
|
||||
res, err := utils.FilterIP(bypassed, ip)
|
||||
if err != nil {
|
||||
tlog.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list")
|
||||
auth.log.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list")
|
||||
continue
|
||||
}
|
||||
if res {
|
||||
tlog.App.Debug().Str("ip", ip).Str("item", bypassed).Msg("IP is in bypass list, allowing access")
|
||||
auth.log.App.Debug().Str("ip", ip).Str("item", bypassed).Msg("IP is in bypass list, skipping authentication")
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
tlog.App.Debug().Str("ip", ip).Msg("IP not in bypass list, continuing with authentication")
|
||||
auth.log.App.Debug().Str("ip", ip).Msg("IP not in bypass list, proceeding with authentication")
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -685,21 +700,21 @@ func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.T
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (auth *AuthService) GetOAuthUserinfo(sessionId string) (config.Claims, error) {
|
||||
func (auth *AuthService) GetOAuthUserinfo(sessionId string) (*model.Claims, error) {
|
||||
session, err := auth.GetOAuthPendingSession(sessionId)
|
||||
|
||||
if err != nil {
|
||||
return config.Claims{}, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if session.Token == nil {
|
||||
return config.Claims{}, fmt.Errorf("oauth token not found for session: %s", sessionId)
|
||||
return nil, fmt.Errorf("oauth token not found for session: %s", sessionId)
|
||||
}
|
||||
|
||||
userinfo, err := (*session.Service).GetUserinfo(session.Token)
|
||||
|
||||
if err != nil {
|
||||
return config.Claims{}, fmt.Errorf("failed to get userinfo: %w", err)
|
||||
return nil, fmt.Errorf("failed to get userinfo: %w", err)
|
||||
}
|
||||
|
||||
return userinfo, nil
|
||||
@@ -722,21 +737,32 @@ func (auth *AuthService) EndOAuthSession(sessionId string) {
|
||||
}
|
||||
|
||||
func (auth *AuthService) CleanupOAuthSessionsRoutine() {
|
||||
auth.log.App.Debug().Msg("Starting OAuth session cleanup routine")
|
||||
|
||||
ticker := time.NewTicker(30 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
auth.oauthMutex.Lock()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
auth.log.App.Debug().Msg("Running OAuth session cleanup")
|
||||
|
||||
now := time.Now()
|
||||
auth.oauthMutex.Lock()
|
||||
|
||||
for sessionId, session := range auth.oauthPendingSessions {
|
||||
if now.After(session.ExpiresAt) {
|
||||
delete(auth.oauthPendingSessions, sessionId)
|
||||
now := time.Now()
|
||||
|
||||
for sessionId, session := range auth.oauthPendingSessions {
|
||||
if now.After(session.ExpiresAt) {
|
||||
delete(auth.oauthPendingSessions, sessionId)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auth.oauthMutex.Unlock()
|
||||
auth.oauthMutex.Unlock()
|
||||
auth.log.App.Debug().Msg("OAuth session cleanup completed")
|
||||
case <-auth.context.Done():
|
||||
auth.log.App.Debug().Msg("Stopping OAuth session cleanup routine")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -805,11 +831,11 @@ func (auth *AuthService) lockdownMode() {
|
||||
|
||||
auth.loginMutex.Lock()
|
||||
|
||||
tlog.App.Warn().Msg("Multiple login attempts detected, possibly DDOS attack. Activating temporary lockdown.")
|
||||
auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")
|
||||
|
||||
auth.lockdown = &Lockdown{
|
||||
Active: true,
|
||||
ActiveUntil: time.Now().Add(time.Duration(auth.config.LoginTimeout) * time.Second),
|
||||
ActiveUntil: time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second),
|
||||
}
|
||||
|
||||
// At this point all login attemps will also expire so,
|
||||
@@ -826,11 +852,14 @@ func (auth *AuthService) lockdownMode() {
|
||||
// Timer expired, end lockdown
|
||||
case <-ctx.Done():
|
||||
// Context cancelled, end lockdown
|
||||
case <-auth.context.Done():
|
||||
// Service is shutting down, end lockdown
|
||||
}
|
||||
|
||||
auth.loginMutex.Lock()
|
||||
|
||||
tlog.App.Info().Msg("Lockdown period ended, resuming normal operation")
|
||||
auth.log.App.Info().Msg("Exiting lockdown mode")
|
||||
|
||||
auth.lockdown = nil
|
||||
auth.loginMutex.Unlock()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user