refactor: rework oidc session storage

This commit is contained in:
Stavros
2026-05-31 20:10:53 +03:00
parent 82d21c3b28
commit 695feca71c
29 changed files with 668 additions and 1880 deletions
+193 -194
View File
@@ -19,7 +19,6 @@ import (
"slices"
"github.com/gin-gonic/gin"
"github.com/go-jose/go-jose/v4"
"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/model"
@@ -42,6 +41,10 @@ var (
ErrInvalidClient = errors.New("invalid_client")
)
// This is not spec-compliant, the ID token SHOULD NOT contain user info claims but,
// it has became a "standard" and apps are looking for the claims in the ID tokens
// instead of calling the userinfo endpoint, so we include them in the ID token as well
// for better compatibility with existing apps
type ClaimSet struct {
Iss string `json:"iss"`
Aud string `json:"aud"`
@@ -67,6 +70,8 @@ type ClaimSet struct {
Nonce string `json:"nonce,omitempty"`
}
// We use this struct as both a response struct and a struct to store userinfo
// in the database
type UserinfoResponse struct {
Sub string `json:"sub"`
Name string `json:"name,omitempty"`
@@ -111,6 +116,16 @@ type AuthorizeRequest struct {
CodeChallengeMethod string `json:"code_challenge_method"`
}
type AuthorizeCodeEntry struct {
CodeHash string
Scope string
RedirectURI string
ClientID string
Nonce string
CodeChallenge string
Userinfo UserinfoResponse
}
type OIDCService struct {
log *logger.Logger
config model.Config
@@ -121,6 +136,10 @@ type OIDCService struct {
privateKey *rsa.PrivateKey
publicKey *rsa.PublicKey
issuer string
caches struct {
code *CacheStore[AuthorizeCodeEntry]
}
}
func NewOIDCService(
@@ -282,7 +301,26 @@ func NewOIDCService(
}
// Start cleanup routine
dg.Go(service.cleanupRoutine, ding.RingMinor)
// dg.Go(service.cleanupRoutine, ding.RingMinor)
// Create caches
codeCash := NewCacheStore[AuthorizeCodeEntry](256)
service.caches.code = codeCash
// Start cache cleanup routine
dg.Go(func(ctx context.Context) {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
service.caches.code.Sweep()
case <-ctx.Done():
return
}
}
}, ding.RingMinor)
return service, nil
}
@@ -345,19 +383,17 @@ func (service *OIDCService) filterScopes(scopes []string) []string {
})
}
func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, req AuthorizeRequest) error {
// Fixed 10 minutes
expiresAt := time.Now().Add(time.Minute * time.Duration(10)).Unix()
func (service *OIDCService) CreateCode(req AuthorizeRequest, userContext model.UserContext) string {
code := utils.GenerateString(32)
sub := service.CreateSub(userContext, req.ClientID)
entry := repository.CreateOidcCodeParams{
Sub: sub,
CodeHash: service.Hash(code),
// Here it's safe to split and trust the output since, we validated the scopes before
Scope: strings.Join(service.filterScopes(strings.Split(req.Scope, " ")), ","),
entry := AuthorizeCodeEntry{
CodeHash: service.Hash(code),
Scope: strings.Join(service.filterScopes(strings.Split(req.Scope, " ")), " "),
RedirectURI: req.RedirectURI,
ClientID: req.ClientID,
ExpiresAt: expiresAt,
Nonce: req.Nonce,
Userinfo: service.userinfoFromContext(userContext, sub),
}
if req.CodeChallenge != "" {
@@ -369,14 +405,14 @@ func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, r
}
}
// Insert the code into the database
_, err := service.queries.CreateOidcCode(c, entry)
// Store the code in the cache
service.caches.code.Set(entry.CodeHash, entry, 10*time.Minute)
return err
return code
}
func (service *OIDCService) StoreUserinfo(c *gin.Context, sub string, userContext model.UserContext, req AuthorizeRequest) error {
userInfoParams := repository.CreateOidcUserInfoParams{
func (service *OIDCService) userinfoFromContext(userContext model.UserContext, sub string) UserinfoResponse {
userInfo := UserinfoResponse{
Sub: sub,
Name: userContext.GetName(),
Email: userContext.GetEmail(),
@@ -385,37 +421,31 @@ func (service *OIDCService) StoreUserinfo(c *gin.Context, sub string, userContex
}
if userContext.IsLocal() {
addressJSON, err := json.Marshal(userContext.Local.Attributes.Address)
if err != nil {
return err
}
userInfoParams.GivenName = userContext.Local.Attributes.GivenName
userInfoParams.FamilyName = userContext.Local.Attributes.FamilyName
userInfoParams.MiddleName = userContext.Local.Attributes.MiddleName
userInfoParams.Nickname = userContext.Local.Attributes.Nickname
userInfoParams.Profile = userContext.Local.Attributes.Profile
userInfoParams.Picture = userContext.Local.Attributes.Picture
userInfoParams.Website = userContext.Local.Attributes.Website
userInfoParams.Gender = userContext.Local.Attributes.Gender
userInfoParams.Birthdate = userContext.Local.Attributes.Birthdate
userInfoParams.Zoneinfo = userContext.Local.Attributes.Zoneinfo
userInfoParams.Locale = userContext.Local.Attributes.Locale
userInfoParams.PhoneNumber = userContext.Local.Attributes.PhoneNumber
userInfoParams.Address = string(addressJSON)
userInfo.GivenName = userContext.Local.Attributes.GivenName
userInfo.FamilyName = userContext.Local.Attributes.FamilyName
userInfo.MiddleName = userContext.Local.Attributes.MiddleName
userInfo.Nickname = userContext.Local.Attributes.Nickname
userInfo.Profile = userContext.Local.Attributes.Profile
userInfo.Picture = userContext.Local.Attributes.Picture
userInfo.Website = userContext.Local.Attributes.Website
userInfo.Gender = userContext.Local.Attributes.Gender
userInfo.Birthdate = userContext.Local.Attributes.Birthdate
userInfo.Zoneinfo = userContext.Local.Attributes.Zoneinfo
userInfo.Locale = userContext.Local.Attributes.Locale
userInfo.PhoneNumber = userContext.Local.Attributes.PhoneNumber
userInfo.Address = &userContext.Local.Attributes.Address
}
// Tinyauth will pass through the groups it got from an LDAP or an OIDC server
if userContext.IsLDAP() {
userInfoParams.Groups = strings.Join(userContext.LDAP.Groups, ",")
userInfo.Groups = userContext.LDAP.Groups
}
if userContext.IsOAuth() {
userInfoParams.Groups = strings.Join(userContext.OAuth.Groups, ",")
userInfo.Groups = userContext.OAuth.Groups
}
_, err := service.queries.CreateOidcUserInfo(c, userInfoParams)
return err
return userInfo
}
func (service *OIDCService) ValidateGrantType(grantType string) error {
@@ -426,36 +456,24 @@ func (service *OIDCService) ValidateGrantType(grantType string) error {
return nil
}
func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string, clientId string) (repository.OidcCode, error) {
oidcCode, err := service.queries.GetOidcCode(c, codeHash)
func (service *OIDCService) GetCodeEntry(codeHash string, clientId string) (*AuthorizeCodeEntry, bool) {
entry, ok := service.caches.code.Get(codeHash)
if err != nil {
if errors.Is(err, repository.ErrNotFound) {
return repository.OidcCode{}, ErrCodeNotFound
}
return repository.OidcCode{}, err
if !ok {
return nil, false
}
if time.Now().Unix() > oidcCode.ExpiresAt {
err = service.queries.DeleteOidcCode(c, codeHash)
if err != nil {
return repository.OidcCode{}, err
}
err = service.DeleteUserinfo(c, oidcCode.Sub)
if err != nil {
return repository.OidcCode{}, err
}
return repository.OidcCode{}, ErrCodeExpired
if entry.ClientID != clientId {
return nil, false
}
if oidcCode.ClientID != clientId {
return repository.OidcCode{}, ErrInvalidClient
}
// Since the code can only be used once, we delete it from the cache after retrieving it
service.caches.code.Delete(codeHash)
return oidcCode, nil
return &entry, true
}
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user repository.OidcUserinfo, scope string, nonce string) (string, error) {
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user UserinfoResponse, scope string, nonce string) (string, error) {
createdAt := time.Now().Unix()
expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
@@ -521,17 +539,11 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
return token, nil
}
func (service *OIDCService) GenerateAccessToken(c *gin.Context, client model.OIDCClientConfig, codeEntry repository.OidcCode) (TokenResponse, error) {
user, err := service.GetUserinfo(c, codeEntry.Sub)
func (service *OIDCService) GenerateAccessToken(ctx context.Context, client model.OIDCClientConfig, codeEntry AuthorizeCodeEntry) (*TokenResponse, error) {
idToken, err := service.generateIDToken(client, codeEntry.Userinfo, codeEntry.Scope, codeEntry.Nonce)
if err != nil {
return TokenResponse{}, err
}
idToken, err := service.generateIDToken(client, user, codeEntry.Scope, codeEntry.Nonce)
if err != nil {
return TokenResponse{}, err
return nil, err
}
accessToken := utils.GenerateString(32)
@@ -551,56 +563,68 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client model.OID
Scope: strings.ReplaceAll(codeEntry.Scope, ",", " "),
}
_, err = service.queries.CreateOidcToken(c, repository.CreateOidcTokenParams{
Sub: codeEntry.Sub,
var userInfoJson []byte
userInfoJson, err = json.Marshal(codeEntry.Userinfo)
if err != nil {
return nil, err
}
_, err = service.queries.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{
Sub: codeEntry.Userinfo.Sub,
AccessTokenHash: service.Hash(accessToken),
RefreshTokenHash: service.Hash(refreshToken),
ClientID: client.ClientID,
Scope: codeEntry.Scope,
ClientID: client.ClientID,
TokenExpiresAt: tokenExpiresAt,
RefreshTokenExpiresAt: refreshTokenExpiresAt,
Nonce: codeEntry.Nonce,
CodeHash: codeEntry.CodeHash,
UserinfoJson: string(userInfoJson),
})
if err != nil {
return TokenResponse{}, err
return nil, err
}
return tokenResponse, nil
return &tokenResponse, nil
}
func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken string, reqClientId string) (TokenResponse, error) {
entry, err := service.queries.GetOidcTokenByRefreshToken(c, service.Hash(refreshToken))
func (service *OIDCService) RefreshAccessToken(ctx context.Context, refreshToken string, clientId string) (*TokenResponse, error) {
entry, err := service.queries.GetOIDCSessionByRefreshTokenHash(ctx, service.Hash(refreshToken))
if err != nil {
if errors.Is(err, repository.ErrNotFound) {
return TokenResponse{}, ErrTokenNotFound
return nil, ErrTokenNotFound
}
return TokenResponse{}, err
return nil, err
}
if entry.RefreshTokenExpiresAt < time.Now().Unix() {
return TokenResponse{}, ErrTokenExpired
return nil, ErrTokenExpired
}
// Ensure the client ID in the request matches the client ID in the token
if entry.ClientID != reqClientId {
return TokenResponse{}, ErrInvalidClient
if entry.ClientID != clientId {
return nil, ErrInvalidClient
}
user, err := service.GetUserinfo(c, entry.Sub)
// we need to unmarshal the userinfo from the database to include it in the new ID token,
// since the ID token includes user claims for better compatibility with existing apps
var userInfo UserinfoResponse
err = json.Unmarshal([]byte(entry.UserinfoJson), &userInfo)
if err != nil {
return TokenResponse{}, err
return nil, err
}
idToken, err := service.generateIDToken(model.OIDCClientConfig{
ClientID: entry.ClientID,
}, user, entry.Scope, entry.Nonce)
}, userInfo, entry.Scope, entry.Nonce)
if err != nil {
return TokenResponse{}, err
return nil, err
}
accessToken := utils.GenerateString(32)
@@ -618,71 +642,54 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri
Scope: strings.ReplaceAll(entry.Scope, ",", " "),
}
_, err = service.queries.UpdateOidcTokenByRefreshToken(c, repository.UpdateOidcTokenByRefreshTokenParams{
_, err = service.queries.UpdateOIDCSession(ctx, repository.UpdateOIDCSessionParams{
Sub: entry.Sub,
AccessTokenHash: service.Hash(accessToken),
RefreshTokenHash: service.Hash(newRefreshToken),
Scope: entry.Scope,
ClientID: entry.ClientID,
TokenExpiresAt: tokenExpiresAt,
RefreshTokenExpiresAt: refreshTokenExpiresAt,
RefreshTokenHash_2: service.Hash(refreshToken), // that's the selector, it's not stored in the db
Nonce: entry.Nonce,
UserinfoJson: entry.UserinfoJson,
})
if err != nil {
return TokenResponse{}, err
return nil, err
}
return tokenResponse, nil
return &tokenResponse, nil
}
func (service *OIDCService) DeleteCodeEntry(c *gin.Context, codeHash string) error {
return service.queries.DeleteOidcCode(c, codeHash)
}
func (service *OIDCService) DeleteUserinfo(c *gin.Context, sub string) error {
return service.queries.DeleteOidcUserInfo(c, sub)
}
func (service *OIDCService) DeleteToken(c *gin.Context, tokenHash string) error {
return service.queries.DeleteOidcToken(c, tokenHash)
}
func (service *OIDCService) DeleteTokenByCodeHash(c *gin.Context, codeHash string) error {
return service.queries.DeleteOidcTokenByCodeHash(c, codeHash)
}
func (service *OIDCService) GetAccessToken(c *gin.Context, tokenHash string) (repository.OidcToken, error) {
entry, err := service.queries.GetOidcToken(c, tokenHash)
func (service *OIDCService) GetSessionByToken(ctx context.Context, tokenHash string) (*repository.OidcSession, error) {
entry, err := service.queries.GetOIDCSessionByAccessTokenHash(ctx, tokenHash)
if err != nil {
if errors.Is(err, repository.ErrNotFound) {
return repository.OidcToken{}, ErrTokenNotFound
return nil, ErrTokenNotFound
}
return repository.OidcToken{}, err
return nil, err
}
if entry.TokenExpiresAt < time.Now().Unix() {
// If refresh token is expired, delete the token and userinfo since there is no way for the client to access anything anymore
// If refresh token is expired, delete the session
// since there is no way for the client to access anything anymore
if entry.RefreshTokenExpiresAt < time.Now().Unix() {
err := service.DeleteToken(c, tokenHash)
// Deletes by sub
err := service.queries.DeleteSession(ctx, entry.Sub)
if err != nil {
return repository.OidcToken{}, err
}
err = service.DeleteUserinfo(c, entry.Sub)
if err != nil {
return repository.OidcToken{}, err
return nil, err
}
return nil, ErrTokenExpired
}
return repository.OidcToken{}, ErrTokenExpired
return nil, ErrTokenExpired
}
return entry, nil
return &entry, nil
}
func (service *OIDCService) GetUserinfo(c *gin.Context, sub string) (repository.OidcUserinfo, error) {
return service.queries.GetOidcUserInfo(c, sub)
}
func (service *OIDCService) CompileUserinfo(user repository.OidcUserinfo, scope string) UserinfoResponse {
scopes := strings.Split(scope, ",") // split by comma since it's a db entry
func (service *OIDCService) CompileUserinfo(user UserinfoResponse, scope string) UserinfoResponse {
scopes := strings.Split(scope, " ")
userInfo := UserinfoResponse{
Sub: user.Sub,
UpdatedAt: user.UpdatedAt,
@@ -710,11 +717,7 @@ func (service *OIDCService) CompileUserinfo(user repository.OidcUserinfo, scope
}
if slices.Contains(scopes, "groups") {
if user.Groups != "" {
userInfo.Groups = strings.Split(user.Groups, ",")
} else {
userInfo.Groups = []string{}
}
userInfo.Groups = user.Groups
}
if slices.Contains(scopes, "phone") {
@@ -724,10 +727,7 @@ func (service *OIDCService) CompileUserinfo(user repository.OidcUserinfo, scope
}
if slices.Contains(scopes, "address") {
var addr model.AddressClaim
if err := json.Unmarshal([]byte(user.Address), &addr); err == nil {
userInfo.Address = &addr
}
userInfo.Address = user.Address
}
return userInfo
@@ -740,83 +740,75 @@ func (service *OIDCService) Hash(token string) string {
}
func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) error {
err := service.queries.DeleteOidcCodeBySub(ctx, sub)
if err != nil && !errors.Is(err, repository.ErrNotFound) {
return err
}
err = service.queries.DeleteOidcTokenBySub(ctx, sub)
if err != nil && !errors.Is(err, repository.ErrNotFound) {
return err
}
err = service.queries.DeleteOidcUserInfo(ctx, sub)
err := service.queries.DeleteOIDCSessionBySub(ctx, sub)
if err != nil && !errors.Is(err, repository.ErrNotFound) {
return err
}
return nil
}
// Cleanup routine - Resource heavy due to the linked tables
func (service *OIDCService) cleanupRoutine(ctx context.Context) {
service.log.App.Debug().Msg("Starting OIDC cleanup routine")
ticker := time.NewTicker(time.Duration(30) * time.Minute)
defer ticker.Stop()
// // Cleanup routine - Resource heavy due to the linked tables
// func (service *OIDCService) cleanupRoutine(ctx context.Context) {
// service.log.App.Debug().Msg("Starting OIDC cleanup routine")
// ticker := time.NewTicker(time.Duration(30) * time.Minute)
// defer ticker.Stop()
for {
select {
case <-ticker.C:
service.log.App.Debug().Msg("Performing OIDC cleanup routine")
// for {
// select {
// case <-ticker.C:
// service.log.App.Debug().Msg("Performing OIDC cleanup routine")
currentTime := time.Now().Unix()
// currentTime := time.Now().Unix()
// For the OIDC tokens, if they are expired we delete the userinfo and codes
expiredTokens, err := service.queries.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{
TokenExpiresAt: currentTime,
RefreshTokenExpiresAt: currentTime,
})
// // For the OIDC tokens, if they are expired we delete the userinfo and codes
// expiredTokens, err := service.queries.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{
// TokenExpiresAt: currentTime,
// RefreshTokenExpiresAt: currentTime,
// })
if err != nil {
service.log.App.Warn().Err(err).Msg("Failed to delete expired tokens")
}
// if err != nil {
// service.log.App.Warn().Err(err).Msg("Failed to delete expired tokens")
// }
for _, expiredToken := range expiredTokens {
err := service.DeleteOldSession(ctx, expiredToken.Sub)
if err != nil {
service.log.App.Warn().Err(err).Msg("Failed to delete session for expired token")
}
}
// for _, expiredToken := range expiredTokens {
// err := service.DeleteOldSession(ctx, expiredToken.Sub)
// if err != nil {
// service.log.App.Warn().Err(err).Msg("Failed to delete session for expired token")
// }
// }
// For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything
expiredCodes, err := service.queries.DeleteExpiredOidcCodes(ctx, currentTime)
// // For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything
// expiredCodes, err := service.queries.DeleteExpiredOidcCodes(ctx, currentTime)
if err != nil {
service.log.App.Warn().Err(err).Msg("Failed to delete expired codes")
}
// if err != nil {
// service.log.App.Warn().Err(err).Msg("Failed to delete expired codes")
// }
for _, expiredCode := range expiredCodes {
token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub)
// for _, expiredCode := range expiredCodes {
// token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub)
if err != nil {
if !errors.Is(err, repository.ErrNotFound) {
service.log.App.Warn().Err(err).Msg("Failed to get token by sub for expired code")
}
continue
}
// if err != nil {
// if !errors.Is(err, repository.ErrNotFound) {
// service.log.App.Warn().Err(err).Msg("Failed to get token by sub for expired code")
// }
// continue
// }
if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime {
err := service.DeleteOldSession(ctx, expiredCode.Sub)
if err != nil {
service.log.App.Warn().Err(err).Msg("Failed to delete session for expired code")
}
}
}
// if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime {
// err := service.DeleteOldSession(ctx, expiredCode.Sub)
// if err != nil {
// service.log.App.Warn().Err(err).Msg("Failed to delete session for expired code")
// }
// }
// }
service.log.App.Debug().Msg("Finished OIDC cleanup routine")
case <-ctx.Done():
service.log.App.Debug().Msg("Stopping OIDC cleanup routine")
return
}
}
}
// service.log.App.Debug().Msg("Finished OIDC cleanup routine")
// case <-ctx.Done():
// service.log.App.Debug().Msg("Stopping OIDC cleanup routine")
// return
// }
// }
// }
func (service *OIDCService) GetJWK() ([]byte, error) {
hasher := sha256.New()
@@ -851,3 +843,10 @@ func (service *OIDCService) hashAndEncodePKCE(codeVerifier string) string {
hasher.Write([]byte(codeVerifier))
return base64.RawURLEncoding.EncodeToString(hasher.Sum(nil))
}
// WARNING: Since Tinyauth is stateless, we cannot have a sub that never changes.
// We will just create a uuid out of the username and client name which remains stable,
// but if username or client name changes then sub changes too.
func (service *OIDCService) CreateSub(userContext model.UserContext, clientId string) string {
return utils.GenerateUUID(fmt.Sprintf("%s:%s", userContext.GetUsername(), clientId))
}