mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-05-24 05:00:15 +00:00
Merge branch 'main' into feat/tailscale
This commit is contained in:
@@ -2,11 +2,9 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -18,7 +16,6 @@ import (
|
||||
|
||||
"slices"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"golang.org/x/oauth2"
|
||||
@@ -79,7 +76,7 @@ type AuthService struct {
|
||||
context context.Context
|
||||
|
||||
ldap *LdapService
|
||||
queries *repository.Queries
|
||||
queries repository.Store
|
||||
oauthBroker *OAuthBrokerService
|
||||
tailscale *TailscaleService
|
||||
|
||||
@@ -101,7 +98,7 @@ func NewAuthService(
|
||||
ctx context.Context,
|
||||
wg *sync.WaitGroup,
|
||||
ldap *LdapService,
|
||||
queries *repository.Queries,
|
||||
queries repository.Store,
|
||||
oauthBroker *OAuthBrokerService,
|
||||
tailscale *TailscaleService,
|
||||
) *AuthService {
|
||||
@@ -133,7 +130,7 @@ func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error)
|
||||
}
|
||||
|
||||
if auth.ldap != nil {
|
||||
userDN, err := auth.ldap.GetUserDN(username)
|
||||
userDN, email, err := auth.ldap.GetUserInfo(username)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get ldap user: %w", err)
|
||||
@@ -141,6 +138,7 @@ func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error)
|
||||
|
||||
return &model.UserSearch{
|
||||
Username: userDN,
|
||||
Email: email,
|
||||
Type: model.UserLDAP,
|
||||
}, nil
|
||||
}
|
||||
@@ -288,7 +286,12 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
|
||||
}
|
||||
|
||||
func (auth *AuthService) IsEmailWhitelisted(email string) bool {
|
||||
return utils.CheckFilter(strings.Join(auth.runtime.OAuthWhitelist, ","), email)
|
||||
match, err := utils.CheckFilter(strings.Join(auth.runtime.OAuthWhitelist, ","), email)
|
||||
if err != nil {
|
||||
auth.log.App.Warn().Err(err).Str("email", email).Msg("Invalid email filter pattern")
|
||||
return false
|
||||
}
|
||||
return match
|
||||
}
|
||||
|
||||
func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) {
|
||||
@@ -445,7 +448,7 @@ func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*reposito
|
||||
session, err := auth.queries.GetSession(ctx, uuid)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
if errors.Is(err, repository.ErrNotFound) {
|
||||
return nil, errors.New("session not found")
|
||||
}
|
||||
return nil, err
|
||||
@@ -482,171 +485,6 @@ func (auth *AuthService) LDAPAuthConfigured() bool {
|
||||
return auth.ldap != nil
|
||||
}
|
||||
|
||||
func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext, acls *model.App) bool {
|
||||
if acls == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
if context.Provider == model.ProviderOAuth {
|
||||
auth.log.App.Debug().Msg("User is an OAuth user, checking OAuth whitelist")
|
||||
return utils.CheckFilter(acls.OAuth.Whitelist, context.OAuth.Email)
|
||||
}
|
||||
|
||||
if acls.Users.Block != "" {
|
||||
auth.log.App.Debug().Msg("Checking users block list")
|
||||
if utils.CheckFilter(acls.Users.Block, context.GetUsername()) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
auth.log.App.Debug().Msg("Checking users allow list")
|
||||
return utils.CheckFilter(acls.Users.Allow, context.GetUsername())
|
||||
}
|
||||
|
||||
func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContext, acls *model.App) bool {
|
||||
if acls == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
if !context.IsOAuth() {
|
||||
auth.log.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check")
|
||||
return false
|
||||
}
|
||||
|
||||
if _, ok := model.OverrideProviders[context.OAuth.ID]; ok {
|
||||
auth.log.App.Debug().Str("provider", context.OAuth.ID).Msg("Provider override detected, skipping group check")
|
||||
return true
|
||||
}
|
||||
|
||||
for _, userGroup := range context.OAuth.Groups {
|
||||
if utils.CheckFilter(acls.OAuth.Groups, strings.TrimSpace(userGroup)) {
|
||||
auth.log.App.Trace().Str("group", userGroup).Str("required", acls.OAuth.Groups).Msg("User group matched")
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
auth.log.App.Debug().Msg("No groups matched")
|
||||
return false
|
||||
}
|
||||
|
||||
func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext, acls *model.App) bool {
|
||||
if acls == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
if !context.IsLDAP() {
|
||||
auth.log.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check")
|
||||
return false
|
||||
}
|
||||
|
||||
for _, userGroup := range context.LDAP.Groups {
|
||||
if utils.CheckFilter(acls.LDAP.Groups, strings.TrimSpace(userGroup)) {
|
||||
auth.log.App.Trace().Str("group", userGroup).Str("required", acls.LDAP.Groups).Msg("User group matched")
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
auth.log.App.Debug().Msg("No groups matched")
|
||||
return false
|
||||
}
|
||||
|
||||
func (auth *AuthService) IsAuthEnabled(uri string, acls *model.App) (bool, error) {
|
||||
if acls == nil {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Check for block list
|
||||
if acls.Path.Block != "" {
|
||||
regex, err := regexp.Compile(acls.Path.Block)
|
||||
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
|
||||
if !regex.MatchString(uri) {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Check for allow list
|
||||
if acls.Path.Allow != "" {
|
||||
regex, err := regexp.Compile(acls.Path.Allow)
|
||||
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
|
||||
if regex.MatchString(uri) {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (auth *AuthService) CheckIP(ip string, acls *model.App) bool {
|
||||
if acls == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
// Merge the global and app IP filter
|
||||
blockedIps := append(auth.config.Auth.IP.Block, acls.IP.Block...)
|
||||
allowedIPs := append(auth.config.Auth.IP.Allow, acls.IP.Allow...)
|
||||
|
||||
for _, blocked := range blockedIps {
|
||||
res, err := utils.FilterIP(blocked, ip)
|
||||
if err != nil {
|
||||
auth.log.App.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list")
|
||||
continue
|
||||
}
|
||||
if res {
|
||||
auth.log.App.Debug().Str("ip", ip).Str("item", blocked).Msg("IP is in block list, denying access")
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
for _, allowed := range allowedIPs {
|
||||
res, err := utils.FilterIP(allowed, ip)
|
||||
if err != nil {
|
||||
auth.log.App.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list")
|
||||
continue
|
||||
}
|
||||
if res {
|
||||
auth.log.App.Debug().Str("ip", ip).Str("item", allowed).Msg("IP is in allow list, allowing access")
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
if len(allowedIPs) > 0 {
|
||||
auth.log.App.Debug().Str("ip", ip).Msg("IP not in allow list, denying access")
|
||||
return false
|
||||
}
|
||||
|
||||
auth.log.App.Debug().Str("ip", ip).Msg("IP not in any block or allow list, allowing access by default")
|
||||
return true
|
||||
}
|
||||
|
||||
func (auth *AuthService) IsBypassedIP(ip string, acls *model.App) bool {
|
||||
if acls == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, bypassed := range acls.IP.Bypass {
|
||||
res, err := utils.FilterIP(bypassed, ip)
|
||||
if err != nil {
|
||||
auth.log.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list")
|
||||
continue
|
||||
}
|
||||
if res {
|
||||
auth.log.App.Debug().Str("ip", ip).Str("item", bypassed).Msg("IP is in bypass list, skipping authentication")
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
auth.log.App.Debug().Str("ip", ip).Msg("IP not in bypass list, proceeding with authentication")
|
||||
return false
|
||||
}
|
||||
|
||||
func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLParams) (string, OAuthPendingSession, error) {
|
||||
auth.ensureOAuthSessionLimit()
|
||||
|
||||
@@ -801,46 +639,49 @@ func (auth *AuthService) ensureOAuthSessionLimit() {
|
||||
auth.oauthMutex.Lock()
|
||||
defer auth.oauthMutex.Unlock()
|
||||
|
||||
if len(auth.oauthPendingSessions) >= MaxOAuthPendingSessions {
|
||||
if len(auth.oauthPendingSessions) <= MaxOAuthPendingSessions {
|
||||
return
|
||||
}
|
||||
|
||||
cleanupIds := make([]string, 0, OAuthCleanupCount)
|
||||
type entry struct {
|
||||
id string
|
||||
expiresAt int64
|
||||
}
|
||||
|
||||
for range OAuthCleanupCount {
|
||||
oldestId := ""
|
||||
oldestTime := int64(0)
|
||||
entries := make([]entry, 0, len(auth.oauthPendingSessions))
|
||||
for id, session := range auth.oauthPendingSessions {
|
||||
entries = append(entries, entry{id, session.ExpiresAt.Unix()})
|
||||
}
|
||||
|
||||
for id, session := range auth.oauthPendingSessions {
|
||||
if oldestTime == 0 {
|
||||
oldestId = id
|
||||
oldestTime = session.ExpiresAt.Unix()
|
||||
continue
|
||||
}
|
||||
if slices.Contains(cleanupIds, id) {
|
||||
continue
|
||||
}
|
||||
if session.ExpiresAt.Unix() < oldestTime {
|
||||
oldestId = id
|
||||
oldestTime = session.ExpiresAt.Unix()
|
||||
}
|
||||
}
|
||||
|
||||
cleanupIds = append(cleanupIds, oldestId)
|
||||
slices.SortFunc(entries, func(a, b entry) int {
|
||||
if a.expiresAt < b.expiresAt {
|
||||
return -1
|
||||
}
|
||||
|
||||
for _, id := range cleanupIds {
|
||||
delete(auth.oauthPendingSessions, id)
|
||||
if a.expiresAt > b.expiresAt {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
})
|
||||
|
||||
for _, e := range entries[:OAuthCleanupCount] {
|
||||
delete(auth.oauthPendingSessions, e.id)
|
||||
}
|
||||
}
|
||||
|
||||
func (auth *AuthService) lockdownMode() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
auth.lockdownCtx = ctx
|
||||
auth.lockdownCancelFunc = cancel
|
||||
|
||||
auth.loginMutex.Lock()
|
||||
|
||||
if auth.lockdown != nil && auth.lockdown.Active {
|
||||
auth.loginMutex.Unlock()
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
|
||||
auth.lockdownCtx = ctx
|
||||
auth.lockdownCancelFunc = cancel
|
||||
|
||||
auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")
|
||||
|
||||
auth.lockdown = &Lockdown{
|
||||
@@ -853,10 +694,12 @@ func (auth *AuthService) lockdownMode() {
|
||||
auth.loginAttempts = make(map[string]*LoginAttempt)
|
||||
|
||||
timer := time.NewTimer(time.Until(auth.lockdown.ActiveUntil))
|
||||
defer timer.Stop()
|
||||
|
||||
auth.loginMutex.Unlock()
|
||||
|
||||
defer cancel()
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case <-timer.C:
|
||||
// Timer expired, end lockdown
|
||||
|
||||
Reference in New Issue
Block a user