From f26c2171610d5c2dfbba2edb6ccd39490e349803 Mon Sep 17 00:00:00 2001 From: Stavros Date: Sun, 22 Mar 2026 21:03:32 +0200 Subject: [PATCH] refactor: oauth flow (#726) * wip * feat: add oauth session impl in auth service * feat: move oauth logic into auth service and handle multiple sessions * tests: fix tests * fix: review comments * fix: prevent ddos attacks in oauth rate limit --- internal/bootstrap/app_bootstrap.go | 30 +-- internal/bootstrap/router_bootstrap.go | 13 +- internal/bootstrap/service_bootstrap.go | 22 +- internal/config/config.go | 1 + internal/controller/oauth_controller.go | 88 +++++--- internal/controller/proxy_controller_test.go | 2 +- internal/controller/user_controller_test.go | 2 +- internal/service/auth_service.go | 223 +++++++++++++++++-- internal/service/generic_oauth_service.go | 132 ----------- internal/service/github_oauth_service.go | 184 --------------- internal/service/google_oauth_service.go | 116 ---------- internal/service/oauth_broker_service.go | 62 ++---- internal/service/oauth_extractors.go | 102 +++++++++ internal/service/oauth_presets.go | 23 ++ internal/service/oauth_service.go | 78 +++++++ 15 files changed, 520 insertions(+), 558 deletions(-) delete mode 100644 internal/service/generic_oauth_service.go delete mode 100644 internal/service/github_oauth_service.go delete mode 100644 internal/service/google_oauth_service.go create mode 100644 internal/service/oauth_extractors.go create mode 100644 internal/service/oauth_presets.go create mode 100644 internal/service/oauth_service.go diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go index b1598f7..12e4f7f 100644 --- a/internal/bootstrap/app_bootstrap.go +++ b/internal/bootstrap/app_bootstrap.go @@ -22,16 +22,17 @@ import ( type BootstrapApp struct { config config.Config context struct { - appUrl string - uuid string - cookieDomain string - sessionCookieName string - csrfCookieName string - redirectCookieName string - users []config.User - oauthProviders map[string]config.OAuthServiceConfig - configuredProviders []controller.Provider - oidcClients []config.OIDCClientConfig + appUrl string + uuid string + cookieDomain string + sessionCookieName string + csrfCookieName string + redirectCookieName string + oauthSessionCookieName string + users []config.User + oauthProviders map[string]config.OAuthServiceConfig + configuredProviders []controller.Provider + oidcClients []config.OIDCClientConfig } services Services } @@ -113,6 +114,7 @@ func (app *BootstrapApp) Setup() error { app.context.sessionCookieName = fmt.Sprintf("%s-%s", config.SessionCookieName, cookieId) app.context.csrfCookieName = fmt.Sprintf("%s-%s", config.CSRFCookieName, cookieId) app.context.redirectCookieName = fmt.Sprintf("%s-%s", config.RedirectCookieName, cookieId) + app.context.oauthSessionCookieName = fmt.Sprintf("%s-%s", config.OAuthSessionCookieName, cookieId) // Dumps tlog.App.Trace().Interface("config", app.config).Msg("Config dump") @@ -190,12 +192,12 @@ func (app *BootstrapApp) Setup() error { // Start db cleanup routine tlog.App.Debug().Msg("Starting database cleanup routine") - go app.dbCleanup(queries) + go app.dbCleanupRoutine(queries) // If analytics are not disabled, start heartbeat if app.config.Analytics.Enabled { tlog.App.Debug().Msg("Starting heartbeat routine") - go app.heartbeat() + go app.heartbeatRoutine() } // If we have an socket path, bind to it @@ -226,7 +228,7 @@ func (app *BootstrapApp) Setup() error { return nil } -func (app *BootstrapApp) heartbeat() { +func (app *BootstrapApp) heartbeatRoutine() { ticker := time.NewTicker(time.Duration(12) * time.Hour) defer ticker.Stop() @@ -280,7 +282,7 @@ func (app *BootstrapApp) heartbeat() { } } -func (app *BootstrapApp) dbCleanup(queries *repository.Queries) { +func (app *BootstrapApp) dbCleanupRoutine(queries *repository.Queries) { ticker := time.NewTicker(time.Duration(30) * time.Minute) defer ticker.Stop() ctx := context.Background() diff --git a/internal/bootstrap/router_bootstrap.go b/internal/bootstrap/router_bootstrap.go index 6ae7e94..ef92d10 100644 --- a/internal/bootstrap/router_bootstrap.go +++ b/internal/bootstrap/router_bootstrap.go @@ -77,12 +77,13 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) { contextController.SetupRoutes() oauthController := controller.NewOAuthController(controller.OAuthControllerConfig{ - AppURL: app.config.AppURL, - SecureCookie: app.config.Auth.SecureCookie, - CSRFCookieName: app.context.csrfCookieName, - RedirectCookieName: app.context.redirectCookieName, - CookieDomain: app.context.cookieDomain, - }, apiRouter, app.services.authService, app.services.oauthBrokerService) + AppURL: app.config.AppURL, + SecureCookie: app.config.Auth.SecureCookie, + CSRFCookieName: app.context.csrfCookieName, + RedirectCookieName: app.context.redirectCookieName, + CookieDomain: app.context.cookieDomain, + OAuthSessionCookieName: app.context.oauthSessionCookieName, + }, apiRouter, app.services.authService) oauthController.SetupRoutes() diff --git a/internal/bootstrap/service_bootstrap.go b/internal/bootstrap/service_bootstrap.go index e9a27a7..7bd4a62 100644 --- a/internal/bootstrap/service_bootstrap.go +++ b/internal/bootstrap/service_bootstrap.go @@ -58,6 +58,16 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er services.accessControlService = accessControlsService + oauthBrokerService := service.NewOAuthBrokerService(app.context.oauthProviders) + + err = oauthBrokerService.Init() + + if err != nil { + return Services{}, err + } + + services.oauthBrokerService = oauthBrokerService + authService := service.NewAuthService(service.AuthServiceConfig{ Users: app.context.users, OauthWhitelist: app.config.OAuth.Whitelist, @@ -70,7 +80,7 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er SessionCookieName: app.context.sessionCookieName, IP: app.config.Auth.IP, LDAPGroupsCacheTTL: app.config.Ldap.GroupCacheTTL, - }, dockerService, services.ldapService, queries) + }, dockerService, services.ldapService, queries, services.oauthBrokerService) err = authService.Init() @@ -80,16 +90,6 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er services.authService = authService - oauthBrokerService := service.NewOAuthBrokerService(app.context.oauthProviders) - - err = oauthBrokerService.Init() - - if err != nil { - return Services{}, err - } - - services.oauthBrokerService = oauthBrokerService - oidcService := service.NewOIDCService(service.OIDCServiceConfig{ Clients: app.config.OIDC.Clients, PrivateKeyPath: app.config.OIDC.PrivateKeyPath, diff --git a/internal/config/config.go b/internal/config/config.go index 4633953..b8db08a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -73,6 +73,7 @@ var BuildTimestamp = "0000-00-00T00:00:00Z" var SessionCookieName = "tinyauth-session" var CSRFCookieName = "tinyauth-csrf" var RedirectCookieName = "tinyauth-redirect" +var OAuthSessionCookieName = "tinyauth-oauth" // Main app config diff --git a/internal/controller/oauth_controller.go b/internal/controller/oauth_controller.go index 019bae7..d5dfc39 100644 --- a/internal/controller/oauth_controller.go +++ b/internal/controller/oauth_controller.go @@ -21,26 +21,25 @@ type OAuthRequest struct { } type OAuthControllerConfig struct { - CSRFCookieName string - RedirectCookieName string - SecureCookie bool - AppURL string - CookieDomain string + CSRFCookieName string + OAuthSessionCookieName string + RedirectCookieName string + SecureCookie bool + AppURL string + CookieDomain string } type OAuthController struct { config OAuthControllerConfig router *gin.RouterGroup auth *service.AuthService - broker *service.OAuthBrokerService } -func NewOAuthController(config OAuthControllerConfig, router *gin.RouterGroup, auth *service.AuthService, broker *service.OAuthBrokerService) *OAuthController { +func NewOAuthController(config OAuthControllerConfig, router *gin.RouterGroup, auth *service.AuthService) *OAuthController { return &OAuthController{ config: config, router: router, auth: auth, - broker: broker, } } @@ -63,21 +62,30 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) { return } - service, exists := controller.broker.GetService(req.Provider) + sessionId, session, err := controller.auth.NewOAuthSession(req.Provider) - if !exists { - tlog.App.Warn().Msgf("OAuth provider not found: %s", req.Provider) - c.JSON(404, gin.H{ - "status": 404, - "message": "Not Found", + if err != nil { + tlog.App.Error().Err(err).Msg("Failed to create OAuth session") + c.JSON(500, gin.H{ + "status": 500, + "message": "Internal Server Error", }) return } - service.GenerateVerifier() - state := service.GenerateState() - authURL := service.GetAuthURL(state) - c.SetCookie(controller.config.CSRFCookieName, state, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true) + authUrl, err := controller.auth.GetOAuthURL(sessionId) + + if err != nil { + tlog.App.Error().Err(err).Msg("Failed to get OAuth URL") + c.JSON(500, gin.H{ + "status": 500, + "message": "Internal Server Error", + }) + return + } + + c.SetCookie(controller.config.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true) + c.SetCookie(controller.config.CSRFCookieName, session.State, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true) redirectURI := c.Query("redirect_uri") isRedirectSafe := utils.IsRedirectSafe(redirectURI, controller.config.CookieDomain) @@ -95,7 +103,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) { c.JSON(200, gin.H{ "status": 200, "message": "OK", - "url": authURL, + "url": authUrl, }) } @@ -112,6 +120,17 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { return } + sessionIdCookie, err := c.Cookie(controller.config.OAuthSessionCookieName) + + if err != nil { + tlog.App.Warn().Err(err).Msg("OAuth session cookie missing") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + return + } + + c.SetCookie(controller.config.OAuthSessionCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true) + defer controller.auth.EndOAuthSession(sessionIdCookie) + state := c.Query("state") csrfCookie, err := c.Cookie(controller.config.CSRFCookieName) @@ -125,29 +144,16 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { c.SetCookie(controller.config.CSRFCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true) code := c.Query("code") - service, exists := controller.broker.GetService(req.Provider) - - if !exists { - tlog.App.Warn().Msgf("OAuth provider not found: %s", req.Provider) - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) - return - } - - err = service.VerifyCode(code) - if err != nil { - tlog.App.Error().Err(err).Msg("Failed to verify OAuth code") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) - return - } - - user, err := controller.broker.GetUser(req.Provider) + _, err = controller.auth.GetOAuthToken(sessionIdCookie, code) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to get user from OAuth provider") + tlog.App.Error().Err(err).Msg("Failed to exchange code for token") c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) return } + user, err := controller.auth.GetOAuthUserinfo(sessionIdCookie) + if user.Email == "" { tlog.App.Error().Msg("OAuth provider did not return an email") c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) @@ -192,13 +198,21 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { username = strings.Replace(user.Email, "@", "_", 1) } + service, err := controller.auth.GetOAuthService(sessionIdCookie) + + if err != nil { + tlog.App.Error().Err(err).Msg("Failed to get OAuth service for session") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + return + } + sessionCookie := repository.Session{ Username: username, Name: name, Email: user.Email, Provider: req.Provider, OAuthGroups: utils.CoalesceToString(user.Groups), - OAuthName: service.GetName(), + OAuthName: service.Name(), OAuthSub: user.Sub, } diff --git a/internal/controller/proxy_controller_test.go b/internal/controller/proxy_controller_test.go index e22e7c4..f7e73ec 100644 --- a/internal/controller/proxy_controller_test.go +++ b/internal/controller/proxy_controller_test.go @@ -85,7 +85,7 @@ func setupProxyController(t *testing.T, middlewares []gin.HandlerFunc) (*gin.Eng LoginTimeout: 300, LoginMaxRetries: 3, SessionCookieName: "tinyauth-session", - }, dockerService, nil, queries) + }, dockerService, nil, queries, &service.OAuthBrokerService{}) // Controller ctrl := controller.NewProxyController(controller.ProxyControllerConfig{ diff --git a/internal/controller/user_controller_test.go b/internal/controller/user_controller_test.go index fedc98c..672740c 100644 --- a/internal/controller/user_controller_test.go +++ b/internal/controller/user_controller_test.go @@ -71,7 +71,7 @@ func setupUserController(t *testing.T, middlewares *[]gin.HandlerFunc) (*gin.Eng LoginTimeout: 300, LoginMaxRetries: 3, SessionCookieName: "tinyauth-session", - }, nil, nil, queries) + }, nil, nil, queries, &service.OAuthBrokerService{}) // Controller ctrl := controller.NewUserController(controller.UserControllerConfig{ diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index 69bfad4..53c879d 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -17,8 +17,21 @@ import ( "github.com/gin-gonic/gin" "github.com/google/uuid" "golang.org/x/crypto/bcrypt" + "golang.org/x/exp/slices" + "golang.org/x/oauth2" ) +const MaxOAuthPendingSessions = 256 +const OAuthCleanupCount = 16 + +type OAuthPendingSession struct { + State string + Verifier string + Token *oauth2.Token + Service *OAuthServiceImpl + ExpiresAt time.Time +} + type LdapGroupsCache struct { Groups []string Expires time.Time @@ -45,28 +58,34 @@ type AuthServiceConfig struct { } type AuthService struct { - config AuthServiceConfig - docker *DockerService - loginAttempts map[string]*LoginAttempt - ldapGroupsCache map[string]*LdapGroupsCache - loginMutex sync.RWMutex - ldapGroupsMutex sync.RWMutex - ldap *LdapService - queries *repository.Queries + config AuthServiceConfig + docker *DockerService + loginAttempts map[string]*LoginAttempt + ldapGroupsCache map[string]*LdapGroupsCache + oauthPendingSessions map[string]*OAuthPendingSession + oauthMutex sync.RWMutex + loginMutex sync.RWMutex + ldapGroupsMutex sync.RWMutex + ldap *LdapService + queries *repository.Queries + oauthBroker *OAuthBrokerService } -func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapService, queries *repository.Queries) *AuthService { +func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapService, queries *repository.Queries, oauthBroker *OAuthBrokerService) *AuthService { return &AuthService{ - config: config, - docker: docker, - loginAttempts: make(map[string]*LoginAttempt), - ldapGroupsCache: make(map[string]*LdapGroupsCache), - ldap: ldap, - queries: queries, + config: config, + docker: docker, + loginAttempts: make(map[string]*LoginAttempt), + ldapGroupsCache: make(map[string]*LdapGroupsCache), + oauthPendingSessions: make(map[string]*OAuthPendingSession), + ldap: ldap, + queries: queries, + oauthBroker: oauthBroker, } } func (auth *AuthService) Init() error { + go auth.CleanupOAuthSessionsRoutine() return nil } @@ -553,3 +572,177 @@ func (auth *AuthService) IsBypassedIP(acls config.AppIP, ip string) bool { tlog.App.Debug().Str("ip", ip).Msg("IP not in bypass list, continuing with authentication") return false } + +func (auth *AuthService) NewOAuthSession(serviceName string) (string, OAuthPendingSession, error) { + auth.ensureOAuthSessionLimit() + + service, ok := auth.oauthBroker.GetService(serviceName) + + if !ok { + return "", OAuthPendingSession{}, fmt.Errorf("oauth service not found: %s", serviceName) + } + + sessionId, err := uuid.NewRandom() + + if err != nil { + return "", OAuthPendingSession{}, fmt.Errorf("failed to generate session ID: %w", err) + } + + state := service.NewRandom() + verifier := service.NewRandom() + + session := OAuthPendingSession{ + State: state, + Verifier: verifier, + Service: &service, + ExpiresAt: time.Now().Add(1 * time.Hour), + } + + auth.oauthMutex.Lock() + auth.oauthPendingSessions[sessionId.String()] = &session + auth.oauthMutex.Unlock() + + return sessionId.String(), session, nil +} + +func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) { + session, err := auth.getOAuthPendingSession(sessionId) + + if err != nil { + return "", err + } + + return (*session.Service).GetAuthURL(session.State, session.Verifier), nil +} + +func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.Token, error) { + session, err := auth.getOAuthPendingSession(sessionId) + + if err != nil { + return nil, err + } + + token, err := (*session.Service).GetToken(code, session.Verifier) + + if err != nil { + return nil, fmt.Errorf("failed to exchange code for token: %w", err) + } + + auth.oauthMutex.Lock() + session.Token = token + auth.oauthMutex.Unlock() + + return token, nil +} + +func (auth *AuthService) GetOAuthUserinfo(sessionId string) (config.Claims, error) { + session, err := auth.getOAuthPendingSession(sessionId) + + if err != nil { + return config.Claims{}, err + } + + if session.Token == nil { + return config.Claims{}, fmt.Errorf("oauth token not found for session: %s", sessionId) + } + + userinfo, err := (*session.Service).GetUserinfo(session.Token) + + if err != nil { + return config.Claims{}, fmt.Errorf("failed to get userinfo: %w", err) + } + + return userinfo, nil +} + +func (auth *AuthService) GetOAuthService(sessionId string) (OAuthServiceImpl, error) { + session, err := auth.getOAuthPendingSession(sessionId) + + if err != nil { + return nil, err + } + + return *session.Service, nil +} + +func (auth *AuthService) EndOAuthSession(sessionId string) { + auth.oauthMutex.Lock() + delete(auth.oauthPendingSessions, sessionId) + auth.oauthMutex.Unlock() +} + +func (auth *AuthService) CleanupOAuthSessionsRoutine() { + ticker := time.NewTicker(30 * time.Minute) + defer ticker.Stop() + + for range ticker.C { + auth.oauthMutex.Lock() + + now := time.Now() + + for sessionId, session := range auth.oauthPendingSessions { + if now.After(session.ExpiresAt) { + delete(auth.oauthPendingSessions, sessionId) + } + } + + auth.oauthMutex.Unlock() + } +} + +func (auth *AuthService) getOAuthPendingSession(sessionId string) (*OAuthPendingSession, error) { + auth.ensureOAuthSessionLimit() + + auth.oauthMutex.RLock() + session, exists := auth.oauthPendingSessions[sessionId] + auth.oauthMutex.RUnlock() + + if !exists { + return &OAuthPendingSession{}, fmt.Errorf("oauth session not found: %s", sessionId) + } + + if time.Now().After(session.ExpiresAt) { + auth.oauthMutex.Lock() + delete(auth.oauthPendingSessions, sessionId) + auth.oauthMutex.Unlock() + return &OAuthPendingSession{}, fmt.Errorf("oauth session expired: %s", sessionId) + } + + return session, nil +} + +func (auth *AuthService) ensureOAuthSessionLimit() { + auth.oauthMutex.Lock() + defer auth.oauthMutex.Unlock() + + if len(auth.oauthPendingSessions) >= MaxOAuthPendingSessions { + + cleanupIds := make([]string, 0, OAuthCleanupCount) + + for range OAuthCleanupCount { + oldestId := "" + oldestTime := int64(0) + + for id, session := range auth.oauthPendingSessions { + if oldestTime == 0 { + oldestId = id + oldestTime = session.ExpiresAt.Unix() + continue + } + if slices.Contains(cleanupIds, id) { + continue + } + if session.ExpiresAt.Unix() < oldestTime { + oldestId = id + oldestTime = session.ExpiresAt.Unix() + } + } + + cleanupIds = append(cleanupIds, oldestId) + } + + for _, id := range cleanupIds { + delete(auth.oauthPendingSessions, id) + } + } +} diff --git a/internal/service/generic_oauth_service.go b/internal/service/generic_oauth_service.go deleted file mode 100644 index ef17b0e..0000000 --- a/internal/service/generic_oauth_service.go +++ /dev/null @@ -1,132 +0,0 @@ -package service - -import ( - "context" - "crypto/rand" - "crypto/tls" - "encoding/base64" - "encoding/json" - "fmt" - "io" - "net/http" - "time" - - "github.com/steveiliop56/tinyauth/internal/config" - "github.com/steveiliop56/tinyauth/internal/utils/tlog" - - "golang.org/x/oauth2" -) - -type GenericOAuthService struct { - config oauth2.Config - context context.Context - token *oauth2.Token - verifier string - insecureSkipVerify bool - userinfoUrl string - name string -} - -func NewGenericOAuthService(config config.OAuthServiceConfig) *GenericOAuthService { - return &GenericOAuthService{ - config: oauth2.Config{ - ClientID: config.ClientID, - ClientSecret: config.ClientSecret, - RedirectURL: config.RedirectURL, - Scopes: config.Scopes, - Endpoint: oauth2.Endpoint{ - AuthURL: config.AuthURL, - TokenURL: config.TokenURL, - }, - }, - insecureSkipVerify: config.Insecure, - userinfoUrl: config.UserinfoURL, - name: config.Name, - } -} - -func (generic *GenericOAuthService) Init() error { - transport := &http.Transport{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: generic.insecureSkipVerify, - MinVersion: tls.VersionTLS12, - }, - } - - httpClient := &http.Client{ - Transport: transport, - Timeout: 30 * time.Second, - } - - ctx := context.Background() - - ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) - - generic.context = ctx - return nil -} - -func (generic *GenericOAuthService) GenerateState() string { - b := make([]byte, 128) - _, err := rand.Read(b) - if err != nil { - return base64.RawURLEncoding.EncodeToString(fmt.Appendf(nil, "state-%d", time.Now().UnixNano())) - } - state := base64.RawURLEncoding.EncodeToString(b) - return state -} - -func (generic *GenericOAuthService) GenerateVerifier() string { - verifier := oauth2.GenerateVerifier() - generic.verifier = verifier - return verifier -} - -func (generic *GenericOAuthService) GetAuthURL(state string) string { - return generic.config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(generic.verifier)) -} - -func (generic *GenericOAuthService) VerifyCode(code string) error { - token, err := generic.config.Exchange(generic.context, code, oauth2.VerifierOption(generic.verifier)) - - if err != nil { - return err - } - - generic.token = token - return nil -} - -func (generic *GenericOAuthService) Userinfo() (config.Claims, error) { - var user config.Claims - - client := generic.config.Client(generic.context, generic.token) - - res, err := client.Get(generic.userinfoUrl) - if err != nil { - return user, err - } - defer res.Body.Close() - - if res.StatusCode < 200 || res.StatusCode >= 300 { - return user, fmt.Errorf("request failed with status: %s", res.Status) - } - - body, err := io.ReadAll(res.Body) - if err != nil { - return user, err - } - - tlog.App.Trace().Str("body", string(body)).Msg("Userinfo response body") - - err = json.Unmarshal(body, &user) - if err != nil { - return user, err - } - - return user, nil -} - -func (generic *GenericOAuthService) GetName() string { - return generic.name -} diff --git a/internal/service/github_oauth_service.go b/internal/service/github_oauth_service.go deleted file mode 100644 index 35b552a..0000000 --- a/internal/service/github_oauth_service.go +++ /dev/null @@ -1,184 +0,0 @@ -package service - -import ( - "context" - "crypto/rand" - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - "strconv" - "time" - - "github.com/steveiliop56/tinyauth/internal/config" - - "golang.org/x/oauth2" - "golang.org/x/oauth2/endpoints" -) - -var GithubOAuthScopes = []string{"user:email", "read:user"} - -type GithubEmailResponse []struct { - Email string `json:"email"` - Primary bool `json:"primary"` -} - -type GithubUserInfoResponse struct { - Login string `json:"login"` - Name string `json:"name"` - ID int `json:"id"` -} - -type GithubOAuthService struct { - config oauth2.Config - context context.Context - token *oauth2.Token - verifier string - name string -} - -func NewGithubOAuthService(config config.OAuthServiceConfig) *GithubOAuthService { - return &GithubOAuthService{ - config: oauth2.Config{ - ClientID: config.ClientID, - ClientSecret: config.ClientSecret, - RedirectURL: config.RedirectURL, - Scopes: GithubOAuthScopes, - Endpoint: endpoints.GitHub, - }, - name: config.Name, - } -} - -func (github *GithubOAuthService) Init() error { - httpClient := &http.Client{ - Timeout: 30 * time.Second, - } - ctx := context.Background() - ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) - github.context = ctx - return nil -} - -func (github *GithubOAuthService) GenerateState() string { - b := make([]byte, 128) - _, err := rand.Read(b) - if err != nil { - return base64.RawURLEncoding.EncodeToString(fmt.Appendf(nil, "state-%d", time.Now().UnixNano())) - } - state := base64.RawURLEncoding.EncodeToString(b) - return state -} - -func (github *GithubOAuthService) GenerateVerifier() string { - verifier := oauth2.GenerateVerifier() - github.verifier = verifier - return verifier -} - -func (github *GithubOAuthService) GetAuthURL(state string) string { - return github.config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(github.verifier)) -} - -func (github *GithubOAuthService) VerifyCode(code string) error { - token, err := github.config.Exchange(github.context, code, oauth2.VerifierOption(github.verifier)) - - if err != nil { - return err - } - - github.token = token - return nil -} - -func (github *GithubOAuthService) Userinfo() (config.Claims, error) { - var user config.Claims - - client := github.config.Client(github.context, github.token) - - req, err := http.NewRequest("GET", "https://api.github.com/user", nil) - if err != nil { - return user, err - } - - req.Header.Set("Accept", "application/vnd.github+json") - - res, err := client.Do(req) - if err != nil { - return user, err - } - defer res.Body.Close() - - if res.StatusCode < 200 || res.StatusCode >= 300 { - return user, fmt.Errorf("request failed with status: %s", res.Status) - } - - body, err := io.ReadAll(res.Body) - if err != nil { - return user, err - } - - var userInfo GithubUserInfoResponse - - err = json.Unmarshal(body, &userInfo) - if err != nil { - return user, err - } - - req, err = http.NewRequest("GET", "https://api.github.com/user/emails", nil) - if err != nil { - return user, err - } - - req.Header.Set("Accept", "application/vnd.github+json") - - res, err = client.Do(req) - if err != nil { - return user, err - } - defer res.Body.Close() - - if res.StatusCode < 200 || res.StatusCode >= 300 { - return user, fmt.Errorf("request failed with status: %s", res.Status) - } - - body, err = io.ReadAll(res.Body) - if err != nil { - return user, err - } - - var emails GithubEmailResponse - - err = json.Unmarshal(body, &emails) - if err != nil { - return user, err - } - - for _, email := range emails { - if email.Primary { - user.Email = email.Email - break - } - } - - if len(emails) == 0 { - return user, errors.New("no emails found") - } - - // Use first available email if no primary email was found - if user.Email == "" { - user.Email = emails[0].Email - } - - user.PreferredUsername = userInfo.Login - user.Name = userInfo.Name - user.Sub = strconv.Itoa(userInfo.ID) - - return user, nil -} - -func (github *GithubOAuthService) GetName() string { - return github.name -} diff --git a/internal/service/google_oauth_service.go b/internal/service/google_oauth_service.go deleted file mode 100644 index 6dfbeaf..0000000 --- a/internal/service/google_oauth_service.go +++ /dev/null @@ -1,116 +0,0 @@ -package service - -import ( - "context" - "crypto/rand" - "encoding/base64" - "encoding/json" - "fmt" - "io" - "net/http" - "strings" - "time" - - "github.com/steveiliop56/tinyauth/internal/config" - - "golang.org/x/oauth2" - "golang.org/x/oauth2/endpoints" -) - -var GoogleOAuthScopes = []string{"openid", "email", "profile"} - -type GoogleOAuthService struct { - config oauth2.Config - context context.Context - token *oauth2.Token - verifier string - name string -} - -func NewGoogleOAuthService(config config.OAuthServiceConfig) *GoogleOAuthService { - return &GoogleOAuthService{ - config: oauth2.Config{ - ClientID: config.ClientID, - ClientSecret: config.ClientSecret, - RedirectURL: config.RedirectURL, - Scopes: GoogleOAuthScopes, - Endpoint: endpoints.Google, - }, - name: config.Name, - } -} - -func (google *GoogleOAuthService) Init() error { - httpClient := &http.Client{ - Timeout: 30 * time.Second, - } - ctx := context.Background() - ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) - google.context = ctx - return nil -} - -func (oauth *GoogleOAuthService) GenerateState() string { - b := make([]byte, 128) - _, err := rand.Read(b) - if err != nil { - return base64.RawURLEncoding.EncodeToString(fmt.Appendf(nil, "state-%d", time.Now().UnixNano())) - } - state := base64.RawURLEncoding.EncodeToString(b) - return state -} - -func (google *GoogleOAuthService) GenerateVerifier() string { - verifier := oauth2.GenerateVerifier() - google.verifier = verifier - return verifier -} - -func (google *GoogleOAuthService) GetAuthURL(state string) string { - return google.config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(google.verifier)) -} - -func (google *GoogleOAuthService) VerifyCode(code string) error { - token, err := google.config.Exchange(google.context, code, oauth2.VerifierOption(google.verifier)) - - if err != nil { - return err - } - - google.token = token - return nil -} - -func (google *GoogleOAuthService) Userinfo() (config.Claims, error) { - var user config.Claims - - client := google.config.Client(google.context, google.token) - - res, err := client.Get("https://openidconnect.googleapis.com/v1/userinfo") - if err != nil { - return config.Claims{}, err - } - defer res.Body.Close() - - if res.StatusCode < 200 || res.StatusCode >= 300 { - return user, fmt.Errorf("request failed with status: %s", res.Status) - } - - body, err := io.ReadAll(res.Body) - if err != nil { - return config.Claims{}, err - } - - err = json.Unmarshal(body, &user) - if err != nil { - return config.Claims{}, err - } - - user.PreferredUsername = strings.SplitN(user.Email, "@", 2)[0] - - return user, nil -} - -func (google *GoogleOAuthService) GetName() string { - return google.name -} diff --git a/internal/service/oauth_broker_service.go b/internal/service/oauth_broker_service.go index 76c23e9..40b6734 100644 --- a/internal/service/oauth_broker_service.go +++ b/internal/service/oauth_broker_service.go @@ -1,60 +1,48 @@ package service import ( - "errors" - "github.com/steveiliop56/tinyauth/internal/config" "github.com/steveiliop56/tinyauth/internal/utils/tlog" "golang.org/x/exp/slices" + "golang.org/x/oauth2" ) -type OAuthService interface { - Init() error - GenerateState() string - GenerateVerifier() string - GetAuthURL(state string) string - VerifyCode(code string) error - Userinfo() (config.Claims, error) - GetName() string +type OAuthServiceImpl interface { + Name() string + NewRandom() string + GetAuthURL(state string, verifier string) string + GetToken(code string, verifier string) (*oauth2.Token, error) + GetUserinfo(token *oauth2.Token) (config.Claims, error) } type OAuthBrokerService struct { - services map[string]OAuthService + services map[string]OAuthServiceImpl configs map[string]config.OAuthServiceConfig } +var presets = map[string]func(config config.OAuthServiceConfig) *OAuthService{ + "github": newGitHubOAuthService, + "google": newGoogleOAuthService, +} + func NewOAuthBrokerService(configs map[string]config.OAuthServiceConfig) *OAuthBrokerService { return &OAuthBrokerService{ - services: make(map[string]OAuthService), + services: make(map[string]OAuthServiceImpl), configs: configs, } } func (broker *OAuthBrokerService) Init() error { for name, cfg := range broker.configs { - switch name { - case "github": - service := NewGithubOAuthService(cfg) - broker.services[name] = service - case "google": - service := NewGoogleOAuthService(cfg) - broker.services[name] = service - default: - service := NewGenericOAuthService(cfg) - broker.services[name] = service + if presetFunc, exists := presets[name]; exists { + broker.services[name] = presetFunc(cfg) + tlog.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset") + } else { + broker.services[name] = NewOAuthService(cfg) + tlog.App.Debug().Str("service", name).Msg("Loaded OAuth service from config") } } - - for name, service := range broker.services { - err := service.Init() - if err != nil { - tlog.App.Error().Err(err).Msgf("Failed to initialize OAuth service: %s", name) - return err - } - tlog.App.Info().Str("service", name).Msg("Initialized OAuth service") - } - return nil } @@ -67,15 +55,7 @@ func (broker *OAuthBrokerService) GetConfiguredServices() []string { return services } -func (broker *OAuthBrokerService) GetService(name string) (OAuthService, bool) { +func (broker *OAuthBrokerService) GetService(name string) (OAuthServiceImpl, bool) { service, exists := broker.services[name] return service, exists } - -func (broker *OAuthBrokerService) GetUser(service string) (config.Claims, error) { - oauthService, exists := broker.services[service] - if !exists { - return config.Claims{}, errors.New("oauth service not found") - } - return oauthService.Userinfo() -} diff --git a/internal/service/oauth_extractors.go b/internal/service/oauth_extractors.go new file mode 100644 index 0000000..91b1387 --- /dev/null +++ b/internal/service/oauth_extractors.go @@ -0,0 +1,102 @@ +package service + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strconv" + + "github.com/steveiliop56/tinyauth/internal/config" +) + +type GithubEmailResponse []struct { + Email string `json:"email"` + Primary bool `json:"primary"` +} + +type GithubUserInfoResponse struct { + Login string `json:"login"` + Name string `json:"name"` + ID int `json:"id"` +} + +func defaultExtractor(client *http.Client, url string) (config.Claims, error) { + return simpleReq[config.Claims](client, url, nil) +} + +func githubExtractor(client *http.Client, url string) (config.Claims, error) { + var user config.Claims + + userInfo, err := simpleReq[GithubUserInfoResponse](client, "https://api.github.com/user", map[string]string{ + "accept": "application/vnd.github+json", + }) + if err != nil { + return config.Claims{}, err + } + + userEmails, err := simpleReq[GithubEmailResponse](client, "https://api.github.com/user/emails", map[string]string{ + "accept": "application/vnd.github+json", + }) + if err != nil { + return config.Claims{}, err + } + + if len(userEmails) == 0 { + return user, errors.New("no emails found") + } + + for _, email := range userEmails { + if email.Primary { + user.Email = email.Email + break + } + } + + // Use first available email if no primary email was found + if user.Email == "" { + user.Email = userEmails[0].Email + } + + user.PreferredUsername = userInfo.Login + user.Name = userInfo.Name + user.Sub = strconv.Itoa(userInfo.ID) + + return user, nil +} + +func simpleReq[T any](client *http.Client, url string, headers map[string]string) (T, error) { + var decodedRes T + + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return decodedRes, err + } + + for key, value := range headers { + req.Header.Add(key, value) + } + + res, err := client.Do(req) + if err != nil { + return decodedRes, err + } + defer res.Body.Close() + + if res.StatusCode < 200 || res.StatusCode >= 300 { + return decodedRes, fmt.Errorf("request failed with status: %s", res.Status) + } + + body, err := io.ReadAll(res.Body) + if err != nil { + return decodedRes, err + } + + err = json.Unmarshal(body, &decodedRes) + if err != nil { + return decodedRes, err + } + + return decodedRes, nil +} diff --git a/internal/service/oauth_presets.go b/internal/service/oauth_presets.go new file mode 100644 index 0000000..6c658dc --- /dev/null +++ b/internal/service/oauth_presets.go @@ -0,0 +1,23 @@ +package service + +import ( + "github.com/steveiliop56/tinyauth/internal/config" + "golang.org/x/oauth2/endpoints" +) + +func newGoogleOAuthService(config config.OAuthServiceConfig) *OAuthService { + scopes := []string{"openid", "email", "profile"} + config.Scopes = scopes + config.AuthURL = endpoints.Google.AuthURL + config.TokenURL = endpoints.Google.TokenURL + config.UserinfoURL = "https://openidconnect.googleapis.com/v1/userinfo" + return NewOAuthService(config) +} + +func newGitHubOAuthService(config config.OAuthServiceConfig) *OAuthService { + scopes := []string{"read:user", "user:email"} + config.Scopes = scopes + config.AuthURL = endpoints.GitHub.AuthURL + config.TokenURL = endpoints.GitHub.TokenURL + return NewOAuthService(config).WithUserinfoExtractor(githubExtractor) +} diff --git a/internal/service/oauth_service.go b/internal/service/oauth_service.go new file mode 100644 index 0000000..76f5a92 --- /dev/null +++ b/internal/service/oauth_service.go @@ -0,0 +1,78 @@ +package service + +import ( + "context" + "crypto/tls" + "net/http" + "time" + + "github.com/steveiliop56/tinyauth/internal/config" + "golang.org/x/oauth2" +) + +type UserinfoExtractor func(client *http.Client, url string) (config.Claims, error) + +type OAuthService struct { + serviceCfg config.OAuthServiceConfig + config *oauth2.Config + ctx context.Context + userinfoExtractor UserinfoExtractor +} + +func NewOAuthService(config config.OAuthServiceConfig) *OAuthService { + httpClient := &http.Client{ + Timeout: 30 * time.Second, + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: config.Insecure, + }, + }, + } + ctx := context.Background() + ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) + + return &OAuthService{ + serviceCfg: config, + config: &oauth2.Config{ + ClientID: config.ClientID, + ClientSecret: config.ClientSecret, + RedirectURL: config.RedirectURL, + Scopes: config.Scopes, + Endpoint: oauth2.Endpoint{ + AuthURL: config.AuthURL, + TokenURL: config.TokenURL, + }, + }, + ctx: ctx, + userinfoExtractor: defaultExtractor, + } +} + +func (s *OAuthService) WithUserinfoExtractor(extractor UserinfoExtractor) *OAuthService { + s.userinfoExtractor = extractor + return s +} + +func (s *OAuthService) Name() string { + return s.serviceCfg.Name +} + +func (s *OAuthService) NewRandom() string { + // The generate verifier function just creates a random string, + // so we can use it to generate a random state as well + random := oauth2.GenerateVerifier() + return random +} + +func (s *OAuthService) GetAuthURL(state string, verifier string) string { + return s.config.AuthCodeURL(state, oauth2.AccessTypeOnline, oauth2.S256ChallengeOption(verifier)) +} + +func (s *OAuthService) GetToken(code string, verifier string) (*oauth2.Token, error) { + return s.config.Exchange(s.ctx, code, oauth2.VerifierOption(verifier)) +} + +func (s *OAuthService) GetUserinfo(token *oauth2.Token) (config.Claims, error) { + client := oauth2.NewClient(s.ctx, oauth2.StaticTokenSource(token)) + return s.userinfoExtractor(client, s.serviceCfg.UserinfoURL) +}