mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-01-26 17:22:29 +00:00
feat: cleanup expired oidc sessions
This commit is contained in:
@@ -247,7 +247,7 @@ func (app *BootstrapApp) heartbeat() {
|
||||
|
||||
heartbeatURL := config.ApiServer + "/v1/instances/heartbeat"
|
||||
|
||||
for ; true; <-ticker.C {
|
||||
for range ticker.C {
|
||||
tlog.App.Debug().Msg("Sending heartbeat")
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, heartbeatURL, bytes.NewReader(bodyJson))
|
||||
@@ -279,7 +279,7 @@ func (app *BootstrapApp) dbCleanup(queries *repository.Queries) {
|
||||
defer ticker.Stop()
|
||||
ctx := context.Background()
|
||||
|
||||
for ; true; <-ticker.C {
|
||||
for range ticker.C {
|
||||
tlog.App.Debug().Msg("Cleaning up old database sessions")
|
||||
err := queries.DeleteExpiredSessions(ctx, time.Now().Unix())
|
||||
if err != nil {
|
||||
|
||||
@@ -137,10 +137,10 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
|
||||
sub := utils.GenerateUUID(userContext.Username)
|
||||
code := rand.Text()
|
||||
|
||||
// Before storing the code, clean up old sessions
|
||||
err = controller.oidc.CleanupOldSessions(c, sub)
|
||||
// Before storing the code, delete old session
|
||||
err = controller.oidc.DeleteOldSession(c, sub)
|
||||
if err != nil {
|
||||
controller.authorizeError(c, err, "Failed to clean up old sessions", "Failed to clean up old sessions", req.RedirectURI, "server_error", req.State)
|
||||
controller.authorizeError(c, err, "Failed to delete old sessions", "Failed to delete old sessions", req.RedirectURI, "server_error", req.State)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -145,6 +145,84 @@ func (q *Queries) CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfo
|
||||
return i, err
|
||||
}
|
||||
|
||||
const deleteExpiredOidcCodes = `-- name: DeleteExpiredOidcCodes :many
|
||||
DELETE FROM "oidc_codes"
|
||||
WHERE "expires_at" < ?
|
||||
RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at
|
||||
`
|
||||
|
||||
func (q *Queries) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]OidcCode, error) {
|
||||
rows, err := q.db.QueryContext(ctx, deleteExpiredOidcCodes, expiresAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []OidcCode
|
||||
for rows.Next() {
|
||||
var i OidcCode
|
||||
if err := rows.Scan(
|
||||
&i.Sub,
|
||||
&i.CodeHash,
|
||||
&i.Scope,
|
||||
&i.RedirectURI,
|
||||
&i.ClientID,
|
||||
&i.ExpiresAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const deleteExpiredOidcTokens = `-- name: DeleteExpiredOidcTokens :many
|
||||
DELETE FROM "oidc_tokens"
|
||||
WHERE "token_expires_at" < ? AND "refresh_token_expires_at" < ?
|
||||
RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at
|
||||
`
|
||||
|
||||
type DeleteExpiredOidcTokensParams struct {
|
||||
TokenExpiresAt int64
|
||||
RefreshTokenExpiresAt int64
|
||||
}
|
||||
|
||||
func (q *Queries) DeleteExpiredOidcTokens(ctx context.Context, arg DeleteExpiredOidcTokensParams) ([]OidcToken, error) {
|
||||
rows, err := q.db.QueryContext(ctx, deleteExpiredOidcTokens, arg.TokenExpiresAt, arg.RefreshTokenExpiresAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []OidcToken
|
||||
for rows.Next() {
|
||||
var i OidcToken
|
||||
if err := rows.Scan(
|
||||
&i.Sub,
|
||||
&i.AccessTokenHash,
|
||||
&i.RefreshTokenHash,
|
||||
&i.Scope,
|
||||
&i.ClientID,
|
||||
&i.TokenExpiresAt,
|
||||
&i.RefreshTokenExpiresAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const deleteOidcCode = `-- name: DeleteOidcCode :exec
|
||||
DELETE FROM "oidc_codes"
|
||||
WHERE "code_hash" = ?
|
||||
@@ -214,6 +292,25 @@ func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, e
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getOidcCodeBySub = `-- name: GetOidcCodeBySub :one
|
||||
SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at FROM "oidc_codes"
|
||||
WHERE "sub" = ?
|
||||
`
|
||||
|
||||
func (q *Queries) GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error) {
|
||||
row := q.db.QueryRowContext(ctx, getOidcCodeBySub, sub)
|
||||
var i OidcCode
|
||||
err := row.Scan(
|
||||
&i.Sub,
|
||||
&i.CodeHash,
|
||||
&i.Scope,
|
||||
&i.RedirectURI,
|
||||
&i.ClientID,
|
||||
&i.ExpiresAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getOidcToken = `-- name: GetOidcToken :one
|
||||
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at FROM "oidc_tokens"
|
||||
WHERE "access_token_hash" = ?
|
||||
@@ -254,6 +351,26 @@ func (q *Queries) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHa
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getOidcTokenBySub = `-- name: GetOidcTokenBySub :one
|
||||
SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at FROM "oidc_tokens"
|
||||
WHERE "sub" = ?
|
||||
`
|
||||
|
||||
func (q *Queries) GetOidcTokenBySub(ctx context.Context, sub string) (OidcToken, error) {
|
||||
row := q.db.QueryRowContext(ctx, getOidcTokenBySub, sub)
|
||||
var i OidcToken
|
||||
err := row.Scan(
|
||||
&i.Sub,
|
||||
&i.AccessTokenHash,
|
||||
&i.RefreshTokenHash,
|
||||
&i.Scope,
|
||||
&i.ClientID,
|
||||
&i.TokenExpiresAt,
|
||||
&i.RefreshTokenExpiresAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getOidcUserInfo = `-- name: GetOidcUserInfo :one
|
||||
SELECT sub, name, preferred_username, email, "groups", updated_at FROM "oidc_userinfo"
|
||||
WHERE "sub" = ?
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
@@ -523,18 +524,73 @@ func (service *OIDCService) Hash(token string) string {
|
||||
return fmt.Sprintf("%x", hasher.Sum(nil))
|
||||
}
|
||||
|
||||
func (service *OIDCService) CleanupOldSessions(c *gin.Context, sub string) error {
|
||||
err := service.queries.DeleteOidcCodeBySub(c, sub)
|
||||
func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) error {
|
||||
err := service.queries.DeleteOidcCodeBySub(ctx, sub)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return err
|
||||
}
|
||||
err = service.queries.DeleteOidcTokenBySub(c, sub)
|
||||
err = service.queries.DeleteOidcTokenBySub(ctx, sub)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return err
|
||||
}
|
||||
err = service.queries.DeleteOidcUserInfo(c, sub)
|
||||
err = service.queries.DeleteOidcUserInfo(ctx, sub)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Cleanup routine - Resource heavy due to the linked tables
|
||||
func (service *OIDCService) Cleanup() {
|
||||
// We need a context for the routine
|
||||
ctx := context.Background()
|
||||
|
||||
ticker := time.NewTicker(time.Duration(30) * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
currentTime := time.Now().Unix()
|
||||
|
||||
// For the OIDC tokens, if they are expired we delete the userinfo and codes
|
||||
expiredTokens, err := service.queries.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{
|
||||
TokenExpiresAt: currentTime,
|
||||
RefreshTokenExpiresAt: currentTime,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
tlog.App.Warn().Err(err).Msg("Failed to delete expired tokens")
|
||||
}
|
||||
|
||||
for _, expiredToken := range expiredTokens {
|
||||
err := service.DeleteOldSession(ctx, expiredToken.Sub)
|
||||
if err != nil {
|
||||
tlog.App.Warn().Err(err).Msg("Failed to delete old session")
|
||||
}
|
||||
}
|
||||
|
||||
// For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything
|
||||
expiredCodes, err := service.queries.DeleteExpiredOidcCodes(ctx, currentTime)
|
||||
|
||||
if err != nil {
|
||||
tlog.App.Warn().Err(err).Msg("Failed to delete expired codes")
|
||||
}
|
||||
|
||||
for _, expiredCode := range expiredCodes {
|
||||
token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub)
|
||||
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
continue
|
||||
}
|
||||
tlog.App.Warn().Err(err).Msg("Failed to get OIDC token by sub")
|
||||
}
|
||||
|
||||
if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime {
|
||||
err := service.queries.DeleteSession(ctx, expiredCode.Sub)
|
||||
if err != nil {
|
||||
tlog.App.Warn().Err(err).Msg("Failed to delete session")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,14 @@ INSERT INTO "oidc_codes" (
|
||||
)
|
||||
RETURNING *;
|
||||
|
||||
-- name: GetOidcCode :one
|
||||
SELECT * FROM "oidc_codes"
|
||||
WHERE "code_hash" = ?;
|
||||
|
||||
-- name: GetOidcCodeBySub :one
|
||||
SELECT * FROM "oidc_codes"
|
||||
WHERE "sub" = ?;
|
||||
|
||||
-- name: DeleteOidcCode :exec
|
||||
DELETE FROM "oidc_codes"
|
||||
WHERE "code_hash" = ?;
|
||||
@@ -19,10 +27,6 @@ WHERE "code_hash" = ?;
|
||||
DELETE FROM "oidc_codes"
|
||||
WHERE "sub" = ?;
|
||||
|
||||
-- name: GetOidcCode :one
|
||||
SELECT * FROM "oidc_codes"
|
||||
WHERE "code_hash" = ?;
|
||||
|
||||
-- name: CreateOidcToken :one
|
||||
INSERT INTO "oidc_tokens" (
|
||||
"sub",
|
||||
@@ -46,14 +50,6 @@ UPDATE "oidc_tokens" SET
|
||||
WHERE "refresh_token_hash" = ?
|
||||
RETURNING *;
|
||||
|
||||
-- name: DeleteOidcToken :exec
|
||||
DELETE FROM "oidc_tokens"
|
||||
WHERE "access_token_hash" = ?;
|
||||
|
||||
-- name: DeleteOidcTokenBySub :exec
|
||||
DELETE FROM "oidc_tokens"
|
||||
WHERE "sub" = ?;
|
||||
|
||||
-- name: GetOidcToken :one
|
||||
SELECT * FROM "oidc_tokens"
|
||||
WHERE "access_token_hash" = ?;
|
||||
@@ -62,6 +58,19 @@ WHERE "access_token_hash" = ?;
|
||||
SELECT * FROM "oidc_tokens"
|
||||
WHERE "refresh_token_hash" = ?;
|
||||
|
||||
-- name: GetOidcTokenBySub :one
|
||||
SELECT * FROM "oidc_tokens"
|
||||
WHERE "sub" = ?;
|
||||
|
||||
|
||||
-- name: DeleteOidcToken :exec
|
||||
DELETE FROM "oidc_tokens"
|
||||
WHERE "access_token_hash" = ?;
|
||||
|
||||
-- name: DeleteOidcTokenBySub :exec
|
||||
DELETE FROM "oidc_tokens"
|
||||
WHERE "sub" = ?;
|
||||
|
||||
-- name: CreateOidcUserInfo :one
|
||||
INSERT INTO "oidc_userinfo" (
|
||||
"sub",
|
||||
@@ -75,10 +84,20 @@ INSERT INTO "oidc_userinfo" (
|
||||
)
|
||||
RETURNING *;
|
||||
|
||||
-- name: GetOidcUserInfo :one
|
||||
SELECT * FROM "oidc_userinfo"
|
||||
WHERE "sub" = ?;
|
||||
|
||||
-- name: DeleteOidcUserInfo :exec
|
||||
DELETE FROM "oidc_userinfo"
|
||||
WHERE "sub" = ?;
|
||||
|
||||
-- name: GetOidcUserInfo :one
|
||||
SELECT * FROM "oidc_userinfo"
|
||||
WHERE "sub" = ?;
|
||||
-- name: DeleteExpiredOidcCodes :many
|
||||
DELETE FROM "oidc_codes"
|
||||
WHERE "expires_at" < ?
|
||||
RETURNING *;
|
||||
|
||||
-- name: DeleteExpiredOidcTokens :many
|
||||
DELETE FROM "oidc_tokens"
|
||||
WHERE "token_expires_at" < ? AND "refresh_token_expires_at" < ?
|
||||
RETURNING *;
|
||||
|
||||
Reference in New Issue
Block a user