This commit is contained in:
Stavros
2026-04-29 19:21:07 +03:00
parent 956d2f55c3
commit eec75a6f49
17 changed files with 667 additions and 561 deletions
+130 -138
View File
@@ -5,12 +5,13 @@ import (
"database/sql"
"errors"
"fmt"
"net/http"
"regexp"
"strings"
"sync"
"time"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
@@ -68,7 +69,7 @@ type Lockdown struct {
}
type AuthServiceConfig struct {
Users []config.User
LocalUsers []model.LocalUser
OauthWhitelist []string
SessionExpiry int
SessionMaxLifetime int
@@ -77,7 +78,7 @@ type AuthServiceConfig struct {
LoginTimeout int
LoginMaxRetries int
SessionCookieName string
IP config.IPConfig
IP model.IPConfig
LDAPGroupsCacheTTL int
}
@@ -106,7 +107,7 @@ func NewAuthService(config AuthServiceConfig, ldap *LdapService, queries *reposi
ldap: ldap,
queries: queries,
oauthBroker: oauthBroker,
}
}
}
func (auth *AuthService) Init() error {
@@ -114,79 +115,67 @@ func (auth *AuthService) Init() error {
return nil
}
func (auth *AuthService) SearchUser(username string) config.UserSearch {
func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error) {
if auth.GetLocalUser(username).Username != "" {
return config.UserSearch{
return &model.UserSearch{
Username: username,
Type: "local",
}
Type: model.UserLocal,
}, nil
}
if auth.ldap.IsConfigured() {
userDN, err := auth.ldap.GetUserDN(username)
if err != nil {
tlog.App.Warn().Err(err).Str("username", username).Msg("Failed to search for user in LDAP")
return config.UserSearch{
Type: "unknown",
}
return nil, fmt.Errorf("failed to get ldap user: %w", err)
}
return config.UserSearch{
return &model.UserSearch{
Username: userDN,
Type: "ldap",
}
Type: model.UserLDAP,
}, nil
}
return config.UserSearch{
Type: "unknown",
}
return nil, fmt.Errorf("user not found")
}
func (auth *AuthService) VerifyUser(search config.UserSearch, password string) bool {
func (auth *AuthService) CheckUserPassword(search model.UserSearch, password string) error {
switch search.Type {
case "local":
case model.UserLocal:
user := auth.GetLocalUser(search.Username)
return auth.CheckPassword(user, password)
case "ldap":
return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
case model.UserLDAP:
if auth.ldap.IsConfigured() {
err := auth.ldap.Bind(search.Username, password)
if err != nil {
tlog.App.Warn().Err(err).Str("username", search.Username).Msg("Failed to bind to LDAP")
return false
return fmt.Errorf("failed to bind to ldap user: %w", err)
}
err = auth.ldap.BindService(true)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to rebind with service account after user authentication")
return false
return fmt.Errorf("failed to bind to ldap service account: %w", err)
}
return true
return nil
}
default:
tlog.App.Debug().Str("type", search.Type).Msg("Unknown user type for authentication")
return false
return errors.New("unknown user search type")
}
tlog.App.Warn().Str("username", search.Username).Msg("User authentication failed")
return false
return errors.New("user authentication failed")
}
func (auth *AuthService) GetLocalUser(username string) config.User {
for _, user := range auth.config.Users {
func (auth *AuthService) GetLocalUser(username string) *model.LocalUser {
for _, user := range auth.config.LocalUsers {
if user.Username == username {
return user
return &user
}
}
tlog.App.Warn().Str("username", username).Msg("Local user not found")
return config.User{}
return nil
}
func (auth *AuthService) GetLdapUser(userDN string) (config.LdapUser, error) {
func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
if !auth.ldap.IsConfigured() {
return config.LdapUser{}, errors.New("LDAP service not initialized")
return nil, errors.New("ldap service not configured")
}
auth.ldapGroupsMutex.RLock()
@@ -194,7 +183,7 @@ func (auth *AuthService) GetLdapUser(userDN string) (config.LdapUser, error) {
auth.ldapGroupsMutex.RUnlock()
if exists && time.Now().Before(entry.Expires) {
return config.LdapUser{
return &model.LDAPUser{
DN: userDN,
Groups: entry.Groups,
}, nil
@@ -203,7 +192,7 @@ func (auth *AuthService) GetLdapUser(userDN string) (config.LdapUser, error) {
groups, err := auth.ldap.GetUserGroups(userDN)
if err != nil {
return config.LdapUser{}, err
return nil, fmt.Errorf("failed to get ldap groups: %w", err)
}
auth.ldapGroupsMutex.Lock()
@@ -213,16 +202,12 @@ func (auth *AuthService) GetLdapUser(userDN string) (config.LdapUser, error) {
}
auth.ldapGroupsMutex.Unlock()
return config.LdapUser{
return &model.LDAPUser{
DN: userDN,
Groups: groups,
}, nil
}
func (auth *AuthService) CheckPassword(user config.User, password string) bool {
return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) == nil
}
func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
auth.loginMutex.RLock()
defer auth.loginMutex.RUnlock()
@@ -291,11 +276,11 @@ func (auth *AuthService) IsEmailWhitelisted(email string) bool {
return utils.CheckFilter(strings.Join(auth.config.OauthWhitelist, ","), email)
}
func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *repository.Session) error {
func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) {
uuid, err := uuid.NewRandom()
if err != nil {
return err
return nil, fmt.Errorf("failed to generate session uuid: %w", err)
}
var expiry int
@@ -320,28 +305,30 @@ func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *repository.Se
OAuthSub: data.OAuthSub,
}
_, err = auth.queries.CreateSession(c, session)
_, err = auth.queries.CreateSession(ctx, session)
if err != nil {
return err
return nil, fmt.Errorf("failed to create session entry: %w", err)
}
c.SetCookie(auth.config.SessionCookieName, session.UUID, expiry, "/", fmt.Sprintf(".%s", auth.config.CookieDomain), auth.config.SecureCookie, true)
return nil
return &http.Cookie{
Name: auth.config.SessionCookieName,
Value: session.UUID,
Path: "/",
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
Expires: time.Now().Add(time.Duration(expiry) * time.Second),
MaxAge: expiry,
Secure: auth.config.SecureCookie,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}, nil
}
func (auth *AuthService) RefreshSessionCookie(c *gin.Context) error {
cookie, err := c.Cookie(auth.config.SessionCookieName)
func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http.Cookie, error) {
session, err := auth.queries.GetSession(ctx, uuid)
if err != nil {
return err
}
session, err := auth.queries.GetSession(c, cookie)
if err != nil {
return err
return nil, fmt.Errorf("failed to retrieve session: %w", err)
}
currentTime := time.Now().Unix()
@@ -355,12 +342,12 @@ func (auth *AuthService) RefreshSessionCookie(c *gin.Context) error {
}
if session.Expiry-currentTime > refreshThreshold {
return nil
return nil, fmt.Errorf("session not eligible for refresh yet")
}
newExpiry := session.Expiry + refreshThreshold
_, err = auth.queries.UpdateSession(c, repository.UpdateSessionParams{
_, err = auth.queries.UpdateSession(ctx, repository.UpdateSessionParams{
Username: session.Username,
Email: session.Email,
Name: session.Name,
@@ -374,120 +361,117 @@ func (auth *AuthService) RefreshSessionCookie(c *gin.Context) error {
})
if err != nil {
return err
return nil, fmt.Errorf("failed to update session expiry: %w", err)
}
c.SetCookie(auth.config.SessionCookieName, cookie, int(newExpiry-currentTime), "/", fmt.Sprintf(".%s", auth.config.CookieDomain), auth.config.SecureCookie, true)
tlog.App.Trace().Str("username", session.Username).Msg("Session cookie refreshed")
return &http.Cookie{
Name: auth.config.SessionCookieName,
Value: session.UUID,
Path: "/",
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second),
MaxAge: auth.config.SessionExpiry,
Secure: auth.config.SecureCookie,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}, nil
return nil
}
func (auth *AuthService) DeleteSessionCookie(c *gin.Context) error {
cookie, err := c.Cookie(auth.config.SessionCookieName)
func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.Cookie, error) {
err := auth.queries.DeleteSession(ctx, uuid)
if err != nil {
return err
tlog.App.Warn().Err(err).Msg("Failed to delete session from database, proceeding to clear cookie anyway")
}
err = auth.queries.DeleteSession(c, cookie)
if err != nil {
return err
}
c.SetCookie(auth.config.SessionCookieName, "", -1, "/", fmt.Sprintf(".%s", auth.config.CookieDomain), auth.config.SecureCookie, true)
return nil
return &http.Cookie{
Name: auth.config.SessionCookieName,
Value: "",
Path: "/",
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
Expires: time.Now(),
MaxAge: -1,
Secure: auth.config.SecureCookie,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}, nil
}
func (auth *AuthService) GetSessionCookie(c *gin.Context) (repository.Session, error) {
cookie, err := c.Cookie(auth.config.SessionCookieName)
if err != nil {
return repository.Session{}, err
}
session, err := auth.queries.GetSession(c, cookie)
func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*repository.Session, error) {
session, err := auth.queries.GetSession(ctx, uuid)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return repository.Session{}, fmt.Errorf("session not found")
return nil, errors.New("session not found")
}
return repository.Session{}, err
return nil, err
}
currentTime := time.Now().Unix()
if auth.config.SessionMaxLifetime != 0 && session.CreatedAt != 0 {
if currentTime-session.CreatedAt > int64(auth.config.SessionMaxLifetime) {
err = auth.queries.DeleteSession(c, cookie)
err = auth.queries.DeleteSession(ctx, uuid)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to delete session exceeding max lifetime")
return nil, fmt.Errorf("failed to delete expired session: %w", err)
}
return repository.Session{}, fmt.Errorf("session expired due to max lifetime exceeded")
return nil, fmt.Errorf("session max lifetime exceeded")
}
}
if currentTime > session.Expiry {
err = auth.queries.DeleteSession(c, cookie)
err = auth.queries.DeleteSession(ctx, uuid)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to delete expired session")
return nil, fmt.Errorf("failed to delete expired session: %w", err)
}
return repository.Session{}, fmt.Errorf("session expired")
return nil, fmt.Errorf("session expired")
}
return repository.Session{
UUID: session.UUID,
Username: session.Username,
Email: session.Email,
Name: session.Name,
Provider: session.Provider,
TotpPending: session.TotpPending,
OAuthGroups: session.OAuthGroups,
OAuthName: session.OAuthName,
OAuthSub: session.OAuthSub,
}, nil
return &session, nil
}
func (auth *AuthService) LocalAuthConfigured() bool {
return len(auth.config.Users) > 0
return len(auth.config.LocalUsers) > 0
}
func (auth *AuthService) LdapAuthConfigured() bool {
func (auth *AuthService) LDAPAuthConfigured() bool {
return auth.ldap.IsConfigured()
}
func (auth *AuthService) IsUserAllowed(c *gin.Context, context config.UserContext, acls config.App) bool {
if context.OAuth {
func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext, acls model.App) bool {
if context.Provider == model.ProviderOAuth {
tlog.App.Debug().Msg("Checking OAuth whitelist")
return utils.CheckFilter(acls.OAuth.Whitelist, context.Email)
return utils.CheckFilter(acls.OAuth.Whitelist, context.OAuth.Email)
}
if acls.Users.Block != "" {
tlog.App.Debug().Msg("Checking blocked users")
if utils.CheckFilter(acls.Users.Block, context.Username) {
if utils.CheckFilter(acls.Users.Block, context.GetUsername()) {
return false
}
}
tlog.App.Debug().Msg("Checking users")
return utils.CheckFilter(acls.Users.Allow, context.Username)
return utils.CheckFilter(acls.Users.Allow, context.GetUsername())
}
func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context config.UserContext, requiredGroups string) bool {
func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContext, requiredGroups string) bool {
if requiredGroups == "" {
return true
}
for id := range config.OverrideProviders {
if context.Provider == id {
tlog.App.Info().Str("provider", id).Msg("OAuth groups not supported for this provider")
return true
}
if !context.IsOAuth() {
tlog.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check")
return false
}
for userGroup := range strings.SplitSeq(context.OAuthGroups, ",") {
if _, ok := model.OverrideProviders[context.OAuth.ID]; ok {
tlog.App.Debug().Msg("Provider override for OAuth groups enabled, skipping group check")
return true
}
for _, userGroup := range context.OAuth.Groups {
if utils.CheckFilter(requiredGroups, strings.TrimSpace(userGroup)) {
tlog.App.Trace().Str("group", userGroup).Str("required", requiredGroups).Msg("User group matched")
return true
@@ -498,12 +482,17 @@ func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context config.UserConte
return false
}
func (auth *AuthService) IsInLdapGroup(c *gin.Context, context config.UserContext, requiredGroups string) bool {
func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext, requiredGroups string) bool {
if requiredGroups == "" {
return true
}
for userGroup := range strings.SplitSeq(context.LdapGroups, ",") {
if !context.IsLDAP() {
tlog.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check")
return false
}
for _, userGroup := range context.LDAP.Groups {
if utils.CheckFilter(requiredGroups, strings.TrimSpace(userGroup)) {
tlog.App.Trace().Str("group", userGroup).Str("required", requiredGroups).Msg("User group matched")
return true
@@ -514,7 +503,7 @@ func (auth *AuthService) IsInLdapGroup(c *gin.Context, context config.UserContex
return false
}
func (auth *AuthService) IsAuthEnabled(uri string, path config.AppPath) (bool, error) {
func (auth *AuthService) IsAuthEnabled(uri string, path model.AppPath) (bool, error) {
// Check for block list
if path.Block != "" {
regex, err := regexp.Compile(path.Block)
@@ -544,19 +533,22 @@ func (auth *AuthService) IsAuthEnabled(uri string, path config.AppPath) (bool, e
return true, nil
}
func (auth *AuthService) GetBasicAuth(c *gin.Context) *config.User {
username, password, ok := c.Request.BasicAuth()
if !ok {
tlog.App.Debug().Msg("No basic auth provided")
return nil
// local user is used only as a medium to pass the basic auth credentials, user can be ldap too
func (auth *AuthService) GetBasicAuth(req *http.Request) (*model.LocalUser, error) {
if req == nil {
return nil, errors.New("request is nil")
}
return &config.User{
username, password, ok := req.BasicAuth()
if !ok {
return nil, errors.New("no basic auth credentials provided")
}
return &model.LocalUser{
Username: username,
Password: password,
}
}, nil
}
func (auth *AuthService) CheckIP(acls config.AppIP, ip string) bool {
func (auth *AuthService) CheckIP(acls model.AppIP, ip string) bool {
// Merge the global and app IP filter
blockedIps := append(auth.config.IP.Block, acls.Block...)
allowedIPs := append(auth.config.IP.Allow, acls.Allow...)
@@ -594,7 +586,7 @@ func (auth *AuthService) CheckIP(acls config.AppIP, ip string) bool {
return true
}
func (auth *AuthService) IsBypassedIP(acls config.AppIP, ip string) bool {
func (auth *AuthService) IsBypassedIP(acls model.AppIP, ip string) bool {
for _, bypassed := range acls.Bypass {
res, err := utils.FilterIP(bypassed, ip)
if err != nil {
@@ -674,21 +666,21 @@ func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.T
return token, nil
}
func (auth *AuthService) GetOAuthUserinfo(sessionId string) (config.Claims, error) {
func (auth *AuthService) GetOAuthUserinfo(sessionId string) (*model.Claims, error) {
session, err := auth.GetOAuthPendingSession(sessionId)
if err != nil {
return config.Claims{}, err
return nil, err
}
if session.Token == nil {
return config.Claims{}, fmt.Errorf("oauth token not found for session: %s", sessionId)
return nil, fmt.Errorf("oauth token not found for session: %s", sessionId)
}
userinfo, err := (*session.Service).GetUserinfo(session.Token)
if err != nil {
return config.Claims{}, fmt.Errorf("failed to get userinfo: %w", err)
return nil, fmt.Errorf("failed to get userinfo: %w", err)
}
return userinfo, nil