mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-06-03 01:50:14 +00:00
refactor: rework oidc session storage
This commit is contained in:
+193
-194
@@ -19,7 +19,6 @@ import (
|
||||
|
||||
"slices"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-jose/go-jose/v4"
|
||||
"github.com/steveiliop56/ding"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
@@ -42,6 +41,10 @@ var (
|
||||
ErrInvalidClient = errors.New("invalid_client")
|
||||
)
|
||||
|
||||
// This is not spec-compliant, the ID token SHOULD NOT contain user info claims but,
|
||||
// it has became a "standard" and apps are looking for the claims in the ID tokens
|
||||
// instead of calling the userinfo endpoint, so we include them in the ID token as well
|
||||
// for better compatibility with existing apps
|
||||
type ClaimSet struct {
|
||||
Iss string `json:"iss"`
|
||||
Aud string `json:"aud"`
|
||||
@@ -67,6 +70,8 @@ type ClaimSet struct {
|
||||
Nonce string `json:"nonce,omitempty"`
|
||||
}
|
||||
|
||||
// We use this struct as both a response struct and a struct to store userinfo
|
||||
// in the database
|
||||
type UserinfoResponse struct {
|
||||
Sub string `json:"sub"`
|
||||
Name string `json:"name,omitempty"`
|
||||
@@ -111,6 +116,16 @@ type AuthorizeRequest struct {
|
||||
CodeChallengeMethod string `json:"code_challenge_method"`
|
||||
}
|
||||
|
||||
type AuthorizeCodeEntry struct {
|
||||
CodeHash string
|
||||
Scope string
|
||||
RedirectURI string
|
||||
ClientID string
|
||||
Nonce string
|
||||
CodeChallenge string
|
||||
Userinfo UserinfoResponse
|
||||
}
|
||||
|
||||
type OIDCService struct {
|
||||
log *logger.Logger
|
||||
config model.Config
|
||||
@@ -121,6 +136,10 @@ type OIDCService struct {
|
||||
privateKey *rsa.PrivateKey
|
||||
publicKey *rsa.PublicKey
|
||||
issuer string
|
||||
|
||||
caches struct {
|
||||
code *CacheStore[AuthorizeCodeEntry]
|
||||
}
|
||||
}
|
||||
|
||||
func NewOIDCService(
|
||||
@@ -282,7 +301,26 @@ func NewOIDCService(
|
||||
}
|
||||
|
||||
// Start cleanup routine
|
||||
dg.Go(service.cleanupRoutine, ding.RingMinor)
|
||||
// dg.Go(service.cleanupRoutine, ding.RingMinor)
|
||||
|
||||
// Create caches
|
||||
codeCash := NewCacheStore[AuthorizeCodeEntry](256)
|
||||
service.caches.code = codeCash
|
||||
|
||||
// Start cache cleanup routine
|
||||
dg.Go(func(ctx context.Context) {
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
service.caches.code.Sweep()
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}, ding.RingMinor)
|
||||
|
||||
return service, nil
|
||||
}
|
||||
@@ -345,19 +383,17 @@ func (service *OIDCService) filterScopes(scopes []string) []string {
|
||||
})
|
||||
}
|
||||
|
||||
func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, req AuthorizeRequest) error {
|
||||
// Fixed 10 minutes
|
||||
expiresAt := time.Now().Add(time.Minute * time.Duration(10)).Unix()
|
||||
func (service *OIDCService) CreateCode(req AuthorizeRequest, userContext model.UserContext) string {
|
||||
code := utils.GenerateString(32)
|
||||
sub := service.CreateSub(userContext, req.ClientID)
|
||||
|
||||
entry := repository.CreateOidcCodeParams{
|
||||
Sub: sub,
|
||||
CodeHash: service.Hash(code),
|
||||
// Here it's safe to split and trust the output since, we validated the scopes before
|
||||
Scope: strings.Join(service.filterScopes(strings.Split(req.Scope, " ")), ","),
|
||||
entry := AuthorizeCodeEntry{
|
||||
CodeHash: service.Hash(code),
|
||||
Scope: strings.Join(service.filterScopes(strings.Split(req.Scope, " ")), " "),
|
||||
RedirectURI: req.RedirectURI,
|
||||
ClientID: req.ClientID,
|
||||
ExpiresAt: expiresAt,
|
||||
Nonce: req.Nonce,
|
||||
Userinfo: service.userinfoFromContext(userContext, sub),
|
||||
}
|
||||
|
||||
if req.CodeChallenge != "" {
|
||||
@@ -369,14 +405,14 @@ func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, r
|
||||
}
|
||||
}
|
||||
|
||||
// Insert the code into the database
|
||||
_, err := service.queries.CreateOidcCode(c, entry)
|
||||
// Store the code in the cache
|
||||
service.caches.code.Set(entry.CodeHash, entry, 10*time.Minute)
|
||||
|
||||
return err
|
||||
return code
|
||||
}
|
||||
|
||||
func (service *OIDCService) StoreUserinfo(c *gin.Context, sub string, userContext model.UserContext, req AuthorizeRequest) error {
|
||||
userInfoParams := repository.CreateOidcUserInfoParams{
|
||||
func (service *OIDCService) userinfoFromContext(userContext model.UserContext, sub string) UserinfoResponse {
|
||||
userInfo := UserinfoResponse{
|
||||
Sub: sub,
|
||||
Name: userContext.GetName(),
|
||||
Email: userContext.GetEmail(),
|
||||
@@ -385,37 +421,31 @@ func (service *OIDCService) StoreUserinfo(c *gin.Context, sub string, userContex
|
||||
}
|
||||
|
||||
if userContext.IsLocal() {
|
||||
addressJSON, err := json.Marshal(userContext.Local.Attributes.Address)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userInfoParams.GivenName = userContext.Local.Attributes.GivenName
|
||||
userInfoParams.FamilyName = userContext.Local.Attributes.FamilyName
|
||||
userInfoParams.MiddleName = userContext.Local.Attributes.MiddleName
|
||||
userInfoParams.Nickname = userContext.Local.Attributes.Nickname
|
||||
userInfoParams.Profile = userContext.Local.Attributes.Profile
|
||||
userInfoParams.Picture = userContext.Local.Attributes.Picture
|
||||
userInfoParams.Website = userContext.Local.Attributes.Website
|
||||
userInfoParams.Gender = userContext.Local.Attributes.Gender
|
||||
userInfoParams.Birthdate = userContext.Local.Attributes.Birthdate
|
||||
userInfoParams.Zoneinfo = userContext.Local.Attributes.Zoneinfo
|
||||
userInfoParams.Locale = userContext.Local.Attributes.Locale
|
||||
userInfoParams.PhoneNumber = userContext.Local.Attributes.PhoneNumber
|
||||
userInfoParams.Address = string(addressJSON)
|
||||
userInfo.GivenName = userContext.Local.Attributes.GivenName
|
||||
userInfo.FamilyName = userContext.Local.Attributes.FamilyName
|
||||
userInfo.MiddleName = userContext.Local.Attributes.MiddleName
|
||||
userInfo.Nickname = userContext.Local.Attributes.Nickname
|
||||
userInfo.Profile = userContext.Local.Attributes.Profile
|
||||
userInfo.Picture = userContext.Local.Attributes.Picture
|
||||
userInfo.Website = userContext.Local.Attributes.Website
|
||||
userInfo.Gender = userContext.Local.Attributes.Gender
|
||||
userInfo.Birthdate = userContext.Local.Attributes.Birthdate
|
||||
userInfo.Zoneinfo = userContext.Local.Attributes.Zoneinfo
|
||||
userInfo.Locale = userContext.Local.Attributes.Locale
|
||||
userInfo.PhoneNumber = userContext.Local.Attributes.PhoneNumber
|
||||
userInfo.Address = &userContext.Local.Attributes.Address
|
||||
}
|
||||
|
||||
// Tinyauth will pass through the groups it got from an LDAP or an OIDC server
|
||||
if userContext.IsLDAP() {
|
||||
userInfoParams.Groups = strings.Join(userContext.LDAP.Groups, ",")
|
||||
userInfo.Groups = userContext.LDAP.Groups
|
||||
}
|
||||
|
||||
if userContext.IsOAuth() {
|
||||
userInfoParams.Groups = strings.Join(userContext.OAuth.Groups, ",")
|
||||
userInfo.Groups = userContext.OAuth.Groups
|
||||
}
|
||||
|
||||
_, err := service.queries.CreateOidcUserInfo(c, userInfoParams)
|
||||
|
||||
return err
|
||||
return userInfo
|
||||
}
|
||||
|
||||
func (service *OIDCService) ValidateGrantType(grantType string) error {
|
||||
@@ -426,36 +456,24 @@ func (service *OIDCService) ValidateGrantType(grantType string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string, clientId string) (repository.OidcCode, error) {
|
||||
oidcCode, err := service.queries.GetOidcCode(c, codeHash)
|
||||
func (service *OIDCService) GetCodeEntry(codeHash string, clientId string) (*AuthorizeCodeEntry, bool) {
|
||||
entry, ok := service.caches.code.Get(codeHash)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, repository.ErrNotFound) {
|
||||
return repository.OidcCode{}, ErrCodeNotFound
|
||||
}
|
||||
return repository.OidcCode{}, err
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if time.Now().Unix() > oidcCode.ExpiresAt {
|
||||
err = service.queries.DeleteOidcCode(c, codeHash)
|
||||
if err != nil {
|
||||
return repository.OidcCode{}, err
|
||||
}
|
||||
err = service.DeleteUserinfo(c, oidcCode.Sub)
|
||||
if err != nil {
|
||||
return repository.OidcCode{}, err
|
||||
}
|
||||
return repository.OidcCode{}, ErrCodeExpired
|
||||
if entry.ClientID != clientId {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if oidcCode.ClientID != clientId {
|
||||
return repository.OidcCode{}, ErrInvalidClient
|
||||
}
|
||||
// Since the code can only be used once, we delete it from the cache after retrieving it
|
||||
service.caches.code.Delete(codeHash)
|
||||
|
||||
return oidcCode, nil
|
||||
return &entry, true
|
||||
}
|
||||
|
||||
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user repository.OidcUserinfo, scope string, nonce string) (string, error) {
|
||||
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user UserinfoResponse, scope string, nonce string) (string, error) {
|
||||
createdAt := time.Now().Unix()
|
||||
expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix()
|
||||
|
||||
@@ -521,17 +539,11 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (service *OIDCService) GenerateAccessToken(c *gin.Context, client model.OIDCClientConfig, codeEntry repository.OidcCode) (TokenResponse, error) {
|
||||
user, err := service.GetUserinfo(c, codeEntry.Sub)
|
||||
func (service *OIDCService) GenerateAccessToken(ctx context.Context, client model.OIDCClientConfig, codeEntry AuthorizeCodeEntry) (*TokenResponse, error) {
|
||||
idToken, err := service.generateIDToken(client, codeEntry.Userinfo, codeEntry.Scope, codeEntry.Nonce)
|
||||
|
||||
if err != nil {
|
||||
return TokenResponse{}, err
|
||||
}
|
||||
|
||||
idToken, err := service.generateIDToken(client, user, codeEntry.Scope, codeEntry.Nonce)
|
||||
|
||||
if err != nil {
|
||||
return TokenResponse{}, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
accessToken := utils.GenerateString(32)
|
||||
@@ -551,56 +563,68 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client model.OID
|
||||
Scope: strings.ReplaceAll(codeEntry.Scope, ",", " "),
|
||||
}
|
||||
|
||||
_, err = service.queries.CreateOidcToken(c, repository.CreateOidcTokenParams{
|
||||
Sub: codeEntry.Sub,
|
||||
var userInfoJson []byte
|
||||
|
||||
userInfoJson, err = json.Marshal(codeEntry.Userinfo)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = service.queries.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{
|
||||
Sub: codeEntry.Userinfo.Sub,
|
||||
AccessTokenHash: service.Hash(accessToken),
|
||||
RefreshTokenHash: service.Hash(refreshToken),
|
||||
ClientID: client.ClientID,
|
||||
Scope: codeEntry.Scope,
|
||||
ClientID: client.ClientID,
|
||||
TokenExpiresAt: tokenExpiresAt,
|
||||
RefreshTokenExpiresAt: refreshTokenExpiresAt,
|
||||
Nonce: codeEntry.Nonce,
|
||||
CodeHash: codeEntry.CodeHash,
|
||||
UserinfoJson: string(userInfoJson),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return TokenResponse{}, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tokenResponse, nil
|
||||
return &tokenResponse, nil
|
||||
}
|
||||
|
||||
func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken string, reqClientId string) (TokenResponse, error) {
|
||||
entry, err := service.queries.GetOidcTokenByRefreshToken(c, service.Hash(refreshToken))
|
||||
func (service *OIDCService) RefreshAccessToken(ctx context.Context, refreshToken string, clientId string) (*TokenResponse, error) {
|
||||
entry, err := service.queries.GetOIDCSessionByRefreshTokenHash(ctx, service.Hash(refreshToken))
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, repository.ErrNotFound) {
|
||||
return TokenResponse{}, ErrTokenNotFound
|
||||
return nil, ErrTokenNotFound
|
||||
}
|
||||
return TokenResponse{}, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if entry.RefreshTokenExpiresAt < time.Now().Unix() {
|
||||
return TokenResponse{}, ErrTokenExpired
|
||||
return nil, ErrTokenExpired
|
||||
}
|
||||
|
||||
// Ensure the client ID in the request matches the client ID in the token
|
||||
if entry.ClientID != reqClientId {
|
||||
return TokenResponse{}, ErrInvalidClient
|
||||
if entry.ClientID != clientId {
|
||||
return nil, ErrInvalidClient
|
||||
}
|
||||
|
||||
user, err := service.GetUserinfo(c, entry.Sub)
|
||||
// we need to unmarshal the userinfo from the database to include it in the new ID token,
|
||||
// since the ID token includes user claims for better compatibility with existing apps
|
||||
var userInfo UserinfoResponse
|
||||
|
||||
err = json.Unmarshal([]byte(entry.UserinfoJson), &userInfo)
|
||||
|
||||
if err != nil {
|
||||
return TokenResponse{}, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
idToken, err := service.generateIDToken(model.OIDCClientConfig{
|
||||
ClientID: entry.ClientID,
|
||||
}, user, entry.Scope, entry.Nonce)
|
||||
}, userInfo, entry.Scope, entry.Nonce)
|
||||
|
||||
if err != nil {
|
||||
return TokenResponse{}, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
accessToken := utils.GenerateString(32)
|
||||
@@ -618,71 +642,54 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri
|
||||
Scope: strings.ReplaceAll(entry.Scope, ",", " "),
|
||||
}
|
||||
|
||||
_, err = service.queries.UpdateOidcTokenByRefreshToken(c, repository.UpdateOidcTokenByRefreshTokenParams{
|
||||
_, err = service.queries.UpdateOIDCSession(ctx, repository.UpdateOIDCSessionParams{
|
||||
Sub: entry.Sub,
|
||||
AccessTokenHash: service.Hash(accessToken),
|
||||
RefreshTokenHash: service.Hash(newRefreshToken),
|
||||
Scope: entry.Scope,
|
||||
ClientID: entry.ClientID,
|
||||
TokenExpiresAt: tokenExpiresAt,
|
||||
RefreshTokenExpiresAt: refreshTokenExpiresAt,
|
||||
RefreshTokenHash_2: service.Hash(refreshToken), // that's the selector, it's not stored in the db
|
||||
Nonce: entry.Nonce,
|
||||
UserinfoJson: entry.UserinfoJson,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return TokenResponse{}, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tokenResponse, nil
|
||||
return &tokenResponse, nil
|
||||
}
|
||||
|
||||
func (service *OIDCService) DeleteCodeEntry(c *gin.Context, codeHash string) error {
|
||||
return service.queries.DeleteOidcCode(c, codeHash)
|
||||
}
|
||||
|
||||
func (service *OIDCService) DeleteUserinfo(c *gin.Context, sub string) error {
|
||||
return service.queries.DeleteOidcUserInfo(c, sub)
|
||||
}
|
||||
|
||||
func (service *OIDCService) DeleteToken(c *gin.Context, tokenHash string) error {
|
||||
return service.queries.DeleteOidcToken(c, tokenHash)
|
||||
}
|
||||
|
||||
func (service *OIDCService) DeleteTokenByCodeHash(c *gin.Context, codeHash string) error {
|
||||
return service.queries.DeleteOidcTokenByCodeHash(c, codeHash)
|
||||
}
|
||||
|
||||
func (service *OIDCService) GetAccessToken(c *gin.Context, tokenHash string) (repository.OidcToken, error) {
|
||||
entry, err := service.queries.GetOidcToken(c, tokenHash)
|
||||
func (service *OIDCService) GetSessionByToken(ctx context.Context, tokenHash string) (*repository.OidcSession, error) {
|
||||
entry, err := service.queries.GetOIDCSessionByAccessTokenHash(ctx, tokenHash)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, repository.ErrNotFound) {
|
||||
return repository.OidcToken{}, ErrTokenNotFound
|
||||
return nil, ErrTokenNotFound
|
||||
}
|
||||
return repository.OidcToken{}, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if entry.TokenExpiresAt < time.Now().Unix() {
|
||||
// If refresh token is expired, delete the token and userinfo since there is no way for the client to access anything anymore
|
||||
// If refresh token is expired, delete the session
|
||||
// since there is no way for the client to access anything anymore
|
||||
if entry.RefreshTokenExpiresAt < time.Now().Unix() {
|
||||
err := service.DeleteToken(c, tokenHash)
|
||||
// Deletes by sub
|
||||
err := service.queries.DeleteSession(ctx, entry.Sub)
|
||||
if err != nil {
|
||||
return repository.OidcToken{}, err
|
||||
}
|
||||
err = service.DeleteUserinfo(c, entry.Sub)
|
||||
if err != nil {
|
||||
return repository.OidcToken{}, err
|
||||
return nil, err
|
||||
}
|
||||
return nil, ErrTokenExpired
|
||||
}
|
||||
return repository.OidcToken{}, ErrTokenExpired
|
||||
return nil, ErrTokenExpired
|
||||
}
|
||||
|
||||
return entry, nil
|
||||
return &entry, nil
|
||||
}
|
||||
|
||||
func (service *OIDCService) GetUserinfo(c *gin.Context, sub string) (repository.OidcUserinfo, error) {
|
||||
return service.queries.GetOidcUserInfo(c, sub)
|
||||
}
|
||||
|
||||
func (service *OIDCService) CompileUserinfo(user repository.OidcUserinfo, scope string) UserinfoResponse {
|
||||
scopes := strings.Split(scope, ",") // split by comma since it's a db entry
|
||||
func (service *OIDCService) CompileUserinfo(user UserinfoResponse, scope string) UserinfoResponse {
|
||||
scopes := strings.Split(scope, " ")
|
||||
userInfo := UserinfoResponse{
|
||||
Sub: user.Sub,
|
||||
UpdatedAt: user.UpdatedAt,
|
||||
@@ -710,11 +717,7 @@ func (service *OIDCService) CompileUserinfo(user repository.OidcUserinfo, scope
|
||||
}
|
||||
|
||||
if slices.Contains(scopes, "groups") {
|
||||
if user.Groups != "" {
|
||||
userInfo.Groups = strings.Split(user.Groups, ",")
|
||||
} else {
|
||||
userInfo.Groups = []string{}
|
||||
}
|
||||
userInfo.Groups = user.Groups
|
||||
}
|
||||
|
||||
if slices.Contains(scopes, "phone") {
|
||||
@@ -724,10 +727,7 @@ func (service *OIDCService) CompileUserinfo(user repository.OidcUserinfo, scope
|
||||
}
|
||||
|
||||
if slices.Contains(scopes, "address") {
|
||||
var addr model.AddressClaim
|
||||
if err := json.Unmarshal([]byte(user.Address), &addr); err == nil {
|
||||
userInfo.Address = &addr
|
||||
}
|
||||
userInfo.Address = user.Address
|
||||
}
|
||||
|
||||
return userInfo
|
||||
@@ -740,83 +740,75 @@ func (service *OIDCService) Hash(token string) string {
|
||||
}
|
||||
|
||||
func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) error {
|
||||
err := service.queries.DeleteOidcCodeBySub(ctx, sub)
|
||||
if err != nil && !errors.Is(err, repository.ErrNotFound) {
|
||||
return err
|
||||
}
|
||||
err = service.queries.DeleteOidcTokenBySub(ctx, sub)
|
||||
if err != nil && !errors.Is(err, repository.ErrNotFound) {
|
||||
return err
|
||||
}
|
||||
err = service.queries.DeleteOidcUserInfo(ctx, sub)
|
||||
err := service.queries.DeleteOIDCSessionBySub(ctx, sub)
|
||||
if err != nil && !errors.Is(err, repository.ErrNotFound) {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Cleanup routine - Resource heavy due to the linked tables
|
||||
func (service *OIDCService) cleanupRoutine(ctx context.Context) {
|
||||
service.log.App.Debug().Msg("Starting OIDC cleanup routine")
|
||||
ticker := time.NewTicker(time.Duration(30) * time.Minute)
|
||||
defer ticker.Stop()
|
||||
// // Cleanup routine - Resource heavy due to the linked tables
|
||||
// func (service *OIDCService) cleanupRoutine(ctx context.Context) {
|
||||
// service.log.App.Debug().Msg("Starting OIDC cleanup routine")
|
||||
// ticker := time.NewTicker(time.Duration(30) * time.Minute)
|
||||
// defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
service.log.App.Debug().Msg("Performing OIDC cleanup routine")
|
||||
// for {
|
||||
// select {
|
||||
// case <-ticker.C:
|
||||
// service.log.App.Debug().Msg("Performing OIDC cleanup routine")
|
||||
|
||||
currentTime := time.Now().Unix()
|
||||
// currentTime := time.Now().Unix()
|
||||
|
||||
// For the OIDC tokens, if they are expired we delete the userinfo and codes
|
||||
expiredTokens, err := service.queries.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{
|
||||
TokenExpiresAt: currentTime,
|
||||
RefreshTokenExpiresAt: currentTime,
|
||||
})
|
||||
// // For the OIDC tokens, if they are expired we delete the userinfo and codes
|
||||
// expiredTokens, err := service.queries.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{
|
||||
// TokenExpiresAt: currentTime,
|
||||
// RefreshTokenExpiresAt: currentTime,
|
||||
// })
|
||||
|
||||
if err != nil {
|
||||
service.log.App.Warn().Err(err).Msg("Failed to delete expired tokens")
|
||||
}
|
||||
// if err != nil {
|
||||
// service.log.App.Warn().Err(err).Msg("Failed to delete expired tokens")
|
||||
// }
|
||||
|
||||
for _, expiredToken := range expiredTokens {
|
||||
err := service.DeleteOldSession(ctx, expiredToken.Sub)
|
||||
if err != nil {
|
||||
service.log.App.Warn().Err(err).Msg("Failed to delete session for expired token")
|
||||
}
|
||||
}
|
||||
// for _, expiredToken := range expiredTokens {
|
||||
// err := service.DeleteOldSession(ctx, expiredToken.Sub)
|
||||
// if err != nil {
|
||||
// service.log.App.Warn().Err(err).Msg("Failed to delete session for expired token")
|
||||
// }
|
||||
// }
|
||||
|
||||
// For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything
|
||||
expiredCodes, err := service.queries.DeleteExpiredOidcCodes(ctx, currentTime)
|
||||
// // For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything
|
||||
// expiredCodes, err := service.queries.DeleteExpiredOidcCodes(ctx, currentTime)
|
||||
|
||||
if err != nil {
|
||||
service.log.App.Warn().Err(err).Msg("Failed to delete expired codes")
|
||||
}
|
||||
// if err != nil {
|
||||
// service.log.App.Warn().Err(err).Msg("Failed to delete expired codes")
|
||||
// }
|
||||
|
||||
for _, expiredCode := range expiredCodes {
|
||||
token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub)
|
||||
// for _, expiredCode := range expiredCodes {
|
||||
// token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub)
|
||||
|
||||
if err != nil {
|
||||
if !errors.Is(err, repository.ErrNotFound) {
|
||||
service.log.App.Warn().Err(err).Msg("Failed to get token by sub for expired code")
|
||||
}
|
||||
continue
|
||||
}
|
||||
// if err != nil {
|
||||
// if !errors.Is(err, repository.ErrNotFound) {
|
||||
// service.log.App.Warn().Err(err).Msg("Failed to get token by sub for expired code")
|
||||
// }
|
||||
// continue
|
||||
// }
|
||||
|
||||
if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime {
|
||||
err := service.DeleteOldSession(ctx, expiredCode.Sub)
|
||||
if err != nil {
|
||||
service.log.App.Warn().Err(err).Msg("Failed to delete session for expired code")
|
||||
}
|
||||
}
|
||||
}
|
||||
// if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime {
|
||||
// err := service.DeleteOldSession(ctx, expiredCode.Sub)
|
||||
// if err != nil {
|
||||
// service.log.App.Warn().Err(err).Msg("Failed to delete session for expired code")
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
service.log.App.Debug().Msg("Finished OIDC cleanup routine")
|
||||
case <-ctx.Done():
|
||||
service.log.App.Debug().Msg("Stopping OIDC cleanup routine")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
// service.log.App.Debug().Msg("Finished OIDC cleanup routine")
|
||||
// case <-ctx.Done():
|
||||
// service.log.App.Debug().Msg("Stopping OIDC cleanup routine")
|
||||
// return
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
func (service *OIDCService) GetJWK() ([]byte, error) {
|
||||
hasher := sha256.New()
|
||||
@@ -851,3 +843,10 @@ func (service *OIDCService) hashAndEncodePKCE(codeVerifier string) string {
|
||||
hasher.Write([]byte(codeVerifier))
|
||||
return base64.RawURLEncoding.EncodeToString(hasher.Sum(nil))
|
||||
}
|
||||
|
||||
// WARNING: Since Tinyauth is stateless, we cannot have a sub that never changes.
|
||||
// We will just create a uuid out of the username and client name which remains stable,
|
||||
// but if username or client name changes then sub changes too.
|
||||
func (service *OIDCService) CreateSub(userContext model.UserContext, clientId string) string {
|
||||
return utils.GenerateUUID(fmt.Sprintf("%s:%s", userContext.GetUsername(), clientId))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user