From 8dd731b21ef170a5425018d1b036167de6b78089 Mon Sep 17 00:00:00 2001 From: Stavros Date: Sun, 25 Jan 2026 19:45:17 +0200 Subject: [PATCH] feat: cleanup expired oidc sessions --- internal/bootstrap/app_bootstrap.go | 4 +- internal/controller/oidc_controller.go | 6 +- internal/repository/oidc_queries.sql.go | 117 ++++++++++++++++++++++++ internal/service/oidc_service.go | 64 ++++++++++++- sql/oidc_queries.sql | 49 +++++++--- 5 files changed, 216 insertions(+), 24 deletions(-) diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go index e9cdd5a..9da1d84 100644 --- a/internal/bootstrap/app_bootstrap.go +++ b/internal/bootstrap/app_bootstrap.go @@ -247,7 +247,7 @@ func (app *BootstrapApp) heartbeat() { heartbeatURL := config.ApiServer + "/v1/instances/heartbeat" - for ; true; <-ticker.C { + for range ticker.C { tlog.App.Debug().Msg("Sending heartbeat") req, err := http.NewRequest(http.MethodPost, heartbeatURL, bytes.NewReader(bodyJson)) @@ -279,7 +279,7 @@ func (app *BootstrapApp) dbCleanup(queries *repository.Queries) { defer ticker.Stop() ctx := context.Background() - for ; true; <-ticker.C { + for range ticker.C { tlog.App.Debug().Msg("Cleaning up old database sessions") err := queries.DeleteExpiredSessions(ctx, time.Now().Unix()) if err != nil { diff --git a/internal/controller/oidc_controller.go b/internal/controller/oidc_controller.go index 44eaa73..60dab7b 100644 --- a/internal/controller/oidc_controller.go +++ b/internal/controller/oidc_controller.go @@ -137,10 +137,10 @@ func (controller *OIDCController) Authorize(c *gin.Context) { sub := utils.GenerateUUID(userContext.Username) code := rand.Text() - // Before storing the code, clean up old sessions - err = controller.oidc.CleanupOldSessions(c, sub) + // Before storing the code, delete old session + err = controller.oidc.DeleteOldSession(c, sub) if err != nil { - controller.authorizeError(c, err, "Failed to clean up old sessions", "Failed to clean up old sessions", req.RedirectURI, "server_error", req.State) + controller.authorizeError(c, err, "Failed to delete old sessions", "Failed to delete old sessions", req.RedirectURI, "server_error", req.State) return } diff --git a/internal/repository/oidc_queries.sql.go b/internal/repository/oidc_queries.sql.go index 0833d90..a6535d1 100644 --- a/internal/repository/oidc_queries.sql.go +++ b/internal/repository/oidc_queries.sql.go @@ -145,6 +145,84 @@ func (q *Queries) CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfo return i, err } +const deleteExpiredOidcCodes = `-- name: DeleteExpiredOidcCodes :many +DELETE FROM "oidc_codes" +WHERE "expires_at" < ? +RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at +` + +func (q *Queries) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]OidcCode, error) { + rows, err := q.db.QueryContext(ctx, deleteExpiredOidcCodes, expiresAt) + if err != nil { + return nil, err + } + defer rows.Close() + var items []OidcCode + for rows.Next() { + var i OidcCode + if err := rows.Scan( + &i.Sub, + &i.CodeHash, + &i.Scope, + &i.RedirectURI, + &i.ClientID, + &i.ExpiresAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const deleteExpiredOidcTokens = `-- name: DeleteExpiredOidcTokens :many +DELETE FROM "oidc_tokens" +WHERE "token_expires_at" < ? AND "refresh_token_expires_at" < ? +RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at +` + +type DeleteExpiredOidcTokensParams struct { + TokenExpiresAt int64 + RefreshTokenExpiresAt int64 +} + +func (q *Queries) DeleteExpiredOidcTokens(ctx context.Context, arg DeleteExpiredOidcTokensParams) ([]OidcToken, error) { + rows, err := q.db.QueryContext(ctx, deleteExpiredOidcTokens, arg.TokenExpiresAt, arg.RefreshTokenExpiresAt) + if err != nil { + return nil, err + } + defer rows.Close() + var items []OidcToken + for rows.Next() { + var i OidcToken + if err := rows.Scan( + &i.Sub, + &i.AccessTokenHash, + &i.RefreshTokenHash, + &i.Scope, + &i.ClientID, + &i.TokenExpiresAt, + &i.RefreshTokenExpiresAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const deleteOidcCode = `-- name: DeleteOidcCode :exec DELETE FROM "oidc_codes" WHERE "code_hash" = ? @@ -214,6 +292,25 @@ func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, e return i, err } +const getOidcCodeBySub = `-- name: GetOidcCodeBySub :one +SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at FROM "oidc_codes" +WHERE "sub" = ? +` + +func (q *Queries) GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error) { + row := q.db.QueryRowContext(ctx, getOidcCodeBySub, sub) + var i OidcCode + err := row.Scan( + &i.Sub, + &i.CodeHash, + &i.Scope, + &i.RedirectURI, + &i.ClientID, + &i.ExpiresAt, + ) + return i, err +} + const getOidcToken = `-- name: GetOidcToken :one SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at FROM "oidc_tokens" WHERE "access_token_hash" = ? @@ -254,6 +351,26 @@ func (q *Queries) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHa return i, err } +const getOidcTokenBySub = `-- name: GetOidcTokenBySub :one +SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at FROM "oidc_tokens" +WHERE "sub" = ? +` + +func (q *Queries) GetOidcTokenBySub(ctx context.Context, sub string) (OidcToken, error) { + row := q.db.QueryRowContext(ctx, getOidcTokenBySub, sub) + var i OidcToken + err := row.Scan( + &i.Sub, + &i.AccessTokenHash, + &i.RefreshTokenHash, + &i.Scope, + &i.ClientID, + &i.TokenExpiresAt, + &i.RefreshTokenExpiresAt, + ) + return i, err +} + const getOidcUserInfo = `-- name: GetOidcUserInfo :one SELECT sub, name, preferred_username, email, "groups", updated_at FROM "oidc_userinfo" WHERE "sub" = ? diff --git a/internal/service/oidc_service.go b/internal/service/oidc_service.go index fc13274..ca55e0c 100644 --- a/internal/service/oidc_service.go +++ b/internal/service/oidc_service.go @@ -1,6 +1,7 @@ package service import ( + "context" "crypto" "crypto/rand" "crypto/rsa" @@ -523,18 +524,73 @@ func (service *OIDCService) Hash(token string) string { return fmt.Sprintf("%x", hasher.Sum(nil)) } -func (service *OIDCService) CleanupOldSessions(c *gin.Context, sub string) error { - err := service.queries.DeleteOidcCodeBySub(c, sub) +func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) error { + err := service.queries.DeleteOidcCodeBySub(ctx, sub) if err != nil && !errors.Is(err, sql.ErrNoRows) { return err } - err = service.queries.DeleteOidcTokenBySub(c, sub) + err = service.queries.DeleteOidcTokenBySub(ctx, sub) if err != nil && !errors.Is(err, sql.ErrNoRows) { return err } - err = service.queries.DeleteOidcUserInfo(c, sub) + err = service.queries.DeleteOidcUserInfo(ctx, sub) if err != nil && !errors.Is(err, sql.ErrNoRows) { return err } return nil } + +// Cleanup routine - Resource heavy due to the linked tables +func (service *OIDCService) Cleanup() { + // We need a context for the routine + ctx := context.Background() + + ticker := time.NewTicker(time.Duration(30) * time.Minute) + defer ticker.Stop() + + for range ticker.C { + currentTime := time.Now().Unix() + + // For the OIDC tokens, if they are expired we delete the userinfo and codes + expiredTokens, err := service.queries.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{ + TokenExpiresAt: currentTime, + RefreshTokenExpiresAt: currentTime, + }) + + if err != nil { + tlog.App.Warn().Err(err).Msg("Failed to delete expired tokens") + } + + for _, expiredToken := range expiredTokens { + err := service.DeleteOldSession(ctx, expiredToken.Sub) + if err != nil { + tlog.App.Warn().Err(err).Msg("Failed to delete old session") + } + } + + // For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything + expiredCodes, err := service.queries.DeleteExpiredOidcCodes(ctx, currentTime) + + if err != nil { + tlog.App.Warn().Err(err).Msg("Failed to delete expired codes") + } + + for _, expiredCode := range expiredCodes { + token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub) + + if err != nil { + if err == sql.ErrNoRows { + continue + } + tlog.App.Warn().Err(err).Msg("Failed to get OIDC token by sub") + } + + if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime { + err := service.queries.DeleteSession(ctx, expiredCode.Sub) + if err != nil { + tlog.App.Warn().Err(err).Msg("Failed to delete session") + } + } + } + } +} diff --git a/sql/oidc_queries.sql b/sql/oidc_queries.sql index 18a3485..4089133 100644 --- a/sql/oidc_queries.sql +++ b/sql/oidc_queries.sql @@ -11,6 +11,14 @@ INSERT INTO "oidc_codes" ( ) RETURNING *; +-- name: GetOidcCode :one +SELECT * FROM "oidc_codes" +WHERE "code_hash" = ?; + +-- name: GetOidcCodeBySub :one +SELECT * FROM "oidc_codes" +WHERE "sub" = ?; + -- name: DeleteOidcCode :exec DELETE FROM "oidc_codes" WHERE "code_hash" = ?; @@ -19,10 +27,6 @@ WHERE "code_hash" = ?; DELETE FROM "oidc_codes" WHERE "sub" = ?; --- name: GetOidcCode :one -SELECT * FROM "oidc_codes" -WHERE "code_hash" = ?; - -- name: CreateOidcToken :one INSERT INTO "oidc_tokens" ( "sub", @@ -46,14 +50,6 @@ UPDATE "oidc_tokens" SET WHERE "refresh_token_hash" = ? RETURNING *; --- name: DeleteOidcToken :exec -DELETE FROM "oidc_tokens" -WHERE "access_token_hash" = ?; - --- name: DeleteOidcTokenBySub :exec -DELETE FROM "oidc_tokens" -WHERE "sub" = ?; - -- name: GetOidcToken :one SELECT * FROM "oidc_tokens" WHERE "access_token_hash" = ?; @@ -62,6 +58,19 @@ WHERE "access_token_hash" = ?; SELECT * FROM "oidc_tokens" WHERE "refresh_token_hash" = ?; +-- name: GetOidcTokenBySub :one +SELECT * FROM "oidc_tokens" +WHERE "sub" = ?; + + +-- name: DeleteOidcToken :exec +DELETE FROM "oidc_tokens" +WHERE "access_token_hash" = ?; + +-- name: DeleteOidcTokenBySub :exec +DELETE FROM "oidc_tokens" +WHERE "sub" = ?; + -- name: CreateOidcUserInfo :one INSERT INTO "oidc_userinfo" ( "sub", @@ -75,10 +84,20 @@ INSERT INTO "oidc_userinfo" ( ) RETURNING *; +-- name: GetOidcUserInfo :one +SELECT * FROM "oidc_userinfo" +WHERE "sub" = ?; + -- name: DeleteOidcUserInfo :exec DELETE FROM "oidc_userinfo" WHERE "sub" = ?; --- name: GetOidcUserInfo :one -SELECT * FROM "oidc_userinfo" -WHERE "sub" = ?; +-- name: DeleteExpiredOidcCodes :many +DELETE FROM "oidc_codes" +WHERE "expires_at" < ? +RETURNING *; + +-- name: DeleteExpiredOidcTokens :many +DELETE FROM "oidc_tokens" +WHERE "token_expires_at" < ? AND "refresh_token_expires_at" < ? +RETURNING *;