refactor: oauth flow (#726)

* wip

* feat: add oauth session impl in auth service

* feat: move oauth logic into auth service and handle multiple sessions

* tests: fix tests

* fix: review comments

* fix: prevent ddos attacks in oauth rate limit
This commit is contained in:
Stavros
2026-03-22 21:03:32 +02:00
committed by GitHub
parent d71a8e03cc
commit f26c217161
15 changed files with 520 additions and 558 deletions

View File

@@ -21,26 +21,25 @@ type OAuthRequest struct {
}
type OAuthControllerConfig struct {
CSRFCookieName string
RedirectCookieName string
SecureCookie bool
AppURL string
CookieDomain string
CSRFCookieName string
OAuthSessionCookieName string
RedirectCookieName string
SecureCookie bool
AppURL string
CookieDomain string
}
type OAuthController struct {
config OAuthControllerConfig
router *gin.RouterGroup
auth *service.AuthService
broker *service.OAuthBrokerService
}
func NewOAuthController(config OAuthControllerConfig, router *gin.RouterGroup, auth *service.AuthService, broker *service.OAuthBrokerService) *OAuthController {
func NewOAuthController(config OAuthControllerConfig, router *gin.RouterGroup, auth *service.AuthService) *OAuthController {
return &OAuthController{
config: config,
router: router,
auth: auth,
broker: broker,
}
}
@@ -63,21 +62,30 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
return
}
service, exists := controller.broker.GetService(req.Provider)
sessionId, session, err := controller.auth.NewOAuthSession(req.Provider)
if !exists {
tlog.App.Warn().Msgf("OAuth provider not found: %s", req.Provider)
c.JSON(404, gin.H{
"status": 404,
"message": "Not Found",
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to create OAuth session")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
})
return
}
service.GenerateVerifier()
state := service.GenerateState()
authURL := service.GetAuthURL(state)
c.SetCookie(controller.config.CSRFCookieName, state, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
authUrl, err := controller.auth.GetOAuthURL(sessionId)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to get OAuth URL")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
})
return
}
c.SetCookie(controller.config.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
c.SetCookie(controller.config.CSRFCookieName, session.State, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
redirectURI := c.Query("redirect_uri")
isRedirectSafe := utils.IsRedirectSafe(redirectURI, controller.config.CookieDomain)
@@ -95,7 +103,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
c.JSON(200, gin.H{
"status": 200,
"message": "OK",
"url": authURL,
"url": authUrl,
})
}
@@ -112,6 +120,17 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
return
}
sessionIdCookie, err := c.Cookie(controller.config.OAuthSessionCookieName)
if err != nil {
tlog.App.Warn().Err(err).Msg("OAuth session cookie missing")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}
c.SetCookie(controller.config.OAuthSessionCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
defer controller.auth.EndOAuthSession(sessionIdCookie)
state := c.Query("state")
csrfCookie, err := c.Cookie(controller.config.CSRFCookieName)
@@ -125,29 +144,16 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
c.SetCookie(controller.config.CSRFCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
code := c.Query("code")
service, exists := controller.broker.GetService(req.Provider)
if !exists {
tlog.App.Warn().Msgf("OAuth provider not found: %s", req.Provider)
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}
err = service.VerifyCode(code)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to verify OAuth code")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}
user, err := controller.broker.GetUser(req.Provider)
_, err = controller.auth.GetOAuthToken(sessionIdCookie, code)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to get user from OAuth provider")
tlog.App.Error().Err(err).Msg("Failed to exchange code for token")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}
user, err := controller.auth.GetOAuthUserinfo(sessionIdCookie)
if user.Email == "" {
tlog.App.Error().Msg("OAuth provider did not return an email")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
@@ -192,13 +198,21 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
username = strings.Replace(user.Email, "@", "_", 1)
}
service, err := controller.auth.GetOAuthService(sessionIdCookie)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to get OAuth service for session")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
return
}
sessionCookie := repository.Session{
Username: username,
Name: name,
Email: user.Email,
Provider: req.Provider,
OAuthGroups: utils.CoalesceToString(user.Groups),
OAuthName: service.GetName(),
OAuthName: service.Name(),
OAuthSub: user.Sub,
}

View File

@@ -85,7 +85,7 @@ func setupProxyController(t *testing.T, middlewares []gin.HandlerFunc) (*gin.Eng
LoginTimeout: 300,
LoginMaxRetries: 3,
SessionCookieName: "tinyauth-session",
}, dockerService, nil, queries)
}, dockerService, nil, queries, &service.OAuthBrokerService{})
// Controller
ctrl := controller.NewProxyController(controller.ProxyControllerConfig{

View File

@@ -71,7 +71,7 @@ func setupUserController(t *testing.T, middlewares *[]gin.HandlerFunc) (*gin.Eng
LoginTimeout: 300,
LoginMaxRetries: 3,
SessionCookieName: "tinyauth-session",
}, nil, nil, queries)
}, nil, nil, queries, &service.OAuthBrokerService{})
// Controller
ctrl := controller.NewUserController(controller.UserControllerConfig{