feat: move oauth logic into auth service and handle multiple sessions

This commit is contained in:
Stavros
2026-03-21 16:37:04 +02:00
parent 2491d453cf
commit 7bead41ae9
8 changed files with 169 additions and 160 deletions

View File

@@ -69,13 +69,14 @@ type AuthService struct {
func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapService, queries *repository.Queries, oauthBroker *OAuthBrokerService) *AuthService {
return &AuthService{
config: config,
docker: docker,
loginAttempts: make(map[string]*LoginAttempt),
ldapGroupsCache: make(map[string]*LdapGroupsCache),
ldap: ldap,
queries: queries,
oauthBroker: oauthBroker,
config: config,
docker: docker,
loginAttempts: make(map[string]*LoginAttempt),
ldapGroupsCache: make(map[string]*LdapGroupsCache),
oauthPendingSessions: make(map[string]*OAuthPendingSession),
ldap: ldap,
queries: queries,
oauthBroker: oauthBroker,
}
}
@@ -568,66 +569,51 @@ func (auth *AuthService) IsBypassedIP(acls config.AppIP, ip string) bool {
return false
}
func (auth *AuthService) NewOAuthSession(serviceName string) (string, error) {
func (auth *AuthService) NewOAuthSession(serviceName string) (string, OAuthPendingSession, error) {
service, ok := auth.oauthBroker.GetService(serviceName)
if !ok {
return "", fmt.Errorf("oauth service not found: %s", serviceName)
return "", OAuthPendingSession{}, fmt.Errorf("oauth service not found: %s", serviceName)
}
sessionId, err := uuid.NewRandom()
if err != nil {
return "", fmt.Errorf("failed to generate session ID: %w", err)
return "", OAuthPendingSession{}, fmt.Errorf("failed to generate session ID: %w", err)
}
state := uuid.New().String()
verifier := uuid.New().String()
state := service.NewRandom()
verifier := service.NewRandom()
auth.oauthMutex.Lock()
auth.oauthPendingSessions[sessionId.String()] = &OAuthPendingSession{
session := OAuthPendingSession{
State: state,
Verifier: verifier,
Service: &service,
ExpiresAt: time.Now().Add(1 * time.Hour),
}
auth.oauthMutex.Lock()
auth.oauthPendingSessions[sessionId.String()] = &session
auth.oauthMutex.Unlock()
return sessionId.String(), nil
return sessionId.String(), session, nil
}
func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) {
auth.oauthMutex.RLock()
defer auth.oauthMutex.RUnlock()
session, err := auth.getOAuthPendingSession(sessionId)
session, exists := auth.oauthPendingSessions[sessionId]
if !exists {
return "", fmt.Errorf("oauth session not found: %s", sessionId)
}
if time.Now().After(session.ExpiresAt) {
delete(auth.oauthPendingSessions, sessionId)
return "", fmt.Errorf("oauth session expired: %s", sessionId)
if err != nil {
return "", err
}
return (*session.Service).GetAuthURL(session.State, session.Verifier), nil
}
func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.Token, error) {
auth.oauthMutex.RLock()
session, exists := auth.oauthPendingSessions[sessionId]
auth.oauthMutex.RUnlock()
session, err := auth.getOAuthPendingSession(sessionId)
if !exists {
return nil, fmt.Errorf("oauth session not found: %s", sessionId)
}
if time.Now().After(session.ExpiresAt) {
auth.oauthMutex.Lock()
delete(auth.oauthPendingSessions, sessionId)
auth.oauthMutex.Unlock()
return nil, fmt.Errorf("oauth session expired: %s", sessionId)
if err != nil {
return nil, err
}
token, err := (*session.Service).GetToken(code, session.Verifier)
@@ -644,19 +630,10 @@ func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.T
}
func (auth *AuthService) GetOAuthUserinfo(sessionId string) (config.Claims, error) {
auth.oauthMutex.RLock()
session, exists := auth.oauthPendingSessions[sessionId]
auth.oauthMutex.RUnlock()
session, err := auth.getOAuthPendingSession(sessionId)
if !exists {
return config.Claims{}, fmt.Errorf("oauth session not found: %s", sessionId)
}
if time.Now().After(session.ExpiresAt) {
auth.oauthMutex.Lock()
delete(auth.oauthPendingSessions, sessionId)
auth.oauthMutex.Unlock()
return config.Claims{}, fmt.Errorf("oauth session expired: %s", sessionId)
if err != nil {
return config.Claims{}, err
}
if session.Token == nil {
@@ -669,13 +646,19 @@ func (auth *AuthService) GetOAuthUserinfo(sessionId string) (config.Claims, erro
return config.Claims{}, fmt.Errorf("failed to get userinfo: %w", err)
}
auth.oauthMutex.Lock()
delete(auth.oauthPendingSessions, sessionId)
auth.oauthMutex.Unlock()
return userinfo, nil
}
func (auth *AuthService) GetOAuthService(sessionId string) (OAuthServiceImpl, error) {
session, err := auth.getOAuthPendingSession(sessionId)
if err != nil {
return nil, err
}
return *session.Service, nil
}
func (auth *AuthService) EndOAuthSession(sessionId string) {
auth.oauthMutex.Lock()
delete(auth.oauthPendingSessions, sessionId)
@@ -699,3 +682,22 @@ func (auth *AuthService) CleanupOAuthSessionsRoutine() {
}
}
}
func (auth *AuthService) getOAuthPendingSession(sessionId string) (*OAuthPendingSession, error) {
auth.oauthMutex.RLock()
session, exists := auth.oauthPendingSessions[sessionId]
auth.oauthMutex.RUnlock()
if !exists {
return &OAuthPendingSession{}, fmt.Errorf("oauth session not found: %s", sessionId)
}
if time.Now().After(session.ExpiresAt) {
auth.oauthMutex.Lock()
delete(auth.oauthPendingSessions, sessionId)
auth.oauthMutex.Unlock()
return &OAuthPendingSession{}, fmt.Errorf("oauth session expired: %s", sessionId)
}
return session, nil
}