From 695feca71c373ad4d7eb81e9a6f7929f9e700b3f Mon Sep 17 00:00:00 2001 From: Stavros Date: Sun, 31 May 2026 20:10:53 +0300 Subject: [PATCH] refactor: rework oidc session storage --- .env.example | 4 +- .../postgres/000002_oidc_rework.down.sql | 46 ++ .../postgres/000002_oidc_rework.up.sql | 28 + .../sqlite/000010_oidc_rework.down.sql | 46 ++ .../sqlite/000010_oidc_rework.up.sql | 28 + internal/bootstrap/db_bootstrap.go | 5 +- internal/controller/oidc_controller.go | 95 +--- internal/repository/memory/memory_test.go | 4 + internal/repository/memory/oidc_queries.go | 4 + internal/repository/memory/session_queries.go | 4 + internal/repository/memory/store.go | 18 +- internal/repository/models.go | 84 +-- internal/repository/postgres/db.go | 2 +- internal/repository/postgres/models.go | 39 +- .../repository/postgres/oidc_queries.sql.go | 505 +++--------------- .../postgres/session_queries.sql.go | 2 +- internal/repository/postgres/store.go | 144 +---- internal/repository/sqlite/db.go | 2 +- internal/repository/sqlite/models.go | 39 +- .../repository/sqlite/oidc_queries.sql.go | 499 +++-------------- .../repository/sqlite/session_queries.sql.go | 2 +- internal/repository/sqlite/store.go | 144 +---- internal/repository/store.go | 33 +- internal/service/oidc_service.go | 387 +++++++------- sql/postgres/oidc_queries.sql | 141 +---- sql/postgres/oidc_schemas.sql | 51 +- sql/sqlite/oidc_queries.sql | 141 +---- sql/sqlite/oidc_schemas.sql | 45 +- sqlc.yml | 6 +- 29 files changed, 668 insertions(+), 1880 deletions(-) create mode 100644 internal/assets/migrations/postgres/000002_oidc_rework.down.sql create mode 100644 internal/assets/migrations/postgres/000002_oidc_rework.up.sql create mode 100644 internal/assets/migrations/sqlite/000010_oidc_rework.down.sql create mode 100644 internal/assets/migrations/sqlite/000010_oidc_rework.up.sql diff --git a/.env.example b/.env.example index a48204f3..100b0e9d 100644 --- a/.env.example +++ b/.env.example @@ -7,9 +7,9 @@ TINYAUTH_APPURL= # database config -# The database driver to use. Valid values: sqlite, memory. +# The database driver to use. Valid values: sqlite, postgres, memory. TINYAUTH_DATABASE_DRIVER="sqlite" -# The path to the SQLite database, including file name. Only used when driver is sqlite. +# The path to the SQLite database file, or connection URL when driver is postgres. TINYAUTH_DATABASE_PATH="./tinyauth.db" # analytics config diff --git a/internal/assets/migrations/postgres/000002_oidc_rework.down.sql b/internal/assets/migrations/postgres/000002_oidc_rework.down.sql new file mode 100644 index 00000000..7e8dda01 --- /dev/null +++ b/internal/assets/migrations/postgres/000002_oidc_rework.down.sql @@ -0,0 +1,46 @@ +DROP TABLE IF EXISTS "oidc_sessions"; + +CREATE TABLE "oidc_codes" ( + "sub" TEXT NOT NULL UNIQUE, + "code_hash" TEXT NOT NULL PRIMARY KEY, + "scope" TEXT NOT NULL, + "redirect_uri" TEXT NOT NULL, + "client_id" TEXT NOT NULL, + "expires_at" BIGINT NOT NULL, + "nonce" TEXT NOT NULL DEFAULT '', + "code_challenge" TEXT NOT NULL DEFAULT '' +); + +CREATE TABLE "oidc_tokens" ( + "sub" TEXT NOT NULL UNIQUE, + "access_token_hash" TEXT NOT NULL PRIMARY KEY, + "refresh_token_hash" TEXT NOT NULL, + "code_hash" TEXT NOT NULL, + "scope" TEXT NOT NULL, + "client_id" TEXT NOT NULL, + "token_expires_at" BIGINT NOT NULL, + "refresh_token_expires_at" BIGINT NOT NULL, + "nonce" TEXT NOT NULL DEFAULT '' +); + +CREATE TABLE "oidc_userinfo" ( + "sub" TEXT NOT NULL PRIMARY KEY, + "name" TEXT NOT NULL, + "preferred_username" TEXT NOT NULL, + "email" TEXT NOT NULL, + "groups" TEXT NOT NULL, + "updated_at" BIGINT NOT NULL, + "given_name" TEXT NOT NULL, + "family_name" TEXT NOT NULL, + "middle_name" TEXT NOT NULL, + "nickname" TEXT NOT NULL, + "profile" TEXT NOT NULL, + "picture" TEXT NOT NULL, + "website" TEXT NOT NULL, + "gender" TEXT NOT NULL, + "birthdate" TEXT NOT NULL, + "zoneinfo" TEXT NOT NULL, + "locale" TEXT NOT NULL, + "phone_number" TEXT NOT NULL, + "address" TEXT NOT NULL +); diff --git a/internal/assets/migrations/postgres/000002_oidc_rework.up.sql b/internal/assets/migrations/postgres/000002_oidc_rework.up.sql new file mode 100644 index 00000000..1104f20e --- /dev/null +++ b/internal/assets/migrations/postgres/000002_oidc_rework.up.sql @@ -0,0 +1,28 @@ +/* +This migration will nuke the entire setup of OIDC sessions and merge everything +into one table. +*/ + +/* +Drop all the old tables. Yes, we will log out all OIDC users, but not really a big deal +*/ + +DROP TABLE IF EXISTS "oidc_tokens"; +DROP TABLE IF EXISTS "oidc_userinfo"; +DROP TABLE IF EXISTS "oidc_codes"; + +/* +Create a new simple OIDC sessions table that will hold tokens + userinfo. +*/ + +CREATE TABLE IF NOT EXISTS "oidc_sessions" ( + "sub" TEXT NOT NULL UNIQUE PRIMARY KEY, + "access_token_hash" TEXT NOT NULL UNIQUE, + "refresh_token_hash" TEXT NOT NULL UNIQUE, + "scope" TEXT NOT NULL, + "client_id" TEXT NOT NULL, + "token_expires_at" BIGINT NOT NULL, + "refresh_token_expires_at" BIGINT NOT NULL, + "nonce" TEXT NOT NULL DEFAULT '', + "userinfo_json" TEXT NOT NULL +); diff --git a/internal/assets/migrations/sqlite/000010_oidc_rework.down.sql b/internal/assets/migrations/sqlite/000010_oidc_rework.down.sql new file mode 100644 index 00000000..94618c51 --- /dev/null +++ b/internal/assets/migrations/sqlite/000010_oidc_rework.down.sql @@ -0,0 +1,46 @@ +DROP TABLE IF EXISTS "oidc_sessions"; + +CREATE TABLE IF NOT EXISTS "oidc_codes" ( + "sub" TEXT NOT NULL UNIQUE, + "code_hash" TEXT NOT NULL PRIMARY KEY UNIQUE, + "scope" TEXT NOT NULL, + "redirect_uri" TEXT NOT NULL, + "client_id" TEXT NOT NULL, + "expires_at" INTEGER NOT NULL, + "nonce" TEXT DEFAULT "", + "code_challenge" TEXT DEFAULT "" +); + +CREATE TABLE IF NOT EXISTS "oidc_tokens" ( + "sub" TEXT NOT NULL UNIQUE, + "access_token_hash" TEXT NOT NULL PRIMARY KEY UNIQUE, + "refresh_token_hash" TEXT NOT NULL, + "code_hash" TEXT NOT NULL, + "scope" TEXT NOT NULL, + "client_id" TEXT NOT NULL, + "token_expires_at" INTEGER NOT NULL, + "refresh_token_expires_at" INTEGER NOT NULL, + "nonce" TEXT DEFAULT "" +); + +CREATE TABLE IF NOT EXISTS "oidc_userinfo" ( + "sub" TEXT NOT NULL UNIQUE PRIMARY KEY, + "name" TEXT NOT NULL, + "preferred_username" TEXT NOT NULL, + "email" TEXT NOT NULL, + "groups" TEXT NOT NULL, + "updated_at" INTEGER NOT NULL, + "given_name" TEXT NOT NULL, + "family_name" TEXT NOT NULL, + "middle_name" TEXT NOT NULL, + "nickname" TEXT NOT NULL, + "profile" TEXT NOT NULL, + "picture" TEXT NOT NULL, + "website" TEXT NOT NULL, + "gender" TEXT NOT NULL, + "birthdate" TEXT NOT NULL, + "zoneinfo" TEXT NOT NULL, + "locale" TEXT NOT NULL, + "phone_number" TEXT NOT NULL, + "address" TEXT NOT NULL +); diff --git a/internal/assets/migrations/sqlite/000010_oidc_rework.up.sql b/internal/assets/migrations/sqlite/000010_oidc_rework.up.sql new file mode 100644 index 00000000..e086250b --- /dev/null +++ b/internal/assets/migrations/sqlite/000010_oidc_rework.up.sql @@ -0,0 +1,28 @@ +/* +This migration will nuke the entire setup of OIDC sessions and merge everything +into one table. +*/ + +/* +Drop all the old tables. Yes, we will log out all OIDC users, but not really a big deal +*/ + +DROP TABLE IF EXISTS "oidc_tokens"; +DROP TABLE IF EXISTS "oidc_userinfo"; +DROP TABLE IF EXISTS "oidc_codes"; + +/* +Create a new simple OIDC sessions table that will hold tokens + userinfo. +*/ + +CREATE TABLE IF NOT EXISTS "oidc_sessions" ( + "sub" TEXT NOT NULL UNIQUE PRIMARY KEY, + "access_token_hash" TEXT NOT NULL UNIQUE, + "refresh_token_hash" TEXT NOT NULL UNIQUE, + "scope" TEXT NOT NULL, + "client_id" TEXT NOT NULL, + "token_expires_at" INTEGER NOT NULL, + "refresh_token_expires_at" INTEGER NOT NULL, + "nonce" TEXT DEFAULT "", + "userinfo_json" TEXT NOT NULL +); diff --git a/internal/bootstrap/db_bootstrap.go b/internal/bootstrap/db_bootstrap.go index 67d6549a..c59c5cf3 100644 --- a/internal/bootstrap/db_bootstrap.go +++ b/internal/bootstrap/db_bootstrap.go @@ -15,15 +15,14 @@ import ( "github.com/tinyauthapp/tinyauth/internal/assets" "github.com/tinyauthapp/tinyauth/internal/repository" - "github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/repository/postgres" "github.com/tinyauthapp/tinyauth/internal/repository/sqlite" ) func (app *BootstrapApp) SetupStore() (repository.Store, error) { switch app.config.Database.Driver { - case "memory": - return memory.New(), nil + // case "memory": + // return memory.New(), nil case "sqlite", "": return app.setupSQLite(app.config.Database.Path) case "postgres": diff --git a/internal/controller/oidc_controller.go b/internal/controller/oidc_controller.go index 40170a78..bf6d1f2f 100644 --- a/internal/controller/oidc_controller.go +++ b/internal/controller/oidc_controller.go @@ -1,6 +1,7 @@ package controller import ( + "encoding/json" "errors" "fmt" "net/http" @@ -12,7 +13,6 @@ import ( "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/service" - "github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) @@ -169,7 +169,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) { return } - client, ok := controller.oidc.GetClient(req.ClientID) + _, ok := controller.oidc.GetClient(req.ClientID) if !ok { controller.authorizeError(c, authorizeErrorParams{ @@ -203,9 +203,8 @@ func (controller *OIDCController) Authorize(c *gin.Context) { return } - // WARNING: Since Tinyauth is stateless, we cannot have a sub that never changes. We will just create a uuid out of the username and client name which remains stable, but if username or client name changes then sub changes too. - sub := utils.GenerateUUID(fmt.Sprintf("%s:%s", userContext.GetUsername(), client.ID)) - code := utils.GenerateString(32) + // Create the sub to find and delete old sessions + sub := controller.oidc.CreateSub(*userContext, req.ClientID) // Before storing the code, delete old session err = controller.oidc.DeleteOldSession(c, sub) @@ -221,37 +220,8 @@ func (controller *OIDCController) Authorize(c *gin.Context) { return } - err = controller.oidc.StoreCode(c, sub, code, req) - - if err != nil { - controller.authorizeError(c, authorizeErrorParams{ - err: err, - reason: "Failed to store code", - reasonPublic: "Failed to store code", - callback: req.RedirectURI, - callbackError: "server_error", - state: req.State, - }) - return - } - - // We also need a snapshot of the user that authorized this (skip if no openid scope) - if slices.Contains(strings.Fields(req.Scope), "openid") { - err = controller.oidc.StoreUserinfo(c, sub, *userContext, req) - - if err != nil { - controller.log.App.Error().Err(err).Msg("Failed to store user info") - controller.authorizeError(c, authorizeErrorParams{ - err: err, - reason: "Failed to store user info", - reasonPublic: "Failed to store user info", - callback: req.RedirectURI, - callbackError: "server_error", - state: req.State, - }) - return - } - } + // Create the authorization code + code := controller.oidc.CreateCode(req, *userContext) queries, err := query.Values(AuthorizeCallback{ Code: code, @@ -354,35 +324,12 @@ func (controller *OIDCController) Token(c *gin.Context) { switch req.GrantType { case "authorization_code": - entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code), client.ClientID) - if err != nil { - if err := controller.oidc.DeleteTokenByCodeHash(c, controller.oidc.Hash(req.Code)); err != nil { - controller.log.App.Error().Err(err).Msg("Failed to revoke tokens for replayed code") - } - if errors.Is(err, service.ErrCodeNotFound) { - controller.log.App.Warn().Msg("Code not found") - c.JSON(400, gin.H{ - "error": "invalid_grant", - }) - return - } - if errors.Is(err, service.ErrCodeExpired) { - controller.log.App.Warn().Msg("Code expired") - c.JSON(400, gin.H{ - "error": "invalid_grant", - }) - return - } - if errors.Is(err, service.ErrInvalidClient) { - controller.log.App.Warn().Msg("Code does not belong to client") - c.JSON(400, gin.H{ - "error": "invalid_client", - }) - return - } - controller.log.App.Error().Err(err).Msg("Failed to get code entry") + entry, ok := controller.oidc.GetCodeEntry(controller.oidc.Hash(req.Code), client.ClientID) + + if !ok { + controller.log.App.Warn().Msg("Code not found") c.JSON(400, gin.H{ - "error": "server_error", + "error": "invalid_grant", }) return } @@ -395,7 +342,7 @@ func (controller *OIDCController) Token(c *gin.Context) { return } - ok := controller.oidc.ValidatePKCE(entry.CodeChallenge, req.CodeVerifier) + ok = controller.oidc.ValidatePKCE(entry.CodeChallenge, req.CodeVerifier) if !ok { controller.log.App.Warn().Msg("PKCE validation failed") @@ -405,7 +352,7 @@ func (controller *OIDCController) Token(c *gin.Context) { return } - tokenRes, err := controller.oidc.GenerateAccessToken(c, client, entry) + tokenRes, err := controller.oidc.GenerateAccessToken(c, client, *entry) if err != nil { controller.log.App.Error().Err(err).Msg("Failed to generate access token") @@ -415,7 +362,7 @@ func (controller *OIDCController) Token(c *gin.Context) { return } - tokenResponse = tokenRes + tokenResponse = *tokenRes case "refresh_token": tokenRes, err := controller.oidc.RefreshAccessToken(c, req.RefreshToken, creds.ClientID) @@ -443,7 +390,7 @@ func (controller *OIDCController) Token(c *gin.Context) { return } - tokenResponse = tokenRes + tokenResponse = *tokenRes } c.Header("cache-control", "no-store") @@ -507,7 +454,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) { return } - entry, err := controller.oidc.GetAccessToken(c, controller.oidc.Hash(token)) + entry, err := controller.oidc.GetSessionByToken(c, controller.oidc.Hash(token)) if err != nil { if errors.Is(err, service.ErrTokenNotFound) { @@ -526,15 +473,17 @@ func (controller *OIDCController) Userinfo(c *gin.Context) { } // If we don't have the openid scope, return an error - if !slices.Contains(strings.Split(entry.Scope, ","), "openid") { - controller.log.App.Warn().Msg("OIDC userinfo accessed with token missing openid scope") + if !slices.Contains(strings.Split(entry.Scope, " "), "openid") { + controller.log.App.Warn().Msg("OIDC userinfo accessed with missing openid scope") c.JSON(401, gin.H{ "error": "invalid_scope", }) return } - user, err := controller.oidc.GetUserinfo(c, entry.Sub) + var userinfo service.UserinfoResponse + + err = json.Unmarshal([]byte(entry.UserinfoJson), &userinfo) if err != nil { controller.log.App.Error().Err(err).Msg("Failed to get user info") @@ -544,7 +493,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) { return } - c.JSON(200, controller.oidc.CompileUserinfo(user, entry.Scope)) + c.JSON(200, controller.oidc.CompileUserinfo(userinfo, entry.Scope)) } func (controller *OIDCController) authorizeError(c *gin.Context, params authorizeErrorParams) { diff --git a/internal/repository/memory/memory_test.go b/internal/repository/memory/memory_test.go index 16f20b13..07fee88d 100644 --- a/internal/repository/memory/memory_test.go +++ b/internal/repository/memory/memory_test.go @@ -1,3 +1,7 @@ +//go:build exclude + +// temporary + package memory_test import ( diff --git a/internal/repository/memory/oidc_queries.go b/internal/repository/memory/oidc_queries.go index d2798c3e..0b4d758f 100644 --- a/internal/repository/memory/oidc_queries.go +++ b/internal/repository/memory/oidc_queries.go @@ -1,3 +1,7 @@ +//go:build exclude + +// temporary + package memory import ( diff --git a/internal/repository/memory/session_queries.go b/internal/repository/memory/session_queries.go index 2edde6b1..fbbb43cf 100644 --- a/internal/repository/memory/session_queries.go +++ b/internal/repository/memory/session_queries.go @@ -1,3 +1,7 @@ +//go:build exclude + +// temporary + package memory import ( diff --git a/internal/repository/memory/store.go b/internal/repository/memory/store.go index 969cba66..a2a56ad3 100644 --- a/internal/repository/memory/store.go +++ b/internal/repository/memory/store.go @@ -1,3 +1,7 @@ +//go:build exclude + +// temporary + // Package memory provides an in-memory implementation of repository.Store for use in tests. package memory @@ -9,19 +13,15 @@ import ( // Store is a thread-safe in-memory implementation of repository.Store. type Store struct { - mu sync.RWMutex - sessions map[string]repository.Session - oidcCodes map[string]repository.OidcCode - oidcTokens map[string]repository.OidcToken - oidcUsers map[string]repository.OidcUserinfo + mu sync.RWMutex + sessions map[string]repository.Session + oidcSessions map[string]repository.OidcSession } // New returns a new empty in-memory Store. func New() repository.Store { return &Store{ - sessions: make(map[string]repository.Session), - oidcCodes: make(map[string]repository.OidcCode), - oidcTokens: make(map[string]repository.OidcToken), - oidcUsers: make(map[string]repository.OidcUserinfo), + sessions: make(map[string]repository.Session), + oidcSessions: make(map[string]repository.OidcSession), } } diff --git a/internal/repository/models.go b/internal/repository/models.go index 3f58dd66..39538a00 100644 --- a/internal/repository/models.go +++ b/internal/repository/models.go @@ -17,49 +17,16 @@ type Session struct { OAuthSub string } -type OidcCode struct { - Sub string - CodeHash string - Scope string - RedirectURI string - ClientID string - ExpiresAt int64 - Nonce string - CodeChallenge string -} - -type OidcToken struct { +type OidcSession struct { Sub string AccessTokenHash string RefreshTokenHash string - CodeHash string Scope string ClientID string TokenExpiresAt int64 RefreshTokenExpiresAt int64 Nonce string -} - -type OidcUserinfo struct { - Sub string - Name string - PreferredUsername string - Email string - Groups string - UpdatedAt int64 - GivenName string - FamilyName string - MiddleName string - Nickname string - Profile string - Picture string - Website string - Gender string - Birthdate string - Zoneinfo string - Locale string - PhoneNumber string - Address string + UserinfoJson string } type CreateSessionParams struct { @@ -89,18 +56,7 @@ type UpdateSessionParams struct { UUID string } -type CreateOidcCodeParams struct { - Sub string - CodeHash string - Scope string - RedirectURI string - ClientID string - ExpiresAt int64 - Nonce string - CodeChallenge string -} - -type CreateOidcTokenParams struct { +type CreateOIDCSessionParams struct { Sub string AccessTokenHash string RefreshTokenHash string @@ -108,41 +64,23 @@ type CreateOidcTokenParams struct { ClientID string TokenExpiresAt int64 RefreshTokenExpiresAt int64 - CodeHash string Nonce string + UserinfoJson string } -type UpdateOidcTokenByRefreshTokenParams struct { +type UpdateOIDCSessionParams struct { AccessTokenHash string RefreshTokenHash string + Scope string + ClientID string TokenExpiresAt int64 RefreshTokenExpiresAt int64 - RefreshTokenHash_2 string + Nonce string + UserinfoJson string + Sub string } -type DeleteExpiredOidcTokensParams struct { +type DeleteExpiredOIDCSessionsParams struct { TokenExpiresAt int64 RefreshTokenExpiresAt int64 } - -type CreateOidcUserInfoParams struct { - Sub string - Name string - PreferredUsername string - Email string - Groups string - UpdatedAt int64 - GivenName string - FamilyName string - MiddleName string - Nickname string - Profile string - Picture string - Website string - Gender string - Birthdate string - Zoneinfo string - Locale string - PhoneNumber string - Address string -} diff --git a/internal/repository/postgres/db.go b/internal/repository/postgres/db.go index e546ecca..76b783ec 100644 --- a/internal/repository/postgres/db.go +++ b/internal/repository/postgres/db.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.31.1 +// sqlc v1.30.0 package postgres diff --git a/internal/repository/postgres/models.go b/internal/repository/postgres/models.go index be3999da..c2247402 100644 --- a/internal/repository/postgres/models.go +++ b/internal/repository/postgres/models.go @@ -1,52 +1,19 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.31.1 +// sqlc v1.30.0 package postgres -type OidcCode struct { - Sub string - CodeHash string - Scope string - RedirectURI string - ClientID string - ExpiresAt int64 - Nonce string - CodeChallenge string -} - -type OidcToken struct { +type OidcSession struct { Sub string AccessTokenHash string RefreshTokenHash string - CodeHash string Scope string ClientID string TokenExpiresAt int64 RefreshTokenExpiresAt int64 Nonce string -} - -type OidcUserinfo struct { - Sub string - Name string - PreferredUsername string - Email string - Groups string - UpdatedAt int64 - GivenName string - FamilyName string - MiddleName string - Nickname string - Profile string - Picture string - Website string - Gender string - Birthdate string - Zoneinfo string - Locale string - PhoneNumber string - Address string + UserinfoJson string } type Session struct { diff --git a/internal/repository/postgres/oidc_queries.sql.go b/internal/repository/postgres/oidc_queries.sql.go index 637bb701..81259f4a 100644 --- a/internal/repository/postgres/oidc_queries.sql.go +++ b/internal/repository/postgres/oidc_queries.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.31.1 +// sqlc v1.30.0 // source: oidc_queries.sql package postgres @@ -9,60 +9,8 @@ import ( "context" ) -const createOidcCode = `-- name: CreateOidcCode :one -INSERT INTO "oidc_codes" ( - "sub", - "code_hash", - "scope", - "redirect_uri", - "client_id", - "expires_at", - "nonce", - "code_challenge" -) VALUES ( - $1, $2, $3, $4, $5, $6, $7, $8 -) -RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge -` - -type CreateOidcCodeParams struct { - Sub string - CodeHash string - Scope string - RedirectURI string - ClientID string - ExpiresAt int64 - Nonce string - CodeChallenge string -} - -func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error) { - row := q.db.QueryRowContext(ctx, createOidcCode, - arg.Sub, - arg.CodeHash, - arg.Scope, - arg.RedirectURI, - arg.ClientID, - arg.ExpiresAt, - arg.Nonce, - arg.CodeChallenge, - ) - var i OidcCode - err := row.Scan( - &i.Sub, - &i.CodeHash, - &i.Scope, - &i.RedirectURI, - &i.ClientID, - &i.ExpiresAt, - &i.Nonce, - &i.CodeChallenge, - ) - return i, err -} - -const createOidcToken = `-- name: CreateOidcToken :one -INSERT INTO "oidc_tokens" ( +const createOIDCSession = `-- name: CreateOIDCSession :one +INSERT INTO "oidc_sessions" ( "sub", "access_token_hash", "refresh_token_hash", @@ -70,15 +18,15 @@ INSERT INTO "oidc_tokens" ( "client_id", "token_expires_at", "refresh_token_expires_at", - "code_hash", - "nonce" + "nonce", + "userinfo_json" ) VALUES ( $1, $2, $3, $4, $5, $6, $7, $8, $9 ) -RETURNING sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce +RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json ` -type CreateOidcTokenParams struct { +type CreateOIDCSessionParams struct { Sub string AccessTokenHash string RefreshTokenHash string @@ -86,12 +34,12 @@ type CreateOidcTokenParams struct { ClientID string TokenExpiresAt int64 RefreshTokenExpiresAt int64 - CodeHash string Nonce string + UserinfoJson string } -func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error) { - row := q.db.QueryRowContext(ctx, createOidcToken, +func (q *Queries) CreateOIDCSession(ctx context.Context, arg CreateOIDCSessionParams) (OidcSession, error) { + row := q.db.QueryRowContext(ctx, createOIDCSession, arg.Sub, arg.AccessTokenHash, arg.RefreshTokenHash, @@ -99,483 +47,164 @@ func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams arg.ClientID, arg.TokenExpiresAt, arg.RefreshTokenExpiresAt, - arg.CodeHash, arg.Nonce, + arg.UserinfoJson, ) - var i OidcToken + var i OidcSession err := row.Scan( &i.Sub, &i.AccessTokenHash, &i.RefreshTokenHash, - &i.CodeHash, &i.Scope, &i.ClientID, &i.TokenExpiresAt, &i.RefreshTokenExpiresAt, &i.Nonce, + &i.UserinfoJson, ) return i, err } -const createOidcUserInfo = `-- name: CreateOidcUserInfo :one -INSERT INTO "oidc_userinfo" ( - "sub", - "name", - "preferred_username", - "email", - "groups", - "updated_at", - "given_name", - "family_name", - "middle_name", - "nickname", - "profile", - "picture", - "website", - "gender", - "birthdate", - "zoneinfo", - "locale", - "phone_number", - "address" -) VALUES ( - $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19 -) -RETURNING sub, name, preferred_username, email, groups, updated_at, given_name, family_name, middle_name, nickname, profile, picture, website, gender, birthdate, zoneinfo, locale, phone_number, address -` - -type CreateOidcUserInfoParams struct { - Sub string - Name string - PreferredUsername string - Email string - Groups string - UpdatedAt int64 - GivenName string - FamilyName string - MiddleName string - Nickname string - Profile string - Picture string - Website string - Gender string - Birthdate string - Zoneinfo string - Locale string - PhoneNumber string - Address string -} - -func (q *Queries) CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfoParams) (OidcUserinfo, error) { - row := q.db.QueryRowContext(ctx, createOidcUserInfo, - arg.Sub, - arg.Name, - arg.PreferredUsername, - arg.Email, - arg.Groups, - arg.UpdatedAt, - arg.GivenName, - arg.FamilyName, - arg.MiddleName, - arg.Nickname, - arg.Profile, - arg.Picture, - arg.Website, - arg.Gender, - arg.Birthdate, - arg.Zoneinfo, - arg.Locale, - arg.PhoneNumber, - arg.Address, - ) - var i OidcUserinfo - err := row.Scan( - &i.Sub, - &i.Name, - &i.PreferredUsername, - &i.Email, - &i.Groups, - &i.UpdatedAt, - &i.GivenName, - &i.FamilyName, - &i.MiddleName, - &i.Nickname, - &i.Profile, - &i.Picture, - &i.Website, - &i.Gender, - &i.Birthdate, - &i.Zoneinfo, - &i.Locale, - &i.PhoneNumber, - &i.Address, - ) - return i, err -} - -const deleteExpiredOidcCodes = `-- name: DeleteExpiredOidcCodes :many -DELETE FROM "oidc_codes" -WHERE "expires_at" < $1 -RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge -` - -func (q *Queries) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]OidcCode, error) { - rows, err := q.db.QueryContext(ctx, deleteExpiredOidcCodes, expiresAt) - if err != nil { - return nil, err - } - defer rows.Close() - var items []OidcCode - for rows.Next() { - var i OidcCode - if err := rows.Scan( - &i.Sub, - &i.CodeHash, - &i.Scope, - &i.RedirectURI, - &i.ClientID, - &i.ExpiresAt, - &i.Nonce, - &i.CodeChallenge, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const deleteExpiredOidcTokens = `-- name: DeleteExpiredOidcTokens :many -DELETE FROM "oidc_tokens" +const deleteExpiredOIDCSessions = `-- name: DeleteExpiredOIDCSessions :exec +DELETE FROM "oidc_sessions" WHERE "token_expires_at" < $1 AND "refresh_token_expires_at" < $2 -RETURNING sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce ` -type DeleteExpiredOidcTokensParams struct { +type DeleteExpiredOIDCSessionsParams struct { TokenExpiresAt int64 RefreshTokenExpiresAt int64 } -func (q *Queries) DeleteExpiredOidcTokens(ctx context.Context, arg DeleteExpiredOidcTokensParams) ([]OidcToken, error) { - rows, err := q.db.QueryContext(ctx, deleteExpiredOidcTokens, arg.TokenExpiresAt, arg.RefreshTokenExpiresAt) - if err != nil { - return nil, err - } - defer rows.Close() - var items []OidcToken - for rows.Next() { - var i OidcToken - if err := rows.Scan( - &i.Sub, - &i.AccessTokenHash, - &i.RefreshTokenHash, - &i.CodeHash, - &i.Scope, - &i.ClientID, - &i.TokenExpiresAt, - &i.RefreshTokenExpiresAt, - &i.Nonce, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const deleteOidcCode = `-- name: DeleteOidcCode :exec -DELETE FROM "oidc_codes" -WHERE "code_hash" = $1 -` - -func (q *Queries) DeleteOidcCode(ctx context.Context, codeHash string) error { - _, err := q.db.ExecContext(ctx, deleteOidcCode, codeHash) +func (q *Queries) DeleteExpiredOIDCSessions(ctx context.Context, arg DeleteExpiredOIDCSessionsParams) error { + _, err := q.db.ExecContext(ctx, deleteExpiredOIDCSessions, arg.TokenExpiresAt, arg.RefreshTokenExpiresAt) return err } -const deleteOidcCodeBySub = `-- name: DeleteOidcCodeBySub :exec -DELETE FROM "oidc_codes" +const deleteOIDCSessionBySub = `-- name: DeleteOIDCSessionBySub :exec +DELETE FROM "oidc_sessions" WHERE "sub" = $1 ` -func (q *Queries) DeleteOidcCodeBySub(ctx context.Context, sub string) error { - _, err := q.db.ExecContext(ctx, deleteOidcCodeBySub, sub) +func (q *Queries) DeleteOIDCSessionBySub(ctx context.Context, sub string) error { + _, err := q.db.ExecContext(ctx, deleteOIDCSessionBySub, sub) return err } -const deleteOidcToken = `-- name: DeleteOidcToken :exec -DELETE FROM "oidc_tokens" +const getOIDCSessionByAccessTokenHash = `-- name: GetOIDCSessionByAccessTokenHash :one +SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions" WHERE "access_token_hash" = $1 ` -func (q *Queries) DeleteOidcToken(ctx context.Context, accessTokenHash string) error { - _, err := q.db.ExecContext(ctx, deleteOidcToken, accessTokenHash) - return err -} - -const deleteOidcTokenByCodeHash = `-- name: DeleteOidcTokenByCodeHash :exec -DELETE FROM "oidc_tokens" -WHERE "code_hash" = $1 -` - -func (q *Queries) DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error { - _, err := q.db.ExecContext(ctx, deleteOidcTokenByCodeHash, codeHash) - return err -} - -const deleteOidcTokenBySub = `-- name: DeleteOidcTokenBySub :exec -DELETE FROM "oidc_tokens" -WHERE "sub" = $1 -` - -func (q *Queries) DeleteOidcTokenBySub(ctx context.Context, sub string) error { - _, err := q.db.ExecContext(ctx, deleteOidcTokenBySub, sub) - return err -} - -const deleteOidcUserInfo = `-- name: DeleteOidcUserInfo :exec -DELETE FROM "oidc_userinfo" -WHERE "sub" = $1 -` - -func (q *Queries) DeleteOidcUserInfo(ctx context.Context, sub string) error { - _, err := q.db.ExecContext(ctx, deleteOidcUserInfo, sub) - return err -} - -const getOidcCode = `-- name: GetOidcCode :one -DELETE FROM "oidc_codes" -WHERE "code_hash" = $1 -RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge -` - -func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error) { - row := q.db.QueryRowContext(ctx, getOidcCode, codeHash) - var i OidcCode - err := row.Scan( - &i.Sub, - &i.CodeHash, - &i.Scope, - &i.RedirectURI, - &i.ClientID, - &i.ExpiresAt, - &i.Nonce, - &i.CodeChallenge, - ) - return i, err -} - -const getOidcCodeBySub = `-- name: GetOidcCodeBySub :one -DELETE FROM "oidc_codes" -WHERE "sub" = $1 -RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge -` - -func (q *Queries) GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error) { - row := q.db.QueryRowContext(ctx, getOidcCodeBySub, sub) - var i OidcCode - err := row.Scan( - &i.Sub, - &i.CodeHash, - &i.Scope, - &i.RedirectURI, - &i.ClientID, - &i.ExpiresAt, - &i.Nonce, - &i.CodeChallenge, - ) - return i, err -} - -const getOidcCodeBySubUnsafe = `-- name: GetOidcCodeBySubUnsafe :one -SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge FROM "oidc_codes" -WHERE "sub" = $1 -` - -func (q *Queries) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (OidcCode, error) { - row := q.db.QueryRowContext(ctx, getOidcCodeBySubUnsafe, sub) - var i OidcCode - err := row.Scan( - &i.Sub, - &i.CodeHash, - &i.Scope, - &i.RedirectURI, - &i.ClientID, - &i.ExpiresAt, - &i.Nonce, - &i.CodeChallenge, - ) - return i, err -} - -const getOidcCodeUnsafe = `-- name: GetOidcCodeUnsafe :one -SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge FROM "oidc_codes" -WHERE "code_hash" = $1 -` - -func (q *Queries) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (OidcCode, error) { - row := q.db.QueryRowContext(ctx, getOidcCodeUnsafe, codeHash) - var i OidcCode - err := row.Scan( - &i.Sub, - &i.CodeHash, - &i.Scope, - &i.RedirectURI, - &i.ClientID, - &i.ExpiresAt, - &i.Nonce, - &i.CodeChallenge, - ) - return i, err -} - -const getOidcToken = `-- name: GetOidcToken :one -SELECT sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens" -WHERE "access_token_hash" = $1 -` - -func (q *Queries) GetOidcToken(ctx context.Context, accessTokenHash string) (OidcToken, error) { - row := q.db.QueryRowContext(ctx, getOidcToken, accessTokenHash) - var i OidcToken +func (q *Queries) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (OidcSession, error) { + row := q.db.QueryRowContext(ctx, getOIDCSessionByAccessTokenHash, accessTokenHash) + var i OidcSession err := row.Scan( &i.Sub, &i.AccessTokenHash, &i.RefreshTokenHash, - &i.CodeHash, &i.Scope, &i.ClientID, &i.TokenExpiresAt, &i.RefreshTokenExpiresAt, &i.Nonce, + &i.UserinfoJson, ) return i, err } -const getOidcTokenByRefreshToken = `-- name: GetOidcTokenByRefreshToken :one -SELECT sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens" +const getOIDCSessionByRefreshTokenHash = `-- name: GetOIDCSessionByRefreshTokenHash :one +SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions" WHERE "refresh_token_hash" = $1 ` -func (q *Queries) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (OidcToken, error) { - row := q.db.QueryRowContext(ctx, getOidcTokenByRefreshToken, refreshTokenHash) - var i OidcToken +func (q *Queries) GetOIDCSessionByRefreshTokenHash(ctx context.Context, refreshTokenHash string) (OidcSession, error) { + row := q.db.QueryRowContext(ctx, getOIDCSessionByRefreshTokenHash, refreshTokenHash) + var i OidcSession err := row.Scan( &i.Sub, &i.AccessTokenHash, &i.RefreshTokenHash, - &i.CodeHash, &i.Scope, &i.ClientID, &i.TokenExpiresAt, &i.RefreshTokenExpiresAt, &i.Nonce, + &i.UserinfoJson, ) return i, err } -const getOidcTokenBySub = `-- name: GetOidcTokenBySub :one -SELECT sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens" +const getOIDCSessionBySub = `-- name: GetOIDCSessionBySub :one +SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions" WHERE "sub" = $1 ` -func (q *Queries) GetOidcTokenBySub(ctx context.Context, sub string) (OidcToken, error) { - row := q.db.QueryRowContext(ctx, getOidcTokenBySub, sub) - var i OidcToken +func (q *Queries) GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSession, error) { + row := q.db.QueryRowContext(ctx, getOIDCSessionBySub, sub) + var i OidcSession err := row.Scan( &i.Sub, &i.AccessTokenHash, &i.RefreshTokenHash, - &i.CodeHash, &i.Scope, &i.ClientID, &i.TokenExpiresAt, &i.RefreshTokenExpiresAt, &i.Nonce, + &i.UserinfoJson, ) return i, err } -const getOidcUserInfo = `-- name: GetOidcUserInfo :one -SELECT sub, name, preferred_username, email, groups, updated_at, given_name, family_name, middle_name, nickname, profile, picture, website, gender, birthdate, zoneinfo, locale, phone_number, address FROM "oidc_userinfo" -WHERE "sub" = $1 +const updateOIDCSession = `-- name: UpdateOIDCSession :one +UPDATE "oidc_sessions" SET + "access_token_hash" = $1, + "refresh_token_hash" = $2, + "scope" = $3, + "client_id" = $4, + "token_expires_at" = $5, + "refresh_token_expires_at" = $6, + "nonce" = $7, + "userinfo_json" = $8 +WHERE "sub" = $9 +RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json ` -func (q *Queries) GetOidcUserInfo(ctx context.Context, sub string) (OidcUserinfo, error) { - row := q.db.QueryRowContext(ctx, getOidcUserInfo, sub) - var i OidcUserinfo - err := row.Scan( - &i.Sub, - &i.Name, - &i.PreferredUsername, - &i.Email, - &i.Groups, - &i.UpdatedAt, - &i.GivenName, - &i.FamilyName, - &i.MiddleName, - &i.Nickname, - &i.Profile, - &i.Picture, - &i.Website, - &i.Gender, - &i.Birthdate, - &i.Zoneinfo, - &i.Locale, - &i.PhoneNumber, - &i.Address, - ) - return i, err -} - -const updateOidcTokenByRefreshToken = `-- name: UpdateOidcTokenByRefreshToken :one -UPDATE "oidc_tokens" SET - "access_token_hash" = $1, - "refresh_token_hash" = $2, - "token_expires_at" = $3, - "refresh_token_expires_at" = $4 -WHERE "refresh_token_hash" = $5 -RETURNING sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce -` - -type UpdateOidcTokenByRefreshTokenParams struct { +type UpdateOIDCSessionParams struct { AccessTokenHash string RefreshTokenHash string + Scope string + ClientID string TokenExpiresAt int64 RefreshTokenExpiresAt int64 - RefreshTokenHash_2 string + Nonce string + UserinfoJson string + Sub string } -func (q *Queries) UpdateOidcTokenByRefreshToken(ctx context.Context, arg UpdateOidcTokenByRefreshTokenParams) (OidcToken, error) { - row := q.db.QueryRowContext(ctx, updateOidcTokenByRefreshToken, +func (q *Queries) UpdateOIDCSession(ctx context.Context, arg UpdateOIDCSessionParams) (OidcSession, error) { + row := q.db.QueryRowContext(ctx, updateOIDCSession, arg.AccessTokenHash, arg.RefreshTokenHash, + arg.Scope, + arg.ClientID, arg.TokenExpiresAt, arg.RefreshTokenExpiresAt, - arg.RefreshTokenHash_2, + arg.Nonce, + arg.UserinfoJson, + arg.Sub, ) - var i OidcToken + var i OidcSession err := row.Scan( &i.Sub, &i.AccessTokenHash, &i.RefreshTokenHash, - &i.CodeHash, &i.Scope, &i.ClientID, &i.TokenExpiresAt, &i.RefreshTokenExpiresAt, &i.Nonce, + &i.UserinfoJson, ) return i, err } diff --git a/internal/repository/postgres/session_queries.sql.go b/internal/repository/postgres/session_queries.sql.go index c7ea71d4..89cc0888 100644 --- a/internal/repository/postgres/session_queries.sql.go +++ b/internal/repository/postgres/session_queries.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.31.1 +// sqlc v1.30.0 // source: session_queries.sql package postgres diff --git a/internal/repository/postgres/store.go b/internal/repository/postgres/store.go index ed4bbb73..b3e79c80 100644 --- a/internal/repository/postgres/store.go +++ b/internal/repository/postgres/store.go @@ -32,28 +32,12 @@ func mapErr(err error) error { return err } -func (s *Store) CreateOidcCode(ctx context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) { - r, err := s.q.CreateOidcCode(ctx, CreateOidcCodeParams(arg)) +func (s *Store) CreateOIDCSession(ctx context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) { + r, err := s.q.CreateOIDCSession(ctx, CreateOIDCSessionParams(arg)) if err != nil { - return repository.OidcCode{}, mapErr(err) + return repository.OidcSession{}, mapErr(err) } - return repository.OidcCode(r), nil -} - -func (s *Store) CreateOidcToken(ctx context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) { - r, err := s.q.CreateOidcToken(ctx, CreateOidcTokenParams(arg)) - if err != nil { - return repository.OidcToken{}, mapErr(err) - } - return repository.OidcToken(r), nil -} - -func (s *Store) CreateOidcUserInfo(ctx context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) { - r, err := s.q.CreateOidcUserInfo(ctx, CreateOidcUserInfoParams(arg)) - if err != nil { - return repository.OidcUserinfo{}, mapErr(err) - } - return repository.OidcUserinfo(r), nil + return repository.OidcSession(r), nil } func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionParams) (repository.Session, error) { @@ -64,124 +48,44 @@ func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionP return repository.Session(r), nil } -func (s *Store) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]repository.OidcCode, error) { - rows, err := s.q.DeleteExpiredOidcCodes(ctx, expiresAt) - if err != nil { - return nil, mapErr(err) - } - out := make([]repository.OidcCode, len(rows)) - for i, row := range rows { - out[i] = repository.OidcCode(row) - } - return out, nil -} - -func (s *Store) DeleteExpiredOidcTokens(ctx context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) { - rows, err := s.q.DeleteExpiredOidcTokens(ctx, DeleteExpiredOidcTokensParams(arg)) - if err != nil { - return nil, mapErr(err) - } - out := make([]repository.OidcToken, len(rows)) - for i, row := range rows { - out[i] = repository.OidcToken(row) - } - return out, nil +func (s *Store) DeleteExpiredOIDCSessions(ctx context.Context, arg repository.DeleteExpiredOIDCSessionsParams) error { + return mapErr(s.q.DeleteExpiredOIDCSessions(ctx, DeleteExpiredOIDCSessionsParams(arg))) } func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error { return mapErr(s.q.DeleteExpiredSessions(ctx, expiry)) } -func (s *Store) DeleteOidcCode(ctx context.Context, codeHash string) error { - return mapErr(s.q.DeleteOidcCode(ctx, codeHash)) -} - -func (s *Store) DeleteOidcCodeBySub(ctx context.Context, sub string) error { - return mapErr(s.q.DeleteOidcCodeBySub(ctx, sub)) -} - -func (s *Store) DeleteOidcToken(ctx context.Context, accessTokenHash string) error { - return mapErr(s.q.DeleteOidcToken(ctx, accessTokenHash)) -} - -func (s *Store) DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error { - return mapErr(s.q.DeleteOidcTokenByCodeHash(ctx, codeHash)) -} - -func (s *Store) DeleteOidcTokenBySub(ctx context.Context, sub string) error { - return mapErr(s.q.DeleteOidcTokenBySub(ctx, sub)) -} - -func (s *Store) DeleteOidcUserInfo(ctx context.Context, sub string) error { - return mapErr(s.q.DeleteOidcUserInfo(ctx, sub)) +func (s *Store) DeleteOIDCSessionBySub(ctx context.Context, sub string) error { + return mapErr(s.q.DeleteOIDCSessionBySub(ctx, sub)) } func (s *Store) DeleteSession(ctx context.Context, uuid string) error { return mapErr(s.q.DeleteSession(ctx, uuid)) } -func (s *Store) GetOidcCode(ctx context.Context, codeHash string) (repository.OidcCode, error) { - r, err := s.q.GetOidcCode(ctx, codeHash) +func (s *Store) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (repository.OidcSession, error) { + r, err := s.q.GetOIDCSessionByAccessTokenHash(ctx, accessTokenHash) if err != nil { - return repository.OidcCode{}, mapErr(err) + return repository.OidcSession{}, mapErr(err) } - return repository.OidcCode(r), nil + return repository.OidcSession(r), nil } -func (s *Store) GetOidcCodeBySub(ctx context.Context, sub string) (repository.OidcCode, error) { - r, err := s.q.GetOidcCodeBySub(ctx, sub) +func (s *Store) GetOIDCSessionByRefreshTokenHash(ctx context.Context, refreshTokenHash string) (repository.OidcSession, error) { + r, err := s.q.GetOIDCSessionByRefreshTokenHash(ctx, refreshTokenHash) if err != nil { - return repository.OidcCode{}, mapErr(err) + return repository.OidcSession{}, mapErr(err) } - return repository.OidcCode(r), nil + return repository.OidcSession(r), nil } -func (s *Store) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (repository.OidcCode, error) { - r, err := s.q.GetOidcCodeBySubUnsafe(ctx, sub) +func (s *Store) GetOIDCSessionBySub(ctx context.Context, sub string) (repository.OidcSession, error) { + r, err := s.q.GetOIDCSessionBySub(ctx, sub) if err != nil { - return repository.OidcCode{}, mapErr(err) + return repository.OidcSession{}, mapErr(err) } - return repository.OidcCode(r), nil -} - -func (s *Store) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (repository.OidcCode, error) { - r, err := s.q.GetOidcCodeUnsafe(ctx, codeHash) - if err != nil { - return repository.OidcCode{}, mapErr(err) - } - return repository.OidcCode(r), nil -} - -func (s *Store) GetOidcToken(ctx context.Context, accessTokenHash string) (repository.OidcToken, error) { - r, err := s.q.GetOidcToken(ctx, accessTokenHash) - if err != nil { - return repository.OidcToken{}, mapErr(err) - } - return repository.OidcToken(r), nil -} - -func (s *Store) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (repository.OidcToken, error) { - r, err := s.q.GetOidcTokenByRefreshToken(ctx, refreshTokenHash) - if err != nil { - return repository.OidcToken{}, mapErr(err) - } - return repository.OidcToken(r), nil -} - -func (s *Store) GetOidcTokenBySub(ctx context.Context, sub string) (repository.OidcToken, error) { - r, err := s.q.GetOidcTokenBySub(ctx, sub) - if err != nil { - return repository.OidcToken{}, mapErr(err) - } - return repository.OidcToken(r), nil -} - -func (s *Store) GetOidcUserInfo(ctx context.Context, sub string) (repository.OidcUserinfo, error) { - r, err := s.q.GetOidcUserInfo(ctx, sub) - if err != nil { - return repository.OidcUserinfo{}, mapErr(err) - } - return repository.OidcUserinfo(r), nil + return repository.OidcSession(r), nil } func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session, error) { @@ -192,12 +96,12 @@ func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session return repository.Session(r), nil } -func (s *Store) UpdateOidcTokenByRefreshToken(ctx context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) { - r, err := s.q.UpdateOidcTokenByRefreshToken(ctx, UpdateOidcTokenByRefreshTokenParams(arg)) +func (s *Store) UpdateOIDCSession(ctx context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) { + r, err := s.q.UpdateOIDCSession(ctx, UpdateOIDCSessionParams(arg)) if err != nil { - return repository.OidcToken{}, mapErr(err) + return repository.OidcSession{}, mapErr(err) } - return repository.OidcToken(r), nil + return repository.OidcSession(r), nil } func (s *Store) UpdateSession(ctx context.Context, arg repository.UpdateSessionParams) (repository.Session, error) { diff --git a/internal/repository/sqlite/db.go b/internal/repository/sqlite/db.go index 51a4906a..3c39218d 100644 --- a/internal/repository/sqlite/db.go +++ b/internal/repository/sqlite/db.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.31.1 +// sqlc v1.30.0 package sqlite diff --git a/internal/repository/sqlite/models.go b/internal/repository/sqlite/models.go index fd6f78da..a00bbb11 100644 --- a/internal/repository/sqlite/models.go +++ b/internal/repository/sqlite/models.go @@ -1,52 +1,19 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.31.1 +// sqlc v1.30.0 package sqlite -type OidcCode struct { - Sub string - CodeHash string - Scope string - RedirectURI string - ClientID string - ExpiresAt int64 - Nonce string - CodeChallenge string -} - -type OidcToken struct { +type OidcSession struct { Sub string AccessTokenHash string RefreshTokenHash string - CodeHash string Scope string ClientID string TokenExpiresAt int64 RefreshTokenExpiresAt int64 Nonce string -} - -type OidcUserinfo struct { - Sub string - Name string - PreferredUsername string - Email string - Groups string - UpdatedAt int64 - GivenName string - FamilyName string - MiddleName string - Nickname string - Profile string - Picture string - Website string - Gender string - Birthdate string - Zoneinfo string - Locale string - PhoneNumber string - Address string + UserinfoJson string } type Session struct { diff --git a/internal/repository/sqlite/oidc_queries.sql.go b/internal/repository/sqlite/oidc_queries.sql.go index e5d08bc2..b5859460 100644 --- a/internal/repository/sqlite/oidc_queries.sql.go +++ b/internal/repository/sqlite/oidc_queries.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.31.1 +// sqlc v1.30.0 // source: oidc_queries.sql package sqlite @@ -9,60 +9,8 @@ import ( "context" ) -const createOidcCode = `-- name: CreateOidcCode :one -INSERT INTO "oidc_codes" ( - "sub", - "code_hash", - "scope", - "redirect_uri", - "client_id", - "expires_at", - "nonce", - "code_challenge" -) VALUES ( - ?, ?, ?, ?, ?, ?, ?, ? -) -RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge -` - -type CreateOidcCodeParams struct { - Sub string - CodeHash string - Scope string - RedirectURI string - ClientID string - ExpiresAt int64 - Nonce string - CodeChallenge string -} - -func (q *Queries) CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error) { - row := q.db.QueryRowContext(ctx, createOidcCode, - arg.Sub, - arg.CodeHash, - arg.Scope, - arg.RedirectURI, - arg.ClientID, - arg.ExpiresAt, - arg.Nonce, - arg.CodeChallenge, - ) - var i OidcCode - err := row.Scan( - &i.Sub, - &i.CodeHash, - &i.Scope, - &i.RedirectURI, - &i.ClientID, - &i.ExpiresAt, - &i.Nonce, - &i.CodeChallenge, - ) - return i, err -} - -const createOidcToken = `-- name: CreateOidcToken :one -INSERT INTO "oidc_tokens" ( +const createOIDCSession = `-- name: CreateOIDCSession :one +INSERT INTO "oidc_sessions" ( "sub", "access_token_hash", "refresh_token_hash", @@ -70,15 +18,15 @@ INSERT INTO "oidc_tokens" ( "client_id", "token_expires_at", "refresh_token_expires_at", - "code_hash", - "nonce" + "nonce", + "userinfo_json" ) VALUES ( ?, ?, ?, ?, ?, ?, ?, ?, ? ) -RETURNING sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce +RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json ` -type CreateOidcTokenParams struct { +type CreateOIDCSessionParams struct { Sub string AccessTokenHash string RefreshTokenHash string @@ -86,12 +34,12 @@ type CreateOidcTokenParams struct { ClientID string TokenExpiresAt int64 RefreshTokenExpiresAt int64 - CodeHash string Nonce string + UserinfoJson string } -func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error) { - row := q.db.QueryRowContext(ctx, createOidcToken, +func (q *Queries) CreateOIDCSession(ctx context.Context, arg CreateOIDCSessionParams) (OidcSession, error) { + row := q.db.QueryRowContext(ctx, createOIDCSession, arg.Sub, arg.AccessTokenHash, arg.RefreshTokenHash, @@ -99,483 +47,164 @@ func (q *Queries) CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams arg.ClientID, arg.TokenExpiresAt, arg.RefreshTokenExpiresAt, - arg.CodeHash, arg.Nonce, + arg.UserinfoJson, ) - var i OidcToken + var i OidcSession err := row.Scan( &i.Sub, &i.AccessTokenHash, &i.RefreshTokenHash, - &i.CodeHash, &i.Scope, &i.ClientID, &i.TokenExpiresAt, &i.RefreshTokenExpiresAt, &i.Nonce, + &i.UserinfoJson, ) return i, err } -const createOidcUserInfo = `-- name: CreateOidcUserInfo :one -INSERT INTO "oidc_userinfo" ( - "sub", - "name", - "preferred_username", - "email", - "groups", - "updated_at", - "given_name", - "family_name", - "middle_name", - "nickname", - "profile", - "picture", - "website", - "gender", - "birthdate", - "zoneinfo", - "locale", - "phone_number", - "address" -) VALUES ( - ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? -) -RETURNING sub, name, preferred_username, email, "groups", updated_at, given_name, family_name, middle_name, nickname, profile, picture, website, gender, birthdate, zoneinfo, locale, phone_number, address -` - -type CreateOidcUserInfoParams struct { - Sub string - Name string - PreferredUsername string - Email string - Groups string - UpdatedAt int64 - GivenName string - FamilyName string - MiddleName string - Nickname string - Profile string - Picture string - Website string - Gender string - Birthdate string - Zoneinfo string - Locale string - PhoneNumber string - Address string -} - -func (q *Queries) CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfoParams) (OidcUserinfo, error) { - row := q.db.QueryRowContext(ctx, createOidcUserInfo, - arg.Sub, - arg.Name, - arg.PreferredUsername, - arg.Email, - arg.Groups, - arg.UpdatedAt, - arg.GivenName, - arg.FamilyName, - arg.MiddleName, - arg.Nickname, - arg.Profile, - arg.Picture, - arg.Website, - arg.Gender, - arg.Birthdate, - arg.Zoneinfo, - arg.Locale, - arg.PhoneNumber, - arg.Address, - ) - var i OidcUserinfo - err := row.Scan( - &i.Sub, - &i.Name, - &i.PreferredUsername, - &i.Email, - &i.Groups, - &i.UpdatedAt, - &i.GivenName, - &i.FamilyName, - &i.MiddleName, - &i.Nickname, - &i.Profile, - &i.Picture, - &i.Website, - &i.Gender, - &i.Birthdate, - &i.Zoneinfo, - &i.Locale, - &i.PhoneNumber, - &i.Address, - ) - return i, err -} - -const deleteExpiredOidcCodes = `-- name: DeleteExpiredOidcCodes :many -DELETE FROM "oidc_codes" -WHERE "expires_at" < ? -RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge -` - -func (q *Queries) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]OidcCode, error) { - rows, err := q.db.QueryContext(ctx, deleteExpiredOidcCodes, expiresAt) - if err != nil { - return nil, err - } - defer rows.Close() - var items []OidcCode - for rows.Next() { - var i OidcCode - if err := rows.Scan( - &i.Sub, - &i.CodeHash, - &i.Scope, - &i.RedirectURI, - &i.ClientID, - &i.ExpiresAt, - &i.Nonce, - &i.CodeChallenge, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const deleteExpiredOidcTokens = `-- name: DeleteExpiredOidcTokens :many -DELETE FROM "oidc_tokens" +const deleteExpiredOIDCSessions = `-- name: DeleteExpiredOIDCSessions :exec +DELETE FROM "oidc_sessions" WHERE "token_expires_at" < ? AND "refresh_token_expires_at" < ? -RETURNING sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce ` -type DeleteExpiredOidcTokensParams struct { +type DeleteExpiredOIDCSessionsParams struct { TokenExpiresAt int64 RefreshTokenExpiresAt int64 } -func (q *Queries) DeleteExpiredOidcTokens(ctx context.Context, arg DeleteExpiredOidcTokensParams) ([]OidcToken, error) { - rows, err := q.db.QueryContext(ctx, deleteExpiredOidcTokens, arg.TokenExpiresAt, arg.RefreshTokenExpiresAt) - if err != nil { - return nil, err - } - defer rows.Close() - var items []OidcToken - for rows.Next() { - var i OidcToken - if err := rows.Scan( - &i.Sub, - &i.AccessTokenHash, - &i.RefreshTokenHash, - &i.CodeHash, - &i.Scope, - &i.ClientID, - &i.TokenExpiresAt, - &i.RefreshTokenExpiresAt, - &i.Nonce, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const deleteOidcCode = `-- name: DeleteOidcCode :exec -DELETE FROM "oidc_codes" -WHERE "code_hash" = ? -` - -func (q *Queries) DeleteOidcCode(ctx context.Context, codeHash string) error { - _, err := q.db.ExecContext(ctx, deleteOidcCode, codeHash) +func (q *Queries) DeleteExpiredOIDCSessions(ctx context.Context, arg DeleteExpiredOIDCSessionsParams) error { + _, err := q.db.ExecContext(ctx, deleteExpiredOIDCSessions, arg.TokenExpiresAt, arg.RefreshTokenExpiresAt) return err } -const deleteOidcCodeBySub = `-- name: DeleteOidcCodeBySub :exec -DELETE FROM "oidc_codes" +const deleteOIDCSessionBySub = `-- name: DeleteOIDCSessionBySub :exec +DELETE FROM "oidc_sessions" WHERE "sub" = ? ` -func (q *Queries) DeleteOidcCodeBySub(ctx context.Context, sub string) error { - _, err := q.db.ExecContext(ctx, deleteOidcCodeBySub, sub) +func (q *Queries) DeleteOIDCSessionBySub(ctx context.Context, sub string) error { + _, err := q.db.ExecContext(ctx, deleteOIDCSessionBySub, sub) return err } -const deleteOidcToken = `-- name: DeleteOidcToken :exec -DELETE FROM "oidc_tokens" +const getOIDCSessionByAccessTokenHash = `-- name: GetOIDCSessionByAccessTokenHash :one +SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions" WHERE "access_token_hash" = ? ` -func (q *Queries) DeleteOidcToken(ctx context.Context, accessTokenHash string) error { - _, err := q.db.ExecContext(ctx, deleteOidcToken, accessTokenHash) - return err -} - -const deleteOidcTokenByCodeHash = `-- name: DeleteOidcTokenByCodeHash :exec -DELETE FROM "oidc_tokens" -WHERE "code_hash" = ? -` - -func (q *Queries) DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error { - _, err := q.db.ExecContext(ctx, deleteOidcTokenByCodeHash, codeHash) - return err -} - -const deleteOidcTokenBySub = `-- name: DeleteOidcTokenBySub :exec -DELETE FROM "oidc_tokens" -WHERE "sub" = ? -` - -func (q *Queries) DeleteOidcTokenBySub(ctx context.Context, sub string) error { - _, err := q.db.ExecContext(ctx, deleteOidcTokenBySub, sub) - return err -} - -const deleteOidcUserInfo = `-- name: DeleteOidcUserInfo :exec -DELETE FROM "oidc_userinfo" -WHERE "sub" = ? -` - -func (q *Queries) DeleteOidcUserInfo(ctx context.Context, sub string) error { - _, err := q.db.ExecContext(ctx, deleteOidcUserInfo, sub) - return err -} - -const getOidcCode = `-- name: GetOidcCode :one -DELETE FROM "oidc_codes" -WHERE "code_hash" = ? -RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge -` - -func (q *Queries) GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error) { - row := q.db.QueryRowContext(ctx, getOidcCode, codeHash) - var i OidcCode - err := row.Scan( - &i.Sub, - &i.CodeHash, - &i.Scope, - &i.RedirectURI, - &i.ClientID, - &i.ExpiresAt, - &i.Nonce, - &i.CodeChallenge, - ) - return i, err -} - -const getOidcCodeBySub = `-- name: GetOidcCodeBySub :one -DELETE FROM "oidc_codes" -WHERE "sub" = ? -RETURNING sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge -` - -func (q *Queries) GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error) { - row := q.db.QueryRowContext(ctx, getOidcCodeBySub, sub) - var i OidcCode - err := row.Scan( - &i.Sub, - &i.CodeHash, - &i.Scope, - &i.RedirectURI, - &i.ClientID, - &i.ExpiresAt, - &i.Nonce, - &i.CodeChallenge, - ) - return i, err -} - -const getOidcCodeBySubUnsafe = `-- name: GetOidcCodeBySubUnsafe :one -SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge FROM "oidc_codes" -WHERE "sub" = ? -` - -func (q *Queries) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (OidcCode, error) { - row := q.db.QueryRowContext(ctx, getOidcCodeBySubUnsafe, sub) - var i OidcCode - err := row.Scan( - &i.Sub, - &i.CodeHash, - &i.Scope, - &i.RedirectURI, - &i.ClientID, - &i.ExpiresAt, - &i.Nonce, - &i.CodeChallenge, - ) - return i, err -} - -const getOidcCodeUnsafe = `-- name: GetOidcCodeUnsafe :one -SELECT sub, code_hash, scope, redirect_uri, client_id, expires_at, nonce, code_challenge FROM "oidc_codes" -WHERE "code_hash" = ? -` - -func (q *Queries) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (OidcCode, error) { - row := q.db.QueryRowContext(ctx, getOidcCodeUnsafe, codeHash) - var i OidcCode - err := row.Scan( - &i.Sub, - &i.CodeHash, - &i.Scope, - &i.RedirectURI, - &i.ClientID, - &i.ExpiresAt, - &i.Nonce, - &i.CodeChallenge, - ) - return i, err -} - -const getOidcToken = `-- name: GetOidcToken :one -SELECT sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens" -WHERE "access_token_hash" = ? -` - -func (q *Queries) GetOidcToken(ctx context.Context, accessTokenHash string) (OidcToken, error) { - row := q.db.QueryRowContext(ctx, getOidcToken, accessTokenHash) - var i OidcToken +func (q *Queries) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (OidcSession, error) { + row := q.db.QueryRowContext(ctx, getOIDCSessionByAccessTokenHash, accessTokenHash) + var i OidcSession err := row.Scan( &i.Sub, &i.AccessTokenHash, &i.RefreshTokenHash, - &i.CodeHash, &i.Scope, &i.ClientID, &i.TokenExpiresAt, &i.RefreshTokenExpiresAt, &i.Nonce, + &i.UserinfoJson, ) return i, err } -const getOidcTokenByRefreshToken = `-- name: GetOidcTokenByRefreshToken :one -SELECT sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens" +const getOIDCSessionByRefreshTokenHash = `-- name: GetOIDCSessionByRefreshTokenHash :one +SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions" WHERE "refresh_token_hash" = ? ` -func (q *Queries) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (OidcToken, error) { - row := q.db.QueryRowContext(ctx, getOidcTokenByRefreshToken, refreshTokenHash) - var i OidcToken +func (q *Queries) GetOIDCSessionByRefreshTokenHash(ctx context.Context, refreshTokenHash string) (OidcSession, error) { + row := q.db.QueryRowContext(ctx, getOIDCSessionByRefreshTokenHash, refreshTokenHash) + var i OidcSession err := row.Scan( &i.Sub, &i.AccessTokenHash, &i.RefreshTokenHash, - &i.CodeHash, &i.Scope, &i.ClientID, &i.TokenExpiresAt, &i.RefreshTokenExpiresAt, &i.Nonce, + &i.UserinfoJson, ) return i, err } -const getOidcTokenBySub = `-- name: GetOidcTokenBySub :one -SELECT sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce FROM "oidc_tokens" +const getOIDCSessionBySub = `-- name: GetOIDCSessionBySub :one +SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions" WHERE "sub" = ? ` -func (q *Queries) GetOidcTokenBySub(ctx context.Context, sub string) (OidcToken, error) { - row := q.db.QueryRowContext(ctx, getOidcTokenBySub, sub) - var i OidcToken +func (q *Queries) GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSession, error) { + row := q.db.QueryRowContext(ctx, getOIDCSessionBySub, sub) + var i OidcSession err := row.Scan( &i.Sub, &i.AccessTokenHash, &i.RefreshTokenHash, - &i.CodeHash, &i.Scope, &i.ClientID, &i.TokenExpiresAt, &i.RefreshTokenExpiresAt, &i.Nonce, + &i.UserinfoJson, ) return i, err } -const getOidcUserInfo = `-- name: GetOidcUserInfo :one -SELECT sub, name, preferred_username, email, "groups", updated_at, given_name, family_name, middle_name, nickname, profile, picture, website, gender, birthdate, zoneinfo, locale, phone_number, address FROM "oidc_userinfo" -WHERE "sub" = ? -` - -func (q *Queries) GetOidcUserInfo(ctx context.Context, sub string) (OidcUserinfo, error) { - row := q.db.QueryRowContext(ctx, getOidcUserInfo, sub) - var i OidcUserinfo - err := row.Scan( - &i.Sub, - &i.Name, - &i.PreferredUsername, - &i.Email, - &i.Groups, - &i.UpdatedAt, - &i.GivenName, - &i.FamilyName, - &i.MiddleName, - &i.Nickname, - &i.Profile, - &i.Picture, - &i.Website, - &i.Gender, - &i.Birthdate, - &i.Zoneinfo, - &i.Locale, - &i.PhoneNumber, - &i.Address, - ) - return i, err -} - -const updateOidcTokenByRefreshToken = `-- name: UpdateOidcTokenByRefreshToken :one -UPDATE "oidc_tokens" SET +const updateOIDCSession = `-- name: UpdateOIDCSession :one +UPDATE "oidc_sessions" SET "access_token_hash" = ?, "refresh_token_hash" = ?, + "scope" = ?, + "client_id" = ?, "token_expires_at" = ?, - "refresh_token_expires_at" = ? -WHERE "refresh_token_hash" = ? -RETURNING sub, access_token_hash, refresh_token_hash, code_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce + "refresh_token_expires_at" = ?, + "nonce" = ?, + "userinfo_json" = ? +WHERE "sub" = ? +RETURNING sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json ` -type UpdateOidcTokenByRefreshTokenParams struct { +type UpdateOIDCSessionParams struct { AccessTokenHash string RefreshTokenHash string + Scope string + ClientID string TokenExpiresAt int64 RefreshTokenExpiresAt int64 - RefreshTokenHash_2 string + Nonce string + UserinfoJson string + Sub string } -func (q *Queries) UpdateOidcTokenByRefreshToken(ctx context.Context, arg UpdateOidcTokenByRefreshTokenParams) (OidcToken, error) { - row := q.db.QueryRowContext(ctx, updateOidcTokenByRefreshToken, +func (q *Queries) UpdateOIDCSession(ctx context.Context, arg UpdateOIDCSessionParams) (OidcSession, error) { + row := q.db.QueryRowContext(ctx, updateOIDCSession, arg.AccessTokenHash, arg.RefreshTokenHash, + arg.Scope, + arg.ClientID, arg.TokenExpiresAt, arg.RefreshTokenExpiresAt, - arg.RefreshTokenHash_2, + arg.Nonce, + arg.UserinfoJson, + arg.Sub, ) - var i OidcToken + var i OidcSession err := row.Scan( &i.Sub, &i.AccessTokenHash, &i.RefreshTokenHash, - &i.CodeHash, &i.Scope, &i.ClientID, &i.TokenExpiresAt, &i.RefreshTokenExpiresAt, &i.Nonce, + &i.UserinfoJson, ) return i, err } diff --git a/internal/repository/sqlite/session_queries.sql.go b/internal/repository/sqlite/session_queries.sql.go index 7792fc4b..d71ecf51 100644 --- a/internal/repository/sqlite/session_queries.sql.go +++ b/internal/repository/sqlite/session_queries.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.31.1 +// sqlc v1.30.0 // source: session_queries.sql package sqlite diff --git a/internal/repository/sqlite/store.go b/internal/repository/sqlite/store.go index e7ce1792..a567c871 100644 --- a/internal/repository/sqlite/store.go +++ b/internal/repository/sqlite/store.go @@ -32,28 +32,12 @@ func mapErr(err error) error { return err } -func (s *Store) CreateOidcCode(ctx context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) { - r, err := s.q.CreateOidcCode(ctx, CreateOidcCodeParams(arg)) +func (s *Store) CreateOIDCSession(ctx context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) { + r, err := s.q.CreateOIDCSession(ctx, CreateOIDCSessionParams(arg)) if err != nil { - return repository.OidcCode{}, mapErr(err) + return repository.OidcSession{}, mapErr(err) } - return repository.OidcCode(r), nil -} - -func (s *Store) CreateOidcToken(ctx context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) { - r, err := s.q.CreateOidcToken(ctx, CreateOidcTokenParams(arg)) - if err != nil { - return repository.OidcToken{}, mapErr(err) - } - return repository.OidcToken(r), nil -} - -func (s *Store) CreateOidcUserInfo(ctx context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) { - r, err := s.q.CreateOidcUserInfo(ctx, CreateOidcUserInfoParams(arg)) - if err != nil { - return repository.OidcUserinfo{}, mapErr(err) - } - return repository.OidcUserinfo(r), nil + return repository.OidcSession(r), nil } func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionParams) (repository.Session, error) { @@ -64,124 +48,44 @@ func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionP return repository.Session(r), nil } -func (s *Store) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]repository.OidcCode, error) { - rows, err := s.q.DeleteExpiredOidcCodes(ctx, expiresAt) - if err != nil { - return nil, mapErr(err) - } - out := make([]repository.OidcCode, len(rows)) - for i, row := range rows { - out[i] = repository.OidcCode(row) - } - return out, nil -} - -func (s *Store) DeleteExpiredOidcTokens(ctx context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) { - rows, err := s.q.DeleteExpiredOidcTokens(ctx, DeleteExpiredOidcTokensParams(arg)) - if err != nil { - return nil, mapErr(err) - } - out := make([]repository.OidcToken, len(rows)) - for i, row := range rows { - out[i] = repository.OidcToken(row) - } - return out, nil +func (s *Store) DeleteExpiredOIDCSessions(ctx context.Context, arg repository.DeleteExpiredOIDCSessionsParams) error { + return mapErr(s.q.DeleteExpiredOIDCSessions(ctx, DeleteExpiredOIDCSessionsParams(arg))) } func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error { return mapErr(s.q.DeleteExpiredSessions(ctx, expiry)) } -func (s *Store) DeleteOidcCode(ctx context.Context, codeHash string) error { - return mapErr(s.q.DeleteOidcCode(ctx, codeHash)) -} - -func (s *Store) DeleteOidcCodeBySub(ctx context.Context, sub string) error { - return mapErr(s.q.DeleteOidcCodeBySub(ctx, sub)) -} - -func (s *Store) DeleteOidcToken(ctx context.Context, accessTokenHash string) error { - return mapErr(s.q.DeleteOidcToken(ctx, accessTokenHash)) -} - -func (s *Store) DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error { - return mapErr(s.q.DeleteOidcTokenByCodeHash(ctx, codeHash)) -} - -func (s *Store) DeleteOidcTokenBySub(ctx context.Context, sub string) error { - return mapErr(s.q.DeleteOidcTokenBySub(ctx, sub)) -} - -func (s *Store) DeleteOidcUserInfo(ctx context.Context, sub string) error { - return mapErr(s.q.DeleteOidcUserInfo(ctx, sub)) +func (s *Store) DeleteOIDCSessionBySub(ctx context.Context, sub string) error { + return mapErr(s.q.DeleteOIDCSessionBySub(ctx, sub)) } func (s *Store) DeleteSession(ctx context.Context, uuid string) error { return mapErr(s.q.DeleteSession(ctx, uuid)) } -func (s *Store) GetOidcCode(ctx context.Context, codeHash string) (repository.OidcCode, error) { - r, err := s.q.GetOidcCode(ctx, codeHash) +func (s *Store) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (repository.OidcSession, error) { + r, err := s.q.GetOIDCSessionByAccessTokenHash(ctx, accessTokenHash) if err != nil { - return repository.OidcCode{}, mapErr(err) + return repository.OidcSession{}, mapErr(err) } - return repository.OidcCode(r), nil + return repository.OidcSession(r), nil } -func (s *Store) GetOidcCodeBySub(ctx context.Context, sub string) (repository.OidcCode, error) { - r, err := s.q.GetOidcCodeBySub(ctx, sub) +func (s *Store) GetOIDCSessionByRefreshTokenHash(ctx context.Context, refreshTokenHash string) (repository.OidcSession, error) { + r, err := s.q.GetOIDCSessionByRefreshTokenHash(ctx, refreshTokenHash) if err != nil { - return repository.OidcCode{}, mapErr(err) + return repository.OidcSession{}, mapErr(err) } - return repository.OidcCode(r), nil + return repository.OidcSession(r), nil } -func (s *Store) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (repository.OidcCode, error) { - r, err := s.q.GetOidcCodeBySubUnsafe(ctx, sub) +func (s *Store) GetOIDCSessionBySub(ctx context.Context, sub string) (repository.OidcSession, error) { + r, err := s.q.GetOIDCSessionBySub(ctx, sub) if err != nil { - return repository.OidcCode{}, mapErr(err) + return repository.OidcSession{}, mapErr(err) } - return repository.OidcCode(r), nil -} - -func (s *Store) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (repository.OidcCode, error) { - r, err := s.q.GetOidcCodeUnsafe(ctx, codeHash) - if err != nil { - return repository.OidcCode{}, mapErr(err) - } - return repository.OidcCode(r), nil -} - -func (s *Store) GetOidcToken(ctx context.Context, accessTokenHash string) (repository.OidcToken, error) { - r, err := s.q.GetOidcToken(ctx, accessTokenHash) - if err != nil { - return repository.OidcToken{}, mapErr(err) - } - return repository.OidcToken(r), nil -} - -func (s *Store) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (repository.OidcToken, error) { - r, err := s.q.GetOidcTokenByRefreshToken(ctx, refreshTokenHash) - if err != nil { - return repository.OidcToken{}, mapErr(err) - } - return repository.OidcToken(r), nil -} - -func (s *Store) GetOidcTokenBySub(ctx context.Context, sub string) (repository.OidcToken, error) { - r, err := s.q.GetOidcTokenBySub(ctx, sub) - if err != nil { - return repository.OidcToken{}, mapErr(err) - } - return repository.OidcToken(r), nil -} - -func (s *Store) GetOidcUserInfo(ctx context.Context, sub string) (repository.OidcUserinfo, error) { - r, err := s.q.GetOidcUserInfo(ctx, sub) - if err != nil { - return repository.OidcUserinfo{}, mapErr(err) - } - return repository.OidcUserinfo(r), nil + return repository.OidcSession(r), nil } func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session, error) { @@ -192,12 +96,12 @@ func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session return repository.Session(r), nil } -func (s *Store) UpdateOidcTokenByRefreshToken(ctx context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) { - r, err := s.q.UpdateOidcTokenByRefreshToken(ctx, UpdateOidcTokenByRefreshTokenParams(arg)) +func (s *Store) UpdateOIDCSession(ctx context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) { + r, err := s.q.UpdateOIDCSession(ctx, UpdateOIDCSessionParams(arg)) if err != nil { - return repository.OidcToken{}, mapErr(err) + return repository.OidcSession{}, mapErr(err) } - return repository.OidcToken(r), nil + return repository.OidcSession(r), nil } func (s *Store) UpdateSession(ctx context.Context, arg repository.UpdateSessionParams) (repository.Session, error) { diff --git a/internal/repository/store.go b/internal/repository/store.go index 302f2f10..abd70bd3 100644 --- a/internal/repository/store.go +++ b/internal/repository/store.go @@ -19,29 +19,12 @@ type Store interface { DeleteSession(ctx context.Context, uuid string) error DeleteExpiredSessions(ctx context.Context, expiry int64) error - // OIDC codes - CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error) - GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error) - GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error) - GetOidcCodeUnsafe(ctx context.Context, codeHash string) (OidcCode, error) - GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (OidcCode, error) - DeleteOidcCode(ctx context.Context, codeHash string) error - DeleteOidcCodeBySub(ctx context.Context, sub string) error - DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]OidcCode, error) - - // OIDC tokens - CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error) - GetOidcToken(ctx context.Context, accessTokenHash string) (OidcToken, error) - GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (OidcToken, error) - GetOidcTokenBySub(ctx context.Context, sub string) (OidcToken, error) - UpdateOidcTokenByRefreshToken(ctx context.Context, arg UpdateOidcTokenByRefreshTokenParams) (OidcToken, error) - DeleteOidcToken(ctx context.Context, accessTokenHash string) error - DeleteOidcTokenBySub(ctx context.Context, sub string) error - DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error - DeleteExpiredOidcTokens(ctx context.Context, arg DeleteExpiredOidcTokensParams) ([]OidcToken, error) - - // OIDC userinfo - CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfoParams) (OidcUserinfo, error) - GetOidcUserInfo(ctx context.Context, sub string) (OidcUserinfo, error) - DeleteOidcUserInfo(ctx context.Context, sub string) error + // OIDC sessions + CreateOIDCSession(ctx context.Context, arg CreateOIDCSessionParams) (OidcSession, error) + DeleteExpiredOIDCSessions(ctx context.Context, arg DeleteExpiredOIDCSessionsParams) error + DeleteOIDCSessionBySub(ctx context.Context, sub string) error + GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (OidcSession, error) + GetOIDCSessionByRefreshTokenHash(ctx context.Context, refreshTokenHash string) (OidcSession, error) + GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSession, error) + UpdateOIDCSession(ctx context.Context, arg UpdateOIDCSessionParams) (OidcSession, error) } diff --git a/internal/service/oidc_service.go b/internal/service/oidc_service.go index e4d7e975..5bd11fcf 100644 --- a/internal/service/oidc_service.go +++ b/internal/service/oidc_service.go @@ -19,7 +19,6 @@ import ( "slices" - "github.com/gin-gonic/gin" "github.com/go-jose/go-jose/v4" "github.com/steveiliop56/ding" "github.com/tinyauthapp/tinyauth/internal/model" @@ -42,6 +41,10 @@ var ( ErrInvalidClient = errors.New("invalid_client") ) +// This is not spec-compliant, the ID token SHOULD NOT contain user info claims but, +// it has became a "standard" and apps are looking for the claims in the ID tokens +// instead of calling the userinfo endpoint, so we include them in the ID token as well +// for better compatibility with existing apps type ClaimSet struct { Iss string `json:"iss"` Aud string `json:"aud"` @@ -67,6 +70,8 @@ type ClaimSet struct { Nonce string `json:"nonce,omitempty"` } +// We use this struct as both a response struct and a struct to store userinfo +// in the database type UserinfoResponse struct { Sub string `json:"sub"` Name string `json:"name,omitempty"` @@ -111,6 +116,16 @@ type AuthorizeRequest struct { CodeChallengeMethod string `json:"code_challenge_method"` } +type AuthorizeCodeEntry struct { + CodeHash string + Scope string + RedirectURI string + ClientID string + Nonce string + CodeChallenge string + Userinfo UserinfoResponse +} + type OIDCService struct { log *logger.Logger config model.Config @@ -121,6 +136,10 @@ type OIDCService struct { privateKey *rsa.PrivateKey publicKey *rsa.PublicKey issuer string + + caches struct { + code *CacheStore[AuthorizeCodeEntry] + } } func NewOIDCService( @@ -282,7 +301,26 @@ func NewOIDCService( } // Start cleanup routine - dg.Go(service.cleanupRoutine, ding.RingMinor) + // dg.Go(service.cleanupRoutine, ding.RingMinor) + + // Create caches + codeCash := NewCacheStore[AuthorizeCodeEntry](256) + service.caches.code = codeCash + + // Start cache cleanup routine + dg.Go(func(ctx context.Context) { + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + service.caches.code.Sweep() + case <-ctx.Done(): + return + } + } + }, ding.RingMinor) return service, nil } @@ -345,19 +383,17 @@ func (service *OIDCService) filterScopes(scopes []string) []string { }) } -func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, req AuthorizeRequest) error { - // Fixed 10 minutes - expiresAt := time.Now().Add(time.Minute * time.Duration(10)).Unix() +func (service *OIDCService) CreateCode(req AuthorizeRequest, userContext model.UserContext) string { + code := utils.GenerateString(32) + sub := service.CreateSub(userContext, req.ClientID) - entry := repository.CreateOidcCodeParams{ - Sub: sub, - CodeHash: service.Hash(code), - // Here it's safe to split and trust the output since, we validated the scopes before - Scope: strings.Join(service.filterScopes(strings.Split(req.Scope, " ")), ","), + entry := AuthorizeCodeEntry{ + CodeHash: service.Hash(code), + Scope: strings.Join(service.filterScopes(strings.Split(req.Scope, " ")), " "), RedirectURI: req.RedirectURI, ClientID: req.ClientID, - ExpiresAt: expiresAt, Nonce: req.Nonce, + Userinfo: service.userinfoFromContext(userContext, sub), } if req.CodeChallenge != "" { @@ -369,14 +405,14 @@ func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, r } } - // Insert the code into the database - _, err := service.queries.CreateOidcCode(c, entry) + // Store the code in the cache + service.caches.code.Set(entry.CodeHash, entry, 10*time.Minute) - return err + return code } -func (service *OIDCService) StoreUserinfo(c *gin.Context, sub string, userContext model.UserContext, req AuthorizeRequest) error { - userInfoParams := repository.CreateOidcUserInfoParams{ +func (service *OIDCService) userinfoFromContext(userContext model.UserContext, sub string) UserinfoResponse { + userInfo := UserinfoResponse{ Sub: sub, Name: userContext.GetName(), Email: userContext.GetEmail(), @@ -385,37 +421,31 @@ func (service *OIDCService) StoreUserinfo(c *gin.Context, sub string, userContex } if userContext.IsLocal() { - addressJSON, err := json.Marshal(userContext.Local.Attributes.Address) - if err != nil { - return err - } - userInfoParams.GivenName = userContext.Local.Attributes.GivenName - userInfoParams.FamilyName = userContext.Local.Attributes.FamilyName - userInfoParams.MiddleName = userContext.Local.Attributes.MiddleName - userInfoParams.Nickname = userContext.Local.Attributes.Nickname - userInfoParams.Profile = userContext.Local.Attributes.Profile - userInfoParams.Picture = userContext.Local.Attributes.Picture - userInfoParams.Website = userContext.Local.Attributes.Website - userInfoParams.Gender = userContext.Local.Attributes.Gender - userInfoParams.Birthdate = userContext.Local.Attributes.Birthdate - userInfoParams.Zoneinfo = userContext.Local.Attributes.Zoneinfo - userInfoParams.Locale = userContext.Local.Attributes.Locale - userInfoParams.PhoneNumber = userContext.Local.Attributes.PhoneNumber - userInfoParams.Address = string(addressJSON) + userInfo.GivenName = userContext.Local.Attributes.GivenName + userInfo.FamilyName = userContext.Local.Attributes.FamilyName + userInfo.MiddleName = userContext.Local.Attributes.MiddleName + userInfo.Nickname = userContext.Local.Attributes.Nickname + userInfo.Profile = userContext.Local.Attributes.Profile + userInfo.Picture = userContext.Local.Attributes.Picture + userInfo.Website = userContext.Local.Attributes.Website + userInfo.Gender = userContext.Local.Attributes.Gender + userInfo.Birthdate = userContext.Local.Attributes.Birthdate + userInfo.Zoneinfo = userContext.Local.Attributes.Zoneinfo + userInfo.Locale = userContext.Local.Attributes.Locale + userInfo.PhoneNumber = userContext.Local.Attributes.PhoneNumber + userInfo.Address = &userContext.Local.Attributes.Address } // Tinyauth will pass through the groups it got from an LDAP or an OIDC server if userContext.IsLDAP() { - userInfoParams.Groups = strings.Join(userContext.LDAP.Groups, ",") + userInfo.Groups = userContext.LDAP.Groups } if userContext.IsOAuth() { - userInfoParams.Groups = strings.Join(userContext.OAuth.Groups, ",") + userInfo.Groups = userContext.OAuth.Groups } - _, err := service.queries.CreateOidcUserInfo(c, userInfoParams) - - return err + return userInfo } func (service *OIDCService) ValidateGrantType(grantType string) error { @@ -426,36 +456,24 @@ func (service *OIDCService) ValidateGrantType(grantType string) error { return nil } -func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string, clientId string) (repository.OidcCode, error) { - oidcCode, err := service.queries.GetOidcCode(c, codeHash) +func (service *OIDCService) GetCodeEntry(codeHash string, clientId string) (*AuthorizeCodeEntry, bool) { + entry, ok := service.caches.code.Get(codeHash) - if err != nil { - if errors.Is(err, repository.ErrNotFound) { - return repository.OidcCode{}, ErrCodeNotFound - } - return repository.OidcCode{}, err + if !ok { + return nil, false } - if time.Now().Unix() > oidcCode.ExpiresAt { - err = service.queries.DeleteOidcCode(c, codeHash) - if err != nil { - return repository.OidcCode{}, err - } - err = service.DeleteUserinfo(c, oidcCode.Sub) - if err != nil { - return repository.OidcCode{}, err - } - return repository.OidcCode{}, ErrCodeExpired + if entry.ClientID != clientId { + return nil, false } - if oidcCode.ClientID != clientId { - return repository.OidcCode{}, ErrInvalidClient - } + // Since the code can only be used once, we delete it from the cache after retrieving it + service.caches.code.Delete(codeHash) - return oidcCode, nil + return &entry, true } -func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user repository.OidcUserinfo, scope string, nonce string) (string, error) { +func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user UserinfoResponse, scope string, nonce string) (string, error) { createdAt := time.Now().Unix() expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix() @@ -521,17 +539,11 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user return token, nil } -func (service *OIDCService) GenerateAccessToken(c *gin.Context, client model.OIDCClientConfig, codeEntry repository.OidcCode) (TokenResponse, error) { - user, err := service.GetUserinfo(c, codeEntry.Sub) +func (service *OIDCService) GenerateAccessToken(ctx context.Context, client model.OIDCClientConfig, codeEntry AuthorizeCodeEntry) (*TokenResponse, error) { + idToken, err := service.generateIDToken(client, codeEntry.Userinfo, codeEntry.Scope, codeEntry.Nonce) if err != nil { - return TokenResponse{}, err - } - - idToken, err := service.generateIDToken(client, user, codeEntry.Scope, codeEntry.Nonce) - - if err != nil { - return TokenResponse{}, err + return nil, err } accessToken := utils.GenerateString(32) @@ -551,56 +563,68 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client model.OID Scope: strings.ReplaceAll(codeEntry.Scope, ",", " "), } - _, err = service.queries.CreateOidcToken(c, repository.CreateOidcTokenParams{ - Sub: codeEntry.Sub, + var userInfoJson []byte + + userInfoJson, err = json.Marshal(codeEntry.Userinfo) + + if err != nil { + return nil, err + } + + _, err = service.queries.CreateOIDCSession(ctx, repository.CreateOIDCSessionParams{ + Sub: codeEntry.Userinfo.Sub, AccessTokenHash: service.Hash(accessToken), RefreshTokenHash: service.Hash(refreshToken), - ClientID: client.ClientID, Scope: codeEntry.Scope, + ClientID: client.ClientID, TokenExpiresAt: tokenExpiresAt, RefreshTokenExpiresAt: refreshTokenExpiresAt, Nonce: codeEntry.Nonce, - CodeHash: codeEntry.CodeHash, + UserinfoJson: string(userInfoJson), }) if err != nil { - return TokenResponse{}, err + return nil, err } - return tokenResponse, nil + return &tokenResponse, nil } -func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken string, reqClientId string) (TokenResponse, error) { - entry, err := service.queries.GetOidcTokenByRefreshToken(c, service.Hash(refreshToken)) +func (service *OIDCService) RefreshAccessToken(ctx context.Context, refreshToken string, clientId string) (*TokenResponse, error) { + entry, err := service.queries.GetOIDCSessionByRefreshTokenHash(ctx, service.Hash(refreshToken)) if err != nil { if errors.Is(err, repository.ErrNotFound) { - return TokenResponse{}, ErrTokenNotFound + return nil, ErrTokenNotFound } - return TokenResponse{}, err + return nil, err } if entry.RefreshTokenExpiresAt < time.Now().Unix() { - return TokenResponse{}, ErrTokenExpired + return nil, ErrTokenExpired } // Ensure the client ID in the request matches the client ID in the token - if entry.ClientID != reqClientId { - return TokenResponse{}, ErrInvalidClient + if entry.ClientID != clientId { + return nil, ErrInvalidClient } - user, err := service.GetUserinfo(c, entry.Sub) + // we need to unmarshal the userinfo from the database to include it in the new ID token, + // since the ID token includes user claims for better compatibility with existing apps + var userInfo UserinfoResponse + + err = json.Unmarshal([]byte(entry.UserinfoJson), &userInfo) if err != nil { - return TokenResponse{}, err + return nil, err } idToken, err := service.generateIDToken(model.OIDCClientConfig{ ClientID: entry.ClientID, - }, user, entry.Scope, entry.Nonce) + }, userInfo, entry.Scope, entry.Nonce) if err != nil { - return TokenResponse{}, err + return nil, err } accessToken := utils.GenerateString(32) @@ -618,71 +642,54 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri Scope: strings.ReplaceAll(entry.Scope, ",", " "), } - _, err = service.queries.UpdateOidcTokenByRefreshToken(c, repository.UpdateOidcTokenByRefreshTokenParams{ + _, err = service.queries.UpdateOIDCSession(ctx, repository.UpdateOIDCSessionParams{ + Sub: entry.Sub, AccessTokenHash: service.Hash(accessToken), RefreshTokenHash: service.Hash(newRefreshToken), + Scope: entry.Scope, + ClientID: entry.ClientID, TokenExpiresAt: tokenExpiresAt, RefreshTokenExpiresAt: refreshTokenExpiresAt, - RefreshTokenHash_2: service.Hash(refreshToken), // that's the selector, it's not stored in the db + Nonce: entry.Nonce, + UserinfoJson: entry.UserinfoJson, }) if err != nil { - return TokenResponse{}, err + return nil, err } - return tokenResponse, nil + return &tokenResponse, nil } -func (service *OIDCService) DeleteCodeEntry(c *gin.Context, codeHash string) error { - return service.queries.DeleteOidcCode(c, codeHash) -} - -func (service *OIDCService) DeleteUserinfo(c *gin.Context, sub string) error { - return service.queries.DeleteOidcUserInfo(c, sub) -} - -func (service *OIDCService) DeleteToken(c *gin.Context, tokenHash string) error { - return service.queries.DeleteOidcToken(c, tokenHash) -} - -func (service *OIDCService) DeleteTokenByCodeHash(c *gin.Context, codeHash string) error { - return service.queries.DeleteOidcTokenByCodeHash(c, codeHash) -} - -func (service *OIDCService) GetAccessToken(c *gin.Context, tokenHash string) (repository.OidcToken, error) { - entry, err := service.queries.GetOidcToken(c, tokenHash) +func (service *OIDCService) GetSessionByToken(ctx context.Context, tokenHash string) (*repository.OidcSession, error) { + entry, err := service.queries.GetOIDCSessionByAccessTokenHash(ctx, tokenHash) if err != nil { if errors.Is(err, repository.ErrNotFound) { - return repository.OidcToken{}, ErrTokenNotFound + return nil, ErrTokenNotFound } - return repository.OidcToken{}, err + return nil, err } if entry.TokenExpiresAt < time.Now().Unix() { - // If refresh token is expired, delete the token and userinfo since there is no way for the client to access anything anymore + // If refresh token is expired, delete the session + // since there is no way for the client to access anything anymore if entry.RefreshTokenExpiresAt < time.Now().Unix() { - err := service.DeleteToken(c, tokenHash) + // Deletes by sub + err := service.queries.DeleteSession(ctx, entry.Sub) if err != nil { - return repository.OidcToken{}, err - } - err = service.DeleteUserinfo(c, entry.Sub) - if err != nil { - return repository.OidcToken{}, err + return nil, err } + return nil, ErrTokenExpired } - return repository.OidcToken{}, ErrTokenExpired + return nil, ErrTokenExpired } - return entry, nil + return &entry, nil } -func (service *OIDCService) GetUserinfo(c *gin.Context, sub string) (repository.OidcUserinfo, error) { - return service.queries.GetOidcUserInfo(c, sub) -} - -func (service *OIDCService) CompileUserinfo(user repository.OidcUserinfo, scope string) UserinfoResponse { - scopes := strings.Split(scope, ",") // split by comma since it's a db entry +func (service *OIDCService) CompileUserinfo(user UserinfoResponse, scope string) UserinfoResponse { + scopes := strings.Split(scope, " ") userInfo := UserinfoResponse{ Sub: user.Sub, UpdatedAt: user.UpdatedAt, @@ -710,11 +717,7 @@ func (service *OIDCService) CompileUserinfo(user repository.OidcUserinfo, scope } if slices.Contains(scopes, "groups") { - if user.Groups != "" { - userInfo.Groups = strings.Split(user.Groups, ",") - } else { - userInfo.Groups = []string{} - } + userInfo.Groups = user.Groups } if slices.Contains(scopes, "phone") { @@ -724,10 +727,7 @@ func (service *OIDCService) CompileUserinfo(user repository.OidcUserinfo, scope } if slices.Contains(scopes, "address") { - var addr model.AddressClaim - if err := json.Unmarshal([]byte(user.Address), &addr); err == nil { - userInfo.Address = &addr - } + userInfo.Address = user.Address } return userInfo @@ -740,83 +740,75 @@ func (service *OIDCService) Hash(token string) string { } func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) error { - err := service.queries.DeleteOidcCodeBySub(ctx, sub) - if err != nil && !errors.Is(err, repository.ErrNotFound) { - return err - } - err = service.queries.DeleteOidcTokenBySub(ctx, sub) - if err != nil && !errors.Is(err, repository.ErrNotFound) { - return err - } - err = service.queries.DeleteOidcUserInfo(ctx, sub) + err := service.queries.DeleteOIDCSessionBySub(ctx, sub) if err != nil && !errors.Is(err, repository.ErrNotFound) { return err } return nil } -// Cleanup routine - Resource heavy due to the linked tables -func (service *OIDCService) cleanupRoutine(ctx context.Context) { - service.log.App.Debug().Msg("Starting OIDC cleanup routine") - ticker := time.NewTicker(time.Duration(30) * time.Minute) - defer ticker.Stop() +// // Cleanup routine - Resource heavy due to the linked tables +// func (service *OIDCService) cleanupRoutine(ctx context.Context) { +// service.log.App.Debug().Msg("Starting OIDC cleanup routine") +// ticker := time.NewTicker(time.Duration(30) * time.Minute) +// defer ticker.Stop() - for { - select { - case <-ticker.C: - service.log.App.Debug().Msg("Performing OIDC cleanup routine") +// for { +// select { +// case <-ticker.C: +// service.log.App.Debug().Msg("Performing OIDC cleanup routine") - currentTime := time.Now().Unix() +// currentTime := time.Now().Unix() - // For the OIDC tokens, if they are expired we delete the userinfo and codes - expiredTokens, err := service.queries.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{ - TokenExpiresAt: currentTime, - RefreshTokenExpiresAt: currentTime, - }) +// // For the OIDC tokens, if they are expired we delete the userinfo and codes +// expiredTokens, err := service.queries.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{ +// TokenExpiresAt: currentTime, +// RefreshTokenExpiresAt: currentTime, +// }) - if err != nil { - service.log.App.Warn().Err(err).Msg("Failed to delete expired tokens") - } +// if err != nil { +// service.log.App.Warn().Err(err).Msg("Failed to delete expired tokens") +// } - for _, expiredToken := range expiredTokens { - err := service.DeleteOldSession(ctx, expiredToken.Sub) - if err != nil { - service.log.App.Warn().Err(err).Msg("Failed to delete session for expired token") - } - } +// for _, expiredToken := range expiredTokens { +// err := service.DeleteOldSession(ctx, expiredToken.Sub) +// if err != nil { +// service.log.App.Warn().Err(err).Msg("Failed to delete session for expired token") +// } +// } - // For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything - expiredCodes, err := service.queries.DeleteExpiredOidcCodes(ctx, currentTime) +// // For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything +// expiredCodes, err := service.queries.DeleteExpiredOidcCodes(ctx, currentTime) - if err != nil { - service.log.App.Warn().Err(err).Msg("Failed to delete expired codes") - } +// if err != nil { +// service.log.App.Warn().Err(err).Msg("Failed to delete expired codes") +// } - for _, expiredCode := range expiredCodes { - token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub) +// for _, expiredCode := range expiredCodes { +// token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub) - if err != nil { - if !errors.Is(err, repository.ErrNotFound) { - service.log.App.Warn().Err(err).Msg("Failed to get token by sub for expired code") - } - continue - } +// if err != nil { +// if !errors.Is(err, repository.ErrNotFound) { +// service.log.App.Warn().Err(err).Msg("Failed to get token by sub for expired code") +// } +// continue +// } - if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime { - err := service.DeleteOldSession(ctx, expiredCode.Sub) - if err != nil { - service.log.App.Warn().Err(err).Msg("Failed to delete session for expired code") - } - } - } +// if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime { +// err := service.DeleteOldSession(ctx, expiredCode.Sub) +// if err != nil { +// service.log.App.Warn().Err(err).Msg("Failed to delete session for expired code") +// } +// } +// } - service.log.App.Debug().Msg("Finished OIDC cleanup routine") - case <-ctx.Done(): - service.log.App.Debug().Msg("Stopping OIDC cleanup routine") - return - } - } -} +// service.log.App.Debug().Msg("Finished OIDC cleanup routine") +// case <-ctx.Done(): +// service.log.App.Debug().Msg("Stopping OIDC cleanup routine") +// return +// } +// } +// } func (service *OIDCService) GetJWK() ([]byte, error) { hasher := sha256.New() @@ -851,3 +843,10 @@ func (service *OIDCService) hashAndEncodePKCE(codeVerifier string) string { hasher.Write([]byte(codeVerifier)) return base64.RawURLEncoding.EncodeToString(hasher.Sum(nil)) } + +// WARNING: Since Tinyauth is stateless, we cannot have a sub that never changes. +// We will just create a uuid out of the username and client name which remains stable, +// but if username or client name changes then sub changes too. +func (service *OIDCService) CreateSub(userContext model.UserContext, clientId string) string { + return utils.GenerateUUID(fmt.Sprintf("%s:%s", userContext.GetUsername(), clientId)) +} diff --git a/sql/postgres/oidc_queries.sql b/sql/postgres/oidc_queries.sql index 8109d5cc..3cd5ff99 100644 --- a/sql/postgres/oidc_queries.sql +++ b/sql/postgres/oidc_queries.sql @@ -1,46 +1,17 @@ --- name: CreateOidcCode :one -INSERT INTO "oidc_codes" ( - "sub", - "code_hash", - "scope", - "redirect_uri", - "client_id", - "expires_at", - "nonce", - "code_challenge" -) VALUES ( - $1, $2, $3, $4, $5, $6, $7, $8 -) -RETURNING *; - --- name: GetOidcCodeUnsafe :one -SELECT * FROM "oidc_codes" -WHERE "code_hash" = $1; - --- name: GetOidcCode :one -DELETE FROM "oidc_codes" -WHERE "code_hash" = $1 -RETURNING *; - --- name: GetOidcCodeBySubUnsafe :one -SELECT * FROM "oidc_codes" +-- name: GetOIDCSessionBySub :one +SELECT * FROM "oidc_sessions" WHERE "sub" = $1; --- name: GetOidcCodeBySub :one -DELETE FROM "oidc_codes" -WHERE "sub" = $1 -RETURNING *; +-- name: GetOIDCSessionByAccessTokenHash :one +SELECT * FROM "oidc_sessions" +WHERE "access_token_hash" = $1; --- name: DeleteOidcCode :exec -DELETE FROM "oidc_codes" -WHERE "code_hash" = $1; +-- name: GetOIDCSessionByRefreshTokenHash :one +SELECT * FROM "oidc_sessions" +WHERE "refresh_token_hash" = $1; --- name: DeleteOidcCodeBySub :exec -DELETE FROM "oidc_codes" -WHERE "sub" = $1; - --- name: CreateOidcToken :one -INSERT INTO "oidc_tokens" ( +-- name: CreateOIDCSession :one +INSERT INTO "oidc_sessions" ( "sub", "access_token_hash", "refresh_token_hash", @@ -48,86 +19,30 @@ INSERT INTO "oidc_tokens" ( "client_id", "token_expires_at", "refresh_token_expires_at", - "code_hash", - "nonce" + "nonce", + "userinfo_json" ) VALUES ( $1, $2, $3, $4, $5, $6, $7, $8, $9 ) RETURNING *; --- name: UpdateOidcTokenByRefreshToken :one -UPDATE "oidc_tokens" SET - "access_token_hash" = $1, - "refresh_token_hash" = $2, - "token_expires_at" = $3, - "refresh_token_expires_at" = $4 -WHERE "refresh_token_hash" = $5 -RETURNING *; - --- name: GetOidcToken :one -SELECT * FROM "oidc_tokens" -WHERE "access_token_hash" = $1; - --- name: GetOidcTokenByRefreshToken :one -SELECT * FROM "oidc_tokens" -WHERE "refresh_token_hash" = $1; - --- name: GetOidcTokenBySub :one -SELECT * FROM "oidc_tokens" +-- name: DeleteOIDCSessionBySub :exec +DELETE FROM "oidc_sessions" WHERE "sub" = $1; --- name: DeleteOidcTokenByCodeHash :exec -DELETE FROM "oidc_tokens" -WHERE "code_hash" = $1; +-- name: DeleteExpiredOIDCSessions :exec +DELETE FROM "oidc_sessions" +WHERE "token_expires_at" < $1 AND "refresh_token_expires_at" < $2; --- name: DeleteOidcToken :exec -DELETE FROM "oidc_tokens" -WHERE "access_token_hash" = $1; - --- name: DeleteOidcTokenBySub :exec -DELETE FROM "oidc_tokens" -WHERE "sub" = $1; - --- name: CreateOidcUserInfo :one -INSERT INTO "oidc_userinfo" ( - "sub", - "name", - "preferred_username", - "email", - "groups", - "updated_at", - "given_name", - "family_name", - "middle_name", - "nickname", - "profile", - "picture", - "website", - "gender", - "birthdate", - "zoneinfo", - "locale", - "phone_number", - "address" -) VALUES ( - $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19 -) -RETURNING *; - --- name: GetOidcUserInfo :one -SELECT * FROM "oidc_userinfo" -WHERE "sub" = $1; - --- name: DeleteOidcUserInfo :exec -DELETE FROM "oidc_userinfo" -WHERE "sub" = $1; - --- name: DeleteExpiredOidcCodes :many -DELETE FROM "oidc_codes" -WHERE "expires_at" < $1 -RETURNING *; - --- name: DeleteExpiredOidcTokens :many -DELETE FROM "oidc_tokens" -WHERE "token_expires_at" < $1 AND "refresh_token_expires_at" < $2 +-- name: UpdateOIDCSession :one +UPDATE "oidc_sessions" SET + "access_token_hash" = $1, + "refresh_token_hash" = $2, + "scope" = $3, + "client_id" = $4, + "token_expires_at" = $5, + "refresh_token_expires_at" = $6, + "nonce" = $7, + "userinfo_json" = $8 +WHERE "sub" = $9 RETURNING *; diff --git a/sql/postgres/oidc_schemas.sql b/sql/postgres/oidc_schemas.sql index 96fac7fc..2376c1d4 100644 --- a/sql/postgres/oidc_schemas.sql +++ b/sql/postgres/oidc_schemas.sql @@ -1,44 +1,11 @@ -CREATE TABLE IF NOT EXISTS "oidc_codes" ( - "sub" TEXT NOT NULL UNIQUE, - "code_hash" TEXT NOT NULL PRIMARY KEY, - "scope" TEXT NOT NULL, - "redirect_uri" TEXT NOT NULL, - "client_id" TEXT NOT NULL, - "expires_at" BIGINT NOT NULL, - "nonce" TEXT NOT NULL DEFAULT '', - "code_challenge" TEXT NOT NULL DEFAULT '' -); - -CREATE TABLE IF NOT EXISTS "oidc_tokens" ( - "sub" TEXT NOT NULL UNIQUE, - "access_token_hash" TEXT NOT NULL PRIMARY KEY, - "refresh_token_hash" TEXT NOT NULL, - "code_hash" TEXT NOT NULL, - "scope" TEXT NOT NULL, - "client_id" TEXT NOT NULL, - "token_expires_at" BIGINT NOT NULL, +CREATE TABLE IF NOT EXISTS "oidc_sessions" ( + "sub" TEXT NOT NULL UNIQUE PRIMARY KEY, + "access_token_hash" TEXT NOT NULL UNIQUE, + "refresh_token_hash" TEXT NOT NULL UNIQUE, + "scope" TEXT NOT NULL, + "client_id" TEXT NOT NULL, + "token_expires_at" BIGINT NOT NULL, "refresh_token_expires_at" BIGINT NOT NULL, - "nonce" TEXT NOT NULL DEFAULT '' -); - -CREATE TABLE IF NOT EXISTS "oidc_userinfo" ( - "sub" TEXT NOT NULL PRIMARY KEY, - "name" TEXT NOT NULL, - "preferred_username" TEXT NOT NULL, - "email" TEXT NOT NULL, - "groups" TEXT NOT NULL, - "updated_at" BIGINT NOT NULL, - "given_name" TEXT NOT NULL, - "family_name" TEXT NOT NULL, - "middle_name" TEXT NOT NULL, - "nickname" TEXT NOT NULL, - "profile" TEXT NOT NULL, - "picture" TEXT NOT NULL, - "website" TEXT NOT NULL, - "gender" TEXT NOT NULL, - "birthdate" TEXT NOT NULL, - "zoneinfo" TEXT NOT NULL, - "locale" TEXT NOT NULL, - "phone_number" TEXT NOT NULL, - "address" TEXT NOT NULL + "nonce" TEXT NOT NULL DEFAULT '', + "userinfo_json" TEXT NOT NULL ); diff --git a/sql/sqlite/oidc_queries.sql b/sql/sqlite/oidc_queries.sql index 67b7b95e..49b33cff 100644 --- a/sql/sqlite/oidc_queries.sql +++ b/sql/sqlite/oidc_queries.sql @@ -1,46 +1,17 @@ --- name: CreateOidcCode :one -INSERT INTO "oidc_codes" ( - "sub", - "code_hash", - "scope", - "redirect_uri", - "client_id", - "expires_at", - "nonce", - "code_challenge" -) VALUES ( - ?, ?, ?, ?, ?, ?, ?, ? -) -RETURNING *; - --- name: GetOidcCodeUnsafe :one -SELECT * FROM "oidc_codes" -WHERE "code_hash" = ?; - --- name: GetOidcCode :one -DELETE FROM "oidc_codes" -WHERE "code_hash" = ? -RETURNING *; - --- name: GetOidcCodeBySubUnsafe :one -SELECT * FROM "oidc_codes" +-- name: GetOIDCSessionBySub :one +SELECT * FROM "oidc_sessions" WHERE "sub" = ?; --- name: GetOidcCodeBySub :one -DELETE FROM "oidc_codes" -WHERE "sub" = ? -RETURNING *; +-- name: GetOIDCSessionByAccessTokenHash :one +SELECT * FROM "oidc_sessions" +WHERE "access_token_hash" = ?; --- name: DeleteOidcCode :exec -DELETE FROM "oidc_codes" -WHERE "code_hash" = ?; +-- name: GetOIDCSessionByRefreshTokenHash :one +SELECT * FROM "oidc_sessions" +WHERE "refresh_token_hash" = ?; --- name: DeleteOidcCodeBySub :exec -DELETE FROM "oidc_codes" -WHERE "sub" = ?; - --- name: CreateOidcToken :one -INSERT INTO "oidc_tokens" ( +-- name: CreateOIDCSession :one +INSERT INTO "oidc_sessions" ( "sub", "access_token_hash", "refresh_token_hash", @@ -48,86 +19,30 @@ INSERT INTO "oidc_tokens" ( "client_id", "token_expires_at", "refresh_token_expires_at", - "code_hash", - "nonce" + "nonce", + "userinfo_json" ) VALUES ( ?, ?, ?, ?, ?, ?, ?, ?, ? ) RETURNING *; --- name: UpdateOidcTokenByRefreshToken :one -UPDATE "oidc_tokens" SET +-- name: DeleteOIDCSessionBySub :exec +DELETE FROM "oidc_sessions" +WHERE "sub" = ?; + +-- name: DeleteExpiredOIDCSessions :exec +DELETE FROM "oidc_sessions" +WHERE "token_expires_at" < ? AND "refresh_token_expires_at" < ?; + +-- name: UpdateOIDCSession :one +UPDATE "oidc_sessions" SET "access_token_hash" = ?, "refresh_token_hash" = ?, + "scope" = ?, + "client_id" = ?, "token_expires_at" = ?, - "refresh_token_expires_at" = ? -WHERE "refresh_token_hash" = ? -RETURNING *; - --- name: GetOidcToken :one -SELECT * FROM "oidc_tokens" -WHERE "access_token_hash" = ?; - --- name: GetOidcTokenByRefreshToken :one -SELECT * FROM "oidc_tokens" -WHERE "refresh_token_hash" = ?; - --- name: GetOidcTokenBySub :one -SELECT * FROM "oidc_tokens" -WHERE "sub" = ?; - --- name: DeleteOidcTokenByCodeHash :exec -DELETE FROM "oidc_tokens" -WHERE "code_hash" = ?; - --- name: DeleteOidcToken :exec -DELETE FROM "oidc_tokens" -WHERE "access_token_hash" = ?; - --- name: DeleteOidcTokenBySub :exec -DELETE FROM "oidc_tokens" -WHERE "sub" = ?; - --- name: CreateOidcUserInfo :one -INSERT INTO "oidc_userinfo" ( - "sub", - "name", - "preferred_username", - "email", - "groups", - "updated_at", - "given_name", - "family_name", - "middle_name", - "nickname", - "profile", - "picture", - "website", - "gender", - "birthdate", - "zoneinfo", - "locale", - "phone_number", - "address" -) VALUES ( - ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? -) -RETURNING *; - --- name: GetOidcUserInfo :one -SELECT * FROM "oidc_userinfo" -WHERE "sub" = ?; - --- name: DeleteOidcUserInfo :exec -DELETE FROM "oidc_userinfo" -WHERE "sub" = ?; - --- name: DeleteExpiredOidcCodes :many -DELETE FROM "oidc_codes" -WHERE "expires_at" < ? -RETURNING *; - --- name: DeleteExpiredOidcTokens :many -DELETE FROM "oidc_tokens" -WHERE "token_expires_at" < ? AND "refresh_token_expires_at" < ? + "refresh_token_expires_at" = ?, + "nonce" = ?, + "userinfo_json" = ? +WHERE "sub" = ? RETURNING *; diff --git a/sql/sqlite/oidc_schemas.sql b/sql/sqlite/oidc_schemas.sql index d9a7ba4e..ce55a717 100644 --- a/sql/sqlite/oidc_schemas.sql +++ b/sql/sqlite/oidc_schemas.sql @@ -1,44 +1,11 @@ -CREATE TABLE IF NOT EXISTS "oidc_codes" ( - "sub" TEXT NOT NULL UNIQUE, - "code_hash" TEXT NOT NULL PRIMARY KEY UNIQUE, - "scope" TEXT NOT NULL, - "redirect_uri" TEXT NOT NULL, - "client_id" TEXT NOT NULL, - "expires_at" INTEGER NOT NULL, - "nonce" TEXT DEFAULT "", - "code_challenge" TEXT DEFAULT "" -); - -CREATE TABLE IF NOT EXISTS "oidc_tokens" ( - "sub" TEXT NOT NULL UNIQUE, - "access_token_hash" TEXT NOT NULL PRIMARY KEY UNIQUE, - "refresh_token_hash" TEXT NOT NULL, - "code_hash" TEXT NOT NULL, +CREATE TABLE IF NOT EXISTS "oidc_sessions" ( + "sub" TEXT NOT NULL UNIQUE PRIMARY KEY, + "access_token_hash" TEXT NOT NULL UNIQUE, + "refresh_token_hash" TEXT NOT NULL UNIQUE, "scope" TEXT NOT NULL, "client_id" TEXT NOT NULL, "token_expires_at" INTEGER NOT NULL, "refresh_token_expires_at" INTEGER NOT NULL, - "nonce" TEXT DEFAULT "" -); - -CREATE TABLE IF NOT EXISTS "oidc_userinfo" ( - "sub" TEXT NOT NULL UNIQUE PRIMARY KEY, - "name" TEXT NOT NULL, - "preferred_username" TEXT NOT NULL, - "email" TEXT NOT NULL, - "groups" TEXT NOT NULL, - "updated_at" INTEGER NOT NULL, - "given_name" TEXT NOT NULL, - "family_name" TEXT NOT NULL, - "middle_name" TEXT NOT NULL, - "nickname" TEXT NOT NULL, - "profile" TEXT NOT NULL, - "picture" TEXT NOT NULL, - "website" TEXT NOT NULL, - "gender" TEXT NOT NULL, - "birthdate" TEXT NOT NULL, - "zoneinfo" TEXT NOT NULL, - "locale" TEXT NOT NULL, - "phone_number" TEXT NOT NULL, - "address" TEXT NOT NULL + "nonce" TEXT DEFAULT "", + "userinfo_json" TEXT NOT NULL ); diff --git a/sqlc.yml b/sqlc.yml index a6fbab5c..e4f98a25 100644 --- a/sqlc.yml +++ b/sqlc.yml @@ -22,11 +22,7 @@ sql: go_type: "string" - column: "sessions.ldap_groups" go_type: "string" - - column: "oidc_codes.nonce" - go_type: "string" - - column: "oidc_tokens.nonce" - go_type: "string" - - column: "oidc_codes.code_challenge" + - column: "oidc_sessions.nonce" go_type: "string" - engine: "postgresql" queries: "sql/postgres/*_queries.sql"