fix: fix oauth oidc flow

This commit is contained in:
Stavros
2026-06-06 17:02:06 +03:00
parent da9079246a
commit f078e3549e
3 changed files with 20 additions and 25 deletions
+5 -8
View File
@@ -61,7 +61,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
return return
} }
var reqParams service.OAuthURLParams var reqParams service.OAuthCallbackParams
err = c.BindQuery(&reqParams) err = c.BindQuery(&reqParams)
@@ -83,7 +83,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
} }
} }
sessionId, _, err := controller.auth.NewOAuthSession(req.Provider, reqParams) sessionId, err := controller.auth.NewOAuthSession(req.Provider, reqParams)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to create new OAuth session") controller.log.App.Error().Err(err).Msg("Failed to create new OAuth session")
@@ -272,7 +272,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return return
} }
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/authorize?%s", controller.runtime.AppURL, queries.Encode())) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/oidc/authorize?%s", controller.runtime.AppURL, queries.Encode()))
return return
} }
@@ -294,11 +294,8 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
c.Redirect(http.StatusTemporaryRedirect, controller.runtime.AppURL) c.Redirect(http.StatusTemporaryRedirect, controller.runtime.AppURL)
} }
func (controller *OAuthController) isOidcRequest(params service.OAuthURLParams) bool { func (controller *OAuthController) isOidcRequest(params service.OAuthCallbackParams) bool {
return params.Scope != "" && return params.LoginFor == "oidc"
params.ResponseType != "" &&
params.ClientID != "" &&
params.RedirectURI != ""
} }
func (controller *OAuthController) getCookieDomain() string { func (controller *OAuthController) getCookieDomain() string {
+2 -1
View File
@@ -9,6 +9,7 @@ import (
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
"github.com/google/go-querystring/query" "github.com/google/go-querystring/query"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
@@ -116,7 +117,7 @@ func (controller *OIDCController) authorize(c *gin.Context) {
var req service.AuthorizeRequest var req service.AuthorizeRequest
err := c.Bind(&req) err := c.ShouldBindWith(&req, binding.Query)
if err != nil { if err != nil {
controller.authorizeError(c, authorizeErrorParams{ controller.authorizeError(c, authorizeErrorParams{
+13 -16
View File
@@ -30,17 +30,14 @@ var (
ErrUserNotFound = errors.New("user not found") ErrUserNotFound = errors.New("user not found")
) )
// slightly modified version of the AuthorizeRequest from the OIDC service to basically accept all // We either store params for redirecting to an app after OAuth login,
// parameters and pass them to the authorize page if needed // or for redirecting back to the authorize screen to continue OIDC
type OAuthURLParams struct { type OAuthCallbackParams struct {
Scope string `form:"scope" url:"scope"` LoginFor string `form:"login_for" url:"login_for"`
ResponseType string `form:"response_type" url:"response_type"` OIDCTicket string `form:"oidc_ticket" url:"oidc_ticket"`
ClientID string `form:"client_id" url:"client_id"` OIDCScope string `form:"oidc_scope" url:"oidc_scope"`
RedirectURI string `form:"redirect_uri" url:"redirect_uri"` OIDCName string `form:"oidc_name" url:"oidc_name"`
State string `form:"state" url:"state"` RedirectURI string `form:"redirect_uri" url:"redirect_uri"`
Nonce string `form:"nonce" url:"nonce"`
CodeChallenge string `form:"code_challenge" url:"code_challenge"`
CodeChallengeMethod string `form:"code_challenge_method" url:"code_challenge_method"`
} }
type OAuthPendingSession struct { type OAuthPendingSession struct {
@@ -49,7 +46,7 @@ type OAuthPendingSession struct {
Token *oauth2.Token Token *oauth2.Token
Service *OAuthServiceImpl Service *OAuthServiceImpl
ExpiresAt time.Time ExpiresAt time.Time
CallbackParams OAuthURLParams CallbackParams OAuthCallbackParams
} }
type LoginAttempt struct { type LoginAttempt struct {
@@ -516,17 +513,17 @@ func (auth *AuthService) LDAPAuthConfigured() bool {
return auth.ldap != nil return auth.ldap != nil
} }
func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLParams) (string, OAuthPendingSession, error) { func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthCallbackParams) (string, error) {
service, ok := auth.oauthBroker.GetService(serviceName) service, ok := auth.oauthBroker.GetService(serviceName)
if !ok { if !ok {
return "", OAuthPendingSession{}, fmt.Errorf("oauth service not found: %s", serviceName) return "", fmt.Errorf("oauth service not found: %s", serviceName)
} }
sessionId, err := uuid.NewRandom() sessionId, err := uuid.NewRandom()
if err != nil { if err != nil {
return "", OAuthPendingSession{}, fmt.Errorf("failed to generate session ID: %w", err) return "", fmt.Errorf("failed to generate session ID: %w", err)
} }
state := service.NewRandom() state := service.NewRandom()
@@ -542,7 +539,7 @@ func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLPara
auth.caches.oauth.Set(sessionId.String(), session, time.Minute*10) auth.caches.oauth.Set(sessionId.String(), session, time.Minute*10)
return sessionId.String(), session, nil return sessionId.String(), nil
} }
func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) { func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) {