feat: move oauth logic into auth service and handle multiple sessions

This commit is contained in:
Stavros
2026-03-21 16:37:04 +02:00
parent 2491d453cf
commit 7bead41ae9
8 changed files with 169 additions and 160 deletions

View File

@@ -22,16 +22,17 @@ import (
type BootstrapApp struct { type BootstrapApp struct {
config config.Config config config.Config
context struct { context struct {
appUrl string appUrl string
uuid string uuid string
cookieDomain string cookieDomain string
sessionCookieName string sessionCookieName string
csrfCookieName string csrfCookieName string
redirectCookieName string redirectCookieName string
users []config.User oauthSessionCookieName string
oauthProviders map[string]config.OAuthServiceConfig users []config.User
configuredProviders []controller.Provider oauthProviders map[string]config.OAuthServiceConfig
oidcClients []config.OIDCClientConfig configuredProviders []controller.Provider
oidcClients []config.OIDCClientConfig
} }
services Services services Services
} }
@@ -113,6 +114,7 @@ func (app *BootstrapApp) Setup() error {
app.context.sessionCookieName = fmt.Sprintf("%s-%s", config.SessionCookieName, cookieId) app.context.sessionCookieName = fmt.Sprintf("%s-%s", config.SessionCookieName, cookieId)
app.context.csrfCookieName = fmt.Sprintf("%s-%s", config.CSRFCookieName, cookieId) app.context.csrfCookieName = fmt.Sprintf("%s-%s", config.CSRFCookieName, cookieId)
app.context.redirectCookieName = fmt.Sprintf("%s-%s", config.RedirectCookieName, cookieId) app.context.redirectCookieName = fmt.Sprintf("%s-%s", config.RedirectCookieName, cookieId)
app.context.oauthSessionCookieName = fmt.Sprintf("%s-%s", config.OAuthSessionCookieName, cookieId)
// Dumps // Dumps
tlog.App.Trace().Interface("config", app.config).Msg("Config dump") tlog.App.Trace().Interface("config", app.config).Msg("Config dump")
@@ -190,12 +192,12 @@ func (app *BootstrapApp) Setup() error {
// Start db cleanup routine // Start db cleanup routine
tlog.App.Debug().Msg("Starting database cleanup routine") tlog.App.Debug().Msg("Starting database cleanup routine")
go app.dbCleanup(queries) go app.dbCleanupRoutine(queries)
// If analytics are not disabled, start heartbeat // If analytics are not disabled, start heartbeat
if app.config.Analytics.Enabled { if app.config.Analytics.Enabled {
tlog.App.Debug().Msg("Starting heartbeat routine") tlog.App.Debug().Msg("Starting heartbeat routine")
go app.heartbeat() go app.heartbeatRoutine()
} }
// If we have an socket path, bind to it // If we have an socket path, bind to it
@@ -226,7 +228,7 @@ func (app *BootstrapApp) Setup() error {
return nil return nil
} }
func (app *BootstrapApp) heartbeat() { func (app *BootstrapApp) heartbeatRoutine() {
ticker := time.NewTicker(time.Duration(12) * time.Hour) ticker := time.NewTicker(time.Duration(12) * time.Hour)
defer ticker.Stop() defer ticker.Stop()
@@ -280,7 +282,7 @@ func (app *BootstrapApp) heartbeat() {
} }
} }
func (app *BootstrapApp) dbCleanup(queries *repository.Queries) { func (app *BootstrapApp) dbCleanupRoutine(queries *repository.Queries) {
ticker := time.NewTicker(time.Duration(30) * time.Minute) ticker := time.NewTicker(time.Duration(30) * time.Minute)
defer ticker.Stop() defer ticker.Stop()
ctx := context.Background() ctx := context.Background()

View File

@@ -77,12 +77,13 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
contextController.SetupRoutes() contextController.SetupRoutes()
oauthController := controller.NewOAuthController(controller.OAuthControllerConfig{ oauthController := controller.NewOAuthController(controller.OAuthControllerConfig{
AppURL: app.config.AppURL, AppURL: app.config.AppURL,
SecureCookie: app.config.Auth.SecureCookie, SecureCookie: app.config.Auth.SecureCookie,
CSRFCookieName: app.context.csrfCookieName, CSRFCookieName: app.context.csrfCookieName,
RedirectCookieName: app.context.redirectCookieName, RedirectCookieName: app.context.redirectCookieName,
CookieDomain: app.context.cookieDomain, CookieDomain: app.context.cookieDomain,
}, apiRouter, app.services.authService, app.services.oauthBrokerService) OAuthSessionCookieName: app.context.oauthSessionCookieName,
}, apiRouter, app.services.authService)
oauthController.SetupRoutes() oauthController.SetupRoutes()

View File

@@ -58,6 +58,16 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er
services.accessControlService = accessControlsService services.accessControlService = accessControlsService
oauthBrokerService := service.NewOAuthBrokerService(app.context.oauthProviders)
err = oauthBrokerService.Init()
if err != nil {
return Services{}, err
}
services.oauthBrokerService = oauthBrokerService
authService := service.NewAuthService(service.AuthServiceConfig{ authService := service.NewAuthService(service.AuthServiceConfig{
Users: app.context.users, Users: app.context.users,
OauthWhitelist: app.config.OAuth.Whitelist, OauthWhitelist: app.config.OAuth.Whitelist,
@@ -70,7 +80,7 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er
SessionCookieName: app.context.sessionCookieName, SessionCookieName: app.context.sessionCookieName,
IP: app.config.Auth.IP, IP: app.config.Auth.IP,
LDAPGroupsCacheTTL: app.config.Ldap.GroupCacheTTL, LDAPGroupsCacheTTL: app.config.Ldap.GroupCacheTTL,
}, dockerService, services.ldapService, queries) }, dockerService, services.ldapService, queries, services.oauthBrokerService)
err = authService.Init() err = authService.Init()
@@ -80,16 +90,6 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er
services.authService = authService services.authService = authService
oauthBrokerService := service.NewOAuthBrokerService(app.context.oauthProviders)
err = oauthBrokerService.Init()
if err != nil {
return Services{}, err
}
services.oauthBrokerService = oauthBrokerService
oidcService := service.NewOIDCService(service.OIDCServiceConfig{ oidcService := service.NewOIDCService(service.OIDCServiceConfig{
Clients: app.config.OIDC.Clients, Clients: app.config.OIDC.Clients,
PrivateKeyPath: app.config.OIDC.PrivateKeyPath, PrivateKeyPath: app.config.OIDC.PrivateKeyPath,

View File

@@ -73,6 +73,7 @@ var BuildTimestamp = "0000-00-00T00:00:00Z"
var SessionCookieName = "tinyauth-session" var SessionCookieName = "tinyauth-session"
var CSRFCookieName = "tinyauth-csrf" var CSRFCookieName = "tinyauth-csrf"
var RedirectCookieName = "tinyauth-redirect" var RedirectCookieName = "tinyauth-redirect"
var OAuthSessionCookieName = "tinyauth-oauth"
// Main app config // Main app config

View File

@@ -21,26 +21,25 @@ type OAuthRequest struct {
} }
type OAuthControllerConfig struct { type OAuthControllerConfig struct {
CSRFCookieName string CSRFCookieName string
RedirectCookieName string OAuthSessionCookieName string
SecureCookie bool RedirectCookieName string
AppURL string SecureCookie bool
CookieDomain string AppURL string
CookieDomain string
} }
type OAuthController struct { type OAuthController struct {
config OAuthControllerConfig config OAuthControllerConfig
router *gin.RouterGroup router *gin.RouterGroup
auth *service.AuthService auth *service.AuthService
broker *service.OAuthBrokerService
} }
func NewOAuthController(config OAuthControllerConfig, router *gin.RouterGroup, auth *service.AuthService, broker *service.OAuthBrokerService) *OAuthController { func NewOAuthController(config OAuthControllerConfig, router *gin.RouterGroup, auth *service.AuthService) *OAuthController {
return &OAuthController{ return &OAuthController{
config: config, config: config,
router: router, router: router,
auth: auth, auth: auth,
broker: broker,
} }
} }
@@ -63,21 +62,32 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
return return
} }
service, exists := controller.broker.GetService(req.Provider) sessionId, session, err := controller.auth.NewOAuthSession(req.Provider)
if !exists { if err != nil {
tlog.App.Warn().Msgf("OAuth provider not found: %s", req.Provider) tlog.App.Error().Err(err).Msg("Failed to create OAuth session")
c.JSON(404, gin.H{ c.JSON(500, gin.H{
"status": 404, "status": 500,
"message": "Not Found", "message": "Internal Server Error",
}) })
return return
} }
service.GenerateVerifier() tlog.App.Debug().Interface("session", session).Msg("Created new OAuth session")
state := service.GenerateState()
authURL := service.GetAuthURL(state) authUrl, err := controller.auth.GetOAuthURL(sessionId)
c.SetCookie(controller.config.CSRFCookieName, state, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to get OAuth URL")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
})
return
}
c.SetCookie(controller.config.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
c.SetCookie(controller.config.CSRFCookieName, session.State, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
redirectURI := c.Query("redirect_uri") redirectURI := c.Query("redirect_uri")
isRedirectSafe := utils.IsRedirectSafe(redirectURI, controller.config.CookieDomain) isRedirectSafe := utils.IsRedirectSafe(redirectURI, controller.config.CookieDomain)
@@ -95,7 +105,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"message": "OK", "message": "OK",
"url": authURL, "url": authUrl,
}) })
} }
@@ -112,6 +122,16 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
return return
} }
sessionIdCookie, err := c.Cookie(controller.config.OAuthSessionCookieName)
if err != nil {
tlog.App.Warn().Err(err).Msg("OAuth session cookie missing")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}
c.SetCookie(controller.config.OAuthSessionCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
state := c.Query("state") state := c.Query("state")
csrfCookie, err := c.Cookie(controller.config.CSRFCookieName) csrfCookie, err := c.Cookie(controller.config.CSRFCookieName)
@@ -125,29 +145,18 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
c.SetCookie(controller.config.CSRFCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true) c.SetCookie(controller.config.CSRFCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
code := c.Query("code") code := c.Query("code")
service, exists := controller.broker.GetService(req.Provider)
if !exists { tlog.App.Debug().Str("code", code).Str("state", state).Msg("Received OAuth callback")
tlog.App.Warn().Msgf("OAuth provider not found: %s", req.Provider) _, err = controller.auth.GetOAuthToken(sessionIdCookie, code)
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}
err = service.VerifyCode(code)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to verify OAuth code")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}
user, err := controller.broker.GetUser(req.Provider)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to get user from OAuth provider") tlog.App.Error().Err(err).Msg("Failed to exchange code for token")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return return
} }
user, err := controller.auth.GetOAuthUserinfo(sessionIdCookie)
if user.Email == "" { if user.Email == "" {
tlog.App.Error().Msg("OAuth provider did not return an email") tlog.App.Error().Msg("OAuth provider did not return an email")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
@@ -192,13 +201,21 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
username = strings.Replace(user.Email, "@", "_", 1) username = strings.Replace(user.Email, "@", "_", 1)
} }
service, err := controller.auth.GetOAuthService(sessionIdCookie)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to get OAuth service for session")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}
sessionCookie := repository.Session{ sessionCookie := repository.Session{
Username: username, Username: username,
Name: name, Name: name,
Email: user.Email, Email: user.Email,
Provider: req.Provider, Provider: req.Provider,
OAuthGroups: utils.CoalesceToString(user.Groups), OAuthGroups: utils.CoalesceToString(user.Groups),
OAuthName: service.GetName(), OAuthName: service.Name(),
OAuthSub: user.Sub, OAuthSub: user.Sub,
} }
@@ -214,6 +231,9 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
tlog.AuditLoginSuccess(c, sessionCookie.Username, sessionCookie.Provider) tlog.AuditLoginSuccess(c, sessionCookie.Username, sessionCookie.Provider)
// Clear OAuth session
controller.auth.EndOAuthSession(sessionIdCookie)
redirectURI, err := c.Cookie(controller.config.RedirectCookieName) redirectURI, err := c.Cookie(controller.config.RedirectCookieName)
if err != nil || !utils.IsRedirectSafe(redirectURI, controller.config.CookieDomain) { if err != nil || !utils.IsRedirectSafe(redirectURI, controller.config.CookieDomain) {

View File

@@ -69,13 +69,14 @@ type AuthService struct {
func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapService, queries *repository.Queries, oauthBroker *OAuthBrokerService) *AuthService { func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapService, queries *repository.Queries, oauthBroker *OAuthBrokerService) *AuthService {
return &AuthService{ return &AuthService{
config: config, config: config,
docker: docker, docker: docker,
loginAttempts: make(map[string]*LoginAttempt), loginAttempts: make(map[string]*LoginAttempt),
ldapGroupsCache: make(map[string]*LdapGroupsCache), ldapGroupsCache: make(map[string]*LdapGroupsCache),
ldap: ldap, oauthPendingSessions: make(map[string]*OAuthPendingSession),
queries: queries, ldap: ldap,
oauthBroker: oauthBroker, queries: queries,
oauthBroker: oauthBroker,
} }
} }
@@ -568,66 +569,51 @@ func (auth *AuthService) IsBypassedIP(acls config.AppIP, ip string) bool {
return false return false
} }
func (auth *AuthService) NewOAuthSession(serviceName string) (string, error) { func (auth *AuthService) NewOAuthSession(serviceName string) (string, OAuthPendingSession, error) {
service, ok := auth.oauthBroker.GetService(serviceName) service, ok := auth.oauthBroker.GetService(serviceName)
if !ok { if !ok {
return "", fmt.Errorf("oauth service not found: %s", serviceName) return "", OAuthPendingSession{}, fmt.Errorf("oauth service not found: %s", serviceName)
} }
sessionId, err := uuid.NewRandom() sessionId, err := uuid.NewRandom()
if err != nil { if err != nil {
return "", fmt.Errorf("failed to generate session ID: %w", err) return "", OAuthPendingSession{}, fmt.Errorf("failed to generate session ID: %w", err)
} }
state := uuid.New().String() state := service.NewRandom()
verifier := uuid.New().String() verifier := service.NewRandom()
auth.oauthMutex.Lock() session := OAuthPendingSession{
auth.oauthPendingSessions[sessionId.String()] = &OAuthPendingSession{
State: state, State: state,
Verifier: verifier, Verifier: verifier,
Service: &service, Service: &service,
ExpiresAt: time.Now().Add(1 * time.Hour), ExpiresAt: time.Now().Add(1 * time.Hour),
} }
auth.oauthMutex.Lock()
auth.oauthPendingSessions[sessionId.String()] = &session
auth.oauthMutex.Unlock() auth.oauthMutex.Unlock()
return sessionId.String(), nil return sessionId.String(), session, nil
} }
func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) { func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) {
auth.oauthMutex.RLock() session, err := auth.getOAuthPendingSession(sessionId)
defer auth.oauthMutex.RUnlock()
session, exists := auth.oauthPendingSessions[sessionId] if err != nil {
return "", err
if !exists {
return "", fmt.Errorf("oauth session not found: %s", sessionId)
}
if time.Now().After(session.ExpiresAt) {
delete(auth.oauthPendingSessions, sessionId)
return "", fmt.Errorf("oauth session expired: %s", sessionId)
} }
return (*session.Service).GetAuthURL(session.State, session.Verifier), nil return (*session.Service).GetAuthURL(session.State, session.Verifier), nil
} }
func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.Token, error) { func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.Token, error) {
auth.oauthMutex.RLock() session, err := auth.getOAuthPendingSession(sessionId)
session, exists := auth.oauthPendingSessions[sessionId]
auth.oauthMutex.RUnlock()
if !exists { if err != nil {
return nil, fmt.Errorf("oauth session not found: %s", sessionId) return nil, err
}
if time.Now().After(session.ExpiresAt) {
auth.oauthMutex.Lock()
delete(auth.oauthPendingSessions, sessionId)
auth.oauthMutex.Unlock()
return nil, fmt.Errorf("oauth session expired: %s", sessionId)
} }
token, err := (*session.Service).GetToken(code, session.Verifier) token, err := (*session.Service).GetToken(code, session.Verifier)
@@ -644,19 +630,10 @@ func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.T
} }
func (auth *AuthService) GetOAuthUserinfo(sessionId string) (config.Claims, error) { func (auth *AuthService) GetOAuthUserinfo(sessionId string) (config.Claims, error) {
auth.oauthMutex.RLock() session, err := auth.getOAuthPendingSession(sessionId)
session, exists := auth.oauthPendingSessions[sessionId]
auth.oauthMutex.RUnlock()
if !exists { if err != nil {
return config.Claims{}, fmt.Errorf("oauth session not found: %s", sessionId) return config.Claims{}, err
}
if time.Now().After(session.ExpiresAt) {
auth.oauthMutex.Lock()
delete(auth.oauthPendingSessions, sessionId)
auth.oauthMutex.Unlock()
return config.Claims{}, fmt.Errorf("oauth session expired: %s", sessionId)
} }
if session.Token == nil { if session.Token == nil {
@@ -669,13 +646,19 @@ func (auth *AuthService) GetOAuthUserinfo(sessionId string) (config.Claims, erro
return config.Claims{}, fmt.Errorf("failed to get userinfo: %w", err) return config.Claims{}, fmt.Errorf("failed to get userinfo: %w", err)
} }
auth.oauthMutex.Lock()
delete(auth.oauthPendingSessions, sessionId)
auth.oauthMutex.Unlock()
return userinfo, nil return userinfo, nil
} }
func (auth *AuthService) GetOAuthService(sessionId string) (OAuthServiceImpl, error) {
session, err := auth.getOAuthPendingSession(sessionId)
if err != nil {
return nil, err
}
return *session.Service, nil
}
func (auth *AuthService) EndOAuthSession(sessionId string) { func (auth *AuthService) EndOAuthSession(sessionId string) {
auth.oauthMutex.Lock() auth.oauthMutex.Lock()
delete(auth.oauthPendingSessions, sessionId) delete(auth.oauthPendingSessions, sessionId)
@@ -699,3 +682,22 @@ func (auth *AuthService) CleanupOAuthSessionsRoutine() {
} }
} }
} }
func (auth *AuthService) getOAuthPendingSession(sessionId string) (*OAuthPendingSession, error) {
auth.oauthMutex.RLock()
session, exists := auth.oauthPendingSessions[sessionId]
auth.oauthMutex.RUnlock()
if !exists {
return &OAuthPendingSession{}, fmt.Errorf("oauth session not found: %s", sessionId)
}
if time.Now().After(session.ExpiresAt) {
auth.oauthMutex.Lock()
delete(auth.oauthPendingSessions, sessionId)
auth.oauthMutex.Unlock()
return &OAuthPendingSession{}, fmt.Errorf("oauth session expired: %s", sessionId)
}
return session, nil
}

View File

@@ -9,7 +9,6 @@ import (
"strconv" "strconv"
"github.com/steveiliop56/tinyauth/internal/config" "github.com/steveiliop56/tinyauth/internal/config"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
) )
type GithubEmailResponse []struct { type GithubEmailResponse []struct {
@@ -24,42 +23,22 @@ type GithubUserInfoResponse struct {
} }
func defaultExtractor(client *http.Client, url string) (config.Claims, error) { func defaultExtractor(client *http.Client, url string) (config.Claims, error) {
var claims config.Claims return simpleReq[config.Claims](client, url, nil)
res, err := client.Get(url)
if err != nil {
return config.Claims{}, err
}
defer res.Body.Close()
if res.StatusCode < 200 || res.StatusCode >= 300 {
return config.Claims{}, fmt.Errorf("request failed with status: %s", res.Status)
}
body, err := io.ReadAll(res.Body)
if err != nil {
return config.Claims{}, err
}
tlog.App.Trace().Str("body", string(body)).Msg("Userinfo response body")
err = json.Unmarshal(body, &claims)
if err != nil {
return config.Claims{}, err
}
return claims, nil
} }
func githubExtractor(client *http.Client, url string) (config.Claims, error) { func githubExtractor(client *http.Client, url string) (config.Claims, error) {
var user config.Claims var user config.Claims
userInfo, err := githubRequest[GithubUserInfoResponse](client, "https://api.github.com/user") userInfo, err := simpleReq[GithubUserInfoResponse](client, "https://api.github.com/user", map[string]string{
"accept": "application/vnd.github+json",
})
if err != nil { if err != nil {
return config.Claims{}, err return config.Claims{}, err
} }
userEmails, err := githubRequest[GithubEmailResponse](client, "https://api.github.com/user/emails") userEmails, err := simpleReq[GithubEmailResponse](client, "https://api.github.com/user/emails", map[string]string{
"accept": "application/vnd.github+json",
})
if err != nil { if err != nil {
return config.Claims{}, err return config.Claims{}, err
} }
@@ -87,35 +66,37 @@ func githubExtractor(client *http.Client, url string) (config.Claims, error) {
return user, nil return user, nil
} }
func githubRequest[T any](client *http.Client, url string) (T, error) { func simpleReq[T any](client *http.Client, url string, headers map[string]string) (T, error) {
var githubRes T var decodedRes T
req, err := http.NewRequest("GET", "https://api.github.com/user", nil) req, err := http.NewRequest("GET", url, nil)
if err != nil { if err != nil {
return githubRes, err return decodedRes, err
} }
req.Header.Set("Accept", "application/vnd.github+json") for key, value := range headers {
req.Header.Add(key, value)
}
res, err := client.Do(req) res, err := client.Do(req)
if err != nil { if err != nil {
return githubRes, err return decodedRes, err
} }
defer res.Body.Close() defer res.Body.Close()
if res.StatusCode < 200 || res.StatusCode >= 300 { if res.StatusCode < 200 || res.StatusCode >= 300 {
return githubRes, fmt.Errorf("request failed with status: %s", res.Status) return decodedRes, fmt.Errorf("request failed with status: %s", res.Status)
} }
body, err := io.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
if err != nil { if err != nil {
return githubRes, err return decodedRes, err
} }
err = json.Unmarshal(body, &githubRes) err = json.Unmarshal(body, &decodedRes)
if err != nil { if err != nil {
return githubRes, err return decodedRes, err
} }
return githubRes, nil return decodedRes, nil
} }

View File

@@ -7,6 +7,7 @@ import (
"time" "time"
"github.com/steveiliop56/tinyauth/internal/config" "github.com/steveiliop56/tinyauth/internal/config"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
@@ -69,6 +70,7 @@ func (s *OAuthService) GetAuthURL(state string, verifier string) string {
} }
func (s *OAuthService) GetToken(code string, verifier string) (*oauth2.Token, error) { func (s *OAuthService) GetToken(code string, verifier string) (*oauth2.Token, error) {
tlog.App.Debug().Str("code", code).Str("verifier", verifier).Msg("Exchanging code for token")
return s.config.Exchange(s.ctx, code, oauth2.VerifierOption(verifier)) return s.config.Exchange(s.ctx, code, oauth2.VerifierOption(verifier))
} }