diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go index b1598f7..12e4f7f 100644 --- a/internal/bootstrap/app_bootstrap.go +++ b/internal/bootstrap/app_bootstrap.go @@ -22,16 +22,17 @@ import ( type BootstrapApp struct { config config.Config context struct { - appUrl string - uuid string - cookieDomain string - sessionCookieName string - csrfCookieName string - redirectCookieName string - users []config.User - oauthProviders map[string]config.OAuthServiceConfig - configuredProviders []controller.Provider - oidcClients []config.OIDCClientConfig + appUrl string + uuid string + cookieDomain string + sessionCookieName string + csrfCookieName string + redirectCookieName string + oauthSessionCookieName string + users []config.User + oauthProviders map[string]config.OAuthServiceConfig + configuredProviders []controller.Provider + oidcClients []config.OIDCClientConfig } services Services } @@ -113,6 +114,7 @@ func (app *BootstrapApp) Setup() error { app.context.sessionCookieName = fmt.Sprintf("%s-%s", config.SessionCookieName, cookieId) app.context.csrfCookieName = fmt.Sprintf("%s-%s", config.CSRFCookieName, cookieId) app.context.redirectCookieName = fmt.Sprintf("%s-%s", config.RedirectCookieName, cookieId) + app.context.oauthSessionCookieName = fmt.Sprintf("%s-%s", config.OAuthSessionCookieName, cookieId) // Dumps tlog.App.Trace().Interface("config", app.config).Msg("Config dump") @@ -190,12 +192,12 @@ func (app *BootstrapApp) Setup() error { // Start db cleanup routine tlog.App.Debug().Msg("Starting database cleanup routine") - go app.dbCleanup(queries) + go app.dbCleanupRoutine(queries) // If analytics are not disabled, start heartbeat if app.config.Analytics.Enabled { tlog.App.Debug().Msg("Starting heartbeat routine") - go app.heartbeat() + go app.heartbeatRoutine() } // If we have an socket path, bind to it @@ -226,7 +228,7 @@ func (app *BootstrapApp) Setup() error { return nil } -func (app *BootstrapApp) heartbeat() { +func (app *BootstrapApp) heartbeatRoutine() { ticker := time.NewTicker(time.Duration(12) * time.Hour) defer ticker.Stop() @@ -280,7 +282,7 @@ func (app *BootstrapApp) heartbeat() { } } -func (app *BootstrapApp) dbCleanup(queries *repository.Queries) { +func (app *BootstrapApp) dbCleanupRoutine(queries *repository.Queries) { ticker := time.NewTicker(time.Duration(30) * time.Minute) defer ticker.Stop() ctx := context.Background() diff --git a/internal/bootstrap/router_bootstrap.go b/internal/bootstrap/router_bootstrap.go index 6ae7e94..ef92d10 100644 --- a/internal/bootstrap/router_bootstrap.go +++ b/internal/bootstrap/router_bootstrap.go @@ -77,12 +77,13 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) { contextController.SetupRoutes() oauthController := controller.NewOAuthController(controller.OAuthControllerConfig{ - AppURL: app.config.AppURL, - SecureCookie: app.config.Auth.SecureCookie, - CSRFCookieName: app.context.csrfCookieName, - RedirectCookieName: app.context.redirectCookieName, - CookieDomain: app.context.cookieDomain, - }, apiRouter, app.services.authService, app.services.oauthBrokerService) + AppURL: app.config.AppURL, + SecureCookie: app.config.Auth.SecureCookie, + CSRFCookieName: app.context.csrfCookieName, + RedirectCookieName: app.context.redirectCookieName, + CookieDomain: app.context.cookieDomain, + OAuthSessionCookieName: app.context.oauthSessionCookieName, + }, apiRouter, app.services.authService) oauthController.SetupRoutes() diff --git a/internal/bootstrap/service_bootstrap.go b/internal/bootstrap/service_bootstrap.go index e9a27a7..7bd4a62 100644 --- a/internal/bootstrap/service_bootstrap.go +++ b/internal/bootstrap/service_bootstrap.go @@ -58,6 +58,16 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er services.accessControlService = accessControlsService + oauthBrokerService := service.NewOAuthBrokerService(app.context.oauthProviders) + + err = oauthBrokerService.Init() + + if err != nil { + return Services{}, err + } + + services.oauthBrokerService = oauthBrokerService + authService := service.NewAuthService(service.AuthServiceConfig{ Users: app.context.users, OauthWhitelist: app.config.OAuth.Whitelist, @@ -70,7 +80,7 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er SessionCookieName: app.context.sessionCookieName, IP: app.config.Auth.IP, LDAPGroupsCacheTTL: app.config.Ldap.GroupCacheTTL, - }, dockerService, services.ldapService, queries) + }, dockerService, services.ldapService, queries, services.oauthBrokerService) err = authService.Init() @@ -80,16 +90,6 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er services.authService = authService - oauthBrokerService := service.NewOAuthBrokerService(app.context.oauthProviders) - - err = oauthBrokerService.Init() - - if err != nil { - return Services{}, err - } - - services.oauthBrokerService = oauthBrokerService - oidcService := service.NewOIDCService(service.OIDCServiceConfig{ Clients: app.config.OIDC.Clients, PrivateKeyPath: app.config.OIDC.PrivateKeyPath, diff --git a/internal/config/config.go b/internal/config/config.go index 4633953..b8db08a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -73,6 +73,7 @@ var BuildTimestamp = "0000-00-00T00:00:00Z" var SessionCookieName = "tinyauth-session" var CSRFCookieName = "tinyauth-csrf" var RedirectCookieName = "tinyauth-redirect" +var OAuthSessionCookieName = "tinyauth-oauth" // Main app config diff --git a/internal/controller/oauth_controller.go b/internal/controller/oauth_controller.go index 019bae7..acf20f1 100644 --- a/internal/controller/oauth_controller.go +++ b/internal/controller/oauth_controller.go @@ -21,26 +21,25 @@ type OAuthRequest struct { } type OAuthControllerConfig struct { - CSRFCookieName string - RedirectCookieName string - SecureCookie bool - AppURL string - CookieDomain string + CSRFCookieName string + OAuthSessionCookieName string + RedirectCookieName string + SecureCookie bool + AppURL string + CookieDomain string } type OAuthController struct { config OAuthControllerConfig router *gin.RouterGroup auth *service.AuthService - broker *service.OAuthBrokerService } -func NewOAuthController(config OAuthControllerConfig, router *gin.RouterGroup, auth *service.AuthService, broker *service.OAuthBrokerService) *OAuthController { +func NewOAuthController(config OAuthControllerConfig, router *gin.RouterGroup, auth *service.AuthService) *OAuthController { return &OAuthController{ config: config, router: router, auth: auth, - broker: broker, } } @@ -63,21 +62,32 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) { return } - service, exists := controller.broker.GetService(req.Provider) + sessionId, session, err := controller.auth.NewOAuthSession(req.Provider) - if !exists { - tlog.App.Warn().Msgf("OAuth provider not found: %s", req.Provider) - c.JSON(404, gin.H{ - "status": 404, - "message": "Not Found", + if err != nil { + tlog.App.Error().Err(err).Msg("Failed to create OAuth session") + c.JSON(500, gin.H{ + "status": 500, + "message": "Internal Server Error", }) return } - service.GenerateVerifier() - state := service.GenerateState() - authURL := service.GetAuthURL(state) - c.SetCookie(controller.config.CSRFCookieName, state, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true) + tlog.App.Debug().Interface("session", session).Msg("Created new OAuth session") + + authUrl, err := controller.auth.GetOAuthURL(sessionId) + + if err != nil { + tlog.App.Error().Err(err).Msg("Failed to get OAuth URL") + c.JSON(500, gin.H{ + "status": 500, + "message": "Internal Server Error", + }) + return + } + + c.SetCookie(controller.config.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true) + c.SetCookie(controller.config.CSRFCookieName, session.State, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true) redirectURI := c.Query("redirect_uri") isRedirectSafe := utils.IsRedirectSafe(redirectURI, controller.config.CookieDomain) @@ -95,7 +105,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) { c.JSON(200, gin.H{ "status": 200, "message": "OK", - "url": authURL, + "url": authUrl, }) } @@ -112,6 +122,16 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { return } + sessionIdCookie, err := c.Cookie(controller.config.OAuthSessionCookieName) + + if err != nil { + tlog.App.Warn().Err(err).Msg("OAuth session cookie missing") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + return + } + + c.SetCookie(controller.config.OAuthSessionCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true) + state := c.Query("state") csrfCookie, err := c.Cookie(controller.config.CSRFCookieName) @@ -125,29 +145,18 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { c.SetCookie(controller.config.CSRFCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true) code := c.Query("code") - service, exists := controller.broker.GetService(req.Provider) - if !exists { - tlog.App.Warn().Msgf("OAuth provider not found: %s", req.Provider) - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) - return - } - - err = service.VerifyCode(code) - if err != nil { - tlog.App.Error().Err(err).Msg("Failed to verify OAuth code") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) - return - } - - user, err := controller.broker.GetUser(req.Provider) + tlog.App.Debug().Str("code", code).Str("state", state).Msg("Received OAuth callback") + _, err = controller.auth.GetOAuthToken(sessionIdCookie, code) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to get user from OAuth provider") + tlog.App.Error().Err(err).Msg("Failed to exchange code for token") c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) return } + user, err := controller.auth.GetOAuthUserinfo(sessionIdCookie) + if user.Email == "" { tlog.App.Error().Msg("OAuth provider did not return an email") c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) @@ -192,13 +201,21 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { username = strings.Replace(user.Email, "@", "_", 1) } + service, err := controller.auth.GetOAuthService(sessionIdCookie) + + if err != nil { + tlog.App.Error().Err(err).Msg("Failed to get OAuth service for session") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + return + } + sessionCookie := repository.Session{ Username: username, Name: name, Email: user.Email, Provider: req.Provider, OAuthGroups: utils.CoalesceToString(user.Groups), - OAuthName: service.GetName(), + OAuthName: service.Name(), OAuthSub: user.Sub, } @@ -214,6 +231,9 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { tlog.AuditLoginSuccess(c, sessionCookie.Username, sessionCookie.Provider) + // Clear OAuth session + controller.auth.EndOAuthSession(sessionIdCookie) + redirectURI, err := c.Cookie(controller.config.RedirectCookieName) if err != nil || !utils.IsRedirectSafe(redirectURI, controller.config.CookieDomain) { diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index c95e792..90a2aec 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -69,13 +69,14 @@ type AuthService struct { func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapService, queries *repository.Queries, oauthBroker *OAuthBrokerService) *AuthService { return &AuthService{ - config: config, - docker: docker, - loginAttempts: make(map[string]*LoginAttempt), - ldapGroupsCache: make(map[string]*LdapGroupsCache), - ldap: ldap, - queries: queries, - oauthBroker: oauthBroker, + config: config, + docker: docker, + loginAttempts: make(map[string]*LoginAttempt), + ldapGroupsCache: make(map[string]*LdapGroupsCache), + oauthPendingSessions: make(map[string]*OAuthPendingSession), + ldap: ldap, + queries: queries, + oauthBroker: oauthBroker, } } @@ -568,66 +569,51 @@ func (auth *AuthService) IsBypassedIP(acls config.AppIP, ip string) bool { return false } -func (auth *AuthService) NewOAuthSession(serviceName string) (string, error) { +func (auth *AuthService) NewOAuthSession(serviceName string) (string, OAuthPendingSession, error) { service, ok := auth.oauthBroker.GetService(serviceName) if !ok { - return "", fmt.Errorf("oauth service not found: %s", serviceName) + return "", OAuthPendingSession{}, fmt.Errorf("oauth service not found: %s", serviceName) } sessionId, err := uuid.NewRandom() if err != nil { - return "", fmt.Errorf("failed to generate session ID: %w", err) + return "", OAuthPendingSession{}, fmt.Errorf("failed to generate session ID: %w", err) } - state := uuid.New().String() - verifier := uuid.New().String() + state := service.NewRandom() + verifier := service.NewRandom() - auth.oauthMutex.Lock() - auth.oauthPendingSessions[sessionId.String()] = &OAuthPendingSession{ + session := OAuthPendingSession{ State: state, Verifier: verifier, Service: &service, ExpiresAt: time.Now().Add(1 * time.Hour), } + + auth.oauthMutex.Lock() + auth.oauthPendingSessions[sessionId.String()] = &session auth.oauthMutex.Unlock() - return sessionId.String(), nil + return sessionId.String(), session, nil } func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) { - auth.oauthMutex.RLock() - defer auth.oauthMutex.RUnlock() + session, err := auth.getOAuthPendingSession(sessionId) - session, exists := auth.oauthPendingSessions[sessionId] - - if !exists { - return "", fmt.Errorf("oauth session not found: %s", sessionId) - } - - if time.Now().After(session.ExpiresAt) { - delete(auth.oauthPendingSessions, sessionId) - return "", fmt.Errorf("oauth session expired: %s", sessionId) + if err != nil { + return "", err } return (*session.Service).GetAuthURL(session.State, session.Verifier), nil } func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.Token, error) { - auth.oauthMutex.RLock() - session, exists := auth.oauthPendingSessions[sessionId] - auth.oauthMutex.RUnlock() + session, err := auth.getOAuthPendingSession(sessionId) - if !exists { - return nil, fmt.Errorf("oauth session not found: %s", sessionId) - } - - if time.Now().After(session.ExpiresAt) { - auth.oauthMutex.Lock() - delete(auth.oauthPendingSessions, sessionId) - auth.oauthMutex.Unlock() - return nil, fmt.Errorf("oauth session expired: %s", sessionId) + if err != nil { + return nil, err } token, err := (*session.Service).GetToken(code, session.Verifier) @@ -644,19 +630,10 @@ func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.T } func (auth *AuthService) GetOAuthUserinfo(sessionId string) (config.Claims, error) { - auth.oauthMutex.RLock() - session, exists := auth.oauthPendingSessions[sessionId] - auth.oauthMutex.RUnlock() + session, err := auth.getOAuthPendingSession(sessionId) - if !exists { - return config.Claims{}, fmt.Errorf("oauth session not found: %s", sessionId) - } - - if time.Now().After(session.ExpiresAt) { - auth.oauthMutex.Lock() - delete(auth.oauthPendingSessions, sessionId) - auth.oauthMutex.Unlock() - return config.Claims{}, fmt.Errorf("oauth session expired: %s", sessionId) + if err != nil { + return config.Claims{}, err } if session.Token == nil { @@ -669,13 +646,19 @@ func (auth *AuthService) GetOAuthUserinfo(sessionId string) (config.Claims, erro return config.Claims{}, fmt.Errorf("failed to get userinfo: %w", err) } - auth.oauthMutex.Lock() - delete(auth.oauthPendingSessions, sessionId) - auth.oauthMutex.Unlock() - return userinfo, nil } +func (auth *AuthService) GetOAuthService(sessionId string) (OAuthServiceImpl, error) { + session, err := auth.getOAuthPendingSession(sessionId) + + if err != nil { + return nil, err + } + + return *session.Service, nil +} + func (auth *AuthService) EndOAuthSession(sessionId string) { auth.oauthMutex.Lock() delete(auth.oauthPendingSessions, sessionId) @@ -699,3 +682,22 @@ func (auth *AuthService) CleanupOAuthSessionsRoutine() { } } } + +func (auth *AuthService) getOAuthPendingSession(sessionId string) (*OAuthPendingSession, error) { + auth.oauthMutex.RLock() + session, exists := auth.oauthPendingSessions[sessionId] + auth.oauthMutex.RUnlock() + + if !exists { + return &OAuthPendingSession{}, fmt.Errorf("oauth session not found: %s", sessionId) + } + + if time.Now().After(session.ExpiresAt) { + auth.oauthMutex.Lock() + delete(auth.oauthPendingSessions, sessionId) + auth.oauthMutex.Unlock() + return &OAuthPendingSession{}, fmt.Errorf("oauth session expired: %s", sessionId) + } + + return session, nil +} diff --git a/internal/service/oauth_extractors.go b/internal/service/oauth_extractors.go index 2bf4ab8..91b1387 100644 --- a/internal/service/oauth_extractors.go +++ b/internal/service/oauth_extractors.go @@ -9,7 +9,6 @@ import ( "strconv" "github.com/steveiliop56/tinyauth/internal/config" - "github.com/steveiliop56/tinyauth/internal/utils/tlog" ) type GithubEmailResponse []struct { @@ -24,42 +23,22 @@ type GithubUserInfoResponse struct { } func defaultExtractor(client *http.Client, url string) (config.Claims, error) { - var claims config.Claims - - res, err := client.Get(url) - if err != nil { - return config.Claims{}, err - } - defer res.Body.Close() - - if res.StatusCode < 200 || res.StatusCode >= 300 { - return config.Claims{}, fmt.Errorf("request failed with status: %s", res.Status) - } - - body, err := io.ReadAll(res.Body) - if err != nil { - return config.Claims{}, err - } - - tlog.App.Trace().Str("body", string(body)).Msg("Userinfo response body") - - err = json.Unmarshal(body, &claims) - if err != nil { - return config.Claims{}, err - } - - return claims, nil + return simpleReq[config.Claims](client, url, nil) } func githubExtractor(client *http.Client, url string) (config.Claims, error) { var user config.Claims - userInfo, err := githubRequest[GithubUserInfoResponse](client, "https://api.github.com/user") + userInfo, err := simpleReq[GithubUserInfoResponse](client, "https://api.github.com/user", map[string]string{ + "accept": "application/vnd.github+json", + }) if err != nil { return config.Claims{}, err } - userEmails, err := githubRequest[GithubEmailResponse](client, "https://api.github.com/user/emails") + userEmails, err := simpleReq[GithubEmailResponse](client, "https://api.github.com/user/emails", map[string]string{ + "accept": "application/vnd.github+json", + }) if err != nil { return config.Claims{}, err } @@ -87,35 +66,37 @@ func githubExtractor(client *http.Client, url string) (config.Claims, error) { return user, nil } -func githubRequest[T any](client *http.Client, url string) (T, error) { - var githubRes T +func simpleReq[T any](client *http.Client, url string, headers map[string]string) (T, error) { + var decodedRes T - req, err := http.NewRequest("GET", "https://api.github.com/user", nil) + req, err := http.NewRequest("GET", url, nil) if err != nil { - return githubRes, err + return decodedRes, err } - req.Header.Set("Accept", "application/vnd.github+json") + for key, value := range headers { + req.Header.Add(key, value) + } res, err := client.Do(req) if err != nil { - return githubRes, err + return decodedRes, err } defer res.Body.Close() if res.StatusCode < 200 || res.StatusCode >= 300 { - return githubRes, fmt.Errorf("request failed with status: %s", res.Status) + return decodedRes, fmt.Errorf("request failed with status: %s", res.Status) } body, err := io.ReadAll(res.Body) if err != nil { - return githubRes, err + return decodedRes, err } - err = json.Unmarshal(body, &githubRes) + err = json.Unmarshal(body, &decodedRes) if err != nil { - return githubRes, err + return decodedRes, err } - return githubRes, nil + return decodedRes, nil } diff --git a/internal/service/oauth_service.go b/internal/service/oauth_service.go index 76f5a92..fc478a8 100644 --- a/internal/service/oauth_service.go +++ b/internal/service/oauth_service.go @@ -7,6 +7,7 @@ import ( "time" "github.com/steveiliop56/tinyauth/internal/config" + "github.com/steveiliop56/tinyauth/internal/utils/tlog" "golang.org/x/oauth2" ) @@ -69,6 +70,7 @@ func (s *OAuthService) GetAuthURL(state string, verifier string) string { } func (s *OAuthService) GetToken(code string, verifier string) (*oauth2.Token, error) { + tlog.App.Debug().Str("code", code).Str("verifier", verifier).Msg("Exchanging code for token") return s.config.Exchange(s.ctx, code, oauth2.VerifierOption(verifier)) }