mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-03-24 07:27:52 +00:00
feat: move oauth logic into auth service and handle multiple sessions
This commit is contained in:
@@ -69,13 +69,14 @@ type AuthService struct {
|
||||
|
||||
func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapService, queries *repository.Queries, oauthBroker *OAuthBrokerService) *AuthService {
|
||||
return &AuthService{
|
||||
config: config,
|
||||
docker: docker,
|
||||
loginAttempts: make(map[string]*LoginAttempt),
|
||||
ldapGroupsCache: make(map[string]*LdapGroupsCache),
|
||||
ldap: ldap,
|
||||
queries: queries,
|
||||
oauthBroker: oauthBroker,
|
||||
config: config,
|
||||
docker: docker,
|
||||
loginAttempts: make(map[string]*LoginAttempt),
|
||||
ldapGroupsCache: make(map[string]*LdapGroupsCache),
|
||||
oauthPendingSessions: make(map[string]*OAuthPendingSession),
|
||||
ldap: ldap,
|
||||
queries: queries,
|
||||
oauthBroker: oauthBroker,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -568,66 +569,51 @@ func (auth *AuthService) IsBypassedIP(acls config.AppIP, ip string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (auth *AuthService) NewOAuthSession(serviceName string) (string, error) {
|
||||
func (auth *AuthService) NewOAuthSession(serviceName string) (string, OAuthPendingSession, error) {
|
||||
service, ok := auth.oauthBroker.GetService(serviceName)
|
||||
|
||||
if !ok {
|
||||
return "", fmt.Errorf("oauth service not found: %s", serviceName)
|
||||
return "", OAuthPendingSession{}, fmt.Errorf("oauth service not found: %s", serviceName)
|
||||
}
|
||||
|
||||
sessionId, err := uuid.NewRandom()
|
||||
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate session ID: %w", err)
|
||||
return "", OAuthPendingSession{}, fmt.Errorf("failed to generate session ID: %w", err)
|
||||
}
|
||||
|
||||
state := uuid.New().String()
|
||||
verifier := uuid.New().String()
|
||||
state := service.NewRandom()
|
||||
verifier := service.NewRandom()
|
||||
|
||||
auth.oauthMutex.Lock()
|
||||
auth.oauthPendingSessions[sessionId.String()] = &OAuthPendingSession{
|
||||
session := OAuthPendingSession{
|
||||
State: state,
|
||||
Verifier: verifier,
|
||||
Service: &service,
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
}
|
||||
|
||||
auth.oauthMutex.Lock()
|
||||
auth.oauthPendingSessions[sessionId.String()] = &session
|
||||
auth.oauthMutex.Unlock()
|
||||
|
||||
return sessionId.String(), nil
|
||||
return sessionId.String(), session, nil
|
||||
}
|
||||
|
||||
func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) {
|
||||
auth.oauthMutex.RLock()
|
||||
defer auth.oauthMutex.RUnlock()
|
||||
session, err := auth.getOAuthPendingSession(sessionId)
|
||||
|
||||
session, exists := auth.oauthPendingSessions[sessionId]
|
||||
|
||||
if !exists {
|
||||
return "", fmt.Errorf("oauth session not found: %s", sessionId)
|
||||
}
|
||||
|
||||
if time.Now().After(session.ExpiresAt) {
|
||||
delete(auth.oauthPendingSessions, sessionId)
|
||||
return "", fmt.Errorf("oauth session expired: %s", sessionId)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return (*session.Service).GetAuthURL(session.State, session.Verifier), nil
|
||||
}
|
||||
|
||||
func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.Token, error) {
|
||||
auth.oauthMutex.RLock()
|
||||
session, exists := auth.oauthPendingSessions[sessionId]
|
||||
auth.oauthMutex.RUnlock()
|
||||
session, err := auth.getOAuthPendingSession(sessionId)
|
||||
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("oauth session not found: %s", sessionId)
|
||||
}
|
||||
|
||||
if time.Now().After(session.ExpiresAt) {
|
||||
auth.oauthMutex.Lock()
|
||||
delete(auth.oauthPendingSessions, sessionId)
|
||||
auth.oauthMutex.Unlock()
|
||||
return nil, fmt.Errorf("oauth session expired: %s", sessionId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
token, err := (*session.Service).GetToken(code, session.Verifier)
|
||||
@@ -644,19 +630,10 @@ func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.T
|
||||
}
|
||||
|
||||
func (auth *AuthService) GetOAuthUserinfo(sessionId string) (config.Claims, error) {
|
||||
auth.oauthMutex.RLock()
|
||||
session, exists := auth.oauthPendingSessions[sessionId]
|
||||
auth.oauthMutex.RUnlock()
|
||||
session, err := auth.getOAuthPendingSession(sessionId)
|
||||
|
||||
if !exists {
|
||||
return config.Claims{}, fmt.Errorf("oauth session not found: %s", sessionId)
|
||||
}
|
||||
|
||||
if time.Now().After(session.ExpiresAt) {
|
||||
auth.oauthMutex.Lock()
|
||||
delete(auth.oauthPendingSessions, sessionId)
|
||||
auth.oauthMutex.Unlock()
|
||||
return config.Claims{}, fmt.Errorf("oauth session expired: %s", sessionId)
|
||||
if err != nil {
|
||||
return config.Claims{}, err
|
||||
}
|
||||
|
||||
if session.Token == nil {
|
||||
@@ -669,13 +646,19 @@ func (auth *AuthService) GetOAuthUserinfo(sessionId string) (config.Claims, erro
|
||||
return config.Claims{}, fmt.Errorf("failed to get userinfo: %w", err)
|
||||
}
|
||||
|
||||
auth.oauthMutex.Lock()
|
||||
delete(auth.oauthPendingSessions, sessionId)
|
||||
auth.oauthMutex.Unlock()
|
||||
|
||||
return userinfo, nil
|
||||
}
|
||||
|
||||
func (auth *AuthService) GetOAuthService(sessionId string) (OAuthServiceImpl, error) {
|
||||
session, err := auth.getOAuthPendingSession(sessionId)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return *session.Service, nil
|
||||
}
|
||||
|
||||
func (auth *AuthService) EndOAuthSession(sessionId string) {
|
||||
auth.oauthMutex.Lock()
|
||||
delete(auth.oauthPendingSessions, sessionId)
|
||||
@@ -699,3 +682,22 @@ func (auth *AuthService) CleanupOAuthSessionsRoutine() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (auth *AuthService) getOAuthPendingSession(sessionId string) (*OAuthPendingSession, error) {
|
||||
auth.oauthMutex.RLock()
|
||||
session, exists := auth.oauthPendingSessions[sessionId]
|
||||
auth.oauthMutex.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return &OAuthPendingSession{}, fmt.Errorf("oauth session not found: %s", sessionId)
|
||||
}
|
||||
|
||||
if time.Now().After(session.ExpiresAt) {
|
||||
auth.oauthMutex.Lock()
|
||||
delete(auth.oauthPendingSessions, sessionId)
|
||||
auth.oauthMutex.Unlock()
|
||||
return &OAuthPendingSession{}, fmt.Errorf("oauth session expired: %s", sessionId)
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"strconv"
|
||||
|
||||
"github.com/steveiliop56/tinyauth/internal/config"
|
||||
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
|
||||
)
|
||||
|
||||
type GithubEmailResponse []struct {
|
||||
@@ -24,42 +23,22 @@ type GithubUserInfoResponse struct {
|
||||
}
|
||||
|
||||
func defaultExtractor(client *http.Client, url string) (config.Claims, error) {
|
||||
var claims config.Claims
|
||||
|
||||
res, err := client.Get(url)
|
||||
if err != nil {
|
||||
return config.Claims{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode < 200 || res.StatusCode >= 300 {
|
||||
return config.Claims{}, fmt.Errorf("request failed with status: %s", res.Status)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return config.Claims{}, err
|
||||
}
|
||||
|
||||
tlog.App.Trace().Str("body", string(body)).Msg("Userinfo response body")
|
||||
|
||||
err = json.Unmarshal(body, &claims)
|
||||
if err != nil {
|
||||
return config.Claims{}, err
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
return simpleReq[config.Claims](client, url, nil)
|
||||
}
|
||||
|
||||
func githubExtractor(client *http.Client, url string) (config.Claims, error) {
|
||||
var user config.Claims
|
||||
|
||||
userInfo, err := githubRequest[GithubUserInfoResponse](client, "https://api.github.com/user")
|
||||
userInfo, err := simpleReq[GithubUserInfoResponse](client, "https://api.github.com/user", map[string]string{
|
||||
"accept": "application/vnd.github+json",
|
||||
})
|
||||
if err != nil {
|
||||
return config.Claims{}, err
|
||||
}
|
||||
|
||||
userEmails, err := githubRequest[GithubEmailResponse](client, "https://api.github.com/user/emails")
|
||||
userEmails, err := simpleReq[GithubEmailResponse](client, "https://api.github.com/user/emails", map[string]string{
|
||||
"accept": "application/vnd.github+json",
|
||||
})
|
||||
if err != nil {
|
||||
return config.Claims{}, err
|
||||
}
|
||||
@@ -87,35 +66,37 @@ func githubExtractor(client *http.Client, url string) (config.Claims, error) {
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func githubRequest[T any](client *http.Client, url string) (T, error) {
|
||||
var githubRes T
|
||||
func simpleReq[T any](client *http.Client, url string, headers map[string]string) (T, error) {
|
||||
var decodedRes T
|
||||
|
||||
req, err := http.NewRequest("GET", "https://api.github.com/user", nil)
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
return githubRes, err
|
||||
return decodedRes, err
|
||||
}
|
||||
|
||||
req.Header.Set("Accept", "application/vnd.github+json")
|
||||
for key, value := range headers {
|
||||
req.Header.Add(key, value)
|
||||
}
|
||||
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return githubRes, err
|
||||
return decodedRes, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode < 200 || res.StatusCode >= 300 {
|
||||
return githubRes, fmt.Errorf("request failed with status: %s", res.Status)
|
||||
return decodedRes, fmt.Errorf("request failed with status: %s", res.Status)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return githubRes, err
|
||||
return decodedRes, err
|
||||
}
|
||||
|
||||
err = json.Unmarshal(body, &githubRes)
|
||||
err = json.Unmarshal(body, &decodedRes)
|
||||
if err != nil {
|
||||
return githubRes, err
|
||||
return decodedRes, err
|
||||
}
|
||||
|
||||
return githubRes, nil
|
||||
return decodedRes, nil
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/steveiliop56/tinyauth/internal/config"
|
||||
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
@@ -69,6 +70,7 @@ func (s *OAuthService) GetAuthURL(state string, verifier string) string {
|
||||
}
|
||||
|
||||
func (s *OAuthService) GetToken(code string, verifier string) (*oauth2.Token, error) {
|
||||
tlog.App.Debug().Str("code", code).Str("verifier", verifier).Msg("Exchanging code for token")
|
||||
return s.config.Exchange(s.ctx, code, oauth2.VerifierOption(verifier))
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user