mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-06-02 17:40:14 +00:00
refactor: rework oidc session storage
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@@ -12,7 +13,6 @@ import (
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
)
|
||||
|
||||
@@ -169,7 +169,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
client, ok := controller.oidc.GetClient(req.ClientID)
|
||||
_, ok := controller.oidc.GetClient(req.ClientID)
|
||||
|
||||
if !ok {
|
||||
controller.authorizeError(c, authorizeErrorParams{
|
||||
@@ -203,9 +203,8 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// WARNING: Since Tinyauth is stateless, we cannot have a sub that never changes. We will just create a uuid out of the username and client name which remains stable, but if username or client name changes then sub changes too.
|
||||
sub := utils.GenerateUUID(fmt.Sprintf("%s:%s", userContext.GetUsername(), client.ID))
|
||||
code := utils.GenerateString(32)
|
||||
// Create the sub to find and delete old sessions
|
||||
sub := controller.oidc.CreateSub(*userContext, req.ClientID)
|
||||
|
||||
// Before storing the code, delete old session
|
||||
err = controller.oidc.DeleteOldSession(c, sub)
|
||||
@@ -221,37 +220,8 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
err = controller.oidc.StoreCode(c, sub, code, req)
|
||||
|
||||
if err != nil {
|
||||
controller.authorizeError(c, authorizeErrorParams{
|
||||
err: err,
|
||||
reason: "Failed to store code",
|
||||
reasonPublic: "Failed to store code",
|
||||
callback: req.RedirectURI,
|
||||
callbackError: "server_error",
|
||||
state: req.State,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// We also need a snapshot of the user that authorized this (skip if no openid scope)
|
||||
if slices.Contains(strings.Fields(req.Scope), "openid") {
|
||||
err = controller.oidc.StoreUserinfo(c, sub, *userContext, req)
|
||||
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to store user info")
|
||||
controller.authorizeError(c, authorizeErrorParams{
|
||||
err: err,
|
||||
reason: "Failed to store user info",
|
||||
reasonPublic: "Failed to store user info",
|
||||
callback: req.RedirectURI,
|
||||
callbackError: "server_error",
|
||||
state: req.State,
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
// Create the authorization code
|
||||
code := controller.oidc.CreateCode(req, *userContext)
|
||||
|
||||
queries, err := query.Values(AuthorizeCallback{
|
||||
Code: code,
|
||||
@@ -354,35 +324,12 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
||||
|
||||
switch req.GrantType {
|
||||
case "authorization_code":
|
||||
entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code), client.ClientID)
|
||||
if err != nil {
|
||||
if err := controller.oidc.DeleteTokenByCodeHash(c, controller.oidc.Hash(req.Code)); err != nil {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to revoke tokens for replayed code")
|
||||
}
|
||||
if errors.Is(err, service.ErrCodeNotFound) {
|
||||
controller.log.App.Warn().Msg("Code not found")
|
||||
c.JSON(400, gin.H{
|
||||
"error": "invalid_grant",
|
||||
})
|
||||
return
|
||||
}
|
||||
if errors.Is(err, service.ErrCodeExpired) {
|
||||
controller.log.App.Warn().Msg("Code expired")
|
||||
c.JSON(400, gin.H{
|
||||
"error": "invalid_grant",
|
||||
})
|
||||
return
|
||||
}
|
||||
if errors.Is(err, service.ErrInvalidClient) {
|
||||
controller.log.App.Warn().Msg("Code does not belong to client")
|
||||
c.JSON(400, gin.H{
|
||||
"error": "invalid_client",
|
||||
})
|
||||
return
|
||||
}
|
||||
controller.log.App.Error().Err(err).Msg("Failed to get code entry")
|
||||
entry, ok := controller.oidc.GetCodeEntry(controller.oidc.Hash(req.Code), client.ClientID)
|
||||
|
||||
if !ok {
|
||||
controller.log.App.Warn().Msg("Code not found")
|
||||
c.JSON(400, gin.H{
|
||||
"error": "server_error",
|
||||
"error": "invalid_grant",
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -395,7 +342,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
ok := controller.oidc.ValidatePKCE(entry.CodeChallenge, req.CodeVerifier)
|
||||
ok = controller.oidc.ValidatePKCE(entry.CodeChallenge, req.CodeVerifier)
|
||||
|
||||
if !ok {
|
||||
controller.log.App.Warn().Msg("PKCE validation failed")
|
||||
@@ -405,7 +352,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, entry)
|
||||
tokenRes, err := controller.oidc.GenerateAccessToken(c, client, *entry)
|
||||
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to generate access token")
|
||||
@@ -415,7 +362,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
tokenResponse = tokenRes
|
||||
tokenResponse = *tokenRes
|
||||
case "refresh_token":
|
||||
tokenRes, err := controller.oidc.RefreshAccessToken(c, req.RefreshToken, creds.ClientID)
|
||||
|
||||
@@ -443,7 +390,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
tokenResponse = tokenRes
|
||||
tokenResponse = *tokenRes
|
||||
}
|
||||
|
||||
c.Header("cache-control", "no-store")
|
||||
@@ -507,7 +454,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
entry, err := controller.oidc.GetAccessToken(c, controller.oidc.Hash(token))
|
||||
entry, err := controller.oidc.GetSessionByToken(c, controller.oidc.Hash(token))
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrTokenNotFound) {
|
||||
@@ -526,15 +473,17 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
|
||||
}
|
||||
|
||||
// If we don't have the openid scope, return an error
|
||||
if !slices.Contains(strings.Split(entry.Scope, ","), "openid") {
|
||||
controller.log.App.Warn().Msg("OIDC userinfo accessed with token missing openid scope")
|
||||
if !slices.Contains(strings.Split(entry.Scope, " "), "openid") {
|
||||
controller.log.App.Warn().Msg("OIDC userinfo accessed with missing openid scope")
|
||||
c.JSON(401, gin.H{
|
||||
"error": "invalid_scope",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := controller.oidc.GetUserinfo(c, entry.Sub)
|
||||
var userinfo service.UserinfoResponse
|
||||
|
||||
err = json.Unmarshal([]byte(entry.UserinfoJson), &userinfo)
|
||||
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to get user info")
|
||||
@@ -544,7 +493,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, controller.oidc.CompileUserinfo(user, entry.Scope))
|
||||
c.JSON(200, controller.oidc.CompileUserinfo(userinfo, entry.Scope))
|
||||
}
|
||||
|
||||
func (controller *OIDCController) authorizeError(c *gin.Context, params authorizeErrorParams) {
|
||||
|
||||
Reference in New Issue
Block a user