mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-03-23 06:57:52 +00:00
feat: move oauth logic into auth service and handle multiple sessions
This commit is contained in:
@@ -22,16 +22,17 @@ import (
|
|||||||
type BootstrapApp struct {
|
type BootstrapApp struct {
|
||||||
config config.Config
|
config config.Config
|
||||||
context struct {
|
context struct {
|
||||||
appUrl string
|
appUrl string
|
||||||
uuid string
|
uuid string
|
||||||
cookieDomain string
|
cookieDomain string
|
||||||
sessionCookieName string
|
sessionCookieName string
|
||||||
csrfCookieName string
|
csrfCookieName string
|
||||||
redirectCookieName string
|
redirectCookieName string
|
||||||
users []config.User
|
oauthSessionCookieName string
|
||||||
oauthProviders map[string]config.OAuthServiceConfig
|
users []config.User
|
||||||
configuredProviders []controller.Provider
|
oauthProviders map[string]config.OAuthServiceConfig
|
||||||
oidcClients []config.OIDCClientConfig
|
configuredProviders []controller.Provider
|
||||||
|
oidcClients []config.OIDCClientConfig
|
||||||
}
|
}
|
||||||
services Services
|
services Services
|
||||||
}
|
}
|
||||||
@@ -113,6 +114,7 @@ func (app *BootstrapApp) Setup() error {
|
|||||||
app.context.sessionCookieName = fmt.Sprintf("%s-%s", config.SessionCookieName, cookieId)
|
app.context.sessionCookieName = fmt.Sprintf("%s-%s", config.SessionCookieName, cookieId)
|
||||||
app.context.csrfCookieName = fmt.Sprintf("%s-%s", config.CSRFCookieName, cookieId)
|
app.context.csrfCookieName = fmt.Sprintf("%s-%s", config.CSRFCookieName, cookieId)
|
||||||
app.context.redirectCookieName = fmt.Sprintf("%s-%s", config.RedirectCookieName, cookieId)
|
app.context.redirectCookieName = fmt.Sprintf("%s-%s", config.RedirectCookieName, cookieId)
|
||||||
|
app.context.oauthSessionCookieName = fmt.Sprintf("%s-%s", config.OAuthSessionCookieName, cookieId)
|
||||||
|
|
||||||
// Dumps
|
// Dumps
|
||||||
tlog.App.Trace().Interface("config", app.config).Msg("Config dump")
|
tlog.App.Trace().Interface("config", app.config).Msg("Config dump")
|
||||||
@@ -190,12 +192,12 @@ func (app *BootstrapApp) Setup() error {
|
|||||||
|
|
||||||
// Start db cleanup routine
|
// Start db cleanup routine
|
||||||
tlog.App.Debug().Msg("Starting database cleanup routine")
|
tlog.App.Debug().Msg("Starting database cleanup routine")
|
||||||
go app.dbCleanup(queries)
|
go app.dbCleanupRoutine(queries)
|
||||||
|
|
||||||
// If analytics are not disabled, start heartbeat
|
// If analytics are not disabled, start heartbeat
|
||||||
if app.config.Analytics.Enabled {
|
if app.config.Analytics.Enabled {
|
||||||
tlog.App.Debug().Msg("Starting heartbeat routine")
|
tlog.App.Debug().Msg("Starting heartbeat routine")
|
||||||
go app.heartbeat()
|
go app.heartbeatRoutine()
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we have an socket path, bind to it
|
// If we have an socket path, bind to it
|
||||||
@@ -226,7 +228,7 @@ func (app *BootstrapApp) Setup() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (app *BootstrapApp) heartbeat() {
|
func (app *BootstrapApp) heartbeatRoutine() {
|
||||||
ticker := time.NewTicker(time.Duration(12) * time.Hour)
|
ticker := time.NewTicker(time.Duration(12) * time.Hour)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
@@ -280,7 +282,7 @@ func (app *BootstrapApp) heartbeat() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (app *BootstrapApp) dbCleanup(queries *repository.Queries) {
|
func (app *BootstrapApp) dbCleanupRoutine(queries *repository.Queries) {
|
||||||
ticker := time.NewTicker(time.Duration(30) * time.Minute)
|
ticker := time.NewTicker(time.Duration(30) * time.Minute)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|||||||
@@ -77,12 +77,13 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
|
|||||||
contextController.SetupRoutes()
|
contextController.SetupRoutes()
|
||||||
|
|
||||||
oauthController := controller.NewOAuthController(controller.OAuthControllerConfig{
|
oauthController := controller.NewOAuthController(controller.OAuthControllerConfig{
|
||||||
AppURL: app.config.AppURL,
|
AppURL: app.config.AppURL,
|
||||||
SecureCookie: app.config.Auth.SecureCookie,
|
SecureCookie: app.config.Auth.SecureCookie,
|
||||||
CSRFCookieName: app.context.csrfCookieName,
|
CSRFCookieName: app.context.csrfCookieName,
|
||||||
RedirectCookieName: app.context.redirectCookieName,
|
RedirectCookieName: app.context.redirectCookieName,
|
||||||
CookieDomain: app.context.cookieDomain,
|
CookieDomain: app.context.cookieDomain,
|
||||||
}, apiRouter, app.services.authService, app.services.oauthBrokerService)
|
OAuthSessionCookieName: app.context.oauthSessionCookieName,
|
||||||
|
}, apiRouter, app.services.authService)
|
||||||
|
|
||||||
oauthController.SetupRoutes()
|
oauthController.SetupRoutes()
|
||||||
|
|
||||||
|
|||||||
@@ -58,6 +58,16 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er
|
|||||||
|
|
||||||
services.accessControlService = accessControlsService
|
services.accessControlService = accessControlsService
|
||||||
|
|
||||||
|
oauthBrokerService := service.NewOAuthBrokerService(app.context.oauthProviders)
|
||||||
|
|
||||||
|
err = oauthBrokerService.Init()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return Services{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
services.oauthBrokerService = oauthBrokerService
|
||||||
|
|
||||||
authService := service.NewAuthService(service.AuthServiceConfig{
|
authService := service.NewAuthService(service.AuthServiceConfig{
|
||||||
Users: app.context.users,
|
Users: app.context.users,
|
||||||
OauthWhitelist: app.config.OAuth.Whitelist,
|
OauthWhitelist: app.config.OAuth.Whitelist,
|
||||||
@@ -70,7 +80,7 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er
|
|||||||
SessionCookieName: app.context.sessionCookieName,
|
SessionCookieName: app.context.sessionCookieName,
|
||||||
IP: app.config.Auth.IP,
|
IP: app.config.Auth.IP,
|
||||||
LDAPGroupsCacheTTL: app.config.Ldap.GroupCacheTTL,
|
LDAPGroupsCacheTTL: app.config.Ldap.GroupCacheTTL,
|
||||||
}, dockerService, services.ldapService, queries)
|
}, dockerService, services.ldapService, queries, services.oauthBrokerService)
|
||||||
|
|
||||||
err = authService.Init()
|
err = authService.Init()
|
||||||
|
|
||||||
@@ -80,16 +90,6 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er
|
|||||||
|
|
||||||
services.authService = authService
|
services.authService = authService
|
||||||
|
|
||||||
oauthBrokerService := service.NewOAuthBrokerService(app.context.oauthProviders)
|
|
||||||
|
|
||||||
err = oauthBrokerService.Init()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return Services{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
services.oauthBrokerService = oauthBrokerService
|
|
||||||
|
|
||||||
oidcService := service.NewOIDCService(service.OIDCServiceConfig{
|
oidcService := service.NewOIDCService(service.OIDCServiceConfig{
|
||||||
Clients: app.config.OIDC.Clients,
|
Clients: app.config.OIDC.Clients,
|
||||||
PrivateKeyPath: app.config.OIDC.PrivateKeyPath,
|
PrivateKeyPath: app.config.OIDC.PrivateKeyPath,
|
||||||
|
|||||||
@@ -73,6 +73,7 @@ var BuildTimestamp = "0000-00-00T00:00:00Z"
|
|||||||
var SessionCookieName = "tinyauth-session"
|
var SessionCookieName = "tinyauth-session"
|
||||||
var CSRFCookieName = "tinyauth-csrf"
|
var CSRFCookieName = "tinyauth-csrf"
|
||||||
var RedirectCookieName = "tinyauth-redirect"
|
var RedirectCookieName = "tinyauth-redirect"
|
||||||
|
var OAuthSessionCookieName = "tinyauth-oauth"
|
||||||
|
|
||||||
// Main app config
|
// Main app config
|
||||||
|
|
||||||
|
|||||||
@@ -21,26 +21,25 @@ type OAuthRequest struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type OAuthControllerConfig struct {
|
type OAuthControllerConfig struct {
|
||||||
CSRFCookieName string
|
CSRFCookieName string
|
||||||
RedirectCookieName string
|
OAuthSessionCookieName string
|
||||||
SecureCookie bool
|
RedirectCookieName string
|
||||||
AppURL string
|
SecureCookie bool
|
||||||
CookieDomain string
|
AppURL string
|
||||||
|
CookieDomain string
|
||||||
}
|
}
|
||||||
|
|
||||||
type OAuthController struct {
|
type OAuthController struct {
|
||||||
config OAuthControllerConfig
|
config OAuthControllerConfig
|
||||||
router *gin.RouterGroup
|
router *gin.RouterGroup
|
||||||
auth *service.AuthService
|
auth *service.AuthService
|
||||||
broker *service.OAuthBrokerService
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOAuthController(config OAuthControllerConfig, router *gin.RouterGroup, auth *service.AuthService, broker *service.OAuthBrokerService) *OAuthController {
|
func NewOAuthController(config OAuthControllerConfig, router *gin.RouterGroup, auth *service.AuthService) *OAuthController {
|
||||||
return &OAuthController{
|
return &OAuthController{
|
||||||
config: config,
|
config: config,
|
||||||
router: router,
|
router: router,
|
||||||
auth: auth,
|
auth: auth,
|
||||||
broker: broker,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -63,21 +62,32 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
service, exists := controller.broker.GetService(req.Provider)
|
sessionId, session, err := controller.auth.NewOAuthSession(req.Provider)
|
||||||
|
|
||||||
if !exists {
|
if err != nil {
|
||||||
tlog.App.Warn().Msgf("OAuth provider not found: %s", req.Provider)
|
tlog.App.Error().Err(err).Msg("Failed to create OAuth session")
|
||||||
c.JSON(404, gin.H{
|
c.JSON(500, gin.H{
|
||||||
"status": 404,
|
"status": 500,
|
||||||
"message": "Not Found",
|
"message": "Internal Server Error",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
service.GenerateVerifier()
|
tlog.App.Debug().Interface("session", session).Msg("Created new OAuth session")
|
||||||
state := service.GenerateState()
|
|
||||||
authURL := service.GetAuthURL(state)
|
authUrl, err := controller.auth.GetOAuthURL(sessionId)
|
||||||
c.SetCookie(controller.config.CSRFCookieName, state, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
|
|
||||||
|
if err != nil {
|
||||||
|
tlog.App.Error().Err(err).Msg("Failed to get OAuth URL")
|
||||||
|
c.JSON(500, gin.H{
|
||||||
|
"status": 500,
|
||||||
|
"message": "Internal Server Error",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.SetCookie(controller.config.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
|
||||||
|
c.SetCookie(controller.config.CSRFCookieName, session.State, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
|
||||||
|
|
||||||
redirectURI := c.Query("redirect_uri")
|
redirectURI := c.Query("redirect_uri")
|
||||||
isRedirectSafe := utils.IsRedirectSafe(redirectURI, controller.config.CookieDomain)
|
isRedirectSafe := utils.IsRedirectSafe(redirectURI, controller.config.CookieDomain)
|
||||||
@@ -95,7 +105,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
|
|||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"message": "OK",
|
"message": "OK",
|
||||||
"url": authURL,
|
"url": authUrl,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -112,6 +122,16 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sessionIdCookie, err := c.Cookie(controller.config.OAuthSessionCookieName)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
tlog.App.Warn().Err(err).Msg("OAuth session cookie missing")
|
||||||
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.SetCookie(controller.config.OAuthSessionCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
|
||||||
|
|
||||||
state := c.Query("state")
|
state := c.Query("state")
|
||||||
csrfCookie, err := c.Cookie(controller.config.CSRFCookieName)
|
csrfCookie, err := c.Cookie(controller.config.CSRFCookieName)
|
||||||
|
|
||||||
@@ -125,29 +145,18 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
c.SetCookie(controller.config.CSRFCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
|
c.SetCookie(controller.config.CSRFCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
|
||||||
|
|
||||||
code := c.Query("code")
|
code := c.Query("code")
|
||||||
service, exists := controller.broker.GetService(req.Provider)
|
|
||||||
|
|
||||||
if !exists {
|
tlog.App.Debug().Str("code", code).Str("state", state).Msg("Received OAuth callback")
|
||||||
tlog.App.Warn().Msgf("OAuth provider not found: %s", req.Provider)
|
_, err = controller.auth.GetOAuthToken(sessionIdCookie, code)
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err = service.VerifyCode(code)
|
|
||||||
if err != nil {
|
|
||||||
tlog.App.Error().Err(err).Msg("Failed to verify OAuth code")
|
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
user, err := controller.broker.GetUser(req.Provider)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to get user from OAuth provider")
|
tlog.App.Error().Err(err).Msg("Failed to exchange code for token")
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
user, err := controller.auth.GetOAuthUserinfo(sessionIdCookie)
|
||||||
|
|
||||||
if user.Email == "" {
|
if user.Email == "" {
|
||||||
tlog.App.Error().Msg("OAuth provider did not return an email")
|
tlog.App.Error().Msg("OAuth provider did not return an email")
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||||
@@ -192,13 +201,21 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
username = strings.Replace(user.Email, "@", "_", 1)
|
username = strings.Replace(user.Email, "@", "_", 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
service, err := controller.auth.GetOAuthService(sessionIdCookie)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
tlog.App.Error().Err(err).Msg("Failed to get OAuth service for session")
|
||||||
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
sessionCookie := repository.Session{
|
sessionCookie := repository.Session{
|
||||||
Username: username,
|
Username: username,
|
||||||
Name: name,
|
Name: name,
|
||||||
Email: user.Email,
|
Email: user.Email,
|
||||||
Provider: req.Provider,
|
Provider: req.Provider,
|
||||||
OAuthGroups: utils.CoalesceToString(user.Groups),
|
OAuthGroups: utils.CoalesceToString(user.Groups),
|
||||||
OAuthName: service.GetName(),
|
OAuthName: service.Name(),
|
||||||
OAuthSub: user.Sub,
|
OAuthSub: user.Sub,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -214,6 +231,9 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
|
|
||||||
tlog.AuditLoginSuccess(c, sessionCookie.Username, sessionCookie.Provider)
|
tlog.AuditLoginSuccess(c, sessionCookie.Username, sessionCookie.Provider)
|
||||||
|
|
||||||
|
// Clear OAuth session
|
||||||
|
controller.auth.EndOAuthSession(sessionIdCookie)
|
||||||
|
|
||||||
redirectURI, err := c.Cookie(controller.config.RedirectCookieName)
|
redirectURI, err := c.Cookie(controller.config.RedirectCookieName)
|
||||||
|
|
||||||
if err != nil || !utils.IsRedirectSafe(redirectURI, controller.config.CookieDomain) {
|
if err != nil || !utils.IsRedirectSafe(redirectURI, controller.config.CookieDomain) {
|
||||||
|
|||||||
@@ -69,13 +69,14 @@ type AuthService struct {
|
|||||||
|
|
||||||
func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapService, queries *repository.Queries, oauthBroker *OAuthBrokerService) *AuthService {
|
func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapService, queries *repository.Queries, oauthBroker *OAuthBrokerService) *AuthService {
|
||||||
return &AuthService{
|
return &AuthService{
|
||||||
config: config,
|
config: config,
|
||||||
docker: docker,
|
docker: docker,
|
||||||
loginAttempts: make(map[string]*LoginAttempt),
|
loginAttempts: make(map[string]*LoginAttempt),
|
||||||
ldapGroupsCache: make(map[string]*LdapGroupsCache),
|
ldapGroupsCache: make(map[string]*LdapGroupsCache),
|
||||||
ldap: ldap,
|
oauthPendingSessions: make(map[string]*OAuthPendingSession),
|
||||||
queries: queries,
|
ldap: ldap,
|
||||||
oauthBroker: oauthBroker,
|
queries: queries,
|
||||||
|
oauthBroker: oauthBroker,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -568,66 +569,51 @@ func (auth *AuthService) IsBypassedIP(acls config.AppIP, ip string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) NewOAuthSession(serviceName string) (string, error) {
|
func (auth *AuthService) NewOAuthSession(serviceName string) (string, OAuthPendingSession, error) {
|
||||||
service, ok := auth.oauthBroker.GetService(serviceName)
|
service, ok := auth.oauthBroker.GetService(serviceName)
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
return "", fmt.Errorf("oauth service not found: %s", serviceName)
|
return "", OAuthPendingSession{}, fmt.Errorf("oauth service not found: %s", serviceName)
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionId, err := uuid.NewRandom()
|
sessionId, err := uuid.NewRandom()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to generate session ID: %w", err)
|
return "", OAuthPendingSession{}, fmt.Errorf("failed to generate session ID: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
state := uuid.New().String()
|
state := service.NewRandom()
|
||||||
verifier := uuid.New().String()
|
verifier := service.NewRandom()
|
||||||
|
|
||||||
auth.oauthMutex.Lock()
|
session := OAuthPendingSession{
|
||||||
auth.oauthPendingSessions[sessionId.String()] = &OAuthPendingSession{
|
|
||||||
State: state,
|
State: state,
|
||||||
Verifier: verifier,
|
Verifier: verifier,
|
||||||
Service: &service,
|
Service: &service,
|
||||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auth.oauthMutex.Lock()
|
||||||
|
auth.oauthPendingSessions[sessionId.String()] = &session
|
||||||
auth.oauthMutex.Unlock()
|
auth.oauthMutex.Unlock()
|
||||||
|
|
||||||
return sessionId.String(), nil
|
return sessionId.String(), session, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) {
|
func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) {
|
||||||
auth.oauthMutex.RLock()
|
session, err := auth.getOAuthPendingSession(sessionId)
|
||||||
defer auth.oauthMutex.RUnlock()
|
|
||||||
|
|
||||||
session, exists := auth.oauthPendingSessions[sessionId]
|
if err != nil {
|
||||||
|
return "", err
|
||||||
if !exists {
|
|
||||||
return "", fmt.Errorf("oauth session not found: %s", sessionId)
|
|
||||||
}
|
|
||||||
|
|
||||||
if time.Now().After(session.ExpiresAt) {
|
|
||||||
delete(auth.oauthPendingSessions, sessionId)
|
|
||||||
return "", fmt.Errorf("oauth session expired: %s", sessionId)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return (*session.Service).GetAuthURL(session.State, session.Verifier), nil
|
return (*session.Service).GetAuthURL(session.State, session.Verifier), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.Token, error) {
|
func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.Token, error) {
|
||||||
auth.oauthMutex.RLock()
|
session, err := auth.getOAuthPendingSession(sessionId)
|
||||||
session, exists := auth.oauthPendingSessions[sessionId]
|
|
||||||
auth.oauthMutex.RUnlock()
|
|
||||||
|
|
||||||
if !exists {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("oauth session not found: %s", sessionId)
|
return nil, err
|
||||||
}
|
|
||||||
|
|
||||||
if time.Now().After(session.ExpiresAt) {
|
|
||||||
auth.oauthMutex.Lock()
|
|
||||||
delete(auth.oauthPendingSessions, sessionId)
|
|
||||||
auth.oauthMutex.Unlock()
|
|
||||||
return nil, fmt.Errorf("oauth session expired: %s", sessionId)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err := (*session.Service).GetToken(code, session.Verifier)
|
token, err := (*session.Service).GetToken(code, session.Verifier)
|
||||||
@@ -644,19 +630,10 @@ func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.T
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) GetOAuthUserinfo(sessionId string) (config.Claims, error) {
|
func (auth *AuthService) GetOAuthUserinfo(sessionId string) (config.Claims, error) {
|
||||||
auth.oauthMutex.RLock()
|
session, err := auth.getOAuthPendingSession(sessionId)
|
||||||
session, exists := auth.oauthPendingSessions[sessionId]
|
|
||||||
auth.oauthMutex.RUnlock()
|
|
||||||
|
|
||||||
if !exists {
|
if err != nil {
|
||||||
return config.Claims{}, fmt.Errorf("oauth session not found: %s", sessionId)
|
return config.Claims{}, err
|
||||||
}
|
|
||||||
|
|
||||||
if time.Now().After(session.ExpiresAt) {
|
|
||||||
auth.oauthMutex.Lock()
|
|
||||||
delete(auth.oauthPendingSessions, sessionId)
|
|
||||||
auth.oauthMutex.Unlock()
|
|
||||||
return config.Claims{}, fmt.Errorf("oauth session expired: %s", sessionId)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if session.Token == nil {
|
if session.Token == nil {
|
||||||
@@ -669,13 +646,19 @@ func (auth *AuthService) GetOAuthUserinfo(sessionId string) (config.Claims, erro
|
|||||||
return config.Claims{}, fmt.Errorf("failed to get userinfo: %w", err)
|
return config.Claims{}, fmt.Errorf("failed to get userinfo: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
auth.oauthMutex.Lock()
|
|
||||||
delete(auth.oauthPendingSessions, sessionId)
|
|
||||||
auth.oauthMutex.Unlock()
|
|
||||||
|
|
||||||
return userinfo, nil
|
return userinfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (auth *AuthService) GetOAuthService(sessionId string) (OAuthServiceImpl, error) {
|
||||||
|
session, err := auth.getOAuthPendingSession(sessionId)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return *session.Service, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (auth *AuthService) EndOAuthSession(sessionId string) {
|
func (auth *AuthService) EndOAuthSession(sessionId string) {
|
||||||
auth.oauthMutex.Lock()
|
auth.oauthMutex.Lock()
|
||||||
delete(auth.oauthPendingSessions, sessionId)
|
delete(auth.oauthPendingSessions, sessionId)
|
||||||
@@ -699,3 +682,22 @@ func (auth *AuthService) CleanupOAuthSessionsRoutine() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (auth *AuthService) getOAuthPendingSession(sessionId string) (*OAuthPendingSession, error) {
|
||||||
|
auth.oauthMutex.RLock()
|
||||||
|
session, exists := auth.oauthPendingSessions[sessionId]
|
||||||
|
auth.oauthMutex.RUnlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
return &OAuthPendingSession{}, fmt.Errorf("oauth session not found: %s", sessionId)
|
||||||
|
}
|
||||||
|
|
||||||
|
if time.Now().After(session.ExpiresAt) {
|
||||||
|
auth.oauthMutex.Lock()
|
||||||
|
delete(auth.oauthPendingSessions, sessionId)
|
||||||
|
auth.oauthMutex.Unlock()
|
||||||
|
return &OAuthPendingSession{}, fmt.Errorf("oauth session expired: %s", sessionId)
|
||||||
|
}
|
||||||
|
|
||||||
|
return session, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/steveiliop56/tinyauth/internal/config"
|
"github.com/steveiliop56/tinyauth/internal/config"
|
||||||
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type GithubEmailResponse []struct {
|
type GithubEmailResponse []struct {
|
||||||
@@ -24,42 +23,22 @@ type GithubUserInfoResponse struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func defaultExtractor(client *http.Client, url string) (config.Claims, error) {
|
func defaultExtractor(client *http.Client, url string) (config.Claims, error) {
|
||||||
var claims config.Claims
|
return simpleReq[config.Claims](client, url, nil)
|
||||||
|
|
||||||
res, err := client.Get(url)
|
|
||||||
if err != nil {
|
|
||||||
return config.Claims{}, err
|
|
||||||
}
|
|
||||||
defer res.Body.Close()
|
|
||||||
|
|
||||||
if res.StatusCode < 200 || res.StatusCode >= 300 {
|
|
||||||
return config.Claims{}, fmt.Errorf("request failed with status: %s", res.Status)
|
|
||||||
}
|
|
||||||
|
|
||||||
body, err := io.ReadAll(res.Body)
|
|
||||||
if err != nil {
|
|
||||||
return config.Claims{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
tlog.App.Trace().Str("body", string(body)).Msg("Userinfo response body")
|
|
||||||
|
|
||||||
err = json.Unmarshal(body, &claims)
|
|
||||||
if err != nil {
|
|
||||||
return config.Claims{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return claims, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func githubExtractor(client *http.Client, url string) (config.Claims, error) {
|
func githubExtractor(client *http.Client, url string) (config.Claims, error) {
|
||||||
var user config.Claims
|
var user config.Claims
|
||||||
|
|
||||||
userInfo, err := githubRequest[GithubUserInfoResponse](client, "https://api.github.com/user")
|
userInfo, err := simpleReq[GithubUserInfoResponse](client, "https://api.github.com/user", map[string]string{
|
||||||
|
"accept": "application/vnd.github+json",
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return config.Claims{}, err
|
return config.Claims{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
userEmails, err := githubRequest[GithubEmailResponse](client, "https://api.github.com/user/emails")
|
userEmails, err := simpleReq[GithubEmailResponse](client, "https://api.github.com/user/emails", map[string]string{
|
||||||
|
"accept": "application/vnd.github+json",
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return config.Claims{}, err
|
return config.Claims{}, err
|
||||||
}
|
}
|
||||||
@@ -87,35 +66,37 @@ func githubExtractor(client *http.Client, url string) (config.Claims, error) {
|
|||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func githubRequest[T any](client *http.Client, url string) (T, error) {
|
func simpleReq[T any](client *http.Client, url string, headers map[string]string) (T, error) {
|
||||||
var githubRes T
|
var decodedRes T
|
||||||
|
|
||||||
req, err := http.NewRequest("GET", "https://api.github.com/user", nil)
|
req, err := http.NewRequest("GET", url, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return githubRes, err
|
return decodedRes, err
|
||||||
}
|
}
|
||||||
|
|
||||||
req.Header.Set("Accept", "application/vnd.github+json")
|
for key, value := range headers {
|
||||||
|
req.Header.Add(key, value)
|
||||||
|
}
|
||||||
|
|
||||||
res, err := client.Do(req)
|
res, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return githubRes, err
|
return decodedRes, err
|
||||||
}
|
}
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
|
|
||||||
if res.StatusCode < 200 || res.StatusCode >= 300 {
|
if res.StatusCode < 200 || res.StatusCode >= 300 {
|
||||||
return githubRes, fmt.Errorf("request failed with status: %s", res.Status)
|
return decodedRes, fmt.Errorf("request failed with status: %s", res.Status)
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := io.ReadAll(res.Body)
|
body, err := io.ReadAll(res.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return githubRes, err
|
return decodedRes, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = json.Unmarshal(body, &githubRes)
|
err = json.Unmarshal(body, &decodedRes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return githubRes, err
|
return decodedRes, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return githubRes, nil
|
return decodedRes, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/steveiliop56/tinyauth/internal/config"
|
"github.com/steveiliop56/tinyauth/internal/config"
|
||||||
|
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -69,6 +70,7 @@ func (s *OAuthService) GetAuthURL(state string, verifier string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *OAuthService) GetToken(code string, verifier string) (*oauth2.Token, error) {
|
func (s *OAuthService) GetToken(code string, verifier string) (*oauth2.Token, error) {
|
||||||
|
tlog.App.Debug().Str("code", code).Str("verifier", verifier).Msg("Exchanging code for token")
|
||||||
return s.config.Exchange(s.ctx, code, oauth2.VerifierOption(verifier))
|
return s.config.Exchange(s.ctx, code, oauth2.VerifierOption(verifier))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user